From d2076a81d58fb6dc054d92ff912bdcc04f79cd7c Mon Sep 17 00:00:00 2001 From: xupei Date: Thu, 18 Apr 2024 16:48:19 +0800 Subject: [PATCH] =?UTF-8?q?generate=20=E8=81=94=E8=B0=83=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../da/common/RabbitMQ/GenerateConsumer.java | 3 +- .../ai/da/controller/GenerateController.java | 8 +- .../dto/GenerateThroughImageTextDTO.java | 1 + .../com/ai/da/model/vo/GenerateResultVO.java | 8 ++ .../ai/da/model/vo/PrepareForGenerateVO.java | 6 +- .../com/ai/da/service/GenerateService.java | 4 +- .../da/service/impl/GenerateServiceImpl.java | 99 ++++++++++++------- 7 files changed, 82 insertions(+), 47 deletions(-) diff --git a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java index 76bf15ec..349dcca0 100644 --- a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java +++ b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java @@ -109,6 +109,7 @@ public class GenerateConsumer { long start = System.currentTimeMillis(); Map generateResult = JSONObject.parseObject(msg.getBody(), Map.class); + log.info("tasks_id : {} start ",generateResult.get("tasks_id")); // log.info("tasks_id : {}, message : {}",generateResult.get("tasks_id"), generateResult.get("message") ); if (generateResult.get("status").equals("SUCCESS")){ String url = generateResult.get("data"); @@ -118,7 +119,7 @@ public class GenerateConsumer { // 修改redis中的数据状态为exception String key = generateResultKey + ":" + generateResult.get("tasks_id"); Long expire = redisUtil.getExpire(key); - redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(null, null, "Fail")), expire); + redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(generateResult.get("tasks_id"), null, null, "Fail")), expire); // 将异常信息存到exception中 HashMap exceptionInfo = new HashMap<>(); exceptionInfo.put(generateResult.get("tasks_id"), generateResult.get("data")); diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java index 895e5195..4e1566dd 100644 --- a/src/main/java/com/ai/da/controller/GenerateController.java +++ b/src/main/java/com/ai/da/controller/GenerateController.java @@ -33,12 +33,12 @@ public class GenerateController { return Response.success(generateService.generateCaption(sketchElementId)); } - @ApiOperation("通过文字、图片生成图片") + /*@ApiOperation("通过文字、图片生成图片") @PostMapping("/sketchAndPrint") public void generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { // return Response.success(generateService.generateThroughImageText(generateThroughImageTextDTO)); generateService.generateThroughImageText(generateThroughImageTextDTO); - } + }*/ @ApiOperation("喜欢生成的图片") @PostMapping("/like") @@ -55,14 +55,14 @@ public class GenerateController { @ApiOperation(value = "发起生成请求,异步获取结果") @PostMapping("/prepare") - public Response> prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public Response prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO)); } @ApiOperation(value = "取消继续生成") @GetMapping("/stopWaiting") public Response stopWaiting(@RequestParam("userId") Long userId, - @RequestParam("uniqueId") String uniqueId, + @RequestParam("uniqueId") List uniqueId, @RequestParam("timeZone") String timeZone) { generateService.cancelGenerate(userId, uniqueId, timeZone); return Response.success("stop waiting successfully"); diff --git a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java index 1adfdbef..fcd60ad1 100644 --- a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java +++ b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java @@ -48,6 +48,7 @@ public class GenerateThroughImageTextDTO { @ApiModelProperty("唯一id,用于保持消息唯一性") String uniqueId; + @NotNull(message = "Please check if the required fields are empty.(isTestUser)") @ApiModelProperty("是否是测试用户") Boolean isTestUser; } diff --git a/src/main/java/com/ai/da/model/vo/GenerateResultVO.java b/src/main/java/com/ai/da/model/vo/GenerateResultVO.java index 90ec7cde..3f28fc59 100644 --- a/src/main/java/com/ai/da/model/vo/GenerateResultVO.java +++ b/src/main/java/com/ai/da/model/vo/GenerateResultVO.java @@ -9,9 +9,17 @@ import lombok.Data; @AllArgsConstructor public class GenerateResultVO { + private String taskId; + private Long id; private String url; private String status; + + public GenerateResultVO(Long id, String url, String status) { + this.id = id; + this.url = url; + this.status = status; + } } diff --git a/src/main/java/com/ai/da/model/vo/PrepareForGenerateVO.java b/src/main/java/com/ai/da/model/vo/PrepareForGenerateVO.java index 438d94db..78c7260c 100644 --- a/src/main/java/com/ai/da/model/vo/PrepareForGenerateVO.java +++ b/src/main/java/com/ai/da/model/vo/PrepareForGenerateVO.java @@ -4,17 +4,19 @@ import io.swagger.annotations.ApiModel; import io.swagger.annotations.ApiModelProperty; import lombok.Data; +import java.util.List; + @Data @ApiModel("prepare for generate响应vo") public class PrepareForGenerateVO { @ApiModelProperty("uniqueId") - private String uniqueId; + private List uniqueId; @ApiModelProperty("剩余使用次数") private Integer leftUsageCount; - public PrepareForGenerateVO(String uniqueId, Integer leftUsageCount) { + public PrepareForGenerateVO(List uniqueId, Integer leftUsageCount) { this.uniqueId = uniqueId; this.leftUsageCount = leftUsageCount; } diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index 3b6c60b0..65297e79 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -29,10 +29,10 @@ public interface GenerateService extends IService { List getGenerateResultList(List taskIdList); - List prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO); + PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO); Long getRankPosition(String uniqueId); - void cancelGenerate(Long userId, String uniqueId, String timeZone); + void cancelGenerate(Long userId, List uniqueId, String timeZone); } diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index bc8ece64..68daf1bb 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -23,6 +23,7 @@ import com.ai.da.service.RabbitMQService; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.google.gson.Gson; import io.minio.errors.MinioException; @@ -153,7 +154,7 @@ public class GenerateServiceImpl extends ServiceImpl i }else { status = "Fail"; } - GenerateResultVO generateResultVO = new GenerateResultVO(null, null, status); + GenerateResultVO generateResultVO = new GenerateResultVO(generateThroughImageTextDTO.getUniqueId(), null, null, status); redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); // 5、处理模型返回的数据 @@ -193,7 +194,19 @@ public class GenerateServiceImpl extends ServiceImpl i // 5.1 将相应的url保存到数据库 GenerateDetail generateDetail = new GenerateDetail(); GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); - Generate generate = selectByUniqueId(taskId); + Generate generate ; + try{ + generate = selectByUniqueId(taskId); + }catch (MybatisPlusException e){ + log.error(e.getMessage()); + if (e.getMessage().equals("One record is expected, but the query result is multiple records")){ + generate = selectListByUniqueId(taskId).get(0); + }else { + throw new BusinessException("There are some problems with database query, please try again."); + } + + } +// Generate generate = selectByUniqueId(taskId); String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(url, 24 * 60), Boolean.FALSE); // 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过 List> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generate.getLevel1Type()); @@ -210,7 +223,7 @@ public class GenerateServiceImpl extends ServiceImpl i String key = generateResultKey + ":" + taskId; Long expire = redisUtil.getExpire(key); - GenerateResultVO generateResultVO = new GenerateResultVO(generateDetail.getId(), url, "Success"); + GenerateResultVO generateResultVO = new GenerateResultVO(taskId, generateDetail.getId(), url, "Success"); redisUtil.addToString(key, new Gson().toJson(generateResultVO), expire); } @@ -360,8 +373,8 @@ public class GenerateServiceImpl extends ServiceImpl i } @Override -// public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { - public List prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { +// public List prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、参数检查,判断必须参数是否为空 if (Objects.isNull(generateThroughImageTextDTO.getUserId())) { throw new BusinessException("userId cannot be empty"); @@ -376,7 +389,7 @@ public class GenerateServiceImpl extends ServiceImpl i if (generateThroughImageTextDTO.getIsTestUser()){ trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type()); if (trialsCount >= 2){ - return new ArrayList<>(); + return new PrepareForGenerateVO(0); } } @@ -422,7 +435,7 @@ public class GenerateServiceImpl extends ServiceImpl i } // 5、返回唯一id - return taskIdList; + return new PrepareForGenerateVO(taskIdList, 2 - trialsCount); } @Override @@ -504,46 +517,56 @@ public class GenerateServiceImpl extends ServiceImpl i return getOne(qw); } - @Override - @Transactional(rollbackFor = Exception.class) - public void cancelGenerate(Long userId, String uniqueId, String timeZone) { - // 1、确认当前消息是否还在排队中 - Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId); - Boolean flag = Boolean.FALSE; - if (exists) flag = redisUtil.getRank(consumptionOrderKey, uniqueId) > 1L ? Boolean.TRUE : Boolean.FALSE; - // 不管flag的默认值是true还是false,只要exists为false,&& 将短路 - if (exists && flag) { - // 1.1、将需要取消的唯一id加入redis,以便及时取消生成 - redisUtil.addToSet(cancelSetKey, uniqueId); - // 1.2 将需要取消的id从redis的ConsumptionOrder中删除 - redisUtil.removeFromZSet(consumptionOrderKey, uniqueId); - } else { - // 2、判断该消息是否异常 - boolean hasKey = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId); - // 3、判断该消息是否已经消费结束 - Boolean existsInResult = redisUtil.isElementExistsInMap(resultMapKey, uniqueId); - if (!hasKey && !existsInResult) { - // 设置取等待状态为false - AsyncCallerUtil.waitingStatus.put(uniqueId, false); - // 3、直接发送取消请求到python端 - pythonService.cancelGenerateTask(uniqueId); - } - } + public List selectListByUniqueId(String uniqueId) { + QueryWrapper qw = new QueryWrapper<>(); + qw.eq("unique_id", uniqueId).orderByDesc("id"); - // 3、考虑加一张表,专门用于记录哪些用户在什么时间进行了取消操作,包括已经异常的请求 - GenerateCancel generateCancel = new GenerateCancel(userId, uniqueId, DateUtil.getByTimeZone(timeZone)); - generateCancelMapper.insert(generateCancel); + return baseMapper.selectList(qw); } - // 判断试用用户试用generate机会是否使用完毕 + @Override + @Transactional(rollbackFor = Exception.class) + public void cancelGenerate(Long userId, List uniqueIdList, String timeZone) { + uniqueIdList.forEach(uniqueId -> { + // 1、确认当前消息是否还在排队中 + Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId); + Boolean flag = Boolean.FALSE; + if (exists) flag = redisUtil.getRank(consumptionOrderKey, uniqueId) > 1L ? Boolean.TRUE : Boolean.FALSE; + // 不管flag的默认值是true还是false,只要exists为false,&& 将短路 + if (exists && flag) { + // 1.1、将需要取消的唯一id加入redis,以便及时取消生成 + redisUtil.addToSet(cancelSetKey, uniqueId); + // 1.2 将需要取消的id从redis的ConsumptionOrder中删除 + redisUtil.removeFromZSet(consumptionOrderKey, uniqueId); + } else { + // 2、判断该消息是否异常 + boolean hasKey = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId); + // 3、判断该消息是否已经消费结束 + Boolean existsInResult = redisUtil.isElementExistsInMap(resultMapKey, uniqueId); + if (!hasKey && !existsInResult) { + // 设置取等待状态为false + AsyncCallerUtil.waitingStatus.put(uniqueId, false); + // 3、直接发送取消请求到python端 + pythonService.cancelGenerateTask(uniqueId); + } + } + // 3、考虑加一张表,专门用于记录哪些用户在什么时间进行了取消操作,包括已经异常的请求 + GenerateCancel generateCancel = new GenerateCancel(userId, uniqueId, DateUtil.getByTimeZone(timeZone)); + generateCancelMapper.insert(generateCancel); + }); + + + } + + // 判断试用用户试用generate机会是否使用完毕 每个board 3次机会 private int getTrialsCount(Long userId, String level1Type){ List getGenerateList = getGenerateByAccountId(userId, level1Type); int trialsCount ; if (getGenerateList.isEmpty()){ trialsCount = 0; - } else if (getGenerateList.size() == 1) { + } else if (getGenerateList.size() == 1 || (getGenerateList.size() >= 4 && getGenerateList.size() / 4 == 1)) { trialsCount = 1; - } else if (getGenerateList.size() == 2) { + } else if (getGenerateList.size() == 2 || (getGenerateList.size() >= 4 && getGenerateList.size() / 4 == 2)) { trialsCount = 2; }else { trialsCount = 2;