1、generate 异步生成及获取排队情况
2、generate 取消生成
This commit is contained in:
@@ -810,9 +810,11 @@ public class CollectionElementServiceImpl extends ServiceImpl<CollectionElementM
|
||||
CollectionElement collectionElement = null;
|
||||
if (!Objects.isNull(elementId)) {
|
||||
collectionElement = collectionElementMapper.selectById(elementId);
|
||||
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(level2Type)) {
|
||||
collectionElement.setLevel2Type(level2Type);
|
||||
updateById(collectionElement);
|
||||
if (!Objects.isNull(collectionElement)) {
|
||||
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(level2Type)) {
|
||||
collectionElement.setLevel2Type(level2Type);
|
||||
updateById(collectionElement);
|
||||
}
|
||||
}
|
||||
}
|
||||
return collectionElement;
|
||||
|
||||
@@ -4,26 +4,30 @@ import com.ai.da.common.config.exception.BusinessException;
|
||||
import com.ai.da.common.context.UserContext;
|
||||
import com.ai.da.common.enums.GenerateModeEnum;
|
||||
import com.ai.da.common.enums.ModelNameEnum;
|
||||
import com.ai.da.common.utils.DateUtil;
|
||||
import com.ai.da.common.utils.MD5Utils;
|
||||
import com.ai.da.common.utils.MinioUtil;
|
||||
import com.ai.da.common.utils.*;
|
||||
import com.ai.da.mapper.CollectionElementMapper;
|
||||
import com.ai.da.mapper.GenerateDetailMapper;
|
||||
import com.ai.da.mapper.GenerateMapper;
|
||||
import com.ai.da.mapper.entity.*;
|
||||
import com.ai.da.model.dto.GenerateLikeDTO;
|
||||
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
||||
import com.ai.da.model.dto.GenerateToPythonDTO;
|
||||
import com.ai.da.model.vo.*;
|
||||
import com.ai.da.python.PythonService;
|
||||
import com.ai.da.service.CollectionElementService;
|
||||
import com.ai.da.service.GenerateService;
|
||||
import com.ai.da.service.LibraryService;
|
||||
import com.ai.da.service.RabbitMQService;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import io.minio.errors.MinioException;
|
||||
import io.netty.util.internal.StringUtil;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.IOException;
|
||||
@@ -52,6 +56,24 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
@Resource
|
||||
private MinioUtil minioUtil;
|
||||
|
||||
@Resource
|
||||
private RabbitMQService rabbitMQService;
|
||||
|
||||
@Resource
|
||||
private RedisUtil redisUtil;
|
||||
|
||||
@Value("${redis.key.consumptionOrder}")
|
||||
private String consumptionOrderKey;
|
||||
|
||||
@Value("${redis.key.cancelSet}")
|
||||
private String cancelSetKey;
|
||||
|
||||
@Value("${redis.key.exceptionMap}")
|
||||
private String exceptionMapKey;
|
||||
|
||||
@Value("${redis.key.resultMap}")
|
||||
private String resultMapKey;
|
||||
|
||||
@Override
|
||||
public GenerateCaptionVO generateCaption(Long sketchElementId) {
|
||||
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
|
||||
@@ -69,16 +91,13 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||
// 1、获取用户信息
|
||||
AuthPrincipalVo userHolder = UserContext.getUserHolder();
|
||||
Long accountId = generateThroughImageTextDTO.getUserId();
|
||||
String generateType = generateThroughImageTextDTO.getGenerateType();
|
||||
Long accountId = userHolder.getId();
|
||||
if (!GenerateModeEnum.getGenerateModeList().contains(generateType)){
|
||||
throw new BusinessException("unknown.generate.type");
|
||||
}
|
||||
|
||||
// 2、判断必须入参是否为非空
|
||||
Generate generate = new Generate();
|
||||
generate.setAccountId(accountId);
|
||||
generate.setUniqueId(generateThroughImageTextDTO.getUniqueId());
|
||||
generate.setLevel1Type(generateThroughImageTextDTO.getLevel1Type());
|
||||
// 当level1type是sketchboard时,存数据库需要加上当前性别
|
||||
generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ?
|
||||
@@ -87,27 +106,30 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getVersion());
|
||||
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
|
||||
|
||||
|
||||
String text = generateThroughImageTextDTO.getText();
|
||||
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
|
||||
validateGeneraType(generate, text, elementId,generateType);
|
||||
|
||||
// 3、将请求信息落库
|
||||
// 3.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
|
||||
// 2.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
|
||||
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type());
|
||||
|
||||
// 3.2 将本次generate的请求信息添加到t_generate表中
|
||||
save(generate);
|
||||
|
||||
// 4、向模型发起请求
|
||||
// 3、向模型发起请求
|
||||
int mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ?
|
||||
GenerateModeEnum.TEXT.getCode() :
|
||||
GenerateModeEnum.TEXT_IMAGE.getCode();
|
||||
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
|
||||
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
|
||||
// text = !StringUtil.isNullOrEmpty(text) && generateThroughImageTextDTO.getVersion().equals("1") ? "painting style, " + text : text;
|
||||
List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
|
||||
category, text, mode, generateThroughImageTextDTO.getVersion(), generateThroughImageTextDTO.getGender());
|
||||
AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
|
||||
List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? null : collectionElement.getUrl(),
|
||||
category, text, mode, "1", generateThroughImageTextDTO.getGender()),0L);
|
||||
// List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
|
||||
// category, text, mode, "1", generateThroughImageTextDTO.getGender()));
|
||||
if (CollectionUtils.isEmpty(generatedSketchUrl)){
|
||||
return null;
|
||||
}
|
||||
|
||||
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
|
||||
save(generate);
|
||||
|
||||
// 5、处理模型返回的数据
|
||||
// 5.1 将相应的url保存到数据库
|
||||
@@ -283,4 +305,122 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
|
||||
return generateDetailMapper.selectList(qw);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||
// 1、参数检查,判断必须参数是否为空
|
||||
if (Objects.isNull(generateThroughImageTextDTO.getUserId())){
|
||||
throw new BusinessException("userId cannot be empty");
|
||||
}
|
||||
String generateType = generateThroughImageTextDTO.getGenerateType();
|
||||
if (!GenerateModeEnum.getGenerateModeList().contains(generateType)){
|
||||
throw new BusinessException("unknown.generate.type");
|
||||
}
|
||||
String text = generateThroughImageTextDTO.getText();
|
||||
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
|
||||
validateGeneraType(new Generate(), text, elementId,generateType);
|
||||
|
||||
// 2、生成唯一id
|
||||
SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
|
||||
long snowflakeId = idWorker.nextId();
|
||||
|
||||
if (AsyncCallerUtil.waitingStatus.containsKey(snowflakeId)){
|
||||
snowflakeId = idWorker.nextId();
|
||||
}
|
||||
generateThroughImageTextDTO.setUniqueId(snowflakeId);
|
||||
String jsonString = JSON.toJSONString(generateThroughImageTextDTO);
|
||||
|
||||
// 3、加入redis排队,便于获取实时排队信息
|
||||
Double maxScore = redisUtil.getMaxScore(consumptionOrderKey);
|
||||
redisUtil.addToZSet(consumptionOrderKey, String.valueOf(snowflakeId),maxScore);
|
||||
|
||||
// 4、将消息发布到MQ消息队列
|
||||
rabbitMQService.publishMessage(jsonString);
|
||||
|
||||
// 5、返回唯一id
|
||||
return snowflakeId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getRankPosition(Long uniqueId) {
|
||||
return redisUtil.getRank(consumptionOrderKey, String.valueOf(uniqueId));
|
||||
}
|
||||
|
||||
@Override
|
||||
public GenerateCollectionVO getGenerateResult(Long uniqueId) {
|
||||
// 1、判断该请求是否已经异常
|
||||
Boolean isMember = redisUtil.isElementExistsInMap(exceptionMapKey, String.valueOf(uniqueId));
|
||||
if (isMember){
|
||||
throw new BusinessException("generate.interface.error");
|
||||
}
|
||||
|
||||
// 2、判断该请求是否还在排队
|
||||
Boolean existsInZSet = redisUtil.isElementExistsInZSet(consumptionOrderKey, String.valueOf(uniqueId));
|
||||
if (existsInZSet){
|
||||
// 排队中,给出当前排序位置
|
||||
return new GenerateCollectionVO(getRankPosition(uniqueId) + 1L);
|
||||
}
|
||||
|
||||
// 3、判断redis中有没有
|
||||
boolean hasHashKey = redisUtil.isElementExistsInMap(resultMapKey, String.valueOf(uniqueId));
|
||||
if (hasHashKey){
|
||||
// 3.1 有直接从redis中拿
|
||||
String resultString = redisUtil.getMapValue(resultMapKey, String.valueOf(uniqueId));
|
||||
return JSONObject.parseObject(resultString,GenerateCollectionVO.class);
|
||||
}
|
||||
|
||||
// 3.2 判断数据库中有没有
|
||||
Generate generate = selectByUniqueId(uniqueId);
|
||||
if (Objects.isNull(generate)){
|
||||
// 3.3 还没执行完,给出当前位置
|
||||
return new GenerateCollectionVO(0L);
|
||||
}
|
||||
Long generateId = generate.getId();
|
||||
QueryWrapper<GenerateDetail> qw = new QueryWrapper<>();
|
||||
qw.eq("generate_id",generateId);
|
||||
List<GenerateDetail> generateDetails = generateDetailMapper.selectList(qw);
|
||||
if (CollectionUtils.isEmpty(generateDetails)){
|
||||
// 会有这种情况吗?存到generate中,但是还没存到generateDetail中
|
||||
return new GenerateCollectionVO(0L);
|
||||
}
|
||||
|
||||
List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
|
||||
generateDetails.forEach(item -> {
|
||||
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
|
||||
generateCollectionItemVO.setGenerateItemId(item.getId());
|
||||
generateCollectionItemVO.setGenerateItemUrl(minioUtil.getPresignedUrl(item.getUrl(), 24 * 60));
|
||||
generatedCollectionItems.add(generateCollectionItemVO);
|
||||
});
|
||||
|
||||
return new GenerateCollectionVO(generateId, null, generatedCollectionItems);
|
||||
}
|
||||
|
||||
public Generate selectByUniqueId(Long uniqueId){
|
||||
QueryWrapper<Generate> qw = new QueryWrapper<>();
|
||||
qw.eq("unique_id",uniqueId);
|
||||
|
||||
return getOne(qw);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cancelGenerate(Long uniqueId) {
|
||||
// 1、确认当前消息是否还在排队中
|
||||
Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, String.valueOf(uniqueId));
|
||||
if (exists){
|
||||
// 1.1、将需要取消的唯一id加入redis,以便及时取消生成
|
||||
redisUtil.addToSet(cancelSetKey, String.valueOf(uniqueId));
|
||||
// 1.2 将需要取消的id从redis的ConsumptionOrder中删除
|
||||
redisUtil.removeFromZSet(consumptionOrderKey, String.valueOf(uniqueId));
|
||||
}else {
|
||||
// 2、判断该消息是否异常
|
||||
boolean hasKey = redisUtil.isElementExistsInMap(exceptionMapKey, String.valueOf(uniqueId));
|
||||
// 3、判断该消息是否已经消费结束
|
||||
Boolean existsInResult = redisUtil.isElementExistsInMap(resultMapKey, String.valueOf(uniqueId));
|
||||
if (!hasKey && !existsInResult){
|
||||
// 设置取等待状态为false
|
||||
AsyncCallerUtil.waitingStatus.put(uniqueId,false);
|
||||
// 3、直接发送取消请求到python端
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.ai.da.service.impl;
|
||||
|
||||
import cn.hutool.core.exceptions.ExceptionUtil;
|
||||
import com.ai.da.common.RabbitMQ.MQPublisher;
|
||||
import com.ai.da.common.config.exception.BusinessException;
|
||||
import com.ai.da.service.RabbitMQService;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.MediaType;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.Response;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.IOException;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class RabbitMQServiceImpl implements RabbitMQService {
|
||||
|
||||
@Resource
|
||||
private MQPublisher mqPublisher;
|
||||
|
||||
@Override
|
||||
public void publishMessage(String message) {
|
||||
mqPublisher.sendGenerateMessage(message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getMessageCount(String queueUrl) {
|
||||
|
||||
String url = "http://localhost:15672/api/queues/%2f/generate-queue";
|
||||
|
||||
OkHttpClient client = new OkHttpClient().newBuilder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
|
||||
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
|
||||
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
|
||||
.build();
|
||||
Request request = new Request.Builder()
|
||||
.url(queueUrl)
|
||||
.method("GET",null)
|
||||
.addHeader("Authorization", "Basic Z3Vlc3Q6Z3Vlc3Q=")
|
||||
.build();
|
||||
Response response = null;
|
||||
try {
|
||||
response = client.newCall(request).execute();
|
||||
} catch (IOException ioException) {
|
||||
log.error("RabbitMQService##" + "getMessage异常###{}", ExceptionUtil.getThrowableList(ioException));
|
||||
}
|
||||
|
||||
String bodyString;
|
||||
// 生成失败
|
||||
if (Objects.isNull(response) || Objects.isNull(response.body())) {
|
||||
log.error("RabbitMQService##getMessageCount异常###{}", "response or body is empty!");
|
||||
throw new BusinessException("compose-layer.interface.exception");
|
||||
}else if (response.code() != HttpURLConnection.HTTP_OK){
|
||||
log.error("RabbitMQService##getMessageCount异常###{}", "Response error!Response code ## " + response.code() + " ##");
|
||||
throw new BusinessException("compose-layer.interface.exception");
|
||||
} else {
|
||||
try {
|
||||
bodyString = response.body().string();
|
||||
} catch (IOException e) {
|
||||
throw new BusinessException("compose-layer.interface.exception");
|
||||
}
|
||||
}
|
||||
JSONObject jsonObject = JSON.parseObject(bodyString);
|
||||
String messageCount = jsonObject.get("messages").toString();
|
||||
return Integer.parseInt(messageCount);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user