From 595effa04c86c4c656133467676b761dad28b3ba Mon Sep 17 00:00:00 2001 From: xupei Date: Tue, 10 Jun 2025 17:50:04 +0800 Subject: [PATCH] =?UTF-8?q?TASK:=20=E6=8E=A5=E5=85=A5=E7=AC=AC=E4=B8=89?= =?UTF-8?q?=E6=96=B9api=20freepik=20=E6=96=B0=E5=A2=9Esketch=20extract=20v?= =?UTF-8?q?ariation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/da/common/utils/SendRequestUtil.java | 2 +- .../ai/da/controller/GenerateController.java | 3 +- .../com/ai/da/service/GenerateService.java | 2 + .../da/service/impl/GenerateServiceImpl.java | 142 +++++++++++++++--- 4 files changed, 130 insertions(+), 19 deletions(-) diff --git a/src/main/java/com/ai/da/common/utils/SendRequestUtil.java b/src/main/java/com/ai/da/common/utils/SendRequestUtil.java index 8ed6321c..1b170d64 100644 --- a/src/main/java/com/ai/da/common/utils/SendRequestUtil.java +++ b/src/main/java/com/ai/da/common/utils/SendRequestUtil.java @@ -74,7 +74,7 @@ public class SendRequestUtil { try (HttpResponse execute = HttpRequest.post(url) .header("Content-Type", "application/json") // 必须设置 Content-Type .body(requestBodyStr) // Hutool 会自动处理 JSON 序列化 - .timeout(120000) // 设置超时(毫秒) + .timeout(180000) // 设置超时(毫秒) .execute()) { status = execute.getStatus(); diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java index e3791854..4d53f283 100644 --- a/src/main/java/com/ai/da/controller/GenerateController.java +++ b/src/main/java/com/ai/da/controller/GenerateController.java @@ -87,7 +87,8 @@ public class GenerateController { @ApiOperation(value = "imageToSketch") @PostMapping("/imageToSketch") public Response imageToSketch(@Valid @RequestBody ImageToSketchDTO imageToSketchDTO) { - return Response.success(generateService.imageToSketch(imageToSketchDTO, null, null)); + return Response.success(generateService.imageToSketchAsync(imageToSketchDTO, null, null)); +// return Response.success(generateService.imageToSketch(imageToSketchDTO, null, null)); } // modifySketch diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index 38e53dec..e02be05d 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -44,6 +44,8 @@ public interface GenerateService extends IService { GenerateResultVO imageToSketch(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId); + GenerateResultVO imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId); + GenerateResultVO modifySketch(GenerateModifyDTO generateModifyDTO); String poseTransform(PoseTransformDTO poseTransformDTO); diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 5d46c6cd..174b42a1 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -37,6 +37,8 @@ import io.netty.util.internal.StringUtil; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; +import org.apache.http.HttpHost; +import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.HttpGet; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; @@ -116,9 +118,6 @@ public class GenerateServiceImpl extends ServiceImpl i @Value("${access.python.generate_sr_port}") private String generateServicePort; - @Value("${ollama.url}") - private String ollamaUrl; - @Value("${ALIYUN_API_KEY}") private String ALIYUN_API_KEY; @@ -868,7 +867,6 @@ public class GenerateServiceImpl extends ServiceImpl i @Transactional(rollbackFor = Exception.class) public GenerateResultVO imageToSketch(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId) { - String bucket = userBucket; Long accountId = UserContext.getUserHolder().getId(); log.info("imageToSketch parameter : {}", imageToSketchDTO); @@ -879,6 +877,24 @@ public class GenerateServiceImpl extends ServiceImpl i throw new BusinessException("remaining.credits.insufficient", ResultEnum.PROMPT.getCode()); } + String style = imageToSketchDTO.getStyle(); + String styleCode = style.equals(SketchStyle.THICK.getValue()) ? "1" : + style.equals(SketchStyle.MEDIUM.getValue()) ? "2" : + style.equals(SketchStyle.THIN.getValue()) ? "3" : "Custom"; + + // 线稿提取 + String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode); + // 存数据库 + Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, accountId, styleCode); + GenerateResultVO generateResultVO = saveExtractSketchResult(generate, sketchPath, imageToSketchDTO.getGender()); + // 积分扣除 + doCreditsSubtract(accountId, event); + + return generateResultVO; + } + + private String requestSketchExtract(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, + Long accountId, String styleCode){ String imagePath; if (StringUtil.isNullOrEmpty(collagePictureUrl)){ CollectionElement collectionElement = collectionElementService.getById(imageToSketchDTO.getElementId()); @@ -890,10 +906,7 @@ public class GenerateServiceImpl extends ServiceImpl i log.info(minioUtil.getPreSignedUrl(imagePath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME)); String imageName = UUID.randomUUID().toString(); String objectName = accountId + "/imageToSketch/" + imageName; - String style = imageToSketchDTO.getStyle(); - String styleCode = style.equals(SketchStyle.THICK.getValue()) ? "1" : - style.equals(SketchStyle.MEDIUM.getValue()) ? "2" : - style.equals(SketchStyle.THIN.getValue()) ? "3" : "Custom"; + String styleImage; if (!Objects.isNull(imageToSketchDTO.getStyleImageId())){ CollectionElement styleElement = collectionElementService.getById(imageToSketchDTO.getElementId()); @@ -901,8 +914,13 @@ public class GenerateServiceImpl extends ServiceImpl i } else { styleImage = ""; } - String sketchPath = pythonService.imageToSketch(imagePath, bucket, objectName, styleCode, styleImage); + String sketchPath = pythonService.imageToSketch(imagePath, userBucket, objectName, styleCode, styleImage); + log.info("初步图片提取结果:{}", sketchPath); + return sketchPath; + } + private Generate saveExtractSketchRequest(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, + Long projectId, Long accountId, String styleCode){ // 存DB Generate generate = new Generate(); generate.setAccountId(accountId); @@ -918,7 +936,10 @@ public class GenerateServiceImpl extends ServiceImpl i generate.setInputImageUrl(collagePictureUrl); generate.setCreateDate(new Date()); baseMapper.insert(generate); + return generate; + } + public GenerateResultVO saveExtractSketchResult(Generate generate, String sketchPath, String gender){ // 将生成结果存入DB GenerateDetail generateDetail = new GenerateDetail(); generateDetail.setGenerateId(generate.getId()); @@ -928,14 +949,64 @@ public class GenerateServiceImpl extends ServiceImpl i generateDetail.setCreateDate(LocalDateTime.now()); generateDetailMapper.insert(generateDetail); - String clothCategory = pythonService.getClothCategory(sketchPath, imageToSketchDTO.getGender()); + String clothCategory = pythonService.getClothCategory(sketchPath, gender); + return new GenerateResultVO(generateDetail.getId(), minioUtil.getPreSignedUrl(sketchPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), "Success", clothCategory); + } + + public void doCreditsSubtract(Long accountId, CreditsEventsEnum event){ BigDecimal existingCredits = accountService.getById(accountId).getCredits(); BigDecimal subtract = existingCredits.subtract(new BigDecimal(event.getValue())); accountService.updateCreditsAndEndTime(accountId, subtract.toString(), null); creditsService.preInsert(accountId, event.getName(), null, Boolean.FALSE, event.getValue()); + } - return new GenerateResultVO(generateDetail.getId(), minioUtil.getPreSignedUrl(sketchPath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), "Success", clothCategory); + // freepik以后会变成异步的吗? 目前同步 + public GenerateResultVO imageToSketchAsync(ImageToSketchDTO imageToSketchDTO, String collagePictureUrl, Long projectId){ + Long accountId = UserContext.getUserHolder().getId(); + log.info("imageToSketch parameter : {}", imageToSketchDTO); + + // 检查积分是否够本次扣除 + CreditsEventsEnum event = CreditsEventsEnum.IMAGE_TO_SKETCH; + Boolean b = creditsService.checkCredits(accountId, event, 1); + if (!b){ + throw new BusinessException("remaining.credits.insufficient", ResultEnum.PROMPT.getCode()); + } + + String style = imageToSketchDTO.getStyle(); + String styleCode = style.equals(SketchStyle.THICK.getValue()) ? "1" : + style.equals(SketchStyle.MEDIUM.getValue()) ? "2" : + style.equals(SketchStyle.THIN.getValue()) ? "3" : "Custom"; + + String sketchPath = requestSketchExtract(imageToSketchDTO, collagePictureUrl, accountId, styleCode); + + String imageDescription = getImageDescription(sketchPath); + + try { + // 请求freepik reimage + String dataStr = reimagineFreePik(sketchPath, imageDescription, "vivid"); + if (StringUtil.isNullOrEmpty(dataStr)){ + throw new BusinessException("extract sketch failed"); + } + JSONObject data = JSONUtil.parseObj(dataStr); + String upgradeImageUrl = data.getBeanList("generated", String.class).get(0); + String taskId = data.getStr("task_id"); + + // 下载图片 +// byte[] bytes = downloadWithProxy(upgradeImageUrl); + byte[] bytes = downloadVideoOrImage(upgradeImageUrl); + // 2、上传图片到minio保存 + String objectName = accountId + "/imageToSketch/" + taskId + ".png"; + minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png"); + // 存数据库 + Generate generate = saveExtractSketchRequest(imageToSketchDTO, collagePictureUrl, projectId, accountId, styleCode); + GenerateResultVO generateResultVO = saveExtractSketchResult(generate, userBucket + "/" + objectName, imageToSketchDTO.getGender()); + // 积分扣除 + doCreditsSubtract(accountId, event); + return generateResultVO; + } catch (IOException e) { + throw new RuntimeException(e); + } } // 对提取出来的sketch做调整 @@ -1760,6 +1831,36 @@ public class GenerateServiceImpl extends ServiceImpl i } } + public byte[] downloadWithProxy(String url) throws IOException { + // 获取系统代理设置(适用于大多数VPN) +// String proxyHost = System.getProperty("http.proxyHost"); + String proxyHost = "localhost"; +// String proxyPort = System.getProperty("http.proxyPort"); + String proxyPort = "7890"; + + CloseableHttpClient client; + if (proxyHost != null && proxyPort != null) { + // 配置代理 + HttpHost proxy = new HttpHost(proxyHost, Integer.parseInt(proxyPort)); + RequestConfig config = RequestConfig.custom().setProxy(proxy).build(); + client = HttpClients.custom().setDefaultRequestConfig(config).build(); + } else { + client = HttpClients.createDefault(); + } + + try { + return client.execute(new HttpGet(url), response -> { + if (response.getStatusLine().getStatusCode() == 200) { + return IOUtils.toByteArray(response.getEntity().getContent()); + } else { + throw new IOException("HTTP Error: " + response.getStatusLine()); + } + }); + } finally { + client.close(); + } + } + public static void generateGif(FFmpegFrameGrabber grabber, OutputStream output, int durationSec, int frameCount) throws Exception { Java2DFrameConverter converter = new Java2DFrameConverter(); @@ -1802,19 +1903,21 @@ public class GenerateServiceImpl extends ServiceImpl i JSONObject data = JSONUtil.parseObj(jsonResp.get("data")); String status = data.getStr("status"); if (status.equals("COMPLETED")){ - List generated = data.getBeanList("generated", String.class); - return generated.get(0); +// List generated = data.getBeanList("generated", String.class); + log.info("freepik 调用结果:{}", jsonResp); + return jsonResp.getStr("data"); } } return null; } /** + * imagePath 图片的minio地址 * ollama * prompt 助手 */ public String getImageDescription(String imagePath) { - // 1. 读取图片并编码为 Base64 +/* // 1. 读取图片并编码为 Base64 String imageAsBase64 = null; try { imageAsBase64 = minioUtil.getImageAsBase64(imagePath); @@ -1831,14 +1934,19 @@ public class GenerateServiceImpl extends ServiceImpl i JSONObject requestBody = new JSONObject(); requestBody.set("model", "llama3.2-vision"); requestBody.set("messages", JSONUtil.createArray().set(message)); - requestBody.set("stream", false); + requestBody.set("stream", false);*/ // log.info("request body:{}", requestBody); - String description = sendRequestUtil.sendPost(ollamaUrl, requestBody.toString()); +// String resp = sendRequestUtil.sendPost("http://18.167.251.121:9994/api/img2prompt", imagePath); + JSONObject requestBody = new JSONObject(); + requestBody.set("img", imagePath); + String description = sendRequestUtil.sendPost("http://localhost:8000/api/img2prompt", requestBody.toString()); if (StringUtil.isNullOrEmpty(description)){ throw new BusinessException("从ollama获取图片描述失败"); } - log.info("image :{}, description: {}", + /*Object msg = JSONUtil.parseObj(resp).get("message"); + String description = JSONUtil.parseObj(msg).getStr("content");*/ + log.info("image :{} \n, description: {}", minioUtil.getPreSignedUrl(imagePath, CommonConstant.MINIO_IMAGE_EXPIRE_TIME), description); return description; }