diff --git a/src/main/java/com/ai/da/mapper/primary/entity/ToProductImageResult.java b/src/main/java/com/ai/da/mapper/primary/entity/ToProductImageResult.java index 796c677a..b3556e7c 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/ToProductImageResult.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/ToProductImageResult.java @@ -62,7 +62,7 @@ public class ToProductImageResult implements Serializable { private String modelName; - private String taskStatus; + private String status; @ApiModelProperty(value = "是否删除1:是 0:否") @TableField diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 4b0740bf..09305555 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -41,6 +41,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.apache.http.HttpHost; import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpGet; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; @@ -48,6 +49,7 @@ import org.bytedeco.javacv.FFmpegFrameGrabber; import org.bytedeco.javacv.Java2DFrameConverter; import org.springframework.beans.factory.annotation.Value; import org.springframework.dao.DuplicateKeyException; +import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import org.springframework.util.StringUtils; @@ -331,7 +333,7 @@ public class GenerateServiceImpl extends ServiceImpl i } ToProductImageResult toProductImageResult = toProductImageResults.get(0); toProductImageResult.setUrl(url); - toProductImageResult.setTaskStatus("Success"); + toProductImageResult.setStatus("Success"); // toProductImageResult.setResultType("ToProductImage"); toProductImageResultMapper.updateById(toProductImageResult); @@ -359,8 +361,8 @@ public class GenerateServiceImpl extends ServiceImpl i List toProductImageResults = toProductImageResultMapper.selectList(qw); if (!CollectionUtils.isEmpty(toProductImageResults)) { ToProductImageResult toProductImageResult = toProductImageResults.get(0); - if (StringUtil.isNullOrEmpty(toProductImageResult.getTaskStatus()) || !toProductImageResult.getTaskStatus().equals(status)){ - toProductImageResult.setTaskStatus(status); + if (StringUtil.isNullOrEmpty(toProductImageResult.getStatus()) || !toProductImageResult.getStatus().equals(status)){ + toProductImageResult.setStatus(status); toProductImageResultMapper.updateById(toProductImageResult); } } @@ -1064,7 +1066,7 @@ public class GenerateServiceImpl extends ServiceImpl i pythonService.bright(url, toProductImageResult.getBrightenValue()); } toProductImageResult.setUrl(url); - toProductImageResult.setTaskStatus("Success"); + toProductImageResult.setStatus("Success"); // toProductImageResult.setResultType("Relight"); toProductImageResultMapper.updateById(toProductImageResult); @@ -1462,16 +1464,19 @@ public class GenerateServiceImpl extends ServiceImpl i // 3、生成唯一id 使用uuid,由于uuid重复的几率很小,故取消对uuid重复性的校验 String taskId; - Boolean flag = false; + Boolean isRequestSuccess = false; PoseTransformation poseTransformation = new PoseTransformation(); if (!StringUtil.isNullOrEmpty(poseTransformDTO.getModelName()) && poseTransformDTO.getModelName().equals("wx")) { taskId = animateAnyone(poseTransformDTO, accountId); - if (!StringUtil.isNullOrEmpty(taskId)) flag = true; + if (!StringUtil.isNullOrEmpty(taskId)){ + isRequestSuccess = true; + addAPIGenerateRecordAsync(taskId, Module.poseTransfer.getValue(), "wx", "Pending"); + } poseTransformation.setModelName("wx"); } else { String uuid = UUID.randomUUID().toString(); taskId = uuid + "-" + accountId; - flag = pythonService.poseTransformation(productImage, poseId, taskId); + isRequestSuccess = pythonService.poseTransformation(productImage, poseId, taskId); } poseTransformation.setProjectId(projectId); @@ -1480,7 +1485,7 @@ public class GenerateServiceImpl extends ServiceImpl i poseTransformation.setProductImage(productImage); poseTransformation.setPoseId(poseId); poseTransformation.setIsLiked((byte) 0); - String taskStatus = flag ? "Executing" : "Fail"; + String taskStatus = isRequestSuccess ? "Executing" : "Fail"; poseTransformation.setTaskStatus(taskStatus); poseTransformation.setCreateTime(LocalDateTime.now()); poseTransformationMapper.insert(poseTransformation); @@ -1489,7 +1494,7 @@ public class GenerateServiceImpl extends ServiceImpl i toProductImageResultVO.setParentId(poseTransformDTO.getParentId()); toProductImageResultVO.setResultType(Module.poseTransfer.getValue()); toProductImageResultVO.setTaskId(taskId); - toProductImageResultVO.setTaskStatus(taskStatus); + toProductImageResultVO.setStatus(taskStatus); toProductImageResultVO.setSourceUrl(minioUtil.getPreSignedUrl(productImage, CommonConstant.MINIO_IMAGE_EXPIRE_TIME)); toProductImageResultVO.setPoseId(poseId); toProductImageResultVO.setModelName(poseTransformDTO.getModelName()); @@ -1508,7 +1513,7 @@ public class GenerateServiceImpl extends ServiceImpl i } - if (flag) { + if (isRequestSuccess) { // 6、添加预扣除积分到redis creditsService.addRecordToCreditsDeduction(accountId, taskId, creditsEventsEnum); // 6.1 添加积分扣除记录到db @@ -2279,13 +2284,7 @@ public class GenerateServiceImpl extends ServiceImpl i if (status.equals(STATUS_FAILED) || status.equals(STATUS_UNKNOWN)) { return null; } - String taskId = output.getStr("task_id"); - - /*PoseTransformation poseTransformation = new PoseTransformation(poseTransformDTO.getProjectId(), - accountId, taskId, inputImage, poseTransformDTO.getPoseId()); - poseTransformation.setCreateTime(LocalDateTime.now()); - poseTransformationMapper.insert(poseTransformation);*/ - return taskId; + return output.getStr("task_id"); } public void checkImage(String inputImageUrl) { @@ -2456,6 +2455,8 @@ public class GenerateServiceImpl extends ServiceImpl i String videoUrl = output.getStr("video_url"); String status = output.getStr("task_status"); + updateTaskStatusAsync(taskId, status); + PoseTransformationVO poseTransformationVO = new PoseTransformationVO(); switch (status) { case STATUS_SUCCESS: @@ -2553,6 +2554,37 @@ public class GenerateServiceImpl extends ServiceImpl i } } + // 增强版下载方法 todo 最好不要报错 + private byte[] downloadVideoOrImageWithValidation(String url) throws IOException { + CloseableHttpClient client = HttpClients.createDefault(); + HttpGet request = new HttpGet(url); + + try (CloseableHttpResponse response = client.execute(request)) { + // 状态码检查 + if (response.getStatusLine().getStatusCode() != 200) { + throw new IOException("Invalid status: " + response.getStatusLine()); + } + + // 内容类型检查 + org.apache.http.Header contentTypeHeader = response.getFirstHeader("Content-Type"); + if (contentTypeHeader == null || !contentTypeHeader.getValue().startsWith("image/")) { + throw new IOException("Invalid content type: " + + (contentTypeHeader != null ? contentTypeHeader.getValue() : "null")); + } + + // 内容长度检查 + org.apache.http.Header contentLengthHeader = response.getFirstHeader("Content-Length"); + if (contentLengthHeader != null) { + long length = Long.parseLong(contentLengthHeader.getValue()); + if (length <= 0) { + throw new IOException("Empty content"); + } + } + + return IOUtils.toByteArray(response.getEntity().getContent()); + } + } + public byte[] downloadWithProxy(String url) throws IOException { // 获取系统代理设置(适用于大多数VPN) // String proxyHost = System.getProperty("http.proxyHost"); @@ -2770,7 +2802,6 @@ public class GenerateServiceImpl extends ServiceImpl i } String resp = sendRequestUtil.sendFluxPost(fluxRequestUrl, requestBody.toString()); -// JSONObject respObj = JSONUtil.parseObj(null); JSONObject respObj = JSONUtil.parseObj(resp); log.info("flux 发起生成请求返回结果: {}", respObj); String taskId = respObj.getStr("id"); @@ -2783,47 +2814,98 @@ public class GenerateServiceImpl extends ServiceImpl i String pollingUrl = respObj.getStr("polling_url"); String key = RedisUtil.FLUX_POLLING_URL + taskId; redisUtil.addToString(key, pollingUrl, CommonConstant.GENERATE_RESULT_EXPIRE_TIME); + // 添加到api_generate表中,以便之后对结果查询做补偿 + addAPIGenerateRecordAsync(taskId, func.getName(), "flux", "Pending"); return taskId; } + @Override public String getFluxResult(String taskId, String objectName) { + // 获取轮询URL String pollingUrl = redisUtil.getFromString(RedisUtil.FLUX_POLLING_URL + taskId); - String fluxResultRequestUrl; + + // 准备请求参数 + String fluxResultRequestUrl = StringUtil.isNullOrEmpty(pollingUrl) + ? "https://api.bfl.ai/v1/get_result" + : pollingUrl; + HashMap params = new HashMap<>(); if (StringUtil.isNullOrEmpty(pollingUrl)) { - fluxResultRequestUrl = "https://api.bfl.ai/v1/get_result"; params.put("id", taskId); - } else { - fluxResultRequestUrl = pollingUrl; } + // 发送请求并解析响应 String resp = sendRequestUtil.sendGet(fluxResultRequestUrl, params); log.info("获取flux生成的结果为:{}", resp); + JSONObject respObj = JSONUtil.parseObj(resp); String status = respObj.getStr("status"); + + // 异步更新状态 + updateTaskStatusAsync(taskId, status); + + // 处理不同状态 switch (status) { case "Task not found": + // 审核没过 + case "Request Moderated": + // 审核没过 + case "Content Moderated": + // 出错 + case "Error": return "Fail"; case "Pending": return "Pending"; - case "Request Moderated": - case "Content Moderated": - // 审核没过 - return "Fail"; case "Ready": // 已完成 获取结果 - String fluxResult = JSONUtil.parseObj(respObj.getStr("result")).getStr("sample"); - byte[] bytes = downloadVideoOrImage(fluxResult); - minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png"); - -// return minioUtil.getPreSignedUrl(userBucket + "/" + objectName, CommonConstant.MINIO_IMAGE_EXPIRE_TIME); - return userBucket + "/" + objectName; - case "Error": - // 出错 - return "Fail"; + return handleReadyStatus(respObj, objectName); + default: + return null; + } + } + + private String handleReadyStatus(JSONObject respObj, String objectName) { + // 1. 首先检查MinIO中是否已存在该图片 + if (minioUtil.doesObjectExist(userBucket, objectName)) { + return userBucket + "/" + objectName; + } + + // 2. 解析响应获取结果URL和生成时间 + JSONObject resultObj = JSONUtil.parseObj(respObj.getStr("result")); + String fluxResult = resultObj.getStr("sample"); + double endTime = resultObj.getDouble("end_time"); // 获取任务结束时间戳 + + // 3. 检查图片链接是否已过期(超过10分钟) + long currentTime = System.currentTimeMillis() / 1000; // 当前Unix时间戳(秒) + long generateTime = (long) endTime; // 生成结束时间戳 + // 图片10分钟过期,保险起见,保留一分钟 + long tenMinutesInSeconds = 9 * 60; + + if (currentTime - generateTime > tenMinutesInSeconds) { + log.warn("Flux result image has expired, generateTime: {}, currentTime: {}", + generateTime, currentTime); + return null; + } + + // 4. 图片未过期,下载并上传到MinIO + try { + byte[] bytes = downloadVideoOrImage(fluxResult); + minioUtil.uploadToMinio(bytes, userBucket, objectName, "image/png"); + return userBucket + "/" + objectName; + } catch (Exception e) { + log.error("Failed to download or upload Flux result image", e); + return null; + } + } + + @Async + public void updateTaskStatusAsync(String taskId, String status) { + try { + updateAPIGenerateStatusAsync(taskId, status); + } catch (Exception e) { + log.error("更新任务状态失败, taskId: {}, status: {}", taskId, status, e); } - return null; } private GenerateResultVO getFluxResultAndSave(String taskId) { @@ -2849,7 +2931,7 @@ public class GenerateServiceImpl extends ServiceImpl i Boolean flag = creditsService.taskCreditsDeduction(accountId, taskId); if (flag) creditsService.updateChangedCredits(String.valueOf(accountId), taskId); } else if (StringUtil.isNullOrEmpty(generateDetail.getUrl())) { - // 一般来说这条线应该走不到 + // 结果已经存入db,一般走不到这条线 generateDetail.setGenerateId(generate.getId()); generateDetail.setUrl(fluxResult); generateDetail.setMd5(MD5Utils.encryptFile( @@ -2884,4 +2966,44 @@ public class GenerateServiceImpl extends ServiceImpl i return null; } } + + @Async + @Transactional + public void addAPIGenerateRecordAsync(String taskId, String function, String modelName, String status){ + try { + log.info("异步执行添加"); + if (!StringUtil.isNullOrEmpty(taskId) && !StringUtil.isNullOrEmpty(modelName)){ + APIGenerate apiGenerate = new APIGenerate(); + apiGenerate.setTaskId(taskId); + apiGenerate.setFunc(function); + apiGenerate.setModelName(modelName); + apiGenerate.setStatus(status); + apiGenerate.setRetry_count(0); + apiGenerate.setCreateTime(LocalDateTime.now()); + apiGenerateMapper.insert(apiGenerate); + } + } catch (Exception e){ + log.error(e.getMessage()); + } + } + + @Async + @Transactional + public void updateAPIGenerateStatusAsync(String taskId, String status){ + log.info("异步执行修改"); + QueryWrapper qw = new QueryWrapper<>(); + qw.lambda().eq(APIGenerate::getTaskId, taskId); + APIGenerate apiGenerate = apiGenerateMapper.selectOne(qw); + if (Objects.nonNull(apiGenerate)){ + if (apiGenerate.getStatus().equals("Ready") || apiGenerate.getStatus().equals("SUCCEEDED")) { + log.warn("当前任务 {} 状态已达Success, 不做修改", taskId); + } else { + apiGenerate.setStatus(status); + apiGenerate.setUpdateTime(LocalDateTime.now()); + apiGenerateMapper.updateById(apiGenerate); + } + } else { + log.error("任务 {} 在api_generate表中找不到", taskId); + } + } } diff --git a/src/main/java/com/ai/da/service/impl/UserLikeGroupServiceImpl.java b/src/main/java/com/ai/da/service/impl/UserLikeGroupServiceImpl.java index b50133fc..54378bd2 100644 --- a/src/main/java/com/ai/da/service/impl/UserLikeGroupServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/UserLikeGroupServiceImpl.java @@ -490,7 +490,7 @@ public class UserLikeGroupServiceImpl extends ServiceImpl 不用修改,直接处理回参 + if (!toProductImageResult.getStatus().equals("Success") + && !StringUtil.isNullOrEmpty(toProductImageResult.getUrl())){ + toProductImageResult.setStatus("Success"); + toProductImageResult.setUrl(fluxImgMinioPath); + toProductImageResultMapper.updateById(toProductImageResult); + } MagicToolResultVO magicToolResultVO = CopyUtil.copyObject(toProductImageResult, MagicToolResultVO.class); magicToolResultVO.setTaskId(taskId); @@ -1127,7 +1137,7 @@ public class UserLikeGroupServiceImpl extends ServiceImpl