PoseTransformation-初版

This commit is contained in:
2025-03-20 17:42:16 +08:00
parent 6a625ed4ea
commit 6b62cf7299
17 changed files with 425 additions and 4 deletions

View File

@@ -1,24 +1,22 @@
package com.ai.da.common.RabbitMQ;
import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.constant.CommonConstant;
import com.ai.da.common.utils.RedisUtil;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.vo.GenerateResultVO;
import com.ai.da.model.vo.PoseTransformationVO;
import com.ai.da.service.GenerateService;
import com.ai.da.service.UserLikeGroupService;
import com.alibaba.fastjson.JSONObject;
import com.google.gson.Gson;
import com.rabbitmq.client.Channel;
import lombok.extern.slf4j.Slf4j;
import org.apache.tomcat.jni.Time;
import org.springframework.amqp.core.Message;
import org.springframework.amqp.rabbit.annotation.RabbitHandler;
import org.springframework.amqp.rabbit.annotation.RabbitListener;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import javax.annotation.Resource;
import java.io.IOException;
@@ -258,6 +256,55 @@ public class GenerateConsumer {
log.info("============ProcessRelightResult End listening==========");
}
public void processPoseTransformResult(Message msg, Channel channel) {
log.info("============ProcessPoseTransformResult listening==========");
long start = System.currentTimeMillis();
Map<String, String> generateResult = JSONObject.parseObject(msg.getBody(), Map.class);
log.info("PoseTransformation response : {}", generateResult);
try {
log.info("tasks_id : {} start ", generateResult.get("tasks_id"));
if (generateResult.get("status").equals("SUCCESS")) {
String gifUrl = generateResult.get("gif_url");
String taskId = generateResult.get("tasks_id");
String videoUrl = generateResult.get("video_url");
String imageUrl = generateResult.get("image_url");
generateService.processPoseTransformResult(taskId, gifUrl, videoUrl, imageUrl);
} else {
// 修改redis中的数据状态为exception
String key = generateResultKey + ":" + generateResult.get("tasks_id");
redisUtil.addToString(key, new Gson().toJson(new PoseTransformationVO(null, generateResult.get("tasks_id"),null, null, null, (byte)0, "Fail")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
// 将异常信息存到exception中
HashMap<String, String> exceptionInfo = new HashMap<>();
exceptionInfo.put(generateResult.get("tasks_id"), generateResult.get("message"));
// 存redis
redisUtil.addToMap(exceptionMapKey, exceptionInfo);
}
} catch (Exception e) {
log.error(e.getMessage());
try {
channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
redisUtil.removeFromZSet(consumptionOrderKey, generateResult.get("tasks_id"));
} catch (IOException exception) {
log.error("手动确认,取消返回队列,不再重新消费");
}
// 将入参和错误信息存入数据库
String exceptionMessage = JSONObject.toJSONString(generateResult) +
" Exception message " + e.getMessage();
HashMap<String, String> exceptionInfo = new HashMap<>();
exceptionInfo.put(String.valueOf(generateResult.get("tasks_id")), exceptionMessage);
// 存redis
redisUtil.addToMap(exceptionMapKey, exceptionInfo);
}
long end = System.currentTimeMillis();
log.info("tasks_id : {}, end , message : {}, 执行时长: {} 毫秒", generateResult.get("tasks_id"), generateResult.get("message"), (end - start));
log.info("============ProcessPoseTransformResult End listening==========");
}
@RabbitListener(queues = "#{rabbitMQProperties.queues.generate}")
@RabbitHandler
public void generateConsumer1(Message msg, Channel channel) {
@@ -329,4 +376,10 @@ public class GenerateConsumer {
public void getRelightResult(Message msg, Channel channel) {
processRelightResult(msg, channel);
}
@RabbitListener(queues = "#{rabbitMQProperties.queues.poseTransform}")
@RabbitHandler
public void getPoseTransformationResult(Message msg, Channel channel) {
processPoseTransformResult(msg, channel);
}
}

View File

@@ -20,6 +20,7 @@ public class RabbitMQProperties {
private String generateResult;
private String toProductImageResult;
private String relightResult;
private String poseTransform;
}
@Data

View File

@@ -30,6 +30,8 @@ public class CommonConstant {
public static final String GENERATE_LOGO_SINGLE_CANCEL = "/api/generate_single_logo_cancel/";
public static final String POSE_TRANSFORMATION_CANCEL = "/api/pose_transform_cancel/";
public static final String PYTHON_PORT_9996 = "9996";
public static final String PYTHON_PORT_9997 = "9997";

View File

@@ -35,6 +35,7 @@ public enum CreditsEventsEnum {
RELIGHT("Relight","5"),
QUESTIONNAIRE("Questionnaire","100"),
IMAGE_TO_SKETCH("ImageToSketch","5"),
POSE_TRANSFORMATION("PoseTransformation","10"),
OTHER("Other","5");

View File

@@ -16,7 +16,6 @@ import org.springframework.web.bind.annotation.*;
import javax.annotation.Resource;
import javax.validation.Valid;
import java.util.List;
/**
* @author XP
*/
@@ -98,4 +97,30 @@ public class GenerateController {
return Response.success(generateService.modifySketch(generateModifyDTO));
}
@ApiOperation(value = "请求poseTransform异步获取结果")
@PostMapping("/poseTransform")
public Response<String> poseTransform(@ApiParam("projectId") @RequestParam Long projectId,
@ApiParam("productImage") @RequestParam String productImage,
@ApiParam("poseId") @RequestParam int poseId) {
return Response.success(generateService.poseTransform(projectId, productImage, poseId));
}
@ApiOperation(value = "获取pose transformation生成结果")
@PostMapping("/poseTransformResult")
public Response<PoseTransformationVO> getPoseTransformationResults(@ApiParam("taskId") @RequestParam String taskId) {
PoseTransformationVO generateResult = generateService.getPoseTransformationResult(taskId);
return Response.success(generateResult);
}
public Response<String> modifyModelProportion(){
return null;
}
public Response<String> sketchReconstruction(){
return null;
}
}

View File

@@ -0,0 +1,7 @@
package com.ai.da.mapper.primary;
import com.ai.da.common.config.mybatis.plus.CommonMapper;
import com.ai.da.mapper.primary.entity.PoseTransformation;
public interface PoseTransformationMapper extends CommonMapper<PoseTransformation> {
}

View File

@@ -0,0 +1,7 @@
package com.ai.da.mapper.primary;
import com.ai.da.mapper.primary.entity.SketchReconstruction;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
public interface SketchReconstructionMapper extends BaseMapper<SketchReconstruction> {
}

View File

@@ -0,0 +1,33 @@
package com.ai.da.mapper.primary.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
@EqualsAndHashCode(callSuper = true)
@TableName("t_pose_transformation")
@Data
public class PoseTransformation extends BaseEntity {
private Long projectId;
private Long accountId;
private String uniqueId;
private String productImage;
private int poseId;
private String gifUrl;
private String videoUrl;
// GIF第一帧截图
private String imageUrl;
private byte isLiked;
private byte isDeleted;
}

View File

@@ -0,0 +1,22 @@
package com.ai.da.mapper.primary.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
@EqualsAndHashCode(callSuper = true)
@Data
@TableName("t_sketch_reconstruction")
public class SketchReconstruction extends BaseEntity{
private Long projectId;
private Long elementId;
// upload、library、generate
private String elementSource;
private String path;
}

View File

@@ -19,4 +19,13 @@ public class ImageToSketchDTO {
@ApiModelProperty("性别")
private String gender;
public ImageToSketchDTO() {
}
public ImageToSketchDTO(Long elementId, String style, String gender) {
this.elementId = elementId;
this.style = style;
this.gender = gender;
}
}

View File

@@ -0,0 +1,32 @@
package com.ai.da.model.dto;
import lombok.Data;
import org.springframework.web.multipart.MultipartFile;
import java.util.List;
@Data
public class SketchReconstructionDTO {
private Long projectId;
private String collagePicture;
private List<Element> elements;
private MultipartFile file;
// like到library时分类用
private String gender;
private boolean Save;
@Data
public static class Element{
private Long elementId;
private String elementSource;
private String path;
}
}

View File

@@ -0,0 +1,25 @@
package com.ai.da.model.vo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class PoseTransformationVO {
private Long id;
private String taskId;
private String gifUrl;
private String videoUrl;
// GIF第一帧截图
private String imageUrl;
private byte isLiked;
private String status;
}

View File

@@ -3799,4 +3799,62 @@ public class PythonService {
throw new BusinessException("design.interface.exception");
}
public Boolean poseTransformation(String productImage, int poseId, String taskId) {
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
MediaType mediaType = MediaType.parse("application/json");
Map<String, String> content = Maps.newHashMap();
content.put("image_url", productImage);
content.put("tasks_id", taskId);
content.put("pose_id", String.valueOf(poseId));
RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content));
log.info("poseTransformation 请求地址: {}", accessPythonIp + ":" + accessPythonPort + "/api/pose_transform");
Request request = new Request.Builder()
.url(accessPythonIp + ":" + accessPythonPort + "/api/pose_transform")
.method("POST", body)
.addHeader("Content-Type", "application/json")
.build();
Response response = null;
String bodyString;
try {
log.info("poseTransformation请求入参content###{}", JSON.toJSONString(content));
response = client.newCall(request).execute();
} catch (IOException ioException) {
log.error("PythonService##poseTransformation异常###{}", ExceptionUtil.getThrowableList(ioException));
throw new BusinessException(ioException.getMessage());
}
// 判断是否生成失败
if (Objects.isNull(response.body())) {
log.error("PythonService##poseTransformation异常###{}", "response or body is empty!");
throw new BusinessException("PythonService##poseTransformation异常###: response or body is empty!");
} else if (response.code() != HttpURLConnection.HTTP_OK) {
log.error("PythonService##poseTransformation异常###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("PythonService##poseTransformation异常### Response error!Response code ## " + response.code() + " ##");
} else {
try {
bodyString = response.body().string();
} catch (IOException e) {
log.error(e.getMessage());
throw new BusinessException(e.getMessage());
}
}
JSONObject jsonObject = JSON.parseObject(bodyString);
Boolean result = JSON.parseObject(JSON.toJSONString(response)).getBoolean("successful");
if (result && jsonObject.get("code").equals(200)) {
log.info("poseTransformation##responseObject###{}", jsonObject);
return Boolean.TRUE;
} else {
log.info("poseTransformation失败###{}", jsonObject);
log.info("poseTransformation Exception! Code : {}", jsonObject.get("code"));
return Boolean.FALSE;
}
}
}

View File

@@ -47,4 +47,10 @@ public interface GenerateService extends IService<Generate> {
GenerateResultVO imageToSketch(ImageToSketchDTO imageToSketchDTO);
GenerateResultVO modifySketch(GenerateModifyDTO generateModifyDTO);
String poseTransform(Long projectId, String productImage, int poseId);
void processPoseTransformResult(String taskId, String gifUrl, String videoUrl, String imageUrl);
PoseTransformationVO getPoseTransformationResult(String taskId);
}

View File

@@ -59,6 +59,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
private RedisUtil redisUtil;
@Resource
private GenerateCancelMapper generateCancelMapper;
@Resource
private SketchReconstructionMapper sketchReconstructionMapper;
@Value("${redis.key.orderForGenerate}")
private String consumptionOrderKey;
@@ -715,6 +717,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
String path;
if (type.equals("Logo")) {
path = CommonConstant.GENERATE_LOGO_SINGLE_CANCEL;
} else if(type.equals("PoseTransformation")){
path =CommonConstant.POSE_TRANSFORMATION_CANCEL;
} else {
path = CommonConstant.GENERATE_CANCEL;
}
@@ -920,4 +924,138 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
return new GenerateResultVO(generateDetailId, minioUtil.getPreSignedUrl(minioPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME, true), "Success", category);
}
public String poseTransform(Long projectId, String productImage, int poseId){
Long accountId = UserContext.getUserHolder().getId();
// 1、判断用户当前积分是否够本次生成消耗
CreditsEventsEnum creditsEventsEnum = CreditsEventsEnum.POSE_TRANSFORMATION;
Boolean preDeduction = creditsService.creditsPreDeduction(creditsEventsEnum, 1);
if (!preDeduction) {
throw new BusinessException("remaining.credits.insufficient", ResultEnum.WARNING.getCode());
}
// 3、生成唯一id 使用uuid,由于uuid重复的几率很小故取消对uuid重复性的校验
String uuid = UUID.randomUUID().toString();
String taskId = uuid + "-" + accountId;
PoseTransformation poseTransformation = new PoseTransformation();
poseTransformation.setProjectId(projectId);
poseTransformation.setAccountId(accountId);
poseTransformation.setUniqueId(taskId);
poseTransformation.setProductImage(productImage);
poseTransformation.setPoseId(poseId);
poseTransformation.setCreateTime(LocalDateTime.now());
poseTransformationMapper.insert(poseTransformation);
Boolean b = pythonService.poseTransformation(productImage, poseId, taskId);
if (b){
// 6、添加预扣除积分到redis
creditsService.addRecordToCreditsDeduction(accountId, uuid, creditsEventsEnum);
// 6.1 添加积分扣除记录到db
creditsService.preInsert(accountId, creditsEventsEnum.getName(), uuid, Boolean.TRUE, null);
return taskId;
}
throw new BusinessException("pose transformation error", ResultEnum.ERROR.getCode());
}
@Resource
private PoseTransformationMapper poseTransformationMapper;
public void processPoseTransformResult(String taskId, String gifUrl, String videoUrl, String imageUrl){
// 1、存储模型返回的数据
PoseTransformation poseTransformation;
QueryWrapper<PoseTransformation> qw = new QueryWrapper<>();
qw.eq("unique_id", taskId);
List<PoseTransformation> poseTransformations = poseTransformationMapper.selectList(qw);
if (poseTransformations != null && poseTransformations.size() > 1){
log.warn("通过taskId {} 查询到的PoseTransformation的结果不止一条", taskId);
}else if (poseTransformations == null || poseTransformations.isEmpty()){
return ;
}
poseTransformation = poseTransformations.get(0);
poseTransformation.setGifUrl(gifUrl);
poseTransformation.setVideoUrl(videoUrl);
poseTransformation.setImageUrl(imageUrl);
poseTransformation.setUpdateTime(LocalDateTime.now());
poseTransformationMapper.updateById(poseTransformation);
String key = generateResultKey + ":" + taskId;
PoseTransformationVO poseTransformationVO = new PoseTransformationVO(
poseTransformation.getId(), taskId, gifUrl, videoUrl, imageUrl, (byte) 0, "Success");
// 2、更新redis
redisUtil.addToString(key, new Gson().toJson(poseTransformationVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
// 3、执行积分扣除
String accountId = taskId.substring(taskId.lastIndexOf("-") + 1);
String uuid = taskId.substring(0, taskId.lastIndexOf("-"));
Boolean flag = creditsService.taskCreditsDeduction(Long.parseLong(accountId), uuid);
if (flag) creditsService.updateChangedCredits(accountId, uuid);
}
public PoseTransformationVO getPoseTransformationResult(String taskId){
String key = generateResultKey + ":" + taskId;
String resultJson = redisUtil.getFromString(key);
if (!StringUtil.isNullOrEmpty(resultJson)){
PoseTransformationVO poseTransformationVO = new Gson().fromJson(redisUtil.getFromString(key), PoseTransformationVO.class);
if (poseTransformationVO.getStatus().equals("Success")){
poseTransformationVO.setGifUrl(minioUtil.getPreSignedUrl(poseTransformationVO.getGifUrl(), CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
poseTransformationVO.setVideoUrl(minioUtil.getPreSignedUrl(poseTransformationVO.getVideoUrl(), CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
poseTransformationVO.setImageUrl(minioUtil.getPreSignedUrl(poseTransformationVO.getImageUrl(), CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
}
return poseTransformationVO;
}else {
return new PoseTransformationVO();
}
}
/**
* String collagePicture(Base64)
* List<DTO(id, type)> elements
* File file
* @return
*/
public String sketchReconstruction(SketchReconstructionDTO sketchReconstructionDTO){
Long accountId = UserContext.getUserHolder().getId();
// 1、线稿生成
String collagePictureBase64 = sketchReconstructionDTO.getCollagePicture();
String path = accountId + "/CollagePicture/" + UUID.randomUUID();
String minioPath = minioUtil.base64UploadToPath(collagePictureBase64, userBucket, path);
CollectionElement collectionElement = new CollectionElement();
collectionElement.setAccountId(accountId);
collectionElement.setLevel1Type(SKETCH_BOARD.getRealName());
collectionElement.setUrl(minioPath);
collectionElement.setMd5(MD5Utils.encryptFile(minioPath, false));
collectionElement.setCreateDate(new Date());
collectionElementService.save(collectionElement);
GenerateResultVO generateResultVO = imageToSketch(new ImageToSketchDTO(collectionElement.getId(), "2", sketchReconstructionDTO.getGender()));
// 2、以文件形式保存元素同时还要将使用的元素单独存储
if (sketchReconstructionDTO.isSave() && !sketchReconstructionDTO.getElements().isEmpty()){
// 将使用的元素全部都保存到新建表
// 先判断该project下有没有数据无 --> 直接保存;有 --> 先删除,再保存
QueryWrapper<SketchReconstruction> qw = new QueryWrapper<>();
qw.eq("project_id", sketchReconstructionDTO.getProjectId());
List<SketchReconstruction> sketchReconstructions = sketchReconstructionMapper.selectList(qw);
if (!sketchReconstructions.isEmpty()){
sketchReconstructionMapper.delete(qw);
}
sketchReconstructionDTO.getElements().forEach(element -> {
SketchReconstruction sketchReconstruction = new SketchReconstruction();
sketchReconstruction.setProjectId(sketchReconstructionDTO.getProjectId());
sketchReconstruction.setElementId(element.getElementId());
sketchReconstruction.setElementSource(element.getElementSource());
sketchReconstruction.setPath(element.getPath());
sketchReconstructionMapper.insert(sketchReconstruction);
});
// 将画布文件上传到minio,地址保存到project表中
String canvasPath = minioUtil.upload("aida-users", accountId + "/CollagePicture/CanvasFile", sketchReconstructionDTO.getFile());
}
// 需要返回哪些信息呢?
return null;
}
}