Merge branch 'dev/dev_xp' into dev/dev
This commit is contained in:
@@ -3,13 +3,14 @@ package com.ai.da.common.utils;
|
||||
import com.ai.da.common.config.exception.BusinessException;
|
||||
import com.ai.da.mapper.primary.entity.ObjectItem;
|
||||
import io.minio.*;
|
||||
import io.minio.errors.MinioException;
|
||||
import io.minio.errors.*;
|
||||
import io.minio.http.Method;
|
||||
import io.minio.messages.DeleteError;
|
||||
import io.minio.messages.DeleteObject;
|
||||
import io.minio.messages.Item;
|
||||
import io.netty.util.internal.StringUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
@@ -617,6 +618,59 @@ public class MinioUtil {
|
||||
}
|
||||
}
|
||||
|
||||
public String getImageAsBase64(String path) throws IOException {
|
||||
int index = path.indexOf("/");
|
||||
String bucketName = path.substring(0, index);
|
||||
String fileName = path.substring(index + 1);
|
||||
|
||||
// 检查桶是否存在
|
||||
boolean found = doesObjectExist(bucketName, fileName);
|
||||
if (!found) {
|
||||
throw new IOException("Bucket " + bucketName + " does not exist");
|
||||
}
|
||||
|
||||
try (InputStream stream = minioClient.getObject(
|
||||
GetObjectArgs.builder()
|
||||
.bucket(bucketName)
|
||||
.object(fileName)
|
||||
.build())) {
|
||||
|
||||
byte[] bytes = IOUtils.toByteArray(stream);
|
||||
return Base64.getEncoder().encodeToString(bytes);
|
||||
} catch (ServerException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (InsufficientDataException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (ErrorResponseException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (InvalidKeyException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (InvalidResponseException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (XmlParserException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (InternalException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public void uploadToMinio(byte[] data, String bucket, String objectName, String contentType) /*throws Exception*/ {
|
||||
try {
|
||||
minioClient.putObject(PutObjectArgs.builder()
|
||||
.bucket(bucket)
|
||||
.object(objectName)
|
||||
.stream(new ByteArrayInputStream(data), data.length, -1)
|
||||
.contentType(contentType)
|
||||
.build());
|
||||
} catch (MinioException | IOException | NoSuchAlgorithmException | InvalidKeyException e){
|
||||
log.error("图片上传到minio出错,{}", e.getMessage());
|
||||
throw new BusinessException("file upload exception");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -498,6 +498,8 @@ public class RedisUtil {
|
||||
|
||||
public final static String STRIPE_EXCEPTION_LOG = "StripeException:";
|
||||
|
||||
public final static String ANIMATE_ANYONE_TEMPLATE_ID = "AnimateAnyoneTemplateId:";
|
||||
|
||||
public void batchDeleteKeysWithSamePrefix(String prefix){
|
||||
Set<String> keys = redisTemplate.keys(prefix + "*");
|
||||
assert keys != null;
|
||||
|
||||
92
src/main/java/com/ai/da/common/utils/SendRequestUtil.java
Normal file
92
src/main/java/com/ai/da/common/utils/SendRequestUtil.java
Normal file
@@ -0,0 +1,92 @@
|
||||
package com.ai.da.common.utils;
|
||||
|
||||
import cn.hutool.http.Header;
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import cn.hutool.json.JSONObject;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class SendRequestUtil {
|
||||
|
||||
@Value("${ALIYUN_API_KEY}")
|
||||
private String AliYunAPIKey;
|
||||
@Value("${FREEPIK_API_KEY}")
|
||||
private String freepikAPIKey;
|
||||
|
||||
public String sendAliYunPostAsync(String apiUrl, String requestBody){
|
||||
// 发送POST请求 todo 异常处理
|
||||
HttpResponse execute = HttpRequest.post(apiUrl)
|
||||
.header(Header.AUTHORIZATION, "Bearer " + AliYunAPIKey)
|
||||
.header("X-DashScope-Async", "enable")
|
||||
.header(Header.CONTENT_TYPE, "application/json")
|
||||
.body(requestBody)
|
||||
.timeout(20000) // 设置超时时间20秒
|
||||
.execute();
|
||||
int status = execute.getStatus();
|
||||
if (status == 200){
|
||||
String body = execute.body();
|
||||
JSONObject bodyJson = JSONUtil.parseObj(body);
|
||||
return body;
|
||||
}
|
||||
log.warn("请求失败,状态码为 : {}", status);
|
||||
return null;
|
||||
}
|
||||
|
||||
public static final String FREE_PIK = "https://api.freepik.com/v1/ai/beta/text-to-image/reimagine-flux";
|
||||
public String sendFreepikPost( String requestBody){
|
||||
// 发送POST请求 todo 异常处理
|
||||
HttpResponse execute = HttpRequest.post(FREE_PIK)
|
||||
.header(Header.CONTENT_TYPE, "application/json")
|
||||
.header("x-freepik-api-key", freepikAPIKey)
|
||||
.body(requestBody)
|
||||
.timeout(20000) // 设置超时时间20秒
|
||||
.execute();
|
||||
int status = execute.getStatus();
|
||||
if (status == 200){
|
||||
return execute.body();
|
||||
}
|
||||
log.warn("请求失败,状态码为 : {}", status);
|
||||
return null;
|
||||
}
|
||||
|
||||
public String sendAliYunGet(String fullUrl){
|
||||
// 发送GET请求 todo 异常处理
|
||||
HttpResponse httpResponse = HttpRequest.get(fullUrl)
|
||||
.header(Header.AUTHORIZATION, "Bearer " + AliYunAPIKey)
|
||||
.timeout(20000) // 设置超时时间20秒
|
||||
.execute();
|
||||
int status = httpResponse.getStatus();
|
||||
if (status == 200){
|
||||
return httpResponse.body();
|
||||
}else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public String sendPost(String url, String requestBodyStr){
|
||||
int status;
|
||||
String body;
|
||||
try (HttpResponse execute = HttpRequest.post(url)
|
||||
.header("Content-Type", "application/json") // 必须设置 Content-Type
|
||||
.body(requestBodyStr) // Hutool 会自动处理 JSON 序列化
|
||||
.timeout(120000) // 设置超时(毫秒)
|
||||
.execute()) {
|
||||
|
||||
status = execute.getStatus();
|
||||
body = execute.body();
|
||||
if (status == 200) {
|
||||
return body;
|
||||
}
|
||||
}
|
||||
log.warn("请求失败,状态码为 : {}, body: {}", status, body);
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -9,10 +9,10 @@ import io.swagger.annotations.ApiOperation;
|
||||
import io.swagger.annotations.ApiParam;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import javax.validation.Valid;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -98,16 +98,14 @@ public class GenerateController {
|
||||
}
|
||||
|
||||
@ApiOperation(value = "请求进行姿势变换")
|
||||
@GetMapping("/poseTransform")
|
||||
public Response<String> poseTransform(@ApiParam("projectId") @RequestParam Long projectId,
|
||||
@ApiParam("productImage") @RequestParam String productImage,
|
||||
@ApiParam("poseId") @RequestParam int poseId) {
|
||||
return Response.success(generateService.poseTransform(projectId, productImage, poseId));
|
||||
@PostMapping("/poseTransform")
|
||||
public Response<String> poseTransform(@Valid @RequestBody PoseTransformDTO poseTransformDTO) {
|
||||
return Response.success(generateService.poseTransform(poseTransformDTO));
|
||||
}
|
||||
|
||||
@ApiOperation(value = "获取姿势变换生成结果")
|
||||
@GetMapping("/poseTransformResult")
|
||||
public Response<PoseTransformationVO> getPoseTransformationResults(@ApiParam("taskId") @RequestParam String taskId) {
|
||||
@PostMapping("/poseTransformResult")
|
||||
public Response<PoseTransformationVO> getPoseTransformationResults(@RequestParam String taskId) {
|
||||
PoseTransformationVO generateResult = generateService.getPoseTransformationResult(taskId);
|
||||
return Response.success(generateResult);
|
||||
}
|
||||
@@ -145,6 +143,52 @@ public class GenerateController {
|
||||
return Response.success(generateService.getAllPose());
|
||||
}
|
||||
|
||||
/*@ApiOperation(value = "万象 t2i 创建异步任务")
|
||||
@GetMapping("/createAsyncTask")
|
||||
public Response<String> createAsyncTask(@RequestParam("prompt") String prompt){
|
||||
return Response.success(generateService.createAsyncTask(87L, prompt, ""));
|
||||
}
|
||||
|
||||
@ApiOperation(value = "万象 t2i 获取异步任务结果")
|
||||
@GetMapping("/waitAsyncTask")
|
||||
public Response<GenerateResultVO> waitAsyncTask(@RequestParam("taskId") String taskId){
|
||||
return Response.success(generateService.getAsyncTaskResult(taskId));
|
||||
}
|
||||
|
||||
@ApiOperation(value = "万象 图生动图")
|
||||
@GetMapping("/animateAnyone")
|
||||
public Response<String> animateAnyone(@Valid @RequestBody PoseTransformDTO poseTransformDTO){
|
||||
return Response.success(generateService.animateAnyone(poseTransformDTO, null));
|
||||
}
|
||||
|
||||
@ApiOperation(value = "万象 获取动图模板id")
|
||||
@GetMapping("/getVideoTemplateId")
|
||||
public Response<String> getVideoTemplateId(@RequestParam("videoPath") String videoPath){
|
||||
return Response.success(generateService.getVideoTemplateId(videoPath));
|
||||
}
|
||||
|
||||
@ApiOperation(value = "万象 获取动图结果")
|
||||
@GetMapping("/getAnimateResult")
|
||||
public Response<PoseTransformationVO> getAnimateResult(@RequestParam("taskId") String taskId){
|
||||
return Response.success(generateService.getAnimateResult(taskId));
|
||||
}*/
|
||||
|
||||
@ApiOperation(value = "freepik toProductImage")
|
||||
@GetMapping("/reimagineFreePik")
|
||||
public Response<String> reimagineFreePik(@RequestParam("path") String path,
|
||||
@RequestParam("prompt") String prompt,
|
||||
@RequestParam("style") String style) throws IOException {
|
||||
return Response.success(generateService.reimagineFreePik(path, prompt, style));
|
||||
}
|
||||
|
||||
@ApiOperation(value = "获取图片描述")
|
||||
@GetMapping("/getImageDescription")
|
||||
public Response<String> getImageDescription(@RequestParam("path") String path) {
|
||||
return Response.success(generateService.getImageDescription(path));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -104,4 +104,16 @@ public class Generate {
|
||||
*/
|
||||
private Date updateDate;
|
||||
|
||||
public Generate() {
|
||||
}
|
||||
|
||||
public Generate(Long accountId, String uniqueId, String level1Type, String text, String generateType, String modelName, Date createDate) {
|
||||
this.accountId = accountId;
|
||||
this.uniqueId = uniqueId;
|
||||
this.level1Type = level1Type;
|
||||
this.text = text;
|
||||
this.generateType = generateType;
|
||||
this.modelName = modelName;
|
||||
this.createDate = createDate;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,5 +31,14 @@ public class PoseTransformation extends BaseEntity {
|
||||
|
||||
private byte isDeleted;
|
||||
|
||||
public PoseTransformation() {
|
||||
}
|
||||
|
||||
public PoseTransformation(Long projectId, Long accountId, String uniqueId, String productImage, int poseId) {
|
||||
this.projectId = projectId;
|
||||
this.accountId = accountId;
|
||||
this.uniqueId = uniqueId;
|
||||
this.productImage = productImage;
|
||||
this.poseId = poseId;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,8 +39,8 @@ public class GenerateThroughImageTextDTO {
|
||||
private String gender;
|
||||
|
||||
|
||||
@ApiModelProperty("选择的模型名 high || fast")
|
||||
private String version;
|
||||
@ApiModelProperty("选择的模型名 high || fast || wx || fp")
|
||||
private String modelName;
|
||||
|
||||
@NotBlank(message = "timeZone cannot be empty!")
|
||||
@ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取")
|
||||
|
||||
24
src/main/java/com/ai/da/model/dto/PoseTransformDTO.java
Normal file
24
src/main/java/com/ai/da/model/dto/PoseTransformDTO.java
Normal file
@@ -0,0 +1,24 @@
|
||||
package com.ai.da.model.dto;
|
||||
|
||||
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import javax.validation.constraints.NotBlank;
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
@Data
|
||||
public class PoseTransformDTO {
|
||||
@ApiModelProperty("项目id")
|
||||
private Long projectId;
|
||||
|
||||
@ApiModelProperty("图片的minio地址")
|
||||
@NotBlank(message = "please select a product image")
|
||||
private String productImage;
|
||||
|
||||
@ApiModelProperty("pose的编号")
|
||||
@NotNull(message = "please select a pose")
|
||||
private Integer poseId;
|
||||
|
||||
private String modelName;
|
||||
}
|
||||
@@ -17,6 +17,7 @@ public class GenerateResultVO {
|
||||
|
||||
private String url;
|
||||
|
||||
// Success || Executing || Invalid || Failed
|
||||
private String status;
|
||||
|
||||
private String category;
|
||||
|
||||
@@ -23,6 +23,7 @@ public class PoseTransformationVO implements AllCollectionVO{
|
||||
|
||||
private byte isLiked;
|
||||
|
||||
// Success || Executing || Invalid || Failed
|
||||
private String status;
|
||||
private String collectionType;
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ import com.ai.da.mapper.primary.entity.GenerateDetail;
|
||||
import com.ai.da.model.dto.*;
|
||||
import com.ai.da.model.vo.*;
|
||||
import com.baomidou.mybatisplus.extension.service.IService;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -46,7 +46,7 @@ public interface GenerateService extends IService<Generate> {
|
||||
|
||||
GenerateResultVO modifySketch(GenerateModifyDTO generateModifyDTO);
|
||||
|
||||
String poseTransform(Long projectId, String productImage, int poseId);
|
||||
String poseTransform(PoseTransformDTO poseTransformDTO);
|
||||
|
||||
void processPoseTransformResult(String taskId, String gifUrl, String videoUrl, String imageUrl);
|
||||
|
||||
@@ -65,4 +65,18 @@ public interface GenerateService extends IService<Generate> {
|
||||
List<Map<String, String>> getAllPose();
|
||||
|
||||
void processPoseTransformResultBatch(String taskId, String gifUrl, String videoUrl, String imageUrl, String progress);
|
||||
|
||||
String createAsyncTask(GenerateThroughImageTextDTO generateThroughImageTextDTO);
|
||||
|
||||
GenerateResultVO getAsyncTaskResult(String taskId);
|
||||
|
||||
String animateAnyone(PoseTransformDTO poseTransformDTO, Long accountId);
|
||||
|
||||
String getVideoTemplateId(String videoPath);
|
||||
|
||||
PoseTransformationVO getAnimateResult(String taskId);
|
||||
|
||||
String reimagineFreePik(String path, String prompt, String style) throws IOException;
|
||||
|
||||
String getImageDescription(String imagePath);
|
||||
}
|
||||
|
||||
@@ -331,6 +331,8 @@ public class CreditsServiceImpl extends ServiceImpl<CreditsDetailMapper, Credits
|
||||
credits = CreditsEventsEnum.SLOGAN.getValue();
|
||||
}else if (changeEvent.equals("PoseTransformation")){
|
||||
credits = CreditsEventsEnum.POSE_TRANSFORMATION.getValue();
|
||||
}else if (changeEvent.equals("Other")){
|
||||
credits = CreditsEventsEnum.OTHER.getValue();
|
||||
}
|
||||
|
||||
// BigDecimal finalCredits = currentCredits.subtract(new BigDecimal(credits));
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
package com.ai.da.service.impl;
|
||||
|
||||
import cn.hutool.core.img.gif.AnimatedGifEncoder;
|
||||
import cn.hutool.http.Header;
|
||||
import cn.hutool.http.HttpRequest;
|
||||
import cn.hutool.http.HttpResponse;
|
||||
import cn.hutool.json.JSONObject;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.ai.da.common.config.exception.BusinessException;
|
||||
import com.ai.da.common.constant.CommonConstant;
|
||||
import com.ai.da.common.context.UserContext;
|
||||
@@ -14,6 +20,12 @@ import com.ai.da.model.enums.SketchStyle;
|
||||
import com.ai.da.model.vo.*;
|
||||
import com.ai.da.python.PythonService;
|
||||
import com.ai.da.service.*;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
|
||||
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
|
||||
import com.alibaba.dashscope.exception.ApiException;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import com.alibaba.dashscope.utils.JsonUtils;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
@@ -24,16 +36,28 @@ import com.google.gson.Gson;
|
||||
import io.minio.errors.MinioException;
|
||||
import io.netty.util.internal.StringUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.http.client.methods.HttpGet;
|
||||
import org.apache.http.impl.client.CloseableHttpClient;
|
||||
import org.apache.http.impl.client.HttpClients;
|
||||
import org.bytedeco.javacv.FFmpegFrameGrabber;
|
||||
import org.bytedeco.javacv.Java2DFrameConverter;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.dao.DuplicateKeyException;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.io.IOException;
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.*;
|
||||
import java.math.BigDecimal;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*;
|
||||
|
||||
@@ -63,6 +87,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
private GenerateCancelMapper generateCancelMapper;
|
||||
@Resource
|
||||
private SketchReconstructionMapper sketchReconstructionMapper;
|
||||
@Resource
|
||||
private SendRequestUtil sendRequestUtil;
|
||||
|
||||
@Value("${redis.key.orderForGenerate}")
|
||||
private String consumptionOrderKey;
|
||||
@@ -91,6 +117,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
@Value("${access.python.generate_sr_port}")
|
||||
private String generateServicePort;
|
||||
|
||||
@Value("${ollama.url}")
|
||||
private String ollamaUrl;
|
||||
|
||||
@Resource
|
||||
private AccountService accountService;
|
||||
|
||||
@@ -128,7 +157,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ?
|
||||
generateType + " (" + generateThroughImageTextDTO.getGender() + ")" :
|
||||
generateType);
|
||||
generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getVersion());
|
||||
generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getModelName());
|
||||
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
|
||||
generate.setElementSource(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getDesignType()) ? null : generateThroughImageTextDTO.getDesignType());
|
||||
|
||||
@@ -155,10 +184,10 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
String jsonString = "";
|
||||
HashMap<String, String> params = new HashMap<>();
|
||||
String version = null;
|
||||
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) && generateThroughImageTextDTO.getVersion().equals("high")){
|
||||
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) && generateThroughImageTextDTO.getModelName().equals("high")){
|
||||
version = "high";
|
||||
params.put("version","high");
|
||||
}else if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) && generateThroughImageTextDTO.getVersion().equals("fast")){
|
||||
}else if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) && generateThroughImageTextDTO.getModelName().equals("fast")){
|
||||
version = "fast";
|
||||
params.put("version","fast");
|
||||
}
|
||||
@@ -487,21 +516,21 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
if (Objects.isNull(generateThroughImageTextDTO.getUserId())) {
|
||||
throw new BusinessException("userId cannot be empty");
|
||||
}
|
||||
/*String generateType = generateThroughImageTextDTO.getGenerateType();
|
||||
if (!GenerateModeEnum.getGenerateModeList().contains(generateType)) {
|
||||
throw new BusinessException("unknown.generate.type");
|
||||
}*/
|
||||
|
||||
// 判断试用用户是否还有剩余试用机会
|
||||
// ** 不再通过生成次数限制试用用户,统一使用积分限制
|
||||
/*int trialsCount = 0;
|
||||
if (generateThroughImageTextDTO.getIsTestUser()) {
|
||||
trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type());
|
||||
if (trialsCount >= 2) {
|
||||
return new PrepareForGenerateVO(0);
|
||||
}
|
||||
}*/
|
||||
CreditsEventsEnum creditsEventsEnum = CreditsEventsEnum.OTHER;
|
||||
|
||||
if (generateThroughImageTextDTO.getModelName().equals("wx")){
|
||||
String taskId = createAsyncTask(generateThroughImageTextDTO);
|
||||
// String taskId = "e53c86ea-53be-424b-8ac7-3c01c141f4f7";
|
||||
// 6、添加预扣除积分到redis
|
||||
creditsService.addRecordToCreditsDeduction(generateThroughImageTextDTO.getUserId(), taskId, creditsEventsEnum);
|
||||
// 6.1 添加积分扣除记录到db
|
||||
creditsService.preInsert(generateThroughImageTextDTO.getUserId(), creditsEventsEnum.getName(), taskId, Boolean.TRUE, null);
|
||||
|
||||
// 7、返回唯一id
|
||||
return new PrepareForGenerateVO(Collections.singletonList(taskId), 2);
|
||||
}
|
||||
|
||||
int times = 4;
|
||||
// 当level1Type为Print_board时,level2Type为pattern时需要确定generateType
|
||||
if (generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())) {
|
||||
@@ -524,7 +553,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
// 模型迁移SD1.? -> flux,从而产生了不同模型的选择,
|
||||
// high -> 生成图片质量高,但生成速度慢,每次生成只返回一张图片
|
||||
// fast -> 生成图片质量低,但生成速度快,每次生成返回四张图片
|
||||
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) && generateThroughImageTextDTO.getVersion().equals("high")){
|
||||
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) && generateThroughImageTextDTO.getModelName().equals("high")){
|
||||
times = 1;
|
||||
}
|
||||
}
|
||||
@@ -583,12 +612,12 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
}
|
||||
} else if (generateThroughImageTextDTO.getLevel1Type().equals(MOOD_BOARD.getRealName())) {
|
||||
creditsEventsEnum = CreditsEventsEnum.MOOD_BOARD;
|
||||
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) && generateThroughImageTextDTO.getVersion().equals("high")){
|
||||
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) && generateThroughImageTextDTO.getModelName().equals("high")){
|
||||
times = 1;
|
||||
}
|
||||
} else if (generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName())) {
|
||||
creditsEventsEnum = CreditsEventsEnum.SKETCH_BOARD;
|
||||
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) && generateThroughImageTextDTO.getVersion().equals("high")){
|
||||
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getModelName()) && generateThroughImageTextDTO.getModelName().equals("high")){
|
||||
times = 1;
|
||||
}
|
||||
}
|
||||
@@ -604,7 +633,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
|
||||
// 除了 Moodboard || Printboard->Pattern(可以区分三种风格) || Sketchboard(Generate Sketch)这三个地方需要区分high || fast之外,其他地方保持原样
|
||||
if (generateThroughImageTextDTO.getLevel1Type().equals("Printboard") && !generateThroughImageTextDTO.getLevel2Type().equals("Pattern")){
|
||||
generateThroughImageTextDTO.setVersion(null);
|
||||
generateThroughImageTextDTO.setModelName(null);
|
||||
}
|
||||
|
||||
ArrayList<String> taskIdList = new ArrayList<>();
|
||||
@@ -647,22 +676,37 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
public List<GenerateResultVO> getGenerateResultList(List<String> taskIdList) {
|
||||
List<GenerateResultVO> results = new ArrayList<>();
|
||||
Set<String> collect = new HashSet<>();
|
||||
taskIdList.forEach(taskId -> {
|
||||
boolean flag = true;
|
||||
String type = null;
|
||||
for (String taskId : taskIdList) {
|
||||
if (flag) {
|
||||
type = resolveModelType(taskId);
|
||||
flag = false;
|
||||
}
|
||||
// 暂定万象每次生成1个
|
||||
if (type.equals("wx")){
|
||||
return Collections.singletonList(getAsyncTaskResult(taskId));
|
||||
}
|
||||
|
||||
String key = generateResultKey + ":" + taskId;
|
||||
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
|
||||
if (!Objects.isNull(generateResultVO) && !StringUtil.isNullOrEmpty(generateResultVO.getUrl())) {
|
||||
|
||||
if (generateResultVO != null && !StringUtil.isNullOrEmpty(generateResultVO.getUrl())) {
|
||||
String url = generateResultVO.getUrl();
|
||||
if (url.substring(url.lastIndexOf("/") + 1).equals("white_image.jpg")) {
|
||||
generateResultVO.setStatus("Invalid");
|
||||
} else {
|
||||
generateResultVO.setUrl(minioUtil.getPreSignedUrl(url, CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
|
||||
}
|
||||
} else if (Objects.isNull(generateResultVO)) {
|
||||
} else if (generateResultVO == null) {
|
||||
generateResultVO = new GenerateResultVO();
|
||||
}
|
||||
if (!StringUtil.isNullOrEmpty(generateResultVO.getStatus())) collect.add(generateResultVO.getStatus());
|
||||
|
||||
if (!StringUtil.isNullOrEmpty(generateResultVO.getStatus())) {
|
||||
collect.add(generateResultVO.getStatus());
|
||||
}
|
||||
results.add(generateResultVO);
|
||||
});
|
||||
}
|
||||
|
||||
if (taskIdList.size() == 4 && collect.size() == 1 && collect.contains("Fail")) {
|
||||
log.info("当前4个生成结果均为失败");
|
||||
@@ -934,8 +978,11 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
return new GenerateResultVO(generateDetailId, minioUtil.getPreSignedUrl(minioPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME, true), "Success", category);
|
||||
}
|
||||
|
||||
public String poseTransform(Long projectId, String productImage, int poseId){
|
||||
public String poseTransform(PoseTransformDTO poseTransformDTO){
|
||||
Long accountId = UserContext.getUserHolder().getId();
|
||||
Long projectId = poseTransformDTO.getProjectId();
|
||||
String productImage = poseTransformDTO.getProductImage();
|
||||
Integer poseId = poseTransformDTO.getPoseId();
|
||||
|
||||
// 1、判断用户当前积分是否够本次生成消耗
|
||||
CreditsEventsEnum creditsEventsEnum = CreditsEventsEnum.POSE_TRANSFORMATION;
|
||||
@@ -945,8 +992,16 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
}
|
||||
|
||||
// 3、生成唯一id 使用uuid,由于uuid重复的几率很小,故取消对uuid重复性的校验
|
||||
String uuid = UUID.randomUUID().toString();
|
||||
String taskId = uuid + "-" + accountId;
|
||||
String taskId;
|
||||
Boolean flag = false;
|
||||
if (poseTransformDTO.getModelName().equals("wx")){
|
||||
taskId = animateAnyone(poseTransformDTO, accountId);
|
||||
if (!StringUtil.isNullOrEmpty(taskId)) flag = true;
|
||||
}else {
|
||||
String uuid = UUID.randomUUID().toString();
|
||||
taskId = uuid + "-" + accountId;
|
||||
flag = pythonService.poseTransformation(productImage, poseId, taskId);
|
||||
}
|
||||
|
||||
PoseTransformation poseTransformation = new PoseTransformation();
|
||||
poseTransformation.setProjectId(projectId);
|
||||
@@ -957,12 +1012,11 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
poseTransformation.setCreateTime(LocalDateTime.now());
|
||||
poseTransformationMapper.insert(poseTransformation);
|
||||
|
||||
Boolean b = pythonService.poseTransformation(productImage, poseId, taskId);
|
||||
if (b){
|
||||
if (flag){
|
||||
// 6、添加预扣除积分到redis
|
||||
creditsService.addRecordToCreditsDeduction(accountId, uuid, creditsEventsEnum);
|
||||
creditsService.addRecordToCreditsDeduction(accountId, taskId, creditsEventsEnum);
|
||||
// 6.1 添加积分扣除记录到db
|
||||
creditsService.preInsert(accountId, creditsEventsEnum.getName(), uuid, Boolean.TRUE, null);
|
||||
creditsService.preInsert(accountId, creditsEventsEnum.getName(), taskId, Boolean.TRUE, null);
|
||||
return taskId;
|
||||
}
|
||||
throw new BusinessException("pose transformation error", ResultEnum.ERROR.getCode());
|
||||
@@ -1002,12 +1056,16 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
|
||||
// 3、执行积分扣除
|
||||
String accountId = taskId.substring(taskId.lastIndexOf("-") + 1);
|
||||
String uuid = taskId.substring(0, taskId.lastIndexOf("-"));
|
||||
Boolean flag = creditsService.taskCreditsDeduction(Long.parseLong(accountId), uuid);
|
||||
if (flag) creditsService.updateChangedCredits(accountId, uuid);
|
||||
// String uuid = taskId.substring(0, taskId.lastIndexOf("-"));
|
||||
Boolean flag = creditsService.taskCreditsDeduction(Long.parseLong(accountId), taskId);
|
||||
if (flag) creditsService.updateChangedCredits(accountId, taskId);
|
||||
}
|
||||
|
||||
public PoseTransformationVO getPoseTransformationResult(String taskId){
|
||||
String type = resolveModelType(taskId);
|
||||
if (type.equals("wx")){
|
||||
return getAnimateResult(taskId);
|
||||
}
|
||||
String key = generateResultKey + ":" + taskId;
|
||||
String resultJson = redisUtil.getFromString(key);
|
||||
|
||||
@@ -1210,6 +1268,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
String posePath = "aida-sys-image/pose/pose-1.gif";
|
||||
String firstFramePath = "aida-sys-image/pose/pose-1-first_frame.jpeg";
|
||||
HashMap<String, String> resp = new HashMap<>();
|
||||
// todo 以后要返回poseId
|
||||
resp.put("gif", minioUtil.getPreSignedUrl(posePath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
|
||||
resp.put("firstFrame", minioUtil.getPreSignedUrl(firstFramePath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
|
||||
return Arrays.asList(resp);
|
||||
@@ -1265,4 +1324,554 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
if (flag) creditsService.updateChangedCredits(accountId, uuid);
|
||||
}
|
||||
|
||||
/**
|
||||
* 万象专业版
|
||||
* 1、MoodBoard t2i
|
||||
* 2、PrintBoard t2i
|
||||
* 3、SketchBoard t2i
|
||||
* 4、pose transfer 图生舞蹈视频-舞动人像AnimateAnyone
|
||||
*/
|
||||
|
||||
/**
|
||||
* 创建异步任务
|
||||
* @return taskId
|
||||
*/
|
||||
public String createAsyncTask(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||
// String prompt = "一间有着精致窗户的花店,漂亮的木质门,摆放着花朵";
|
||||
String level1Type = generateThroughImageTextDTO.getLevel1Type();
|
||||
String prompt = generateThroughImageTextDTO.getText();
|
||||
Long userId = generateThroughImageTextDTO.getUserId();
|
||||
String gender = generateThroughImageTextDTO.getGender();
|
||||
|
||||
// 添加预设prompt,使生成结果更加具有指向性(区分不同的board)
|
||||
switch(level1Type){
|
||||
case "Moodboard":
|
||||
break;
|
||||
case "Printboard":
|
||||
prompt = "pattern image, " + prompt;
|
||||
break;
|
||||
case "Sketchboard":
|
||||
prompt = "a single item of sketch of " + prompt + ", clean white background, simple lines";
|
||||
break;
|
||||
default:
|
||||
log.warn("未知类型 type:{}", level1Type);
|
||||
}
|
||||
HashMap<String, Boolean> promptExtend = new HashMap<>();
|
||||
promptExtend.put("prompt_extend", false);
|
||||
ImageSynthesisParam param =
|
||||
ImageSynthesisParam.builder()
|
||||
.apiKey(System.getenv("DASHSCOPE_API_KEY"))
|
||||
.model("wanx2.1-t2i-plus")
|
||||
.prompt(prompt)
|
||||
.n(1)
|
||||
.size("1024*1024")
|
||||
.parameters(promptExtend)
|
||||
.build();
|
||||
|
||||
log.info(param.toString());
|
||||
|
||||
ImageSynthesis imageSynthesis = new ImageSynthesis();
|
||||
ImageSynthesisResult result = null;
|
||||
try {
|
||||
result = imageSynthesis.asyncCall(param);
|
||||
} catch (Exception e){
|
||||
throw new RuntimeException(e.getMessage());
|
||||
}
|
||||
String taskId = result.getOutput().getTaskId();
|
||||
log.info("wx text2image 请求生成:{}, taskId:{}", JsonUtils.toJson(result), taskId);
|
||||
|
||||
Generate generate = new Generate(userId, taskId, level1Type, prompt, "text(" + gender + ")", "wx", new Date());
|
||||
save(generate);
|
||||
return taskId;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取异步任务结果
|
||||
* @param taskId 任务id
|
||||
* */
|
||||
public GenerateResultVO getAsyncTaskResult(String taskId) {
|
||||
ImageSynthesis imageSynthesis = new ImageSynthesis();
|
||||
ImageSynthesisResult result = null;
|
||||
try {
|
||||
//如果已经在环境变量中设置了 DASHSCOPE_API_KEY,wait()方法可将apiKey设置为null
|
||||
result = imageSynthesis.fetch(taskId, null);
|
||||
log.info(JsonUtils.toJson(result));
|
||||
//PENDING:任务排队中; RUNNING:任务处理中; SUCCEEDED:任务执行成功; FAILED:任务执行失败; CANCELED:任务取消成功; UNKNOWN:任务不存在或状态未知
|
||||
String taskStatus = result.getOutput().getTaskStatus();
|
||||
|
||||
if (taskStatus.equals("SUCCEEDED")){
|
||||
List<Generate> generates = selectListByUniqueId(taskId);
|
||||
String url = result.getOutput().getResults().get(0).get("url");
|
||||
|
||||
String path = null;
|
||||
if (!generates.isEmpty()){
|
||||
Generate generate = generates.get(0);
|
||||
Long accountId = generate.getAccountId();
|
||||
|
||||
// 1、下载图片
|
||||
// InputStream inputStream = downloadImageFromAliyun(url);
|
||||
byte[] bytes = downloadVideoOrImage(url);
|
||||
// 2、上传图片到minio保存
|
||||
String objectName = accountId + "/" + generate.getLevel1Type().toLowerCase() + "/" + taskId + ".png";
|
||||
minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png");
|
||||
path = userBucket + "/" + objectName;
|
||||
// 3、生成结果保存到db
|
||||
GenerateDetail generateDetail = new GenerateDetail();
|
||||
generateDetail.setGenerateId(generate.getId());
|
||||
generateDetail.setUrl(path);
|
||||
generateDetail.setMd5(MD5Utils.encryptFile(minioUtil.getPreSignedUrl(path, 24 * 60),false));
|
||||
generateDetail.setCreateDate(LocalDateTime.now());
|
||||
generateDetailMapper.insert(generateDetail);
|
||||
// 4、扣积分
|
||||
Boolean flag = creditsService.taskCreditsDeduction(accountId, taskId);
|
||||
if (flag) creditsService.updateChangedCredits(String.valueOf(generate.getAccountId()), taskId);
|
||||
|
||||
GenerateResultVO generateResultVO = new GenerateResultVO(taskId, generateDetail.getId(), minioUtil.getPreSignedUrl(path, 24 * 60), "Success");
|
||||
if (generate.getLevel1Type().equals(SKETCH_BOARD.getRealName())){
|
||||
String gender = extractGender(generate.getGenerateType());
|
||||
if (!StringUtil.isNullOrEmpty(gender)){
|
||||
String clothCategory = pythonService.getClothCategory(path, gender);
|
||||
generateResultVO.setCategory(clothCategory);
|
||||
}else {
|
||||
log.warn("未提取到性别");
|
||||
}
|
||||
}
|
||||
|
||||
return generateResultVO;
|
||||
}else {
|
||||
throw new BusinessException("Unknown generate task");
|
||||
}
|
||||
} else if(taskStatus.equals("PENDING") || taskStatus.equals("RUNNING")){
|
||||
log.info("万象 异步接口返回生成状态为:{}", taskStatus);
|
||||
return new GenerateResultVO(taskId, null, null, "Executing");
|
||||
} else {
|
||||
log.warn("万象 异步接口返回生成状态为:{}", taskStatus);
|
||||
return new GenerateResultVO(taskId, null, null, "Failed");
|
||||
}
|
||||
} catch (ApiException | NoApiKeyException e){
|
||||
throw new RuntimeException(e.getMessage());
|
||||
} catch (Exception e){
|
||||
log.error("从aliyun下载图片失败, {}", e.getMessage());
|
||||
throw new BusinessException("Generation result retrieval failed");
|
||||
}
|
||||
}
|
||||
|
||||
private static final String IMAGE_DETECT = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2video/aa-detect";
|
||||
private static final String TEMPLATE_ID_GEN = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2video/aa-template-generation/";
|
||||
private static final String GET_ASYNC_RESULT = "https://dashscope.aliyuncs.com/api/v1/tasks/";
|
||||
private static final String ANIMATE = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2video/video-synthesis/";
|
||||
private static final String API_KEY = System.getenv("DASHSCOPE_API_KEY"); // 替换为你的实际API密钥
|
||||
public String animateAnyone(PoseTransformDTO poseTransformDTO, Long accountId){
|
||||
|
||||
accountId = 87L;
|
||||
String inputImage = poseTransformDTO.getProductImage();
|
||||
inputImage = "aida-users/87/product_image/03983c74-741b-4d4d-820a-7c0a98a8f500-0-87.png";
|
||||
String inputImageUrl = minioUtil.getPreSignedUrl(inputImage, CommonConstant.MINIO_IMAGE_EXPIRE_TIME);
|
||||
// 1、输入图片检测
|
||||
// checkImage(inputImageUrl);
|
||||
|
||||
// 2、动作模板生成
|
||||
/* 目前只有一个pose,所以不调获取templateId的方法,写死
|
||||
String videoPath = "aida-sys-image/pose/WeChat_20250408175337.mp4";
|
||||
String videoTemplateId = getVideoTemplateId(videoPath);*/
|
||||
String videoTemplateId = "AACT.8090e67b.-E3pujumEfCbDTI_rjSH-A.LwIlGT3j";;
|
||||
|
||||
// 3、生成动图
|
||||
JSONObject requestBody1 = new JSONObject();
|
||||
requestBody1.set("model", "animate-anyone-gen2");
|
||||
|
||||
JSONObject input1 = new JSONObject();
|
||||
input1.set("image_url", inputImageUrl); // 替换为实际图片URL
|
||||
input1.set("template_id", videoTemplateId); // 替换为实际图片URL
|
||||
JSONObject parameters1 = new JSONObject();
|
||||
parameters1.set("use_ref_img_bg", false);
|
||||
parameters1.set("video_ratio", "9:16");
|
||||
|
||||
requestBody1.set("input", input1);
|
||||
requestBody1.set("parameters", parameters1);
|
||||
|
||||
// String resp = sendRequestUtil.sendAliYunPostAsync(ANIMATE, requestBody1.toString());
|
||||
String resp = "{\"request_id\":\"656c4339-59e5-9b34-a010-b5aa625a4008\",\"output\":{\"task_id\":\"05c0fe3e-8d93-4754-babe-28a1efc62151\",\"task_status\":\"PENDING\"}}";
|
||||
log.info("wx pose transform 请求生成,获取taskId:{}", resp);
|
||||
JSONObject jsonResponse = JSONUtil.parseObj(resp);
|
||||
JSONObject output = jsonResponse.getJSONObject("output");
|
||||
String status = output.getStr("task_status");
|
||||
if (status.equals(STATUS_FAILED) || status.equals(STATUS_UNKNOWN)){
|
||||
return null;
|
||||
}
|
||||
String taskId = output.getStr("task_id");
|
||||
|
||||
PoseTransformation poseTransformation = new PoseTransformation(poseTransformDTO.getProjectId(),
|
||||
accountId, taskId, inputImage, poseTransformDTO.getPoseId());
|
||||
poseTransformation.setCreateTime(LocalDateTime.now());
|
||||
poseTransformationMapper.insert(poseTransformation);
|
||||
return taskId;
|
||||
}
|
||||
|
||||
public void checkImage(String inputImageUrl){
|
||||
JSONObject requestBody = new JSONObject();
|
||||
requestBody.set("model", "animate-anyone-detect-gen2");
|
||||
|
||||
JSONObject input = new JSONObject();
|
||||
input.set("image_url", inputImageUrl); // 替换为实际图片URL
|
||||
JSONObject parameters = new JSONObject();
|
||||
|
||||
requestBody.set("input", input);
|
||||
requestBody.set("parameters", parameters);
|
||||
|
||||
String response = sendRequestUtil.sendPost(IMAGE_DETECT, requestBody.toString());
|
||||
|
||||
System.out.println("API响应: " + response);
|
||||
JSONObject jsonResponse = JSONUtil.parseObj(response);
|
||||
// 获取check_pass值
|
||||
JSONObject output = jsonResponse.getJSONObject("output");
|
||||
Boolean checkPass = output.getBool("check_pass");
|
||||
|
||||
if (!checkPass){
|
||||
String reason = output.getStr("reason");
|
||||
log.info("原因: {}", reason);
|
||||
throw new BusinessException("输入的图片不满足要求");
|
||||
}
|
||||
}
|
||||
|
||||
// 轮询配置
|
||||
private static final int MAX_RETRIES = 30; // 最大重试次数
|
||||
private static final int POLL_INTERVAL = 2000; // 轮询间隔(毫秒)
|
||||
|
||||
public String getVideoTemplateId(String videoPath){
|
||||
String key = RedisUtil.ANIMATE_ANYONE_TEMPLATE_ID + videoPath;
|
||||
String templateId = redisUtil.getFromString(key);
|
||||
|
||||
if (StringUtil.isNullOrEmpty(templateId)){
|
||||
String videoUrl = minioUtil.getPreSignedUrl(videoPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME);
|
||||
JSONObject requestBody = new JSONObject();
|
||||
requestBody.set("model", "animate-anyone-template-gen2");
|
||||
JSONObject input = new JSONObject();
|
||||
input.set("video_url", videoUrl); // 替换为实际图片URL
|
||||
JSONObject parameters = new JSONObject();
|
||||
requestBody.set("input", input);
|
||||
requestBody.set("parameters", parameters);
|
||||
|
||||
String resp = sendRequestUtil.sendPost(TEMPLATE_ID_GEN, requestBody.toString());
|
||||
|
||||
if (StringUtil.isNullOrEmpty(resp)){
|
||||
throw new BusinessException("请求获取video template id失败");
|
||||
}
|
||||
JSONObject jsonResponse = JSONUtil.parseObj(resp);
|
||||
log.info("getVideoTemplateId response:{}", jsonResponse);
|
||||
JSONObject output = jsonResponse.getJSONObject("output");
|
||||
String taskId = output.getStr("task_id");
|
||||
|
||||
// 暂时用while循环轮询
|
||||
templateId = pollTemplateIdResult(taskId);
|
||||
if (StringUtil.isNullOrEmpty(templateId)){
|
||||
throw new BusinessException("获取动作模板失败");
|
||||
}
|
||||
// templateId = "AACT.8090e67b.-E3pujumEfCbDTI_rjSH-A.LwIlGT3j";
|
||||
redisUtil.addToString(key, templateId);
|
||||
}
|
||||
|
||||
return templateId;
|
||||
}
|
||||
|
||||
// 定义任务状态常量
|
||||
private static final String STATUS_SUCCESS = "SUCCEEDED";
|
||||
private static final String STATUS_FAILED = "FAILED";
|
||||
private static final String STATUS_UNKNOWN = "UNKNOWN";
|
||||
private static final String STATUS_RUNNING = "RUNNING";
|
||||
private static final String STATUS_PENDING = "PENDING";
|
||||
|
||||
public String pollTemplateIdResult(String taskId) {
|
||||
int attempt = 0;
|
||||
boolean isCompleted = false;
|
||||
String templateId = null;
|
||||
|
||||
while (attempt < MAX_RETRIES && !isCompleted) {
|
||||
attempt++;
|
||||
System.out.printf("尝试第 %d 次查询...%n", attempt);
|
||||
|
||||
try {
|
||||
// 发送GET请求查询任务状态
|
||||
HttpResponse httpResponse = HttpRequest.get(GET_ASYNC_RESULT + taskId)
|
||||
.header(Header.AUTHORIZATION, "Bearer " + API_KEY)
|
||||
.timeout(10000)
|
||||
.execute();
|
||||
|
||||
if (httpResponse.getStatus() == 200) {
|
||||
JSONObject response = JSONUtil.parseObj(httpResponse.body());
|
||||
JSONObject output = JSONUtil.parseObj(response.getStr("output"));
|
||||
String taskStatus = output.getStr("task_status", "UNKNOWN");
|
||||
|
||||
System.out.println("当前任务状态: " + taskStatus);
|
||||
|
||||
switch (taskStatus) {
|
||||
case STATUS_SUCCESS:
|
||||
templateId = handleSuccessResponse(response);
|
||||
isCompleted = true;
|
||||
break;
|
||||
case STATUS_FAILED:
|
||||
case STATUS_UNKNOWN:
|
||||
handleFailedResponse(response);
|
||||
isCompleted = true;
|
||||
break;
|
||||
case STATUS_RUNNING:
|
||||
case STATUS_PENDING:
|
||||
// 任务仍在运行,继续等待
|
||||
break;
|
||||
default:
|
||||
System.out.println("未知状态: " + taskStatus);
|
||||
}
|
||||
} else {
|
||||
System.out.println("请求失败,状态码: " + httpResponse.getStatus());
|
||||
}
|
||||
|
||||
// 如果不是最终状态,等待一段时间再重试
|
||||
if (!isCompleted && attempt < MAX_RETRIES) {
|
||||
TimeUnit.MILLISECONDS.sleep(POLL_INTERVAL);
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
System.out.println("轮询被中断");
|
||||
break;
|
||||
} catch (Exception e) {
|
||||
System.out.println("请求发生异常: " + e.getMessage());
|
||||
// 发生异常时可以选择重试或退出
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isCompleted) {
|
||||
System.out.println("达到最大重试次数仍未获取最终结果");
|
||||
}
|
||||
return templateId;
|
||||
}
|
||||
|
||||
private static String handleSuccessResponse(JSONObject response) {
|
||||
log.info("任务执行成功!");
|
||||
// 提取成功结果
|
||||
JSONObject output = response.getJSONObject("output");
|
||||
if (output != null) {
|
||||
log.info("任务输出: {}", output.toStringPretty());
|
||||
return output.getStr("template_id");
|
||||
}else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static void handleFailedResponse(JSONObject response) {
|
||||
log.info("任务执行失败!");
|
||||
// 提取失败原因
|
||||
String errorMsg = response.getStr("error_message", "未知错误");
|
||||
log.info("失败原因: {}", errorMsg);
|
||||
}
|
||||
|
||||
|
||||
public PoseTransformationVO getAnimateResult(String taskId){
|
||||
String fullUrl = GET_ASYNC_RESULT + taskId;
|
||||
String respBody = sendRequestUtil.sendAliYunGet(fullUrl);
|
||||
log.info("获取wx pose transform 的结果: {}", respBody);
|
||||
|
||||
String outputStr = JSONUtil.parseObj(respBody).getStr("output");
|
||||
JSONObject output = JSONUtil.parseObj(outputStr);
|
||||
String videoUrl = output.getStr("video_url");
|
||||
String status = output.getStr("task_status");
|
||||
|
||||
PoseTransformationVO poseTransformationVO = new PoseTransformationVO();
|
||||
switch (status) {
|
||||
case STATUS_SUCCESS:
|
||||
poseTransformationVO.setStatus("Success");
|
||||
List<PoseTransformation> poseTransformations = poseTransformationMapper.selectList(new QueryWrapper<PoseTransformation>().eq("unique_id", taskId).orderByDesc("id"));
|
||||
if (!poseTransformations.isEmpty()){
|
||||
PoseTransformation poseTransformation = poseTransformations.get(0);
|
||||
// 生成视频的gif和第一帧图片
|
||||
processVideo(videoUrl, poseTransformation);
|
||||
poseTransformationVO.setId(poseTransformation.getId());
|
||||
if (!StringUtil.isNullOrEmpty(poseTransformation.getGifUrl())){
|
||||
poseTransformationVO.setGifUrl(minioUtil.getPreSignedUrl(poseTransformation.getGifUrl(), CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
|
||||
}
|
||||
if (!StringUtil.isNullOrEmpty(poseTransformation.getVideoUrl())){
|
||||
poseTransformationVO.setVideoUrl(minioUtil.getPreSignedUrl(poseTransformation.getVideoUrl(), CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
|
||||
}
|
||||
if (!StringUtil.isNullOrEmpty(poseTransformation.getFirstFrameUrl())){
|
||||
poseTransformationVO.setFirstFrameUrl(minioUtil.getPreSignedUrl(poseTransformation.getFirstFrameUrl(), CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
|
||||
}
|
||||
// 执行积分扣除
|
||||
Long accountId = poseTransformation.getAccountId();
|
||||
Boolean flag = creditsService.taskCreditsDeduction(accountId, taskId);
|
||||
if (flag) creditsService.updateChangedCredits(String.valueOf(accountId), taskId);
|
||||
}
|
||||
break;
|
||||
case STATUS_FAILED:
|
||||
case STATUS_UNKNOWN:
|
||||
poseTransformationVO.setStatus("Failed");
|
||||
break;
|
||||
case STATUS_RUNNING:
|
||||
case STATUS_PENDING:
|
||||
// 任务仍在运行,继续等待
|
||||
poseTransformationVO.setStatus("Executing");
|
||||
break;
|
||||
default:
|
||||
log.info("未知状态: {}", status);
|
||||
poseTransformationVO.setStatus("Failed");
|
||||
}
|
||||
poseTransformationVO.setTaskId(taskId);
|
||||
|
||||
return poseTransformationVO;
|
||||
}
|
||||
|
||||
public void processVideo(String aliyunVideoUrl, PoseTransformation poseTransformation) /*throws Exception*/ {
|
||||
// 1. 从阿里云下载视频到内存
|
||||
byte[] videoBytes = downloadVideoOrImage(aliyunVideoUrl);
|
||||
Long accountId = poseTransformation.getAccountId();
|
||||
String taskId = poseTransformation.getUniqueId();
|
||||
|
||||
// 2. 提取第一帧和生成GIF
|
||||
ByteArrayOutputStream firstFrameOutput = new ByteArrayOutputStream();
|
||||
ByteArrayOutputStream gifOutput = new ByteArrayOutputStream();
|
||||
|
||||
try (FFmpegFrameGrabber grabber = new FFmpegFrameGrabber(new ByteArrayInputStream(videoBytes))) {
|
||||
grabber.start();
|
||||
// 提取第一帧
|
||||
BufferedImage firstFrame = new Java2DFrameConverter().convert(grabber.grabImage());
|
||||
ImageIO.write(firstFrame, "jpg", firstFrameOutput);
|
||||
|
||||
// 生成GIF(取前10秒,50帧)
|
||||
generateGif(grabber, gifOutput, 10, 50);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// 3. 上传所有文件到MinIO
|
||||
String videoPrefix = accountId + "/pose_transform_video/" + taskId + ".mp4";
|
||||
String imgPrefix = accountId + "/pose_transform_first_img/" + taskId + ".jpg";
|
||||
String gifPrefix = accountId + "/pose_transform_gif/" + taskId + ".gif";
|
||||
|
||||
minioUtil.uploadToMinio(videoBytes, userBucket, videoPrefix, "video/mp4");
|
||||
minioUtil.uploadToMinio(firstFrameOutput.toByteArray(), userBucket, imgPrefix, "image/jpeg");
|
||||
minioUtil.uploadToMinio(gifOutput.toByteArray(), userBucket, gifPrefix, "image/gif");
|
||||
// 存储数据到数据库
|
||||
poseTransformation.setGifUrl(userBucket + "/" + gifPrefix);
|
||||
poseTransformation.setVideoUrl(userBucket + "/" + videoPrefix);
|
||||
poseTransformation.setFirstFrameUrl(userBucket + "/" + imgPrefix);
|
||||
poseTransformation.setUpdateTime(LocalDateTime.now());
|
||||
poseTransformationMapper.updateById(poseTransformation);
|
||||
}
|
||||
|
||||
private byte[] downloadVideoOrImage(String url) {
|
||||
try (CloseableHttpClient client = HttpClients.createDefault();
|
||||
InputStream in = client.execute(new HttpGet(url)).getEntity().getContent()) {
|
||||
return IOUtils.toByteArray(in);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void generateGif(FFmpegFrameGrabber grabber, OutputStream output,
|
||||
int durationSec, int frameCount) throws Exception {
|
||||
Java2DFrameConverter converter = new Java2DFrameConverter();
|
||||
AnimatedGifEncoder gifEncoder = new AnimatedGifEncoder();
|
||||
|
||||
// 配置GIF参数
|
||||
gifEncoder.start(output);
|
||||
gifEncoder.setDelay(100); // 每帧延迟(毫秒)
|
||||
gifEncoder.setRepeat(0); // 0=无限循环
|
||||
|
||||
int totalFrames = (int) (grabber.getFrameRate() * durationSec);
|
||||
int step = Math.max(1, totalFrames / frameCount);
|
||||
|
||||
// 逐帧处理
|
||||
for (int i = 0; i < totalFrames; i += step) {
|
||||
grabber.setVideoFrameNumber(i);
|
||||
BufferedImage frame = converter.convert(grabber.grabImage());
|
||||
if (frame != null) {
|
||||
gifEncoder.addFrame(frame);
|
||||
}
|
||||
}
|
||||
gifEncoder.finish();
|
||||
}
|
||||
|
||||
/**
|
||||
* Freepik
|
||||
* To Product Image
|
||||
*/
|
||||
public String reimagineFreePik(String path, String prompt, String style) throws IOException {
|
||||
String imageAsBase64 = minioUtil.getImageAsBase64(path);
|
||||
log.info(minioUtil.getPreSignedUrl(path, CommonConstant.MINIO_IMAGE_EXPIRE_TIME));
|
||||
JSONObject requestBody = new JSONObject();
|
||||
requestBody.set("image", imageAsBase64);
|
||||
requestBody.set("prompt", prompt);
|
||||
requestBody.set("imagination", style);
|
||||
|
||||
String resp = sendRequestUtil.sendFreepikPost(requestBody.toString());
|
||||
if (!StringUtil.isNullOrEmpty(resp)){
|
||||
JSONObject jsonResp = JSONUtil.parseObj(resp);
|
||||
JSONObject data = JSONUtil.parseObj(jsonResp.get("data"));
|
||||
String status = data.getStr("status");
|
||||
if (status.equals("COMPLETED")){
|
||||
List<String> generated = data.getBeanList("generated", String.class);
|
||||
return generated.get(0);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* ollama
|
||||
* prompt 助手
|
||||
*/
|
||||
public String getImageDescription(String imagePath) {
|
||||
// 1. 读取图片并编码为 Base64
|
||||
String imageAsBase64 = null;
|
||||
try {
|
||||
imageAsBase64 = minioUtil.getImageAsBase64(imagePath);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
// 2. 构建 JSON 请求体
|
||||
JSONObject message = new JSONObject();
|
||||
message.set("role", "user");
|
||||
message.set("content", "Please describe the clothing in the image and provide a line art description of the outfit. The description should allow for the reconstruction of the corresponding line art based on the details given.");
|
||||
message.set("images", JSONUtil.createArray().set(imageAsBase64));
|
||||
|
||||
JSONObject requestBody = new JSONObject();
|
||||
requestBody.set("model", "llama3.2-vision");
|
||||
requestBody.set("messages", JSONUtil.createArray().set(message));
|
||||
requestBody.set("stream", false);
|
||||
|
||||
// log.info("request body:{}", requestBody);
|
||||
String description = sendRequestUtil.sendPost(ollamaUrl, requestBody.toString());
|
||||
if (StringUtil.isNullOrEmpty(description)){
|
||||
throw new BusinessException("从ollama获取图片描述失败");
|
||||
}
|
||||
log.info("image :{}, description: {}",
|
||||
minioUtil.getPreSignedUrl(imagePath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), description);
|
||||
return description;
|
||||
}
|
||||
|
||||
private String resolveModelType(String taskId){
|
||||
// 判断当前task来自哪个模型
|
||||
// 判断taskId的结构
|
||||
int count = StringUtils.countMatches(taskId, "-");
|
||||
String lastPart = taskId.substring(taskId.lastIndexOf("-") + 1);
|
||||
String type;
|
||||
if (count == 4 && lastPart.length() == 12){
|
||||
// 万象
|
||||
type = "wx";
|
||||
}else {
|
||||
// 本地部署的模型
|
||||
type = "local";
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
public static String extractGender(String text) {
|
||||
// 匹配末尾的 (Male) 或 (Female),忽略大小写
|
||||
Pattern pattern = Pattern.compile("\\(([Mm]ale|[Ff]emale)\\)$");
|
||||
Matcher matcher = pattern.matcher(text);
|
||||
|
||||
if (matcher.find()) {
|
||||
return matcher.group(1); // 返回括号内的内容
|
||||
}
|
||||
return null; // 未匹配到性别
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user