diff --git a/src/main/java/com/ai/da/mapper/primary/entity/PoseTransformation.java b/src/main/java/com/ai/da/mapper/primary/entity/PoseTransformation.java index cf385bdc..90aef6de 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/PoseTransformation.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/PoseTransformation.java @@ -31,6 +31,8 @@ public class PoseTransformation extends BaseEntity { private byte isDeleted; + private String modelName; + public PoseTransformation() { } diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index b93e92c8..71fc05ce 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -36,7 +36,6 @@ import io.minio.errors.MinioException; import io.netty.util.internal.StringUtil; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; -import org.apache.commons.lang3.StringUtils; import org.apache.http.HttpHost; import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.HttpGet; @@ -691,7 +690,7 @@ public class GenerateServiceImpl extends ServiceImpl i GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class); if (flag) { - type = resolveModelType(taskId); + type = resolveModelType(taskId, null); flag = false; } // 暂定万象每次生成1个 @@ -1123,16 +1122,17 @@ public class GenerateServiceImpl extends ServiceImpl i // 3、生成唯一id 使用uuid,由于uuid重复的几率很小,故取消对uuid重复性的校验 String taskId; Boolean flag = false; + PoseTransformation poseTransformation = new PoseTransformation(); if (poseTransformDTO.getModelName().equals("wx")){ taskId = animateAnyone(poseTransformDTO, accountId); if (!StringUtil.isNullOrEmpty(taskId)) flag = true; + poseTransformation.setModelName("wx"); }else { String uuid = UUID.randomUUID().toString(); taskId = uuid + "-" + accountId; flag = pythonService.poseTransformation(productImage, poseId, taskId); } - PoseTransformation poseTransformation = new PoseTransformation(); poseTransformation.setProjectId(projectId); poseTransformation.setAccountId(accountId); poseTransformation.setUniqueId(taskId); @@ -1189,7 +1189,7 @@ public class GenerateServiceImpl extends ServiceImpl i } public PoseTransformationVO getPoseTransformationResult(String taskId){ - String type = resolveModelType(taskId); + String type = resolveModelType(taskId, CreditsEventsEnum.POSE_TRANSFORMATION.getValue()); if (type.equals("wx")){ return getAnimateResult(taskId); } @@ -2003,8 +2003,19 @@ public class GenerateServiceImpl extends ServiceImpl i return description; } - private String resolveModelType(String taskId){ + private String resolveModelType(String taskId, String func){ // 判断当前task来自哪个模型 + if (!StringUtil.isNullOrEmpty(func) + && func.equals(CreditsEventsEnum.POSE_TRANSFORMATION.getValue())){ + List poseTransformations = poseTransformationMapper.selectList( + new QueryWrapper().eq("unique_id", taskId)); + if (!poseTransformations.isEmpty() && poseTransformations.get(0).getModelName().equals("wx")){ + return "wx"; + }else { + return "local"; + } + } + Generate generate = selectByUniqueId(taskId); if (!StringUtil.isNullOrEmpty(generate.getModelName()) && (generate.getModelName().equals("wx")