generate 添加Logo与Slogan

This commit is contained in:
2024-06-03 17:13:48 +08:00
parent 86e7119cfb
commit 49b086ad10
10 changed files with 261 additions and 102 deletions

View File

@@ -3,6 +3,7 @@ package com.ai.da.service.impl;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.constant.CommonConstant;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.enums.CollectionLevel2TypeEnum;
import com.ai.da.common.enums.GenerateModeEnum;
import com.ai.da.common.enums.ModelNameEnum;
import com.ai.da.common.utils.*;
@@ -21,7 +22,7 @@ import com.ai.da.service.GenerateService;
import com.ai.da.service.LibraryService;
import com.ai.da.service.RabbitMQService;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
@@ -32,13 +33,11 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import javax.annotation.Resource;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.*;
import java.util.stream.Collectors;
import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*;
@@ -85,6 +84,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Value("${redis.key.generateResult}")
private String generateResultKey;
@Value("${minio.bucketName.slogan}")
private String sloganBucket;
@Override
public GenerateCaptionVO generateCaption(Long sketchElementId) {
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
@@ -110,6 +112,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generate.setAccountId(accountId);
generate.setUniqueId(generateThroughImageTextDTO.getUniqueId());
generate.setLevel1Type(generateThroughImageTextDTO.getLevel1Type());
generate.setLevel2Type(generateThroughImageTextDTO.getLevel2Type());
generate.setSeed(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getSeed()) ? "" : generateThroughImageTextDTO.getSeed());
// 当level1type是sketchboard时存数据库需要加上当前性别
generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ?
generateType + " (" + generateThroughImageTextDTO.getGender() + ")" :
@@ -119,12 +123,15 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generate.setElementSource(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getDesignType()) ? null : generateThroughImageTextDTO.getDesignType());
String text = generateThroughImageTextDTO.getText();
generate.setText(text);
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(generate, text, elementId, generateType);
// validateGeneraType(generate, text, elementId);
if (!StringUtil.isNullOrEmpty(text)) {
text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type());
}
// todo 这一步现在还是有必要的吗?
// 2.1 sketch或print在t_collection_element表/t_library表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType());
@@ -134,15 +141,38 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
GenerateModeEnum.TEXT_IMAGE.getType();
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
// AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
// List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
// category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId()));
Boolean requestResult = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
mode, category, generateThroughImageTextDTO.getGender()));
// log.info("generate 响应 " + generatedSketchUrl);
// if (CollectionUtils.isEmpty(generatedSketchUrl)) {
// return null;
// }
String path = CommonConstant.GENERATE_PATH;
String jsonString = "";
HashMap<String, String> params = new HashMap<>();
// 3.1 确定不同类型的印花分别调哪个接口
if (generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())){
switch(generateThroughImageTextDTO.getLevel2Type()){
case "Logo":
path = CommonConstant.GENERATE_SINGLE_LOGO;
params.put("tasks_id",generateThroughImageTextDTO.getUniqueId());
params.put("prompt", text);
params.put("seed", generateThroughImageTextDTO.getSeed());
jsonString = JSON.toJSONString(params, SerializerFeature.WriteMapNullValue);
break;
case "Slogan":
path = CommonConstant.GENERATE_SLOGAN;
params.put("tasks_id",generateThroughImageTextDTO.getUniqueId());
params.put("prompt", text);
params.put("svg", collectionElement.getUrl());
jsonString = JSON.toJSONString(params, SerializerFeature.WriteMapNullValue);
break;
case "Pattern":
GenerateToPythonDTO generateToPythonDTO = new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
mode, category, generateThroughImageTextDTO.getGender());
jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue);
}
}else {
GenerateToPythonDTO generateToPythonDTO = new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
mode, category, generateThroughImageTextDTO.getGender());
jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue);
}
Boolean requestResult = pythonService.generateSketchOrPrint(jsonString, path);
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
save(generate);
@@ -229,8 +259,24 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
}
private void validateGeneraType(Generate generate, String text, Long elementId, String generateType) {
switch (generateType) {
private void validateGeneraType(Generate generate, String text, Long elementId) {
String generateType = "";
if (StringUtil.isNullOrEmpty(text.trim()) && Objects.isNull(elementId)) {
throw new BusinessException("please.input.the.caption.or.choose.an.image");
} else if (!StringUtil.isNullOrEmpty(text.trim()) && !Objects.isNull(elementId)) {
generateType = "text-image";
generate.setText(text);
generate.setElementId(elementId);
} else if (!StringUtil.isNullOrEmpty(text.trim())) {
generateType = "text";
generate.setText(text);
} else if (!Objects.isNull(elementId)) {
generateType = "image";
generate.setElementId(elementId);
}
generate.setGenerateType(generateType);
/*switch (generateType) {
case "text":
if (StringUtil.isNullOrEmpty(text)) {
throw new BusinessException("please.input.the.caption");
@@ -250,7 +296,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generate.setText(text);
generate.setElementId(elementId);
default:
}
}*/
}
private String modifyPrompt(String userInput, Generate generate, String level1Type) {
@@ -263,13 +309,15 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
break;
case "Printboard":
if (userInput.contains("Painting Style")) {
userInput = "Picasso,increased color saturation,increased glossiness," + translated;
text = "Picasso,increased color saturation,increased glossiness," + translated + ", fabric print, high quality";
} else if (userInput.contains("Illustration Style")) {
userInput = "Flat coating,romantic,soft,pencil strokes,accentuating and widening the depth of pencil strokes,paper patterns,block colors,crayons,reducing image contrast,and hand drawn painting marks," + translated;
text = "Flat coating,romantic,soft,pencil strokes,accentuating and widening the depth of pencil strokes,paper patterns,block colors,crayons,reducing image contrast,and hand drawn painting marks," + translated + ", fabric print, high quality";
} else if (userInput.contains("Real Style")) {
userInput = "Still life photography,hyper realism,3d,deepened projection,increased permutation value,increased concavity and convexity value," + translated;
text = "Still life photography,hyper realism,3d,deepened projection,increased permutation value,increased concavity and convexity value," + translated + ", fabric print, high quality";
}else {
text = translated;
}
text = userInput + ", fabric print, high quality";
// text = userInput + ", fabric print, high quality";
// generate.setText(text);
break;
case "Sketchboard":
@@ -408,10 +456,10 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
if (Objects.isNull(generateThroughImageTextDTO.getUserId())) {
throw new BusinessException("userId cannot be empty");
}
String generateType = generateThroughImageTextDTO.getGenerateType();
/*String generateType = generateThroughImageTextDTO.getGenerateType();
if (!GenerateModeEnum.getGenerateModeList().contains(generateType)) {
throw new BusinessException("unknown.generate.type");
}
}*/
// 判断试用用户是否还有剩余试用机会
int trialsCount = 0;
@@ -422,15 +470,75 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
}
}
String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(new Generate(), text, elementId, generateType);
int times = 4;
// 当level1Type为Print_board时level2Type为pattern时需要确定generateType
if (generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())){
if (StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getLevel2Type())){
throw new BusinessException("level2Type.cannot.be.empty");
}else if (!CollectionLevel2TypeEnum.printType().contains(generateThroughImageTextDTO.getLevel2Type())){
throw new BusinessException("unknown.parameter.level2Type");
}
// Pattern 参数校验
if (generateThroughImageTextDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.Pattern.getRealName())){
String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
Generate generate = new Generate();
validateGeneraType(generate, text, elementId);
// 校验后获取
generateThroughImageTextDTO.setGenerateType(generate.getGenerateType());
}
// Slogan 参数校验
if (generateThroughImageTextDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.SLOGAN.getRealName())){
if (StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getSloganBase64())){
log.error("Printboard-Slogan模式下slogan image为空");
throw new BusinessException("Slogan can not be empty!");
}
// 将图片上传到图片服务器
String path = minioUtil.base64Upload(generateThroughImageTextDTO.getSloganBase64(), sloganBucket);
String name = path.substring(path.lastIndexOf("/") + 1, path.lastIndexOf("."));
// 保存到db,collection-element
CollectionElement collectionElement = new CollectionElement();
collectionElement.setAccountId(generateThroughImageTextDTO.getUserId());
collectionElement.setCollectionId(0L);
collectionElement.setLevel1Type(PRINT_BOARD.getRealName());
collectionElement.setLevel2Type(CollectionLevel2TypeEnum.SLOGAN.getRealName());
collectionElement.setName(name);
collectionElement.setUrl(path);
collectionElement.setHasPin((byte) 0);
collectionElement.setMd5(MD5Utils.encryptFile(minioUtil.getPresignedUrl(path, 24 * 60), Boolean.FALSE));
collectionElement.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
collectionElementService.save(collectionElement);
// 将上传后的地址放在指定字段
generateThroughImageTextDTO.setCollectionElementId(collectionElement.getId());
generateThroughImageTextDTO.setSloganBase64(null);
generateThroughImageTextDTO.setDesignType("collection");
}
// Logo参数校验
if (generateThroughImageTextDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.LOGO.getRealName())){
// logo模式下一次只生成一张
times = 1;
// 校验是否输入内容
if (StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getText().trim())){
throw new BusinessException("please.input.the.prompt");
}
// 校验seed的取值范围
int seed = Integer.parseInt(generateThroughImageTextDTO.getSeed());
if (seed < 0 || seed > 99999){
throw new BusinessException("the.value.range.of.seed");
}
}
}
// 2、生成唯一id 使用uuid,由于uuid重复的几率很小故取消对uuid重复性的校验
String uuid = UUID.randomUUID().toString();
ArrayList<String> taskIdList = new ArrayList<>();
for (int i = 1; i <= 4; i++) {
for (int i = 1; i <= times; i++) {
String temp = uuid;
temp += "-" + i + "-" + generateThroughImageTextDTO.getUserId();
taskIdList.add(temp);
@@ -588,7 +696,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
String key = generateResultKey + ":" + uniqueId;
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
// 判断当前task的状态是不是Fail
if (!generateResultVO.getStatus().equals("Fail")){
if (!generateResultVO.getStatus().equals("Fail")) {
// 2、不是直接发送取消请求到python端
pythonService.cancelGenerateTask(uniqueId);
// 3、更改result中当前taskId的状态