From 249351bf52294663ed5e377739c4b4e3da2feab1 Mon Sep 17 00:00:00 2001 From: shahaibo <1023316923@qq.com> Date: Thu, 5 Jun 2025 17:06:11 +0800 Subject: [PATCH] TASK:batch toProductImage;chatStream; --- .../da/common/RabbitMQ/GenerateConsumer.java | 32 ++--- .../com/ai/da/model/dto/BatchParamDTO.java | 2 + .../java/com/ai/da/python/PythonService.java | 11 +- .../ai/da/service/impl/DesignServiceImpl.java | 117 +++++++++++------- 4 files changed, 97 insertions(+), 65 deletions(-) diff --git a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java index e02ee7cd..018a37c9 100644 --- a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java +++ b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java @@ -387,11 +387,11 @@ public class GenerateConsumer { } } else { // 修改redis中的数据状态为exception - String key = toProductImageResultKey + ":" + generateResult.get("task_id"); - redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(generateResult.getString("task_id"), null, null, "Fail")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); + String key = toProductImageResultKey + ":" + generateResult.get("tasks_id"); + redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(generateResult.getString("tasks_id"), null, null, "Fail")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); // 将异常信息存到exception中 HashMap exceptionInfo = new HashMap<>(); - exceptionInfo.put(generateResult.getString("task_id"), generateResult.getString("data")); + exceptionInfo.put(generateResult.getString("tasks_id"), generateResult.getString("data")); // 存redis redisUtil.addToMap(exceptionMapKey, exceptionInfo); } @@ -400,7 +400,7 @@ public class GenerateConsumer { try { channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false); // 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除 - redisUtil.removeFromZSet(consumptionOrderKey, generateResult.getString("task_id")); + redisUtil.removeFromZSet(consumptionOrderKey, generateResult.getString("tasks_id")); } catch (IOException exception) { log.error("手动确认,取消返回队列,不再重新消费"); } @@ -408,13 +408,13 @@ public class GenerateConsumer { String exceptionMessage = JSONObject.toJSONString(generateResult) + " Exception message : " + e.getMessage(); HashMap exceptionInfo = new HashMap<>(); - exceptionInfo.put(String.valueOf(generateResult.get("task_id")), exceptionMessage); + exceptionInfo.put(String.valueOf(generateResult.get("tasks_id")), exceptionMessage); // 存redis redisUtil.addToMap(exceptionMapKey, exceptionInfo); } long end = System.currentTimeMillis(); - log.info("tasks_id : {}, end , message : {}, 执行时长: {} 毫秒", generateResult.get("task_id"), generateResult.get("message"), (end - start)); + log.info("tasks_id : {}, end , message : {}, 执行时长: {} 毫秒", generateResult.get("tasks_id"), generateResult.get("message"), (end - start)); log.info("============ProcessToProductImageBatchResult End listening=========="); } @@ -426,23 +426,23 @@ public class GenerateConsumer { log.info("relightBatch response : {}", generateResult); try { - log.info("task_id : {} start ", generateResult.get("task_id")); + log.info("task_id : {} start ", generateResult.get("tasks_id")); if (!StringUtils.isEmpty(generateResult.getString("progress"))) { String progress = generateResult.getString("progress"); - JSONArray result = generateResult.getJSONArray("result"); + JSONObject result = generateResult.getJSONObject("result_data"); String url = null; if (!StringUtils.isEmpty(result)) { - url = result.getString(0); - String taskId = generateResult.getString("task_id"); + url = result.getString("image_url"); + String taskId = generateResult.getString("tasks_id"); userLikeGroupService.relightBatch(taskId, url, progress); } } else { // 修改redis中的数据状态为exception - String key = relightResultKey + ":" + generateResult.get("task_id"); - redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(generateResult.getString("task_id"), null, null, "Fail")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); + String key = relightResultKey + ":" + generateResult.get("tasks_id"); + redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(generateResult.getString("tasks_id"), null, null, "Fail")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); // 将异常信息存到exception中 HashMap exceptionInfo = new HashMap<>(); - exceptionInfo.put(generateResult.getString("task_id"), generateResult.getString("data")); + exceptionInfo.put(generateResult.getString("tasks_id"), generateResult.getString("data")); // 存redis redisUtil.addToMap(exceptionMapKey, exceptionInfo); } @@ -451,7 +451,7 @@ public class GenerateConsumer { try { channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false); // 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除 - redisUtil.removeFromZSet(consumptionOrderKey, generateResult.getString("task_id")); + redisUtil.removeFromZSet(consumptionOrderKey, generateResult.getString("tasks_id")); } catch (IOException exception) { log.error("手动确认,取消返回队列,不再重新消费"); } @@ -459,13 +459,13 @@ public class GenerateConsumer { String exceptionMessage = JSONObject.toJSONString(generateResult) + " Exception message : " + e.getMessage(); HashMap exceptionInfo = new HashMap<>(); - exceptionInfo.put(String.valueOf(generateResult.get("task_id")), exceptionMessage); + exceptionInfo.put(String.valueOf(generateResult.get("tasks_id")), exceptionMessage); // 存redis redisUtil.addToMap(exceptionMapKey, exceptionInfo); } long end = System.currentTimeMillis(); - log.info("task_id : {}, end , message : {}, 执行时长: {} 毫秒", generateResult.get("task_id"), generateResult.get("message"), (end - start)); + log.info("task_id : {}, end , message : {}, 执行时长: {} 毫秒", generateResult.get("tasks_id"), generateResult.get("message"), (end - start)); log.info("============ProcessRelightBatchResult End listening=========="); } diff --git a/src/main/java/com/ai/da/model/dto/BatchParamDTO.java b/src/main/java/com/ai/da/model/dto/BatchParamDTO.java index 81ca9401..68c36240 100644 --- a/src/main/java/com/ai/da/model/dto/BatchParamDTO.java +++ b/src/main/java/com/ai/da/model/dto/BatchParamDTO.java @@ -11,4 +11,6 @@ public class BatchParamDTO { private BigDecimal image_strength; private String image_url; private String product_type; + + private String direction; } diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index fcf34798..1dd9d3ec 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -4314,7 +4314,7 @@ public class PythonService { throw new BusinessException("toProductImage.interface.exception"); } - public Boolean relightBatch(String url, String taskId, String prompt, String direction, String relightType) { + public Boolean relightBatch(String taskIdBatch, List paramList, String userId) { // todo 限流校验 // AccessLimitUtils.validate("design",5); OkHttpClient client = new OkHttpClient().newBuilder() @@ -4326,12 +4326,9 @@ public class PythonService { MediaType mediaType = MediaType.parse("application/json"); //关闭FastJson的引用检测 防止出现$ref 现象 Map map = new HashMap<>(); - map.put("tasks_id", taskId); - map.put("image_url", url); - map.put("prompt", prompt); - map.put("direction", direction); - map.put("product_type", relightType); - map.put("batch_size", 1); + map.put("batch_tasks_id", taskIdBatch); + map.put("user_id", userId); + map.put("batch_data_list", paramList); log.info("relightImage请求python 参数:####{}", map); String param = JSON.toJSONString(map, SerializerFeature.WriteNullStringAsEmpty); log.info(param); diff --git a/src/main/java/com/ai/da/service/impl/DesignServiceImpl.java b/src/main/java/com/ai/da/service/impl/DesignServiceImpl.java index 673d9e30..aaca297e 100644 --- a/src/main/java/com/ai/da/service/impl/DesignServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/DesignServiceImpl.java @@ -2039,9 +2039,9 @@ public class DesignServiceImpl extends ServiceImpl impleme }else { s = "Snow moutain, snowy day, natural light"; } + Map toProductImageVOIntegerMap = allocateElements(toProductImageDTO.getToProductImageVOList(), cloudTaskDTO.getNums()); for (ToProductImageVO toProductImageVO : toProductImageDTO.getToProductImageVOList()) { - String taskId = UUID.randomUUID() + "-" + i + "-" + userHolder.getId(); - i ++; + String taskId; if (toProductImageVO.getElementType().equals("ToProductImage")) { ToProductImageResult toProductImageResult1 = toProductImageResultMapper.selectById(toProductImageVO.getElementId()); String relightType = "overall"; @@ -2052,53 +2052,86 @@ public class DesignServiceImpl extends ServiceImpl impleme if (design.getSingleOverall().equals("single")) { relightType = "single"; } + List paramList = new ArrayList<>(); + List promptList = pythonService.getPrompt(s, toProductImageVOIntegerMap.get(toProductImageVO)); + for (int i1 = 0; i1 < toProductImageVOIntegerMap.get(toProductImageVO); i1++) { + BatchParamDTO batchParamDTO = new BatchParamDTO(); + taskId = UUID.randomUUID() + "-" + i + "-" + userHolder.getId(); + batchParamDTO.setTasks_id(taskId); + batchParamDTO.setPrompt(promptList.get(i1)); + batchParamDTO.setImage_url(toProductImageResult1.getUrl()); + batchParamDTO.setProduct_type(relightType); + batchParamDTO.setDirection(toProductImageDTO.getDirection()); + paramList.add(batchParamDTO); + + ToProductImageResult toProductImageResult = new ToProductImageResult(); + toProductImageResult.setElementId(toProductImageResult1.getId()); + toProductImageResult.setElementType("ToProductImage"); + toProductImageResult.setCreateTime(LocalDateTime.now()); + toProductImageResult.setToProductImageRecordId(toProductImageRecord.getId()); + toProductImageResult.setIsLike(0); + toProductImageResult.setTaskId(taskId); + toProductImageResult.setProjectId(projectId); + toProductImageResult.setTaskIdBatch(batchTaskId); + if (null != userLikeGroupId) { + toProductImageResult.setUserLikeGroupId(userLikeGroupId); + } + if (toProductImageDTO.getBrightenValue() != null) { + toProductImageResult.setBrightenValue(toProductImageDTO.getBrightenValue()); + } + toProductImageResult.setDirection(toProductImageDTO.getDirection()); + toProductImageResultMapper.insert(toProductImageResult); + result.add(toProductImageResult); + + // 添加需要扣除的积分到预扣除区 + creditsService.addRecordToCreditsDeduction(userHolder.getId(), taskId, CreditsEventsEnum.RELIGHT); + i ++; + } + // 走模型 + pythonService.relightBatch(batchTaskId, paramList, userHolder.getId().toString()); } - // 走模型 - pythonService.relightBatch(toProductImageResult1.getUrl(), taskId, s, toProductImageDTO.getDirection(), relightType); - ToProductImageResult toProductImageResult = new ToProductImageResult(); - toProductImageResult.setElementId(toProductImageResult1.getId()); - toProductImageResult.setElementType("ToProductImage"); - toProductImageResult.setCreateTime(LocalDateTime.now()); - toProductImageResult.setToProductImageRecordId(toProductImageRecord.getId()); - toProductImageResult.setIsLike(0); - toProductImageResult.setTaskId(taskId); - toProductImageResult.setProjectId(projectId); - toProductImageResult.setTaskIdBatch(batchTaskId); - if (null != userLikeGroupId) { - toProductImageResult.setUserLikeGroupId(userLikeGroupId); - } - if (toProductImageDTO.getBrightenValue() != null) { - toProductImageResult.setBrightenValue(toProductImageDTO.getBrightenValue()); - } - toProductImageResult.setDirection(toProductImageDTO.getDirection()); - toProductImageResultMapper.insert(toProductImageResult); - result.add(toProductImageResult); }else { ToProductElement toProductElement = toProductElementMapper.selectById(toProductImageVO.getElementId()); // 走模型 - pythonService.relightBatch(toProductElement.getUrl(), taskId, s, toProductImageDTO.getDirection(), "overall"); - ToProductImageResult toProductImageResult = new ToProductImageResult(); - toProductImageResult.setElementId(toProductElement.getId()); - toProductImageResult.setElementType("ProductElement"); - toProductImageResult.setCreateTime(LocalDateTime.now()); - toProductImageResult.setToProductImageRecordId(toProductImageRecord.getId()); - toProductImageResult.setIsLike(0); - toProductImageResult.setTaskId(taskId); - toProductImageResult.setProjectId(projectId); - toProductImageResult.setTaskIdBatch(batchTaskId); - if (null != userLikeGroupId) { - toProductImageResult.setUserLikeGroupId(userLikeGroupId); + List paramList = new ArrayList<>(); + List promptList = pythonService.getPrompt(s, toProductImageVOIntegerMap.get(toProductImageVO)); + for (int i1 = 0; i1 < toProductImageVOIntegerMap.get(toProductImageVO); i1++) { + BatchParamDTO batchParamDTO = new BatchParamDTO(); + taskId = UUID.randomUUID() + "-" + i + "-" + userHolder.getId(); + batchParamDTO.setTasks_id(taskId); + batchParamDTO.setPrompt(promptList.get(i1)); + batchParamDTO.setImage_url(toProductElement.getUrl()); + batchParamDTO.setProduct_type("overall"); + batchParamDTO.setDirection(toProductImageDTO.getDirection()); + paramList.add(batchParamDTO); + + ToProductImageResult toProductImageResult = new ToProductImageResult(); + toProductImageResult.setElementId(toProductElement.getId()); + toProductImageResult.setElementType("ProductElement"); + toProductImageResult.setCreateTime(LocalDateTime.now()); + toProductImageResult.setToProductImageRecordId(toProductImageRecord.getId()); + toProductImageResult.setIsLike(0); + toProductImageResult.setTaskId(taskId); + toProductImageResult.setProjectId(projectId); + toProductImageResult.setTaskIdBatch(batchTaskId); + if (null != userLikeGroupId) { + toProductImageResult.setUserLikeGroupId(userLikeGroupId); + } + if (toProductImageDTO.getBrightenValue() != null) { + toProductImageResult.setBrightenValue(toProductImageDTO.getBrightenValue()); + } + toProductImageResult.setDirection(toProductImageDTO.getDirection()); + toProductImageResultMapper.insert(toProductImageResult); + result.add(toProductImageResult); + + // 添加需要扣除的积分到预扣除区 + creditsService.addRecordToCreditsDeduction(userHolder.getId(), taskId, CreditsEventsEnum.RELIGHT); + i ++; } - if (toProductImageDTO.getBrightenValue() != null) { - toProductImageResult.setBrightenValue(toProductImageDTO.getBrightenValue()); - } - toProductImageResult.setDirection(toProductImageDTO.getDirection()); - toProductImageResultMapper.insert(toProductImageResult); - result.add(toProductImageResult); + // 走模型 + pythonService.relightBatch(batchTaskId, paramList, userHolder.getId().toString()); } - // 添加需要扣除的积分到预扣除区 - creditsService.addRecordToCreditsDeduction(userHolder.getId(), taskId, CreditsEventsEnum.RELIGHT); } CloudTask cloudTask = CopyUtil.copyObject(cloudTaskDTO, CloudTask.class); cloudTask.setProjectId(projectId);