Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
shahaibo
2024-01-26 13:18:37 +08:00
24 changed files with 1167 additions and 82 deletions

View File

@@ -0,0 +1,45 @@
package com.ai.da.common.RabbitMQ;
import org.springframework.amqp.core.Binding;
import org.springframework.amqp.core.BindingBuilder;
import org.springframework.amqp.core.FanoutExchange;
import org.springframework.amqp.core.Queue;
import org.springframework.amqp.rabbit.connection.CachingConnectionFactory;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.amqp.rabbit.core.RabbitAdmin;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.beans.factory.annotation.Value;
@Configuration
public class MQConfig {
public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange";
public static final String GENERATE_QUEUE = "generate-queue-prod";
public MQConfig() {
}
// @Bean
// FanoutExchange fanoutRasaExchange() {
// return new FanoutExchange(GENERATE_EXCHANGE_FANOUT);
// }
/**
* 创建队列,使用工作模式,不用定义交换机
*/
@Bean
public Queue queueRasa() {
return new Queue(GENERATE_QUEUE);
}
/**
* 将队列绑定到交换机上【队列订阅交换机】
*/
// @Bean
// Binding bindingExchangeRasa() {
// return BindingBuilder.bind(queueRasa()).to(fanoutRasaExchange());
// }
}

View File

@@ -0,0 +1,161 @@
package com.ai.da.common.RabbitMQ;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.utils.RedisUtil;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.vo.GenerateCollectionVO;
import com.ai.da.service.GenerateService;
import com.alibaba.fastjson.JSONObject;
import com.rabbitmq.client.Channel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.core.Message;
import org.springframework.amqp.rabbit.annotation.RabbitHandler;
import org.springframework.amqp.rabbit.annotation.RabbitListener;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.io.IOException;
import java.util.HashMap;
import java.util.Objects;
@Slf4j
@Component
public class MQConsumer {
@Resource
private GenerateService generateService;
@Resource
private RedisUtil redisUtil;
@Value("${redis.key.consumptionOrder}")
private String consumptionOrderKey;
@Value("${redis.key.cancelSet}")
private String cancelSetKey;
@Value("${redis.key.exceptionMap}")
private String exceptionMapKey;
@Value("${redis.key.resultMap}")
private String resultMapKey;
public void generate(Message msg, Channel channel, String consumerName) {
log.info("============start listening==========");
long start = System.currentTimeMillis();
GenerateThroughImageTextDTO generateThroughImageTextDTO = JSONObject.parseObject(msg.getBody(), GenerateThroughImageTextDTO.class);
String uniqueId = generateThroughImageTextDTO.getUniqueId();
log.info("From " + consumerName + " : " + uniqueId);
try {
// 2、判断当前消息是否在取消列表中
Boolean isMember = redisUtil.isElementExistsInSet(cancelSetKey, uniqueId);
if (isMember) {
try {
// 2.1 手动确认该消息
channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
} catch (IOException ex) {
log.error("手动确认,不返回队列重新消费");
}
// 2.2 将该消息从取消列表中删除
// redisUtil.removeFromSet(cancelSetKey, uniqueId);
} else {
/*try {
Thread.sleep(15000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}*/
GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO);
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
if (!Objects.isNull(generateCollectionVO)) {
HashMap<String, String> generateResult = new HashMap<>();
generateResult.put(uniqueId, JSONObject.toJSONString(generateCollectionVO));
// 将结果存在redis中 ,为空时不要存
redisUtil.addToMap(resultMapKey, generateResult);
}
}
} catch (BusinessException e) {
log.error(e.getMsg());
// channel.basicNack() 为不确认deliveryTag对应的消息第二个参数是否应用于多消息第三个参数是否requeue
try {
// 第二个参数是否批量确认消息当传false时只确认当前 deliveryTag对应的消息;当传true时会确认当前及之前所有未确认的消息。
channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
} catch (IOException exception) {
log.error("手动确认,取消返回队列,不再重新消费");
}
// 将入参和错误信息存入数据库
String exceptionMessage = JSONObject.toJSONString(generateThroughImageTextDTO) +
" Exception message " + e.getMsg();
HashMap<String, String> exceptionInfo = new HashMap<>();
exceptionInfo.put(String.valueOf(uniqueId), exceptionMessage);
// 存redis
redisUtil.addToMap(exceptionMapKey, exceptionInfo);
}
long end = System.currentTimeMillis();
log.info(" task_id " + uniqueId + "----------" + consumerName + " 执行时长:" + (end - start) + "毫秒");
log.info("=============end listening===========");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer1(Message msg, Channel channel) {
generate(msg, channel, "consumer 1");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer2(Message msg, Channel channel) {
generate(msg, channel, "consumer 2");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer3(Message msg, Channel channel) {
generate(msg, channel, "consumer 3");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer4(Message msg, Channel channel) {
generate(msg, channel, "consumer 4");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer5(Message msg, Channel channel) {
generate(msg, channel, "consumer 5");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer6(Message msg, Channel channel) {
generate(msg, channel, "consumer 6");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer7(Message msg, Channel channel) {
generate(msg, channel, "consumer 7");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer8(Message msg, Channel channel) {
generate(msg, channel, "consumer 8");
}
@RabbitListener(queues = MQConfig.GENERATE_QUEUE)
@RabbitHandler
public void generateConsumer9(Message msg, Channel channel) {
generate(msg, channel, "consumer 9");
}
}

View File

@@ -0,0 +1,24 @@
package com.ai.da.common.RabbitMQ;
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.core.AmqpTemplate;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
@Slf4j
@Component
public class MQPublisher {
private final String url = "http://localhost:15672/api/queues/%2f/generate-queue";
@Resource
private AmqpTemplate amqpTemplate;
public void sendGenerateMessage(String mm) {
log.info("send message:" + mm);
amqpTemplate.convertAndSend(MQConfig.GENERATE_QUEUE, mm);
}
}

View File

@@ -0,0 +1,42 @@
package com.ai.da.common.config;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
@Configuration
public class RedisConfig {
@Bean(name = "redisTemplate")
public RedisTemplate<String, Object> getRedisTemplate(RedisConnectionFactory factory) {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setConnectionFactory(factory);
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
redisTemplate.setKeySerializer(stringRedisSerializer); // key的序列化类型
Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
// 方法过期,改为下面代码
// objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
objectMapper.activateDefaultTyping(LaissezFaireSubTypeValidator.instance,
ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
redisTemplate.setValueSerializer(jackson2JsonRedisSerializer); // value的序列化类型
redisTemplate.setHashKeySerializer(stringRedisSerializer);
redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
redisTemplate.afterPropertiesSet();
return redisTemplate;
}
}

View File

@@ -0,0 +1,73 @@
package com.ai.da.common.utils;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.model.dto.GenerateToPythonDTO;
import com.ai.da.python.PythonService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.*;
@Slf4j
@Component
public class AsyncCallerUtil {
public static Map<String, Boolean> waitingStatus = new HashMap<>();
private static PythonService pythonService;
@Autowired
public void setPythonService(PythonService pythonService) {
AsyncCallerUtil.pythonService = pythonService;
}
public CompletableFuture<List<String>> callGenerateAsync(GenerateToPythonDTO generateToPython) {
return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython));
}
public List<String> generate(GenerateToPythonDTO generateToPython) {
ScheduledExecutorService scheduledExecutorService = Executors.newScheduledThreadPool(5);
String taskId = generateToPython.getTasks_id();
ScheduledFuture<?> timeoutTask = null;
if (!waitingStatus.containsKey(taskId)) waitingStatus.put(taskId, true);
try {
CompletableFuture<List<String>> generateResult = callGenerateAsync(generateToPython);
// 5秒后第一次确认之后每隔10秒确认一次用户选择结果
timeoutTask = scheduledExecutorService.scheduleAtFixedRate(() -> {
// 调用另一个接口获取用户的选择
if (!waitingStatus.get(taskId)) {
// 如果用户选择取消则取消对generate的调用
generateResult.cancel(true);
waitingStatus.remove(taskId);
} else log.info("===============持续等待===============");
}, 5, 10, TimeUnit.SECONDS);
log.info("阻塞等待结果...");
// 阻塞,等待结果
List<String> result = generateResult.get();
// 取消定时任务
timeoutTask.cancel(true);
waitingStatus.remove(taskId);
return result;
} catch (CancellationException e) {
// generateResult.cancel(true);通过抛出异常取消该任务
log.info("==========成功取消generate任务==========");
return null;
} catch (InterruptedException | ExecutionException | BusinessException e) {
// 处理异常
log.error("发生错误 " + e);
// 取消定时任务
assert timeoutTask != null;
timeoutTask.cancel(true);
throw new BusinessException(e.getMessage());
} finally {
// 关闭线程池
// executorService.shutdown();
// scheduledExecutorService.shutdown();
}
}
}

View File

@@ -0,0 +1,126 @@
package com.ai.da.common.utils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ZSetOperations;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import javax.annotation.Resource;
import java.util.Map;
import java.util.Set;
@Slf4j
@Component
public class RedisUtil {
@Resource
private RedisTemplate<String, String> redisTemplate;
//- - - - - - - - - - - - - - - - - - - - - ZSet类型 - - - - - - - - - - - - - - - - - - - -
/**
* 向ZSet中添加元素
*/
public void addToZSet(String key, String value, Double score) {
redisTemplate.opsForZSet().add(key, value, score);
}
/**
* 从ZSet中删除元素
*/
public void removeFromZSet(String key, String value) {
redisTemplate.opsForZSet().remove(key, value);
}
/**
* 获取指定元素的当前排列顺序
*/
public Long getRank(String key, String value) {
return redisTemplate.opsForZSet().rank(key, value);
}
/**
* 获取当前ZSet中的最大score
*/
public Double getMaxScore(String key) {
Set<ZSetOperations.TypedTuple<String>> set = redisTemplate.opsForZSet().reverseRangeWithScores(key, 0, 0);
if (!CollectionUtils.isEmpty(set)) {
Double score = set.iterator().next().getScore();
return score + 1.0;
} else {
return 1.0;
}
}
/**
* 判断元素是否存在
*/
public Boolean isElementExistsInZSet(String key, String value) {
return redisTemplate.opsForZSet().score(key, value) != null;
}
/**
* 获取当前ZSet中数据量的总和
*/
public Long getZSetTotal(String key) {
return redisTemplate.opsForZSet().zCard(key);
}
//- - - - - - - - - - - - - - - - - - - - - set类型 - - - - - - - - - - - - - - - - - - - -
/**
* 将数据放入set缓存
*/
public void addToSet(String key, String value) {
redisTemplate.opsForSet().add(key, value);
}
/**
* 弹出变量中的元素
*/
public void removeFromSet(String key, String value) {
redisTemplate.opsForSet().remove(key, value);
}
/**
* 检查给定的元素是否在变量中。
*/
public Boolean isElementExistsInSet(String key, String obj) {
return redisTemplate.opsForSet().isMember(key, obj);
}
//- - - - - - - - - - - - - - - - - - - - - hash类型 - - - - - - - - - - - - - - - - - - - -
/**
* 加入缓存
*/
public void addToMap(String key, Map<String, String> map) {
redisTemplate.opsForHash().putAll(key, map);
}
/**
* 验证指定 key 下 有没有指定的 hashkey
*/
public Boolean isElementExistsInMap(String key, String hashKey) {
return redisTemplate.opsForHash().hasKey(key, hashKey);
}
/**
* 获取指定key的值string
*/
public String getMapValue(String key1, String key2) {
return String.valueOf(redisTemplate.opsForHash().get(key1, key2));
}
/**
* 删除指定 hash 的 HashKey
*
* @return 删除成功的 数量
*/
public Long removeFromMap(String key, String hashKeys) {
return redisTemplate.opsForHash().delete(key, hashKeys);
}
}

View File

@@ -0,0 +1,137 @@
package com.ai.da.common.utils;
public class SnowflakeUtil {
// ==============================Fields===========================================
/** 开始时间截 (2015-01-01) */
private final long twepoch = 1420041600000L;
/** 机器id所占的位数 */
private final long workerIdBits = 5L;
/** 数据标识id所占的位数 */
private final long datacenterIdBits = 5L;
/** 支持的最大机器id结果是31 (这个移位算法可以很快的计算出几位二进制数所能表示的最大十进制数) */
private final long maxWorkerId = -1L ^ (-1L << workerIdBits);
/** 支持的最大数据标识id结果是31 */
private final long maxDatacenterId = -1L ^ (-1L << datacenterIdBits);
/** 序列在id中占的位数 */
private final long sequenceBits = 12L;
/** 机器ID向左移12位 */
private final long workerIdShift = sequenceBits;
/** 数据标识id向左移17位(12+5) */
private final long datacenterIdShift = sequenceBits + workerIdBits;
/** 时间截向左移22位(5+5+12) */
private final long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
/** 生成序列的掩码这里为4095 (0b111111111111=0xfff=4095) */
private final long sequenceMask = -1L ^ (-1L << sequenceBits);
/** 工作机器ID(0~31) */
private long workerId;
/** 数据中心ID(0~31) */
private long datacenterId;
/** 毫秒内序列(0~4095) */
private long sequence = 0L;
/** 上次生成ID的时间截 */
private long lastTimestamp = -1L;
//==============================Constructors=====================================
/**
* 构造函数
* @param workerId 工作ID (0~31)
* @param datacenterId 数据中心ID (0~31)
*/
public SnowflakeUtil(long workerId, long datacenterId) {
if (workerId > maxWorkerId || workerId < 0) {
throw new IllegalArgumentException(String.format
("worker Id can't be greater than %d or less than 0", maxWorkerId));
}
if (datacenterId > maxDatacenterId || datacenterId < 0) {
throw new IllegalArgumentException(String.format
("datacenter Id can't be greater than %d or less than 0", maxDatacenterId));
}
this.workerId = workerId;
this.datacenterId = datacenterId;
}
// ==============================Methods==========================================
/**
* 获得下一个ID (该方法是线程安全的)
* @return SnowflakeId
*/
public synchronized long nextId() {
long timestamp = timeGen();
//如果当前时间小于上一次ID生成的时间戳说明系统时钟回退过这个时候应当抛出异常
if (timestamp < lastTimestamp) {
throw new RuntimeException(
String.format
("Clock moved backwards. Refusing to generate id for %d milliseconds", lastTimestamp - timestamp));
}
//如果是同一时间生成的,则进行毫秒内序列
if (lastTimestamp == timestamp) {
sequence = (sequence + 1) & sequenceMask;
//毫秒内序列溢出
if (sequence == 0) {
//阻塞到下一个毫秒,获得新的时间戳
timestamp = tilNextMillis(lastTimestamp);
}
}
//时间戳改变,毫秒内序列重置
else {
sequence = 0L;
}
//上次生成ID的时间截
lastTimestamp = timestamp;
//移位并通过或运算拼到一起组成64位的ID
return ((timestamp - twepoch) << timestampLeftShift) //
| (datacenterId << datacenterIdShift) //
| (workerId << workerIdShift) //
| sequence;
}
/**
* 阻塞到下一个毫秒,直到获得新的时间戳
* @param lastTimestamp 上次生成ID的时间截
* @return 当前时间戳
*/
protected long tilNextMillis(long lastTimestamp) {
long timestamp = timeGen();
while (timestamp <= lastTimestamp) {
timestamp = timeGen();
}
return timestamp;
}
/**
* 返回以毫秒为单位的当前时间
* @return 当前时间(毫秒)
*/
protected long timeGen() {
return System.currentTimeMillis();
}
//==============================Test=============================================
/** 测试 */
public static void main(String[] args) {
SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
long id = idWorker.nextId();
System.out.println("id:"+id);
//id:768842202204864512
}
}

View File

@@ -53,4 +53,25 @@ public class GenerateController {
return Response.success(generateService.generateDislike(generateDetailId, timeZone));
}
@ApiOperation(value = "发起生成请求,异步获取结果")
@PostMapping("/prepare")
public Response<String> prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) {
return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO));
}
@ApiOperation(value = "取消继续生成")
@GetMapping("/stopWaiting")
public Response<String> stopWaiting(@RequestParam("userId") Long userId,
@RequestParam("uniqueId") String uniqueId,
@RequestParam("timeZone") String timeZone) {
generateService.cancelGenerate(userId, uniqueId, timeZone);
return Response.success("stop waiting successfully");
}
@ApiOperation(value = "获取生成结果")
@GetMapping("/result")
public Response<GenerateCollectionVO> getGenerateResult(@RequestParam("uniqueId") String uniqueId) {
GenerateCollectionVO generateResult = generateService.getGenerateResult(uniqueId);
return Response.success(generateResult);
}
}

View File

@@ -0,0 +1,7 @@
package com.ai.da.mapper;
import com.ai.da.common.config.mybatis.plus.CommonMapper;
import com.ai.da.mapper.entity.GenerateCancel;
public interface GenerateCancelMapper extends CommonMapper<GenerateCancel> {
}

View File

@@ -27,6 +27,11 @@ public class Generate {
*/
private Long accountId;
/**
* 唯一id
*/
private String uniqueId;
/**
* Sketchboard Printboard
*/

View File

@@ -0,0 +1,44 @@
package com.ai.da.mapper.entity;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.Accessors;
import java.util.Date;
@Data
@EqualsAndHashCode(callSuper = false)
@Accessors(chain = true)
@TableName("t_generate_cancel")
public class GenerateCancel {
/**
* ID
*/
@TableId(value = "id", type = IdType.AUTO)
private Long id;
/**
* 用户ID
*/
private Long accountId;
/**
* 唯一id(任务id)
*/
private String uniqueId;
/**
* 创建时间
*/
private Date createDate;
public GenerateCancel(Long accountId, String uniqueId, Date createDate) {
this.accountId = accountId;
this.uniqueId = uniqueId;
this.createDate = createDate;
}
}

View File

@@ -5,10 +5,14 @@ import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
@Data
@ApiModel("GenerateThroughImageTextDTO")
public class GenerateThroughImageTextDTO {
@NotNull(message = "userId cannot be empty")
@ApiModelProperty("用户id")
Long userId;
@ApiModelProperty("caption")
String text;
@@ -40,4 +44,7 @@ public class GenerateThroughImageTextDTO {
@NotBlank(message = "timeZone cannot be empty!")
@ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取")
String timeZone;
@ApiModelProperty("唯一id用于保持消息唯一性")
String uniqueId;
}

View File

@@ -0,0 +1,27 @@
package com.ai.da.model.dto;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class GenerateToPythonDTO {
private Long user_id;
private String image_url;
private String category;
private String str;
private Integer mode;
private String version;
private String gender;
private String tasks_id;
}

View File

@@ -19,6 +19,13 @@ public class GenerateCollectionVO {
@ApiModelProperty("生成的图片信息")
private List<GenerateCollectionItemVO> generatedCollectionItems;
@ApiModelProperty("在当前队列中的排序")
private Long rankPosition;
public GenerateCollectionVO(Long rankPosition) {
this.rankPosition = rankPosition;
}
public GenerateCollectionVO(Long generateId, Long collectionId, List<GenerateCollectionItemVO> generatedCollectionItems) {
this.generateId = generateId;
this.collectionId = collectionId;

View File

@@ -23,7 +23,6 @@ import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import io.netty.util.internal.StringUtil;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.springframework.beans.factory.annotation.Value;
@@ -57,6 +56,7 @@ public class PythonService {
private String accessPythonIp;
@Value("${access.python.port:''}")
private String accessPythonPort;
@Resource
private PythonTAllInfoService pythonTAllInfoService;
@@ -107,7 +107,7 @@ public class PythonService {
if (Objects.nonNull(response.body())) {
responseBody = response.body().string();
JSONObject responseObj = JSON.parseObject(responseBody);
log.info("moodboard与printboard图片合成 python返回###{}",responseObj);
log.info("moodboard与printboard图片合成 python返回###{}", responseObj);
return responseObj.get("data").toString();
}
} catch (IOException | JSONException e) {
@@ -389,7 +389,7 @@ public class PythonService {
all.addAll(new ArrayList<>(DesignPythonItem.SKIRT_TROUSERS));
return all;
}
}else if (modelSex.equals(Sex.MALE.getValue())) {
} else if (modelSex.equals(Sex.MALE.getValue())) {
Long randomIndex = RandomsUtil.randomSysFile(0L, 3L);
if (randomIndex == 0) {
return DesignPythonItem.TOPS;
@@ -422,11 +422,11 @@ public class PythonService {
long noPinNum = printBoardElements.stream().filter(f -> f.getHasPin() == 0).count();
if (noPinNum == 0L) {
return 0;
}else {
} else {
long pinNum = printBoardElements.stream().filter(f -> f.getHasPin() == 1).count();
if (8 - pinNum < 4) {
return RandomsUtil.randomSysFile(0L, 8 - pinNum + 1);
}else {
} else {
return RandomsUtil.randomSysFile(0L, 4L + 1);
}
}
@@ -1600,7 +1600,7 @@ public class PythonService {
printBoardElements = elementVO.getPrintBoardElements()
.stream()
.filter(f -> !elementVO.getHasUseMd5List().contains(f.getMd5())).collect(Collectors.toList());
}else {
} else {
printBoardElements = elementVO.getPrintBoardElements();
}
if (CollectionUtil.isEmpty(printBoardElements)) {
@@ -2081,7 +2081,7 @@ public class PythonService {
skirt.setPrint(designPythonItemPrint);
skirt.setPath("aida-sys-image/images/female/trousers/trousers_974.jpg");
response.add(skirt);
}else {
} else {
DesignPythonItem top = new DesignPythonItem();
top.setType(MalePosition.TOPS.getValue());
top.setColor("none");
@@ -2195,7 +2195,9 @@ public class PythonService {
throw new BusinessException("design.interface.exception");
}
/** 暂时未用 */
/**
* 暂时未用
*/
public String generateSketchCaption(String url) {
//限流校验
AccessLimitUtils.validate("generateSketchCaption", 5);
@@ -2238,9 +2240,9 @@ public class PythonService {
throw new BusinessException("system error!");
}
public List<String> generateSketchOrPrint(Long userId, String url, String category, String text, int mode, String modelName, String gender) {
public List<String> generateSketchOrPrint(GenerateToPythonDTO generateToPythonDTO) {
//限流校验
AccessLimitUtils.validate("generateSketchOrPrint", 5);
// AccessLimitUtils.validate("generateSketchOrPrint", 5);
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
@@ -2248,46 +2250,45 @@ public class PythonService {
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
MediaType mediaType = MediaType.parse("application/json");
Map<String, Object> content = Maps.newHashMap();
content.put("user_id", userId);
content.put("image_url", url);
content.put("category", category);
content.put("mode", mode);
content.put("str", text);
content.put("version", "1");
content.put("gender", gender);
RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content, SerializerFeature.WriteMapNullValue));
RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue));
Request request = new Request.Builder()
// .url(accessPythonIp + ":2828/aida/diffusion")
// .url("http://18.167.251.121:9992")
.url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion")
// .url("http://127.0.0.1:5000/api/diffusion")
// .url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion")
.url(accessPythonIp + ":" + accessPythonPort + "/api/generate_image")
.method("POST", body)
.addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
// .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
.addHeader("Content-Type", "application/json")
.build();
Response response = null;
String bodyString ;
String bodyString;
try {
log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(content, SerializerFeature.WriteMapNullValue));
log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue));
response = client.newCall(request).execute();
} catch (IOException ioException) {
log.error("PythonService##generateSketchOrPrint异常###{}", ExceptionUtil.getThrowableList(ioException));
// throw new BusinessException("generate.interface.error");
throw new BusinessException(ioException.getMessage());
}
//去除限流
AccessLimitUtils.validateOut("generateSketchOrPrint");
// AccessLimitUtils.validateOut("generateSketchOrPrint");
// 判断是否生成失败
if (Objects.isNull(response) || Objects.isNull(response.body())) {
if (Objects.isNull(response.body())) {
log.error("PythonService##generateSketchOrPrint异常###{}", "response or body is empty!");
throw new BusinessException("generate.interface.error");
} else if (response.code() != HttpURLConnection.HTTP_OK){
// throw new BusinessException("generate.interface.error");
throw new BusinessException("PythonService##generateSketchOrPrint异常###: response or body is empty!");
} else if (response.code() != HttpURLConnection.HTTP_OK) {
log.error("PythonService##generateSketchOrPrint异常###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("generate.interface.error");
// throw new BusinessException("generate.interface.error");
throw new BusinessException("PythonService##generateSketchOrPrint异常### Response error!Response code ## " + response.code() + " ##");
} else {
try {
bodyString = response.body().string();
} catch (IOException e) {
throw new BusinessException("generate.interface.error");
log.error(e.getMessage());
// throw new BusinessException("generate.interface.error");
throw new BusinessException(e.getMessage());
}
}
JSONObject jsonObject = JSON.parseObject(bodyString);
@@ -2357,7 +2358,7 @@ public class PythonService {
if (Objects.isNull(response) || Objects.isNull(response.body())) {
log.error("PythonService##composeLayers异常###{}", "response or body is empty!");
throw new BusinessException("compose-layer.interface.exception");
}else if (response.code() != HttpURLConnection.HTTP_OK){
} else if (response.code() != HttpURLConnection.HTTP_OK) {
log.error("PythonService##composeLayers异常###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("compose-layer.interface.exception");
} else {
@@ -2382,10 +2383,10 @@ public class PythonService {
return item0.getString("synthesis_url");
}
public String getClothCategory(String path,String gender){
public String getClothCategory(String path, String gender) {
HashMap<String, String> content = new HashMap<>();
content.put("sketch_img_url",path);
content.put("colony",gender);
content.put("sketch_img_url", path);
content.put("colony", gender);
List<HashMap<String, String>> contents = Collections.singletonList(content);
String jsonString = JSON.toJSONString(contents, SerializerFeature.WriteNullStringAsEmpty);
@@ -2398,7 +2399,7 @@ public class PythonService {
if (Objects.isNull(response) || Objects.isNull(response.body())) {
log.error("PythonService##GetClothCategory###{}", "response or body is empty!");
throw new BusinessException("cloth-classification.interface.exception");
} else if (response.code() != HttpURLConnection.HTTP_OK){
} else if (response.code() != HttpURLConnection.HTTP_OK) {
log.error("PythonService##GetClothCategory###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("cloth-classification.interface.exception");
} else {
@@ -2409,7 +2410,7 @@ public class PythonService {
}
}
JSONObject jsonObject = JSON.parseObject(bodyString);
try{
try {
Boolean result = JSON.parseObject(JSON.toJSONString(response)).getBoolean("successful");
if (result && jsonObject.get("msg").equals("OK!")) {
JSONObject data = jsonObject.getJSONObject("data");
@@ -2417,7 +2418,7 @@ public class PythonService {
JSONObject map = (JSONObject) list.get(0);
return map.get("category").toString();
}
}catch (NullPointerException e){
} catch (NullPointerException e) {
log.info("getClothCategory 失败###{}未返回category", jsonObject);
throw new BusinessException("cloth-classification.interface.exception");
}
@@ -2425,4 +2426,35 @@ public class PythonService {
//生成失败
throw new BusinessException("cloth-classification.interface.exception");
}
public Boolean cancelGenerateTask(String taskId) {
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
String url = accessPythonIp + ":" + accessPythonPort + "/api/generate_cancel/" + taskId;
Request request = new Request.Builder()
.url(url)
// .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
// .addHeader("Content-Type", "application/json")
.build();
Response response;
try {
log.info("cancelGenerateTask请求入参content###{}", taskId);
response = client.newCall(request).execute();
} catch (IOException ioException) {
log.error("PythonService##cancelGenerateTask异常###{}", ExceptionUtil.getThrowableList(ioException));
return null;
}
if (response.code() != HttpURLConnection.HTTP_OK) {
log.info("generate-python 取消请求失败");
return Boolean.FALSE;
}
log.info("generate-python 取消请求成功");
return Boolean.TRUE;
}
}

View File

@@ -24,4 +24,13 @@ public interface GenerateService extends IService<Generate> {
void updateLikeStatusBatch(List<Long> generateDetailIdList, Byte hasLike, Long libraryId, String timeZone);
List<GenerateDetail> selectBatchByLibraryId(List<Long> libraryId);
GenerateCollectionVO getGenerateResult(String uniqueId);
String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO);
Long getRankPosition(String uniqueId);
void cancelGenerate(Long userId, String uniqueId, String timeZone);
}

View File

@@ -0,0 +1,11 @@
package com.ai.da.service;
import org.springframework.stereotype.Service;
@Service
public interface RabbitMQService {
void publishMessage(String message);
Integer getMessageCount(String queueUrl);
}

View File

@@ -810,9 +810,11 @@ public class CollectionElementServiceImpl extends ServiceImpl<CollectionElementM
CollectionElement collectionElement = null;
if (!Objects.isNull(elementId)) {
collectionElement = collectionElementMapper.selectById(elementId);
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(level2Type)) {
collectionElement.setLevel2Type(level2Type);
updateById(collectionElement);
if (!Objects.isNull(collectionElement)) {
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(level2Type)) {
collectionElement.setLevel2Type(level2Type);
updateById(collectionElement);
}
}
}
return collectionElement;

View File

@@ -4,26 +4,32 @@ import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.enums.GenerateModeEnum;
import com.ai.da.common.enums.ModelNameEnum;
import com.ai.da.common.utils.DateUtil;
import com.ai.da.common.utils.MD5Utils;
import com.ai.da.common.utils.MinioUtil;
import com.ai.da.common.utils.*;
import com.ai.da.mapper.CollectionElementMapper;
import com.ai.da.mapper.GenerateCancelMapper;
import com.ai.da.mapper.GenerateDetailMapper;
import com.ai.da.mapper.GenerateMapper;
import com.ai.da.mapper.entity.*;
import com.ai.da.model.dto.GenerateLikeDTO;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.dto.GenerateToPythonDTO;
import com.ai.da.model.vo.*;
import com.ai.da.python.PythonService;
import com.ai.da.service.CollectionElementService;
import com.ai.da.service.GenerateService;
import com.ai.da.service.LibraryService;
import com.ai.da.service.RabbitMQService;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import io.minio.errors.MinioException;
import io.netty.util.internal.StringUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import javax.annotation.Resource;
import java.io.IOException;
@@ -31,6 +37,7 @@ import java.util.*;
import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*;
@Slf4j
@Service
public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> implements GenerateService {
@@ -52,10 +59,31 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Resource
private MinioUtil minioUtil;
@Resource
private RabbitMQService rabbitMQService;
@Resource
private RedisUtil redisUtil;
@Resource
private GenerateCancelMapper generateCancelMapper;
@Value("${redis.key.consumptionOrder}")
private String consumptionOrderKey;
@Value("${redis.key.cancelSet}")
private String cancelSetKey;
@Value("${redis.key.exceptionMap}")
private String exceptionMapKey;
@Value("${redis.key.resultMap}")
private String resultMapKey;
@Override
public GenerateCaptionVO generateCaption(Long sketchElementId) {
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
if (Objects.isNull(collectionElement)){
if (Objects.isNull(collectionElement)) {
throw new BusinessException("the.image.does.not.exist.please.reselect");
}
String url = collectionElement.getUrl();
@@ -69,45 +97,46 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Transactional(rollbackFor = Exception.class)
public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、获取用户信息
AuthPrincipalVo userHolder = UserContext.getUserHolder();
Long accountId = generateThroughImageTextDTO.getUserId();
String generateType = generateThroughImageTextDTO.getGenerateType();
Long accountId = userHolder.getId();
if (!GenerateModeEnum.getGenerateModeList().contains(generateType)){
throw new BusinessException("unknown.generate.type");
}
// 2、判断必须入参是否为非空
Generate generate = new Generate();
generate.setAccountId(accountId);
generate.setUniqueId(generateThroughImageTextDTO.getUniqueId());
generate.setLevel1Type(generateThroughImageTextDTO.getLevel1Type());
// 当level1type是sketchboard时存数据库需要加上当前性别
generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ?
generateType + " (" +generateThroughImageTextDTO.getGender() + ")":
generateType + " (" + generateThroughImageTextDTO.getGender() + ")" :
generateType);
generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getVersion());
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(generate, text, elementId,generateType);
validateGeneraType(generate, text, elementId, generateType);
// 3、将请求信息落库
// 3.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
// 2.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type());
// 3.2 将本次generate的请求信息添加到t_generate表中
save(generate);
// 4、向模型发起请求
// 3、向模型发起请求
int mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ?
GenerateModeEnum.TEXT.getCode() :
GenerateModeEnum.TEXT_IMAGE.getCode();
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
// text = !StringUtil.isNullOrEmpty(text) && generateThroughImageTextDTO.getVersion().equals("1") ? "painting style, " + text : text;
List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
category, text, mode, generateThroughImageTextDTO.getVersion(), generateThroughImageTextDTO.getGender());
AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId()));
// List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
// category, text, mode, "1", generateThroughImageTextDTO.getGender()));
log.info("generate 响应 " + generatedSketchUrl);
if (CollectionUtils.isEmpty(generatedSketchUrl)) {
return null;
}
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
save(generate);
// 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库
@@ -117,8 +146,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(item, 24 * 60), Boolean.FALSE);
// 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过
List<Map<String,Long>> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generateThroughImageTextDTO.getLevel1Type());
if (!libraryIdList.isEmpty()){
List<Map<String, Long>> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generateThroughImageTextDTO.getLevel1Type());
if (!libraryIdList.isEmpty()) {
generateDetail.setIsLike((byte) 1);
generateDetail.setLibraryId(libraryIdList.get(0).get("library_id"));
generateCollectionItemVO.setIsLiked(Boolean.TRUE);
@@ -139,22 +168,22 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems);
}
private void validateGeneraType(Generate generate, String text, Long elementId,String generateType) {
private void validateGeneraType(Generate generate, String text, Long elementId, String generateType) {
switch (generateType) {
case "text":
if (StringUtil.isNullOrEmpty(text)){
if (StringUtil.isNullOrEmpty(text)) {
throw new BusinessException("please.input.the.caption");
}
generate.setText(text);
break;
case "image":
if (Objects.isNull(elementId)){
if (Objects.isNull(elementId)) {
throw new BusinessException("please.choose.an.image");
}
generate.setCollectionElementId(elementId);
break;
case "text-image":
if (StringUtil.isNullOrEmpty(text) || Objects.isNull(elementId)){
if (StringUtil.isNullOrEmpty(text) || Objects.isNull(elementId)) {
throw new BusinessException("please.input.the.caption.and.choose.an.image");
}
generate.setText(text);
@@ -169,21 +198,21 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 1、判断参数是否正确
// 1.1 必须参数是否非空
if (SKETCH_BOARD.getRealName().equals(generateLikeDTO.getLevel1Type())) {
if (StringUtil.isNullOrEmpty(generateLikeDTO.getLevel2Type())){
if (StringUtil.isNullOrEmpty(generateLikeDTO.getLevel2Type())) {
throw new BusinessException("level2Type.cannot.be.empty");
}
if (StringUtil.isNullOrEmpty(generateLikeDTO.getGender())){
if (StringUtil.isNullOrEmpty(generateLikeDTO.getGender())) {
throw new BusinessException("gender.cannot.be.empty");
}
}
// 1.2 判断参数是否真实有效
Long generateDetailId = generateLikeDTO.getGenerateDetailId();
GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId);
if (Objects.isNull(generateDetail)){
if (Objects.isNull(generateDetail)) {
throw new BusinessException("generateItem.does.not.exist");
}
Generate generate = getById(generateDetail.getGenerateId());
if (!generateLikeDTO.getLevel1Type().equals(generate.getLevel1Type())){
if (!generateLikeDTO.getLevel1Type().equals(generate.getLevel1Type())) {
throw new BusinessException("level1Type.does.not.match");
}
@@ -191,8 +220,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 2.1、不能重复喜欢
// 2.1.1 判断该图片是否被喜欢过
Library libraryDetail = libraryService.getById(generateDetail.getLibraryId());
if ( (Objects.nonNull(generateDetail.getLibraryId()) && !generateDetail.getLibraryId().equals(0L))
|| Objects.nonNull(libraryDetail)){
if ((Objects.nonNull(generateDetail.getLibraryId()) && !generateDetail.getLibraryId().equals(0L))
|| Objects.nonNull(libraryDetail)) {
throw new BusinessException("duplicate.likes.are.not.allowed");
}
@@ -215,7 +244,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
public Boolean generateDislike(Long generateDetailId, String timeZone) {
// 1、确定generateDetail中是否有这条记录
GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId);
if (Objects.isNull(generateDetail)){
if (Objects.isNull(generateDetail)) {
throw new BusinessException("generateItem.does.not.exist");
}
@@ -265,7 +294,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generateDetailMapper.update(generateDetail, queryWrapper);
}
public void updateLikeStatusBatch(List<Long> generateDetailIdList, Byte hasLike, Long libraryId, String timeZone){
public void updateLikeStatusBatch(List<Long> generateDetailIdList, Byte hasLike, Long libraryId, String timeZone) {
QueryWrapper<GenerateDetail> queryWrapper = new QueryWrapper<>();
queryWrapper.in("id", generateDetailIdList);
@@ -277,10 +306,156 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generateDetailMapper.update(generateDetail, queryWrapper);
}
public List<GenerateDetail> selectBatchByLibraryId(List<Long> libraryId){
public List<GenerateDetail> selectBatchByLibraryId(List<Long> libraryId) {
QueryWrapper<GenerateDetail> qw = new QueryWrapper<>();
qw.in("library_id",libraryId);
qw.in("library_id", libraryId);
return generateDetailMapper.selectList(qw);
}
@Override
public String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、参数检查判断必须参数是否为空
if (Objects.isNull(generateThroughImageTextDTO.getUserId())) {
throw new BusinessException("userId cannot be empty");
}
String generateType = generateThroughImageTextDTO.getGenerateType();
if (!GenerateModeEnum.getGenerateModeList().contains(generateType)) {
throw new BusinessException("unknown.generate.type");
}
String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(new Generate(), text, elementId, generateType);
// 2、生成唯一id 使用uuid
String uuid = UUID.randomUUID().toString();
// SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
// long snowflakeId = idWorker.nextId();
int num = 1;
// 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id
while ((redisUtil.isElementExistsInMap(resultMapKey, uuid) ||
redisUtil.isElementExistsInZSet(consumptionOrderKey, uuid))
&& num < 10) {
uuid = UUID.randomUUID().toString();
num++;
}
// 无依据确定的数字
if (num > 10) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
uuid = UUID.randomUUID().toString();
}
generateThroughImageTextDTO.setUniqueId(uuid);
String jsonString = JSON.toJSONString(generateThroughImageTextDTO);
// 3、加入redis排队便于获取实时排队信息
Double maxScore = redisUtil.getMaxScore(consumptionOrderKey);
redisUtil.addToZSet(consumptionOrderKey, uuid, maxScore);
// 4、将消息发布到MQ消息队列
rabbitMQService.publishMessage(jsonString);
// 5、返回唯一id
return uuid;
}
@Override
public Long getRankPosition(String uniqueId) {
// rank 从0开始
return redisUtil.getRank(consumptionOrderKey, uniqueId);
}
@Override
public GenerateCollectionVO getGenerateResult(String uniqueId) {
// 1、判断该请求是否已经异常
Boolean isMember = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId);
if (isMember) {
throw new BusinessException("generate.interface.error");
}
// 2、判断该请求是否还在排队
Boolean existsInZSet = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId);
if (existsInZSet) {
// 排队中,给出当前排序位置,rank从0开始
Long rankPosition = getRankPosition(uniqueId);
// 有9个消费者所以当rank>8即当前请求至少排在第九位时其实际排队位置为9-8+1当rank <=8请求均在处理中
return new GenerateCollectionVO(rankPosition > 8L ? rankPosition - 8 + 1 : 1L);
}
// 3、判断redis中有没有
boolean hasHashKey = redisUtil.isElementExistsInMap(resultMapKey, uniqueId);
if (hasHashKey) {
// 3.1 有直接从redis中拿
String resultString = redisUtil.getMapValue(resultMapKey, uniqueId);
return JSONObject.parseObject(resultString, GenerateCollectionVO.class);
}
// 3.2 判断数据库中有没有
Generate generate = selectByUniqueId(uniqueId);
if (Objects.isNull(generate)) {
// 3.3 还没执行完,给出当前位置
return new GenerateCollectionVO(1L);
}
Long generateId = generate.getId();
QueryWrapper<GenerateDetail> qw = new QueryWrapper<>();
qw.eq("generate_id", generateId);
List<GenerateDetail> generateDetails = generateDetailMapper.selectList(qw);
if (CollectionUtils.isEmpty(generateDetails)) {
// 会有这种情况吗存到generate中但是还没存到generateDetail中
return new GenerateCollectionVO(1L);
}
List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
generateDetails.forEach(item -> {
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
generateCollectionItemVO.setGenerateItemId(item.getId());
generateCollectionItemVO.setGenerateItemUrl(minioUtil.getPresignedUrl(item.getUrl(), 24 * 60));
generatedCollectionItems.add(generateCollectionItemVO);
});
return new GenerateCollectionVO(generateId, null, generatedCollectionItems);
}
public Generate selectByUniqueId(String uniqueId) {
QueryWrapper<Generate> qw = new QueryWrapper<>();
qw.eq("unique_id", uniqueId);
return getOne(qw);
}
@Override
@Transactional(rollbackFor = Exception.class)
public void cancelGenerate(Long userId, String uniqueId, String timeZone) {
// 1、确认当前消息是否还在排队中
Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId);
Boolean flag = Boolean.FALSE;
if (exists) flag = redisUtil.getRank(consumptionOrderKey, uniqueId) > 1L ? Boolean.TRUE : Boolean.FALSE;
// 不管flag的默认值是true还是false只要exists为false&& 将短路
if (exists && flag) {
// 1.1、将需要取消的唯一id加入redis以便及时取消生成
redisUtil.addToSet(cancelSetKey, uniqueId);
// 1.2 将需要取消的id从redis的ConsumptionOrder中删除
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
} else {
// 2、判断该消息是否异常
boolean hasKey = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId);
// 3、判断该消息是否已经消费结束
Boolean existsInResult = redisUtil.isElementExistsInMap(resultMapKey, uniqueId);
if (!hasKey && !existsInResult) {
// 设置取等待状态为false
AsyncCallerUtil.waitingStatus.put(uniqueId, false);
// 3、直接发送取消请求到python端
pythonService.cancelGenerateTask(uniqueId);
}
}
// 3、考虑加一张表专门用于记录哪些用户在什么时间进行了取消操作
GenerateCancel generateCancel = new GenerateCancel(userId, uniqueId, DateUtil.getByTimeZone(timeZone));
generateCancelMapper.insert(generateCancel);
}
}

View File

@@ -0,0 +1,76 @@
package com.ai.da.service.impl;
import cn.hutool.core.exceptions.ExceptionUtil;
import com.ai.da.common.RabbitMQ.MQPublisher;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.service.RabbitMQService;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
@Slf4j
@Service
public class RabbitMQServiceImpl implements RabbitMQService {
@Resource
private MQPublisher mqPublisher;
@Override
public void publishMessage(String message) {
mqPublisher.sendGenerateMessage(message);
}
@Override
public Integer getMessageCount(String queueUrl) {
String url = "http://localhost:15672/api/queues/%2f/generate-queue";
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
Request request = new Request.Builder()
.url(queueUrl)
.method("GET",null)
.addHeader("Authorization", "Basic Z3Vlc3Q6Z3Vlc3Q=")
.build();
Response response = null;
try {
response = client.newCall(request).execute();
} catch (IOException ioException) {
log.error("RabbitMQService##" + "getMessage异常###{}", ExceptionUtil.getThrowableList(ioException));
}
String bodyString;
// 生成失败
if (Objects.isNull(response) || Objects.isNull(response.body())) {
log.error("RabbitMQService##getMessageCount异常###{}", "response or body is empty!");
throw new BusinessException("compose-layer.interface.exception");
}else if (response.code() != HttpURLConnection.HTTP_OK){
log.error("RabbitMQService##getMessageCount异常###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("compose-layer.interface.exception");
} else {
try {
bodyString = response.body().string();
} catch (IOException e) {
throw new BusinessException("compose-layer.interface.exception");
}
}
JSONObject jsonObject = JSON.parseObject(bodyString);
String messageCount = jsonObject.get("messages").toString();
return Integer.parseInt(messageCount);
}
}