From 174d1bf0d0dcfe2e3e33eda72cd863d9c2aedb68 Mon Sep 17 00:00:00 2001 From: xupei Date: Fri, 13 Jun 2025 16:37:45 +0800 Subject: [PATCH] =?UTF-8?q?TASK:1=E3=80=81=E5=B0=86imageToSketch=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E8=B0=83=E7=94=A8=E8=BD=AC=E4=B8=BA=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=202=E3=80=81imageToSketch=E5=8A=A0=E5=85=A5flux?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../com/ai/da/common/config/AsyncConfig.java | 21 ++ .../ai/da/controller/GenerateController.java | 4 +- .../mapper/primary/entity/GenerateDetail.java | 8 + .../com/ai/da/model/dto/ImageToSketchDTO.java | 3 + .../com/ai/da/model/vo/GenerateResultVO.java | 5 + .../com/ai/da/service/GenerateService.java | 4 +- .../da/service/impl/GenerateServiceImpl.java | 185 +++++++++++++----- .../impl/UserLikeGroupServiceImpl.java | 6 +- 8 files changed, 180 insertions(+), 56 deletions(-) create mode 100644 src/main/java/com/ai/da/common/config/AsyncConfig.java diff --git a/src/main/java/com/ai/da/common/config/AsyncConfig.java b/src/main/java/com/ai/da/common/config/AsyncConfig.java new file mode 100644 index 00000000..6aa86e4e --- /dev/null +++ b/src/main/java/com/ai/da/common/config/AsyncConfig.java @@ -0,0 +1,21 @@ +package com.ai.da.common.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; + +import java.util.concurrent.Executor; + +@Configuration +public class AsyncConfig { + @Bean("asyncTaskExecutor") + public Executor asyncTaskExecutor() { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + executor.setCorePoolSize(5); + executor.setMaxPoolSize(10); + executor.setQueueCapacity(100); + executor.setThreadNamePrefix("Async-ImageToSketch-"); + executor.initialize(); + return executor; + } +} diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java index 837d535c..55e34619 100644 --- a/src/main/java/com/ai/da/controller/GenerateController.java +++ b/src/main/java/com/ai/da/controller/GenerateController.java @@ -87,7 +87,7 @@ public class GenerateController { @ApiOperation(value = "imageToSketch") @PostMapping("/imageToSketch") - public Response imageToSketch(@Valid @RequestBody ImageToSketchDTO imageToSketchDTO) { + public Response imageToSketch(@Valid @RequestBody ImageToSketchDTO imageToSketchDTO) { return Response.success(generateService.imageToSketchAsync(imageToSketchDTO, null, null)); // return Response.success(generateService.imageToSketch(imageToSketchDTO, null, null)); } @@ -209,7 +209,7 @@ public class GenerateController { // @ApiOperation(value = "获取flux结果") // @GetMapping("/fluxResult") public Response fluxResult(@RequestParam("taskId") String taskId){ - return Response.success(generateService.getFluxResult(taskId, 87L)); + return Response.success(generateService.getFluxResult(taskId, "87/" + taskId + ".png")); } diff --git a/src/main/java/com/ai/da/mapper/primary/entity/GenerateDetail.java b/src/main/java/com/ai/da/mapper/primary/entity/GenerateDetail.java index 6cf9f0b3..8b8a36b0 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/GenerateDetail.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/GenerateDetail.java @@ -58,5 +58,13 @@ public class GenerateDetail { */ private Date updateDate; + public GenerateDetail() { + } + public GenerateDetail(Long generateId, String url, String md5, LocalDateTime createDate) { + this.generateId = generateId; + this.url = url; + this.md5 = md5; + this.createDate = createDate; + } } diff --git a/src/main/java/com/ai/da/model/dto/ImageToSketchDTO.java b/src/main/java/com/ai/da/model/dto/ImageToSketchDTO.java index e356a7d1..85949622 100644 --- a/src/main/java/com/ai/da/model/dto/ImageToSketchDTO.java +++ b/src/main/java/com/ai/da/model/dto/ImageToSketchDTO.java @@ -20,6 +20,9 @@ public class ImageToSketchDTO { @ApiModelProperty("性别") private String gender; + @ApiModelProperty("模型名") + private String modelName; + public ImageToSketchDTO() { } diff --git a/src/main/java/com/ai/da/model/vo/GenerateResultVO.java b/src/main/java/com/ai/da/model/vo/GenerateResultVO.java index 1865cf1a..4dd1cc5e 100644 --- a/src/main/java/com/ai/da/model/vo/GenerateResultVO.java +++ b/src/main/java/com/ai/da/model/vo/GenerateResultVO.java @@ -35,4 +35,9 @@ public class GenerateResultVO { this.status = status; this.category = category; } + + public GenerateResultVO(String taskId, String status) { + this.taskId = taskId; + this.status = status; + } } diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index 8fa067e2..a4627212 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -45,7 +45,7 @@ public interface GenerateService extends IService { GenerateResultVO imageToSketch(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId); - GenerateResultVO imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId); + String imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId); GenerateResultVO modifySketch(GenerateModifyDTO generateModifyDTO); @@ -85,7 +85,7 @@ public interface GenerateService extends IService { String flux(CreditsEventsEnum func, String prompt, String imagePath); - String getFluxResult(String taskId, Long accountId); + String getFluxResult(String taskId, String objectName); byte[] downloadVideoOrImage(String url); } diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 42923843..280ba2b6 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -56,6 +56,8 @@ import java.io.*; import java.math.BigDecimal; import java.time.LocalDateTime; import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -685,6 +687,9 @@ public class GenerateServiceImpl extends ServiceImpl i boolean flag = true; String type = null; for (String taskId : taskIdList) { + String key = generateResultKey + ":" + taskId; + GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class); + if (flag) { type = resolveModelType(taskId); flag = false; @@ -692,11 +697,14 @@ public class GenerateServiceImpl extends ServiceImpl i // 暂定万象每次生成1个 if (type.equals("wx")){ return Collections.singletonList(getAsyncTaskResult(taskId)); + } else if (type.equals("freepik")){ + results.add(generateResultVO); + continue; + } else if (type.equals("flux")){ + results.add(getFluxResultAndSave(taskId)); + continue; } - String key = generateResultKey + ":" + taskId; - GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class); - if (generateResultVO != null && !StringUtil.isNullOrEmpty(generateResultVO.getUrl())) { String url = generateResultVO.getUrl(); if (url.substring(url.lastIndexOf("/") + 1).equals("white_image.jpg")) { @@ -886,7 +894,8 @@ public class GenerateServiceImpl extends ServiceImpl i // 线稿提取 String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode); // 存数据库 - Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, accountId, styleCode); + Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, + accountId, styleCode, "local", "0"); GenerateResultVO generateResultVO = saveExtractSketchResult(generate, sketchPath, imageToSketchDTO.getGender()); // 积分扣除 doCreditsSubtract(accountId, event); @@ -921,16 +930,18 @@ public class GenerateServiceImpl extends ServiceImpl i } private Generate saveExtractSketchRequest(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, - Long projectId, Long accountId, String styleCode){ + Long projectId, Long accountId, String styleCode, + String modelName, String taskId){ // 存DB Generate generate = new Generate(); generate.setAccountId(accountId); - generate.setUniqueId(String.valueOf(0)); + generate.setUniqueId(taskId); generate.setLevel1Type(SKETCH_BOARD.getRealName()); generate.setLevel2Type("ImageToSketch"); generate.setElementSource("collection"); generate.setElementId(imageToSketchDTO.getElementId()); - generate.setGenerateType("image"); + generate.setGenerateType("image(" + imageToSketchDTO.getGender() + ")"); + generate.setModelName(modelName); generate.setSketchStyle(styleCode); generate.setStyleImageElementId(imageToSketchDTO.getStyleImageId()); generate.setProjectId(projectId); @@ -962,52 +973,92 @@ public class GenerateServiceImpl extends ServiceImpl i creditsService.preInsert(accountId, event.getName(), null, Boolean.FALSE, event.getValue()); } - // freepik以后会变成异步的吗? 目前同步 - public GenerateResultVO imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId){ + // 注入线程池(可在配置类中定义) + @Resource + private Executor asyncTaskExecutor; + + public String imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId) { Long accountId = UserContext.getUserHolder().getId(); log.info("imageToSketch parameter : {}", imageToSketchDTO); // 检查积分是否够本次扣除 CreditsEventsEnum event = CreditsEventsEnum.IMAGE_TO_SKETCH; Boolean b = creditsService.checkCredits(accountId, event, 1); - if (!b){ + if (!b) { throw new BusinessException("remaining.credits.insufficient", ResultEnum.PROMPT.getCode()); } + // 生成唯一任务ID + String taskId; + if (!StringUtil.isNullOrEmpty(imageToSketchDTO.getModelName()) + && imageToSketchDTO.getModelName().equals("flux")){ + String imagePath; + // todo 拼贴图的线稿提取是否能用flux + if (StringUtil.isNullOrEmpty(collagePictureUrl)){ + CollectionElement collectionElement = collectionElementService.getById(imageToSketchDTO.getElementId()); + imagePath = collectionElement.getUrl(); + }else { + imagePath = collagePictureUrl; + } + taskId = flux(CreditsEventsEnum.IMAGE_TO_SKETCH, null, imagePath); + // 存数据库 + saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, + accountId, imageToSketchDTO.getStyle(), "flux", taskId); + return taskId; + } + + taskId = UUID.randomUUID().toString(); + // 异步执行耗时操作 + CompletableFuture.runAsync(() -> { + try { + processImageToSketch(taskId, imageToSketchDTO, collagePictureUrl, projectId, accountId, event); + } catch (Exception e) { + log.error("异步处理图片转sketch失败, taskId: {}", taskId, e); + // 更新redis + redisUtil.addToString(generateResultKey + ":" + taskId, new Gson().toJson(new GenerateResultVO(taskId, "Failed")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); + } + }, asyncTaskExecutor); + return taskId; + } + + private void processImageToSketch(String taskId, ImageToSketchDTO imageToSketchDTO, + String collagePictureUrl, Long projectId, + Long accountId, CreditsEventsEnum event) throws IOException { + // 设置任务状态为处理中 + redisUtil.addToString(generateResultKey + ":" + taskId, new Gson().toJson(new GenerateResultVO(taskId, "Executing")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); String style = imageToSketchDTO.getStyle(); String styleCode = style.equals(SketchStyle.THICK.getValue()) ? "1" : style.equals(SketchStyle.MEDIUM.getValue()) ? "2" : style.equals(SketchStyle.THIN.getValue()) ? "3" : "Custom"; - + // 请求记录存数据库 + Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, + accountId, styleCode, "freepik", taskId); + // 1、初步提取结果 String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode); - + // 2、获取输入图的描述 String imageDescription = getImageDescription(sketchPath); - - try { - // 请求freepik reimage - String dataStr = reimagineFreePik(sketchPath, imageDescription, "vivid"); - if (StringUtil.isNullOrEmpty(dataStr)){ - throw new BusinessException("extract sketch failed"); - } - JSONObject data = JSONUtil.parseObj(dataStr); - String upgradeImageUrl = data.getBeanList("generated", String.class).get(0); - String taskId = data.getStr("task_id"); - - // 下载图片 freepik -// byte[] bytes = downloadWithProxy(upgradeImageUrl); - byte[] bytes = downloadVideoOrImage(upgradeImageUrl); - // 2、上传图片到minio保存 - String objectName = accountId + "/imageToSketch/" + taskId + ".png"; - minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png"); - // 存数据库 - Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, accountId, styleCode); - GenerateResultVO generateResultVO = saveExtractSketchResult(generate, userBucket + "/" + objectName, imageToSketchDTO.getGender()); - // 积分扣除 - doCreditsSubtract(accountId, event); - return generateResultVO; - } catch (IOException e) { - throw new RuntimeException(e); + // 3、请求freepik reimage + String dataStr = reimagineFreePik(sketchPath, imageDescription, "vivid"); + if (StringUtil.isNullOrEmpty(dataStr)) { + throw new BusinessException("extract sketch failed"); } + + JSONObject data = JSONUtil.parseObj(dataStr); + String upgradeImageUrl = data.getBeanList("generated", String.class).get(0); + String freepikTaskId = data.getStr("task_id"); + + // 4、下载图片 + byte[] bytes = downloadVideoOrImage(upgradeImageUrl); +// byte[] bytes = downloadWithProxy(upgradeImageUrl); + // 5、上传图片到minio保存 + String objectName = accountId + "/imageToSketch/" + freepikTaskId + ".png"; + minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png"); + // 6、保存结果到db + GenerateResultVO generateResultVO = saveExtractSketchResult(generate, userBucket + "/" + objectName, imageToSketchDTO.getGender()); + // 7、积分扣除 + doCreditsSubtract(accountId, event); + // 8、将结果存入Redis + redisUtil.addToString(generateResultKey + ":" + taskId, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); } // 对提取出来的sketch做调整 @@ -1835,8 +1886,8 @@ public class GenerateServiceImpl extends ServiceImpl i public byte[] downloadWithProxy(String url) throws IOException { // 获取系统代理设置(适用于大多数VPN) // String proxyHost = System.getProperty("http.proxyHost"); - String proxyHost = "localhost"; // String proxyPort = System.getProperty("http.proxyPort"); + String proxyHost = "localhost"; String proxyPort = "7890"; CloseableHttpClient client; @@ -1954,18 +2005,15 @@ public class GenerateServiceImpl extends ServiceImpl i private String resolveModelType(String taskId){ // 判断当前task来自哪个模型 - // 判断taskId的结构 - int count = StringUtils.countMatches(taskId, "-"); - String lastPart = taskId.substring(taskId.lastIndexOf("-") + 1); - String type; - if (count == 4 && lastPart.length() == 12){ - // 万象 - type = "wx"; + Generate generate = selectByUniqueId(taskId); + if (!StringUtil.isNullOrEmpty(generate.getModelName()) && + (generate.getModelName().equals("wx") + || generate.getModelName().equals("freepik") + || generate.getModelName().equals("flux") )){ + return generate.getModelName(); }else { - // 本地部署的模型 - type = "local"; + return "local"; } - return type; } public static String extractGender(String text) { @@ -2027,7 +2075,7 @@ public class GenerateServiceImpl extends ServiceImpl i return respObj.getStr("id"); } - public String getFluxResult(String taskId, Long accountId){ + public String getFluxResult(String taskId, String objectName){ String fluxResultRequestUrl = "https://api.bfl.ai/v1/get_result"; HashMap params = new HashMap<>(); params.put("id", taskId); @@ -2047,7 +2095,6 @@ public class GenerateServiceImpl extends ServiceImpl i // 已完成 获取结果 String fluxResult = JSONUtil.parseObj(respObj.getStr("result")).getStr("sample"); byte[] bytes = downloadVideoOrImage(fluxResult); - String objectName = accountId + "/product_image/" + taskId + ".png"; minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png"); // return minioUtil.getPreSignedUrl(userBucket + "/" + objectName, CommonConstant.MINIO_IMAGE_EXPIRE_TIME); @@ -2058,4 +2105,42 @@ public class GenerateServiceImpl extends ServiceImpl i } return null; } + + private GenerateResultVO getFluxResultAndSave(String taskId){ + Generate generate = selectByUniqueId(taskId); + if (Objects.nonNull(generate)){ + GenerateDetail generateDetail = generateDetailMapper.selectOne(new QueryWrapper().eq("generate_id", generate.getId())); + Long accountId = generate.getAccountId(); + String objectName = accountId + "/imageToSketch/" + taskId + ".png"; + String fluxResult = getFluxResult(taskId, objectName); + if (Objects.isNull(generateDetail)){ + if (fluxResult.equals("Failed") || fluxResult.equals("Pending")){ + String status = fluxResult.equals("Failed") ? "Failed" : "Executing"; + return new GenerateResultVO(taskId, status); + } + + generateDetail = new GenerateDetail(generate.getId(), fluxResult, + MD5Utils.encryptFile( + minioUtil.getPreSignedUrl(fluxResult, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), false), + LocalDateTime.now()); + generateDetailMapper.insert(generateDetail); + } else if (StringUtil.isNullOrEmpty(generateDetail.getUrl())){ + // 一般来说这条线应该走不到 + generateDetail.setGenerateId(generate.getId()); + generateDetail.setUrl(fluxResult); + generateDetail.setMd5(MD5Utils.encryptFile( + minioUtil.getPreSignedUrl(fluxResult, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), false)); + generateDetail.setUpdateDate(new Date()); + generateDetailMapper.updateById(generateDetail); + } + String url = generateDetail.getUrl(); + String clothCategory = pythonService.getClothCategory(url, extractGender(generate.getGenerateType())); + return new GenerateResultVO(taskId, generateDetail.getId(), + minioUtil.getPreSignedUrl(url, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), "Success", clothCategory); + }else { + throw new BusinessException("unknown generate"); + } + } + + } diff --git a/src/main/java/com/ai/da/service/impl/UserLikeGroupServiceImpl.java b/src/main/java/com/ai/da/service/impl/UserLikeGroupServiceImpl.java index 2ca6efb3..215cbf3c 100644 --- a/src/main/java/com/ai/da/service/impl/UserLikeGroupServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/UserLikeGroupServiceImpl.java @@ -626,7 +626,8 @@ public class UserLikeGroupServiceImpl extends ServiceImpl