Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
shahaibo
2024-01-26 13:18:37 +08:00
24 changed files with 1167 additions and 82 deletions

View File

@@ -810,9 +810,11 @@ public class CollectionElementServiceImpl extends ServiceImpl<CollectionElementM
CollectionElement collectionElement = null;
if (!Objects.isNull(elementId)) {
collectionElement = collectionElementMapper.selectById(elementId);
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(level2Type)) {
collectionElement.setLevel2Type(level2Type);
updateById(collectionElement);
if (!Objects.isNull(collectionElement)) {
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(level2Type)) {
collectionElement.setLevel2Type(level2Type);
updateById(collectionElement);
}
}
}
return collectionElement;

View File

@@ -4,26 +4,32 @@ import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.enums.GenerateModeEnum;
import com.ai.da.common.enums.ModelNameEnum;
import com.ai.da.common.utils.DateUtil;
import com.ai.da.common.utils.MD5Utils;
import com.ai.da.common.utils.MinioUtil;
import com.ai.da.common.utils.*;
import com.ai.da.mapper.CollectionElementMapper;
import com.ai.da.mapper.GenerateCancelMapper;
import com.ai.da.mapper.GenerateDetailMapper;
import com.ai.da.mapper.GenerateMapper;
import com.ai.da.mapper.entity.*;
import com.ai.da.model.dto.GenerateLikeDTO;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.dto.GenerateToPythonDTO;
import com.ai.da.model.vo.*;
import com.ai.da.python.PythonService;
import com.ai.da.service.CollectionElementService;
import com.ai.da.service.GenerateService;
import com.ai.da.service.LibraryService;
import com.ai.da.service.RabbitMQService;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import io.minio.errors.MinioException;
import io.netty.util.internal.StringUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import javax.annotation.Resource;
import java.io.IOException;
@@ -31,6 +37,7 @@ import java.util.*;
import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*;
@Slf4j
@Service
public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> implements GenerateService {
@@ -52,10 +59,31 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Resource
private MinioUtil minioUtil;
@Resource
private RabbitMQService rabbitMQService;
@Resource
private RedisUtil redisUtil;
@Resource
private GenerateCancelMapper generateCancelMapper;
@Value("${redis.key.consumptionOrder}")
private String consumptionOrderKey;
@Value("${redis.key.cancelSet}")
private String cancelSetKey;
@Value("${redis.key.exceptionMap}")
private String exceptionMapKey;
@Value("${redis.key.resultMap}")
private String resultMapKey;
@Override
public GenerateCaptionVO generateCaption(Long sketchElementId) {
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
if (Objects.isNull(collectionElement)){
if (Objects.isNull(collectionElement)) {
throw new BusinessException("the.image.does.not.exist.please.reselect");
}
String url = collectionElement.getUrl();
@@ -69,45 +97,46 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Transactional(rollbackFor = Exception.class)
public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、获取用户信息
AuthPrincipalVo userHolder = UserContext.getUserHolder();
Long accountId = generateThroughImageTextDTO.getUserId();
String generateType = generateThroughImageTextDTO.getGenerateType();
Long accountId = userHolder.getId();
if (!GenerateModeEnum.getGenerateModeList().contains(generateType)){
throw new BusinessException("unknown.generate.type");
}
// 2、判断必须入参是否为非空
Generate generate = new Generate();
generate.setAccountId(accountId);
generate.setUniqueId(generateThroughImageTextDTO.getUniqueId());
generate.setLevel1Type(generateThroughImageTextDTO.getLevel1Type());
// 当level1type是sketchboard时存数据库需要加上当前性别
generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ?
generateType + " (" +generateThroughImageTextDTO.getGender() + ")":
generateType + " (" + generateThroughImageTextDTO.getGender() + ")" :
generateType);
generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getVersion());
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(generate, text, elementId,generateType);
validateGeneraType(generate, text, elementId, generateType);
// 3、将请求信息落库
// 3.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
// 2.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type());
// 3.2 将本次generate的请求信息添加到t_generate表中
save(generate);
// 4、向模型发起请求
// 3、向模型发起请求
int mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ?
GenerateModeEnum.TEXT.getCode() :
GenerateModeEnum.TEXT_IMAGE.getCode();
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
// text = !StringUtil.isNullOrEmpty(text) && generateThroughImageTextDTO.getVersion().equals("1") ? "painting style, " + text : text;
List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
category, text, mode, generateThroughImageTextDTO.getVersion(), generateThroughImageTextDTO.getGender());
AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId()));
// List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
// category, text, mode, "1", generateThroughImageTextDTO.getGender()));
log.info("generate 响应 " + generatedSketchUrl);
if (CollectionUtils.isEmpty(generatedSketchUrl)) {
return null;
}
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
save(generate);
// 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库
@@ -117,8 +146,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(item, 24 * 60), Boolean.FALSE);
// 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过
List<Map<String,Long>> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generateThroughImageTextDTO.getLevel1Type());
if (!libraryIdList.isEmpty()){
List<Map<String, Long>> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generateThroughImageTextDTO.getLevel1Type());
if (!libraryIdList.isEmpty()) {
generateDetail.setIsLike((byte) 1);
generateDetail.setLibraryId(libraryIdList.get(0).get("library_id"));
generateCollectionItemVO.setIsLiked(Boolean.TRUE);
@@ -139,22 +168,22 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems);
}
private void validateGeneraType(Generate generate, String text, Long elementId,String generateType) {
private void validateGeneraType(Generate generate, String text, Long elementId, String generateType) {
switch (generateType) {
case "text":
if (StringUtil.isNullOrEmpty(text)){
if (StringUtil.isNullOrEmpty(text)) {
throw new BusinessException("please.input.the.caption");
}
generate.setText(text);
break;
case "image":
if (Objects.isNull(elementId)){
if (Objects.isNull(elementId)) {
throw new BusinessException("please.choose.an.image");
}
generate.setCollectionElementId(elementId);
break;
case "text-image":
if (StringUtil.isNullOrEmpty(text) || Objects.isNull(elementId)){
if (StringUtil.isNullOrEmpty(text) || Objects.isNull(elementId)) {
throw new BusinessException("please.input.the.caption.and.choose.an.image");
}
generate.setText(text);
@@ -169,21 +198,21 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 1、判断参数是否正确
// 1.1 必须参数是否非空
if (SKETCH_BOARD.getRealName().equals(generateLikeDTO.getLevel1Type())) {
if (StringUtil.isNullOrEmpty(generateLikeDTO.getLevel2Type())){
if (StringUtil.isNullOrEmpty(generateLikeDTO.getLevel2Type())) {
throw new BusinessException("level2Type.cannot.be.empty");
}
if (StringUtil.isNullOrEmpty(generateLikeDTO.getGender())){
if (StringUtil.isNullOrEmpty(generateLikeDTO.getGender())) {
throw new BusinessException("gender.cannot.be.empty");
}
}
// 1.2 判断参数是否真实有效
Long generateDetailId = generateLikeDTO.getGenerateDetailId();
GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId);
if (Objects.isNull(generateDetail)){
if (Objects.isNull(generateDetail)) {
throw new BusinessException("generateItem.does.not.exist");
}
Generate generate = getById(generateDetail.getGenerateId());
if (!generateLikeDTO.getLevel1Type().equals(generate.getLevel1Type())){
if (!generateLikeDTO.getLevel1Type().equals(generate.getLevel1Type())) {
throw new BusinessException("level1Type.does.not.match");
}
@@ -191,8 +220,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 2.1、不能重复喜欢
// 2.1.1 判断该图片是否被喜欢过
Library libraryDetail = libraryService.getById(generateDetail.getLibraryId());
if ( (Objects.nonNull(generateDetail.getLibraryId()) && !generateDetail.getLibraryId().equals(0L))
|| Objects.nonNull(libraryDetail)){
if ((Objects.nonNull(generateDetail.getLibraryId()) && !generateDetail.getLibraryId().equals(0L))
|| Objects.nonNull(libraryDetail)) {
throw new BusinessException("duplicate.likes.are.not.allowed");
}
@@ -215,7 +244,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
public Boolean generateDislike(Long generateDetailId, String timeZone) {
// 1、确定generateDetail中是否有这条记录
GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId);
if (Objects.isNull(generateDetail)){
if (Objects.isNull(generateDetail)) {
throw new BusinessException("generateItem.does.not.exist");
}
@@ -265,7 +294,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generateDetailMapper.update(generateDetail, queryWrapper);
}
public void updateLikeStatusBatch(List<Long> generateDetailIdList, Byte hasLike, Long libraryId, String timeZone){
public void updateLikeStatusBatch(List<Long> generateDetailIdList, Byte hasLike, Long libraryId, String timeZone) {
QueryWrapper<GenerateDetail> queryWrapper = new QueryWrapper<>();
queryWrapper.in("id", generateDetailIdList);
@@ -277,10 +306,156 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generateDetailMapper.update(generateDetail, queryWrapper);
}
public List<GenerateDetail> selectBatchByLibraryId(List<Long> libraryId){
public List<GenerateDetail> selectBatchByLibraryId(List<Long> libraryId) {
QueryWrapper<GenerateDetail> qw = new QueryWrapper<>();
qw.in("library_id",libraryId);
qw.in("library_id", libraryId);
return generateDetailMapper.selectList(qw);
}
@Override
public String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、参数检查判断必须参数是否为空
if (Objects.isNull(generateThroughImageTextDTO.getUserId())) {
throw new BusinessException("userId cannot be empty");
}
String generateType = generateThroughImageTextDTO.getGenerateType();
if (!GenerateModeEnum.getGenerateModeList().contains(generateType)) {
throw new BusinessException("unknown.generate.type");
}
String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(new Generate(), text, elementId, generateType);
// 2、生成唯一id 使用uuid
String uuid = UUID.randomUUID().toString();
// SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
// long snowflakeId = idWorker.nextId();
int num = 1;
// 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id
while ((redisUtil.isElementExistsInMap(resultMapKey, uuid) ||
redisUtil.isElementExistsInZSet(consumptionOrderKey, uuid))
&& num < 10) {
uuid = UUID.randomUUID().toString();
num++;
}
// 无依据确定的数字
if (num > 10) {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
uuid = UUID.randomUUID().toString();
}
generateThroughImageTextDTO.setUniqueId(uuid);
String jsonString = JSON.toJSONString(generateThroughImageTextDTO);
// 3、加入redis排队便于获取实时排队信息
Double maxScore = redisUtil.getMaxScore(consumptionOrderKey);
redisUtil.addToZSet(consumptionOrderKey, uuid, maxScore);
// 4、将消息发布到MQ消息队列
rabbitMQService.publishMessage(jsonString);
// 5、返回唯一id
return uuid;
}
@Override
public Long getRankPosition(String uniqueId) {
// rank 从0开始
return redisUtil.getRank(consumptionOrderKey, uniqueId);
}
@Override
public GenerateCollectionVO getGenerateResult(String uniqueId) {
// 1、判断该请求是否已经异常
Boolean isMember = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId);
if (isMember) {
throw new BusinessException("generate.interface.error");
}
// 2、判断该请求是否还在排队
Boolean existsInZSet = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId);
if (existsInZSet) {
// 排队中,给出当前排序位置,rank从0开始
Long rankPosition = getRankPosition(uniqueId);
// 有9个消费者所以当rank>8即当前请求至少排在第九位时其实际排队位置为9-8+1当rank <=8请求均在处理中
return new GenerateCollectionVO(rankPosition > 8L ? rankPosition - 8 + 1 : 1L);
}
// 3、判断redis中有没有
boolean hasHashKey = redisUtil.isElementExistsInMap(resultMapKey, uniqueId);
if (hasHashKey) {
// 3.1 有直接从redis中拿
String resultString = redisUtil.getMapValue(resultMapKey, uniqueId);
return JSONObject.parseObject(resultString, GenerateCollectionVO.class);
}
// 3.2 判断数据库中有没有
Generate generate = selectByUniqueId(uniqueId);
if (Objects.isNull(generate)) {
// 3.3 还没执行完,给出当前位置
return new GenerateCollectionVO(1L);
}
Long generateId = generate.getId();
QueryWrapper<GenerateDetail> qw = new QueryWrapper<>();
qw.eq("generate_id", generateId);
List<GenerateDetail> generateDetails = generateDetailMapper.selectList(qw);
if (CollectionUtils.isEmpty(generateDetails)) {
// 会有这种情况吗存到generate中但是还没存到generateDetail中
return new GenerateCollectionVO(1L);
}
List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
generateDetails.forEach(item -> {
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
generateCollectionItemVO.setGenerateItemId(item.getId());
generateCollectionItemVO.setGenerateItemUrl(minioUtil.getPresignedUrl(item.getUrl(), 24 * 60));
generatedCollectionItems.add(generateCollectionItemVO);
});
return new GenerateCollectionVO(generateId, null, generatedCollectionItems);
}
public Generate selectByUniqueId(String uniqueId) {
QueryWrapper<Generate> qw = new QueryWrapper<>();
qw.eq("unique_id", uniqueId);
return getOne(qw);
}
@Override
@Transactional(rollbackFor = Exception.class)
public void cancelGenerate(Long userId, String uniqueId, String timeZone) {
// 1、确认当前消息是否还在排队中
Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId);
Boolean flag = Boolean.FALSE;
if (exists) flag = redisUtil.getRank(consumptionOrderKey, uniqueId) > 1L ? Boolean.TRUE : Boolean.FALSE;
// 不管flag的默认值是true还是false只要exists为false&& 将短路
if (exists && flag) {
// 1.1、将需要取消的唯一id加入redis以便及时取消生成
redisUtil.addToSet(cancelSetKey, uniqueId);
// 1.2 将需要取消的id从redis的ConsumptionOrder中删除
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
} else {
// 2、判断该消息是否异常
boolean hasKey = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId);
// 3、判断该消息是否已经消费结束
Boolean existsInResult = redisUtil.isElementExistsInMap(resultMapKey, uniqueId);
if (!hasKey && !existsInResult) {
// 设置取等待状态为false
AsyncCallerUtil.waitingStatus.put(uniqueId, false);
// 3、直接发送取消请求到python端
pythonService.cancelGenerateTask(uniqueId);
}
}
// 3、考虑加一张表专门用于记录哪些用户在什么时间进行了取消操作
GenerateCancel generateCancel = new GenerateCancel(userId, uniqueId, DateUtil.getByTimeZone(timeZone));
generateCancelMapper.insert(generateCancel);
}
}

View File

@@ -0,0 +1,76 @@
package com.ai.da.service.impl;
import cn.hutool.core.exceptions.ExceptionUtil;
import com.ai.da.common.RabbitMQ.MQPublisher;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.service.RabbitMQService;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
@Slf4j
@Service
public class RabbitMQServiceImpl implements RabbitMQService {
@Resource
private MQPublisher mqPublisher;
@Override
public void publishMessage(String message) {
mqPublisher.sendGenerateMessage(message);
}
@Override
public Integer getMessageCount(String queueUrl) {
String url = "http://localhost:15672/api/queues/%2f/generate-queue";
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
Request request = new Request.Builder()
.url(queueUrl)
.method("GET",null)
.addHeader("Authorization", "Basic Z3Vlc3Q6Z3Vlc3Q=")
.build();
Response response = null;
try {
response = client.newCall(request).execute();
} catch (IOException ioException) {
log.error("RabbitMQService##" + "getMessage异常###{}", ExceptionUtil.getThrowableList(ioException));
}
String bodyString;
// 生成失败
if (Objects.isNull(response) || Objects.isNull(response.body())) {
log.error("RabbitMQService##getMessageCount异常###{}", "response or body is empty!");
throw new BusinessException("compose-layer.interface.exception");
}else if (response.code() != HttpURLConnection.HTTP_OK){
log.error("RabbitMQService##getMessageCount异常###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("compose-layer.interface.exception");
} else {
try {
bodyString = response.body().string();
} catch (IOException e) {
throw new BusinessException("compose-layer.interface.exception");
}
}
JSONObject jsonObject = JSON.parseObject(bodyString);
String messageCount = jsonObject.get("messages").toString();
return Integer.parseInt(messageCount);
}
}