diff --git a/src/main/java/com/ai/da/common/task/GenerateTask.java b/src/main/java/com/ai/da/common/task/GenerateTask.java index 281ca082..4a11e4e4 100644 --- a/src/main/java/com/ai/da/common/task/GenerateTask.java +++ b/src/main/java/com/ai/da/common/task/GenerateTask.java @@ -6,12 +6,10 @@ import com.ai.da.mapper.primary.PoseTransformationMapper; import com.ai.da.mapper.primary.ToProductImageResultMapper; import com.ai.da.mapper.primary.entity.*; import com.ai.da.model.vo.PoseTransformationVO; -import com.ai.da.service.APIGenerateService; -import com.ai.da.service.CreditsService; -import com.ai.da.service.GenerateService; -import com.ai.da.service.MessageCenterService; +import com.ai.da.service.*; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import io.netty.util.internal.StringUtil; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; @@ -25,19 +23,13 @@ import static com.ai.da.common.enums.CreditsEventsEnum.TO_PRODUCT_IMAGE; @Slf4j @Component +@RequiredArgsConstructor public class GenerateTask { - @Resource - private APIGenerateService apiGenerateService; - @Resource - private CreditsService creditsService; - @Resource - private GenerateService generateService; - @Resource - private MessageCenterService messageCenterService; - @Resource - private ToProductImageResultMapper toProductImageResultMapper; - @Resource - private PoseTransformationMapper poseTransformationMapper; + private final APIGenerateService apiGenerateService; + private final CreditsService creditsService; + private final GenerateService generateService; + private final PoseTransformationMapper poseTransformationMapper; + private final ToProductImageResultMapper toProductImageResultMapper; /* @@ -106,9 +98,10 @@ public class GenerateTask { } - // 万相 -> pose transformation 补偿 一小时执行一次 + // 万相 -> pose transformation 补偿 当前任务执行完后,5分钟再执行一次(不会出现任务重叠的情况) @Scheduled(fixedDelay = 5 * 60 * 1000) public void wxCompensationMechanism(){ + log.info("=====万相补偿获取结果开始====="); List apiGenerates = apiGenerateService.getPendingTaskByStatus("wx"); if (apiGenerates != null && !apiGenerates.isEmpty()){ for (APIGenerate apiGenerate : apiGenerates){ @@ -122,7 +115,6 @@ public class GenerateTask { PoseTransformationVO animateResult = generateService.getAnimateResult(taskId); if (animateResult.getStatus().equals("Success")){ log.info("补偿获取结果成功,发送系统消息"); - sendSysMsgToUser(poseTransformation.getAccountId(), "您的姿势变换生成任务已完成"); } } catch (BusinessException e){ log.warn("万相 animation 生成失败,原因:{}", e.getMessage()); @@ -136,19 +128,13 @@ public class GenerateTask { apiGenerate.setStatus("Fail"); apiGenerate.setUpdateTime(LocalDateTime.now()); apiGenerateService.updateById(apiGenerate); + generateService.sendSysMsgForPT(poseTransformation); } } } } - public void sendSysMsgToUser(Long accountId, String content){ - Notification notification = new Notification(); - notification.setType("system"); - notification.setReceiverId(accountId); - notification.setContent(content); - messageCenterService.prePushMessage(notification); - } diff --git a/src/main/java/com/ai/da/common/utils/RedisUtil.java b/src/main/java/com/ai/da/common/utils/RedisUtil.java index 7f9e4ae7..c40edac7 100644 --- a/src/main/java/com/ai/da/common/utils/RedisUtil.java +++ b/src/main/java/com/ai/da/common/utils/RedisUtil.java @@ -97,6 +97,8 @@ public class RedisUtil { //- - - - - - - - - - - - - - - - - - - - - set类型 - - - - - - - - - - - - - - - - - - - - + public final static String VIDEO_FINISHED_TASKS = "VideoFinishedTasks"; + /** * 将数据放入set缓存 */ diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index d8af260c..e321d79d 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -59,6 +59,8 @@ public interface GenerateService extends IService { void processPTFailSituation(String taskId); + void sendSysMsgForPT(PoseTransformation poseTransformation); + List getPoseTransformationResult(List taskIdList, Long projectId, Boolean like); void updatePoseTransferStatus(String taskId, String status, PoseTransformation poseTransformation); diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index f315a662..f7c1ef9b 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -2595,7 +2595,7 @@ public class GenerateServiceImpl extends ServiceImpl i if (flag) creditsService.updateChangedCredits(accountId, taskId); // 发消息 - sendSysMsgForPT(poseTransformation, true); + sendSysMsgForPT(poseTransformation); } // 处理PoseTransformation失败的情况 @@ -2610,7 +2610,7 @@ public class GenerateServiceImpl extends ServiceImpl i poseTransformation.setUpdateTime(LocalDateTime.now()); poseTransformationMapper.updateById(poseTransformation); // 发消息 - sendSysMsgForPT(poseTransformation, false); + sendSysMsgForPT(poseTransformation); } private PoseTransformation getPoseTransformationByTaskId(String taskId) { @@ -2625,11 +2625,31 @@ public class GenerateServiceImpl extends ServiceImpl i return poseTransformations.get(0); } - public void sendSysMsgForPT(PoseTransformation poseTransformation, boolean isSuccess){ + @Override + public void sendSysMsgForPT(PoseTransformation poseTransformation) { + // 确认当前任务是否已通知过,是 -> 不再通知;否 -> 通知 + Boolean elementExistsInSet = redisUtil.isElementExistsInSet(RedisUtil.VIDEO_FINISHED_TASKS, poseTransformation.getUniqueId()); + if (elementExistsInSet) { + // 已通知过,不再通知 + return; + } + + boolean isSuccess; + if (!StringUtil.isNullOrEmpty(poseTransformation.getTaskStatus()) && poseTransformation.getTaskStatus().equals("Success")) { + isSuccess = true; + } else if (!StringUtil.isNullOrEmpty(poseTransformation.getTaskStatus()) && poseTransformation.getTaskStatus().equals("Fail")) { + isSuccess = false; + } else { + // 不通知 + return; + } + Project project = projectService.getById(poseTransformation.getProjectId()); // 发通知 if (Objects.nonNull(project) && !StringUtil.isNullOrEmpty(project.getName())) { messageCenterService.videoFinishedMsg(poseTransformation.getAccountId(), project.getName(), isSuccess); + // 添加已通知记录到redis + redisUtil.addToSet(RedisUtil.VIDEO_FINISHED_TASKS, poseTransformation.getUniqueId(), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); } } @@ -3727,6 +3747,8 @@ public class GenerateServiceImpl extends ServiceImpl i creditsService.deleteCreditsDeduction(accountId, taskId); } poseTransformationVO.setTaskId(taskId); + // 发送提示消息 + sendSysMsgForPT(poseTransformation); return poseTransformationVO; }