Merge branch 'dev/dev_xp' into dev/dev

This commit is contained in:
2024-04-25 14:25:33 +08:00

View File

@@ -124,7 +124,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
String text = generateThroughImageTextDTO.getText(); String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId(); Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(generate, text, elementId, generateType); validateGeneraType(generate, text, elementId, generateType);
if (generateType.equals("text") || generateType.equals("text-image")){ if (generateType.equals("text") || generateType.equals("text-image")) {
text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type()); text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type());
} }
@@ -140,7 +140,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil(); // AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
// List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), // List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
// category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId())); // category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId()));
Boolean requestResult = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text,Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), Boolean requestResult = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
mode, category, generateThroughImageTextDTO.getGender())); mode, category, generateThroughImageTextDTO.getGender()));
// log.info("generate 响应 " + generatedSketchUrl); // log.info("generate 响应 " + generatedSketchUrl);
// if (CollectionUtils.isEmpty(generatedSketchUrl)) { // if (CollectionUtils.isEmpty(generatedSketchUrl)) {
@@ -153,9 +153,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 5、将本次请求存入redis // 5、将本次请求存入redis
String key = generateResultKey + ":" + generateThroughImageTextDTO.getUniqueId(); String key = generateResultKey + ":" + generateThroughImageTextDTO.getUniqueId();
String status; String status;
if (requestResult){ if (requestResult) {
status = "Executing"; status = "Executing";
}else { } else {
status = "Fail"; status = "Fail";
} }
GenerateResultVO generateResultVO = new GenerateResultVO(generateThroughImageTextDTO.getUniqueId(), null, null, status); GenerateResultVO generateResultVO = new GenerateResultVO(generateThroughImageTextDTO.getUniqueId(), null, null, status);
@@ -193,19 +193,19 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void processGenerateResult(String taskId, String url, String category){ public void processGenerateResult(String taskId, String url, String category) {
// 5、处理模型返回的数据 // 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库 // 5.1 将相应的url保存到数据库
GenerateDetail generateDetail = new GenerateDetail(); GenerateDetail generateDetail = new GenerateDetail();
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
Generate generate ; Generate generate;
try{ try {
generate = selectByUniqueId(taskId); generate = selectByUniqueId(taskId);
}catch (MybatisPlusException e){ } catch (MybatisPlusException e) {
log.error(e.getMessage()); log.error(e.getMessage());
if (e.getMessage().equals("One record is expected, but the query result is multiple records")){ if (e.getMessage().equals("One record is expected, but the query result is multiple records")) {
generate = selectListByUniqueId(taskId).get(0); generate = selectListByUniqueId(taskId).get(0);
}else { } else {
throw new BusinessException("There are some problems with database query, please try again."); throw new BusinessException("There are some problems with database query, please try again.");
} }
@@ -256,7 +256,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
} }
} }
private String modifyPrompt(String userInput, Generate generate, String level1Type){ private String modifyPrompt(String userInput, Generate generate, String level1Type) {
String text = ""; String text = "";
switch (level1Type) { switch (level1Type) {
case "Moodboard": case "Moodboard":
@@ -264,6 +264,13 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generate.setText(text); generate.setText(text);
break; break;
case "Printboard": case "Printboard":
if (userInput.contains("Painting Style")) {
userInput = "Picasso,increased color saturation,increased glossiness," + userInput;
} else if (userInput.contains("Illustration Style")) {
userInput = "Flat coating,romantic,soft,pencil strokes,accentuating and widening the depth of pencil strokes,paper patterns,block colors,crayons,reducing image contrast,and hand drawn painting marks," + userInput;
} else if (userInput.contains("Real Style")) {
userInput = "Still life photography,hyper realism,3d,deepened projection,increased permutation value,increased concavity and convexity value," + userInput;
}
text = userInput + ", fabric print, high quality"; text = userInput + ", fabric print, high quality";
generate.setText(text); generate.setText(text);
break; break;
@@ -410,9 +417,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 判断试用用户是否还有剩余试用机会 // 判断试用用户是否还有剩余试用机会
int trialsCount = 0; int trialsCount = 0;
if (generateThroughImageTextDTO.getIsTestUser()){ if (generateThroughImageTextDTO.getIsTestUser()) {
trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type()); trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type());
if (trialsCount >= 2){ if (trialsCount >= 2) {
return new PrepareForGenerateVO(0); return new PrepareForGenerateVO(0);
} }
} }
@@ -443,7 +450,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
} }
ArrayList<String> taskIdList = new ArrayList<>(); ArrayList<String> taskIdList = new ArrayList<>();
for (int i = 1 ; i <= 4 ; i++){ for (int i = 1; i <= 4; i++) {
String temp = uuid; String temp = uuid;
temp += "-" + i + "-" + generateThroughImageTextDTO.getUserId(); temp += "-" + i + "-" + generateThroughImageTextDTO.getUserId();
taskIdList.add(temp); taskIdList.add(temp);
@@ -528,19 +535,19 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class); GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
if (!Objects.isNull(generateResultVO) && !StringUtil.isNullOrEmpty(generateResultVO.getUrl())) { if (!Objects.isNull(generateResultVO) && !StringUtil.isNullOrEmpty(generateResultVO.getUrl())) {
String url = generateResultVO.getUrl(); String url = generateResultVO.getUrl();
if (url.substring(url.lastIndexOf("/") + 1).equals("white_image.jpg")){ if (url.substring(url.lastIndexOf("/") + 1).equals("white_image.jpg")) {
generateResultVO.setStatus("Invalid"); generateResultVO.setStatus("Invalid");
}else { } else {
generateResultVO.setUrl(minioUtil.getPresignedUrl(url, CommonConstant.MINIO_IMAGE_EXPIRE_TIME)); generateResultVO.setUrl(minioUtil.getPresignedUrl(url, CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
} }
}else if (Objects.isNull(generateResultVO)){ } else if (Objects.isNull(generateResultVO)) {
generateResultVO = new GenerateResultVO(); generateResultVO = new GenerateResultVO();
} }
if (!StringUtil.isNullOrEmpty(generateResultVO.getStatus())) collect.add(generateResultVO.getStatus()); if (!StringUtil.isNullOrEmpty(generateResultVO.getStatus())) collect.add(generateResultVO.getStatus());
results.add(generateResultVO); results.add(generateResultVO);
}); });
// todo // todo
if (taskIdList.size() == 4 && collect.size() == 1 && collect.contains("Fail")){ if (taskIdList.size() == 4 && collect.size() == 1 && collect.contains("Fail")) {
log.info("当前4个生成结果均为失败"); log.info("当前4个生成结果均为失败");
throw new BusinessException("generate.interface.error"); throw new BusinessException("generate.interface.error");
} }
@@ -601,24 +608,24 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
} }
// 判断试用用户试用generate机会是否使用完毕 每个board 3次机会 // 判断试用用户试用generate机会是否使用完毕 每个board 3次机会
private int getTrialsCount(Long userId, String level1Type){ private int getTrialsCount(Long userId, String level1Type) {
List<Generate> getGenerateList = getGenerateByAccountId(userId, level1Type); List<Generate> getGenerateList = getGenerateByAccountId(userId, level1Type);
int trialsCount ; int trialsCount;
if (getGenerateList.isEmpty()){ if (getGenerateList.isEmpty()) {
trialsCount = 0; trialsCount = 0;
} else if (getGenerateList.size() == 1 || (getGenerateList.size() >= 4 && getGenerateList.size() / 4 == 1)) { } else if (getGenerateList.size() == 1 || (getGenerateList.size() >= 4 && getGenerateList.size() / 4 == 1)) {
trialsCount = 1; trialsCount = 1;
} else if (getGenerateList.size() == 2 || (getGenerateList.size() >= 4 && getGenerateList.size() / 4 == 2)) { } else if (getGenerateList.size() == 2 || (getGenerateList.size() >= 4 && getGenerateList.size() / 4 == 2)) {
trialsCount = 2; trialsCount = 2;
}else { } else {
trialsCount = 2; trialsCount = 2;
} }
return trialsCount; return trialsCount;
} }
public List<Generate> getGenerateByAccountId(Long accountId, String level1Type){ public List<Generate> getGenerateByAccountId(Long accountId, String level1Type) {
QueryWrapper<Generate> qw = new QueryWrapper<>(); QueryWrapper<Generate> qw = new QueryWrapper<>();
qw.eq("account_id",accountId); qw.eq("account_id", accountId);
qw.eq("level1_type", level1Type); qw.eq("level1_type", level1Type);
return baseMapper.selectList(qw); return baseMapper.selectList(qw);