From 9627239f9a163e1b1b3d7acf42889ce768284d4a Mon Sep 17 00:00:00 2001 From: xupei Date: Thu, 25 Apr 2024 14:23:37 +0800 Subject: [PATCH] =?UTF-8?q?generate--print=20=E5=A2=9E=E5=8A=A0=E9=A3=8E?= =?UTF-8?q?=E6=A0=BC=E6=8F=90=E7=A4=BA=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../da/service/impl/GenerateServiceImpl.java | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 52f0a9c9..694bee4f 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -124,7 +124,7 @@ public class GenerateServiceImpl extends ServiceImpl i String text = generateThroughImageTextDTO.getText(); Long elementId = generateThroughImageTextDTO.getCollectionElementId(); validateGeneraType(generate, text, elementId, generateType); - if (generateType.equals("text") || generateType.equals("text-image")){ + if (generateType.equals("text") || generateType.equals("text-image")) { text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type()); } @@ -140,7 +140,7 @@ public class GenerateServiceImpl extends ServiceImpl i // AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil(); // List generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), // category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId())); - Boolean requestResult = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text,Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), + Boolean requestResult = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), mode, category, generateThroughImageTextDTO.getGender())); // log.info("generate 响应 : " + generatedSketchUrl); // if (CollectionUtils.isEmpty(generatedSketchUrl)) { @@ -153,9 +153,9 @@ public class GenerateServiceImpl extends ServiceImpl i // 5、将本次请求存入redis String key = generateResultKey + ":" + generateThroughImageTextDTO.getUniqueId(); String status; - if (requestResult){ + if (requestResult) { status = "Executing"; - }else { + } else { status = "Fail"; } GenerateResultVO generateResultVO = new GenerateResultVO(generateThroughImageTextDTO.getUniqueId(), null, null, status); @@ -193,19 +193,19 @@ public class GenerateServiceImpl extends ServiceImpl i @Override @Transactional(rollbackFor = Exception.class) - public void processGenerateResult(String taskId, String url, String category){ + public void processGenerateResult(String taskId, String url, String category) { // 5、处理模型返回的数据 // 5.1 将相应的url保存到数据库 GenerateDetail generateDetail = new GenerateDetail(); GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); - Generate generate ; - try{ + Generate generate; + try { generate = selectByUniqueId(taskId); - }catch (MybatisPlusException e){ + } catch (MybatisPlusException e) { log.error(e.getMessage()); - if (e.getMessage().equals("One record is expected, but the query result is multiple records")){ + if (e.getMessage().equals("One record is expected, but the query result is multiple records")) { generate = selectListByUniqueId(taskId).get(0); - }else { + } else { throw new BusinessException("There are some problems with database query, please try again."); } @@ -256,7 +256,7 @@ public class GenerateServiceImpl extends ServiceImpl i } } - private String modifyPrompt(String userInput, Generate generate, String level1Type){ + private String modifyPrompt(String userInput, Generate generate, String level1Type) { String text = ""; switch (level1Type) { case "Moodboard": @@ -264,6 +264,13 @@ public class GenerateServiceImpl extends ServiceImpl i generate.setText(text); break; case "Printboard": + if (userInput.contains("Painting Style")) { + userInput = "Picasso,increased color saturation,increased glossiness," + userInput; + } else if (userInput.contains("Illustration Style")) { + userInput = "Flat coating,romantic,soft,pencil strokes,accentuating and widening the depth of pencil strokes,paper patterns,block colors,crayons,reducing image contrast,and hand drawn painting marks," + userInput; + } else if (userInput.contains("Real Style")) { + userInput = "Still life photography,hyper realism,3d,deepened projection,increased permutation value,increased concavity and convexity value," + userInput; + } text = userInput + ", fabric print, high quality"; generate.setText(text); break; @@ -410,9 +417,9 @@ public class GenerateServiceImpl extends ServiceImpl i // 判断试用用户是否还有剩余试用机会 int trialsCount = 0; - if (generateThroughImageTextDTO.getIsTestUser()){ + if (generateThroughImageTextDTO.getIsTestUser()) { trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type()); - if (trialsCount >= 2){ + if (trialsCount >= 2) { return new PrepareForGenerateVO(0); } } @@ -443,7 +450,7 @@ public class GenerateServiceImpl extends ServiceImpl i } ArrayList taskIdList = new ArrayList<>(); - for (int i = 1 ; i <= 4 ; i++){ + for (int i = 1; i <= 4; i++) { String temp = uuid; temp += "-" + i + "-" + generateThroughImageTextDTO.getUserId(); taskIdList.add(temp); @@ -528,19 +535,19 @@ public class GenerateServiceImpl extends ServiceImpl i GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class); if (!Objects.isNull(generateResultVO) && !StringUtil.isNullOrEmpty(generateResultVO.getUrl())) { String url = generateResultVO.getUrl(); - if (url.substring(url.lastIndexOf("/") + 1).equals("white_image.jpg")){ + if (url.substring(url.lastIndexOf("/") + 1).equals("white_image.jpg")) { generateResultVO.setStatus("Invalid"); - }else { + } else { generateResultVO.setUrl(minioUtil.getPresignedUrl(url, CommonConstant.MINIO_IMAGE_EXPIRE_TIME)); } - }else if (Objects.isNull(generateResultVO)){ + } else if (Objects.isNull(generateResultVO)) { generateResultVO = new GenerateResultVO(); } if (!StringUtil.isNullOrEmpty(generateResultVO.getStatus())) collect.add(generateResultVO.getStatus()); results.add(generateResultVO); }); // todo - if (taskIdList.size() == 4 && collect.size() == 1 && collect.contains("Fail")){ + if (taskIdList.size() == 4 && collect.size() == 1 && collect.contains("Fail")) { log.info("当前4个生成结果均为失败"); throw new BusinessException("generate.interface.error"); } @@ -601,24 +608,24 @@ public class GenerateServiceImpl extends ServiceImpl i } // 判断试用用户试用generate机会是否使用完毕 每个board 3次机会 - private int getTrialsCount(Long userId, String level1Type){ + private int getTrialsCount(Long userId, String level1Type) { List getGenerateList = getGenerateByAccountId(userId, level1Type); - int trialsCount ; - if (getGenerateList.isEmpty()){ + int trialsCount; + if (getGenerateList.isEmpty()) { trialsCount = 0; } else if (getGenerateList.size() == 1 || (getGenerateList.size() >= 4 && getGenerateList.size() / 4 == 1)) { trialsCount = 1; } else if (getGenerateList.size() == 2 || (getGenerateList.size() >= 4 && getGenerateList.size() / 4 == 2)) { trialsCount = 2; - }else { + } else { trialsCount = 2; } return trialsCount; } - public List getGenerateByAccountId(Long accountId, String level1Type){ + public List getGenerateByAccountId(Long accountId, String level1Type) { QueryWrapper qw = new QueryWrapper<>(); - qw.eq("account_id",accountId); + qw.eq("account_id", accountId); qw.eq("level1_type", level1Type); return baseMapper.selectList(qw);