diff --git a/pom.xml b/pom.xml
index 50c4b965..14913240 100644
--- a/pom.xml
+++ b/pom.xml
@@ -157,6 +157,19 @@
commons-fileupload
1.4
+
+
+
+ org.springframework.boot
+ spring-boot-starter-amqp
+
+
+
+
+ org.apache.commons
+ commons-pool2
+
+
diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java
new file mode 100644
index 00000000..880def93
--- /dev/null
+++ b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java
@@ -0,0 +1,45 @@
+package com.ai.da.common.RabbitMQ;
+
+import org.springframework.amqp.core.Binding;
+import org.springframework.amqp.core.BindingBuilder;
+import org.springframework.amqp.core.FanoutExchange;
+import org.springframework.amqp.core.Queue;
+import org.springframework.amqp.rabbit.connection.CachingConnectionFactory;
+import org.springframework.amqp.rabbit.connection.ConnectionFactory;
+import org.springframework.amqp.rabbit.core.RabbitAdmin;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.beans.factory.annotation.Value;
+
+@Configuration
+public class MQConfig {
+
+ public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange";
+ public static final String GENERATE_QUEUE = "generate-queue";
+
+ public MQConfig() {
+ }
+
+ @Bean
+ FanoutExchange fanoutRasaExchange() {
+ return new FanoutExchange(GENERATE_EXCHANGE_FANOUT);
+ }
+
+ /**
+ * 创建队列,使用工作模式,不用定义交换机
+ */
+ @Bean
+ public Queue queueRasa() {
+ return new Queue(GENERATE_QUEUE);
+ }
+
+ /**
+ * 将队列绑定到交换机上【队列订阅交换机】
+ */
+ @Bean
+ Binding bindingExchangeRasa() {
+ return BindingBuilder.bind(queueRasa()).to(fanoutRasaExchange());
+ }
+
+}
diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQConsumer.java b/src/main/java/com/ai/da/common/RabbitMQ/MQConsumer.java
new file mode 100644
index 00000000..bce55fcc
--- /dev/null
+++ b/src/main/java/com/ai/da/common/RabbitMQ/MQConsumer.java
@@ -0,0 +1,108 @@
+package com.ai.da.common.RabbitMQ;
+
+import com.ai.da.common.config.exception.BusinessException;
+import com.ai.da.common.utils.RedisUtil;
+import com.ai.da.model.dto.GenerateThroughImageTextDTO;
+import com.ai.da.model.vo.GenerateCollectionVO;
+import com.ai.da.service.GenerateService;
+import com.alibaba.fastjson.JSONObject;
+import com.rabbitmq.client.Channel;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.amqp.core.Message;
+import org.springframework.amqp.rabbit.annotation.RabbitHandler;
+import org.springframework.amqp.rabbit.annotation.RabbitListener;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+
+import javax.annotation.Resource;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Objects;
+
+
+@Slf4j
+@Component
+public class MQConsumer {
+
+ @Resource
+ private GenerateService generateService;
+
+ @Resource
+ private RedisUtil redisUtil;
+
+ @Value("${redis.key.consumptionOrder}")
+ private String consumptionOrderKey;
+
+ @Value("${redis.key.cancelSet}")
+ private String cancelSetKey;
+
+ @Value("${redis.key.exceptionMap}")
+ private String exceptionMapKey;
+
+ @Value("${redis.key.resultMap}")
+ private String resultMapKey;
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generate(Message msg, Channel channel) {
+ log.info("============start listening==========");
+
+ GenerateThroughImageTextDTO generateThroughImageTextDTO = JSONObject.parseObject(msg.getBody(), GenerateThroughImageTextDTO.class);
+ Long uniqueId = generateThroughImageTextDTO.getUniqueId();
+ // 1、将消息从redis排队队列中删除
+ redisUtil.removeFromZSet(consumptionOrderKey, String.valueOf(uniqueId));
+ try {
+ // 2、判断当前消息是否在取消列表中
+ Boolean isMember = redisUtil.isElementExistsInSet(cancelSetKey, String.valueOf(uniqueId));
+ if (isMember) {
+ try {
+ // 2.1 手动确认该消息
+ channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
+ } catch (IOException ex) {
+ log.error("手动确认,不返回队列重新消费");
+ }
+ // 2.2 将该消息从取消列表中删除
+ redisUtil.removeFromSet(cancelSetKey, String.valueOf(uniqueId));
+ } else {
+ try {
+ // 模拟耗时
+ Thread.sleep(40000);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO);
+ if (!Objects.isNull(generateCollectionVO)){
+ HashMap generateResult = new HashMap<>();
+ generateResult.put(String.valueOf(uniqueId), JSONObject.toJSONString(generateCollectionVO));
+ // 将结果存在redis中 ,为空时不要存
+ redisUtil.addToMap(resultMapKey, generateResult);
+ }
+
+ }
+ } catch (BusinessException e) {
+ log.error(e.getMessage());
+ // channel.basicNack() 为不确认deliveryTag对应的消息,第二个参数是否应用于多消息,第三个参数是否requeue
+ try {
+ // 第二个参数,是否批量确认消息,当传false时,只确认当前 deliveryTag对应的消息;当传true时,会确认当前及之前所有未确认的消息。
+ channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
+ } catch (IOException exception) {
+ log.error("手动确认,取消返回队列,不再重新消费");
+ }
+ // 将入参和错误信息存入数据库
+ String exceptionMessage = JSONObject.toJSONString(generateThroughImageTextDTO) + " Exception message : " + e.getMessage();
+ HashMap exceptionInfo = new HashMap<>();
+ exceptionInfo.put(String.valueOf(uniqueId), exceptionMessage);
+ // 存redis
+ redisUtil.addToMap(exceptionMapKey, exceptionInfo);
+ }
+
+// log.info(JSONObject.parseObject(msg.getBody(), GenerateThroughImageTextDTO.class).toString());
+// try {
+// Thread.sleep(10000);
+// } catch (InterruptedException e) {
+// throw new RuntimeException(e);
+// }
+ log.info("============end listening==========");
+ }
+
+}
diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java
new file mode 100644
index 00000000..40b7baea
--- /dev/null
+++ b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java
@@ -0,0 +1,39 @@
+package com.ai.da.common.RabbitMQ;
+
+import com.ai.da.service.RabbitMQService;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.amqp.core.AmqpTemplate;
+import org.springframework.amqp.rabbit.core.RabbitAdmin;
+import org.springframework.stereotype.Component;
+
+import javax.annotation.Resource;
+
+@Slf4j
+@Component
+public class MQPublisher {
+
+ private final String url = "http://localhost:15672/api/queues/%2f/generate-queue";
+
+ @Resource
+ private AmqpTemplate amqpTemplate;
+ @Resource
+ private RabbitMQService rabbitMQService;
+
+ public void sendGenerateMessage(String mm) {
+ log.info("send message:" + mm);
+ amqpTemplate.convertAndSend(MQConfig.GENERATE_QUEUE, mm);
+
+ }
+
+ public void getMsgCount() {
+//// AMQP.Queue.DeclareOk declareOk = rabbitTemplate.execute(channel -> channel.queueDeclarePassive(MQConfig.GENERATE_QUEUE));
+//
+// QueueInformation queueInfo = rabbitAdmin.getQueueInfo(MQConfig.GENERATE_QUEUE);
+//// assert queueInfo != null;
+//
+//// System.out.println(declareOk.getMessageCount());
+// System.out.println(queueInfo.getMessageCount());
+//// return declareOk.getMessageCount();
+// return queueInfo.getMessageCount();
+ }
+}
diff --git a/src/main/java/com/ai/da/common/config/RedisConfig.java b/src/main/java/com/ai/da/common/config/RedisConfig.java
new file mode 100644
index 00000000..5f01753c
--- /dev/null
+++ b/src/main/java/com/ai/da/common/config/RedisConfig.java
@@ -0,0 +1,42 @@
+package com.ai.da.common.config;
+
+import com.fasterxml.jackson.annotation.JsonAutoDetect;
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
+import com.fasterxml.jackson.annotation.PropertyAccessor;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.data.redis.connection.RedisConnectionFactory;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
+import org.springframework.data.redis.serializer.StringRedisSerializer;
+
+@Configuration
+public class RedisConfig {
+ @Bean(name = "redisTemplate")
+ public RedisTemplate getRedisTemplate(RedisConnectionFactory factory) {
+ RedisTemplate redisTemplate = new RedisTemplate<>();
+ redisTemplate.setConnectionFactory(factory);
+
+ StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
+
+ redisTemplate.setKeySerializer(stringRedisSerializer); // key的序列化类型
+
+ Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
+ ObjectMapper objectMapper = new ObjectMapper();
+ objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
+ // 方法过期,改为下面代码
+// objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
+ objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance,
+ ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
+ jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
+ jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
+
+ redisTemplate.setValueSerializer(jackson2JsonRedisSerializer); // value的序列化类型
+ redisTemplate.setHashKeySerializer(stringRedisSerializer);
+ redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
+ redisTemplate.afterPropertiesSet();
+ return redisTemplate;
+ }
+}
diff --git a/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java b/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java
new file mode 100644
index 00000000..8f5dac91
--- /dev/null
+++ b/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java
@@ -0,0 +1,73 @@
+package com.ai.da.common.utils;
+
+import com.ai.da.common.config.exception.BusinessException;
+import com.ai.da.model.dto.GenerateToPythonDTO;
+import com.ai.da.python.PythonService;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Component;
+
+import java.util.*;
+import java.util.concurrent.*;
+
+@Slf4j
+@Component
+public class AsyncCallerUtil {
+
+ public static Map waitingStatus = new HashMap<>();
+
+ private static PythonService pythonService;
+
+ @Autowired
+ public void setPythonService(PythonService pythonService) {
+ AsyncCallerUtil.pythonService = pythonService;
+ }
+
+ public CompletableFuture> callGenerateAsync(GenerateToPythonDTO generateToPython) {
+ return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython));
+ }
+
+ public List generate(GenerateToPythonDTO generateToPython, Long requestId) {
+ ScheduledExecutorService scheduledExecutorService = Executors.newScheduledThreadPool(1);
+ waitingStatus.put(requestId, true);
+ ScheduledFuture> timeoutTask = null;
+
+ try {
+ CompletableFuture> generateResult = callGenerateAsync(generateToPython);
+ // 10秒后第一次确认,之后每隔10秒确认一次用户选择结果
+ timeoutTask = scheduledExecutorService.scheduleAtFixedRate(() -> {
+ // 调用另一个接口获取用户的选择
+ if (!waitingStatus.get(requestId)) {
+ // 如果用户选择取消,则取消对generate的调用,cancel判断是否成功取消
+ generateResult.cancel(true);
+ waitingStatus.remove(requestId);
+ }
+ System.out.println("持续等待...... : " + DateUtil.getByTimeZone("Asia/Shanghai"));
+ }, 10, 10, TimeUnit.SECONDS);
+
+ System.out.println("开始阻塞 : " + DateUtil.getByTimeZone("Asia/Shanghai"));
+ // 阻塞,等待结果
+ List result = generateResult.get();
+ // 取消定时任务
+ timeoutTask.cancel(true);
+
+ // 处理结果
+ System.out.println("generate 响应: " + result);
+ System.out.println("schedule finish time : " + DateUtil.getByTimeZone("Asia/Shanghai"));
+ waitingStatus.remove(requestId);
+ return result;
+ } catch (InterruptedException | ExecutionException | BusinessException e) {
+ // 处理异常
+ log.error("发生错误 : " + e);
+ // 取消定时任务
+ assert timeoutTask != null;
+ timeoutTask.cancel(true);
+ throw new BusinessException("generate.interface.error");
+ } finally {
+ // 关闭线程池
+// executorService.shutdown();
+// scheduledExecutorService.shutdown();
+ }
+ }
+
+}
diff --git a/src/main/java/com/ai/da/common/utils/RedisUtil.java b/src/main/java/com/ai/da/common/utils/RedisUtil.java
new file mode 100644
index 00000000..5121e883
--- /dev/null
+++ b/src/main/java/com/ai/da/common/utils/RedisUtil.java
@@ -0,0 +1,126 @@
+package com.ai.da.common.utils;
+
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.data.redis.core.ZSetOperations;
+import org.springframework.stereotype.Component;
+import org.springframework.util.CollectionUtils;
+
+import javax.annotation.Resource;
+import java.util.Map;
+import java.util.Set;
+
+@Slf4j
+@Component
+public class RedisUtil {
+
+ @Resource
+ private RedisTemplate redisTemplate;
+
+ //- - - - - - - - - - - - - - - - - - - - - ZSet类型 - - - - - - - - - - - - - - - - - - - -
+
+ /**
+ * 向ZSet中添加元素
+ */
+ public void addToZSet(String key, String value, Double score) {
+ redisTemplate.opsForZSet().add(key, value, score);
+ }
+
+ /**
+ * 从ZSet中删除元素
+ */
+ public void removeFromZSet(String key, String value) {
+ redisTemplate.opsForZSet().remove(key, value);
+ }
+
+ /**
+ * 获取指定元素的当前排列顺序
+ */
+ public Long getRank(String key, String value) {
+ return redisTemplate.opsForZSet().rank(key, value);
+ }
+
+ /**
+ * 获取当前ZSet中的最大score
+ */
+ public Double getMaxScore(String key) {
+ Set> set = redisTemplate.opsForZSet().reverseRangeWithScores(key, 0, 0);
+
+ if (!CollectionUtils.isEmpty(set)) {
+ Double score = set.iterator().next().getScore();
+ return score + 1.0;
+ } else {
+ return 1.0;
+ }
+ }
+
+ /**
+ * 判断元素是否存在
+ */
+ public Boolean isElementExistsInZSet(String key, String value) {
+ return redisTemplate.opsForZSet().score(key, value) != null;
+ }
+
+ /**
+ * 获取当前ZSet中数据量的总和
+ */
+ public Long getZSetTotal(String key) {
+ return redisTemplate.opsForZSet().zCard(key);
+ }
+
+ //- - - - - - - - - - - - - - - - - - - - - set类型 - - - - - - - - - - - - - - - - - - - -
+
+ /**
+ * 将数据放入set缓存
+ */
+ public void addToSet(String key, String value) {
+ redisTemplate.opsForSet().add(key, value);
+ }
+
+ /**
+ * 弹出变量中的元素
+ */
+ public void removeFromSet(String key, String value) {
+ redisTemplate.opsForSet().remove(key, value);
+ }
+
+ /**
+ * 检查给定的元素是否在变量中。
+ */
+ public Boolean isElementExistsInSet(String key, String obj) {
+ return redisTemplate.opsForSet().isMember(key, obj);
+ }
+
+
+ //- - - - - - - - - - - - - - - - - - - - - hash类型 - - - - - - - - - - - - - - - - - - - -
+
+ /**
+ * 加入缓存
+ */
+ public void addToMap(String key, Map map) {
+ redisTemplate.opsForHash().putAll(key, map);
+ }
+
+ /**
+ * 验证指定 key 下 有没有指定的 hashkey
+ */
+ public Boolean isElementExistsInMap(String key, String hashKey) {
+ return redisTemplate.opsForHash().hasKey(key, hashKey);
+ }
+
+ /**
+ * 获取指定key的值string
+ */
+ public String getMapValue(String key1, String key2) {
+ return String.valueOf(redisTemplate.opsForHash().get(key1, key2));
+ }
+
+ /**
+ * 删除指定 hash 的 HashKey
+ *
+ * @return 删除成功的 数量
+ */
+ public Long removeFromMap(String key, String hashKeys) {
+ return redisTemplate.opsForHash().delete(key, hashKeys);
+ }
+}
diff --git a/src/main/java/com/ai/da/common/utils/SnowflakeUtil.java b/src/main/java/com/ai/da/common/utils/SnowflakeUtil.java
new file mode 100644
index 00000000..03298fb4
--- /dev/null
+++ b/src/main/java/com/ai/da/common/utils/SnowflakeUtil.java
@@ -0,0 +1,137 @@
+package com.ai.da.common.utils;
+
+
+public class SnowflakeUtil {
+
+ // ==============================Fields===========================================
+ /** 开始时间截 (2015-01-01) */
+ private final long twepoch = 1420041600000L;
+
+ /** 机器id所占的位数 */
+ private final long workerIdBits = 5L;
+
+ /** 数据标识id所占的位数 */
+ private final long datacenterIdBits = 5L;
+
+ /** 支持的最大机器id,结果是31 (这个移位算法可以很快的计算出几位二进制数所能表示的最大十进制数) */
+ private final long maxWorkerId = -1L ^ (-1L << workerIdBits);
+
+ /** 支持的最大数据标识id,结果是31 */
+ private final long maxDatacenterId = -1L ^ (-1L << datacenterIdBits);
+
+ /** 序列在id中占的位数 */
+ private final long sequenceBits = 12L;
+
+ /** 机器ID向左移12位 */
+ private final long workerIdShift = sequenceBits;
+
+ /** 数据标识id向左移17位(12+5) */
+ private final long datacenterIdShift = sequenceBits + workerIdBits;
+
+ /** 时间截向左移22位(5+5+12) */
+ private final long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
+
+ /** 生成序列的掩码,这里为4095 (0b111111111111=0xfff=4095) */
+ private final long sequenceMask = -1L ^ (-1L << sequenceBits);
+
+ /** 工作机器ID(0~31) */
+ private long workerId;
+
+ /** 数据中心ID(0~31) */
+ private long datacenterId;
+
+ /** 毫秒内序列(0~4095) */
+ private long sequence = 0L;
+
+ /** 上次生成ID的时间截 */
+ private long lastTimestamp = -1L;
+
+ //==============================Constructors=====================================
+ /**
+ * 构造函数
+ * @param workerId 工作ID (0~31)
+ * @param datacenterId 数据中心ID (0~31)
+ */
+ public SnowflakeUtil(long workerId, long datacenterId) {
+ if (workerId > maxWorkerId || workerId < 0) {
+ throw new IllegalArgumentException(String.format
+ ("worker Id can't be greater than %d or less than 0", maxWorkerId));
+ }
+ if (datacenterId > maxDatacenterId || datacenterId < 0) {
+ throw new IllegalArgumentException(String.format
+ ("datacenter Id can't be greater than %d or less than 0", maxDatacenterId));
+ }
+ this.workerId = workerId;
+ this.datacenterId = datacenterId;
+ }
+
+ // ==============================Methods==========================================
+ /**
+ * 获得下一个ID (该方法是线程安全的)
+ * @return SnowflakeId
+ */
+ public synchronized long nextId() {
+ long timestamp = timeGen();
+
+ //如果当前时间小于上一次ID生成的时间戳,说明系统时钟回退过这个时候应当抛出异常
+ if (timestamp < lastTimestamp) {
+ throw new RuntimeException(
+ String.format
+ ("Clock moved backwards. Refusing to generate id for %d milliseconds", lastTimestamp - timestamp));
+ }
+
+ //如果是同一时间生成的,则进行毫秒内序列
+ if (lastTimestamp == timestamp) {
+ sequence = (sequence + 1) & sequenceMask;
+ //毫秒内序列溢出
+ if (sequence == 0) {
+ //阻塞到下一个毫秒,获得新的时间戳
+ timestamp = tilNextMillis(lastTimestamp);
+ }
+ }
+ //时间戳改变,毫秒内序列重置
+ else {
+ sequence = 0L;
+ }
+
+ //上次生成ID的时间截
+ lastTimestamp = timestamp;
+
+ //移位并通过或运算拼到一起组成64位的ID
+ return ((timestamp - twepoch) << timestampLeftShift) //
+ | (datacenterId << datacenterIdShift) //
+ | (workerId << workerIdShift) //
+ | sequence;
+ }
+
+ /**
+ * 阻塞到下一个毫秒,直到获得新的时间戳
+ * @param lastTimestamp 上次生成ID的时间截
+ * @return 当前时间戳
+ */
+ protected long tilNextMillis(long lastTimestamp) {
+ long timestamp = timeGen();
+ while (timestamp <= lastTimestamp) {
+ timestamp = timeGen();
+ }
+ return timestamp;
+ }
+
+ /**
+ * 返回以毫秒为单位的当前时间
+ * @return 当前时间(毫秒)
+ */
+ protected long timeGen() {
+ return System.currentTimeMillis();
+ }
+
+ //==============================Test=============================================
+ /** 测试 */
+ public static void main(String[] args) {
+ SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
+ long id = idWorker.nextId();
+ System.out.println("id:"+id);
+ //id:768842202204864512
+ }
+}
+
diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java
index 5fa712cc..27a20693 100644
--- a/src/main/java/com/ai/da/controller/GenerateController.java
+++ b/src/main/java/com/ai/da/controller/GenerateController.java
@@ -53,4 +53,22 @@ public class GenerateController {
return Response.success(generateService.generateDislike(generateDetailId, timeZone));
}
+ @PostMapping("/prepare")
+ public Response prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO){
+ return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO));
+ }
+
+ @ApiOperation(value = "取消继续生成")
+ @PostMapping("/stopWaiting")
+ public Response stopWaiting(@RequestParam("uniqueId") Long uniqueId){
+ generateService.cancelGenerate(uniqueId);
+ return Response.success("stop waiting successfully");
+ }
+
+ @ApiOperation(value = "获取生成结果")
+ @PostMapping("/result")
+ public Response getGenerateResult(@RequestParam("uniqueId") Long uniqueId){
+ GenerateCollectionVO generateResult = generateService.getGenerateResult(uniqueId);
+ return Response.success(generateResult);
+ }
}
diff --git a/src/main/java/com/ai/da/mapper/entity/Generate.java b/src/main/java/com/ai/da/mapper/entity/Generate.java
index fc1f4e33..f45cb144 100644
--- a/src/main/java/com/ai/da/mapper/entity/Generate.java
+++ b/src/main/java/com/ai/da/mapper/entity/Generate.java
@@ -27,6 +27,11 @@ public class Generate {
*/
private Long accountId;
+ /**
+ * 唯一id,用于保持消息的唯一性
+ */
+ private Long uniqueId;
+
/**
* Sketchboard Printboard
*/
diff --git a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java
index ca12d04a..df1c9ab9 100644
--- a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java
+++ b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java
@@ -5,10 +5,14 @@ import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.NotBlank;
+import javax.validation.constraints.NotNull;
@Data
@ApiModel("GenerateThroughImageTextDTO")
public class GenerateThroughImageTextDTO {
+ @NotNull(message = "userId cannot be empty")
+ @ApiModelProperty("用户id")
+ Long userId;
@ApiModelProperty("caption")
String text;
@@ -40,4 +44,7 @@ public class GenerateThroughImageTextDTO {
@NotBlank(message = "timeZone cannot be empty!")
@ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取")
String timeZone;
+
+ @ApiModelProperty("唯一id,用于保持消息唯一性")
+ Long uniqueId;
}
diff --git a/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java b/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java
new file mode 100644
index 00000000..29c77243
--- /dev/null
+++ b/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java
@@ -0,0 +1,25 @@
+package com.ai.da.model.dto;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class GenerateToPythonDTO {
+
+ private Long user_id;
+
+ private String image_url;
+
+ private String category;
+
+ private String str;
+
+ private Integer mode;
+
+ private String version;
+
+ private String gender;
+}
diff --git a/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java b/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java
index 433e86c8..aa8fb7c6 100644
--- a/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java
+++ b/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java
@@ -19,6 +19,13 @@ public class GenerateCollectionVO {
@ApiModelProperty("生成的图片信息")
private List generatedCollectionItems;
+ @ApiModelProperty("在当前队列中的排序")
+ private Long rankPosition;
+
+ public GenerateCollectionVO(Long rankPosition) {
+ this.rankPosition = rankPosition;
+ }
+
public GenerateCollectionVO(Long generateId, Long collectionId, List generatedCollectionItems) {
this.generateId = generateId;
this.collectionId = collectionId;
diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java
index 578de9ef..e12eef12 100644
--- a/src/main/java/com/ai/da/python/PythonService.java
+++ b/src/main/java/com/ai/da/python/PythonService.java
@@ -23,7 +23,6 @@ import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
-import io.netty.util.internal.StringUtil;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.beans.factory.annotation.Value;
@@ -2238,7 +2237,7 @@ public class PythonService {
throw new BusinessException("system error!");
}
- public List generateSketchOrPrint(Long userId, String url, String category, String text, int mode, String modelName, String gender) {
+ public List generateSketchOrPrint(GenerateToPythonDTO generateToPythonDTO) {
//限流校验
AccessLimitUtils.validate("generateSketchOrPrint", 5);
OkHttpClient client = new OkHttpClient().newBuilder()
@@ -2248,18 +2247,10 @@ public class PythonService {
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
MediaType mediaType = MediaType.parse("application/json");
- Map content = Maps.newHashMap();
- content.put("user_id", userId);
- content.put("image_url", url);
- content.put("category", category);
- content.put("mode", mode);
- content.put("str", text);
- content.put("version", "1");
- content.put("gender", gender);
- RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content, SerializerFeature.WriteMapNullValue));
+ RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue));
Request request = new Request.Builder()
-// .url(accessPythonIp + ":2828/aida/diffusion")
// .url("http://18.167.251.121:9992")
+// .url("http://127.0.0.1:5000/api/diffusion")
.url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion")
.method("POST", body)
.addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
@@ -2268,10 +2259,11 @@ public class PythonService {
Response response = null;
String bodyString ;
try {
- log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(content, SerializerFeature.WriteMapNullValue));
+ log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue));
response = client.newCall(request).execute();
} catch (IOException ioException) {
log.error("PythonService##generateSketchOrPrint异常###{}", ExceptionUtil.getThrowableList(ioException));
+ throw new BusinessException("generate.interface.error");
}
//去除限流
AccessLimitUtils.validateOut("generateSketchOrPrint");
diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java
index 0d92a07f..9f17e7ea 100644
--- a/src/main/java/com/ai/da/service/GenerateService.java
+++ b/src/main/java/com/ai/da/service/GenerateService.java
@@ -24,4 +24,13 @@ public interface GenerateService extends IService {
void updateLikeStatusBatch(List generateDetailIdList, Byte hasLike, Long libraryId, String timeZone);
List selectBatchByLibraryId(List libraryId);
+
+ GenerateCollectionVO getGenerateResult(Long uniqueId);
+
+ Long prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO);
+
+ Long getRankPosition(Long uniqueId);
+
+ void cancelGenerate(Long uniqueId);
+
}
diff --git a/src/main/java/com/ai/da/service/RabbitMQService.java b/src/main/java/com/ai/da/service/RabbitMQService.java
new file mode 100644
index 00000000..b03a578d
--- /dev/null
+++ b/src/main/java/com/ai/da/service/RabbitMQService.java
@@ -0,0 +1,11 @@
+package com.ai.da.service;
+
+import org.springframework.stereotype.Service;
+
+@Service
+public interface RabbitMQService {
+
+ void publishMessage(String message);
+
+ Integer getMessageCount(String queueUrl);
+}
diff --git a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java
index 1c6307d9..a44defca 100644
--- a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java
+++ b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java
@@ -810,9 +810,11 @@ public class CollectionElementServiceImpl extends ServiceImpl i
@Resource
private MinioUtil minioUtil;
+ @Resource
+ private RabbitMQService rabbitMQService;
+
+ @Resource
+ private RedisUtil redisUtil;
+
+ @Value("${redis.key.consumptionOrder}")
+ private String consumptionOrderKey;
+
+ @Value("${redis.key.cancelSet}")
+ private String cancelSetKey;
+
+ @Value("${redis.key.exceptionMap}")
+ private String exceptionMapKey;
+
+ @Value("${redis.key.resultMap}")
+ private String resultMapKey;
+
@Override
public GenerateCaptionVO generateCaption(Long sketchElementId) {
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
@@ -69,16 +91,13 @@ public class GenerateServiceImpl extends ServiceImpl i
@Transactional(rollbackFor = Exception.class)
public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、获取用户信息
- AuthPrincipalVo userHolder = UserContext.getUserHolder();
+ Long accountId = generateThroughImageTextDTO.getUserId();
String generateType = generateThroughImageTextDTO.getGenerateType();
- Long accountId = userHolder.getId();
- if (!GenerateModeEnum.getGenerateModeList().contains(generateType)){
- throw new BusinessException("unknown.generate.type");
- }
// 2、判断必须入参是否为非空
Generate generate = new Generate();
generate.setAccountId(accountId);
+ generate.setUniqueId(generateThroughImageTextDTO.getUniqueId());
generate.setLevel1Type(generateThroughImageTextDTO.getLevel1Type());
// 当level1type是sketchboard时,存数据库需要加上当前性别
generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ?
@@ -87,27 +106,30 @@ public class GenerateServiceImpl extends ServiceImpl i
generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getVersion());
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
-
String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(generate, text, elementId,generateType);
- // 3、将请求信息落库
- // 3.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
+ // 2.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type());
- // 3.2 将本次generate的请求信息添加到t_generate表中
- save(generate);
-
- // 4、向模型发起请求
+ // 3、向模型发起请求
int mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ?
GenerateModeEnum.TEXT.getCode() :
GenerateModeEnum.TEXT_IMAGE.getCode();
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
-// text = !StringUtil.isNullOrEmpty(text) && generateThroughImageTextDTO.getVersion().equals("1") ? "painting style, " + text : text;
- List generatedSketchUrl = pythonService.generateSketchOrPrint(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
- category, text, mode, generateThroughImageTextDTO.getVersion(), generateThroughImageTextDTO.getGender());
+ AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
+ List generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? null : collectionElement.getUrl(),
+ category, text, mode, "1", generateThroughImageTextDTO.getGender()),0L);
+// List generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
+// category, text, mode, "1", generateThroughImageTextDTO.getGender()));
+ if (CollectionUtils.isEmpty(generatedSketchUrl)){
+ return null;
+ }
+
+ // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
+ save(generate);
// 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库
@@ -283,4 +305,122 @@ public class GenerateServiceImpl extends ServiceImpl i
return generateDetailMapper.selectList(qw);
}
+
+ @Override
+ public Long prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
+ // 1、参数检查,判断必须参数是否为空
+ if (Objects.isNull(generateThroughImageTextDTO.getUserId())){
+ throw new BusinessException("userId cannot be empty");
+ }
+ String generateType = generateThroughImageTextDTO.getGenerateType();
+ if (!GenerateModeEnum.getGenerateModeList().contains(generateType)){
+ throw new BusinessException("unknown.generate.type");
+ }
+ String text = generateThroughImageTextDTO.getText();
+ Long elementId = generateThroughImageTextDTO.getCollectionElementId();
+ validateGeneraType(new Generate(), text, elementId,generateType);
+
+ // 2、生成唯一id
+ SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
+ long snowflakeId = idWorker.nextId();
+
+ if (AsyncCallerUtil.waitingStatus.containsKey(snowflakeId)){
+ snowflakeId = idWorker.nextId();
+ }
+ generateThroughImageTextDTO.setUniqueId(snowflakeId);
+ String jsonString = JSON.toJSONString(generateThroughImageTextDTO);
+
+ // 3、加入redis排队,便于获取实时排队信息
+ Double maxScore = redisUtil.getMaxScore(consumptionOrderKey);
+ redisUtil.addToZSet(consumptionOrderKey, String.valueOf(snowflakeId),maxScore);
+
+ // 4、将消息发布到MQ消息队列
+ rabbitMQService.publishMessage(jsonString);
+
+ // 5、返回唯一id
+ return snowflakeId;
+ }
+
+ @Override
+ public Long getRankPosition(Long uniqueId) {
+ return redisUtil.getRank(consumptionOrderKey, String.valueOf(uniqueId));
+ }
+
+ @Override
+ public GenerateCollectionVO getGenerateResult(Long uniqueId) {
+ // 1、判断该请求是否已经异常
+ Boolean isMember = redisUtil.isElementExistsInMap(exceptionMapKey, String.valueOf(uniqueId));
+ if (isMember){
+ throw new BusinessException("generate.interface.error");
+ }
+
+ // 2、判断该请求是否还在排队
+ Boolean existsInZSet = redisUtil.isElementExistsInZSet(consumptionOrderKey, String.valueOf(uniqueId));
+ if (existsInZSet){
+ // 排队中,给出当前排序位置
+ return new GenerateCollectionVO(getRankPosition(uniqueId) + 1L);
+ }
+
+ // 3、判断redis中有没有
+ boolean hasHashKey = redisUtil.isElementExistsInMap(resultMapKey, String.valueOf(uniqueId));
+ if (hasHashKey){
+ // 3.1 有直接从redis中拿
+ String resultString = redisUtil.getMapValue(resultMapKey, String.valueOf(uniqueId));
+ return JSONObject.parseObject(resultString,GenerateCollectionVO.class);
+ }
+
+ // 3.2 判断数据库中有没有
+ Generate generate = selectByUniqueId(uniqueId);
+ if (Objects.isNull(generate)){
+ // 3.3 还没执行完,给出当前位置
+ return new GenerateCollectionVO(0L);
+ }
+ Long generateId = generate.getId();
+ QueryWrapper qw = new QueryWrapper<>();
+ qw.eq("generate_id",generateId);
+ List generateDetails = generateDetailMapper.selectList(qw);
+ if (CollectionUtils.isEmpty(generateDetails)){
+ // 会有这种情况吗?存到generate中,但是还没存到generateDetail中
+ return new GenerateCollectionVO(0L);
+ }
+
+ List generatedCollectionItems = new ArrayList<>();
+ generateDetails.forEach(item -> {
+ GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
+ generateCollectionItemVO.setGenerateItemId(item.getId());
+ generateCollectionItemVO.setGenerateItemUrl(minioUtil.getPresignedUrl(item.getUrl(), 24 * 60));
+ generatedCollectionItems.add(generateCollectionItemVO);
+ });
+
+ return new GenerateCollectionVO(generateId, null, generatedCollectionItems);
+ }
+
+ public Generate selectByUniqueId(Long uniqueId){
+ QueryWrapper qw = new QueryWrapper<>();
+ qw.eq("unique_id",uniqueId);
+
+ return getOne(qw);
+ }
+
+ @Override
+ public void cancelGenerate(Long uniqueId) {
+ // 1、确认当前消息是否还在排队中
+ Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, String.valueOf(uniqueId));
+ if (exists){
+ // 1.1、将需要取消的唯一id加入redis,以便及时取消生成
+ redisUtil.addToSet(cancelSetKey, String.valueOf(uniqueId));
+ // 1.2 将需要取消的id从redis的ConsumptionOrder中删除
+ redisUtil.removeFromZSet(consumptionOrderKey, String.valueOf(uniqueId));
+ }else {
+ // 2、判断该消息是否异常
+ boolean hasKey = redisUtil.isElementExistsInMap(exceptionMapKey, String.valueOf(uniqueId));
+ // 3、判断该消息是否已经消费结束
+ Boolean existsInResult = redisUtil.isElementExistsInMap(resultMapKey, String.valueOf(uniqueId));
+ if (!hasKey && !existsInResult){
+ // 设置取等待状态为false
+ AsyncCallerUtil.waitingStatus.put(uniqueId,false);
+ // 3、直接发送取消请求到python端
+ }
+ }
+ }
}
diff --git a/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java b/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java
new file mode 100644
index 00000000..41743964
--- /dev/null
+++ b/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java
@@ -0,0 +1,76 @@
+package com.ai.da.service.impl;
+
+import cn.hutool.core.exceptions.ExceptionUtil;
+import com.ai.da.common.RabbitMQ.MQPublisher;
+import com.ai.da.common.config.exception.BusinessException;
+import com.ai.da.service.RabbitMQService;
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONObject;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.MediaType;
+import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.Response;
+import org.springframework.stereotype.Service;
+
+import javax.annotation.Resource;
+import java.io.IOException;
+import java.net.HttpURLConnection;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+
+@Slf4j
+@Service
+public class RabbitMQServiceImpl implements RabbitMQService {
+
+ @Resource
+ private MQPublisher mqPublisher;
+
+ @Override
+ public void publishMessage(String message) {
+ mqPublisher.sendGenerateMessage(message);
+ }
+
+ @Override
+ public Integer getMessageCount(String queueUrl) {
+
+ String url = "http://localhost:15672/api/queues/%2f/generate-queue";
+
+ OkHttpClient client = new OkHttpClient().newBuilder()
+ .connectTimeout(30, TimeUnit.SECONDS)
+ .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
+ .readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
+ .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
+ .build();
+ Request request = new Request.Builder()
+ .url(queueUrl)
+ .method("GET",null)
+ .addHeader("Authorization", "Basic Z3Vlc3Q6Z3Vlc3Q=")
+ .build();
+ Response response = null;
+ try {
+ response = client.newCall(request).execute();
+ } catch (IOException ioException) {
+ log.error("RabbitMQService##" + "getMessage异常###{}", ExceptionUtil.getThrowableList(ioException));
+ }
+
+ String bodyString;
+ // 生成失败
+ if (Objects.isNull(response) || Objects.isNull(response.body())) {
+ log.error("RabbitMQService##getMessageCount异常###{}", "response or body is empty!");
+ throw new BusinessException("compose-layer.interface.exception");
+ }else if (response.code() != HttpURLConnection.HTTP_OK){
+ log.error("RabbitMQService##getMessageCount异常###{}", "Response error!Response code ## " + response.code() + " ##");
+ throw new BusinessException("compose-layer.interface.exception");
+ } else {
+ try {
+ bodyString = response.body().string();
+ } catch (IOException e) {
+ throw new BusinessException("compose-layer.interface.exception");
+ }
+ }
+ JSONObject jsonObject = JSON.parseObject(bodyString);
+ String messageCount = jsonObject.get("messages").toString();
+ return Integer.parseInt(messageCount);
+ }
+}
diff --git a/src/main/resources/application-test.properties b/src/main/resources/application-test.properties
index 9027b77f..da9009b5 100644
--- a/src/main/resources/application-test.properties
+++ b/src/main/resources/application-test.properties
@@ -62,6 +62,24 @@ minio.bucketName.results=aida-results
minio.bucketName.sysImage=aida-sys-image
minio.bucketName.users=aida-users
minio.bucketName.collectionElement=aida-collection-element
+redirect_url=http://18.167.251.121:7788
+spring.rabbitmq.host=18.167.251.121
+spring.rabbitmq.port=5672
+spring.rabbitmq.username=rabbit
+spring.rabbitmq.password=123456
+spring.rabbitmq.virtual-host=/
+spring.redis.host=127.0.0.1
+spring.redis.port=6379
+spring.redis.database=1
+spring.redis.lettuce.pool.max-active=8
+spring.redis.lettuce.pool.max-idle=8
+spring.redis.lettuce.pool.min-idle=0
+spring.redis.lettuce.pool.max-wait=5
+
+redis.key.consumptionOrder=ConsumptionOrder
+redis.key.cancelSet=CancelSet
+redis.key.exceptionMap=ExceptionMap
+redis.key.resultMap=ResultMap
diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties
index 8550475d..ec94939b 100644
--- a/src/main/resources/application.properties
+++ b/src/main/resources/application.properties
@@ -1,8 +1,8 @@
#����application-test�ļ�(���Ի���)
-#spring.profiles.active=test
+spring.profiles.active=test
#����application-prod�ļ�(��������)
-spring.profiles.active=prod
+#spring.profiles.active=prod
#����application-dev�ļ�(��������)
#spring.profiles.active=dev