diff --git a/src/main/java/com/ai/da/mapper/primary/entity/CollectionElement.java b/src/main/java/com/ai/da/mapper/primary/entity/CollectionElement.java index 64103278..437ba163 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/CollectionElement.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/CollectionElement.java @@ -90,4 +90,19 @@ public class CollectionElement implements Serializable { private Long projectId; private Integer isCompositeImage; + + public CollectionElement() { + } + + public CollectionElement(Long accountId, String level1Type, String level2Type, String name, String url, Byte hasPin, String md5, Date createDate, Long projectId) { + this.accountId = accountId; + this.level1Type = level1Type; + this.level2Type = level2Type; + this.name = name; + this.url = url; + this.hasPin = hasPin; + this.md5 = md5; + this.createDate = createDate; + this.projectId = projectId; + } } diff --git a/src/main/java/com/ai/da/mapper/primary/entity/Generate.java b/src/main/java/com/ai/da/mapper/primary/entity/Generate.java index cd0e3743..a8f04670 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/Generate.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/Generate.java @@ -107,10 +107,11 @@ public class Generate { public Generate() { } - public Generate(Long accountId, String uniqueId, String level1Type, String text, String generateType, String modelName, Date createDate) { + public Generate(Long accountId, String uniqueId, String level1Type, String level2Type, String text, String generateType, String modelName, Date createDate) { this.accountId = accountId; this.uniqueId = uniqueId; this.level1Type = level1Type; + this.level2Type = level2Type; this.text = text; this.generateType = generateType; this.modelName = modelName; diff --git a/src/main/java/com/ai/da/model/dto/GenerateModifyDTO.java b/src/main/java/com/ai/da/model/dto/GenerateModifyDTO.java index 9c5fe97a..38ddf43a 100644 --- a/src/main/java/com/ai/da/model/dto/GenerateModifyDTO.java +++ b/src/main/java/com/ai/da/model/dto/GenerateModifyDTO.java @@ -28,7 +28,7 @@ public class GenerateModifyDTO { private Long originalId; @NotBlank(message = "original Id Source cannot be empty") - @ApiModelProperty(value = "原图id的来源", required = true) + @ApiModelProperty(value = "原图id的来源 Library || Generate || Collection", required = true) private String originalIdSource; @NotNull(message = "isOverride cannot be empty") diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 0c0cc944..2febb2b5 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -67,6 +67,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*; +import static com.ai.da.common.enums.CreditsEventsEnum.PATTERN; import static com.ai.da.common.enums.CreditsEventsEnum.TO_PRODUCT_IMAGE; import static com.ai.da.common.enums.CreditsEventsEnum.TO_PRODUCT_IMAGE_FLUX; @@ -288,9 +289,14 @@ public class GenerateServiceImpl extends ServiceImpl i // 将相应的url保存到数据库 generateDetailMapper.insert(generateDetail); + String uuid = taskId.substring(0, taskId.substring(0, taskId.lastIndexOf("-")).lastIndexOf("-")); String key = generateResultKey + ":" + taskId; String imageName = url.substring(url.lastIndexOf("/") + 1); String status = imageName.equals("white_image.jpg") ? "Invalid" : "Success"; + if (StringUtil.isNullOrEmpty(category)){ + Generate generateRecord = selectByUniqueId(taskId); + category = generateRecord.getLevel2Type(); + } GenerateResultVO generateResultVO = new GenerateResultVO(taskId, generateDetail.getId(), url, status, category); // 更新redis redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); @@ -299,7 +305,6 @@ public class GenerateServiceImpl extends ServiceImpl i // ** 注:如果生成的图片都是空白 则不扣积分 if (!status.equals("Invalid")) { String accountId = taskId.substring(taskId.lastIndexOf("-") + 1); - String uuid = taskId.substring(0, taskId.substring(0, taskId.lastIndexOf("-")).lastIndexOf("-")); Boolean flag = creditsService.taskCreditsDeduction(Long.parseLong(accountId), uuid); if (flag) creditsService.updateChangedCredits(accountId, uuid); } @@ -390,6 +395,7 @@ public class GenerateServiceImpl extends ServiceImpl i // generate.setText(text); break; case "Printboard": + case "Pattern": text = translated; /*if (prefix.contains("Painting Style")) { text = "Picasso,increased color saturation,increased glossiness," + translated + ", fabric print, high quality"; @@ -542,75 +548,117 @@ public class GenerateServiceImpl extends ServiceImpl i } @Override - public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateDTO) { // public List prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、参数检查,判断必须参数是否为空 - if (Objects.isNull(generateThroughImageTextDTO.getUserId())) { + if (Objects.isNull(generateDTO.getUserId())) { throw new BusinessException("userId cannot be empty"); } CreditsEventsEnum creditsEventsEnum = CreditsEventsEnum.OTHER; - if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) && generateThroughImageTextDTO.getModelName().equals("wx")) { - String taskId = createAsyncTask(generateThroughImageTextDTO); -// String taskId = "e53c86ea-53be-424b-8ac7-3c01c141f4f7"; + if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) && generateDTO.getModelName().equals("wx")) { + String taskId = createAsyncTask(generateDTO); creditsEventsEnum = CreditsEventsEnum.WX_TEXT2IMG; // 6、添加预扣除积分到redis - creditsService.addRecordToCreditsDeduction(generateThroughImageTextDTO.getUserId(), taskId, creditsEventsEnum); + creditsService.addRecordToCreditsDeduction(generateDTO.getUserId(), taskId, creditsEventsEnum); // 6.1 添加积分扣除记录到db - creditsService.preInsert(generateThroughImageTextDTO.getUserId(), creditsEventsEnum.getName(), taskId, Boolean.TRUE, null); + creditsService.preInsert(generateDTO.getUserId(), creditsEventsEnum.getName(), taskId, Boolean.TRUE, null); + // 7、返回唯一id + return new PrepareForGenerateVO(Collections.singletonList(taskId), 2); + } else if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) + && generateDTO.getModelName().equals("flux") && generateDTO.getLevel2Type().equals("Pattern")){ + String imagePath = null; + if (Objects.nonNull(generateDTO.getCollectionElementId()) && !StringUtil.isNullOrEmpty(generateDTO.getDesignType())){ + switch (generateDTO.getDesignType()){ + case "collection": + CollectionElement collectionElement = collectionElementMapper.selectById(generateDTO.getCollectionElementId()); + if (Objects.nonNull(collectionElement)){ + imagePath = collectionElement.getUrl(); + } + break; + case "library": + Library library = libraryService.getById(generateDTO.getCollectionElementId()); + if (Objects.nonNull(library)){ + imagePath = library.getUrl(); + } + } + } + String taskId = flux(PATTERN, generateDTO.getText(), imagePath, false); + Generate generate = CopyUtil.copyObject(generateDTO, Generate.class); + generate.setAccountId(generateDTO.getUserId()); + generate.setUniqueId(taskId); + generate.setElementSource(generateDTO.getDesignType()); + generate.setElementId(generateDTO.getCollectionElementId()); + String generateType; + if (Objects.nonNull(generateDTO.getCollectionElementId()) && !StringUtil.isNullOrEmpty(generateDTO.getText())){ + generateType = "text-image"; + } else if (Objects.nonNull(generateDTO.getCollectionElementId()) && StringUtil.isNullOrEmpty(generateDTO.getText())){ + generateType = "image"; + } else { + generateType = "text"; + } + generate.setGenerateType(generateType); + generate.setModelName("flux"); + generate.setCreateDate(new Date()); + save(generate); + + creditsEventsEnum = CreditsEventsEnum.FLUX_IMG2IMG; + // 6、添加预扣除积分到redis + creditsService.addRecordToCreditsDeduction(generateDTO.getUserId(), taskId, creditsEventsEnum); + // 6.1 添加积分扣除记录到db + creditsService.preInsert(generateDTO.getUserId(), creditsEventsEnum.getName(), taskId, Boolean.TRUE, null); // 7、返回唯一id return new PrepareForGenerateVO(Collections.singletonList(taskId), 2); } int times = 4; // 当level1Type为Print_board时,level2Type为pattern时需要确定generateType - if (generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())) { - if (StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getLevel2Type())) { + if (generateDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())) { + if (StringUtil.isNullOrEmpty(generateDTO.getLevel2Type())) { throw new BusinessException("level2Type.cannot.be.empty"); - } else if (!CollectionLevel2TypeEnum.printType().contains(generateThroughImageTextDTO.getLevel2Type())) { + } else if (!CollectionLevel2TypeEnum.printType().contains(generateDTO.getLevel2Type())) { throw new BusinessException("unknown.parameter.level2Type"); } // Pattern 参数校验 - if (generateThroughImageTextDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.Pattern.getRealName())) { - String text = generateThroughImageTextDTO.getText(); - Long elementId = generateThroughImageTextDTO.getCollectionElementId(); + if (generateDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.Pattern.getRealName())) { + String text = generateDTO.getText(); + Long elementId = generateDTO.getCollectionElementId(); Generate generate = new Generate(); validateGeneraType(generate, text, elementId); // 校验后获取 - generateThroughImageTextDTO.setGenerateType(generate.getGenerateType()); + generateDTO.setGenerateType(generate.getGenerateType()); // creditsEventsEnum = CreditsEventsEnum.PATTERN; creditsEventsEnum = CreditsEventsEnum.PATTERN; // 模型迁移SD1.? -> flux,从而产生了不同模型的选择, // high -> 生成图片质量高,但生成速度慢,每次生成只返回一张图片 // fast -> 生成图片质量低,但生成速度快,每次生成返回四张图片 - if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) && generateThroughImageTextDTO.getModelName().equals("high")) { + if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) && generateDTO.getModelName().equals("high")) { creditsEventsEnum = CreditsEventsEnum.LOCAL_TEXT2IMG_HIGH; times = 1; } } // Slogan 参数校验 slogan目前只能开一个接口。所以只有生产环境上能使用 - if (generateThroughImageTextDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.SLOGAN.getRealName())) { - if (StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getSloganBase64())) { + if (generateDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.SLOGAN.getRealName())) { + if (StringUtil.isNullOrEmpty(generateDTO.getSloganBase64())) { log.error("Printboard-Slogan模式下,slogan image为空"); throw new BusinessException("slogan.image.cannot.be.empty"); } - if (StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getText())) { + if (StringUtil.isNullOrEmpty(generateDTO.getText())) { log.error("Printboard-Slogan模式下,slogan text为空"); throw new BusinessException("slogan.style.cannot.be.empty"); } times = 1; // 将图片上传到图片服务器 - String path = minioUtil.base64UploadToPath(generateThroughImageTextDTO.getSloganBase64(), sloganBucket, null); -// String path = "test/7c9114f93d08a702e00da928e66f321.png"; + String path = minioUtil.base64UploadToPath(generateDTO.getSloganBase64(), sloganBucket, null); String name = path.substring(path.lastIndexOf("/") + 1, path.lastIndexOf(".")); // 保存到db,collection-element CollectionElement collectionElement = new CollectionElement(); - collectionElement.setAccountId(generateThroughImageTextDTO.getUserId()); + collectionElement.setAccountId(generateDTO.getUserId()); collectionElement.setCollectionId(0L); collectionElement.setLevel1Type(PRINT_BOARD.getRealName()); collectionElement.setLevel2Type(CollectionLevel2TypeEnum.SLOGAN.getRealName()); @@ -618,41 +666,41 @@ public class GenerateServiceImpl extends ServiceImpl i collectionElement.setUrl(path); collectionElement.setHasPin((byte) 0); collectionElement.setMd5(MD5Utils.encryptFile(minioUtil.getPreSignedUrl(path, 24 * 60), Boolean.FALSE)); - collectionElement.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone())); + collectionElement.setCreateDate(DateUtil.getByTimeZone(generateDTO.getTimeZone())); collectionElementService.save(collectionElement); // 将上传后的地址放在指定字段 - generateThroughImageTextDTO.setCollectionElementId(collectionElement.getId()); - generateThroughImageTextDTO.setSloganBase64(null); - generateThroughImageTextDTO.setDesignType("collection"); + generateDTO.setCollectionElementId(collectionElement.getId()); + generateDTO.setSloganBase64(null); + generateDTO.setDesignType("collection"); creditsEventsEnum = CreditsEventsEnum.SLOGAN; } // Logo参数校验 - if (generateThroughImageTextDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.LOGO.getRealName())) { + if (generateDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.LOGO.getRealName())) { // logo模式下一次只生成一张 times = 1; // 校验是否输入内容 - if (StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getText().trim())) { + if (StringUtil.isNullOrEmpty(generateDTO.getText().trim())) { throw new BusinessException("please.input.the.prompt"); } // 校验seed的取值范围 int seed = random.nextInt(501); log.info("随机种子:{}", seed); - generateThroughImageTextDTO.setSeed(String.valueOf(seed)); + generateDTO.setSeed(String.valueOf(seed)); creditsEventsEnum = CreditsEventsEnum.LOGO; } - } else if (generateThroughImageTextDTO.getLevel1Type().equals(MOOD_BOARD.getRealName())) { + } else if (generateDTO.getLevel1Type().equals(MOOD_BOARD.getRealName())) { creditsEventsEnum = CreditsEventsEnum.MOOD_BOARD; - if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) && generateThroughImageTextDTO.getModelName().equals("high")) { + if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) && generateDTO.getModelName().equals("high")) { creditsEventsEnum = CreditsEventsEnum.LOCAL_TEXT2IMG_HIGH; times = 1; } - } else if (generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName())) { + } else if (generateDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName())) { creditsEventsEnum = CreditsEventsEnum.SKETCH_BOARD; - if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) && generateThroughImageTextDTO.getModelName().equals("high")) { + if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) && generateDTO.getModelName().equals("high")) { creditsEventsEnum = CreditsEventsEnum.LOCAL_TEXT2IMG_HIGH; times = 1; } @@ -668,17 +716,17 @@ public class GenerateServiceImpl extends ServiceImpl i String uuid = UUID.randomUUID().toString(); // 除了 Moodboard || Printboard->Pattern(可以区分三种风格) || Sketchboard(Generate Sketch)这三个地方需要区分high || fast之外,其他地方保持原样 - if (generateThroughImageTextDTO.getLevel1Type().equals("Printboard") && !generateThroughImageTextDTO.getLevel2Type().equals("Pattern")) { - generateThroughImageTextDTO.setModelName(null); + if (generateDTO.getLevel1Type().equals("Printboard") && !generateDTO.getLevel2Type().equals("Pattern")) { + generateDTO.setModelName(null); } ArrayList taskIdList = new ArrayList<>(); for (int i = 1; i <= times; i++) { String taskId = uuid; - taskId += "-" + i + "-" + generateThroughImageTextDTO.getUserId(); + taskId += "-" + i + "-" + generateDTO.getUserId(); taskIdList.add(taskId); - generateThroughImageTextDTO.setUniqueId(taskId); - String jsonString = JSON.toJSONString(generateThroughImageTextDTO); + generateDTO.setUniqueId(taskId); + String jsonString = JSON.toJSONString(generateDTO); // 4、加入redis排队,便于获取实时排队信息 Double maxScore = redisUtil.getMaxScore(consumptionOrderKey); @@ -686,7 +734,7 @@ public class GenerateServiceImpl extends ServiceImpl i // 加入resultMap String key = generateResultKey + ":" + taskId; - GenerateResultVO generateResultVO = new GenerateResultVO(generateThroughImageTextDTO.getUniqueId(), null, null, "Waiting"); + GenerateResultVO generateResultVO = new GenerateResultVO(generateDTO.getUniqueId(), null, null, "Waiting"); redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); // 5、将消息发布到MQ消息队列 @@ -695,9 +743,9 @@ public class GenerateServiceImpl extends ServiceImpl i // update 积分扣除由按次收费改为按生成图片数量收费 --> 改回按次收费 // 6、添加预扣除积分到redis - creditsService.addRecordToCreditsDeduction(generateThroughImageTextDTO.getUserId(), uuid, creditsEventsEnum); + creditsService.addRecordToCreditsDeduction(generateDTO.getUserId(), uuid, creditsEventsEnum); // 6.1 添加积分扣除记录到db - creditsService.preInsert(generateThroughImageTextDTO.getUserId(), creditsEventsEnum.getName(), uuid, Boolean.TRUE, null); + creditsService.preInsert(generateDTO.getUserId(), creditsEventsEnum.getName(), uuid, Boolean.TRUE, null); // 7、返回唯一id return new PrepareForGenerateVO(taskIdList, 2); @@ -1112,13 +1160,15 @@ public class GenerateServiceImpl extends ServiceImpl i String gender = generateModifyDTO.getGender(); String category = generateModifyDTO.getCategory(); Long originalId = generateModifyDTO.getOriginalId(); + String originalIdSource = generateModifyDTO.getOriginalIdSource(); boolean isOverride = generateModifyDTO.getIsOverride(); - boolean isFromLibrary = !StringUtil.isNullOrEmpty(generateModifyDTO.getOriginalIdSource()) - && generateModifyDTO.getOriginalIdSource().equals("Library"); boolean isSketch = generateModifyDTO.getType().equals(SKETCH_BOARD.getRealName()); // 获取原始路径和可能的generateId - PathInfo pathInfo = getOriginalPathAndGenerateId(isFromLibrary, originalId); + PathInfo pathInfo = getOriginalPathAndGenerateId(originalIdSource, originalId); + if (Objects.isNull(pathInfo)){ + throw new BusinessException("unknown sourceIdType", ResultEnum.PROMPT.getCode()); + } // 确定存储路径 String storagePath = isOverride @@ -1132,9 +1182,11 @@ public class GenerateServiceImpl extends ServiceImpl i log.info("修改后的图片:{}", minioPath); // 保存到数据库并返回结果 - return isFromLibrary + return originalIdSource.equals("Library") ? handleLibrarySave(accountId, originalId, minioPath, category, gender, isOverride, generateModifyDTO.getType()) - : handleGenerateSave(originalId, pathInfo.generateId, minioPath, category, isOverride); + : originalIdSource.equals("Generate") + ? handleGenerateSave(originalId, pathInfo.generateId, minioPath, category, isOverride) + : handleUploadSave(accountId, originalId, minioPath, category, isOverride, generateModifyDTO.getType()); } private static class PathInfo { @@ -1147,18 +1199,25 @@ public class GenerateServiceImpl extends ServiceImpl i } } - private PathInfo getOriginalPathAndGenerateId(boolean isFromLibrary, Long originalId) { - if (isFromLibrary) { - return new PathInfo(libraryService.getById(originalId).getUrl(), null); - } else { - GenerateDetail detail = generateDetailMapper.selectById(originalId); - return new PathInfo(detail.getUrl(), detail.getGenerateId()); + private PathInfo getOriginalPathAndGenerateId(String originalIdSource, Long originalId) { + switch (originalIdSource) { + case "Library": + return new PathInfo(libraryService.getById(originalId).getUrl(), null); + case "Generate": + GenerateDetail detail = generateDetailMapper.selectById(originalId); + return new PathInfo(detail.getUrl(), detail.getGenerateId()); + case "Collection": + CollectionElement collectionElement = collectionElementMapper.selectById(originalId); + return new PathInfo(collectionElement.getUrl(), null); + default: + return null; } } private GenerateResultVO handleLibrarySave(Long accountId, Long libraryId, String minioPath, - String category, String gender, boolean isOverride, String type) { + String category, String gender, boolean isOverride, String level1Type) { Library library; + String md5 = MD5Utils.encryptFile(minioUtil.getPreSignedUrl(minioPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), false); if (isOverride) { library = new Library(); library.setId(libraryId); @@ -1166,9 +1225,7 @@ public class GenerateServiceImpl extends ServiceImpl i library.setUpdateDate(new Date()); libraryService.updateById(library); } else { - library = new Library(accountId, type, category, gender, minioPath, - MD5Utils.encryptFile(minioUtil.getPreSignedUrl(minioPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), false), - new Date()); + library = new Library(accountId, level1Type, category, gender, minioPath, md5, new Date()); libraryService.save(library); libraryId = library.getId(); } @@ -1178,6 +1235,7 @@ public class GenerateServiceImpl extends ServiceImpl i private GenerateResultVO handleGenerateSave(Long originalId, Long generateId, String minioPath, String category, boolean isOverride) { GenerateDetail generateDetail = new GenerateDetail(); + String md5 = MD5Utils.encryptFile(minioUtil.getPreSignedUrl(minioPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME, true), Boolean.FALSE); if (isOverride) { generateDetail.setId(originalId); generateDetail.setUrl(minioPath); @@ -1187,8 +1245,7 @@ public class GenerateServiceImpl extends ServiceImpl i generateDetail.setGenerateId(generateId); generateDetail.setUrl(minioPath); generateDetail.setIsLike((byte)0); - generateDetail.setMd5(MD5Utils.encryptFile(minioUtil.getPreSignedUrl(minioPath, - CommonConstant.MINIO_IMAGE_EXPIRE_TIME, true), Boolean.FALSE)); + generateDetail.setMd5(md5); generateDetail.setCreateDate(LocalDateTime.now()); generateDetailMapper.insert(generateDetail); originalId = generateDetail.getId(); @@ -1196,6 +1253,26 @@ public class GenerateServiceImpl extends ServiceImpl i return buildResultVO(originalId, minioPath, category); } + private GenerateResultVO handleUploadSave(Long accountId, Long originalId, String minioPath, + String category, boolean isOverride, String level1Type){ + CollectionElement collectionElement; + String md5 = MD5Utils.encryptFile(minioUtil.getPreSignedUrl(minioPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), false); + if (isOverride){ + collectionElement = new CollectionElement(); + collectionElement.setId(originalId); + collectionElement.setUrl(minioPath); + collectionElement.setMd5(md5); + collectionElement.setUpdateDate(new Date()); + collectionElementMapper.updateById(collectionElement); + }else { + CollectionElement originalElement = collectionElementMapper.selectById(originalId); + String name = minioPath.substring(minioPath.lastIndexOf("/") + 1, minioPath.lastIndexOf(".")); + collectionElement = new CollectionElement(accountId, level1Type, category, name, minioPath, (byte)0, md5, new Date(), originalElement.getProjectId()); + collectionElementMapper.insert(collectionElement); + } + return buildResultVO(collectionElement.getId(), minioPath, category); + } + private GenerateResultVO buildResultVO(Long id, String minioPath, String category) { String url = minioUtil.getPreSignedUrl(minioPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME, true); return new GenerateResultVO(id, url, "Success", category); @@ -1912,12 +1989,13 @@ public class GenerateServiceImpl extends ServiceImpl i * * @return taskId */ - public String createAsyncTask(GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public String createAsyncTask(GenerateThroughImageTextDTO generateDTO) { // String prompt = "一间有着精致窗户的花店,漂亮的木质门,摆放着花朵"; - String level1Type = generateThroughImageTextDTO.getLevel1Type(); - String prompt = generateThroughImageTextDTO.getText(); - Long userId = generateThroughImageTextDTO.getUserId(); - String gender = generateThroughImageTextDTO.getGender(); + String level1Type = generateDTO.getLevel1Type(); + String level2Type = generateDTO.getLevel2Type(); + String prompt = generateDTO.getText(); + Long userId = generateDTO.getUserId(); + String gender = generateDTO.getGender(); // 添加预设prompt,使生成结果更加具有指向性(区分不同的board) switch (level1Type) { @@ -1956,7 +2034,7 @@ public class GenerateServiceImpl extends ServiceImpl i String taskId = result.getOutput().getTaskId(); log.info("wx text2image 请求生成:{}, taskId:{}", JsonUtils.toJson(result), taskId); - Generate generate = new Generate(userId, taskId, level1Type, prompt, "text(" + gender + ")", "wx", new Date()); + Generate generate = new Generate(userId, taskId, level1Type, level2Type, prompt, "text(" + gender + ")", "wx", new Date()); save(generate); return taskId; } @@ -2012,8 +2090,10 @@ public class GenerateServiceImpl extends ServiceImpl i } else { log.warn("未提取到性别"); } + }else if (generate.getLevel1Type().equals(PRINT_BOARD.getRealName())){ + Generate generateRecord = selectByUniqueId(taskId); + generateResultVO.setCategory(generateRecord.getLevel2Type()); } - return generateResultVO; } else { throw new BusinessException("Unknown generate task"); @@ -2528,6 +2608,18 @@ public class GenerateServiceImpl extends ServiceImpl i if (childStyle) prompt = prompt + ", Children's face"; break; } + } else { + prompt = modifyPrompt(prompt, null, func.getName(), null); + switch (func) { + case PATTERN: + prompt = "pattern image, " + prompt; + break; + case SKETCH_BOARD: + prompt = "a single item of sketch of " + prompt + ", clean white background, simple lines"; + break; + default: + log.warn("未知类型 type:{}", func); + } } JSONObject requestBody = new JSONObject(); @@ -2581,7 +2673,7 @@ public class GenerateServiceImpl extends ServiceImpl i fluxResultRequestUrl = pollingUrl; } - String resp = sendRequestUtil.sendGet(fluxResultRequestUrl, params); + String resp = sendRequestUtil.sendFluxGet(fluxResultRequestUrl, params); log.info("获取flux生成的结果为:{}", resp); JSONObject respObj = JSONUtil.parseObj(resp); String status = respObj.getStr("status"); @@ -2641,9 +2733,15 @@ public class GenerateServiceImpl extends ServiceImpl i generateDetailMapper.updateById(generateDetail); } String url = generateDetail.getUrl(); - String clothCategory = pythonService.getClothCategory(url, extractGender(generate.getGenerateType())); + String category ; + if (generate.getLevel1Type().equals(SKETCH_BOARD.getRealName())){ + category = pythonService.getClothCategory(url, extractGender(generate.getGenerateType())); + } else { + category = generate.getLevel2Type(); + } + return new GenerateResultVO(taskId, generateDetail.getId(), - minioUtil.getPreSignedUrl(url, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), "Success", clothCategory); + minioUtil.getPreSignedUrl(url, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), "Success", category); } else { throw new BusinessException("unknown generate"); }