generate模型更换后的接口更改及异步获取结果

This commit is contained in:
2024-04-18 14:07:20 +08:00
parent 8d330e8ad9
commit 896120fea4
13 changed files with 222 additions and 85 deletions

View File

@@ -4,10 +4,7 @@ import com.ai.da.mapper.primary.entity.Generate;
import com.ai.da.mapper.primary.entity.GenerateDetail;
import com.ai.da.model.dto.GenerateLikeDTO;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.vo.GenerateCaptionVO;
import com.ai.da.model.vo.GenerateCollectionVO;
import com.ai.da.model.vo.GenerateLikeVO;
import com.ai.da.model.vo.PrepareForGenerateVO;
import com.ai.da.model.vo.*;
import com.baomidou.mybatisplus.extension.service.IService;
import java.util.List;
@@ -16,7 +13,9 @@ public interface GenerateService extends IService<Generate> {
GenerateCaptionVO generateCaption(Long sketchElementId);
GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO);
void generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO);
void processGenerateResult(String taskId, String url);
GenerateLikeVO generateLike(GenerateLikeDTO generateLikeDTO);
@@ -28,7 +27,9 @@ public interface GenerateService extends IService<Generate> {
GenerateCollectionVO getGenerateResult(String uniqueId);
PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO);
List<GenerateResultVO> getGenerateResultList(List<String> taskIdList);
List<String> prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO);
Long getRankPosition(String uniqueId);

View File

@@ -37,6 +37,7 @@ import javax.annotation.Resource;
import java.io.File;
import java.io.IOException;
import java.math.BigDecimal;
import java.time.LocalDateTime;
import java.util.*;
import java.util.stream.Collectors;
@@ -834,6 +835,8 @@ public class CollectionElementServiceImpl extends ServiceImpl<CollectionElementM
} else {
throw new BusinessException("element source type cannot be empty!");
}
}else {
return null;
}
return collectionElement;
}
@@ -867,7 +870,7 @@ public class CollectionElementServiceImpl extends ServiceImpl<CollectionElementM
generateDetail.setLibraryId(libraryIds.get(0).get("library_id"));
}
generateDetail.setMd5(md5);
generateDetail.setCreateDate(DateUtil.getByTimeZone(timeZone));
generateDetail.setCreateDate(LocalDateTime.now());
return generateDetail;
}

View File

@@ -1,6 +1,7 @@
package com.ai.da.service.impl;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.constant.CommonConstant;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.enums.GenerateModeEnum;
import com.ai.da.common.enums.ModelNameEnum;
@@ -23,6 +24,7 @@ import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.google.gson.Gson;
import io.minio.errors.MinioException;
import io.netty.util.internal.StringUtil;
import lombok.extern.slf4j.Slf4j;
@@ -33,6 +35,7 @@ import org.springframework.util.CollectionUtils;
import javax.annotation.Resource;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.*;
import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*;
@@ -80,6 +83,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Value("${redis.key.resultMap}")
private String resultMapKey;
@Value("${redis.key.generateResult}")
private String generateResultKey;
@Override
public GenerateCaptionVO generateCaption(Long sketchElementId) {
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
@@ -95,12 +101,12 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Override
@Transactional(rollbackFor = Exception.class)
public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
public void generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、获取用户信息
Long accountId = generateThroughImageTextDTO.getUserId();
String generateType = generateThroughImageTextDTO.getGenerateType();
// 2、判断必须入参是否为非空
// 2、判断必须入参是否为非空(在prepare阶段已校验)
Generate generate = new Generate();
generate.setAccountId(accountId);
generate.setUniqueId(generateThroughImageTextDTO.getUniqueId());
@@ -121,27 +127,38 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType());
// 3、向模型发起请求
int mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ?
GenerateModeEnum.TEXT.getCode() :
GenerateModeEnum.TEXT_IMAGE.getCode();
String mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ?
GenerateModeEnum.TEXT.getType() :
GenerateModeEnum.TEXT_IMAGE.getType();
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId()));
// List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
// category, text, mode, "1", generateThroughImageTextDTO.getGender()));
log.info("generate 响应 " + generatedSketchUrl);
if (CollectionUtils.isEmpty(generatedSketchUrl)) {
return null;
}
// AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
// List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
// category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId()));
Boolean requestResult = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text,Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
mode, category));
// log.info("generate 响应 " + generatedSketchUrl);
// if (CollectionUtils.isEmpty(generatedSketchUrl)) {
// return null;
// }
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
save(generate);
// 5、将本次请求存入redis
String key = generateResultKey + ":" + generateThroughImageTextDTO.getUniqueId();
String status;
if (requestResult){
status = "Executing";
}else {
status = "Fail";
}
GenerateResultVO generateResultVO = new GenerateResultVO(null, null, status);
redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
// 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库
List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
/*List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
generatedSketchUrl.forEach(item -> {
GenerateDetail generateDetail = new GenerateDetail();
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
@@ -166,7 +183,35 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 6、将模型返回的图片地址返回给前端
Long collectionId = Objects.isNull(collectionElement) ? null : collectionElement.getCollectionId();
return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems);
return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems);*/
}
@Override
@Transactional(rollbackFor = Exception.class)
public void processGenerateResult(String taskId, String url){
// 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库
GenerateDetail generateDetail = new GenerateDetail();
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
Generate generate = selectByUniqueId(taskId);
String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(url, 24 * 60), Boolean.FALSE);
// 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过
List<Map<String, Long>> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generate.getLevel1Type());
if (!libraryIdList.isEmpty()) {
generateDetail.setIsLike((byte) 1);
generateDetail.setLibraryId(libraryIdList.get(0).get("library_id"));
generateCollectionItemVO.setIsLiked(Boolean.TRUE);
}
generateDetail.setUrl(url);
generateDetail.setGenerateId(generate.getId());
generateDetail.setCreateDate(LocalDateTime.now());
generateDetail.setMd5(md5);
generateDetailMapper.insert(generateDetail);
String key = generateResultKey + ":" + taskId;
Long expire = redisUtil.getExpire(key);
GenerateResultVO generateResultVO = new GenerateResultVO(generateDetail.getId(), url, "Success");
redisUtil.addToString(key, new Gson().toJson(generateResultVO), expire);
}
private void validateGeneraType(Generate generate, String text, Long elementId, String generateType) {
@@ -315,7 +360,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
}
@Override
public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
public List<String> prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、参数检查判断必须参数是否为空
if (Objects.isNull(generateThroughImageTextDTO.getUserId())) {
throw new BusinessException("userId cannot be empty");
@@ -330,7 +376,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
if (generateThroughImageTextDTO.getIsTestUser()){
trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type());
if (trialsCount >= 2){
return new PrepareForGenerateVO(0);
return new ArrayList<>();
}
}
@@ -341,9 +387,6 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 2、生成唯一id 使用uuid
String uuid = UUID.randomUUID().toString();
// SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
// long snowflakeId = idWorker.nextId();
int num = 1;
// 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id
while ((redisUtil.isElementExistsInMap(resultMapKey, uuid) ||
@@ -361,18 +404,25 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
}
uuid = UUID.randomUUID().toString();
}
generateThroughImageTextDTO.setUniqueId(uuid);
String jsonString = JSON.toJSONString(generateThroughImageTextDTO);
// 3、加入redis排队便于获取实时排队信息
Double maxScore = redisUtil.getMaxScore(consumptionOrderKey);
redisUtil.addToZSet(consumptionOrderKey, uuid, maxScore);
ArrayList<String> taskIdList = new ArrayList<>();
for (int i = 1 ; i <= 4 ; i++){
String temp = uuid;
temp += "-" + i + "-" + generateThroughImageTextDTO.getUserId();
taskIdList.add(temp);
generateThroughImageTextDTO.setUniqueId(temp);
String jsonString = JSON.toJSONString(generateThroughImageTextDTO);
// 4、将消息发布到MQ消息队列
rabbitMQService.publishMessageToGenerate(jsonString);
// 3、加入redis排队便于获取实时排队信息
Double maxScore = redisUtil.getMaxScore(consumptionOrderKey);
redisUtil.addToZSet(consumptionOrderKey, temp, maxScore);
// 4、将消息发布到MQ消息队列
rabbitMQService.publishMessageToGenerate(jsonString);
}
// 5、返回唯一id
return new PrepareForGenerateVO(uuid, 2 - trialsCount);
return taskIdList;
}
@Override
@@ -432,6 +482,21 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
return new GenerateCollectionVO(generateId, null, generatedCollectionItems);
}
@Override
public List<GenerateResultVO> getGenerateResultList(List<String> taskIdList) {
List<GenerateResultVO> results = new ArrayList<>();
taskIdList.forEach(taskId -> {
String key = generateResultKey + ":" + taskId;
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
if (!StringUtil.isNullOrEmpty(generateResultVO.getUrl())) {
generateResultVO.setUrl(minioUtil.getPresignedUrl(generateResultVO.getUrl(), CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
}
results.add(generateResultVO);
});
return results;
}
public Generate selectByUniqueId(String uniqueId) {
QueryWrapper<Generate> qw = new QueryWrapper<>();
qw.eq("unique_id", uniqueId);