generate模型更换后的接口更改及异步获取结果

This commit is contained in:
2024-04-18 14:07:20 +08:00
parent 8d330e8ad9
commit 896120fea4
13 changed files with 222 additions and 85 deletions

View File

@@ -3,9 +3,10 @@ package com.ai.da.common.RabbitMQ;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.utils.RedisUtil;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.vo.GenerateCollectionVO;
import com.ai.da.model.vo.GenerateResultVO;
import com.ai.da.service.GenerateService;
import com.alibaba.fastjson.JSONObject;
import com.google.gson.Gson;
import com.rabbitmq.client.Channel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.core.Message;
@@ -17,7 +18,7 @@ import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.io.IOException;
import java.util.HashMap;
import java.util.Objects;
import java.util.Map;
@Slf4j
@@ -42,6 +43,9 @@ public class GenerateConsumer {
@Value("${redis.key.resultMap}")
private String resultMapKey;
@Value("${redis.key.generateResult}")
private String generateResultKey;
public void generate(Message msg, Channel channel, String consumerName) {
log.info("============start listening==========");
long start = System.currentTimeMillis();
@@ -63,20 +67,16 @@ public class GenerateConsumer {
// 2.2 将该消息从取消列表中删除
// redisUtil.removeFromSet(cancelSetKey, uniqueId);
} else {
/*try {
Thread.sleep(15000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}*/
GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO);
// GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO);
generateService.generateThroughImageText(generateThroughImageTextDTO);
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
if (!Objects.isNull(generateCollectionVO)) {
/*if (!Objects.isNull(generateCollectionVO)) {
HashMap<String, String> generateResult = new HashMap<>();
generateResult.put(uniqueId, JSONObject.toJSONString(generateCollectionVO));
// 将结果存在redis中 ,为空时不要存
redisUtil.addToMap(resultMapKey, generateResult);
}
}*/
}
} catch (BusinessException e) {
@@ -104,6 +104,34 @@ public class GenerateConsumer {
log.info("=============end listening===========");
}
public void processGenerateResult(Message msg, Channel channel){
log.info("============ProcessGenerateResult listening==========");
long start = System.currentTimeMillis();
Map<String, String> generateResult = JSONObject.parseObject(msg.getBody(), Map.class);
// log.info("tasks_id : {}, message : {}",generateResult.get("tasks_id"), generateResult.get("message") );
if (generateResult.get("status").equals("SUCCESS")){
String url = generateResult.get("data");
String taskId = generateResult.get("tasks_id");
generateService.processGenerateResult(taskId, url);
}else {
// 修改redis中的数据状态为exception
String key = generateResultKey + ":" + generateResult.get("tasks_id");
Long expire = redisUtil.getExpire(key);
redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(null, null, "Fail")), expire);
// 将异常信息存到exception中
HashMap<String, String> exceptionInfo = new HashMap<>();
exceptionInfo.put(generateResult.get("tasks_id"), generateResult.get("data"));
// 存redis
redisUtil.addToMap(exceptionMapKey, exceptionInfo);
}
long end = System.currentTimeMillis();
log.info("tasks_id : {}, message : {}, 执行时长: {} 毫秒",generateResult.get("tasks_id"), generateResult.get("message"), (end - start));
log.info("============ProcessGenerateResult End listening==========");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer1(Message msg, Channel channel) {
@@ -158,4 +186,9 @@ public class GenerateConsumer {
generate(msg, channel, "consumer 9");
}
@RabbitListener(queues = MQConfig.GENERATE_RESULT_QUEUE)
@RabbitHandler
public void getGenerateResult(Message msg, Channel channel){
processGenerateResult(msg, channel);
}
}

View File

@@ -10,14 +10,17 @@ public class MQConfig {
public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange";
// public static final String GENERATE_QUEUE = "generate-queue-prod";
// public static final String GENERATE_QUEUE = "generate-queue-test";
// public static final String GENERATE_QUEUE = "generate-queue-dev";
public static final String GENERATE_QUEUE = "generate-queue-local";
// public static final String GENERATE_QUEUE = "generate-queue-local";
public static final String GENERATE_QUEUE = "generate-queue-dev";
// public static final String SR_QUEUE = "SR-queue-dev";
public static final String SR_QUEUE = "SR-queue-local";
// public static final String SR_QUEUE = "SR-queue-local";
public static final String SR_QUEUE = "SR-queue-dev";
public static final String SR_RESULT_QUEUE = "SuperResolution-local";
// public static final String SR_RESULT_QUEUE = "SuperResolution-dev";
// public static final String SR_RESULT_QUEUE = "SuperResolution-local";
public static final String SR_RESULT_QUEUE = "SuperResolution-dev";
// public static final String GENERATE_RESULT_QUEUE = "GenerateImage-local";
public static final String GENERATE_RESULT_QUEUE = "GenerateImage-dev";
public MQConfig() {
}

View File

@@ -7,4 +7,7 @@ public class CommonConstant {
public static final Long CREDITS_EXPIRE_TIME = 2 * 24 * 60 * 60L;
// 单位 分钟
public static final Integer MINIO_IMAGE_EXPIRE_TIME = 24 * 60;
// 单位 秒 一天过期 in redis
public static final Long GENERATE_RESULT_EXPIRE_TIME = 24 * 60 * 60L;
}

View File

@@ -12,26 +12,33 @@ public enum GenerateModeEnum {
/**
* 通过文本生成
*/
TEXT(1, "text"),
TEXT(1, "text","txt2img"),
/**
* 通过图片生成
*/
IMAGE(2, "image"),
IMAGE(2, "image", "img2img"),
/**
* 通过文本和图片生成
*/
TEXT_IMAGE(2, "text-image");
TEXT_IMAGE(2, "text-image","txt2img");
private Integer code;
private String value;
private String type;
GenerateModeEnum(int code, String value) {
this.code = code;
this.value = value;
}
GenerateModeEnum(Integer code, String value, String type) {
this.code = code;
this.value = value;
this.type = type;
}
public static List<String> getGenerateModeList(){
return Stream.of(TEXT,IMAGE,TEXT_IMAGE).map(GenerateModeEnum::getValue).collect(Collectors.toList());
}

View File

@@ -25,7 +25,8 @@ public class AsyncCallerUtil {
}
public CompletableFuture<List<String>> callGenerateAsync(GenerateToPythonDTO generateToPython) {
return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython));
// return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython));
return null;
}
public List<String> generate(GenerateToPythonDTO generateToPython) {