From c19e9094d1ec6750dde5b6474e6ee65a87df333d Mon Sep 17 00:00:00 2001 From: xupei Date: Fri, 9 May 2025 17:03:42 +0800 Subject: [PATCH] =?UTF-8?q?BUGFIX:=20generate=20mode=E5=AD=97=E6=AE=B5?= =?UTF-8?q?=E4=BC=A0=E9=80=92=E4=B8=8D=E5=87=86=E7=A1=AE=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=E7=94=9F=E6=88=90=E7=BB=93=E6=9E=9C=E4=B8=8E=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E6=B2=A1=E6=9C=89=E5=85=B3=E8=81=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../da/service/impl/GenerateServiceImpl.java | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index cbe440a2..b96d04b0 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -111,8 +111,9 @@ public class GenerateServiceImpl extends ServiceImpl i public void generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、获取用户信息 Long accountId = generateThroughImageTextDTO.getUserId(); - String generateType = generateThroughImageTextDTO.getGenerateType(); + GenerateModeEnum modeEnum = getMode(generateThroughImageTextDTO); + String generateType = modeEnum.getValue(); // 2、判断必须入参是否为非空(在prepare阶段已校验) Generate generate = new Generate(); generate.setAccountId(accountId); @@ -141,9 +142,7 @@ public class GenerateServiceImpl extends ServiceImpl i CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType()); // 3、向模型发起请求 - String mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ? - GenerateModeEnum.TEXT.getType() : - GenerateModeEnum.TEXT_IMAGE.getType(); + String mode = modeEnum.getType(); String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" : generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard"; String path = CommonConstant.GENERATE_PATH; @@ -188,7 +187,6 @@ public class GenerateServiceImpl extends ServiceImpl i jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue); } - Boolean requestResult = pythonService.generateSketchOrPrint(jsonString, port, path); // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中 @@ -207,6 +205,21 @@ public class GenerateServiceImpl extends ServiceImpl i } + public GenerateModeEnum getMode(GenerateThroughImageTextDTO generateThroughImageTextDTO){ + if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getText())){ + if (Objects.nonNull(generateThroughImageTextDTO.getCollectionElementId())){ + return GenerateModeEnum.TEXT_IMAGE; + }else { + return GenerateModeEnum.TEXT; + } + }else { + if (Objects.nonNull(generateThroughImageTextDTO.getCollectionElementId())){ + return GenerateModeEnum.IMAGE; + } + } + return GenerateModeEnum.TEXT; + } + @Override @Transactional(rollbackFor = Exception.class) public void processGenerateResult(String taskId, String url, String category) {