diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 2febb2b5..785c2917 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -549,206 +549,368 @@ public class GenerateServiceImpl extends ServiceImpl i @Override public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateDTO) { -// public List prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、参数检查,判断必须参数是否为空 + validateRequiredParams(generateDTO); + + // 2、处理特殊模型情况(wx/flux) + if (isWxModel(generateDTO)) { + return handleWxModelGeneration(generateDTO); + } + if (isFluxPatternModel(generateDTO)) { + return handleFluxPatternGeneration(generateDTO); + } + + // 3、处理标准生成流程 + return handleStandardGeneration(generateDTO); + } + +// ============== 以下是辅助方法 ============== + + /** + * 参数校验 + */ + private void validateRequiredParams(GenerateThroughImageTextDTO generateDTO) { if (Objects.isNull(generateDTO.getUserId())) { throw new BusinessException("userId cannot be empty"); } - CreditsEventsEnum creditsEventsEnum = CreditsEventsEnum.OTHER; - if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) && generateDTO.getModelName().equals("wx")) { - String taskId = createAsyncTask(generateDTO); - creditsEventsEnum = CreditsEventsEnum.WX_TEXT2IMG; - // 6、添加预扣除积分到redis - creditsService.addRecordToCreditsDeduction(generateDTO.getUserId(), taskId, creditsEventsEnum); - // 6.1 添加积分扣除记录到db - creditsService.preInsert(generateDTO.getUserId(), creditsEventsEnum.getName(), taskId, Boolean.TRUE, null); + // Printboard必须要有level2Type + if (generateDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) + && StringUtil.isNullOrEmpty(generateDTO.getLevel2Type())) { + throw new BusinessException("level2Type.cannot.be.empty"); + } + } - // 7、返回唯一id - return new PrepareForGenerateVO(Collections.singletonList(taskId), 2); - } else if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) - && generateDTO.getModelName().equals("flux") && generateDTO.getLevel2Type().equals("Pattern")){ - String imagePath = null; - if (Objects.nonNull(generateDTO.getCollectionElementId()) && !StringUtil.isNullOrEmpty(generateDTO.getDesignType())){ - switch (generateDTO.getDesignType()){ - case "collection": - CollectionElement collectionElement = collectionElementMapper.selectById(generateDTO.getCollectionElementId()); - if (Objects.nonNull(collectionElement)){ - imagePath = collectionElement.getUrl(); - } - break; - case "library": - Library library = libraryService.getById(generateDTO.getCollectionElementId()); - if (Objects.nonNull(library)){ - imagePath = library.getUrl(); - } - } - } - String taskId = flux(PATTERN, generateDTO.getText(), imagePath, false); - Generate generate = CopyUtil.copyObject(generateDTO, Generate.class); - generate.setAccountId(generateDTO.getUserId()); - generate.setUniqueId(taskId); - generate.setElementSource(generateDTO.getDesignType()); - generate.setElementId(generateDTO.getCollectionElementId()); - String generateType; - if (Objects.nonNull(generateDTO.getCollectionElementId()) && !StringUtil.isNullOrEmpty(generateDTO.getText())){ - generateType = "text-image"; - } else if (Objects.nonNull(generateDTO.getCollectionElementId()) && StringUtil.isNullOrEmpty(generateDTO.getText())){ - generateType = "image"; - } else { - generateType = "text"; - } - generate.setGenerateType(generateType); - generate.setModelName("flux"); - generate.setCreateDate(new Date()); - save(generate); + /** + * 判断是否为wx模型 + */ + private boolean isWxModel(GenerateThroughImageTextDTO generateDTO) { + return !StringUtil.isNullOrEmpty(generateDTO.getModelName()) + && "wx".equals(generateDTO.getModelName()); + } - creditsEventsEnum = CreditsEventsEnum.FLUX_IMG2IMG; - // 6、添加预扣除积分到redis - creditsService.addRecordToCreditsDeduction(generateDTO.getUserId(), taskId, creditsEventsEnum); - // 6.1 添加积分扣除记录到db - creditsService.preInsert(generateDTO.getUserId(), creditsEventsEnum.getName(), taskId, Boolean.TRUE, null); - // 7、返回唯一id - return new PrepareForGenerateVO(Collections.singletonList(taskId), 2); + /** + * 处理wx模型生成 + */ + private PrepareForGenerateVO handleWxModelGeneration(GenerateThroughImageTextDTO generateDTO) { + String taskId = createAsyncTask(generateDTO); + processCreditDeduction(generateDTO.getUserId(), taskId, CreditsEventsEnum.WX_TEXT2IMG); + return new PrepareForGenerateVO(Collections.singletonList(taskId), 2); + } + + /** + * 判断是否为flux模型且类型为Pattern + */ + private boolean isFluxPatternModel(GenerateThroughImageTextDTO generateDTO) { + return !StringUtil.isNullOrEmpty(generateDTO.getModelName()) + && "flux".equals(generateDTO.getModelName()) + && "Pattern".equals(generateDTO.getLevel2Type()); + } + + /** + * 处理flux pattern生成 + */ + private PrepareForGenerateVO handleFluxPatternGeneration(GenerateThroughImageTextDTO generateDTO) { + // 获取图片路径 + String imagePath = getImagePathForFlux(generateDTO); + + // 创建生成任务 + String taskId = flux(PATTERN, generateDTO.getText(), imagePath, false); + + // 保存生成记录 + saveGenerateRecord(generateDTO, taskId, imagePath); + + // 处理积分扣除 + processCreditDeduction(generateDTO.getUserId(), taskId, CreditsEventsEnum.FLUX_IMG2IMG); + + return new PrepareForGenerateVO(Collections.singletonList(taskId), 2); + } + + /** + * 获取flux模型需要的图片路径 + */ + private String getImagePathForFlux(GenerateThroughImageTextDTO generateDTO) { + if (Objects.isNull(generateDTO.getCollectionElementId()) + || StringUtil.isNullOrEmpty(generateDTO.getDesignType())) { + return null; } + switch (generateDTO.getDesignType()) { + case "collection": + CollectionElement element = collectionElementMapper.selectById(generateDTO.getCollectionElementId()); + return element != null ? element.getUrl() : null; + case "library": + Library library = libraryService.getById(generateDTO.getCollectionElementId()); + return library != null ? library.getUrl() : null; + default: + return null; + } + } + + /** + * 保存生成记录 + */ + private void saveGenerateRecord(GenerateThroughImageTextDTO generateDTO, String taskId, String imagePath) { + Generate generate = CopyUtil.copyObject(generateDTO, Generate.class); + generate.setAccountId(generateDTO.getUserId()); + generate.setUniqueId(taskId); + generate.setElementSource(generateDTO.getDesignType()); + generate.setElementId(generateDTO.getCollectionElementId()); + + // 确定生成类型 + String generateType = determineGenerateType(generateDTO); + generate.setGenerateType(generateType); + generate.setModelName("flux"); + generate.setCreateDate(new Date()); + + save(generate); + } + + /** + * 确定生成类型 + */ + private String determineGenerateType(GenerateThroughImageTextDTO generateDTO) { + if (Objects.nonNull(generateDTO.getCollectionElementId())) { + return StringUtil.isNullOrEmpty(generateDTO.getText()) ? "image" : "text-image"; + } + return "text"; + } + + /** + * 处理标准生成流程 + */ + private PrepareForGenerateVO handleStandardGeneration(GenerateThroughImageTextDTO generateDTO) { + // 确定积分事件和生成次数 + GenerationConfig config = determineGenerationConfig(generateDTO); + + // 校验积分是否足够 + validateCredits(config.creditsEvent); + + // 创建生成任务 + List taskIds = createGenerationTasks(generateDTO, config.times); + + // 处理积分扣除(使用第一个任务的UUID前缀) + processCreditDeduction(generateDTO.getUserId(), taskIds.get(0).split("-")[0], config.creditsEvent); + + return new PrepareForGenerateVO(taskIds, 2); + } + + /** + * 确定生成配置(积分事件和生成次数) + */ + private GenerationConfig determineGenerationConfig(GenerateThroughImageTextDTO generateDTO) { + CreditsEventsEnum creditsEvent = CreditsEventsEnum.OTHER; int times = 4; - // 当level1Type为Print_board时,level2Type为pattern时需要确定generateType - if (generateDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())) { - if (StringUtil.isNullOrEmpty(generateDTO.getLevel2Type())) { - throw new BusinessException("level2Type.cannot.be.empty"); - } else if (!CollectionLevel2TypeEnum.printType().contains(generateDTO.getLevel2Type())) { - throw new BusinessException("unknown.parameter.level2Type"); - } - // Pattern 参数校验 - if (generateDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.Pattern.getRealName())) { - String text = generateDTO.getText(); - Long elementId = generateDTO.getCollectionElementId(); - Generate generate = new Generate(); - validateGeneraType(generate, text, elementId); - // 校验后获取 - generateDTO.setGenerateType(generate.getGenerateType()); -// creditsEventsEnum = CreditsEventsEnum.PATTERN; - creditsEventsEnum = CreditsEventsEnum.PATTERN; - - // 模型迁移SD1.? -> flux,从而产生了不同模型的选择, - // high -> 生成图片质量高,但生成速度慢,每次生成只返回一张图片 - // fast -> 生成图片质量低,但生成速度快,每次生成返回四张图片 - if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) && generateDTO.getModelName().equals("high")) { - creditsEventsEnum = CreditsEventsEnum.LOCAL_TEXT2IMG_HIGH; + // 根据不同类型确定配置 + // high -> 生成图片质量高,但生成速度慢,每次生成只返回一张图片 + // fast -> 生成图片质量低,但生成速度快,每次生成返回四张图片 + switch (generateDTO.getLevel1Type()) { + case "Printboard": + GenerationConfig generationConfig = handlePrintboardConfig(generateDTO); + creditsEvent = generationConfig.creditsEvent; + times = generationConfig.times; + break; + case "Moodboard": + creditsEvent = CreditsEventsEnum.MOOD_BOARD; + if (isHighModel(generateDTO)) { + creditsEvent = CreditsEventsEnum.LOCAL_TEXT2IMG_HIGH; times = 1; } - } - // Slogan 参数校验 slogan目前只能开一个接口。所以只有生产环境上能使用 - if (generateDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.SLOGAN.getRealName())) { - if (StringUtil.isNullOrEmpty(generateDTO.getSloganBase64())) { - log.error("Printboard-Slogan模式下,slogan image为空"); - throw new BusinessException("slogan.image.cannot.be.empty"); + break; + case "Sketchboard": + creditsEvent = CreditsEventsEnum.SKETCH_BOARD; + if (isHighModel(generateDTO)) { + creditsEvent = CreditsEventsEnum.LOCAL_TEXT2IMG_HIGH; + times = 1; } - - if (StringUtil.isNullOrEmpty(generateDTO.getText())) { - log.error("Printboard-Slogan模式下,slogan text为空"); - throw new BusinessException("slogan.style.cannot.be.empty"); - } - - times = 1; - // 将图片上传到图片服务器 - String path = minioUtil.base64UploadToPath(generateDTO.getSloganBase64(), sloganBucket, null); - String name = path.substring(path.lastIndexOf("/") + 1, path.lastIndexOf(".")); - // 保存到db,collection-element - CollectionElement collectionElement = new CollectionElement(); - collectionElement.setAccountId(generateDTO.getUserId()); - collectionElement.setCollectionId(0L); - collectionElement.setLevel1Type(PRINT_BOARD.getRealName()); - collectionElement.setLevel2Type(CollectionLevel2TypeEnum.SLOGAN.getRealName()); - collectionElement.setName(name); - collectionElement.setUrl(path); - collectionElement.setHasPin((byte) 0); - collectionElement.setMd5(MD5Utils.encryptFile(minioUtil.getPreSignedUrl(path, 24 * 60), Boolean.FALSE)); - collectionElement.setCreateDate(DateUtil.getByTimeZone(generateDTO.getTimeZone())); - collectionElementService.save(collectionElement); - - // 将上传后的地址放在指定字段 - generateDTO.setCollectionElementId(collectionElement.getId()); - generateDTO.setSloganBase64(null); - generateDTO.setDesignType("collection"); - creditsEventsEnum = CreditsEventsEnum.SLOGAN; - } - - // Logo参数校验 - if (generateDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.LOGO.getRealName())) { - // logo模式下一次只生成一张 - times = 1; - // 校验是否输入内容 - if (StringUtil.isNullOrEmpty(generateDTO.getText().trim())) { - throw new BusinessException("please.input.the.prompt"); - } - - // 校验seed的取值范围 - int seed = random.nextInt(501); - log.info("随机种子:{}", seed); - generateDTO.setSeed(String.valueOf(seed)); - - creditsEventsEnum = CreditsEventsEnum.LOGO; - } - } else if (generateDTO.getLevel1Type().equals(MOOD_BOARD.getRealName())) { - creditsEventsEnum = CreditsEventsEnum.MOOD_BOARD; - if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) && generateDTO.getModelName().equals("high")) { - creditsEventsEnum = CreditsEventsEnum.LOCAL_TEXT2IMG_HIGH; - times = 1; - } - } else if (generateDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName())) { - creditsEventsEnum = CreditsEventsEnum.SKETCH_BOARD; - if (!StringUtil.isNullOrEmpty(generateDTO.getModelName()) && generateDTO.getModelName().equals("high")) { - creditsEventsEnum = CreditsEventsEnum.LOCAL_TEXT2IMG_HIGH; - times = 1; - } + break; } - // 2、判断用户当前积分是否够本次生成消耗 - Boolean preDeduction = creditsService.creditsPreDeduction(creditsEventsEnum, 1); - if (!preDeduction) { + return new GenerationConfig(creditsEvent, times); + } + + /** + * 处理Printboard的特殊配置 + */ + private GenerationConfig handlePrintboardConfig(GenerateThroughImageTextDTO generateDTO) { + String level2Type = generateDTO.getLevel2Type(); + CreditsEventsEnum creditsEvent = CreditsEventsEnum.OTHER; + int times = 4; + + if (!CollectionLevel2TypeEnum.printType().contains(level2Type)) { + throw new BusinessException("unknown.parameter.level2Type"); + } + + switch (level2Type) { + case "Pattern": + // Pattern参数校验 + Generate generate = new Generate(); + validateGeneraType(generate, generateDTO.getText(), generateDTO.getCollectionElementId()); + generateDTO.setGenerateType(generate.getGenerateType()); + + creditsEvent = CreditsEventsEnum.PATTERN; + if (isHighModel(generateDTO)) { + creditsEvent = CreditsEventsEnum.LOCAL_TEXT2IMG_HIGH; + times = 1; + } + break; + + case "Slogan": + validateSloganParams(generateDTO); + processSloganImage(generateDTO); + creditsEvent = CreditsEventsEnum.SLOGAN; + times = 1; + break; + + case "Logo": + validateLogoParams(generateDTO); + generateDTO.setSeed(String.valueOf(random.nextInt(501))); + creditsEvent = CreditsEventsEnum.LOGO; + times = 1; + break; + } + + return new GenerationConfig(creditsEvent, times); + } + + /** + * 校验Slogan参数 + */ + private void validateSloganParams(GenerateThroughImageTextDTO generateDTO) { + if (StringUtil.isNullOrEmpty(generateDTO.getSloganBase64())) { + log.error("Printboard-Slogan模式下,slogan image为空"); + throw new BusinessException("slogan.image.cannot.be.empty"); + } + if (StringUtil.isNullOrEmpty(generateDTO.getText())) { + log.error("Printboard-Slogan模式下,slogan text为空"); + throw new BusinessException("slogan.style.cannot.be.empty"); + } + } + + /** + * 处理Slogan图片上传 + */ + private void processSloganImage(GenerateThroughImageTextDTO generateDTO) { + // 上传图片到服务器 + String path = minioUtil.base64UploadToPath(generateDTO.getSloganBase64(), sloganBucket, null); + String name = path.substring(path.lastIndexOf("/") + 1, path.lastIndexOf(".")); + + // 保存到数据库 + CollectionElement element = new CollectionElement(); + element.setAccountId(generateDTO.getUserId()); + element.setCollectionId(0L); + element.setLevel1Type(PRINT_BOARD.getRealName()); + element.setLevel2Type(CollectionLevel2TypeEnum.SLOGAN.getRealName()); + element.setName(name); + element.setUrl(path); + element.setHasPin((byte) 0); + element.setMd5(MD5Utils.encryptFile(minioUtil.getPreSignedUrl(path, 24 * 60), Boolean.FALSE)); + element.setCreateDate(DateUtil.getByTimeZone(generateDTO.getTimeZone())); + collectionElementService.save(element); + + // 更新DTO + generateDTO.setCollectionElementId(element.getId()); + generateDTO.setSloganBase64(null); + generateDTO.setDesignType("collection"); + } + + /** + * 校验Logo参数 + */ + private void validateLogoParams(GenerateThroughImageTextDTO generateDTO) { + if (StringUtil.isNullOrEmpty(generateDTO.getText().trim())) { + throw new BusinessException("please.input.the.prompt"); + } + } + + /** + * 判断是否为high模型 + */ + private boolean isHighModel(GenerateThroughImageTextDTO generateDTO) { + return !StringUtil.isNullOrEmpty(generateDTO.getModelName()) + && "high".equals(generateDTO.getModelName()); + } + + /** + * 校验积分是否足够 + */ + private void validateCredits(CreditsEventsEnum creditsEvent) { + if (!creditsService.creditsPreDeduction(creditsEvent, 1)) { throw new BusinessException("remaining.credits.insufficient", ResultEnum.WARNING.getCode()); } + } - // 3、生成唯一id 使用uuid,由于uuid重复的几率很小,故取消对uuid重复性的校验 + /** + * 创建生成任务 + */ + private List createGenerationTasks(GenerateThroughImageTextDTO generateDTO, int times) { String uuid = UUID.randomUUID().toString(); + List taskIds = new ArrayList<>(); - // 除了 Moodboard || Printboard->Pattern(可以区分三种风格) || Sketchboard(Generate Sketch)这三个地方需要区分high || fast之外,其他地方保持原样 - if (generateDTO.getLevel1Type().equals("Printboard") && !generateDTO.getLevel2Type().equals("Pattern")) { + // 特殊处理:某些情况下需要清空modelName + if ("Printboard".equals(generateDTO.getLevel1Type()) + && !"Pattern".equals(generateDTO.getLevel2Type())) { + // Logo 和 Slogan 没有模型可选 generateDTO.setModelName(null); } - ArrayList taskIdList = new ArrayList<>(); for (int i = 1; i <= times; i++) { - String taskId = uuid; - taskId += "-" + i + "-" + generateDTO.getUserId(); - taskIdList.add(taskId); + String taskId = uuid + "-" + i + "-" + generateDTO.getUserId(); + taskIds.add(taskId); generateDTO.setUniqueId(taskId); + + // 序列化为JSON String jsonString = JSON.toJSONString(generateDTO); - // 4、加入redis排队,便于获取实时排队信息 - Double maxScore = redisUtil.getMaxScore(consumptionOrderKey); - redisUtil.addToZSet(consumptionOrderKey, taskId, maxScore); - - // 加入resultMap - String key = generateResultKey + ":" + taskId; - GenerateResultVO generateResultVO = new GenerateResultVO(generateDTO.getUniqueId(), null, null, "Waiting"); - redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); - - // 5、将消息发布到MQ消息队列 - rabbitMQService.publishMessageToGenerate(jsonString); + // 加入Redis队列 + addToRedisQueue(taskId, jsonString); } - // update 积分扣除由按次收费改为按生成图片数量收费 --> 改回按次收费 - // 6、添加预扣除积分到redis - creditsService.addRecordToCreditsDeduction(generateDTO.getUserId(), uuid, creditsEventsEnum); - // 6.1 添加积分扣除记录到db - creditsService.preInsert(generateDTO.getUserId(), creditsEventsEnum.getName(), uuid, Boolean.TRUE, null); + return taskIds; + } - // 7、返回唯一id - return new PrepareForGenerateVO(taskIdList, 2); + /** + * 添加到Redis队列 + */ + private void addToRedisQueue(String taskId, String jsonString) { + // 加入排队队列 +// Double maxScore = redisUtil.getMaxScore(consumptionOrderKey); +// redisUtil.addToZSet(consumptionOrderKey, taskId, maxScore); + + // 加入结果映射 + String key = generateResultKey + ":" + taskId; + GenerateResultVO resultVO = new GenerateResultVO(taskId, null, null, "Waiting"); + redisUtil.addToString(key, new Gson().toJson(resultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); + + // 发布到MQ + rabbitMQService.publishMessageToGenerate(jsonString); + } + + /** + * 处理积分扣除 + */ + private void processCreditDeduction(Long userId, String taskId, CreditsEventsEnum creditsEvent) { + // 添加到Redis + creditsService.addRecordToCreditsDeduction(userId, taskId, creditsEvent); + // 预插入到数据库 + creditsService.preInsert(userId, creditsEvent.getName(), taskId, Boolean.TRUE, null); + } + +// ============== 配置类 ============== + + /** + * 生成任务配置类 + * 包含积分事件类型和生成次数 + */ + private static class GenerationConfig { + final CreditsEventsEnum creditsEvent; + final int times; + + GenerationConfig(CreditsEventsEnum creditsEvent, int times) { + this.creditsEvent = creditsEvent; + this.times = times; + } } @Override