TASK:1、将imageToSketch接口调用转为异步 2、imageToSketch加入flux
This commit is contained in:
21
src/main/java/com/ai/da/common/config/AsyncConfig.java
Normal file
21
src/main/java/com/ai/da/common/config/AsyncConfig.java
Normal file
@@ -0,0 +1,21 @@
|
||||
package com.ai.da.common.config;
|
||||
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
|
||||
|
||||
import java.util.concurrent.Executor;
|
||||
|
||||
@Configuration
|
||||
public class AsyncConfig {
|
||||
@Bean("asyncTaskExecutor")
|
||||
public Executor asyncTaskExecutor() {
|
||||
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
|
||||
executor.setCorePoolSize(5);
|
||||
executor.setMaxPoolSize(10);
|
||||
executor.setQueueCapacity(100);
|
||||
executor.setThreadNamePrefix("Async-ImageToSketch-");
|
||||
executor.initialize();
|
||||
return executor;
|
||||
}
|
||||
}
|
||||
@@ -87,7 +87,7 @@ public class GenerateController {
|
||||
|
||||
@ApiOperation(value = "imageToSketch")
|
||||
@PostMapping("/imageToSketch")
|
||||
public Response<GenerateResultVO> imageToSketch(@Valid @RequestBody ImageToSketchDTO imageToSketchDTO) {
|
||||
public Response<String> imageToSketch(@Valid @RequestBody ImageToSketchDTO imageToSketchDTO) {
|
||||
return Response.success(generateService.imageToSketchAsync(imageToSketchDTO, null, null));
|
||||
// return Response.success(generateService.imageToSketch(imageToSketchDTO, null, null));
|
||||
}
|
||||
@@ -209,7 +209,7 @@ public class GenerateController {
|
||||
// @ApiOperation(value = "获取flux结果")
|
||||
// @GetMapping("/fluxResult")
|
||||
public Response<String> fluxResult(@RequestParam("taskId") String taskId){
|
||||
return Response.success(generateService.getFluxResult(taskId, 87L));
|
||||
return Response.success(generateService.getFluxResult(taskId, "87/" + taskId + ".png"));
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -58,5 +58,13 @@ public class GenerateDetail {
|
||||
*/
|
||||
private Date updateDate;
|
||||
|
||||
public GenerateDetail() {
|
||||
}
|
||||
|
||||
public GenerateDetail(Long generateId, String url, String md5, LocalDateTime createDate) {
|
||||
this.generateId = generateId;
|
||||
this.url = url;
|
||||
this.md5 = md5;
|
||||
this.createDate = createDate;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,9 @@ public class ImageToSketchDTO {
|
||||
@ApiModelProperty("性别")
|
||||
private String gender;
|
||||
|
||||
@ApiModelProperty("模型名")
|
||||
private String modelName;
|
||||
|
||||
public ImageToSketchDTO() {
|
||||
}
|
||||
|
||||
|
||||
@@ -35,4 +35,9 @@ public class GenerateResultVO {
|
||||
this.status = status;
|
||||
this.category = category;
|
||||
}
|
||||
|
||||
public GenerateResultVO(String taskId, String status) {
|
||||
this.taskId = taskId;
|
||||
this.status = status;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ public interface GenerateService extends IService<Generate> {
|
||||
|
||||
GenerateResultVO imageToSketch(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId);
|
||||
|
||||
GenerateResultVO imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId);
|
||||
String imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId);
|
||||
|
||||
GenerateResultVO modifySketch(GenerateModifyDTO generateModifyDTO);
|
||||
|
||||
@@ -85,7 +85,7 @@ public interface GenerateService extends IService<Generate> {
|
||||
|
||||
String flux(CreditsEventsEnum func, String prompt, String imagePath);
|
||||
|
||||
String getFluxResult(String taskId, Long accountId);
|
||||
String getFluxResult(String taskId, String objectName);
|
||||
|
||||
byte[] downloadVideoOrImage(String url);
|
||||
}
|
||||
|
||||
@@ -56,6 +56,8 @@ import java.io.*;
|
||||
import java.math.BigDecimal;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
@@ -685,6 +687,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
boolean flag = true;
|
||||
String type = null;
|
||||
for (String taskId : taskIdList) {
|
||||
String key = generateResultKey + ":" + taskId;
|
||||
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
|
||||
|
||||
if (flag) {
|
||||
type = resolveModelType(taskId);
|
||||
flag = false;
|
||||
@@ -692,11 +697,14 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
// 暂定万象每次生成1个
|
||||
if (type.equals("wx")){
|
||||
return Collections.singletonList(getAsyncTaskResult(taskId));
|
||||
} else if (type.equals("freepik")){
|
||||
results.add(generateResultVO);
|
||||
continue;
|
||||
} else if (type.equals("flux")){
|
||||
results.add(getFluxResultAndSave(taskId));
|
||||
continue;
|
||||
}
|
||||
|
||||
String key = generateResultKey + ":" + taskId;
|
||||
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
|
||||
|
||||
if (generateResultVO != null && !StringUtil.isNullOrEmpty(generateResultVO.getUrl())) {
|
||||
String url = generateResultVO.getUrl();
|
||||
if (url.substring(url.lastIndexOf("/") + 1).equals("white_image.jpg")) {
|
||||
@@ -886,7 +894,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
// 线稿提取
|
||||
String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode);
|
||||
// 存数据库
|
||||
Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, accountId, styleCode);
|
||||
Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId,
|
||||
accountId, styleCode, "local", "0");
|
||||
GenerateResultVO generateResultVO = saveExtractSketchResult(generate, sketchPath, imageToSketchDTO.getGender());
|
||||
// 积分扣除
|
||||
doCreditsSubtract(accountId, event);
|
||||
@@ -921,16 +930,18 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
}
|
||||
|
||||
private Generate saveExtractSketchRequest(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl,
|
||||
Long projectId, Long accountId, String styleCode){
|
||||
Long projectId, Long accountId, String styleCode,
|
||||
String modelName, String taskId){
|
||||
// 存DB
|
||||
Generate generate = new Generate();
|
||||
generate.setAccountId(accountId);
|
||||
generate.setUniqueId(String.valueOf(0));
|
||||
generate.setUniqueId(taskId);
|
||||
generate.setLevel1Type(SKETCH_BOARD.getRealName());
|
||||
generate.setLevel2Type("ImageToSketch");
|
||||
generate.setElementSource("collection");
|
||||
generate.setElementId(imageToSketchDTO.getElementId());
|
||||
generate.setGenerateType("image");
|
||||
generate.setGenerateType("image(" + imageToSketchDTO.getGender() + ")");
|
||||
generate.setModelName(modelName);
|
||||
generate.setSketchStyle(styleCode);
|
||||
generate.setStyleImageElementId(imageToSketchDTO.getStyleImageId());
|
||||
generate.setProjectId(projectId);
|
||||
@@ -962,52 +973,92 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
creditsService.preInsert(accountId, event.getName(), null, Boolean.FALSE, event.getValue());
|
||||
}
|
||||
|
||||
// freepik以后会变成异步的吗? 目前同步
|
||||
public GenerateResultVO imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId){
|
||||
// 注入线程池(可在配置类中定义)
|
||||
@Resource
|
||||
private Executor asyncTaskExecutor;
|
||||
|
||||
public String imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId) {
|
||||
Long accountId = UserContext.getUserHolder().getId();
|
||||
log.info("imageToSketch parameter : {}", imageToSketchDTO);
|
||||
|
||||
// 检查积分是否够本次扣除
|
||||
CreditsEventsEnum event = CreditsEventsEnum.IMAGE_TO_SKETCH;
|
||||
Boolean b = creditsService.checkCredits(accountId, event, 1);
|
||||
if (!b){
|
||||
if (!b) {
|
||||
throw new BusinessException("remaining.credits.insufficient", ResultEnum.PROMPT.getCode());
|
||||
}
|
||||
|
||||
// 生成唯一任务ID
|
||||
String taskId;
|
||||
if (!StringUtil.isNullOrEmpty(imageToSketchDTO.getModelName())
|
||||
&& imageToSketchDTO.getModelName().equals("flux")){
|
||||
String imagePath;
|
||||
// todo 拼贴图的线稿提取是否能用flux
|
||||
if (StringUtil.isNullOrEmpty(collagePictureUrl)){
|
||||
CollectionElement collectionElement = collectionElementService.getById(imageToSketchDTO.getElementId());
|
||||
imagePath = collectionElement.getUrl();
|
||||
}else {
|
||||
imagePath = collagePictureUrl;
|
||||
}
|
||||
taskId = flux(CreditsEventsEnum.IMAGE_TO_SKETCH, null, imagePath);
|
||||
// 存数据库
|
||||
saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId,
|
||||
accountId, imageToSketchDTO.getStyle(), "flux", taskId);
|
||||
return taskId;
|
||||
}
|
||||
|
||||
taskId = UUID.randomUUID().toString();
|
||||
// 异步执行耗时操作
|
||||
CompletableFuture.runAsync(() -> {
|
||||
try {
|
||||
processImageToSketch(taskId, imageToSketchDTO, collagePictureUrl, projectId, accountId, event);
|
||||
} catch (Exception e) {
|
||||
log.error("异步处理图片转sketch失败, taskId: {}", taskId, e);
|
||||
// 更新redis
|
||||
redisUtil.addToString(generateResultKey + ":" + taskId, new Gson().toJson(new GenerateResultVO(taskId, "Failed")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
|
||||
}
|
||||
}, asyncTaskExecutor);
|
||||
return taskId;
|
||||
}
|
||||
|
||||
private void processImageToSketch(String taskId, ImageToSketchDTO imageToSketchDTO,
|
||||
String collagePictureUrl, Long projectId,
|
||||
Long accountId, CreditsEventsEnum event) throws IOException {
|
||||
// 设置任务状态为处理中
|
||||
redisUtil.addToString(generateResultKey + ":" + taskId, new Gson().toJson(new GenerateResultVO(taskId, "Executing")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
|
||||
String style = imageToSketchDTO.getStyle();
|
||||
String styleCode = style.equals(SketchStyle.THICK.getValue()) ? "1" :
|
||||
style.equals(SketchStyle.MEDIUM.getValue()) ? "2" :
|
||||
style.equals(SketchStyle.THIN.getValue()) ? "3" : "Custom";
|
||||
|
||||
// 请求记录存数据库
|
||||
Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId,
|
||||
accountId, styleCode, "freepik", taskId);
|
||||
// 1、初步提取结果
|
||||
String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode);
|
||||
|
||||
// 2、获取输入图的描述
|
||||
String imageDescription = getImageDescription(sketchPath);
|
||||
|
||||
try {
|
||||
// 请求freepik reimage
|
||||
String dataStr = reimagineFreePik(sketchPath, imageDescription, "vivid");
|
||||
if (StringUtil.isNullOrEmpty(dataStr)){
|
||||
throw new BusinessException("extract sketch failed");
|
||||
}
|
||||
JSONObject data = JSONUtil.parseObj(dataStr);
|
||||
String upgradeImageUrl = data.getBeanList("generated", String.class).get(0);
|
||||
String taskId = data.getStr("task_id");
|
||||
|
||||
// 下载图片 freepik
|
||||
// byte[] bytes = downloadWithProxy(upgradeImageUrl);
|
||||
byte[] bytes = downloadVideoOrImage(upgradeImageUrl);
|
||||
// 2、上传图片到minio保存
|
||||
String objectName = accountId + "/imageToSketch/" + taskId + ".png";
|
||||
minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png");
|
||||
// 存数据库
|
||||
Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, accountId, styleCode);
|
||||
GenerateResultVO generateResultVO = saveExtractSketchResult(generate, userBucket + "/" + objectName, imageToSketchDTO.getGender());
|
||||
// 积分扣除
|
||||
doCreditsSubtract(accountId, event);
|
||||
return generateResultVO;
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
// 3、请求freepik reimage
|
||||
String dataStr = reimagineFreePik(sketchPath, imageDescription, "vivid");
|
||||
if (StringUtil.isNullOrEmpty(dataStr)) {
|
||||
throw new BusinessException("extract sketch failed");
|
||||
}
|
||||
|
||||
JSONObject data = JSONUtil.parseObj(dataStr);
|
||||
String upgradeImageUrl = data.getBeanList("generated", String.class).get(0);
|
||||
String freepikTaskId = data.getStr("task_id");
|
||||
|
||||
// 4、下载图片
|
||||
byte[] bytes = downloadVideoOrImage(upgradeImageUrl);
|
||||
// byte[] bytes = downloadWithProxy(upgradeImageUrl);
|
||||
// 5、上传图片到minio保存
|
||||
String objectName = accountId + "/imageToSketch/" + freepikTaskId + ".png";
|
||||
minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png");
|
||||
// 6、保存结果到db
|
||||
GenerateResultVO generateResultVO = saveExtractSketchResult(generate, userBucket + "/" + objectName, imageToSketchDTO.getGender());
|
||||
// 7、积分扣除
|
||||
doCreditsSubtract(accountId, event);
|
||||
// 8、将结果存入Redis
|
||||
redisUtil.addToString(generateResultKey + ":" + taskId, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
|
||||
}
|
||||
|
||||
// 对提取出来的sketch做调整
|
||||
@@ -1835,8 +1886,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
public byte[] downloadWithProxy(String url) throws IOException {
|
||||
// 获取系统代理设置(适用于大多数VPN)
|
||||
// String proxyHost = System.getProperty("http.proxyHost");
|
||||
String proxyHost = "localhost";
|
||||
// String proxyPort = System.getProperty("http.proxyPort");
|
||||
String proxyHost = "localhost";
|
||||
String proxyPort = "7890";
|
||||
|
||||
CloseableHttpClient client;
|
||||
@@ -1954,18 +2005,15 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
|
||||
private String resolveModelType(String taskId){
|
||||
// 判断当前task来自哪个模型
|
||||
// 判断taskId的结构
|
||||
int count = StringUtils.countMatches(taskId, "-");
|
||||
String lastPart = taskId.substring(taskId.lastIndexOf("-") + 1);
|
||||
String type;
|
||||
if (count == 4 && lastPart.length() == 12){
|
||||
// 万象
|
||||
type = "wx";
|
||||
Generate generate = selectByUniqueId(taskId);
|
||||
if (!StringUtil.isNullOrEmpty(generate.getModelName()) &&
|
||||
(generate.getModelName().equals("wx")
|
||||
|| generate.getModelName().equals("freepik")
|
||||
|| generate.getModelName().equals("flux") )){
|
||||
return generate.getModelName();
|
||||
}else {
|
||||
// 本地部署的模型
|
||||
type = "local";
|
||||
return "local";
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
public static String extractGender(String text) {
|
||||
@@ -2027,7 +2075,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
return respObj.getStr("id");
|
||||
}
|
||||
|
||||
public String getFluxResult(String taskId, Long accountId){
|
||||
public String getFluxResult(String taskId, String objectName){
|
||||
String fluxResultRequestUrl = "https://api.bfl.ai/v1/get_result";
|
||||
HashMap<String, Object> params = new HashMap<>();
|
||||
params.put("id", taskId);
|
||||
@@ -2047,7 +2095,6 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
// 已完成 获取结果
|
||||
String fluxResult = JSONUtil.parseObj(respObj.getStr("result")).getStr("sample");
|
||||
byte[] bytes = downloadVideoOrImage(fluxResult);
|
||||
String objectName = accountId + "/product_image/" + taskId + ".png";
|
||||
minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png");
|
||||
|
||||
// return minioUtil.getPreSignedUrl(userBucket + "/" + objectName, CommonConstant.MINIO_IMAGE_EXPIRE_TIME);
|
||||
@@ -2058,4 +2105,42 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private GenerateResultVO getFluxResultAndSave(String taskId){
|
||||
Generate generate = selectByUniqueId(taskId);
|
||||
if (Objects.nonNull(generate)){
|
||||
GenerateDetail generateDetail = generateDetailMapper.selectOne(new QueryWrapper<GenerateDetail>().eq("generate_id", generate.getId()));
|
||||
Long accountId = generate.getAccountId();
|
||||
String objectName = accountId + "/imageToSketch/" + taskId + ".png";
|
||||
String fluxResult = getFluxResult(taskId, objectName);
|
||||
if (Objects.isNull(generateDetail)){
|
||||
if (fluxResult.equals("Failed") || fluxResult.equals("Pending")){
|
||||
String status = fluxResult.equals("Failed") ? "Failed" : "Executing";
|
||||
return new GenerateResultVO(taskId, status);
|
||||
}
|
||||
|
||||
generateDetail = new GenerateDetail(generate.getId(), fluxResult,
|
||||
MD5Utils.encryptFile(
|
||||
minioUtil.getPreSignedUrl(fluxResult, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), false),
|
||||
LocalDateTime.now());
|
||||
generateDetailMapper.insert(generateDetail);
|
||||
} else if (StringUtil.isNullOrEmpty(generateDetail.getUrl())){
|
||||
// 一般来说这条线应该走不到
|
||||
generateDetail.setGenerateId(generate.getId());
|
||||
generateDetail.setUrl(fluxResult);
|
||||
generateDetail.setMd5(MD5Utils.encryptFile(
|
||||
minioUtil.getPreSignedUrl(fluxResult, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), false));
|
||||
generateDetail.setUpdateDate(new Date());
|
||||
generateDetailMapper.updateById(generateDetail);
|
||||
}
|
||||
String url = generateDetail.getUrl();
|
||||
String clothCategory = pythonService.getClothCategory(url, extractGender(generate.getGenerateType()));
|
||||
return new GenerateResultVO(taskId, generateDetail.getId(),
|
||||
minioUtil.getPreSignedUrl(url, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), "Success", clothCategory);
|
||||
}else {
|
||||
throw new BusinessException("unknown generate");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -626,7 +626,8 @@ public class UserLikeGroupServiceImpl extends ServiceImpl<UserLikeGroupMapper, U
|
||||
if (Objects.isNull(project)){
|
||||
throw new BusinessException("unknown project");
|
||||
}
|
||||
String fluxResult = generateService.getFluxResult(taskId, project.getAccountId());
|
||||
String objectName = project.getAccountId() + "/product_image/" + taskId + ".png";
|
||||
String fluxResult = generateService.getFluxResult(taskId, objectName);
|
||||
if (StringUtil.isNullOrEmpty(fluxResult)){
|
||||
results.add(new MagicToolResultVO());
|
||||
} else if (fluxResult.equals("Failed") || fluxResult.equals("Pending")) {
|
||||
@@ -1016,7 +1017,8 @@ public class UserLikeGroupServiceImpl extends ServiceImpl<UserLikeGroupMapper, U
|
||||
if (Objects.isNull(project)){
|
||||
throw new BusinessException("unknown project");
|
||||
}
|
||||
String fluxResult = generateService.getFluxResult(taskId, project.getAccountId());
|
||||
String objectName = project.getAccountId() + "/product_image/" + taskId + ".png";
|
||||
String fluxResult = generateService.getFluxResult(taskId, objectName);
|
||||
|
||||
if (StringUtil.isNullOrEmpty(fluxResult)){
|
||||
results.add(new MagicToolResultVO());
|
||||
|
||||
Reference in New Issue
Block a user