diff --git a/docker-compose.yml b/docker-compose.yml index c1843a76..a37c5319 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,8 +3,8 @@ services: aida_back: container_name: stable-version-aida-back build: . - volumes: - # 数据挂载 - - /workspace/home/aida/file/:/workspace/home/aida/file/ +# volumes: +# # 数据挂载 +# - /workspace/home/aida/file/:/workspace/home/aida/file/ ports: - "10086:5567" \ No newline at end of file diff --git a/pom.xml b/pom.xml index 50c4b965..14913240 100644 --- a/pom.xml +++ b/pom.xml @@ -157,6 +157,19 @@ commons-fileupload 1.4 + + + + org.springframework.boot + spring-boot-starter-amqp + + + + + org.apache.commons + commons-pool2 + + diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java new file mode 100644 index 00000000..d4e24508 --- /dev/null +++ b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java @@ -0,0 +1,45 @@ +package com.ai.da.common.RabbitMQ; + +import org.springframework.amqp.core.Binding; +import org.springframework.amqp.core.BindingBuilder; +import org.springframework.amqp.core.FanoutExchange; +import org.springframework.amqp.core.Queue; +import org.springframework.amqp.rabbit.connection.CachingConnectionFactory; +import org.springframework.amqp.rabbit.connection.ConnectionFactory; +import org.springframework.amqp.rabbit.core.RabbitAdmin; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.beans.factory.annotation.Value; + +@Configuration +public class MQConfig { + + public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange"; + public static final String GENERATE_QUEUE = "generate-queue-prod"; + + public MQConfig() { + } + +// @Bean +// FanoutExchange fanoutRasaExchange() { +// return new FanoutExchange(GENERATE_EXCHANGE_FANOUT); +// } + + /** + * 创建队列,使用工作模式,不用定义交换机 + */ + @Bean + public Queue queueRasa() { + return new Queue(GENERATE_QUEUE); + } + + /** + * 将队列绑定到交换机上【队列订阅交换机】 + */ +// @Bean +// Binding bindingExchangeRasa() { +// return BindingBuilder.bind(queueRasa()).to(fanoutRasaExchange()); +// } + +} diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQConsumer.java b/src/main/java/com/ai/da/common/RabbitMQ/MQConsumer.java new file mode 100644 index 00000000..77f46dbd --- /dev/null +++ b/src/main/java/com/ai/da/common/RabbitMQ/MQConsumer.java @@ -0,0 +1,161 @@ +package com.ai.da.common.RabbitMQ; + +import com.ai.da.common.config.exception.BusinessException; +import com.ai.da.common.utils.RedisUtil; +import com.ai.da.model.dto.GenerateThroughImageTextDTO; +import com.ai.da.model.vo.GenerateCollectionVO; +import com.ai.da.service.GenerateService; +import com.alibaba.fastjson.JSONObject; +import com.rabbitmq.client.Channel; +import lombok.extern.slf4j.Slf4j; +import org.springframework.amqp.core.Message; +import org.springframework.amqp.rabbit.annotation.RabbitHandler; +import org.springframework.amqp.rabbit.annotation.RabbitListener; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +import javax.annotation.Resource; +import java.io.IOException; +import java.util.HashMap; +import java.util.Objects; + + +@Slf4j +@Component +public class MQConsumer { + + @Resource + private GenerateService generateService; + + @Resource + private RedisUtil redisUtil; + + @Value("${redis.key.consumptionOrder}") + private String consumptionOrderKey; + + @Value("${redis.key.cancelSet}") + private String cancelSetKey; + + @Value("${redis.key.exceptionMap}") + private String exceptionMapKey; + + @Value("${redis.key.resultMap}") + private String resultMapKey; + + public void generate(Message msg, Channel channel, String consumerName) { + log.info("============start listening=========="); + long start = System.currentTimeMillis(); + + GenerateThroughImageTextDTO generateThroughImageTextDTO = JSONObject.parseObject(msg.getBody(), GenerateThroughImageTextDTO.class); + String uniqueId = generateThroughImageTextDTO.getUniqueId(); + log.info("From " + consumerName + " : " + uniqueId); + + try { + // 2、判断当前消息是否在取消列表中 + Boolean isMember = redisUtil.isElementExistsInSet(cancelSetKey, uniqueId); + if (isMember) { + try { + // 2.1 手动确认该消息 + channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false); + } catch (IOException ex) { + log.error("手动确认,不返回队列重新消费"); + } + // 2.2 将该消息从取消列表中删除 +// redisUtil.removeFromSet(cancelSetKey, uniqueId); + } else { + /*try { + Thread.sleep(15000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + }*/ + GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO); + // 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除 + redisUtil.removeFromZSet(consumptionOrderKey, uniqueId); + if (!Objects.isNull(generateCollectionVO)) { + HashMap generateResult = new HashMap<>(); + generateResult.put(uniqueId, JSONObject.toJSONString(generateCollectionVO)); + // 将结果存在redis中 ,为空时不要存 + redisUtil.addToMap(resultMapKey, generateResult); + } + + } + } catch (BusinessException e) { + log.error(e.getMsg()); + // channel.basicNack() 为不确认deliveryTag对应的消息,第二个参数是否应用于多消息,第三个参数是否requeue + try { + // 第二个参数,是否批量确认消息,当传false时,只确认当前 deliveryTag对应的消息;当传true时,会确认当前及之前所有未确认的消息。 + channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false); + // 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除 + redisUtil.removeFromZSet(consumptionOrderKey, uniqueId); + } catch (IOException exception) { + log.error("手动确认,取消返回队列,不再重新消费"); + } + // 将入参和错误信息存入数据库 + String exceptionMessage = JSONObject.toJSONString(generateThroughImageTextDTO) + + " Exception message : " + e.getMsg(); + HashMap exceptionInfo = new HashMap<>(); + exceptionInfo.put(String.valueOf(uniqueId), exceptionMessage); + // 存redis + redisUtil.addToMap(exceptionMapKey, exceptionInfo); + } + long end = System.currentTimeMillis(); + + log.info(" task_id: " + uniqueId + "----------" + consumerName + " 执行时长:" + (end - start) + "毫秒"); + log.info("=============end listening==========="); + } + + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) + @RabbitHandler + public void generateConsumer1(Message msg, Channel channel) { + generate(msg, channel, "consumer 1"); + } + + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) + @RabbitHandler + public void generateConsumer2(Message msg, Channel channel) { + generate(msg, channel, "consumer 2"); + } + + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) + @RabbitHandler + public void generateConsumer3(Message msg, Channel channel) { + generate(msg, channel, "consumer 3"); + } + + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) + @RabbitHandler + public void generateConsumer4(Message msg, Channel channel) { + generate(msg, channel, "consumer 4"); + } + + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) + @RabbitHandler + public void generateConsumer5(Message msg, Channel channel) { + generate(msg, channel, "consumer 5"); + } + + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) + @RabbitHandler + public void generateConsumer6(Message msg, Channel channel) { + generate(msg, channel, "consumer 6"); + } + + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) + @RabbitHandler + public void generateConsumer7(Message msg, Channel channel) { + generate(msg, channel, "consumer 7"); + } + + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) + @RabbitHandler + public void generateConsumer8(Message msg, Channel channel) { + generate(msg, channel, "consumer 8"); + } + + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) + @RabbitHandler + public void generateConsumer9(Message msg, Channel channel) { + generate(msg, channel, "consumer 9"); + } + +} diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java new file mode 100644 index 00000000..b0429110 --- /dev/null +++ b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java @@ -0,0 +1,24 @@ +package com.ai.da.common.RabbitMQ; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.amqp.core.AmqpTemplate; +import org.springframework.stereotype.Component; + +import javax.annotation.Resource; + +@Slf4j +@Component +public class MQPublisher { + + private final String url = "http://localhost:15672/api/queues/%2f/generate-queue"; + + @Resource + private AmqpTemplate amqpTemplate; + + public void sendGenerateMessage(String mm) { + log.info("send message:" + mm); + amqpTemplate.convertAndSend(MQConfig.GENERATE_QUEUE, mm); + + } + +} diff --git a/src/main/java/com/ai/da/common/config/RedisConfig.java b/src/main/java/com/ai/da/common/config/RedisConfig.java new file mode 100644 index 00000000..5f01753c --- /dev/null +++ b/src/main/java/com/ai/da/common/config/RedisConfig.java @@ -0,0 +1,42 @@ +package com.ai.da.common.config; + +import com.fasterxml.jackson.annotation.JsonAutoDetect; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.PropertyAccessor; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.redis.connection.RedisConnectionFactory; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; +import org.springframework.data.redis.serializer.StringRedisSerializer; + +@Configuration +public class RedisConfig { + @Bean(name = "redisTemplate") + public RedisTemplate getRedisTemplate(RedisConnectionFactory factory) { + RedisTemplate redisTemplate = new RedisTemplate<>(); + redisTemplate.setConnectionFactory(factory); + + StringRedisSerializer stringRedisSerializer = new StringRedisSerializer(); + + redisTemplate.setKeySerializer(stringRedisSerializer); // key的序列化类型 + + Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class); + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY); + // 方法过期,改为下面代码 +// objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL); + objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, + ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY); + jackson2JsonRedisSerializer.setObjectMapper(objectMapper); + jackson2JsonRedisSerializer.setObjectMapper(objectMapper); + + redisTemplate.setValueSerializer(jackson2JsonRedisSerializer); // value的序列化类型 + redisTemplate.setHashKeySerializer(stringRedisSerializer); + redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer); + redisTemplate.afterPropertiesSet(); + return redisTemplate; + } +} diff --git a/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java b/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java new file mode 100644 index 00000000..35039431 --- /dev/null +++ b/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java @@ -0,0 +1,73 @@ +package com.ai.da.common.utils; + +import com.ai.da.common.config.exception.BusinessException; +import com.ai.da.model.dto.GenerateToPythonDTO; +import com.ai.da.python.PythonService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.*; +import java.util.concurrent.*; + +@Slf4j +@Component +public class AsyncCallerUtil { + + public static Map waitingStatus = new HashMap<>(); + + private static PythonService pythonService; + + @Autowired + public void setPythonService(PythonService pythonService) { + AsyncCallerUtil.pythonService = pythonService; + } + + public CompletableFuture> callGenerateAsync(GenerateToPythonDTO generateToPython) { + return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython)); + } + + public List generate(GenerateToPythonDTO generateToPython) { + ScheduledExecutorService scheduledExecutorService = Executors.newScheduledThreadPool(5); + String taskId = generateToPython.getTasks_id(); + ScheduledFuture timeoutTask = null; + if (!waitingStatus.containsKey(taskId)) waitingStatus.put(taskId, true); + + try { + CompletableFuture> generateResult = callGenerateAsync(generateToPython); + // 5秒后第一次确认,之后每隔10秒确认一次用户选择结果 + timeoutTask = scheduledExecutorService.scheduleAtFixedRate(() -> { + // 调用另一个接口获取用户的选择 + if (!waitingStatus.get(taskId)) { + // 如果用户选择取消,则取消对generate的调用 + generateResult.cancel(true); + waitingStatus.remove(taskId); + } else log.info("===============持续等待==============="); + }, 5, 10, TimeUnit.SECONDS); + + log.info("阻塞等待结果..."); + // 阻塞,等待结果 + List result = generateResult.get(); + // 取消定时任务 + timeoutTask.cancel(true); + waitingStatus.remove(taskId); + return result; + } catch (CancellationException e) { + // generateResult.cancel(true);通过抛出异常取消该任务 + log.info("==========成功取消generate任务=========="); + return null; + } catch (InterruptedException | ExecutionException | BusinessException e) { + // 处理异常 + log.error("发生错误 : " + e); + // 取消定时任务 + assert timeoutTask != null; + timeoutTask.cancel(true); + throw new BusinessException(e.getMessage()); + } finally { + // 关闭线程池 +// executorService.shutdown(); +// scheduledExecutorService.shutdown(); + } + } + +} diff --git a/src/main/java/com/ai/da/common/utils/RedisUtil.java b/src/main/java/com/ai/da/common/utils/RedisUtil.java new file mode 100644 index 00000000..5121e883 --- /dev/null +++ b/src/main/java/com/ai/da/common/utils/RedisUtil.java @@ -0,0 +1,126 @@ +package com.ai.da.common.utils; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.ZSetOperations; +import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; + +import javax.annotation.Resource; +import java.util.Map; +import java.util.Set; + +@Slf4j +@Component +public class RedisUtil { + + @Resource + private RedisTemplate redisTemplate; + + //- - - - - - - - - - - - - - - - - - - - - ZSet类型 - - - - - - - - - - - - - - - - - - - - + + /** + * 向ZSet中添加元素 + */ + public void addToZSet(String key, String value, Double score) { + redisTemplate.opsForZSet().add(key, value, score); + } + + /** + * 从ZSet中删除元素 + */ + public void removeFromZSet(String key, String value) { + redisTemplate.opsForZSet().remove(key, value); + } + + /** + * 获取指定元素的当前排列顺序 + */ + public Long getRank(String key, String value) { + return redisTemplate.opsForZSet().rank(key, value); + } + + /** + * 获取当前ZSet中的最大score + */ + public Double getMaxScore(String key) { + Set> set = redisTemplate.opsForZSet().reverseRangeWithScores(key, 0, 0); + + if (!CollectionUtils.isEmpty(set)) { + Double score = set.iterator().next().getScore(); + return score + 1.0; + } else { + return 1.0; + } + } + + /** + * 判断元素是否存在 + */ + public Boolean isElementExistsInZSet(String key, String value) { + return redisTemplate.opsForZSet().score(key, value) != null; + } + + /** + * 获取当前ZSet中数据量的总和 + */ + public Long getZSetTotal(String key) { + return redisTemplate.opsForZSet().zCard(key); + } + + //- - - - - - - - - - - - - - - - - - - - - set类型 - - - - - - - - - - - - - - - - - - - - + + /** + * 将数据放入set缓存 + */ + public void addToSet(String key, String value) { + redisTemplate.opsForSet().add(key, value); + } + + /** + * 弹出变量中的元素 + */ + public void removeFromSet(String key, String value) { + redisTemplate.opsForSet().remove(key, value); + } + + /** + * 检查给定的元素是否在变量中。 + */ + public Boolean isElementExistsInSet(String key, String obj) { + return redisTemplate.opsForSet().isMember(key, obj); + } + + + //- - - - - - - - - - - - - - - - - - - - - hash类型 - - - - - - - - - - - - - - - - - - - - + + /** + * 加入缓存 + */ + public void addToMap(String key, Map map) { + redisTemplate.opsForHash().putAll(key, map); + } + + /** + * 验证指定 key 下 有没有指定的 hashkey + */ + public Boolean isElementExistsInMap(String key, String hashKey) { + return redisTemplate.opsForHash().hasKey(key, hashKey); + } + + /** + * 获取指定key的值string + */ + public String getMapValue(String key1, String key2) { + return String.valueOf(redisTemplate.opsForHash().get(key1, key2)); + } + + /** + * 删除指定 hash 的 HashKey + * + * @return 删除成功的 数量 + */ + public Long removeFromMap(String key, String hashKeys) { + return redisTemplate.opsForHash().delete(key, hashKeys); + } +} diff --git a/src/main/java/com/ai/da/common/utils/SnowflakeUtil.java b/src/main/java/com/ai/da/common/utils/SnowflakeUtil.java new file mode 100644 index 00000000..03298fb4 --- /dev/null +++ b/src/main/java/com/ai/da/common/utils/SnowflakeUtil.java @@ -0,0 +1,137 @@ +package com.ai.da.common.utils; + + +public class SnowflakeUtil { + + // ==============================Fields=========================================== + /** 开始时间截 (2015-01-01) */ + private final long twepoch = 1420041600000L; + + /** 机器id所占的位数 */ + private final long workerIdBits = 5L; + + /** 数据标识id所占的位数 */ + private final long datacenterIdBits = 5L; + + /** 支持的最大机器id,结果是31 (这个移位算法可以很快的计算出几位二进制数所能表示的最大十进制数) */ + private final long maxWorkerId = -1L ^ (-1L << workerIdBits); + + /** 支持的最大数据标识id,结果是31 */ + private final long maxDatacenterId = -1L ^ (-1L << datacenterIdBits); + + /** 序列在id中占的位数 */ + private final long sequenceBits = 12L; + + /** 机器ID向左移12位 */ + private final long workerIdShift = sequenceBits; + + /** 数据标识id向左移17位(12+5) */ + private final long datacenterIdShift = sequenceBits + workerIdBits; + + /** 时间截向左移22位(5+5+12) */ + private final long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits; + + /** 生成序列的掩码,这里为4095 (0b111111111111=0xfff=4095) */ + private final long sequenceMask = -1L ^ (-1L << sequenceBits); + + /** 工作机器ID(0~31) */ + private long workerId; + + /** 数据中心ID(0~31) */ + private long datacenterId; + + /** 毫秒内序列(0~4095) */ + private long sequence = 0L; + + /** 上次生成ID的时间截 */ + private long lastTimestamp = -1L; + + //==============================Constructors===================================== + /** + * 构造函数 + * @param workerId 工作ID (0~31) + * @param datacenterId 数据中心ID (0~31) + */ + public SnowflakeUtil(long workerId, long datacenterId) { + if (workerId > maxWorkerId || workerId < 0) { + throw new IllegalArgumentException(String.format + ("worker Id can't be greater than %d or less than 0", maxWorkerId)); + } + if (datacenterId > maxDatacenterId || datacenterId < 0) { + throw new IllegalArgumentException(String.format + ("datacenter Id can't be greater than %d or less than 0", maxDatacenterId)); + } + this.workerId = workerId; + this.datacenterId = datacenterId; + } + + // ==============================Methods========================================== + /** + * 获得下一个ID (该方法是线程安全的) + * @return SnowflakeId + */ + public synchronized long nextId() { + long timestamp = timeGen(); + + //如果当前时间小于上一次ID生成的时间戳,说明系统时钟回退过这个时候应当抛出异常 + if (timestamp < lastTimestamp) { + throw new RuntimeException( + String.format + ("Clock moved backwards. Refusing to generate id for %d milliseconds", lastTimestamp - timestamp)); + } + + //如果是同一时间生成的,则进行毫秒内序列 + if (lastTimestamp == timestamp) { + sequence = (sequence + 1) & sequenceMask; + //毫秒内序列溢出 + if (sequence == 0) { + //阻塞到下一个毫秒,获得新的时间戳 + timestamp = tilNextMillis(lastTimestamp); + } + } + //时间戳改变,毫秒内序列重置 + else { + sequence = 0L; + } + + //上次生成ID的时间截 + lastTimestamp = timestamp; + + //移位并通过或运算拼到一起组成64位的ID + return ((timestamp - twepoch) << timestampLeftShift) // + | (datacenterId << datacenterIdShift) // + | (workerId << workerIdShift) // + | sequence; + } + + /** + * 阻塞到下一个毫秒,直到获得新的时间戳 + * @param lastTimestamp 上次生成ID的时间截 + * @return 当前时间戳 + */ + protected long tilNextMillis(long lastTimestamp) { + long timestamp = timeGen(); + while (timestamp <= lastTimestamp) { + timestamp = timeGen(); + } + return timestamp; + } + + /** + * 返回以毫秒为单位的当前时间 + * @return 当前时间(毫秒) + */ + protected long timeGen() { + return System.currentTimeMillis(); + } + + //==============================Test============================================= + /** 测试 */ + public static void main(String[] args) { + SnowflakeUtil idWorker = new SnowflakeUtil(0, 0); + long id = idWorker.nextId(); + System.out.println("id:"+id); + //id:768842202204864512 + } +} + diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java index 5fa712cc..b3f6fb5b 100644 --- a/src/main/java/com/ai/da/controller/GenerateController.java +++ b/src/main/java/com/ai/da/controller/GenerateController.java @@ -53,4 +53,25 @@ public class GenerateController { return Response.success(generateService.generateDislike(generateDetailId, timeZone)); } + @ApiOperation(value = "发起生成请求,异步获取结果") + @PostMapping("/prepare") + public Response prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { + return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO)); + } + + @ApiOperation(value = "取消继续生成") + @GetMapping("/stopWaiting") + public Response stopWaiting(@RequestParam("userId") Long userId, + @RequestParam("uniqueId") String uniqueId, + @RequestParam("timeZone") String timeZone) { + generateService.cancelGenerate(userId, uniqueId, timeZone); + return Response.success("stop waiting successfully"); + } + + @ApiOperation(value = "获取生成结果") + @GetMapping("/result") + public Response getGenerateResult(@RequestParam("uniqueId") String uniqueId) { + GenerateCollectionVO generateResult = generateService.getGenerateResult(uniqueId); + return Response.success(generateResult); + } } diff --git a/src/main/java/com/ai/da/mapper/GenerateCancelMapper.java b/src/main/java/com/ai/da/mapper/GenerateCancelMapper.java new file mode 100644 index 00000000..3e24ea09 --- /dev/null +++ b/src/main/java/com/ai/da/mapper/GenerateCancelMapper.java @@ -0,0 +1,7 @@ +package com.ai.da.mapper; + +import com.ai.da.common.config.mybatis.plus.CommonMapper; +import com.ai.da.mapper.entity.GenerateCancel; + +public interface GenerateCancelMapper extends CommonMapper { +} diff --git a/src/main/java/com/ai/da/mapper/entity/Generate.java b/src/main/java/com/ai/da/mapper/entity/Generate.java index fc1f4e33..9fad23fb 100644 --- a/src/main/java/com/ai/da/mapper/entity/Generate.java +++ b/src/main/java/com/ai/da/mapper/entity/Generate.java @@ -27,6 +27,11 @@ public class Generate { */ private Long accountId; + /** + * 唯一id + */ + private String uniqueId; + /** * Sketchboard Printboard */ diff --git a/src/main/java/com/ai/da/mapper/entity/GenerateCancel.java b/src/main/java/com/ai/da/mapper/entity/GenerateCancel.java new file mode 100644 index 00000000..8726b9ed --- /dev/null +++ b/src/main/java/com/ai/da/mapper/entity/GenerateCancel.java @@ -0,0 +1,44 @@ +package com.ai.da.mapper.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.Accessors; + +import java.util.Date; + +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(chain = true) +@TableName("t_generate_cancel") +public class GenerateCancel { + + /** + * ID + */ + @TableId(value = "id", type = IdType.AUTO) + private Long id; + + /** + * 用户ID + */ + private Long accountId; + + /** + * 唯一id(任务id) + */ + private String uniqueId; + + /** + * 创建时间 + */ + private Date createDate; + + public GenerateCancel(Long accountId, String uniqueId, Date createDate) { + this.accountId = accountId; + this.uniqueId = uniqueId; + this.createDate = createDate; + } +} diff --git a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java index ca12d04a..fd45533c 100644 --- a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java +++ b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java @@ -5,10 +5,14 @@ import io.swagger.annotations.ApiModelProperty; import lombok.Data; import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; @Data @ApiModel("GenerateThroughImageTextDTO") public class GenerateThroughImageTextDTO { + @NotNull(message = "userId cannot be empty") + @ApiModelProperty("用户id") + Long userId; @ApiModelProperty("caption") String text; @@ -40,4 +44,7 @@ public class GenerateThroughImageTextDTO { @NotBlank(message = "timeZone cannot be empty!") @ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取") String timeZone; + + @ApiModelProperty("唯一id,用于保持消息唯一性") + String uniqueId; } diff --git a/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java b/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java new file mode 100644 index 00000000..93553db0 --- /dev/null +++ b/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java @@ -0,0 +1,27 @@ +package com.ai.da.model.dto; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class GenerateToPythonDTO { + + private Long user_id; + + private String image_url; + + private String category; + + private String str; + + private Integer mode; + + private String version; + + private String gender; + + private String tasks_id; +} diff --git a/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java b/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java index 433e86c8..aa8fb7c6 100644 --- a/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java +++ b/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java @@ -19,6 +19,13 @@ public class GenerateCollectionVO { @ApiModelProperty("生成的图片信息") private List generatedCollectionItems; + @ApiModelProperty("在当前队列中的排序") + private Long rankPosition; + + public GenerateCollectionVO(Long rankPosition) { + this.rankPosition = rankPosition; + } + public GenerateCollectionVO(Long generateId, Long collectionId, List generatedCollectionItems) { this.generateId = generateId; this.collectionId = collectionId; diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index 578de9ef..5a37b3b7 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -23,7 +23,6 @@ import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.serializer.SerializerFeature; import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import io.netty.util.internal.StringUtil; import lombok.extern.slf4j.Slf4j; import okhttp3.*; import org.springframework.beans.factory.annotation.Value; @@ -57,6 +56,7 @@ public class PythonService { private String accessPythonIp; @Value("${access.python.port:''}") private String accessPythonPort; + @Resource private PythonTAllInfoService pythonTAllInfoService; @@ -107,7 +107,7 @@ public class PythonService { if (Objects.nonNull(response.body())) { responseBody = response.body().string(); JSONObject responseObj = JSON.parseObject(responseBody); - log.info("moodboard与printboard图片合成 python返回###{}",responseObj); + log.info("moodboard与printboard图片合成 python返回###{}", responseObj); return responseObj.get("data").toString(); } } catch (IOException | JSONException e) { @@ -389,7 +389,7 @@ public class PythonService { all.addAll(new ArrayList<>(DesignPythonItem.SKIRT_TROUSERS)); return all; } - }else if (modelSex.equals(Sex.MALE.getValue())) { + } else if (modelSex.equals(Sex.MALE.getValue())) { Long randomIndex = RandomsUtil.randomSysFile(0L, 3L); if (randomIndex == 0) { return DesignPythonItem.TOPS; @@ -422,11 +422,11 @@ public class PythonService { long noPinNum = printBoardElements.stream().filter(f -> f.getHasPin() == 0).count(); if (noPinNum == 0L) { return 0; - }else { + } else { long pinNum = printBoardElements.stream().filter(f -> f.getHasPin() == 1).count(); if (8 - pinNum < 4) { return RandomsUtil.randomSysFile(0L, 8 - pinNum + 1); - }else { + } else { return RandomsUtil.randomSysFile(0L, 4L + 1); } } @@ -1600,7 +1600,7 @@ public class PythonService { printBoardElements = elementVO.getPrintBoardElements() .stream() .filter(f -> !elementVO.getHasUseMd5List().contains(f.getMd5())).collect(Collectors.toList()); - }else { + } else { printBoardElements = elementVO.getPrintBoardElements(); } if (CollectionUtil.isEmpty(printBoardElements)) { @@ -2081,7 +2081,7 @@ public class PythonService { skirt.setPrint(designPythonItemPrint); skirt.setPath("aida-sys-image/images/female/trousers/trousers_974.jpg"); response.add(skirt); - }else { + } else { DesignPythonItem top = new DesignPythonItem(); top.setType(MalePosition.TOPS.getValue()); top.setColor("none"); @@ -2195,7 +2195,9 @@ public class PythonService { throw new BusinessException("design.interface.exception"); } - /** 暂时未用 */ + /** + * 暂时未用 + */ public String generateSketchCaption(String url) { //限流校验 AccessLimitUtils.validate("generateSketchCaption", 5); @@ -2238,9 +2240,9 @@ public class PythonService { throw new BusinessException("system error!"); } - public List generateSketchOrPrint(Long userId, String url, String category, String text, int mode, String modelName, String gender) { + public List generateSketchOrPrint(GenerateToPythonDTO generateToPythonDTO) { //限流校验 - AccessLimitUtils.validate("generateSketchOrPrint", 5); +// AccessLimitUtils.validate("generateSketchOrPrint", 5); OkHttpClient client = new OkHttpClient().newBuilder() .connectTimeout(30, TimeUnit.SECONDS) .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒) @@ -2248,46 +2250,45 @@ public class PythonService { .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) .build(); MediaType mediaType = MediaType.parse("application/json"); - Map content = Maps.newHashMap(); - content.put("user_id", userId); - content.put("image_url", url); - content.put("category", category); - content.put("mode", mode); - content.put("str", text); - content.put("version", "1"); - content.put("gender", gender); - RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content, SerializerFeature.WriteMapNullValue)); + RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue)); Request request = new Request.Builder() -// .url(accessPythonIp + ":2828/aida/diffusion") // .url("http://18.167.251.121:9992") - .url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion") +// .url("http://127.0.0.1:5000/api/diffusion") +// .url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion") + .url(accessPythonIp + ":" + accessPythonPort + "/api/generate_image") .method("POST", body) - .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==") +// .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==") .addHeader("Content-Type", "application/json") .build(); Response response = null; - String bodyString ; + String bodyString; try { - log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(content, SerializerFeature.WriteMapNullValue)); + log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue)); response = client.newCall(request).execute(); } catch (IOException ioException) { log.error("PythonService##generateSketchOrPrint异常###{}", ExceptionUtil.getThrowableList(ioException)); +// throw new BusinessException("generate.interface.error"); + throw new BusinessException(ioException.getMessage()); } //去除限流 - AccessLimitUtils.validateOut("generateSketchOrPrint"); +// AccessLimitUtils.validateOut("generateSketchOrPrint"); // 判断是否生成失败 - if (Objects.isNull(response) || Objects.isNull(response.body())) { + if (Objects.isNull(response.body())) { log.error("PythonService##generateSketchOrPrint异常###{}", "response or body is empty!"); - throw new BusinessException("generate.interface.error"); - } else if (response.code() != HttpURLConnection.HTTP_OK){ +// throw new BusinessException("generate.interface.error"); + throw new BusinessException("PythonService##generateSketchOrPrint异常###: response or body is empty!"); + } else if (response.code() != HttpURLConnection.HTTP_OK) { log.error("PythonService##generateSketchOrPrint异常###{}", "Response error!Response code ## " + response.code() + " ##"); - throw new BusinessException("generate.interface.error"); +// throw new BusinessException("generate.interface.error"); + throw new BusinessException("PythonService##generateSketchOrPrint异常### Response error!Response code ## " + response.code() + " ##"); } else { try { bodyString = response.body().string(); } catch (IOException e) { - throw new BusinessException("generate.interface.error"); + log.error(e.getMessage()); +// throw new BusinessException("generate.interface.error"); + throw new BusinessException(e.getMessage()); } } JSONObject jsonObject = JSON.parseObject(bodyString); @@ -2357,7 +2358,7 @@ public class PythonService { if (Objects.isNull(response) || Objects.isNull(response.body())) { log.error("PythonService##composeLayers异常###{}", "response or body is empty!"); throw new BusinessException("compose-layer.interface.exception"); - }else if (response.code() != HttpURLConnection.HTTP_OK){ + } else if (response.code() != HttpURLConnection.HTTP_OK) { log.error("PythonService##composeLayers异常###{}", "Response error!Response code ## " + response.code() + " ##"); throw new BusinessException("compose-layer.interface.exception"); } else { @@ -2382,10 +2383,10 @@ public class PythonService { return item0.getString("synthesis_url"); } - public String getClothCategory(String path,String gender){ + public String getClothCategory(String path, String gender) { HashMap content = new HashMap<>(); - content.put("sketch_img_url",path); - content.put("colony",gender); + content.put("sketch_img_url", path); + content.put("colony", gender); List> contents = Collections.singletonList(content); String jsonString = JSON.toJSONString(contents, SerializerFeature.WriteNullStringAsEmpty); @@ -2398,7 +2399,7 @@ public class PythonService { if (Objects.isNull(response) || Objects.isNull(response.body())) { log.error("PythonService##GetClothCategory###{}", "response or body is empty!"); throw new BusinessException("cloth-classification.interface.exception"); - } else if (response.code() != HttpURLConnection.HTTP_OK){ + } else if (response.code() != HttpURLConnection.HTTP_OK) { log.error("PythonService##GetClothCategory###{}", "Response error!Response code ## " + response.code() + " ##"); throw new BusinessException("cloth-classification.interface.exception"); } else { @@ -2409,7 +2410,7 @@ public class PythonService { } } JSONObject jsonObject = JSON.parseObject(bodyString); - try{ + try { Boolean result = JSON.parseObject(JSON.toJSONString(response)).getBoolean("successful"); if (result && jsonObject.get("msg").equals("OK!")) { JSONObject data = jsonObject.getJSONObject("data"); @@ -2417,7 +2418,7 @@ public class PythonService { JSONObject map = (JSONObject) list.get(0); return map.get("category").toString(); } - }catch (NullPointerException e){ + } catch (NullPointerException e) { log.info("getClothCategory 失败###{},未返回category", jsonObject); throw new BusinessException("cloth-classification.interface.exception"); } @@ -2425,4 +2426,35 @@ public class PythonService { //生成失败 throw new BusinessException("cloth-classification.interface.exception"); } + + public Boolean cancelGenerateTask(String taskId) { + OkHttpClient client = new OkHttpClient().newBuilder() + .connectTimeout(30, TimeUnit.SECONDS) + .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒) + .readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒) + .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) + .build(); + String url = accessPythonIp + ":" + accessPythonPort + "/api/generate_cancel/" + taskId; + Request request = new Request.Builder() + .url(url) +// .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==") +// .addHeader("Content-Type", "application/json") + .build(); + Response response; + try { + log.info("cancelGenerateTask请求入参content###{}", taskId); + response = client.newCall(request).execute(); + } catch (IOException ioException) { + log.error("PythonService##cancelGenerateTask异常###{}", ExceptionUtil.getThrowableList(ioException)); + return null; + } + + if (response.code() != HttpURLConnection.HTTP_OK) { + log.info("generate-python 取消请求失败"); + return Boolean.FALSE; + } + log.info("generate-python 取消请求成功"); + return Boolean.TRUE; + } + } diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index 0d92a07f..d1a327bd 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -24,4 +24,13 @@ public interface GenerateService extends IService { void updateLikeStatusBatch(List generateDetailIdList, Byte hasLike, Long libraryId, String timeZone); List selectBatchByLibraryId(List libraryId); + + GenerateCollectionVO getGenerateResult(String uniqueId); + + String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO); + + Long getRankPosition(String uniqueId); + + void cancelGenerate(Long userId, String uniqueId, String timeZone); + } diff --git a/src/main/java/com/ai/da/service/RabbitMQService.java b/src/main/java/com/ai/da/service/RabbitMQService.java new file mode 100644 index 00000000..b03a578d --- /dev/null +++ b/src/main/java/com/ai/da/service/RabbitMQService.java @@ -0,0 +1,11 @@ +package com.ai.da.service; + +import org.springframework.stereotype.Service; + +@Service +public interface RabbitMQService { + + void publishMessage(String message); + + Integer getMessageCount(String queueUrl); +} diff --git a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java index 1c6307d9..a44defca 100644 --- a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java @@ -810,9 +810,11 @@ public class CollectionElementServiceImpl extends ServiceImpl implements GenerateService { @@ -52,10 +59,31 @@ public class GenerateServiceImpl extends ServiceImpl i @Resource private MinioUtil minioUtil; + @Resource + private RabbitMQService rabbitMQService; + + @Resource + private RedisUtil redisUtil; + + @Resource + private GenerateCancelMapper generateCancelMapper; + + @Value("${redis.key.consumptionOrder}") + private String consumptionOrderKey; + + @Value("${redis.key.cancelSet}") + private String cancelSetKey; + + @Value("${redis.key.exceptionMap}") + private String exceptionMapKey; + + @Value("${redis.key.resultMap}") + private String resultMapKey; + @Override public GenerateCaptionVO generateCaption(Long sketchElementId) { CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId); - if (Objects.isNull(collectionElement)){ + if (Objects.isNull(collectionElement)) { throw new BusinessException("the.image.does.not.exist.please.reselect"); } String url = collectionElement.getUrl(); @@ -69,45 +97,46 @@ public class GenerateServiceImpl extends ServiceImpl i @Transactional(rollbackFor = Exception.class) public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、获取用户信息 - AuthPrincipalVo userHolder = UserContext.getUserHolder(); + Long accountId = generateThroughImageTextDTO.getUserId(); String generateType = generateThroughImageTextDTO.getGenerateType(); - Long accountId = userHolder.getId(); - if (!GenerateModeEnum.getGenerateModeList().contains(generateType)){ - throw new BusinessException("unknown.generate.type"); - } // 2、判断必须入参是否为非空 Generate generate = new Generate(); generate.setAccountId(accountId); + generate.setUniqueId(generateThroughImageTextDTO.getUniqueId()); generate.setLevel1Type(generateThroughImageTextDTO.getLevel1Type()); // 当level1type是sketchboard时,存数据库需要加上当前性别 generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? - generateType + " (" +generateThroughImageTextDTO.getGender() + ")": + generateType + " (" + generateThroughImageTextDTO.getGender() + ")" : generateType); generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getVersion()); generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone())); - String text = generateThroughImageTextDTO.getText(); Long elementId = generateThroughImageTextDTO.getCollectionElementId(); - validateGeneraType(generate, text, elementId,generateType); + validateGeneraType(generate, text, elementId, generateType); - // 3、将请求信息落库 - // 3.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type + // 2.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type()); - // 3.2 将本次generate的请求信息添加到t_generate表中 - save(generate); - - // 4、向模型发起请求 + // 3、向模型发起请求 int mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ? GenerateModeEnum.TEXT.getCode() : GenerateModeEnum.TEXT_IMAGE.getCode(); String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" : generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard"; -// text = !StringUtil.isNullOrEmpty(text) && generateThroughImageTextDTO.getVersion().equals("1") ? "painting style, " + text : text; - List generatedSketchUrl = pythonService.generateSketchOrPrint(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(), - category, text, mode, generateThroughImageTextDTO.getVersion(), generateThroughImageTextDTO.getGender()); + AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil(); + List generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), + category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId())); +// List generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(), +// category, text, mode, "1", generateThroughImageTextDTO.getGender())); + log.info("generate 响应 : " + generatedSketchUrl); + if (CollectionUtils.isEmpty(generatedSketchUrl)) { + return null; + } + + // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中 + save(generate); // 5、处理模型返回的数据 // 5.1 将相应的url保存到数据库 @@ -117,8 +146,8 @@ public class GenerateServiceImpl extends ServiceImpl i GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(item, 24 * 60), Boolean.FALSE); // 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过 - List> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generateThroughImageTextDTO.getLevel1Type()); - if (!libraryIdList.isEmpty()){ + List> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generateThroughImageTextDTO.getLevel1Type()); + if (!libraryIdList.isEmpty()) { generateDetail.setIsLike((byte) 1); generateDetail.setLibraryId(libraryIdList.get(0).get("library_id")); generateCollectionItemVO.setIsLiked(Boolean.TRUE); @@ -139,22 +168,22 @@ public class GenerateServiceImpl extends ServiceImpl i return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems); } - private void validateGeneraType(Generate generate, String text, Long elementId,String generateType) { + private void validateGeneraType(Generate generate, String text, Long elementId, String generateType) { switch (generateType) { case "text": - if (StringUtil.isNullOrEmpty(text)){ + if (StringUtil.isNullOrEmpty(text)) { throw new BusinessException("please.input.the.caption"); } generate.setText(text); break; case "image": - if (Objects.isNull(elementId)){ + if (Objects.isNull(elementId)) { throw new BusinessException("please.choose.an.image"); } generate.setCollectionElementId(elementId); break; case "text-image": - if (StringUtil.isNullOrEmpty(text) || Objects.isNull(elementId)){ + if (StringUtil.isNullOrEmpty(text) || Objects.isNull(elementId)) { throw new BusinessException("please.input.the.caption.and.choose.an.image"); } generate.setText(text); @@ -169,21 +198,21 @@ public class GenerateServiceImpl extends ServiceImpl i // 1、判断参数是否正确 // 1.1 必须参数是否非空 if (SKETCH_BOARD.getRealName().equals(generateLikeDTO.getLevel1Type())) { - if (StringUtil.isNullOrEmpty(generateLikeDTO.getLevel2Type())){ + if (StringUtil.isNullOrEmpty(generateLikeDTO.getLevel2Type())) { throw new BusinessException("level2Type.cannot.be.empty"); } - if (StringUtil.isNullOrEmpty(generateLikeDTO.getGender())){ + if (StringUtil.isNullOrEmpty(generateLikeDTO.getGender())) { throw new BusinessException("gender.cannot.be.empty"); } } // 1.2 判断参数是否真实有效 Long generateDetailId = generateLikeDTO.getGenerateDetailId(); GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId); - if (Objects.isNull(generateDetail)){ + if (Objects.isNull(generateDetail)) { throw new BusinessException("generateItem.does.not.exist"); } Generate generate = getById(generateDetail.getGenerateId()); - if (!generateLikeDTO.getLevel1Type().equals(generate.getLevel1Type())){ + if (!generateLikeDTO.getLevel1Type().equals(generate.getLevel1Type())) { throw new BusinessException("level1Type.does.not.match"); } @@ -191,8 +220,8 @@ public class GenerateServiceImpl extends ServiceImpl i // 2.1、不能重复喜欢 // 2.1.1 判断该图片是否被喜欢过 Library libraryDetail = libraryService.getById(generateDetail.getLibraryId()); - if ( (Objects.nonNull(generateDetail.getLibraryId()) && !generateDetail.getLibraryId().equals(0L)) - || Objects.nonNull(libraryDetail)){ + if ((Objects.nonNull(generateDetail.getLibraryId()) && !generateDetail.getLibraryId().equals(0L)) + || Objects.nonNull(libraryDetail)) { throw new BusinessException("duplicate.likes.are.not.allowed"); } @@ -215,7 +244,7 @@ public class GenerateServiceImpl extends ServiceImpl i public Boolean generateDislike(Long generateDetailId, String timeZone) { // 1、确定generateDetail中是否有这条记录 GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId); - if (Objects.isNull(generateDetail)){ + if (Objects.isNull(generateDetail)) { throw new BusinessException("generateItem.does.not.exist"); } @@ -265,7 +294,7 @@ public class GenerateServiceImpl extends ServiceImpl i generateDetailMapper.update(generateDetail, queryWrapper); } - public void updateLikeStatusBatch(List generateDetailIdList, Byte hasLike, Long libraryId, String timeZone){ + public void updateLikeStatusBatch(List generateDetailIdList, Byte hasLike, Long libraryId, String timeZone) { QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.in("id", generateDetailIdList); @@ -277,10 +306,156 @@ public class GenerateServiceImpl extends ServiceImpl i generateDetailMapper.update(generateDetail, queryWrapper); } - public List selectBatchByLibraryId(List libraryId){ + public List selectBatchByLibraryId(List libraryId) { QueryWrapper qw = new QueryWrapper<>(); - qw.in("library_id",libraryId); + qw.in("library_id", libraryId); return generateDetailMapper.selectList(qw); } + + @Override + public String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { + // 1、参数检查,判断必须参数是否为空 + if (Objects.isNull(generateThroughImageTextDTO.getUserId())) { + throw new BusinessException("userId cannot be empty"); + } + String generateType = generateThroughImageTextDTO.getGenerateType(); + if (!GenerateModeEnum.getGenerateModeList().contains(generateType)) { + throw new BusinessException("unknown.generate.type"); + } + String text = generateThroughImageTextDTO.getText(); + Long elementId = generateThroughImageTextDTO.getCollectionElementId(); + validateGeneraType(new Generate(), text, elementId, generateType); + + // 2、生成唯一id 使用uuid + String uuid = UUID.randomUUID().toString(); + +// SnowflakeUtil idWorker = new SnowflakeUtil(0, 0); +// long snowflakeId = idWorker.nextId(); + + int num = 1; + // 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id + while ((redisUtil.isElementExistsInMap(resultMapKey, uuid) || + redisUtil.isElementExistsInZSet(consumptionOrderKey, uuid)) + && num < 10) { + uuid = UUID.randomUUID().toString(); + num++; + } + // 无依据确定的数字 + if (num > 10) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + uuid = UUID.randomUUID().toString(); + } + generateThroughImageTextDTO.setUniqueId(uuid); + String jsonString = JSON.toJSONString(generateThroughImageTextDTO); + + // 3、加入redis排队,便于获取实时排队信息 + Double maxScore = redisUtil.getMaxScore(consumptionOrderKey); + redisUtil.addToZSet(consumptionOrderKey, uuid, maxScore); + + // 4、将消息发布到MQ消息队列 + rabbitMQService.publishMessage(jsonString); + + // 5、返回唯一id + return uuid; + } + + @Override + public Long getRankPosition(String uniqueId) { + // rank 从0开始 + return redisUtil.getRank(consumptionOrderKey, uniqueId); + } + + @Override + public GenerateCollectionVO getGenerateResult(String uniqueId) { + // 1、判断该请求是否已经异常 + Boolean isMember = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId); + if (isMember) { + throw new BusinessException("generate.interface.error"); + } + + // 2、判断该请求是否还在排队 + Boolean existsInZSet = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId); + if (existsInZSet) { + // 排队中,给出当前排序位置,rank从0开始 + Long rankPosition = getRankPosition(uniqueId); + // 有9个消费者,所以当rank>8即当前请求至少排在第九位时,其实际排队位置为9-8+1,当rank <=8,请求均在处理中 + return new GenerateCollectionVO(rankPosition > 8L ? rankPosition - 8 + 1 : 1L); + } + + // 3、判断redis中有没有 + boolean hasHashKey = redisUtil.isElementExistsInMap(resultMapKey, uniqueId); + if (hasHashKey) { + // 3.1 有直接从redis中拿 + String resultString = redisUtil.getMapValue(resultMapKey, uniqueId); + return JSONObject.parseObject(resultString, GenerateCollectionVO.class); + } + + // 3.2 判断数据库中有没有 + Generate generate = selectByUniqueId(uniqueId); + if (Objects.isNull(generate)) { + // 3.3 还没执行完,给出当前位置 + return new GenerateCollectionVO(1L); + } + Long generateId = generate.getId(); + QueryWrapper qw = new QueryWrapper<>(); + qw.eq("generate_id", generateId); + List generateDetails = generateDetailMapper.selectList(qw); + if (CollectionUtils.isEmpty(generateDetails)) { + // 会有这种情况吗?存到generate中,但是还没存到generateDetail中 + return new GenerateCollectionVO(1L); + } + + List generatedCollectionItems = new ArrayList<>(); + generateDetails.forEach(item -> { + GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); + generateCollectionItemVO.setGenerateItemId(item.getId()); + generateCollectionItemVO.setGenerateItemUrl(minioUtil.getPresignedUrl(item.getUrl(), 24 * 60)); + generatedCollectionItems.add(generateCollectionItemVO); + }); + + return new GenerateCollectionVO(generateId, null, generatedCollectionItems); + } + + public Generate selectByUniqueId(String uniqueId) { + QueryWrapper qw = new QueryWrapper<>(); + qw.eq("unique_id", uniqueId); + + return getOne(qw); + } + + @Override + @Transactional(rollbackFor = Exception.class) + public void cancelGenerate(Long userId, String uniqueId, String timeZone) { + // 1、确认当前消息是否还在排队中 + Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId); + Boolean flag = Boolean.FALSE; + if (exists) flag = redisUtil.getRank(consumptionOrderKey, uniqueId) > 1L ? Boolean.TRUE : Boolean.FALSE; + // 不管flag的默认值是true还是false,只要exists为false,&& 将短路 + if (exists && flag) { + // 1.1、将需要取消的唯一id加入redis,以便及时取消生成 + redisUtil.addToSet(cancelSetKey, uniqueId); + // 1.2 将需要取消的id从redis的ConsumptionOrder中删除 + redisUtil.removeFromZSet(consumptionOrderKey, uniqueId); + } else { + // 2、判断该消息是否异常 + boolean hasKey = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId); + // 3、判断该消息是否已经消费结束 + Boolean existsInResult = redisUtil.isElementExistsInMap(resultMapKey, uniqueId); + if (!hasKey && !existsInResult) { + // 设置取等待状态为false + AsyncCallerUtil.waitingStatus.put(uniqueId, false); + // 3、直接发送取消请求到python端 + pythonService.cancelGenerateTask(uniqueId); + } + } + + // 3、考虑加一张表,专门用于记录哪些用户在什么时间进行了取消操作 + GenerateCancel generateCancel = new GenerateCancel(userId, uniqueId, DateUtil.getByTimeZone(timeZone)); + generateCancelMapper.insert(generateCancel); + } } diff --git a/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java b/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java new file mode 100644 index 00000000..41743964 --- /dev/null +++ b/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java @@ -0,0 +1,76 @@ +package com.ai.da.service.impl; + +import cn.hutool.core.exceptions.ExceptionUtil; +import com.ai.da.common.RabbitMQ.MQPublisher; +import com.ai.da.common.config.exception.BusinessException; +import com.ai.da.service.RabbitMQService; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; +import lombok.extern.slf4j.Slf4j; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.springframework.stereotype.Service; + +import javax.annotation.Resource; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +@Slf4j +@Service +public class RabbitMQServiceImpl implements RabbitMQService { + + @Resource + private MQPublisher mqPublisher; + + @Override + public void publishMessage(String message) { + mqPublisher.sendGenerateMessage(message); + } + + @Override + public Integer getMessageCount(String queueUrl) { + + String url = "http://localhost:15672/api/queues/%2f/generate-queue"; + + OkHttpClient client = new OkHttpClient().newBuilder() + .connectTimeout(30, TimeUnit.SECONDS) + .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒) + .readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒) + .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) + .build(); + Request request = new Request.Builder() + .url(queueUrl) + .method("GET",null) + .addHeader("Authorization", "Basic Z3Vlc3Q6Z3Vlc3Q=") + .build(); + Response response = null; + try { + response = client.newCall(request).execute(); + } catch (IOException ioException) { + log.error("RabbitMQService##" + "getMessage异常###{}", ExceptionUtil.getThrowableList(ioException)); + } + + String bodyString; + // 生成失败 + if (Objects.isNull(response) || Objects.isNull(response.body())) { + log.error("RabbitMQService##getMessageCount异常###{}", "response or body is empty!"); + throw new BusinessException("compose-layer.interface.exception"); + }else if (response.code() != HttpURLConnection.HTTP_OK){ + log.error("RabbitMQService##getMessageCount异常###{}", "Response error!Response code ## " + response.code() + " ##"); + throw new BusinessException("compose-layer.interface.exception"); + } else { + try { + bodyString = response.body().string(); + } catch (IOException e) { + throw new BusinessException("compose-layer.interface.exception"); + } + } + JSONObject jsonObject = JSON.parseObject(bodyString); + String messageCount = jsonObject.get("messages").toString(); + return Integer.parseInt(messageCount); + } +} diff --git a/src/main/resources/application-prod.properties b/src/main/resources/application-prod.properties index 3abdcf3b..a9510207 100644 --- a/src/main/resources/application-prod.properties +++ b/src/main/resources/application-prod.properties @@ -54,3 +54,24 @@ minio.bucketName.sysImage=aida-sys-image minio.bucketName.users=aida-users minio.bucketName.collectionElement=aida-collection-element redirect_url=http://18.167.251.121:7788 + +spring.rabbitmq.host=18.167.251.121 +spring.rabbitmq.port=5672 +spring.rabbitmq.username=rabbit +spring.rabbitmq.password=123456 +spring.rabbitmq.virtual-host=/ + +spring.redis.host=172.31.11.32 +#spring.redis.host=18.167.251.121 +spring.redis.port=6379 +spring.redis.database=1 +spring.redis.password=Aidlab +spring.redis.lettuce.pool.max-active=8 +spring.redis.lettuce.pool.max-idle=8 +spring.redis.lettuce.pool.min-idle=0 +spring.redis.lettuce.pool.max-wait=5 + +redis.key.consumptionOrder=ConsumptionOrder +redis.key.cancelSet=CancelSet +redis.key.exceptionMap=ExceptionMap +redis.key.resultMap=ResultMap \ No newline at end of file diff --git a/src/main/resources/application-test.properties b/src/main/resources/application-test.properties index 9027b77f..01678779 100644 --- a/src/main/resources/application-test.properties +++ b/src/main/resources/application-test.properties @@ -50,8 +50,8 @@ spring.servlet.multipart.max-request-size= 5MB #access.python.ip=http://43.198.80.117 access.python.ip=http://18.167.251.121 #access.python.ip=http://18.167.251.121:9991/ -#access.python.port=9992 -access.python.port=9990 +access.python.port=9992 +#access.python.port=9991 # minIO服务配置之信息 minio.endpoint=https://www.minio.aida.com.hk:9000 @@ -62,6 +62,26 @@ minio.bucketName.results=aida-results minio.bucketName.sysImage=aida-sys-image minio.bucketName.users=aida-users minio.bucketName.collectionElement=aida-collection-element +redirect_url=http://18.167.251.121:7788 +spring.rabbitmq.host=18.167.251.121 +spring.rabbitmq.port=5672 +spring.rabbitmq.username=rabbit +spring.rabbitmq.password=123456 +spring.rabbitmq.virtual-host=/ +spring.redis.host=172.31.11.32 +#spring.redis.host=18.167.251.121 +spring.redis.port=6379 +spring.redis.database=1 +spring.redis.password=Aidlab +spring.redis.lettuce.pool.max-active=8 +spring.redis.lettuce.pool.max-idle=8 +spring.redis.lettuce.pool.min-idle=0 +spring.redis.lettuce.pool.max-wait=5 + +redis.key.consumptionOrder=ConsumptionOrder +redis.key.cancelSet=CancelSet +redis.key.exceptionMap=ExceptionMap +redis.key.resultMap=ResultMap