From 896120fea4149abebcf11d46c2843e477f752191 Mon Sep 17 00:00:00 2001 From: xupei Date: Thu, 18 Apr 2024 14:07:20 +0800 Subject: [PATCH] =?UTF-8?q?generate=E6=A8=A1=E5=9E=8B=E6=9B=B4=E6=8D=A2?= =?UTF-8?q?=E5=90=8E=E7=9A=84=E6=8E=A5=E5=8F=A3=E6=9B=B4=E6=94=B9=E5=8F=8A?= =?UTF-8?q?=E5=BC=82=E6=AD=A5=E8=8E=B7=E5=8F=96=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../da/common/RabbitMQ/GenerateConsumer.java | 53 ++++++-- .../com/ai/da/common/RabbitMQ/MQConfig.java | 15 ++- .../ai/da/common/constant/CommonConstant.java | 3 + .../ai/da/common/enums/GenerateModeEnum.java | 13 +- .../ai/da/common/utils/AsyncCallerUtil.java | 3 +- .../ai/da/controller/GenerateController.java | 24 ++-- .../mapper/primary/entity/GenerateDetail.java | 3 +- .../ai/da/model/dto/GenerateToPythonDTO.java | 31 +++-- .../java/com/ai/da/python/PythonService.java | 16 ++- .../com/ai/da/service/GenerateService.java | 13 +- .../impl/CollectionElementServiceImpl.java | 5 +- .../da/service/impl/GenerateServiceImpl.java | 123 +++++++++++++----- src/main/resources/application-dev.properties | 5 +- 13 files changed, 222 insertions(+), 85 deletions(-) diff --git a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java index 58f9435a..76bf15ec 100644 --- a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java +++ b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java @@ -3,9 +3,10 @@ package com.ai.da.common.RabbitMQ; import com.ai.da.common.config.exception.BusinessException; import com.ai.da.common.utils.RedisUtil; import com.ai.da.model.dto.GenerateThroughImageTextDTO; -import com.ai.da.model.vo.GenerateCollectionVO; +import com.ai.da.model.vo.GenerateResultVO; import com.ai.da.service.GenerateService; import com.alibaba.fastjson.JSONObject; +import com.google.gson.Gson; import com.rabbitmq.client.Channel; import lombok.extern.slf4j.Slf4j; import org.springframework.amqp.core.Message; @@ -17,7 +18,7 @@ import org.springframework.stereotype.Component; import javax.annotation.Resource; import java.io.IOException; import java.util.HashMap; -import java.util.Objects; +import java.util.Map; @Slf4j @@ -42,6 +43,9 @@ public class GenerateConsumer { @Value("${redis.key.resultMap}") private String resultMapKey; + @Value("${redis.key.generateResult}") + private String generateResultKey; + public void generate(Message msg, Channel channel, String consumerName) { log.info("============start listening=========="); long start = System.currentTimeMillis(); @@ -63,20 +67,16 @@ public class GenerateConsumer { // 2.2 将该消息从取消列表中删除 // redisUtil.removeFromSet(cancelSetKey, uniqueId); } else { - /*try { - Thread.sleep(15000); - } catch (InterruptedException e) { - throw new RuntimeException(e); - }*/ - GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO); +// GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO); + generateService.generateThroughImageText(generateThroughImageTextDTO); // 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除 redisUtil.removeFromZSet(consumptionOrderKey, uniqueId); - if (!Objects.isNull(generateCollectionVO)) { + /*if (!Objects.isNull(generateCollectionVO)) { HashMap generateResult = new HashMap<>(); generateResult.put(uniqueId, JSONObject.toJSONString(generateCollectionVO)); // 将结果存在redis中 ,为空时不要存 redisUtil.addToMap(resultMapKey, generateResult); - } + }*/ } } catch (BusinessException e) { @@ -104,6 +104,34 @@ public class GenerateConsumer { log.info("=============end listening==========="); } + public void processGenerateResult(Message msg, Channel channel){ + log.info("============ProcessGenerateResult listening=========="); + long start = System.currentTimeMillis(); + + Map generateResult = JSONObject.parseObject(msg.getBody(), Map.class); +// log.info("tasks_id : {}, message : {}",generateResult.get("tasks_id"), generateResult.get("message") ); + if (generateResult.get("status").equals("SUCCESS")){ + String url = generateResult.get("data"); + String taskId = generateResult.get("tasks_id"); + generateService.processGenerateResult(taskId, url); + }else { + // 修改redis中的数据状态为exception + String key = generateResultKey + ":" + generateResult.get("tasks_id"); + Long expire = redisUtil.getExpire(key); + redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(null, null, "Fail")), expire); + // 将异常信息存到exception中 + HashMap exceptionInfo = new HashMap<>(); + exceptionInfo.put(generateResult.get("tasks_id"), generateResult.get("data")); + // 存redis + redisUtil.addToMap(exceptionMapKey, exceptionInfo); + } + + long end = System.currentTimeMillis(); + log.info("tasks_id : {}, message : {}, 执行时长: {} 毫秒",generateResult.get("tasks_id"), generateResult.get("message"), (end - start)); + log.info("============ProcessGenerateResult End listening=========="); + + } + @RabbitListener(queues = MQConfig.GENERATE_QUEUE) @RabbitHandler public void generateConsumer1(Message msg, Channel channel) { @@ -158,4 +186,9 @@ public class GenerateConsumer { generate(msg, channel, "consumer 9"); } + @RabbitListener(queues = MQConfig.GENERATE_RESULT_QUEUE) + @RabbitHandler + public void getGenerateResult(Message msg, Channel channel){ + processGenerateResult(msg, channel); + } } diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java index 78e6fc25..e9df0264 100644 --- a/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java +++ b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java @@ -10,14 +10,17 @@ public class MQConfig { public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange"; // public static final String GENERATE_QUEUE = "generate-queue-prod"; // public static final String GENERATE_QUEUE = "generate-queue-test"; -// public static final String GENERATE_QUEUE = "generate-queue-dev"; - public static final String GENERATE_QUEUE = "generate-queue-local"; +// public static final String GENERATE_QUEUE = "generate-queue-local"; + public static final String GENERATE_QUEUE = "generate-queue-dev"; -// public static final String SR_QUEUE = "SR-queue-dev"; - public static final String SR_QUEUE = "SR-queue-local"; +// public static final String SR_QUEUE = "SR-queue-local"; + public static final String SR_QUEUE = "SR-queue-dev"; - public static final String SR_RESULT_QUEUE = "SuperResolution-local"; -// public static final String SR_RESULT_QUEUE = "SuperResolution-dev"; +// public static final String SR_RESULT_QUEUE = "SuperResolution-local"; + public static final String SR_RESULT_QUEUE = "SuperResolution-dev"; + +// public static final String GENERATE_RESULT_QUEUE = "GenerateImage-local"; + public static final String GENERATE_RESULT_QUEUE = "GenerateImage-dev"; public MQConfig() { } diff --git a/src/main/java/com/ai/da/common/constant/CommonConstant.java b/src/main/java/com/ai/da/common/constant/CommonConstant.java index f7ad3b1c..63076fd4 100644 --- a/src/main/java/com/ai/da/common/constant/CommonConstant.java +++ b/src/main/java/com/ai/da/common/constant/CommonConstant.java @@ -7,4 +7,7 @@ public class CommonConstant { public static final Long CREDITS_EXPIRE_TIME = 2 * 24 * 60 * 60L; // 单位 分钟 public static final Integer MINIO_IMAGE_EXPIRE_TIME = 24 * 60; + // 单位 秒 一天过期 in redis + public static final Long GENERATE_RESULT_EXPIRE_TIME = 24 * 60 * 60L; + } diff --git a/src/main/java/com/ai/da/common/enums/GenerateModeEnum.java b/src/main/java/com/ai/da/common/enums/GenerateModeEnum.java index 44409731..fb4f8e94 100644 --- a/src/main/java/com/ai/da/common/enums/GenerateModeEnum.java +++ b/src/main/java/com/ai/da/common/enums/GenerateModeEnum.java @@ -12,26 +12,33 @@ public enum GenerateModeEnum { /** * 通过文本生成 */ - TEXT(1, "text"), + TEXT(1, "text","txt2img"), /** * 通过图片生成 */ - IMAGE(2, "image"), + IMAGE(2, "image", "img2img"), /** * 通过文本和图片生成 */ - TEXT_IMAGE(2, "text-image"); + TEXT_IMAGE(2, "text-image","txt2img"); private Integer code; private String value; + private String type; GenerateModeEnum(int code, String value) { this.code = code; this.value = value; } + GenerateModeEnum(Integer code, String value, String type) { + this.code = code; + this.value = value; + this.type = type; + } + public static List getGenerateModeList(){ return Stream.of(TEXT,IMAGE,TEXT_IMAGE).map(GenerateModeEnum::getValue).collect(Collectors.toList()); } diff --git a/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java b/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java index 29801a6b..40ae25db 100644 --- a/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java +++ b/src/main/java/com/ai/da/common/utils/AsyncCallerUtil.java @@ -25,7 +25,8 @@ public class AsyncCallerUtil { } public CompletableFuture> callGenerateAsync(GenerateToPythonDTO generateToPython) { - return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython)); +// return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython)); + return null; } public List generate(GenerateToPythonDTO generateToPython) { diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java index c752fbc0..895e5195 100644 --- a/src/main/java/com/ai/da/controller/GenerateController.java +++ b/src/main/java/com/ai/da/controller/GenerateController.java @@ -3,10 +3,7 @@ package com.ai.da.controller; import com.ai.da.common.response.Response; import com.ai.da.model.dto.GenerateLikeDTO; import com.ai.da.model.dto.GenerateThroughImageTextDTO; -import com.ai.da.model.vo.GenerateCaptionVO; -import com.ai.da.model.vo.GenerateCollectionVO; -import com.ai.da.model.vo.GenerateLikeVO; -import com.ai.da.model.vo.PrepareForGenerateVO; +import com.ai.da.model.vo.*; import com.ai.da.service.GenerateService; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; @@ -16,6 +13,7 @@ import org.springframework.web.bind.annotation.*; import javax.annotation.Resource; import javax.validation.Valid; +import java.util.List; /** * @author XP @@ -37,8 +35,9 @@ public class GenerateController { @ApiOperation("通过文字、图片生成图片") @PostMapping("/sketchAndPrint") - public Response generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { - return Response.success(generateService.generateThroughImageText(generateThroughImageTextDTO)); + public void generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { +// return Response.success(generateService.generateThroughImageText(generateThroughImageTextDTO)); + generateService.generateThroughImageText(generateThroughImageTextDTO); } @ApiOperation("喜欢生成的图片") @@ -56,7 +55,7 @@ public class GenerateController { @ApiOperation(value = "发起生成请求,异步获取结果") @PostMapping("/prepare") - public Response prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public Response> prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO)); } @@ -69,10 +68,19 @@ public class GenerateController { return Response.success("stop waiting successfully"); } - @ApiOperation(value = "获取生成结果") + /*@ApiOperation(value = "获取生成结果") @GetMapping("/result") public Response getGenerateResult(@RequestParam("uniqueId") String uniqueId) { GenerateCollectionVO generateResult = generateService.getGenerateResult(uniqueId); return Response.success(generateResult); + }*/ + + @ApiOperation(value = "获取生成结果") + @PostMapping("/result") + public Response> getGenerateResults(@Valid @RequestBody List taskIdList) { + List generateResult = generateService.getGenerateResultList(taskIdList); + return Response.success(generateResult); } + + } diff --git a/src/main/java/com/ai/da/mapper/primary/entity/GenerateDetail.java b/src/main/java/com/ai/da/mapper/primary/entity/GenerateDetail.java index b016e7fa..6cf9f0b3 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/GenerateDetail.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/GenerateDetail.java @@ -8,6 +8,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.experimental.Accessors; +import java.time.LocalDateTime; import java.util.Date; @Data @@ -50,7 +51,7 @@ public class GenerateDetail { /** * 创建时间 */ - private Date createDate; + private LocalDateTime createDate; /** * 更新时间 diff --git a/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java b/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java index 927fff70..ea62e425 100644 --- a/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java +++ b/src/main/java/com/ai/da/model/dto/GenerateToPythonDTO.java @@ -6,22 +6,31 @@ import lombok.NoArgsConstructor; @Data @NoArgsConstructor -@AllArgsConstructor +//@AllArgsConstructor public class GenerateToPythonDTO { - - private Long user_id; + // 去掉 +// private Long user_id; private String image_url; private String category; + // 改为prompt +// private String content; + private String prompt; - private String content; - - private Integer mode; - - private String version; - - private String gender; - + private String mode; + // 去除 +// private String version; + // 去掉 +// private String gender; + // taskId的最后拼接用户id private String tasks_id; + + public GenerateToPythonDTO(String tasks_id, String prompt, String image_url, String mode, String category) { + this.image_url = image_url; + this.category = category; + this.prompt = prompt; + this.mode = mode; + this.tasks_id = tasks_id; + } } diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index 9ad94860..7f7c4e5f 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -2880,7 +2880,7 @@ public class PythonService { throw new BusinessException("system error!"); } - public List generateSketchOrPrint(GenerateToPythonDTO generateToPythonDTO) { + public Boolean generateSketchOrPrint(GenerateToPythonDTO generateToPythonDTO) { //限流校验 // AccessLimitUtils.validate("generateSketchOrPrint", 5); OkHttpClient client = new OkHttpClient().newBuilder() @@ -2895,7 +2895,8 @@ public class PythonService { // .url("http://18.167.251.121:9992") // .url("http://127.0.0.1:5000/api/diffusion") // .url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion") - .url(accessPythonIp + ":" + accessPythonPort + "/api/generate_image") +// .url(accessPythonIp + ":" + accessPythonPort + "/api/generate_image") + .url(srPythonPort + "/api/generate_image") .method("POST", body) // .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==") .addHeader("Content-Type", "application/json") @@ -2936,12 +2937,13 @@ public class PythonService { if (result && jsonObject.get("code").equals(200)) { log.info("Generate##responseObject###{}", jsonObject); - return setGenerateImageList(jsonObject.getJSONObject("data")); +// return setGenerateImageList(jsonObject.getJSONObject("data")); + return Boolean.TRUE; + }else { + log.info("generateSketchOrPrintPrint失败###{}", jsonObject); + log.info("Generate Exception! Code : " + jsonObject.get("code")); + return Boolean.FALSE; } - log.info("generateSketchOrPrintPrint失败###{}", jsonObject); - log.info("Generate Exception! Code : " + jsonObject.get("code")); - //生成失败 - throw new BusinessException("generate.interface.error"); } public Response sendPostToModel(String content, String portAndRoute, String functionName) { diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index 8e3a841c..3b6c60b0 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -4,10 +4,7 @@ import com.ai.da.mapper.primary.entity.Generate; import com.ai.da.mapper.primary.entity.GenerateDetail; import com.ai.da.model.dto.GenerateLikeDTO; import com.ai.da.model.dto.GenerateThroughImageTextDTO; -import com.ai.da.model.vo.GenerateCaptionVO; -import com.ai.da.model.vo.GenerateCollectionVO; -import com.ai.da.model.vo.GenerateLikeVO; -import com.ai.da.model.vo.PrepareForGenerateVO; +import com.ai.da.model.vo.*; import com.baomidou.mybatisplus.extension.service.IService; import java.util.List; @@ -16,7 +13,9 @@ public interface GenerateService extends IService { GenerateCaptionVO generateCaption(Long sketchElementId); - GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO); + void generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO); + + void processGenerateResult(String taskId, String url); GenerateLikeVO generateLike(GenerateLikeDTO generateLikeDTO); @@ -28,7 +27,9 @@ public interface GenerateService extends IService { GenerateCollectionVO getGenerateResult(String uniqueId); - PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO); + List getGenerateResultList(List taskIdList); + + List prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO); Long getRankPosition(String uniqueId); diff --git a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java index 1a9a2357..5c9d526d 100644 --- a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java @@ -37,6 +37,7 @@ import javax.annotation.Resource; import java.io.File; import java.io.IOException; import java.math.BigDecimal; +import java.time.LocalDateTime; import java.util.*; import java.util.stream.Collectors; @@ -834,6 +835,8 @@ public class CollectionElementServiceImpl extends ServiceImpl i @Value("${redis.key.resultMap}") private String resultMapKey; + @Value("${redis.key.generateResult}") + private String generateResultKey; + @Override public GenerateCaptionVO generateCaption(Long sketchElementId) { CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId); @@ -95,12 +101,12 @@ public class GenerateServiceImpl extends ServiceImpl i @Override @Transactional(rollbackFor = Exception.class) - public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public void generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、获取用户信息 Long accountId = generateThroughImageTextDTO.getUserId(); String generateType = generateThroughImageTextDTO.getGenerateType(); - // 2、判断必须入参是否为非空 + // 2、判断必须入参是否为非空(在prepare阶段已校验) Generate generate = new Generate(); generate.setAccountId(accountId); generate.setUniqueId(generateThroughImageTextDTO.getUniqueId()); @@ -121,27 +127,38 @@ public class GenerateServiceImpl extends ServiceImpl i CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType()); // 3、向模型发起请求 - int mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ? - GenerateModeEnum.TEXT.getCode() : - GenerateModeEnum.TEXT_IMAGE.getCode(); + String mode = GenerateModeEnum.TEXT.getValue().equals(generateType) ? + GenerateModeEnum.TEXT.getType() : + GenerateModeEnum.TEXT_IMAGE.getType(); String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" : generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard"; - AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil(); - List generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), - category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId())); -// List generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(), -// category, text, mode, "1", generateThroughImageTextDTO.getGender())); - log.info("generate 响应 : " + generatedSketchUrl); - if (CollectionUtils.isEmpty(generatedSketchUrl)) { - return null; - } +// AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil(); +// List generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), +// category, text, mode, "1", generateThroughImageTextDTO.getGender(), generateThroughImageTextDTO.getUniqueId())); + Boolean requestResult = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text,Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), + mode, category)); +// log.info("generate 响应 : " + generatedSketchUrl); +// if (CollectionUtils.isEmpty(generatedSketchUrl)) { +// return null; +// } // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中 save(generate); + // 5、将本次请求存入redis + String key = generateResultKey + ":" + generateThroughImageTextDTO.getUniqueId(); + String status; + if (requestResult){ + status = "Executing"; + }else { + status = "Fail"; + } + GenerateResultVO generateResultVO = new GenerateResultVO(null, null, status); + redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); + // 5、处理模型返回的数据 // 5.1 将相应的url保存到数据库 - List generatedCollectionItems = new ArrayList<>(); + /*List generatedCollectionItems = new ArrayList<>(); generatedSketchUrl.forEach(item -> { GenerateDetail generateDetail = new GenerateDetail(); GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); @@ -166,7 +183,35 @@ public class GenerateServiceImpl extends ServiceImpl i // 6、将模型返回的图片地址返回给前端 Long collectionId = Objects.isNull(collectionElement) ? null : collectionElement.getCollectionId(); - return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems); + return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems);*/ + } + + @Override + @Transactional(rollbackFor = Exception.class) + public void processGenerateResult(String taskId, String url){ + // 5、处理模型返回的数据 + // 5.1 将相应的url保存到数据库 + GenerateDetail generateDetail = new GenerateDetail(); + GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); + Generate generate = selectByUniqueId(taskId); + String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(url, 24 * 60), Boolean.FALSE); + // 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过 + List> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generate.getLevel1Type()); + if (!libraryIdList.isEmpty()) { + generateDetail.setIsLike((byte) 1); + generateDetail.setLibraryId(libraryIdList.get(0).get("library_id")); + generateCollectionItemVO.setIsLiked(Boolean.TRUE); + } + generateDetail.setUrl(url); + generateDetail.setGenerateId(generate.getId()); + generateDetail.setCreateDate(LocalDateTime.now()); + generateDetail.setMd5(md5); + generateDetailMapper.insert(generateDetail); + + String key = generateResultKey + ":" + taskId; + Long expire = redisUtil.getExpire(key); + GenerateResultVO generateResultVO = new GenerateResultVO(generateDetail.getId(), url, "Success"); + redisUtil.addToString(key, new Gson().toJson(generateResultVO), expire); } private void validateGeneraType(Generate generate, String text, Long elementId, String generateType) { @@ -315,7 +360,8 @@ public class GenerateServiceImpl extends ServiceImpl i } @Override - public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { +// public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public List prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、参数检查,判断必须参数是否为空 if (Objects.isNull(generateThroughImageTextDTO.getUserId())) { throw new BusinessException("userId cannot be empty"); @@ -330,7 +376,7 @@ public class GenerateServiceImpl extends ServiceImpl i if (generateThroughImageTextDTO.getIsTestUser()){ trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type()); if (trialsCount >= 2){ - return new PrepareForGenerateVO(0); + return new ArrayList<>(); } } @@ -341,9 +387,6 @@ public class GenerateServiceImpl extends ServiceImpl i // 2、生成唯一id 使用uuid String uuid = UUID.randomUUID().toString(); -// SnowflakeUtil idWorker = new SnowflakeUtil(0, 0); -// long snowflakeId = idWorker.nextId(); - int num = 1; // 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id while ((redisUtil.isElementExistsInMap(resultMapKey, uuid) || @@ -361,18 +404,25 @@ public class GenerateServiceImpl extends ServiceImpl i } uuid = UUID.randomUUID().toString(); } - generateThroughImageTextDTO.setUniqueId(uuid); - String jsonString = JSON.toJSONString(generateThroughImageTextDTO); - // 3、加入redis排队,便于获取实时排队信息 - Double maxScore = redisUtil.getMaxScore(consumptionOrderKey); - redisUtil.addToZSet(consumptionOrderKey, uuid, maxScore); + ArrayList taskIdList = new ArrayList<>(); + for (int i = 1 ; i <= 4 ; i++){ + String temp = uuid; + temp += "-" + i + "-" + generateThroughImageTextDTO.getUserId(); + taskIdList.add(temp); + generateThroughImageTextDTO.setUniqueId(temp); + String jsonString = JSON.toJSONString(generateThroughImageTextDTO); - // 4、将消息发布到MQ消息队列 - rabbitMQService.publishMessageToGenerate(jsonString); + // 3、加入redis排队,便于获取实时排队信息 + Double maxScore = redisUtil.getMaxScore(consumptionOrderKey); + redisUtil.addToZSet(consumptionOrderKey, temp, maxScore); + + // 4、将消息发布到MQ消息队列 + rabbitMQService.publishMessageToGenerate(jsonString); + } // 5、返回唯一id - return new PrepareForGenerateVO(uuid, 2 - trialsCount); + return taskIdList; } @Override @@ -432,6 +482,21 @@ public class GenerateServiceImpl extends ServiceImpl i return new GenerateCollectionVO(generateId, null, generatedCollectionItems); } + @Override + public List getGenerateResultList(List taskIdList) { + List results = new ArrayList<>(); + taskIdList.forEach(taskId -> { + String key = generateResultKey + ":" + taskId; + GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class); + if (!StringUtil.isNullOrEmpty(generateResultVO.getUrl())) { + generateResultVO.setUrl(minioUtil.getPresignedUrl(generateResultVO.getUrl(), CommonConstant.MINIO_IMAGE_EXPIRE_TIME)); + } + results.add(generateResultVO); + }); + return results; + } + + public Generate selectByUniqueId(String uniqueId) { QueryWrapper qw = new QueryWrapper<>(); qw.eq("unique_id", uniqueId); diff --git a/src/main/resources/application-dev.properties b/src/main/resources/application-dev.properties index c5e8ddb4..63a1a5be 100644 --- a/src/main/resources/application-dev.properties +++ b/src/main/resources/application-dev.properties @@ -77,10 +77,11 @@ spring.redis.lettuce.pool.max-wait=5 redis.key.orderForGenerate=OrderForGenerate redis.key.generateCancelSet=GenerateCancelSet -redis.key.generateExceptionMap=GenerateExceptionMap +redis.key.generateExceptionMap=Generate:Exception redis.key.resultMap=ResultMap redis.key.orderForSR=OrderForSR redis.key.SRCancelSet=SRCancelSet redis.key.SRExceptionMap=SRExceptionMap redis.key.taskList=TaskList -redis.key.credits.pre-deduction=Credits:PreDeduction \ No newline at end of file +redis.key.credits.pre-deduction=Credits:PreDeduction +redis.key.generateResult=Generate:Result \ No newline at end of file