Merge branch 'dev/dev_xp' into dev/dev

This commit is contained in:
2025-06-13 16:41:15 +08:00
8 changed files with 180 additions and 56 deletions

View File

@@ -0,0 +1,21 @@
package com.ai.da.common.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.concurrent.Executor;
@Configuration
public class AsyncConfig {
@Bean("asyncTaskExecutor")
public Executor asyncTaskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(5);
executor.setMaxPoolSize(10);
executor.setQueueCapacity(100);
executor.setThreadNamePrefix("Async-ImageToSketch-");
executor.initialize();
return executor;
}
}

View File

@@ -87,7 +87,7 @@ public class GenerateController {
@ApiOperation(value = "imageToSketch") @ApiOperation(value = "imageToSketch")
@PostMapping("/imageToSketch") @PostMapping("/imageToSketch")
public Response<GenerateResultVO> imageToSketch(@Valid @RequestBody ImageToSketchDTO imageToSketchDTO) { public Response<String> imageToSketch(@Valid @RequestBody ImageToSketchDTO imageToSketchDTO) {
return Response.success(generateService.imageToSketchAsync(imageToSketchDTO, null, null)); return Response.success(generateService.imageToSketchAsync(imageToSketchDTO, null, null));
// return Response.success(generateService.imageToSketch(imageToSketchDTO, null, null)); // return Response.success(generateService.imageToSketch(imageToSketchDTO, null, null));
} }
@@ -209,7 +209,7 @@ public class GenerateController {
// @ApiOperation(value = "获取flux结果") // @ApiOperation(value = "获取flux结果")
// @GetMapping("/fluxResult") // @GetMapping("/fluxResult")
public Response<String> fluxResult(@RequestParam("taskId") String taskId){ public Response<String> fluxResult(@RequestParam("taskId") String taskId){
return Response.success(generateService.getFluxResult(taskId, 87L)); return Response.success(generateService.getFluxResult(taskId, "87/" + taskId + ".png"));
} }

View File

@@ -58,5 +58,13 @@ public class GenerateDetail {
*/ */
private Date updateDate; private Date updateDate;
public GenerateDetail() {
}
public GenerateDetail(Long generateId, String url, String md5, LocalDateTime createDate) {
this.generateId = generateId;
this.url = url;
this.md5 = md5;
this.createDate = createDate;
}
} }

View File

@@ -20,6 +20,9 @@ public class ImageToSketchDTO {
@ApiModelProperty("性别") @ApiModelProperty("性别")
private String gender; private String gender;
@ApiModelProperty("模型名")
private String modelName;
public ImageToSketchDTO() { public ImageToSketchDTO() {
} }

View File

@@ -35,4 +35,9 @@ public class GenerateResultVO {
this.status = status; this.status = status;
this.category = category; this.category = category;
} }
public GenerateResultVO(String taskId, String status) {
this.taskId = taskId;
this.status = status;
}
} }

View File

@@ -45,7 +45,7 @@ public interface GenerateService extends IService<Generate> {
GenerateResultVO imageToSketch(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId); GenerateResultVO imageToSketch(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId);
GenerateResultVO imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId); String imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId);
GenerateResultVO modifySketch(GenerateModifyDTO generateModifyDTO); GenerateResultVO modifySketch(GenerateModifyDTO generateModifyDTO);
@@ -85,7 +85,7 @@ public interface GenerateService extends IService<Generate> {
String flux(CreditsEventsEnum func, String prompt, String imagePath); String flux(CreditsEventsEnum func, String prompt, String imagePath);
String getFluxResult(String taskId, Long accountId); String getFluxResult(String taskId, String objectName);
byte[] downloadVideoOrImage(String url); byte[] downloadVideoOrImage(String url);
} }

View File

@@ -57,6 +57,8 @@ import java.io.*;
import java.math.BigDecimal; import java.math.BigDecimal;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.*; import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@@ -686,6 +688,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
boolean flag = true; boolean flag = true;
String type = null; String type = null;
for (String taskId : taskIdList) { for (String taskId : taskIdList) {
String key = generateResultKey + ":" + taskId;
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
if (flag) { if (flag) {
type = resolveModelType(taskId); type = resolveModelType(taskId);
flag = false; flag = false;
@@ -693,11 +698,14 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 暂定万象每次生成1个 // 暂定万象每次生成1个
if (type.equals("wx")){ if (type.equals("wx")){
return Collections.singletonList(getAsyncTaskResult(taskId)); return Collections.singletonList(getAsyncTaskResult(taskId));
} else if (type.equals("freepik")){
results.add(generateResultVO);
continue;
} else if (type.equals("flux")){
results.add(getFluxResultAndSave(taskId));
continue;
} }
String key = generateResultKey + ":" + taskId;
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
if (generateResultVO != null && !StringUtil.isNullOrEmpty(generateResultVO.getUrl())) { if (generateResultVO != null && !StringUtil.isNullOrEmpty(generateResultVO.getUrl())) {
String url = generateResultVO.getUrl(); String url = generateResultVO.getUrl();
if (url.substring(url.lastIndexOf("/") + 1).equals("white_image.jpg")) { if (url.substring(url.lastIndexOf("/") + 1).equals("white_image.jpg")) {
@@ -887,7 +895,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 线稿提取 // 线稿提取
String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode); String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode);
// 存数据库 // 存数据库
Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, accountId, styleCode); Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId,
accountId, styleCode, "local", "0");
GenerateResultVO generateResultVO = saveExtractSketchResult(generate, sketchPath, imageToSketchDTO.getGender()); GenerateResultVO generateResultVO = saveExtractSketchResult(generate, sketchPath, imageToSketchDTO.getGender());
// 积分扣除 // 积分扣除
doCreditsSubtract(accountId, event); doCreditsSubtract(accountId, event);
@@ -922,16 +931,18 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
} }
private Generate saveExtractSketchRequest(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, private Generate saveExtractSketchRequest(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl,
Long projectId, Long accountId, String styleCode){ Long projectId, Long accountId, String styleCode,
String modelName, String taskId){
// 存DB // 存DB
Generate generate = new Generate(); Generate generate = new Generate();
generate.setAccountId(accountId); generate.setAccountId(accountId);
generate.setUniqueId(String.valueOf(0)); generate.setUniqueId(taskId);
generate.setLevel1Type(SKETCH_BOARD.getRealName()); generate.setLevel1Type(SKETCH_BOARD.getRealName());
generate.setLevel2Type("ImageToSketch"); generate.setLevel2Type("ImageToSketch");
generate.setElementSource("collection"); generate.setElementSource("collection");
generate.setElementId(imageToSketchDTO.getElementId()); generate.setElementId(imageToSketchDTO.getElementId());
generate.setGenerateType("image"); generate.setGenerateType("image(" + imageToSketchDTO.getGender() + ")");
generate.setModelName(modelName);
generate.setSketchStyle(styleCode); generate.setSketchStyle(styleCode);
generate.setStyleImageElementId(imageToSketchDTO.getStyleImageId()); generate.setStyleImageElementId(imageToSketchDTO.getStyleImageId());
generate.setProjectId(projectId); generate.setProjectId(projectId);
@@ -963,52 +974,92 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
creditsService.preInsert(accountId, event.getName(), null, Boolean.FALSE, event.getValue()); creditsService.preInsert(accountId, event.getName(), null, Boolean.FALSE, event.getValue());
} }
// freepik以后会变成异步的吗 目前同步 // 注入线程池(可在配置类中定义)
public GenerateResultVO imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId){ @Resource
private Executor asyncTaskExecutor;
public String imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId) {
Long accountId = UserContext.getUserHolder().getId(); Long accountId = UserContext.getUserHolder().getId();
log.info("imageToSketch parameter : {}", imageToSketchDTO); log.info("imageToSketch parameter : {}", imageToSketchDTO);
// 检查积分是否够本次扣除 // 检查积分是否够本次扣除
CreditsEventsEnum event = CreditsEventsEnum.IMAGE_TO_SKETCH; CreditsEventsEnum event = CreditsEventsEnum.IMAGE_TO_SKETCH;
Boolean b = creditsService.checkCredits(accountId, event, 1); Boolean b = creditsService.checkCredits(accountId, event, 1);
if (!b){ if (!b) {
throw new BusinessException("remaining.credits.insufficient", ResultEnum.PROMPT.getCode()); throw new BusinessException("remaining.credits.insufficient", ResultEnum.PROMPT.getCode());
} }
// 生成唯一任务ID
String taskId;
if (!StringUtil.isNullOrEmpty(imageToSketchDTO.getModelName())
&& imageToSketchDTO.getModelName().equals("flux")){
String imagePath;
// todo 拼贴图的线稿提取是否能用flux
if (StringUtil.isNullOrEmpty(collagePictureUrl)){
CollectionElement collectionElement = collectionElementService.getById(imageToSketchDTO.getElementId());
imagePath = collectionElement.getUrl();
}else {
imagePath = collagePictureUrl;
}
taskId = flux(CreditsEventsEnum.IMAGE_TO_SKETCH, null, imagePath);
// 存数据库
saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId,
accountId, imageToSketchDTO.getStyle(), "flux", taskId);
return taskId;
}
taskId = UUID.randomUUID().toString();
// 异步执行耗时操作
CompletableFuture.runAsync(() -> {
try {
processImageToSketch(taskId, imageToSketchDTO, collagePictureUrl, projectId, accountId, event);
} catch (Exception e) {
log.error("异步处理图片转sketch失败, taskId: {}", taskId, e);
// 更新redis
redisUtil.addToString(generateResultKey + ":" + taskId, new Gson().toJson(new GenerateResultVO(taskId, "Failed")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
}
}, asyncTaskExecutor);
return taskId;
}
private void processImageToSketch(String taskId, ImageToSketchDTO imageToSketchDTO,
String collagePictureUrl, Long projectId,
Long accountId, CreditsEventsEnum event) throws IOException {
// 设置任务状态为处理中
redisUtil.addToString(generateResultKey + ":" + taskId, new Gson().toJson(new GenerateResultVO(taskId, "Executing")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
String style = imageToSketchDTO.getStyle(); String style = imageToSketchDTO.getStyle();
String styleCode = style.equals(SketchStyle.THICK.getValue()) ? "1" : String styleCode = style.equals(SketchStyle.THICK.getValue()) ? "1" :
style.equals(SketchStyle.MEDIUM.getValue()) ? "2" : style.equals(SketchStyle.MEDIUM.getValue()) ? "2" :
style.equals(SketchStyle.THIN.getValue()) ? "3" : "Custom"; style.equals(SketchStyle.THIN.getValue()) ? "3" : "Custom";
// 请求记录存数据库
Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId,
accountId, styleCode, "freepik", taskId);
// 1、初步提取结果
String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode); String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode);
// 2、获取输入图的描述
String imageDescription = getImageDescription(sketchPath); String imageDescription = getImageDescription(sketchPath);
// 3、请求freepik reimage
try { String dataStr = reimagineFreePik(sketchPath, imageDescription, "vivid");
// 请求freepik reimage if (StringUtil.isNullOrEmpty(dataStr)) {
String dataStr = reimagineFreePik(sketchPath, imageDescription, "vivid"); throw new BusinessException("extract sketch failed");
if (StringUtil.isNullOrEmpty(dataStr)){
throw new BusinessException("extract sketch failed");
}
JSONObject data = JSONUtil.parseObj(dataStr);
String upgradeImageUrl = data.getBeanList("generated", String.class).get(0);
String taskId = data.getStr("task_id");
// 下载图片 freepik
// byte[] bytes = downloadWithProxy(upgradeImageUrl);
byte[] bytes = downloadVideoOrImage(upgradeImageUrl);
// 2、上传图片到minio保存
String objectName = accountId + "/imageToSketch/" + taskId + ".png";
minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png");
// 存数据库
Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, accountId, styleCode);
GenerateResultVO generateResultVO = saveExtractSketchResult(generate, userBucket + "/" + objectName, imageToSketchDTO.getGender());
// 积分扣除
doCreditsSubtract(accountId, event);
return generateResultVO;
} catch (IOException e) {
throw new RuntimeException(e);
} }
JSONObject data = JSONUtil.parseObj(dataStr);
String upgradeImageUrl = data.getBeanList("generated", String.class).get(0);
String freepikTaskId = data.getStr("task_id");
// 4、下载图片
byte[] bytes = downloadVideoOrImage(upgradeImageUrl);
// byte[] bytes = downloadWithProxy(upgradeImageUrl);
// 5、上传图片到minio保存
String objectName = accountId + "/imageToSketch/" + freepikTaskId + ".png";
minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png");
// 6、保存结果到db
GenerateResultVO generateResultVO = saveExtractSketchResult(generate, userBucket + "/" + objectName, imageToSketchDTO.getGender());
// 7、积分扣除
doCreditsSubtract(accountId, event);
// 8、将结果存入Redis
redisUtil.addToString(generateResultKey + ":" + taskId, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
} }
// 对提取出来的sketch做调整 // 对提取出来的sketch做调整
@@ -1846,8 +1897,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
public byte[] downloadWithProxy(String url) throws IOException { public byte[] downloadWithProxy(String url) throws IOException {
// 获取系统代理设置适用于大多数VPN // 获取系统代理设置适用于大多数VPN
// String proxyHost = System.getProperty("http.proxyHost"); // String proxyHost = System.getProperty("http.proxyHost");
String proxyHost = "localhost";
// String proxyPort = System.getProperty("http.proxyPort"); // String proxyPort = System.getProperty("http.proxyPort");
String proxyHost = "localhost";
String proxyPort = "7890"; String proxyPort = "7890";
CloseableHttpClient client; CloseableHttpClient client;
@@ -1965,18 +2016,15 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
private String resolveModelType(String taskId){ private String resolveModelType(String taskId){
// 判断当前task来自哪个模型 // 判断当前task来自哪个模型
// 判断taskId的结构 Generate generate = selectByUniqueId(taskId);
int count = StringUtils.countMatches(taskId, "-"); if (!StringUtil.isNullOrEmpty(generate.getModelName()) &&
String lastPart = taskId.substring(taskId.lastIndexOf("-") + 1); (generate.getModelName().equals("wx")
String type; || generate.getModelName().equals("freepik")
if (count == 4 && lastPart.length() == 12){ || generate.getModelName().equals("flux") )){
// 万象 return generate.getModelName();
type = "wx";
}else { }else {
// 本地部署的模型 return "local";
type = "local";
} }
return type;
} }
public static String extractGender(String text) { public static String extractGender(String text) {
@@ -2038,7 +2086,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
return respObj.getStr("id"); return respObj.getStr("id");
} }
public String getFluxResult(String taskId, Long accountId){ public String getFluxResult(String taskId, String objectName){
String fluxResultRequestUrl = "https://api.bfl.ai/v1/get_result"; String fluxResultRequestUrl = "https://api.bfl.ai/v1/get_result";
HashMap<String, Object> params = new HashMap<>(); HashMap<String, Object> params = new HashMap<>();
params.put("id", taskId); params.put("id", taskId);
@@ -2058,7 +2106,6 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 已完成 获取结果 // 已完成 获取结果
String fluxResult = JSONUtil.parseObj(respObj.getStr("result")).getStr("sample"); String fluxResult = JSONUtil.parseObj(respObj.getStr("result")).getStr("sample");
byte[] bytes = downloadVideoOrImage(fluxResult); byte[] bytes = downloadVideoOrImage(fluxResult);
String objectName = accountId + "/product_image/" + taskId + ".png";
minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png"); minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png");
// return minioUtil.getPreSignedUrl(userBucket + "/" + objectName, CommonConstant.MINIO_IMAGE_EXPIRE_TIME); // return minioUtil.getPreSignedUrl(userBucket + "/" + objectName, CommonConstant.MINIO_IMAGE_EXPIRE_TIME);
@@ -2069,4 +2116,42 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
} }
return null; return null;
} }
private GenerateResultVO getFluxResultAndSave(String taskId){
Generate generate = selectByUniqueId(taskId);
if (Objects.nonNull(generate)){
GenerateDetail generateDetail = generateDetailMapper.selectOne(new QueryWrapper<GenerateDetail>().eq("generate_id", generate.getId()));
Long accountId = generate.getAccountId();
String objectName = accountId + "/imageToSketch/" + taskId + ".png";
String fluxResult = getFluxResult(taskId, objectName);
if (Objects.isNull(generateDetail)){
if (fluxResult.equals("Failed") || fluxResult.equals("Pending")){
String status = fluxResult.equals("Failed") ? "Failed" : "Executing";
return new GenerateResultVO(taskId, status);
}
generateDetail = new GenerateDetail(generate.getId(), fluxResult,
MD5Utils.encryptFile(
minioUtil.getPreSignedUrl(fluxResult, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), false),
LocalDateTime.now());
generateDetailMapper.insert(generateDetail);
} else if (StringUtil.isNullOrEmpty(generateDetail.getUrl())){
// 一般来说这条线应该走不到
generateDetail.setGenerateId(generate.getId());
generateDetail.setUrl(fluxResult);
generateDetail.setMd5(MD5Utils.encryptFile(
minioUtil.getPreSignedUrl(fluxResult, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), false));
generateDetail.setUpdateDate(new Date());
generateDetailMapper.updateById(generateDetail);
}
String url = generateDetail.getUrl();
String clothCategory = pythonService.getClothCategory(url, extractGender(generate.getGenerateType()));
return new GenerateResultVO(taskId, generateDetail.getId(),
minioUtil.getPreSignedUrl(url, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), "Success", clothCategory);
}else {
throw new BusinessException("unknown generate");
}
}
} }

View File

@@ -643,7 +643,8 @@ public class UserLikeGroupServiceImpl extends ServiceImpl<UserLikeGroupMapper, U
if (Objects.isNull(project)){ if (Objects.isNull(project)){
throw new BusinessException("unknown project"); throw new BusinessException("unknown project");
} }
String fluxResult = generateService.getFluxResult(taskId, project.getAccountId()); String objectName = project.getAccountId() + "/product_image/" + taskId + ".png";
String fluxResult = generateService.getFluxResult(taskId, objectName);
if (StringUtil.isNullOrEmpty(fluxResult)){ if (StringUtil.isNullOrEmpty(fluxResult)){
results.add(new MagicToolResultVO()); results.add(new MagicToolResultVO());
} else if (fluxResult.equals("Failed") || fluxResult.equals("Pending")) { } else if (fluxResult.equals("Failed") || fluxResult.equals("Pending")) {
@@ -1039,7 +1040,8 @@ public class UserLikeGroupServiceImpl extends ServiceImpl<UserLikeGroupMapper, U
if (Objects.isNull(project)){ if (Objects.isNull(project)){
throw new BusinessException("unknown project"); throw new BusinessException("unknown project");
} }
String fluxResult = generateService.getFluxResult(taskId, project.getAccountId()); String objectName = project.getAccountId() + "/product_image/" + taskId + ".png";
String fluxResult = generateService.getFluxResult(taskId, objectName);
if (StringUtil.isNullOrEmpty(fluxResult)){ if (StringUtil.isNullOrEmpty(fluxResult)){
results.add(new MagicToolResultVO()); results.add(new MagicToolResultVO());