diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index cbe440a2..b96d04b0 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -111,8 +111,9 @@ public class GenerateServiceImpl extends ServiceImpl i public void generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、获取用户信息 Long accountId = generateThroughImageTextDTO.getUserId(); - String generateType = generateThroughImageTextDTO.getGenerateType(); + GenerateModeEnum modeEnum = getMode(generateThroughImageTextDTO); + String generateType = modeEnum.getValue(); // 2、判断必须入参是否为非空(在prepare阶段已校验) Generate generate = new Generate(); generate.setAccountId(accountId); @@ -141,9 +142,7 @@ public class GenerateServiceImpl extends ServiceImpl i CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType()); // 3、向模型发起请求 - String mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ? - GenerateModeEnum.TEXT.getType() : - GenerateModeEnum.TEXT_IMAGE.getType(); + String mode = modeEnum.getType(); String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" : generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard"; String path = CommonConstant.GENERATE_PATH; @@ -188,7 +187,6 @@ public class GenerateServiceImpl extends ServiceImpl i jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue); } - Boolean requestResult = pythonService.generateSketchOrPrint(jsonString, port, path); // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中 @@ -207,6 +205,21 @@ public class GenerateServiceImpl extends ServiceImpl i } + public GenerateModeEnum getMode(GenerateThroughImageTextDTO generateThroughImageTextDTO){ + if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getText())){ + if (Objects.nonNull(generateThroughImageTextDTO.getCollectionElementId())){ + return GenerateModeEnum.TEXT_IMAGE; + }else { + return GenerateModeEnum.TEXT; + } + }else { + if (Objects.nonNull(generateThroughImageTextDTO.getCollectionElementId())){ + return GenerateModeEnum.IMAGE; + } + } + return GenerateModeEnum.TEXT; + } + @Override @Transactional(rollbackFor = Exception.class) public void processGenerateResult(String taskId, String url, String category) {