diff --git a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java index 573b1ecb..47464a59 100644 --- a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java +++ b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java @@ -294,6 +294,8 @@ public class GenerateConsumer { exceptionInfo.put(generateResult.get("tasks_id"), generateResult.get("message")); // 存redis redisUtil.addToMap(exceptionMapKey, exceptionInfo); + // 记录失败状态并向用户发送提示消息 + generateService.processPTFailSituation(generateResult.get("tasks_id")); } } catch (Exception e) { e.printStackTrace(); @@ -312,6 +314,8 @@ public class GenerateConsumer { exceptionInfo.put(String.valueOf(generateResult.get("tasks_id")), exceptionMessage); // 存redis redisUtil.addToMap(exceptionMapKey, exceptionInfo); + // 记录失败状态并向用户发送提示消息 + generateService.processPTFailSituation(generateResult.get("tasks_id")); } long end = System.currentTimeMillis(); diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index ccd3cda5..d8af260c 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -57,6 +57,8 @@ public interface GenerateService extends IService { void processPoseTransformResult(String taskId, String gifUrl, String videoUrl, String imageUrl); + void processPTFailSituation(String taskId); + List getPoseTransformationResult(List taskIdList, Long projectId, Boolean like); void updatePoseTransferStatus(String taskId, String status, PoseTransformation poseTransformation); diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 8d21451d..f315a662 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -78,6 +78,7 @@ import java.util.regex.Pattern; import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*; import static com.ai.da.common.enums.CreditsEventsEnum.PATTERN; +import static com.ai.da.common.enums.CreditsEventsEnum.POSE_TRANSFORMATION; import static com.ai.da.common.enums.CreditsEventsEnum.TO_PRODUCT_IMAGE; import static com.ai.da.common.enums.CreditsEventsEnum.TO_PRODUCT_IMAGE_ADVANCED; import static com.ai.da.common.enums.WangXiangTaskStatusEnum.FAILED; @@ -2469,6 +2470,7 @@ public class GenerateServiceImpl extends ServiceImpl i Boolean isRequestSuccess = false; PoseTransformation poseTransformation = new PoseTransformation(); if (!StringUtil.isNullOrEmpty(poseTransformDTO.getModelName()) && poseTransformDTO.getModelName().equals("wx")) { + // 请求生成视频 taskId = animateAnyone(poseTransformDTO, accountId); if (!StringUtil.isNullOrEmpty(taskId)) { isRequestSuccess = true; @@ -2482,6 +2484,7 @@ public class GenerateServiceImpl extends ServiceImpl i com.alibaba.fastjson.JSONObject params = createParamsForMotion(poseTransformDTO, taskId); String api = params.getString("api"); params.remove("api"); + // 请求生成视频 isRequestSuccess = pythonService.poseTransformation(params, api); } @@ -2564,16 +2567,11 @@ public class GenerateServiceImpl extends ServiceImpl i public void processPoseTransformResult(String taskId, String gifUrl, String videoUrl, String imageUrl) { // 1、存储模型返回的数据 - PoseTransformation poseTransformation; - QueryWrapper qw = new QueryWrapper<>(); - qw.eq("unique_id", taskId); - List poseTransformations = poseTransformationMapper.selectList(qw); - if (poseTransformations != null && poseTransformations.size() > 1) { - log.warn("通过taskId {} 查询到的PoseTransformation的结果不止一条", taskId); - } else if (poseTransformations == null || poseTransformations.isEmpty()) { + PoseTransformation poseTransformation = getPoseTransformationByTaskId(taskId); + if (Objects.isNull(poseTransformation)) { return; } - poseTransformation = poseTransformations.get(0); + poseTransformation.setGifUrl(gifUrl); poseTransformation.setVideoUrl(videoUrl); poseTransformation.setFirstFrameUrl(imageUrl); @@ -2596,10 +2594,42 @@ public class GenerateServiceImpl extends ServiceImpl i Boolean flag = creditsService.taskCreditsDeduction(Long.parseLong(accountId), taskId); if (flag) creditsService.updateChangedCredits(accountId, taskId); + // 发消息 + sendSysMsgForPT(poseTransformation, true); + } + + // 处理PoseTransformation失败的情况 + public void processPTFailSituation(String taskId) { + PoseTransformation poseTransformation = getPoseTransformationByTaskId(taskId); + if (Objects.isNull(poseTransformation)) { + return; + } + + // 更新生成记录的状态 + poseTransformation.setTaskStatus("Fail"); + poseTransformation.setUpdateTime(LocalDateTime.now()); + poseTransformationMapper.updateById(poseTransformation); + // 发消息 + sendSysMsgForPT(poseTransformation, false); + } + + private PoseTransformation getPoseTransformationByTaskId(String taskId) { + QueryWrapper qw = new QueryWrapper<>(); + qw.eq("unique_id", taskId); + List poseTransformations = poseTransformationMapper.selectList(qw); + if (poseTransformations != null && poseTransformations.size() > 1) { + log.warn("通过taskId {} 查询到的PoseTransformation的结果不止一条", taskId); + } else if (poseTransformations == null || poseTransformations.isEmpty()) { + return null; + } + return poseTransformations.get(0); + } + + public void sendSysMsgForPT(PoseTransformation poseTransformation, boolean isSuccess){ Project project = projectService.getById(poseTransformation.getProjectId()); // 发通知 if (Objects.nonNull(project) && !StringUtil.isNullOrEmpty(project.getName())) { - messageCenterService.videoFinishedMsg(poseTransformation.getAccountId(), project.getName(), true); + messageCenterService.videoFinishedMsg(poseTransformation.getAccountId(), project.getName(), isSuccess); } } @@ -2641,7 +2671,7 @@ public class GenerateServiceImpl extends ServiceImpl i } private PoseTransformationVO buildPoseTransformationVO(String taskId, PoseTransformation dbItem) { - String type = resolveModelType(taskId, CreditsEventsEnum.POSE_TRANSFORMATION.getValue()); + String type = resolveModelType(taskId, POSE_TRANSFORMATION.getName()); String key = generateResultKey + ":" + taskId; String resultJson = redisUtil.getFromString(key); @@ -3020,16 +3050,11 @@ public class GenerateServiceImpl extends ServiceImpl i @Transactional public void processPoseTransformResultBatch(String taskId, String gifUrl, String videoUrl, String imageUrl, String progress) { // 1、存储模型返回的数据 - PoseTransformation poseTransformation; - QueryWrapper qw = new QueryWrapper<>(); - qw.eq("unique_id", taskId); - List poseTransformations = poseTransformationMapper.selectList(qw); - if (poseTransformations != null && poseTransformations.size() > 1) { - log.warn("通过taskId {} 查询到的PoseTransformation的结果不止一条", taskId); - } else if (poseTransformations == null || poseTransformations.isEmpty()) { + PoseTransformation poseTransformation = getPoseTransformationByTaskId(taskId); + if (Objects.isNull(poseTransformation)) { return; } - poseTransformation = poseTransformations.get(0); + poseTransformation.setGifUrl(gifUrl); poseTransformation.setVideoUrl(videoUrl); poseTransformation.setFirstFrameUrl(imageUrl); @@ -3072,17 +3097,11 @@ public class GenerateServiceImpl extends ServiceImpl i @Transactional public void processPoseTransformResultBatch(String progress, String taskId) { // 1、存储模型返回的数据 - PoseTransformation poseTransformation; - QueryWrapper qw = new QueryWrapper<>(); - qw.eq("unique_id", taskId); - List poseTransformations = poseTransformationMapper.selectList(qw); - log.info("poseTransformations : {}", poseTransformations); - if (poseTransformations != null && poseTransformations.size() > 1) { - log.warn("通过taskId {} 查询到的PoseTransformation的结果不止一条", taskId); - } else if (poseTransformations == null || poseTransformations.isEmpty()) { + PoseTransformation poseTransformation = getPoseTransformationByTaskId(taskId); + if (Objects.isNull(poseTransformation)) { return; } - poseTransformation = poseTransformations.get(0); + String taskIdBatch = poseTransformation.getTaskIdBatch(); log.info("progress:{}", progress); log.info("taskIdBatch:{}", taskIdBatch); @@ -3914,7 +3933,7 @@ public class GenerateServiceImpl extends ServiceImpl i private String resolveModelType(String taskId, String func) { // 判断当前task来自哪个模型 if (!StringUtil.isNullOrEmpty(func) - && func.equals(CreditsEventsEnum.POSE_TRANSFORMATION.getValue())) { + && func.equals(POSE_TRANSFORMATION.getName())) { List poseTransformations = poseTransformationMapper.selectList( new QueryWrapper().eq("unique_id", taskId)); if (!poseTransformations.isEmpty() diff --git a/src/main/java/com/ai/da/service/impl/MessageCenterServiceImpl.java b/src/main/java/com/ai/da/service/impl/MessageCenterServiceImpl.java index 24fd21c5..7c8460ab 100644 --- a/src/main/java/com/ai/da/service/impl/MessageCenterServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/MessageCenterServiceImpl.java @@ -287,9 +287,15 @@ public class MessageCenterServiceImpl extends ServiceImpl notificationIdList) { Long id = UserContext.getUserHolder().getId(); for (Long notificationId : notificationIdList) { - SysNotificationReadStatus sysNotificationReadStatus = new SysNotificationReadStatus(notificationId, id); - sysNotificationReadStatus.setCreateTime(LocalDateTime.now()); - sysNotificationReadStatusMapper.insert(sysNotificationReadStatus); + Notification notification = getById(notificationId); + if (Objects.nonNull(notification) && notification.getType().equals("system")) { + // 当系统消息指定了接收人员时,不允许其他人员已读 + if (Objects.isNull(notification.getReceiverId()) || notification.getReceiverId().equals(id)) { + SysNotificationReadStatus sysNotificationReadStatus = new SysNotificationReadStatus(notificationId, id); + sysNotificationReadStatus.setCreateTime(LocalDateTime.now()); + sysNotificationReadStatusMapper.insert(sysNotificationReadStatus); + } + } } }