generate模型更换后的接口更改及异步获取结果
This commit is contained in:
@@ -3,9 +3,10 @@ package com.ai.da.common.RabbitMQ;
|
|||||||
import com.ai.da.common.config.exception.BusinessException;
|
import com.ai.da.common.config.exception.BusinessException;
|
||||||
import com.ai.da.common.utils.RedisUtil;
|
import com.ai.da.common.utils.RedisUtil;
|
||||||
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
||||||
import com.ai.da.model.vo.GenerateCollectionVO;
|
import com.ai.da.model.vo.GenerateResultVO;
|
||||||
import com.ai.da.service.GenerateService;
|
import com.ai.da.service.GenerateService;
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import com.google.gson.Gson;
|
||||||
import com.rabbitmq.client.Channel;
|
import com.rabbitmq.client.Channel;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.amqp.core.Message;
|
import org.springframework.amqp.core.Message;
|
||||||
@@ -17,7 +18,7 @@ import org.springframework.stereotype.Component;
|
|||||||
import javax.annotation.Resource;
|
import javax.annotation.Resource;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Objects;
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -42,6 +43,9 @@ public class GenerateConsumer {
|
|||||||
@Value("${redis.key.resultMap}")
|
@Value("${redis.key.resultMap}")
|
||||||
private String resultMapKey;
|
private String resultMapKey;
|
||||||
|
|
||||||
|
@Value("${redis.key.generateResult}")
|
||||||
|
private String generateResultKey;
|
||||||
|
|
||||||
public void generate(Message msg, Channel channel, String consumerName) {
|
public void generate(Message msg, Channel channel, String consumerName) {
|
||||||
log.info("============start listening==========");
|
log.info("============start listening==========");
|
||||||
long start = System.currentTimeMillis();
|
long start = System.currentTimeMillis();
|
||||||
@@ -63,20 +67,16 @@ public class GenerateConsumer {
|
|||||||
// 2.2 将该消息从取消列表中删除
|
// 2.2 将该消息从取消列表中删除
|
||||||
// redisUtil.removeFromSet(cancelSetKey, uniqueId);
|
// redisUtil.removeFromSet(cancelSetKey, uniqueId);
|
||||||
} else {
|
} else {
|
||||||
/*try {
|
// GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO);
|
||||||
Thread.sleep(15000);
|
generateService.generateThroughImageText(generateThroughImageTextDTO);
|
||||||
} catch (InterruptedException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}*/
|
|
||||||
GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO);
|
|
||||||
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
|
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
|
||||||
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
|
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
|
||||||
if (!Objects.isNull(generateCollectionVO)) {
|
/*if (!Objects.isNull(generateCollectionVO)) {
|
||||||
HashMap<String, String> generateResult = new HashMap<>();
|
HashMap<String, String> generateResult = new HashMap<>();
|
||||||
generateResult.put(uniqueId, JSONObject.toJSONString(generateCollectionVO));
|
generateResult.put(uniqueId, JSONObject.toJSONString(generateCollectionVO));
|
||||||
// 将结果存在redis中 ,为空时不要存
|
// 将结果存在redis中 ,为空时不要存
|
||||||
redisUtil.addToMap(resultMapKey, generateResult);
|
redisUtil.addToMap(resultMapKey, generateResult);
|
||||||
}
|
}*/
|
||||||
|
|
||||||
}
|
}
|
||||||
} catch (BusinessException e) {
|
} catch (BusinessException e) {
|
||||||
@@ -104,6 +104,34 @@ public class GenerateConsumer {
|
|||||||
log.info("=============end listening===========");
|
log.info("=============end listening===========");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void processGenerateResult(Message msg, Channel channel){
|
||||||
|
log.info("============ProcessGenerateResult listening==========");
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
|
|
||||||
|
Map<String, String> generateResult = JSONObject.parseObject(msg.getBody(), Map.class);
|
||||||
|
// log.info("tasks_id : {}, message : {}",generateResult.get("tasks_id"), generateResult.get("message") );
|
||||||
|
if (generateResult.get("status").equals("SUCCESS")){
|
||||||
|
String url = generateResult.get("data");
|
||||||
|
String taskId = generateResult.get("tasks_id");
|
||||||
|
generateService.processGenerateResult(taskId, url);
|
||||||
|
}else {
|
||||||
|
// 修改redis中的数据状态为exception
|
||||||
|
String key = generateResultKey + ":" + generateResult.get("tasks_id");
|
||||||
|
Long expire = redisUtil.getExpire(key);
|
||||||
|
redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(null, null, "Fail")), expire);
|
||||||
|
// 将异常信息存到exception中
|
||||||
|
HashMap<String, String> exceptionInfo = new HashMap<>();
|
||||||
|
exceptionInfo.put(generateResult.get("tasks_id"), generateResult.get("data"));
|
||||||
|
// 存redis
|
||||||
|
redisUtil.addToMap(exceptionMapKey, exceptionInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
long end = System.currentTimeMillis();
|
||||||
|
log.info("tasks_id : {}, message : {}, 执行时长: {} 毫秒",generateResult.get("tasks_id"), generateResult.get("message"), (end - start));
|
||||||
|
log.info("============ProcessGenerateResult End listening==========");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
|
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
|
||||||
@RabbitHandler
|
@RabbitHandler
|
||||||
public void generateConsumer1(Message msg, Channel channel) {
|
public void generateConsumer1(Message msg, Channel channel) {
|
||||||
@@ -158,4 +186,9 @@ public class GenerateConsumer {
|
|||||||
generate(msg, channel, "consumer 9");
|
generate(msg, channel, "consumer 9");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@RabbitListener(queues = MQConfig.GENERATE_RESULT_QUEUE)
|
||||||
|
@RabbitHandler
|
||||||
|
public void getGenerateResult(Message msg, Channel channel){
|
||||||
|
processGenerateResult(msg, channel);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,14 +10,17 @@ public class MQConfig {
|
|||||||
public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange";
|
public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange";
|
||||||
// public static final String GENERATE_QUEUE = "generate-queue-prod";
|
// public static final String GENERATE_QUEUE = "generate-queue-prod";
|
||||||
// public static final String GENERATE_QUEUE = "generate-queue-test";
|
// public static final String GENERATE_QUEUE = "generate-queue-test";
|
||||||
// public static final String GENERATE_QUEUE = "generate-queue-dev";
|
// public static final String GENERATE_QUEUE = "generate-queue-local";
|
||||||
public static final String GENERATE_QUEUE = "generate-queue-local";
|
public static final String GENERATE_QUEUE = "generate-queue-dev";
|
||||||
|
|
||||||
// public static final String SR_QUEUE = "SR-queue-dev";
|
// public static final String SR_QUEUE = "SR-queue-local";
|
||||||
public static final String SR_QUEUE = "SR-queue-local";
|
public static final String SR_QUEUE = "SR-queue-dev";
|
||||||
|
|
||||||
public static final String SR_RESULT_QUEUE = "SuperResolution-local";
|
// public static final String SR_RESULT_QUEUE = "SuperResolution-local";
|
||||||
// public static final String SR_RESULT_QUEUE = "SuperResolution-dev";
|
public static final String SR_RESULT_QUEUE = "SuperResolution-dev";
|
||||||
|
|
||||||
|
// public static final String GENERATE_RESULT_QUEUE = "GenerateImage-local";
|
||||||
|
public static final String GENERATE_RESULT_QUEUE = "GenerateImage-dev";
|
||||||
|
|
||||||
public MQConfig() {
|
public MQConfig() {
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,4 +7,7 @@ public class CommonConstant {
|
|||||||
public static final Long CREDITS_EXPIRE_TIME = 2 * 24 * 60 * 60L;
|
public static final Long CREDITS_EXPIRE_TIME = 2 * 24 * 60 * 60L;
|
||||||
// 单位 分钟
|
// 单位 分钟
|
||||||
public static final Integer MINIO_IMAGE_EXPIRE_TIME = 24 * 60;
|
public static final Integer MINIO_IMAGE_EXPIRE_TIME = 24 * 60;
|
||||||
|
// 单位 秒 一天过期 in redis
|
||||||
|
public static final Long GENERATE_RESULT_EXPIRE_TIME = 24 * 60 * 60L;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,26 +12,33 @@ public enum GenerateModeEnum {
|
|||||||
/**
|
/**
|
||||||
* 通过文本生成
|
* 通过文本生成
|
||||||
*/
|
*/
|
||||||
TEXT(1, "text"),
|
TEXT(1, "text","txt2img"),
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 通过图片生成
|
* 通过图片生成
|
||||||
*/
|
*/
|
||||||
IMAGE(2, "image"),
|
IMAGE(2, "image", "img2img"),
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 通过文本和图片生成
|
* 通过文本和图片生成
|
||||||
*/
|
*/
|
||||||
TEXT_IMAGE(2, "text-image");
|
TEXT_IMAGE(2, "text-image","txt2img");
|
||||||
|
|
||||||
private Integer code;
|
private Integer code;
|
||||||
private String value;
|
private String value;
|
||||||
|
private String type;
|
||||||
|
|
||||||
GenerateModeEnum(int code, String value) {
|
GenerateModeEnum(int code, String value) {
|
||||||
this.code = code;
|
this.code = code;
|
||||||
this.value = value;
|
this.value = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GenerateModeEnum(Integer code, String value, String type) {
|
||||||
|
this.code = code;
|
||||||
|
this.value = value;
|
||||||
|
this.type = type;
|
||||||
|
}
|
||||||
|
|
||||||
public static List<String> getGenerateModeList(){
|
public static List<String> getGenerateModeList(){
|
||||||
return Stream.of(TEXT,IMAGE,TEXT_IMAGE).map(GenerateModeEnum::getValue).collect(Collectors.toList());
|
return Stream.of(TEXT,IMAGE,TEXT_IMAGE).map(GenerateModeEnum::getValue).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ public class AsyncCallerUtil {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public CompletableFuture<List<String>> callGenerateAsync(GenerateToPythonDTO generateToPython) {
|
public CompletableFuture<List<String>> callGenerateAsync(GenerateToPythonDTO generateToPython) {
|
||||||
return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython));
|
// return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython));
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<String> generate(GenerateToPythonDTO generateToPython) {
|
public List<String> generate(GenerateToPythonDTO generateToPython) {
|
||||||
|
|||||||
@@ -3,10 +3,7 @@ package com.ai.da.controller;
|
|||||||
import com.ai.da.common.response.Response;
|
import com.ai.da.common.response.Response;
|
||||||
import com.ai.da.model.dto.GenerateLikeDTO;
|
import com.ai.da.model.dto.GenerateLikeDTO;
|
||||||
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
||||||
import com.ai.da.model.vo.GenerateCaptionVO;
|
import com.ai.da.model.vo.*;
|
||||||
import com.ai.da.model.vo.GenerateCollectionVO;
|
|
||||||
import com.ai.da.model.vo.GenerateLikeVO;
|
|
||||||
import com.ai.da.model.vo.PrepareForGenerateVO;
|
|
||||||
import com.ai.da.service.GenerateService;
|
import com.ai.da.service.GenerateService;
|
||||||
import io.swagger.annotations.Api;
|
import io.swagger.annotations.Api;
|
||||||
import io.swagger.annotations.ApiOperation;
|
import io.swagger.annotations.ApiOperation;
|
||||||
@@ -16,6 +13,7 @@ import org.springframework.web.bind.annotation.*;
|
|||||||
|
|
||||||
import javax.annotation.Resource;
|
import javax.annotation.Resource;
|
||||||
import javax.validation.Valid;
|
import javax.validation.Valid;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author XP
|
* @author XP
|
||||||
@@ -37,8 +35,9 @@ public class GenerateController {
|
|||||||
|
|
||||||
@ApiOperation("通过文字、图片生成图片")
|
@ApiOperation("通过文字、图片生成图片")
|
||||||
@PostMapping("/sketchAndPrint")
|
@PostMapping("/sketchAndPrint")
|
||||||
public Response<GenerateCollectionVO> generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
public void generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||||
return Response.success(generateService.generateThroughImageText(generateThroughImageTextDTO));
|
// return Response.success(generateService.generateThroughImageText(generateThroughImageTextDTO));
|
||||||
|
generateService.generateThroughImageText(generateThroughImageTextDTO);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ApiOperation("喜欢生成的图片")
|
@ApiOperation("喜欢生成的图片")
|
||||||
@@ -56,7 +55,7 @@ public class GenerateController {
|
|||||||
|
|
||||||
@ApiOperation(value = "发起生成请求,异步获取结果")
|
@ApiOperation(value = "发起生成请求,异步获取结果")
|
||||||
@PostMapping("/prepare")
|
@PostMapping("/prepare")
|
||||||
public Response<PrepareForGenerateVO> prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
public Response<List<String>> prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||||
return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO));
|
return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,10 +68,19 @@ public class GenerateController {
|
|||||||
return Response.success("stop waiting successfully");
|
return Response.success("stop waiting successfully");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ApiOperation(value = "获取生成结果")
|
/*@ApiOperation(value = "获取生成结果")
|
||||||
@GetMapping("/result")
|
@GetMapping("/result")
|
||||||
public Response<GenerateCollectionVO> getGenerateResult(@RequestParam("uniqueId") String uniqueId) {
|
public Response<GenerateCollectionVO> getGenerateResult(@RequestParam("uniqueId") String uniqueId) {
|
||||||
GenerateCollectionVO generateResult = generateService.getGenerateResult(uniqueId);
|
GenerateCollectionVO generateResult = generateService.getGenerateResult(uniqueId);
|
||||||
return Response.success(generateResult);
|
return Response.success(generateResult);
|
||||||
|
}*/
|
||||||
|
|
||||||
|
@ApiOperation(value = "获取生成结果")
|
||||||
|
@PostMapping("/result")
|
||||||
|
public Response<List<GenerateResultVO>> getGenerateResults(@Valid @RequestBody List<String> taskIdList) {
|
||||||
|
List<GenerateResultVO> generateResult = generateService.getGenerateResultList(taskIdList);
|
||||||
|
return Response.success(generateResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import lombok.Data;
|
|||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.experimental.Accessors;
|
import lombok.experimental.Accessors;
|
||||||
|
|
||||||
|
import java.time.LocalDateTime;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -50,7 +51,7 @@ public class GenerateDetail {
|
|||||||
/**
|
/**
|
||||||
* 创建时间
|
* 创建时间
|
||||||
*/
|
*/
|
||||||
private Date createDate;
|
private LocalDateTime createDate;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 更新时间
|
* 更新时间
|
||||||
|
|||||||
@@ -6,22 +6,31 @@ import lombok.NoArgsConstructor;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
//@AllArgsConstructor
|
||||||
public class GenerateToPythonDTO {
|
public class GenerateToPythonDTO {
|
||||||
|
// 去掉
|
||||||
private Long user_id;
|
// private Long user_id;
|
||||||
|
|
||||||
private String image_url;
|
private String image_url;
|
||||||
|
|
||||||
private String category;
|
private String category;
|
||||||
|
// 改为prompt
|
||||||
|
// private String content;
|
||||||
|
private String prompt;
|
||||||
|
|
||||||
private String content;
|
private String mode;
|
||||||
|
// 去除
|
||||||
private Integer mode;
|
// private String version;
|
||||||
|
// 去掉
|
||||||
private String version;
|
// private String gender;
|
||||||
|
// taskId的最后拼接用户id
|
||||||
private String gender;
|
|
||||||
|
|
||||||
private String tasks_id;
|
private String tasks_id;
|
||||||
|
|
||||||
|
public GenerateToPythonDTO(String tasks_id, String prompt, String image_url, String mode, String category) {
|
||||||
|
this.image_url = image_url;
|
||||||
|
this.category = category;
|
||||||
|
this.prompt = prompt;
|
||||||
|
this.mode = mode;
|
||||||
|
this.tasks_id = tasks_id;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2880,7 +2880,7 @@ public class PythonService {
|
|||||||
throw new BusinessException("system error!");
|
throw new BusinessException("system error!");
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<String> generateSketchOrPrint(GenerateToPythonDTO generateToPythonDTO) {
|
public Boolean generateSketchOrPrint(GenerateToPythonDTO generateToPythonDTO) {
|
||||||
//限流校验
|
//限流校验
|
||||||
// AccessLimitUtils.validate("generateSketchOrPrint", 5);
|
// AccessLimitUtils.validate("generateSketchOrPrint", 5);
|
||||||
OkHttpClient client = new OkHttpClient().newBuilder()
|
OkHttpClient client = new OkHttpClient().newBuilder()
|
||||||
@@ -2895,7 +2895,8 @@ public class PythonService {
|
|||||||
// .url("http://18.167.251.121:9992")
|
// .url("http://18.167.251.121:9992")
|
||||||
// .url("http://127.0.0.1:5000/api/diffusion")
|
// .url("http://127.0.0.1:5000/api/diffusion")
|
||||||
// .url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion")
|
// .url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion")
|
||||||
.url(accessPythonIp + ":" + accessPythonPort + "/api/generate_image")
|
// .url(accessPythonIp + ":" + accessPythonPort + "/api/generate_image")
|
||||||
|
.url(srPythonPort + "/api/generate_image")
|
||||||
.method("POST", body)
|
.method("POST", body)
|
||||||
// .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
|
// .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
|
||||||
.addHeader("Content-Type", "application/json")
|
.addHeader("Content-Type", "application/json")
|
||||||
@@ -2936,12 +2937,13 @@ public class PythonService {
|
|||||||
|
|
||||||
if (result && jsonObject.get("code").equals(200)) {
|
if (result && jsonObject.get("code").equals(200)) {
|
||||||
log.info("Generate##responseObject###{}", jsonObject);
|
log.info("Generate##responseObject###{}", jsonObject);
|
||||||
return setGenerateImageList(jsonObject.getJSONObject("data"));
|
// return setGenerateImageList(jsonObject.getJSONObject("data"));
|
||||||
}
|
return Boolean.TRUE;
|
||||||
|
}else {
|
||||||
log.info("generateSketchOrPrintPrint失败###{}", jsonObject);
|
log.info("generateSketchOrPrintPrint失败###{}", jsonObject);
|
||||||
log.info("Generate Exception! Code : " + jsonObject.get("code"));
|
log.info("Generate Exception! Code : " + jsonObject.get("code"));
|
||||||
//生成失败
|
return Boolean.FALSE;
|
||||||
throw new BusinessException("generate.interface.error");
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Response sendPostToModel(String content, String portAndRoute, String functionName) {
|
public Response sendPostToModel(String content, String portAndRoute, String functionName) {
|
||||||
|
|||||||
@@ -4,10 +4,7 @@ import com.ai.da.mapper.primary.entity.Generate;
|
|||||||
import com.ai.da.mapper.primary.entity.GenerateDetail;
|
import com.ai.da.mapper.primary.entity.GenerateDetail;
|
||||||
import com.ai.da.model.dto.GenerateLikeDTO;
|
import com.ai.da.model.dto.GenerateLikeDTO;
|
||||||
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
||||||
import com.ai.da.model.vo.GenerateCaptionVO;
|
import com.ai.da.model.vo.*;
|
||||||
import com.ai.da.model.vo.GenerateCollectionVO;
|
|
||||||
import com.ai.da.model.vo.GenerateLikeVO;
|
|
||||||
import com.ai.da.model.vo.PrepareForGenerateVO;
|
|
||||||
import com.baomidou.mybatisplus.extension.service.IService;
|
import com.baomidou.mybatisplus.extension.service.IService;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -16,7 +13,9 @@ public interface GenerateService extends IService<Generate> {
|
|||||||
|
|
||||||
GenerateCaptionVO generateCaption(Long sketchElementId);
|
GenerateCaptionVO generateCaption(Long sketchElementId);
|
||||||
|
|
||||||
GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO);
|
void generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO);
|
||||||
|
|
||||||
|
void processGenerateResult(String taskId, String url);
|
||||||
|
|
||||||
GenerateLikeVO generateLike(GenerateLikeDTO generateLikeDTO);
|
GenerateLikeVO generateLike(GenerateLikeDTO generateLikeDTO);
|
||||||
|
|
||||||
@@ -28,7 +27,9 @@ public interface GenerateService extends IService<Generate> {
|
|||||||
|
|
||||||
GenerateCollectionVO getGenerateResult(String uniqueId);
|
GenerateCollectionVO getGenerateResult(String uniqueId);
|
||||||
|
|
||||||
PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO);
|
List<GenerateResultVO> getGenerateResultList(List<String> taskIdList);
|
||||||
|
|
||||||
|
List<String> prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO);
|
||||||
|
|
||||||
Long getRankPosition(String uniqueId);
|
Long getRankPosition(String uniqueId);
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ import javax.annotation.Resource;
|
|||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.math.BigDecimal;
|
import java.math.BigDecimal;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@@ -834,6 +835,8 @@ public class CollectionElementServiceImpl extends ServiceImpl<CollectionElementM
|
|||||||
} else {
|
} else {
|
||||||
throw new BusinessException("element source type cannot be empty!");
|
throw new BusinessException("element source type cannot be empty!");
|
||||||
}
|
}
|
||||||
|
}else {
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
return collectionElement;
|
return collectionElement;
|
||||||
}
|
}
|
||||||
@@ -867,7 +870,7 @@ public class CollectionElementServiceImpl extends ServiceImpl<CollectionElementM
|
|||||||
generateDetail.setLibraryId(libraryIds.get(0).get("library_id"));
|
generateDetail.setLibraryId(libraryIds.get(0).get("library_id"));
|
||||||
}
|
}
|
||||||
generateDetail.setMd5(md5);
|
generateDetail.setMd5(md5);
|
||||||
generateDetail.setCreateDate(DateUtil.getByTimeZone(timeZone));
|
generateDetail.setCreateDate(LocalDateTime.now());
|
||||||
|
|
||||||
return generateDetail;
|
return generateDetail;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.ai.da.service.impl;
|
package com.ai.da.service.impl;
|
||||||
|
|
||||||
import com.ai.da.common.config.exception.BusinessException;
|
import com.ai.da.common.config.exception.BusinessException;
|
||||||
|
import com.ai.da.common.constant.CommonConstant;
|
||||||
import com.ai.da.common.context.UserContext;
|
import com.ai.da.common.context.UserContext;
|
||||||
import com.ai.da.common.enums.GenerateModeEnum;
|
import com.ai.da.common.enums.GenerateModeEnum;
|
||||||
import com.ai.da.common.enums.ModelNameEnum;
|
import com.ai.da.common.enums.ModelNameEnum;
|
||||||
@@ -23,6 +24,7 @@ import com.alibaba.fastjson.JSON;
|
|||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||||
|
import com.google.gson.Gson;
|
||||||
import io.minio.errors.MinioException;
|
import io.minio.errors.MinioException;
|
||||||
import io.netty.util.internal.StringUtil;
|
import io.netty.util.internal.StringUtil;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -33,6 +35,7 @@ import org.springframework.util.CollectionUtils;
|
|||||||
|
|
||||||
import javax.annotation.Resource;
|
import javax.annotation.Resource;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*;
|
import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*;
|
||||||
@@ -80,6 +83,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
|||||||
@Value("${redis.key.resultMap}")
|
@Value("${redis.key.resultMap}")
|
||||||
private String resultMapKey;
|
private String resultMapKey;
|
||||||
|
|
||||||
|
@Value("${redis.key.generateResult}")
|
||||||
|
private String generateResultKey;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public GenerateCaptionVO generateCaption(Long sketchElementId) {
|
public GenerateCaptionVO generateCaption(Long sketchElementId) {
|
||||||
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
|
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
|
||||||
@@ -95,12 +101,12 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
@Transactional(rollbackFor = Exception.class)
|
@Transactional(rollbackFor = Exception.class)
|
||||||
public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
public void generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||||
// 1、获取用户信息
|
// 1、获取用户信息
|
||||||
Long accountId = generateThroughImageTextDTO.getUserId();
|
Long accountId = generateThroughImageTextDTO.getUserId();
|
||||||
String generateType = generateThroughImageTextDTO.getGenerateType();
|
String generateType = generateThroughImageTextDTO.getGenerateType();
|
||||||
|
|
||||||
// 2、判断必须入参是否为非空
|
// 2、判断必须入参是否为非空(在prepare阶段已校验)
|
||||||
Generate generate = new Generate();
|
Generate generate = new Generate();
|
||||||
generate.setAccountId(accountId);
|
generate.setAccountId(accountId);
|
||||||
generate.setUniqueId(generateThroughImageTextDTO.getUniqueId());
|
generate.setUniqueId(generateThroughImageTextDTO.getUniqueId());
|
||||||
@@ -121,27 +127,38 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
|||||||
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType());
|
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType());
|
||||||
|
|
||||||
// 3、向模型发起请求
|
// 3、向模型发起请求
|
||||||
int mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ?
|
String mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ?
|
||||||
GenerateModeEnum.TEXT.getCode() :
|
GenerateModeEnum.TEXT.getType() :
|
||||||
GenerateModeEnum.TEXT_IMAGE.getCode();
|
GenerateModeEnum.TEXT_IMAGE.getType();
|
||||||
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
|
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
|
||||||
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
|
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
|
||||||
AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
|
// AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
|
||||||
List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
|
// List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
|
||||||
category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId()));
|
// category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId()));
|
||||||
// List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
|
Boolean requestResult = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text,Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
|
||||||
// category, text, mode, "1", generateThroughImageTextDTO.getGender()));
|
mode, category));
|
||||||
log.info("generate 响应 : " + generatedSketchUrl);
|
// log.info("generate 响应 : " + generatedSketchUrl);
|
||||||
if (CollectionUtils.isEmpty(generatedSketchUrl)) {
|
// if (CollectionUtils.isEmpty(generatedSketchUrl)) {
|
||||||
return null;
|
// return null;
|
||||||
}
|
// }
|
||||||
|
|
||||||
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
|
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
|
||||||
save(generate);
|
save(generate);
|
||||||
|
|
||||||
|
// 5、将本次请求存入redis
|
||||||
|
String key = generateResultKey + ":" + generateThroughImageTextDTO.getUniqueId();
|
||||||
|
String status;
|
||||||
|
if (requestResult){
|
||||||
|
status = "Executing";
|
||||||
|
}else {
|
||||||
|
status = "Fail";
|
||||||
|
}
|
||||||
|
GenerateResultVO generateResultVO = new GenerateResultVO(null, null, status);
|
||||||
|
redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
|
||||||
|
|
||||||
// 5、处理模型返回的数据
|
// 5、处理模型返回的数据
|
||||||
// 5.1 将相应的url保存到数据库
|
// 5.1 将相应的url保存到数据库
|
||||||
List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
|
/*List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
|
||||||
generatedSketchUrl.forEach(item -> {
|
generatedSketchUrl.forEach(item -> {
|
||||||
GenerateDetail generateDetail = new GenerateDetail();
|
GenerateDetail generateDetail = new GenerateDetail();
|
||||||
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
|
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
|
||||||
@@ -166,7 +183,35 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
|||||||
|
|
||||||
// 6、将模型返回的图片地址返回给前端
|
// 6、将模型返回的图片地址返回给前端
|
||||||
Long collectionId = Objects.isNull(collectionElement) ? null : collectionElement.getCollectionId();
|
Long collectionId = Objects.isNull(collectionElement) ? null : collectionElement.getCollectionId();
|
||||||
return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems);
|
return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems);*/
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
@Transactional(rollbackFor = Exception.class)
|
||||||
|
public void processGenerateResult(String taskId, String url){
|
||||||
|
// 5、处理模型返回的数据
|
||||||
|
// 5.1 将相应的url保存到数据库
|
||||||
|
GenerateDetail generateDetail = new GenerateDetail();
|
||||||
|
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
|
||||||
|
Generate generate = selectByUniqueId(taskId);
|
||||||
|
String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(url, 24 * 60), Boolean.FALSE);
|
||||||
|
// 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过
|
||||||
|
List<Map<String, Long>> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generate.getLevel1Type());
|
||||||
|
if (!libraryIdList.isEmpty()) {
|
||||||
|
generateDetail.setIsLike((byte) 1);
|
||||||
|
generateDetail.setLibraryId(libraryIdList.get(0).get("library_id"));
|
||||||
|
generateCollectionItemVO.setIsLiked(Boolean.TRUE);
|
||||||
|
}
|
||||||
|
generateDetail.setUrl(url);
|
||||||
|
generateDetail.setGenerateId(generate.getId());
|
||||||
|
generateDetail.setCreateDate(LocalDateTime.now());
|
||||||
|
generateDetail.setMd5(md5);
|
||||||
|
generateDetailMapper.insert(generateDetail);
|
||||||
|
|
||||||
|
String key = generateResultKey + ":" + taskId;
|
||||||
|
Long expire = redisUtil.getExpire(key);
|
||||||
|
GenerateResultVO generateResultVO = new GenerateResultVO(generateDetail.getId(), url, "Success");
|
||||||
|
redisUtil.addToString(key, new Gson().toJson(generateResultVO), expire);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void validateGeneraType(Generate generate, String text, Long elementId, String generateType) {
|
private void validateGeneraType(Generate generate, String text, Long elementId, String generateType) {
|
||||||
@@ -315,7 +360,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
// public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||||
|
public List<String> prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||||
// 1、参数检查,判断必须参数是否为空
|
// 1、参数检查,判断必须参数是否为空
|
||||||
if (Objects.isNull(generateThroughImageTextDTO.getUserId())) {
|
if (Objects.isNull(generateThroughImageTextDTO.getUserId())) {
|
||||||
throw new BusinessException("userId cannot be empty");
|
throw new BusinessException("userId cannot be empty");
|
||||||
@@ -330,7 +376,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
|||||||
if (generateThroughImageTextDTO.getIsTestUser()){
|
if (generateThroughImageTextDTO.getIsTestUser()){
|
||||||
trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type());
|
trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type());
|
||||||
if (trialsCount >= 2){
|
if (trialsCount >= 2){
|
||||||
return new PrepareForGenerateVO(0);
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -341,9 +387,6 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
|||||||
// 2、生成唯一id 使用uuid
|
// 2、生成唯一id 使用uuid
|
||||||
String uuid = UUID.randomUUID().toString();
|
String uuid = UUID.randomUUID().toString();
|
||||||
|
|
||||||
// SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
|
|
||||||
// long snowflakeId = idWorker.nextId();
|
|
||||||
|
|
||||||
int num = 1;
|
int num = 1;
|
||||||
// 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id
|
// 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id
|
||||||
while ((redisUtil.isElementExistsInMap(resultMapKey, uuid) ||
|
while ((redisUtil.isElementExistsInMap(resultMapKey, uuid) ||
|
||||||
@@ -361,18 +404,25 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
|||||||
}
|
}
|
||||||
uuid = UUID.randomUUID().toString();
|
uuid = UUID.randomUUID().toString();
|
||||||
}
|
}
|
||||||
generateThroughImageTextDTO.setUniqueId(uuid);
|
|
||||||
|
ArrayList<String> taskIdList = new ArrayList<>();
|
||||||
|
for (int i = 1 ; i <= 4 ; i++){
|
||||||
|
String temp = uuid;
|
||||||
|
temp += "-" + i + "-" + generateThroughImageTextDTO.getUserId();
|
||||||
|
taskIdList.add(temp);
|
||||||
|
generateThroughImageTextDTO.setUniqueId(temp);
|
||||||
String jsonString = JSON.toJSONString(generateThroughImageTextDTO);
|
String jsonString = JSON.toJSONString(generateThroughImageTextDTO);
|
||||||
|
|
||||||
// 3、加入redis排队,便于获取实时排队信息
|
// 3、加入redis排队,便于获取实时排队信息
|
||||||
Double maxScore = redisUtil.getMaxScore(consumptionOrderKey);
|
Double maxScore = redisUtil.getMaxScore(consumptionOrderKey);
|
||||||
redisUtil.addToZSet(consumptionOrderKey, uuid, maxScore);
|
redisUtil.addToZSet(consumptionOrderKey, temp, maxScore);
|
||||||
|
|
||||||
// 4、将消息发布到MQ消息队列
|
// 4、将消息发布到MQ消息队列
|
||||||
rabbitMQService.publishMessageToGenerate(jsonString);
|
rabbitMQService.publishMessageToGenerate(jsonString);
|
||||||
|
}
|
||||||
|
|
||||||
// 5、返回唯一id
|
// 5、返回唯一id
|
||||||
return new PrepareForGenerateVO(uuid, 2 - trialsCount);
|
return taskIdList;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -432,6 +482,21 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
|||||||
return new GenerateCollectionVO(generateId, null, generatedCollectionItems);
|
return new GenerateCollectionVO(generateId, null, generatedCollectionItems);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<GenerateResultVO> getGenerateResultList(List<String> taskIdList) {
|
||||||
|
List<GenerateResultVO> results = new ArrayList<>();
|
||||||
|
taskIdList.forEach(taskId -> {
|
||||||
|
String key = generateResultKey + ":" + taskId;
|
||||||
|
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
|
||||||
|
if (!StringUtil.isNullOrEmpty(generateResultVO.getUrl())) {
|
||||||
|
generateResultVO.setUrl(minioUtil.getPresignedUrl(generateResultVO.getUrl(), CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
|
||||||
|
}
|
||||||
|
results.add(generateResultVO);
|
||||||
|
});
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public Generate selectByUniqueId(String uniqueId) {
|
public Generate selectByUniqueId(String uniqueId) {
|
||||||
QueryWrapper<Generate> qw = new QueryWrapper<>();
|
QueryWrapper<Generate> qw = new QueryWrapper<>();
|
||||||
qw.eq("unique_id", uniqueId);
|
qw.eq("unique_id", uniqueId);
|
||||||
|
|||||||
@@ -77,10 +77,11 @@ spring.redis.lettuce.pool.max-wait=5
|
|||||||
|
|
||||||
redis.key.orderForGenerate=OrderForGenerate
|
redis.key.orderForGenerate=OrderForGenerate
|
||||||
redis.key.generateCancelSet=GenerateCancelSet
|
redis.key.generateCancelSet=GenerateCancelSet
|
||||||
redis.key.generateExceptionMap=GenerateExceptionMap
|
redis.key.generateExceptionMap=Generate:Exception
|
||||||
redis.key.resultMap=ResultMap
|
redis.key.resultMap=ResultMap
|
||||||
redis.key.orderForSR=OrderForSR
|
redis.key.orderForSR=OrderForSR
|
||||||
redis.key.SRCancelSet=SRCancelSet
|
redis.key.SRCancelSet=SRCancelSet
|
||||||
redis.key.SRExceptionMap=SRExceptionMap
|
redis.key.SRExceptionMap=SRExceptionMap
|
||||||
redis.key.taskList=TaskList
|
redis.key.taskList=TaskList
|
||||||
redis.key.credits.pre-deduction=Credits:PreDeduction
|
redis.key.credits.pre-deduction=Credits:PreDeduction
|
||||||
|
redis.key.generateResult=Generate:Result
|
||||||
Reference in New Issue
Block a user