diff --git a/docker-compose.yml b/docker-compose.yml
index c1843a76..a37c5319 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -3,8 +3,8 @@ services:
aida_back:
container_name: stable-version-aida-back
build: .
- volumes:
- # 数据挂载
- - /workspace/home/aida/file/:/workspace/home/aida/file/
+# volumes:
+# # 数据挂载
+# - /workspace/home/aida/file/:/workspace/home/aida/file/
ports:
- "10086:5567"
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 50c4b965..14913240 100644
--- a/pom.xml
+++ b/pom.xml
@@ -157,6 +157,19 @@
commons-fileupload
1.4
+
+
+
+ org.springframework.boot
+ spring-boot-starter-amqp
+
+
+
+
+ org.apache.commons
+ commons-pool2
+
+
diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java
new file mode 100644
index 00000000..d4e24508
--- /dev/null
+++ b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java
@@ -0,0 +1,45 @@
+package com.ai.da.common.RabbitMQ;
+
+import org.springframework.amqp.core.Binding;
+import org.springframework.amqp.core.BindingBuilder;
+import org.springframework.amqp.core.FanoutExchange;
+import org.springframework.amqp.core.Queue;
+import org.springframework.amqp.rabbit.connection.CachingConnectionFactory;
+import org.springframework.amqp.rabbit.connection.ConnectionFactory;
+import org.springframework.amqp.rabbit.core.RabbitAdmin;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.beans.factory.annotation.Value;
+
+@Configuration
+public class MQConfig {
+
+ public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange";
+ public static final String GENERATE_QUEUE = "generate-queue-prod";
+
+ public MQConfig() {
+ }
+
+// @Bean
+// FanoutExchange fanoutRasaExchange() {
+// return new FanoutExchange(GENERATE_EXCHANGE_FANOUT);
+// }
+
+ /**
+ * 创建队列,使用工作模式,不用定义交换机
+ */
+ @Bean
+ public Queue queueRasa() {
+ return new Queue(GENERATE_QUEUE);
+ }
+
+ /**
+ * 将队列绑定到交换机上【队列订阅交换机】
+ */
+// @Bean
+// Binding bindingExchangeRasa() {
+// return BindingBuilder.bind(queueRasa()).to(fanoutRasaExchange());
+// }
+
+}
diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQConsumer.java b/src/main/java/com/ai/da/common/RabbitMQ/MQConsumer.java
new file mode 100644
index 00000000..77f46dbd
--- /dev/null
+++ b/src/main/java/com/ai/da/common/RabbitMQ/MQConsumer.java
@@ -0,0 +1,161 @@
+package com.ai.da.common.RabbitMQ;
+
+import com.ai.da.common.config.exception.BusinessException;
+import com.ai.da.common.utils.RedisUtil;
+import com.ai.da.model.dto.GenerateThroughImageTextDTO;
+import com.ai.da.model.vo.GenerateCollectionVO;
+import com.ai.da.service.GenerateService;
+import com.alibaba.fastjson.JSONObject;
+import com.rabbitmq.client.Channel;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.amqp.core.Message;
+import org.springframework.amqp.rabbit.annotation.RabbitHandler;
+import org.springframework.amqp.rabbit.annotation.RabbitListener;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Component;
+
+import javax.annotation.Resource;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Objects;
+
+
+@Slf4j
+@Component
+public class MQConsumer {
+
+ @Resource
+ private GenerateService generateService;
+
+ @Resource
+ private RedisUtil redisUtil;
+
+ @Value("${redis.key.consumptionOrder}")
+ private String consumptionOrderKey;
+
+ @Value("${redis.key.cancelSet}")
+ private String cancelSetKey;
+
+ @Value("${redis.key.exceptionMap}")
+ private String exceptionMapKey;
+
+ @Value("${redis.key.resultMap}")
+ private String resultMapKey;
+
+ public void generate(Message msg, Channel channel, String consumerName) {
+ log.info("============start listening==========");
+ long start = System.currentTimeMillis();
+
+ GenerateThroughImageTextDTO generateThroughImageTextDTO = JSONObject.parseObject(msg.getBody(), GenerateThroughImageTextDTO.class);
+ String uniqueId = generateThroughImageTextDTO.getUniqueId();
+ log.info("From " + consumerName + " : " + uniqueId);
+
+ try {
+ // 2、判断当前消息是否在取消列表中
+ Boolean isMember = redisUtil.isElementExistsInSet(cancelSetKey, uniqueId);
+ if (isMember) {
+ try {
+ // 2.1 手动确认该消息
+ channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
+ } catch (IOException ex) {
+ log.error("手动确认,不返回队列重新消费");
+ }
+ // 2.2 将该消息从取消列表中删除
+// redisUtil.removeFromSet(cancelSetKey, uniqueId);
+ } else {
+ /*try {
+ Thread.sleep(15000);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }*/
+ GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO);
+ // 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
+ redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
+ if (!Objects.isNull(generateCollectionVO)) {
+ HashMap generateResult = new HashMap<>();
+ generateResult.put(uniqueId, JSONObject.toJSONString(generateCollectionVO));
+ // 将结果存在redis中 ,为空时不要存
+ redisUtil.addToMap(resultMapKey, generateResult);
+ }
+
+ }
+ } catch (BusinessException e) {
+ log.error(e.getMsg());
+ // channel.basicNack() 为不确认deliveryTag对应的消息,第二个参数是否应用于多消息,第三个参数是否requeue
+ try {
+ // 第二个参数,是否批量确认消息,当传false时,只确认当前 deliveryTag对应的消息;当传true时,会确认当前及之前所有未确认的消息。
+ channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
+ // 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
+ redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
+ } catch (IOException exception) {
+ log.error("手动确认,取消返回队列,不再重新消费");
+ }
+ // 将入参和错误信息存入数据库
+ String exceptionMessage = JSONObject.toJSONString(generateThroughImageTextDTO) +
+ " Exception message : " + e.getMsg();
+ HashMap exceptionInfo = new HashMap<>();
+ exceptionInfo.put(String.valueOf(uniqueId), exceptionMessage);
+ // 存redis
+ redisUtil.addToMap(exceptionMapKey, exceptionInfo);
+ }
+ long end = System.currentTimeMillis();
+
+ log.info(" task_id: " + uniqueId + "----------" + consumerName + " 执行时长:" + (end - start) + "毫秒");
+ log.info("=============end listening===========");
+ }
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generateConsumer1(Message msg, Channel channel) {
+ generate(msg, channel, "consumer 1");
+ }
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generateConsumer2(Message msg, Channel channel) {
+ generate(msg, channel, "consumer 2");
+ }
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generateConsumer3(Message msg, Channel channel) {
+ generate(msg, channel, "consumer 3");
+ }
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generateConsumer4(Message msg, Channel channel) {
+ generate(msg, channel, "consumer 4");
+ }
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generateConsumer5(Message msg, Channel channel) {
+ generate(msg, channel, "consumer 5");
+ }
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generateConsumer6(Message msg, Channel channel) {
+ generate(msg, channel, "consumer 6");
+ }
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generateConsumer7(Message msg, Channel channel) {
+ generate(msg, channel, "consumer 7");
+ }
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generateConsumer8(Message msg, Channel channel) {
+ generate(msg, channel, "consumer 8");
+ }
+
+ @RabbitListener(queues = MQConfig.GENERATE_QUEUE)
+ @RabbitHandler
+ public void generateConsumer9(Message msg, Channel channel) {
+ generate(msg, channel, "consumer 9");
+ }
+
+}
diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java
new file mode 100644
index 00000000..b0429110
--- /dev/null
+++ b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java
@@ -0,0 +1,24 @@
+package com.ai.da.common.RabbitMQ;
+
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.amqp.core.AmqpTemplate;
+import org.springframework.stereotype.Component;
+
+import javax.annotation.Resource;
+
+@Slf4j
+@Component
+public class MQPublisher {
+
+ private final String url = "http://localhost:15672/api/queues/%2f/generate-queue";
+
+ @Resource
+ private AmqpTemplate amqpTemplate;
+
+ public void sendGenerateMessage(String mm) {
+ log.info("send message:" + mm);
+ amqpTemplate.convertAndSend(MQConfig.GENERATE_QUEUE, mm);
+
+ }
+
+}
diff --git a/src/main/java/com/ai/da/common/config/MyTaskScheduler.java b/src/main/java/com/ai/da/common/config/MyTaskScheduler.java
index 893c3771..0101d5bc 100644
--- a/src/main/java/com/ai/da/common/config/MyTaskScheduler.java
+++ b/src/main/java/com/ai/da/common/config/MyTaskScheduler.java
@@ -29,6 +29,9 @@ public class MyTaskScheduler {
// 用户到期时间戳
Long timestamp = account.getValidEndTime(); // 替换为你的时间戳
+ if (null == timestamp) {
+ continue;
+ }
// 获取当前时间戳
Long currentTimestamp = System.currentTimeMillis();
diff --git a/src/main/java/com/ai/da/common/config/RedisConfig.java b/src/main/java/com/ai/da/common/config/RedisConfig.java
new file mode 100644
index 00000000..5f01753c
--- /dev/null
+++ b/src/main/java/com/ai/da/common/config/RedisConfig.java
@@ -0,0 +1,42 @@
+package com.ai.da.common.config;
+
+import com.fasterxml.jackson.annotation.JsonAutoDetect;
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
+import com.fasterxml.jackson.annotation.PropertyAccessor;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.data.redis.connection.RedisConnectionFactory;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
+import org.springframework.data.redis.serializer.StringRedisSerializer;
+
+@Configuration
+public class RedisConfig {
+ @Bean(name = "redisTemplate")
+ public RedisTemplate getRedisTemplate(RedisConnectionFactory factory) {
+ RedisTemplate redisTemplate = new RedisTemplate<>();
+ redisTemplate.setConnectionFactory(factory);
+
+ StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
+
+ redisTemplate.setKeySerializer(stringRedisSerializer); // key的序列化类型
+
+ Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
+ ObjectMapper objectMapper = new ObjectMapper();
+ objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
+ // 方法过期,改为下面代码
+// objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
+ objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance,
+ ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
+ jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
+ jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
+
+ redisTemplate.setValueSerializer(jackson2JsonRedisSerializer); // value的序列化类型
+ redisTemplate.setHashKeySerializer(stringRedisSerializer);
+ redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
+ redisTemplate.afterPropertiesSet();
+ return redisTemplate;
+ }
+}
diff --git a/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java b/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java
new file mode 100644
index 00000000..35039431
--- /dev/null
+++ b/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java
@@ -0,0 +1,73 @@
+package com.ai.da.common.utils;
+
+import com.ai.da.common.config.exception.BusinessException;
+import com.ai.da.model.dto.GenerateToPythonDTO;
+import com.ai.da.python.PythonService;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Component;
+
+import java.util.*;
+import java.util.concurrent.*;
+
+@Slf4j
+@Component
+public class AsyncCallerUtil {
+
+ public static Map waitingStatus = new HashMap<>();
+
+ private static PythonService pythonService;
+
+ @Autowired
+ public void setPythonService(PythonService pythonService) {
+ AsyncCallerUtil.pythonService = pythonService;
+ }
+
+ public CompletableFuture> callGenerateAsync(GenerateToPythonDTO generateToPython) {
+ return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython));
+ }
+
+ public List generate(GenerateToPythonDTO generateToPython) {
+ ScheduledExecutorService scheduledExecutorService = Executors.newScheduledThreadPool(5);
+ String taskId = generateToPython.getTasks_id();
+ ScheduledFuture> timeoutTask = null;
+ if (!waitingStatus.containsKey(taskId)) waitingStatus.put(taskId, true);
+
+ try {
+ CompletableFuture> generateResult = callGenerateAsync(generateToPython);
+ // 5秒后第一次确认,之后每隔10秒确认一次用户选择结果
+ timeoutTask = scheduledExecutorService.scheduleAtFixedRate(() -> {
+ // 调用另一个接口获取用户的选择
+ if (!waitingStatus.get(taskId)) {
+ // 如果用户选择取消,则取消对generate的调用
+ generateResult.cancel(true);
+ waitingStatus.remove(taskId);
+ } else log.info("===============持续等待===============");
+ }, 5, 10, TimeUnit.SECONDS);
+
+ log.info("阻塞等待结果...");
+ // 阻塞,等待结果
+ List result = generateResult.get();
+ // 取消定时任务
+ timeoutTask.cancel(true);
+ waitingStatus.remove(taskId);
+ return result;
+ } catch (CancellationException e) {
+ // generateResult.cancel(true);通过抛出异常取消该任务
+ log.info("==========成功取消generate任务==========");
+ return null;
+ } catch (InterruptedException | ExecutionException | BusinessException e) {
+ // 处理异常
+ log.error("发生错误 : " + e);
+ // 取消定时任务
+ assert timeoutTask != null;
+ timeoutTask.cancel(true);
+ throw new BusinessException(e.getMessage());
+ } finally {
+ // 关闭线程池
+// executorService.shutdown();
+// scheduledExecutorService.shutdown();
+ }
+ }
+
+}
diff --git a/src/main/java/com/ai/da/common/utils/MinioUtil.java b/src/main/java/com/ai/da/common/utils/MinioUtil.java
index 5e98eecb..352b4d79 100644
--- a/src/main/java/com/ai/da/common/utils/MinioUtil.java
+++ b/src/main/java/com/ai/da/common/utils/MinioUtil.java
@@ -188,7 +188,7 @@ public class MinioUtil {
}
}
}
- return bucketName + path;
+ return bucketName + "/" + path;
}
// public String upload(String bucketName, String path, File file) {
diff --git a/src/main/java/com/ai/da/common/utils/RedisUtil.java b/src/main/java/com/ai/da/common/utils/RedisUtil.java
new file mode 100644
index 00000000..5121e883
--- /dev/null
+++ b/src/main/java/com/ai/da/common/utils/RedisUtil.java
@@ -0,0 +1,126 @@
+package com.ai.da.common.utils;
+
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.data.redis.core.ZSetOperations;
+import org.springframework.stereotype.Component;
+import org.springframework.util.CollectionUtils;
+
+import javax.annotation.Resource;
+import java.util.Map;
+import java.util.Set;
+
+@Slf4j
+@Component
+public class RedisUtil {
+
+ @Resource
+ private RedisTemplate redisTemplate;
+
+ //- - - - - - - - - - - - - - - - - - - - - ZSet类型 - - - - - - - - - - - - - - - - - - - -
+
+ /**
+ * 向ZSet中添加元素
+ */
+ public void addToZSet(String key, String value, Double score) {
+ redisTemplate.opsForZSet().add(key, value, score);
+ }
+
+ /**
+ * 从ZSet中删除元素
+ */
+ public void removeFromZSet(String key, String value) {
+ redisTemplate.opsForZSet().remove(key, value);
+ }
+
+ /**
+ * 获取指定元素的当前排列顺序
+ */
+ public Long getRank(String key, String value) {
+ return redisTemplate.opsForZSet().rank(key, value);
+ }
+
+ /**
+ * 获取当前ZSet中的最大score
+ */
+ public Double getMaxScore(String key) {
+ Set> set = redisTemplate.opsForZSet().reverseRangeWithScores(key, 0, 0);
+
+ if (!CollectionUtils.isEmpty(set)) {
+ Double score = set.iterator().next().getScore();
+ return score + 1.0;
+ } else {
+ return 1.0;
+ }
+ }
+
+ /**
+ * 判断元素是否存在
+ */
+ public Boolean isElementExistsInZSet(String key, String value) {
+ return redisTemplate.opsForZSet().score(key, value) != null;
+ }
+
+ /**
+ * 获取当前ZSet中数据量的总和
+ */
+ public Long getZSetTotal(String key) {
+ return redisTemplate.opsForZSet().zCard(key);
+ }
+
+ //- - - - - - - - - - - - - - - - - - - - - set类型 - - - - - - - - - - - - - - - - - - - -
+
+ /**
+ * 将数据放入set缓存
+ */
+ public void addToSet(String key, String value) {
+ redisTemplate.opsForSet().add(key, value);
+ }
+
+ /**
+ * 弹出变量中的元素
+ */
+ public void removeFromSet(String key, String value) {
+ redisTemplate.opsForSet().remove(key, value);
+ }
+
+ /**
+ * 检查给定的元素是否在变量中。
+ */
+ public Boolean isElementExistsInSet(String key, String obj) {
+ return redisTemplate.opsForSet().isMember(key, obj);
+ }
+
+
+ //- - - - - - - - - - - - - - - - - - - - - hash类型 - - - - - - - - - - - - - - - - - - - -
+
+ /**
+ * 加入缓存
+ */
+ public void addToMap(String key, Map map) {
+ redisTemplate.opsForHash().putAll(key, map);
+ }
+
+ /**
+ * 验证指定 key 下 有没有指定的 hashkey
+ */
+ public Boolean isElementExistsInMap(String key, String hashKey) {
+ return redisTemplate.opsForHash().hasKey(key, hashKey);
+ }
+
+ /**
+ * 获取指定key的值string
+ */
+ public String getMapValue(String key1, String key2) {
+ return String.valueOf(redisTemplate.opsForHash().get(key1, key2));
+ }
+
+ /**
+ * 删除指定 hash 的 HashKey
+ *
+ * @return 删除成功的 数量
+ */
+ public Long removeFromMap(String key, String hashKeys) {
+ return redisTemplate.opsForHash().delete(key, hashKeys);
+ }
+}
diff --git a/src/main/java/com/ai/da/common/utils/SendEmailUtil.java b/src/main/java/com/ai/da/common/utils/SendEmailUtil.java
index b4f7969c..82d6d163 100644
--- a/src/main/java/com/ai/da/common/utils/SendEmailUtil.java
+++ b/src/main/java/com/ai/da/common/utils/SendEmailUtil.java
@@ -34,6 +34,7 @@ public class SendEmailUtil {
* 发信地址
*/
private static String SEND_ADDRESS = "info@aida.com.hk";
+ private final static String CODE_CREATE_SEND_ADDRESS = "info@code-create.com.hk";
/**
* 登入主题
*/
@@ -219,18 +220,20 @@ public class SendEmailUtil {
JSONObject jsonObject = new JSONObject();
// 设置试用订单相关数据
jsonObject.put("userName", account.getUserName());
+
// 用户到期时间戳
Long timestamp = account.getValidEndTime(); // 替换为你的时间戳
+ if (null != timestamp) {
+ // 获取当前时间戳
+ Long currentTimestamp = System.currentTimeMillis();
- // 获取当前时间戳
- Long currentTimestamp = System.currentTimeMillis();
+ // 计算时间差(毫秒)
+ long timeDifference = currentTimestamp - timestamp;
- // 计算时间差(毫秒)
- long timeDifference = currentTimestamp - timestamp;
-
- // 向上取整计算天数
- long days = (timeDifference + 24 * 60 * 60 * 1000 - 1) / (24 * 60 * 60 * 1000);
- jsonObject.put("days", days);
+ // 向上取整计算天数
+ long days = (timeDifference + 24 * 60 * 60 * 1000 - 1) / (24 * 60 * 60 * 1000);
+ jsonObject.put("days", days);
+ }
return jsonObject.toJSONString();
}
@@ -269,4 +272,40 @@ public class SendEmailUtil {
jsonObject.put("email", trialOrder.getEmail());
return jsonObject.toJSONString();
}
+
+ private final static Long UPGRADE_NOTIFICATION_ID = 118855L;
+ public static void sendUpgradeNotification(Account account, String senderAddress) {
+ try {
+ // 实例化一个认证对象
+ Credential cred = new Credential(SECRET_ID, SECRET_KEy);
+ HttpProfile httpProfile = new HttpProfile();
+ httpProfile.setEndpoint("ses.tencentcloudapi.com");
+ ClientProfile clientProfile = new ClientProfile();
+ clientProfile.setHttpProfile(httpProfile);
+ SesClient client = new SesClient(cred, "ap-hongkong", clientProfile);
+ SendEmailRequest req = new SendEmailRequest();
+ if (StringUtils.isEmpty(senderAddress)) {
+ senderAddress = CODE_CREATE_SEND_ADDRESS;
+ }
+ req.setFromEmailAddress(senderAddress);
+ req.setDestination(new String[]{account.getUserEmail()});
+
+ // 根据邮件类型设置不同的主题和模板
+ String subject = "";
+ Template template = new Template();
+ subject = "Upcoming AiDA 3.0 Launch and Scheduled Maintenance";
+ template.setTemplateID(UPGRADE_NOTIFICATION_ID);
+ template.setTemplateData(buildAccountData(account));
+
+ req.setSubject(subject);
+ req.setTemplate(template);
+
+ // 发送邮件
+ SendEmailResponse resp = client.SendEmail(req);
+ log.info("短信发送结果res###{}", SendEmailResponse.toJsonString(resp));
+ } catch (TencentCloudSDKException e) {
+ log.info("邮件发送失败###{}", e.toString());
+ throw new BusinessException("failed.to.send.mail");
+ }
+ }
}
diff --git a/src/main/java/com/ai/da/common/utils/SnowflakeUtil.java b/src/main/java/com/ai/da/common/utils/SnowflakeUtil.java
new file mode 100644
index 00000000..03298fb4
--- /dev/null
+++ b/src/main/java/com/ai/da/common/utils/SnowflakeUtil.java
@@ -0,0 +1,137 @@
+package com.ai.da.common.utils;
+
+
+public class SnowflakeUtil {
+
+ // ==============================Fields===========================================
+ /** 开始时间截 (2015-01-01) */
+ private final long twepoch = 1420041600000L;
+
+ /** 机器id所占的位数 */
+ private final long workerIdBits = 5L;
+
+ /** 数据标识id所占的位数 */
+ private final long datacenterIdBits = 5L;
+
+ /** 支持的最大机器id,结果是31 (这个移位算法可以很快的计算出几位二进制数所能表示的最大十进制数) */
+ private final long maxWorkerId = -1L ^ (-1L << workerIdBits);
+
+ /** 支持的最大数据标识id,结果是31 */
+ private final long maxDatacenterId = -1L ^ (-1L << datacenterIdBits);
+
+ /** 序列在id中占的位数 */
+ private final long sequenceBits = 12L;
+
+ /** 机器ID向左移12位 */
+ private final long workerIdShift = sequenceBits;
+
+ /** 数据标识id向左移17位(12+5) */
+ private final long datacenterIdShift = sequenceBits + workerIdBits;
+
+ /** 时间截向左移22位(5+5+12) */
+ private final long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
+
+ /** 生成序列的掩码,这里为4095 (0b111111111111=0xfff=4095) */
+ private final long sequenceMask = -1L ^ (-1L << sequenceBits);
+
+ /** 工作机器ID(0~31) */
+ private long workerId;
+
+ /** 数据中心ID(0~31) */
+ private long datacenterId;
+
+ /** 毫秒内序列(0~4095) */
+ private long sequence = 0L;
+
+ /** 上次生成ID的时间截 */
+ private long lastTimestamp = -1L;
+
+ //==============================Constructors=====================================
+ /**
+ * 构造函数
+ * @param workerId 工作ID (0~31)
+ * @param datacenterId 数据中心ID (0~31)
+ */
+ public SnowflakeUtil(long workerId, long datacenterId) {
+ if (workerId > maxWorkerId || workerId < 0) {
+ throw new IllegalArgumentException(String.format
+ ("worker Id can't be greater than %d or less than 0", maxWorkerId));
+ }
+ if (datacenterId > maxDatacenterId || datacenterId < 0) {
+ throw new IllegalArgumentException(String.format
+ ("datacenter Id can't be greater than %d or less than 0", maxDatacenterId));
+ }
+ this.workerId = workerId;
+ this.datacenterId = datacenterId;
+ }
+
+ // ==============================Methods==========================================
+ /**
+ * 获得下一个ID (该方法是线程安全的)
+ * @return SnowflakeId
+ */
+ public synchronized long nextId() {
+ long timestamp = timeGen();
+
+ //如果当前时间小于上一次ID生成的时间戳,说明系统时钟回退过这个时候应当抛出异常
+ if (timestamp < lastTimestamp) {
+ throw new RuntimeException(
+ String.format
+ ("Clock moved backwards. Refusing to generate id for %d milliseconds", lastTimestamp - timestamp));
+ }
+
+ //如果是同一时间生成的,则进行毫秒内序列
+ if (lastTimestamp == timestamp) {
+ sequence = (sequence + 1) & sequenceMask;
+ //毫秒内序列溢出
+ if (sequence == 0) {
+ //阻塞到下一个毫秒,获得新的时间戳
+ timestamp = tilNextMillis(lastTimestamp);
+ }
+ }
+ //时间戳改变,毫秒内序列重置
+ else {
+ sequence = 0L;
+ }
+
+ //上次生成ID的时间截
+ lastTimestamp = timestamp;
+
+ //移位并通过或运算拼到一起组成64位的ID
+ return ((timestamp - twepoch) << timestampLeftShift) //
+ | (datacenterId << datacenterIdShift) //
+ | (workerId << workerIdShift) //
+ | sequence;
+ }
+
+ /**
+ * 阻塞到下一个毫秒,直到获得新的时间戳
+ * @param lastTimestamp 上次生成ID的时间截
+ * @return 当前时间戳
+ */
+ protected long tilNextMillis(long lastTimestamp) {
+ long timestamp = timeGen();
+ while (timestamp <= lastTimestamp) {
+ timestamp = timeGen();
+ }
+ return timestamp;
+ }
+
+ /**
+ * 返回以毫秒为单位的当前时间
+ * @return 当前时间(毫秒)
+ */
+ protected long timeGen() {
+ return System.currentTimeMillis();
+ }
+
+ //==============================Test=============================================
+ /** 测试 */
+ public static void main(String[] args) {
+ SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
+ long id = idWorker.nextId();
+ System.out.println("id:"+id);
+ //id:768842202204864512
+ }
+}
+
diff --git a/src/main/java/com/ai/da/controller/AccountController.java b/src/main/java/com/ai/da/controller/AccountController.java
index 73aa3bc2..d9ed7803 100644
--- a/src/main/java/com/ai/da/controller/AccountController.java
+++ b/src/main/java/com/ai/da/controller/AccountController.java
@@ -2,11 +2,13 @@ package com.ai.da.controller;
import com.ai.da.common.response.PageBaseResponse;
import com.ai.da.common.response.Response;
-import com.ai.da.mapper.primary.entity.TrialOrder;
-import com.ai.da.mapper.secondary.entity.FemaleDress;
+import com.ai.da.common.security.jwt.JWTTokenHelper;
+import com.ai.da.mapper.entity.TrialOrder;
import com.ai.da.model.dto.*;
+import com.ai.da.model.enums.Language;
import com.ai.da.model.vo.AccountLoginVO;
import com.ai.da.model.vo.AccountPreLoginVO;
+import com.ai.da.model.vo.QueryLibraryPageVO;
import com.ai.da.service.AccountService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
@@ -148,10 +150,10 @@ public class AccountController {
return Response.success(accountService.noLoginRequired(noLoginRequiredDTO, request));
}
- @ApiOperation(value = "免密登录")
- @PostMapping("/test")
- public Response test(){
- return Response.success(accountService.test());
+ @PostMapping("upgradeNotification")
+ @ApiOperation(value = "升级邮件通知")
+ public Response upgradeNotification() {
+ accountService.upgradeNotification();
+ return Response.success(true);
}
-
}
diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java
index 5fa712cc..b3f6fb5b 100644
--- a/src/main/java/com/ai/da/controller/GenerateController.java
+++ b/src/main/java/com/ai/da/controller/GenerateController.java
@@ -53,4 +53,25 @@ public class GenerateController {
return Response.success(generateService.generateDislike(generateDetailId, timeZone));
}
+ @ApiOperation(value = "发起生成请求,异步获取结果")
+ @PostMapping("/prepare")
+ public Response prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) {
+ return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO));
+ }
+
+ @ApiOperation(value = "取消继续生成")
+ @GetMapping("/stopWaiting")
+ public Response stopWaiting(@RequestParam("userId") Long userId,
+ @RequestParam("uniqueId") String uniqueId,
+ @RequestParam("timeZone") String timeZone) {
+ generateService.cancelGenerate(userId, uniqueId, timeZone);
+ return Response.success("stop waiting successfully");
+ }
+
+ @ApiOperation(value = "获取生成结果")
+ @GetMapping("/result")
+ public Response getGenerateResult(@RequestParam("uniqueId") String uniqueId) {
+ GenerateCollectionVO generateResult = generateService.getGenerateResult(uniqueId);
+ return Response.success(generateResult);
+ }
}
diff --git a/src/main/java/com/ai/da/controller/LibraryController.java b/src/main/java/com/ai/da/controller/LibraryController.java
index 09263d09..9c3a6ebb 100644
--- a/src/main/java/com/ai/da/controller/LibraryController.java
+++ b/src/main/java/com/ai/da/controller/LibraryController.java
@@ -26,6 +26,7 @@ import org.springframework.web.multipart.MultipartFile;
import javax.annotation.Resource;
import javax.validation.Valid;
import java.io.File;
+import java.text.ParseException;
import java.util.Date;
import java.util.Objects;
import java.util.UUID;
@@ -194,4 +195,11 @@ public class LibraryController {
return "/workspace/python_code/Multi-layer-Virtual-Try-on/dataset_for_test/Img_model.png";
}
+ @PostMapping("moveLibraryData")
+ @ApiOperation(value = "用户library数据迁移")
+ public Response moveLibraryDate() throws ParseException {
+ libraryService.moveLibraryDate();
+ return Response.success(true);
+ }
+
}
diff --git a/src/main/java/com/ai/da/mapper/GenerateCancelMapper.java b/src/main/java/com/ai/da/mapper/GenerateCancelMapper.java
new file mode 100644
index 00000000..3e24ea09
--- /dev/null
+++ b/src/main/java/com/ai/da/mapper/GenerateCancelMapper.java
@@ -0,0 +1,7 @@
+package com.ai.da.mapper;
+
+import com.ai.da.common.config.mybatis.plus.CommonMapper;
+import com.ai.da.mapper.entity.GenerateCancel;
+
+public interface GenerateCancelMapper extends CommonMapper {
+}
diff --git a/src/main/java/com/ai/da/mapper/LibraryCopyMapper.java b/src/main/java/com/ai/da/mapper/LibraryCopyMapper.java
new file mode 100644
index 00000000..3cec4a2e
--- /dev/null
+++ b/src/main/java/com/ai/da/mapper/LibraryCopyMapper.java
@@ -0,0 +1,15 @@
+package com.ai.da.mapper;
+
+import com.ai.da.common.config.mybatis.plus.CommonMapper;
+import com.ai.da.mapper.entity.Library;
+import com.ai.da.mapper.entity.LibraryCopy;
+
+/**
+ * Mapper 接口
+ *
+ * @author easy-generator
+ * @since 2022-06-13
+ */
+public interface LibraryCopyMapper extends CommonMapper {
+
+}
diff --git a/src/main/java/com/ai/da/mapper/LibraryModelPointCopyMapper.java b/src/main/java/com/ai/da/mapper/LibraryModelPointCopyMapper.java
new file mode 100644
index 00000000..05447f66
--- /dev/null
+++ b/src/main/java/com/ai/da/mapper/LibraryModelPointCopyMapper.java
@@ -0,0 +1,15 @@
+package com.ai.da.mapper;
+
+import com.ai.da.common.config.mybatis.plus.CommonMapper;
+import com.ai.da.mapper.entity.LibraryModelPoint;
+import com.ai.da.mapper.entity.LibraryModelPointCopy;
+
+/**
+ * Mapper 接口
+ *
+ * @author easy-generator
+ * @since 2022-11-11
+ */
+public interface LibraryModelPointCopyMapper extends CommonMapper {
+
+}
diff --git a/src/main/java/com/ai/da/mapper/primary/entity/Generate.java b/src/main/java/com/ai/da/mapper/primary/entity/Generate.java
index 009e7fe0..8da4dbcb 100644
--- a/src/main/java/com/ai/da/mapper/primary/entity/Generate.java
+++ b/src/main/java/com/ai/da/mapper/primary/entity/Generate.java
@@ -27,6 +27,11 @@ public class Generate {
*/
private Long accountId;
+ /**
+ * 唯一id
+ */
+ private String uniqueId;
+
/**
* Sketchboard Printboard
*/
diff --git a/src/main/java/com/ai/da/mapper/primary/entity/GenerateCancel.java b/src/main/java/com/ai/da/mapper/primary/entity/GenerateCancel.java
new file mode 100644
index 00000000..8726b9ed
--- /dev/null
+++ b/src/main/java/com/ai/da/mapper/primary/entity/GenerateCancel.java
@@ -0,0 +1,44 @@
+package com.ai.da.mapper.entity;
+
+import com.baomidou.mybatisplus.annotation.IdType;
+import com.baomidou.mybatisplus.annotation.TableId;
+import com.baomidou.mybatisplus.annotation.TableName;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.experimental.Accessors;
+
+import java.util.Date;
+
+@Data
+@EqualsAndHashCode(callSuper = false)
+@Accessors(chain = true)
+@TableName("t_generate_cancel")
+public class GenerateCancel {
+
+ /**
+ * ID
+ */
+ @TableId(value = "id", type = IdType.AUTO)
+ private Long id;
+
+ /**
+ * 用户ID
+ */
+ private Long accountId;
+
+ /**
+ * 唯一id(任务id)
+ */
+ private String uniqueId;
+
+ /**
+ * 创建时间
+ */
+ private Date createDate;
+
+ public GenerateCancel(Long accountId, String uniqueId, Date createDate) {
+ this.accountId = accountId;
+ this.uniqueId = uniqueId;
+ this.createDate = createDate;
+ }
+}
diff --git a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java
index ca12d04a..fd45533c 100644
--- a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java
+++ b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java
@@ -5,10 +5,14 @@ import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.NotBlank;
+import javax.validation.constraints.NotNull;
@Data
@ApiModel("GenerateThroughImageTextDTO")
public class GenerateThroughImageTextDTO {
+ @NotNull(message = "userId cannot be empty")
+ @ApiModelProperty("用户id")
+ Long userId;
@ApiModelProperty("caption")
String text;
@@ -40,4 +44,7 @@ public class GenerateThroughImageTextDTO {
@NotBlank(message = "timeZone cannot be empty!")
@ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取")
String timeZone;
+
+ @ApiModelProperty("唯一id,用于保持消息唯一性")
+ String uniqueId;
}
diff --git a/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java b/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java
new file mode 100644
index 00000000..927fff70
--- /dev/null
+++ b/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java
@@ -0,0 +1,27 @@
+package com.ai.da.model.dto;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class GenerateToPythonDTO {
+
+ private Long user_id;
+
+ private String image_url;
+
+ private String category;
+
+ private String content;
+
+ private Integer mode;
+
+ private String version;
+
+ private String gender;
+
+ private String tasks_id;
+}
diff --git a/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java b/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java
index 433e86c8..aa8fb7c6 100644
--- a/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java
+++ b/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java
@@ -19,6 +19,13 @@ public class GenerateCollectionVO {
@ApiModelProperty("生成的图片信息")
private List generatedCollectionItems;
+ @ApiModelProperty("在当前队列中的排序")
+ private Long rankPosition;
+
+ public GenerateCollectionVO(Long rankPosition) {
+ this.rankPosition = rankPosition;
+ }
+
public GenerateCollectionVO(Long generateId, Long collectionId, List generatedCollectionItems) {
this.generateId = generateId;
this.collectionId = collectionId;
diff --git a/src/main/java/com/ai/da/model/vo/PageQueryBaseVo.java b/src/main/java/com/ai/da/model/vo/PageQueryBaseVo.java
index 496f0122..8986433c 100644
--- a/src/main/java/com/ai/da/model/vo/PageQueryBaseVo.java
+++ b/src/main/java/com/ai/da/model/vo/PageQueryBaseVo.java
@@ -18,6 +18,6 @@ public class PageQueryBaseVo {
@ApiModelProperty("每页数量")
@Min(value = 0, message = "The minimum size is 1")
- @Max(value = 50, message = "The maximum size is 50")
+// @Max(value = 50, message = "The maximum size is 50")
private Integer size = 20;
}
diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java
index 42fa011a..cc8f32ae 100644
--- a/src/main/java/com/ai/da/python/PythonService.java
+++ b/src/main/java/com/ai/da/python/PythonService.java
@@ -56,6 +56,7 @@ public class PythonService {
private String accessPythonIp;
@Value("${access.python.port:''}")
private String accessPythonPort;
+
@Resource
private PythonTAllInfoService pythonTAllInfoService;
@@ -106,7 +107,7 @@ public class PythonService {
if (Objects.nonNull(response.body())) {
responseBody = response.body().string();
JSONObject responseObj = JSON.parseObject(responseBody);
- log.info("moodboard与printboard图片合成 python返回###{}",responseObj);
+ log.info("moodboard与printboard图片合成 python返回###{}", responseObj);
return responseObj.get("data").toString();
}
} catch (IOException | JSONException e) {
@@ -410,7 +411,7 @@ public class PythonService {
all.addAll(new ArrayList<>(DesignPythonItem.SKIRT_TROUSERS));
return all;
}
- }else if (modelSex.equals(Sex.MALE.getValue())) {
+ } else if (modelSex.equals(Sex.MALE.getValue())) {
Long randomIndex = RandomsUtil.randomSysFile(0L, 3L);
if (randomIndex == 0) {
return DesignPythonItem.TOPS;
@@ -443,11 +444,11 @@ public class PythonService {
long noPinNum = printBoardElements.stream().filter(f -> f.getHasPin() == 0).count();
if (noPinNum == 0L) {
return 0;
- }else {
+ } else {
long pinNum = printBoardElements.stream().filter(f -> f.getHasPin() == 1).count();
if (8 - pinNum < 4) {
return RandomsUtil.randomSysFile(0L, 8 - pinNum + 1);
- }else {
+ } else {
return RandomsUtil.randomSysFile(0L, 4L + 1);
}
}
@@ -1621,7 +1622,7 @@ public class PythonService {
printBoardElements = elementVO.getPrintBoardElements()
.stream()
.filter(f -> !elementVO.getHasUseMd5List().contains(f.getMd5())).collect(Collectors.toList());
- }else {
+ } else {
printBoardElements = elementVO.getPrintBoardElements();
}
if (CollectionUtil.isEmpty(printBoardElements)) {
@@ -2102,7 +2103,7 @@ public class PythonService {
skirt.setPrint(designPythonItemPrint);
skirt.setPath("aida-sys-image/images/female/trousers/trousers_974.jpg");
response.add(skirt);
- }else {
+ } else {
DesignPythonItem top = new DesignPythonItem();
top.setType(MalePosition.TOPS.getValue());
top.setColor("none");
@@ -2216,7 +2217,9 @@ public class PythonService {
throw new BusinessException("design.interface.exception");
}
- /** 暂时未用 */
+ /**
+ * 暂时未用
+ */
public String generateSketchCaption(String url) {
//限流校验
AccessLimitUtils.validate("generateSketchCaption", 5);
@@ -2259,9 +2262,9 @@ public class PythonService {
throw new BusinessException("system error!");
}
- public List generateSketchOrPrint(Long userId, String url, String category, String text, int mode, String modelName, String gender) {
+ public List generateSketchOrPrint(GenerateToPythonDTO generateToPythonDTO) {
//限流校验
- AccessLimitUtils.validate("generateSketchOrPrint", 5);
+// AccessLimitUtils.validate("generateSketchOrPrint", 5);
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
@@ -2269,46 +2272,45 @@ public class PythonService {
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
MediaType mediaType = MediaType.parse("application/json");
- Map content = Maps.newHashMap();
- content.put("user_id", userId);
- content.put("image_url", url);
- content.put("category", category);
- content.put("mode", mode);
- content.put("str", text);
- content.put("version", "1");
- content.put("gender", gender);
- RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content, SerializerFeature.WriteMapNullValue));
+ RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue));
Request request = new Request.Builder()
-// .url(accessPythonIp + ":2828/aida/diffusion")
// .url("http://18.167.251.121:9992")
- .url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion")
+// .url("http://127.0.0.1:5000/api/diffusion")
+// .url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion")
+ .url(accessPythonIp + ":" + accessPythonPort + "/api/generate_image")
.method("POST", body)
- .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
+// .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
.addHeader("Content-Type", "application/json")
.build();
Response response = null;
- String bodyString ;
+ String bodyString;
try {
- log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(content, SerializerFeature.WriteMapNullValue));
+ log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue));
response = client.newCall(request).execute();
} catch (IOException ioException) {
log.error("PythonService##generateSketchOrPrint异常###{}", ExceptionUtil.getThrowableList(ioException));
+// throw new BusinessException("generate.interface.error");
+ throw new BusinessException(ioException.getMessage());
}
//去除限流
- AccessLimitUtils.validateOut("generateSketchOrPrint");
+// AccessLimitUtils.validateOut("generateSketchOrPrint");
// 判断是否生成失败
- if (Objects.isNull(response) || Objects.isNull(response.body())) {
+ if (Objects.isNull(response.body())) {
log.error("PythonService##generateSketchOrPrint异常###{}", "response or body is empty!");
- throw new BusinessException("generate.interface.error");
- } else if (response.code() != HttpURLConnection.HTTP_OK){
+// throw new BusinessException("generate.interface.error");
+ throw new BusinessException("PythonService##generateSketchOrPrint异常###: response or body is empty!");
+ } else if (response.code() != HttpURLConnection.HTTP_OK) {
log.error("PythonService##generateSketchOrPrint异常###{}", "Response error!Response code ## " + response.code() + " ##");
- throw new BusinessException("generate.interface.error");
+// throw new BusinessException("generate.interface.error");
+ throw new BusinessException("PythonService##generateSketchOrPrint异常### Response error!Response code ## " + response.code() + " ##");
} else {
try {
bodyString = response.body().string();
} catch (IOException e) {
- throw new BusinessException("generate.interface.error");
+ log.error(e.getMessage());
+// throw new BusinessException("generate.interface.error");
+ throw new BusinessException(e.getMessage());
}
}
JSONObject jsonObject = JSON.parseObject(bodyString);
@@ -2378,7 +2380,7 @@ public class PythonService {
if (Objects.isNull(response) || Objects.isNull(response.body())) {
log.error("PythonService##composeLayers异常###{}", "response or body is empty!");
throw new BusinessException("compose-layer.interface.exception");
- }else if (response.code() != HttpURLConnection.HTTP_OK){
+ } else if (response.code() != HttpURLConnection.HTTP_OK) {
log.error("PythonService##composeLayers异常###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("compose-layer.interface.exception");
} else {
@@ -2403,10 +2405,10 @@ public class PythonService {
return item0.getString("synthesis_url");
}
- public String getClothCategory(String path,String gender){
+ public String getClothCategory(String path, String gender) {
HashMap content = new HashMap<>();
- content.put("sketch_img_url",path);
- content.put("colony",gender);
+ content.put("sketch_img_url", path);
+ content.put("colony", gender);
List> contents = Collections.singletonList(content);
String jsonString = JSON.toJSONString(contents, SerializerFeature.WriteNullStringAsEmpty);
@@ -2419,7 +2421,7 @@ public class PythonService {
if (Objects.isNull(response) || Objects.isNull(response.body())) {
log.error("PythonService##GetClothCategory###{}", "response or body is empty!");
throw new BusinessException("cloth-classification.interface.exception");
- } else if (response.code() != HttpURLConnection.HTTP_OK){
+ } else if (response.code() != HttpURLConnection.HTTP_OK) {
log.error("PythonService##GetClothCategory###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("cloth-classification.interface.exception");
} else {
@@ -2430,7 +2432,7 @@ public class PythonService {
}
}
JSONObject jsonObject = JSON.parseObject(bodyString);
- try{
+ try {
Boolean result = JSON.parseObject(JSON.toJSONString(response)).getBoolean("successful");
if (result && jsonObject.get("msg").equals("OK!")) {
JSONObject data = jsonObject.getJSONObject("data");
@@ -2438,7 +2440,7 @@ public class PythonService {
JSONObject map = (JSONObject) list.get(0);
return map.get("category").toString();
}
- }catch (NullPointerException e){
+ } catch (NullPointerException e) {
log.info("getClothCategory 失败###{},未返回category", jsonObject);
throw new BusinessException("cloth-classification.interface.exception");
}
@@ -2446,4 +2448,35 @@ public class PythonService {
//生成失败
throw new BusinessException("cloth-classification.interface.exception");
}
+
+ public Boolean cancelGenerateTask(String taskId) {
+ OkHttpClient client = new OkHttpClient().newBuilder()
+ .connectTimeout(30, TimeUnit.SECONDS)
+ .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
+ .readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
+ .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
+ .build();
+ String url = accessPythonIp + ":" + accessPythonPort + "/api/generate_cancel/" + taskId;
+ Request request = new Request.Builder()
+ .url(url)
+// .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
+// .addHeader("Content-Type", "application/json")
+ .build();
+ Response response;
+ try {
+ log.info("cancelGenerateTask请求入参content###{}", taskId);
+ response = client.newCall(request).execute();
+ } catch (IOException ioException) {
+ log.error("PythonService##cancelGenerateTask异常###{}", ExceptionUtil.getThrowableList(ioException));
+ return null;
+ }
+
+ if (response.code() != HttpURLConnection.HTTP_OK) {
+ log.info("generate-python 取消请求失败");
+ return Boolean.FALSE;
+ }
+ log.info("generate-python 取消请求成功");
+ return Boolean.TRUE;
+ }
+
}
diff --git a/src/main/java/com/ai/da/service/AccountService.java b/src/main/java/com/ai/da/service/AccountService.java
index 8df6b601..d2982dd1 100644
--- a/src/main/java/com/ai/da/service/AccountService.java
+++ b/src/main/java/com/ai/da/service/AccountService.java
@@ -126,5 +126,7 @@ public interface AccountService extends IService {
Boolean deleteNoLoginRequiredNew(NoLoginRequiredDTO noLoginRequiredDTO, HttpServletRequest request);
- FemaleDress test();
+ void upgradeNotification();
+
+ void moveLibraryDate();
}
diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java
index 762ffbe9..09b10687 100644
--- a/src/main/java/com/ai/da/service/GenerateService.java
+++ b/src/main/java/com/ai/da/service/GenerateService.java
@@ -24,4 +24,13 @@ public interface GenerateService extends IService {
void updateLikeStatusBatch(List generateDetailIdList, Byte hasLike, Long libraryId, String timeZone);
List selectBatchByLibraryId(List libraryId);
+
+ GenerateCollectionVO getGenerateResult(String uniqueId);
+
+ String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO);
+
+ Long getRankPosition(String uniqueId);
+
+ void cancelGenerate(Long userId, String uniqueId, String timeZone);
+
}
diff --git a/src/main/java/com/ai/da/service/LibraryService.java b/src/main/java/com/ai/da/service/LibraryService.java
index 440a9577..d2083312 100644
--- a/src/main/java/com/ai/da/service/LibraryService.java
+++ b/src/main/java/com/ai/da/service/LibraryService.java
@@ -8,6 +8,8 @@ import com.ai.da.model.vo.LibraryVo;
import com.ai.da.model.vo.QueryLibraryPageVO;
import com.baomidou.mybatisplus.extension.service.IService;
+import javax.validation.Valid;
+import java.text.ParseException;
import java.util.List;
/**
@@ -84,4 +86,6 @@ public interface LibraryService extends IService {
void batchDeleteLibrary(LibraryDeleteDTO deleteDTO);
void deleteTrialData(Long id);
+
+ void moveLibraryDate() throws ParseException;
}
diff --git a/src/main/java/com/ai/da/service/RabbitMQService.java b/src/main/java/com/ai/da/service/RabbitMQService.java
new file mode 100644
index 00000000..b03a578d
--- /dev/null
+++ b/src/main/java/com/ai/da/service/RabbitMQService.java
@@ -0,0 +1,11 @@
+package com.ai.da.service;
+
+import org.springframework.stereotype.Service;
+
+@Service
+public interface RabbitMQService {
+
+ void publishMessage(String message);
+
+ Integer getMessageCount(String queueUrl);
+}
diff --git a/src/main/java/com/ai/da/service/impl/AccountServiceImpl.java b/src/main/java/com/ai/da/service/impl/AccountServiceImpl.java
index 446186c6..de53e544 100644
--- a/src/main/java/com/ai/da/service/impl/AccountServiceImpl.java
+++ b/src/main/java/com/ai/da/service/impl/AccountServiceImpl.java
@@ -476,7 +476,14 @@ public class AccountServiceImpl extends ServiceImpl impl
.eq(Account::getUserName, accountTrialDTO.getUserName());
List accountList = accountMapper.selectList(qw);
if (CollectionUtil.isNotEmpty(accountList)) {
- throw new BusinessException("The username or email has already been registered", ResultEnum.PROMPT.getCode());
+ if (accountList.get(0).getIsTrial() == 1) {
+ throw new BusinessException("The username or email has already been registered", ResultEnum.PROMPT.getCode());
+ }else {
+ Account account = accountList.get(0);
+ if (null == account.getValidEndTime() || account.getValidEndTime() > System.currentTimeMillis()) {
+ throw new BusinessException("The username or email has already been registered", ResultEnum.PROMPT.getCode());
+ }
+ }
}
// 接收到数据后要形成一条使用订单信息
TrialOrder trialOrder = CopyUtil.copyObject(accountTrialDTO, TrialOrder.class);
@@ -491,16 +498,25 @@ public class AccountServiceImpl extends ServiceImpl impl
trialOrder.setUpdateTime(LocalDateTime.now());
trialOrderMapper.updateById(trialOrder);
Account account = new Account();
- account.setUserName(trialOrder.getUserName());
- account.setUserPassword("Third-000000");
- account.setUserEmail(trialOrder.getEmail());
- account.setLanguage(Language.ENGLISH.name());
- account.setValidStartTime(System.currentTimeMillis());
- account.setValidEndTime(Instant.now().plus(3, ChronoUnit.DAYS).toEpochMilli());
- account.setCreateDate(new Date());
- account.setIsTrial(1);
- account.setIsBeginner(1);
- accountMapper.insert(account);
+ if (CollectionUtil.isNotEmpty(accountList)) {
+ account = CopyUtil.copyObject(accountList.get(0), Account.class);
+ account.setIsTrial(1);
+ account.setIsBeginner(1);
+ account.setValidStartTime(System.currentTimeMillis());
+ account.setValidEndTime(Instant.now().plus(3, ChronoUnit.DAYS).toEpochMilli());
+ accountMapper.updateById(account);
+ }else {
+ account.setUserName(trialOrder.getUserName());
+ account.setUserPassword("Third-000000");
+ account.setUserEmail(trialOrder.getEmail());
+ account.setLanguage(Language.ENGLISH.name());
+ account.setValidStartTime(System.currentTimeMillis());
+ account.setValidEndTime(Instant.now().plus(3, ChronoUnit.DAYS).toEpochMilli());
+ account.setCreateDate(new Date());
+ account.setIsTrial(1);
+ account.setIsBeginner(1);
+ accountMapper.insert(account);
+ }
// 发送邮件提醒用户试用用户已创建
SendEmailUtil.sendCustomEmail("1023316923@qq.com", null, trialOrder,2);
SendEmailUtil.sendCustomEmail(account.getUserEmail(), null, trialOrder, 3);
@@ -522,17 +538,33 @@ public class AccountServiceImpl extends ServiceImpl impl
trialOrder.setStatus(1);
trialOrder.setUpdateTime(LocalDateTime.now());
trialOrderMapper.updateById(trialOrder);
+
+ QueryWrapper qw = new QueryWrapper<>();
+ qw.lambda().eq(Account::getUserEmail, trialOrder.getEmail())
+ .or()
+ .eq(Account::getUserName, trialOrder.getUserName());
+ List accountList = accountMapper.selectList(qw);
+
Account account = new Account();
- account.setUserName(trialOrder.getUserName());
- account.setUserPassword("Third-000000");
- account.setUserEmail(trialOrder.getEmail());
- account.setLanguage(Language.ENGLISH.name());
- account.setValidStartTime(System.currentTimeMillis());
- account.setValidEndTime(Instant.now().plus(3, ChronoUnit.DAYS).toEpochMilli());
- account.setCreateDate(new Date());
- account.setIsTrial(1);
- account.setIsBeginner(1);
- accountMapper.insert(account);
+ if (CollectionUtil.isNotEmpty(accountList)) {
+ account = CopyUtil.copyObject(accountList.get(0), Account.class);
+ account.setIsTrial(1);
+ account.setIsBeginner(1);
+ account.setValidStartTime(System.currentTimeMillis());
+ account.setValidEndTime(Instant.now().plus(3, ChronoUnit.DAYS).toEpochMilli());
+ accountMapper.updateById(account);
+ }else {
+ account.setUserName(trialOrder.getUserName());
+ account.setUserPassword("Third-000000");
+ account.setUserEmail(trialOrder.getEmail());
+ account.setLanguage(Language.ENGLISH.name());
+ account.setValidStartTime(System.currentTimeMillis());
+ account.setValidEndTime(Instant.now().plus(3, ChronoUnit.DAYS).toEpochMilli());
+ account.setCreateDate(new Date());
+ account.setIsTrial(1);
+ account.setIsBeginner(1);
+ accountMapper.insert(account);
+ }
// 发送邮件提醒用户试用用户已创建
SendEmailUtil.sendCustomEmail("1023316923@qq.com", null, trialOrder,2);
SendEmailUtil.sendCustomEmail(account.getUserEmail(), null, trialOrder, 3);
@@ -835,10 +867,25 @@ public class AccountServiceImpl extends ServiceImpl impl
return Boolean.TRUE;
}
- @Resource
- private FemaleDressMapper femaleDressMapper;
@Override
- public FemaleDress test() {
- return femaleDressMapper.selectById(33056);
+ public void upgradeNotification() {
+ QueryWrapper queryWrapper = new QueryWrapper<>();
+ queryWrapper.and(wrapper ->
+ wrapper.gt("valid_end_time", 1706112000000L)
+ .or().isNull("valid_end_time"))
+ .isNotNull("user_email");
+
+ List accountList = accountMapper.selectList(queryWrapper);
+ System.out.println(accountList);
+ for (Account account : accountList) {
+ SendEmailUtil.sendUpgradeNotification(account, null);
+ }
+ }
+
+ @Override
+ public void moveLibraryDate() {
+ // 查询生产全部library数据,遍历数据,根据用户id和md5查询是否已经迁移过
+
+ // 未迁移过的进行迁移,注意模特数据迁移打点信息以及转换模特格式
}
}
diff --git a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java
index 5b9d3710..82146f3b 100644
--- a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java
+++ b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java
@@ -810,9 +810,11 @@ public class CollectionElementServiceImpl extends ServiceImpl implements GenerateService {
@@ -52,10 +59,31 @@ public class GenerateServiceImpl extends ServiceImpl i
@Resource
private MinioUtil minioUtil;
+ @Resource
+ private RabbitMQService rabbitMQService;
+
+ @Resource
+ private RedisUtil redisUtil;
+
+ @Resource
+ private GenerateCancelMapper generateCancelMapper;
+
+ @Value("${redis.key.consumptionOrder}")
+ private String consumptionOrderKey;
+
+ @Value("${redis.key.cancelSet}")
+ private String cancelSetKey;
+
+ @Value("${redis.key.exceptionMap}")
+ private String exceptionMapKey;
+
+ @Value("${redis.key.resultMap}")
+ private String resultMapKey;
+
@Override
public GenerateCaptionVO generateCaption(Long sketchElementId) {
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
- if (Objects.isNull(collectionElement)){
+ if (Objects.isNull(collectionElement)) {
throw new BusinessException("the.image.does.not.exist.please.reselect");
}
String url = collectionElement.getUrl();
@@ -69,45 +97,46 @@ public class GenerateServiceImpl extends ServiceImpl i
@Transactional(rollbackFor = Exception.class)
public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、获取用户信息
- AuthPrincipalVo userHolder = UserContext.getUserHolder();
+ Long accountId = generateThroughImageTextDTO.getUserId();
String generateType = generateThroughImageTextDTO.getGenerateType();
- Long accountId = userHolder.getId();
- if (!GenerateModeEnum.getGenerateModeList().contains(generateType)){
- throw new BusinessException("unknown.generate.type");
- }
// 2、判断必须入参是否为非空
Generate generate = new Generate();
generate.setAccountId(accountId);
+ generate.setUniqueId(generateThroughImageTextDTO.getUniqueId());
generate.setLevel1Type(generateThroughImageTextDTO.getLevel1Type());
// 当level1type是sketchboard时,存数据库需要加上当前性别
generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ?
- generateType + " (" +generateThroughImageTextDTO.getGender() + ")":
+ generateType + " (" + generateThroughImageTextDTO.getGender() + ")" :
generateType);
generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getVersion());
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
-
String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
- validateGeneraType(generate, text, elementId,generateType);
+ validateGeneraType(generate, text, elementId, generateType);
- // 3、将请求信息落库
- // 3.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
+ // 2.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type());
- // 3.2 将本次generate的请求信息添加到t_generate表中
- save(generate);
-
- // 4、向模型发起请求
+ // 3、向模型发起请求
int mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ?
GenerateModeEnum.TEXT.getCode() :
GenerateModeEnum.TEXT_IMAGE.getCode();
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
-// text = !StringUtil.isNullOrEmpty(text) && generateThroughImageTextDTO.getVersion().equals("1") ? "painting style, " + text : text;
- List generatedSketchUrl = pythonService.generateSketchOrPrint(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
- category, text, mode, generateThroughImageTextDTO.getVersion(), generateThroughImageTextDTO.getGender());
+ AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
+ List generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
+ category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId()));
+// List generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
+// category, text, mode, "1", generateThroughImageTextDTO.getGender()));
+ log.info("generate 响应 : " + generatedSketchUrl);
+ if (CollectionUtils.isEmpty(generatedSketchUrl)) {
+ return null;
+ }
+
+ // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
+ save(generate);
// 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库
@@ -117,8 +146,8 @@ public class GenerateServiceImpl extends ServiceImpl i
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(item, 24 * 60), Boolean.FALSE);
// 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过
- List