Merge branch 'refs/heads/dev/3.1_release_merge' into dev/dev_xp
This commit is contained in:
5
pom.xml
5
pom.xml
@@ -427,6 +427,11 @@
|
||||
<artifactId>bcpkix-jdk18on</artifactId>
|
||||
<version>1.78.1</version>
|
||||
</dependency>
|
||||
<!-- AOP -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-aop</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
|
||||
@@ -28,6 +28,11 @@ public class MQPublisher {
|
||||
amqpTemplate.convertAndSend(rabbitMQProperties.getQueues().getSr(), mm);
|
||||
}
|
||||
|
||||
public void sendGenerateResultMessage(String mm) {
|
||||
log.info("send generate result message: {}", mm);
|
||||
amqpTemplate.convertAndSend(rabbitMQProperties.getQueues().getGenerateResult(), mm);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param mailParams 含有的字段
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
package com.ai.da.common.aspect;
|
||||
|
||||
import com.ai.da.common.context.UserContext;
|
||||
import com.ai.da.model.vo.AuthPrincipalVo;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import org.aspectj.lang.JoinPoint;
|
||||
import org.aspectj.lang.annotation.*;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.context.request.RequestContextHolder;
|
||||
import org.springframework.web.context.request.ServletRequestAttributes;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Controller日志切面
|
||||
* 记录所有Controller接口的请求参数和用户信息
|
||||
*/
|
||||
@Aspect
|
||||
@Component
|
||||
public class ControllerLoggingAspect {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(ControllerLoggingAspect.class);
|
||||
|
||||
/**
|
||||
* 定义切点:所有Controller方法
|
||||
*/
|
||||
@Pointcut("execution(* com.ai.da.controller..*(..))")
|
||||
public void controllerMethods() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Controller方法执行前记录日志
|
||||
*/
|
||||
// @Before("controllerMethods()")
|
||||
public void logControllerBefore(JoinPoint joinPoint) {
|
||||
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
if (attributes != null) {
|
||||
HttpServletRequest request = attributes.getRequest();
|
||||
|
||||
// 获取当前用户ID
|
||||
Long userId = null;
|
||||
AuthPrincipalVo authPrincipalVo = UserContext.getUserHolder();
|
||||
if (authPrincipalVo != null) {
|
||||
userId = authPrincipalVo.getId();
|
||||
}
|
||||
|
||||
// 获取请求参数
|
||||
Map<String, Object> params = getRequestParams(joinPoint, request);
|
||||
|
||||
logger.info("=== 请求开始 ===");
|
||||
logger.info("用户ID: {}", userId);
|
||||
logger.info("请求URL: {}", request.getRequestURL().toString());
|
||||
logger.info("请求方法: {}", request.getMethod());
|
||||
logger.info("请求IP: {}", getClientIpAddress(request));
|
||||
logger.info("调用方法: {}.{}", joinPoint.getSignature().getDeclaringType().getSimpleName(), joinPoint.getSignature().getName());
|
||||
logger.info("请求参数: {}", params);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取请求参数
|
||||
*/
|
||||
private Map<String, Object> getRequestParams(JoinPoint joinPoint, HttpServletRequest request) {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
|
||||
// 1. 获取Query String参数
|
||||
String queryString = request.getQueryString();
|
||||
if (queryString != null && !queryString.isEmpty()) {
|
||||
params.put("queryString", queryString);
|
||||
}
|
||||
|
||||
// 2. 获取方法参数(包含 @PathVariable, @RequestParam, @RequestBody 等)
|
||||
Object[] args = joinPoint.getArgs();
|
||||
|
||||
if (args != null && args.length > 0) {
|
||||
Map<String, Object> methodParams = new HashMap<>();
|
||||
for (int i = 0; i < args.length; i++) {
|
||||
Object arg = args[i];
|
||||
// 过滤掉不可序列化的参数
|
||||
if (arg != null) {
|
||||
if (isIgnorable(arg)) {
|
||||
// 对于可忽略的类型,记录类型名
|
||||
methodParams.put("arg" + i, "[" + arg.getClass().getSimpleName() + "]");
|
||||
} else {
|
||||
try {
|
||||
methodParams.put("arg" + i, arg);
|
||||
} catch (Exception e) {
|
||||
methodParams.put("arg" + i, arg.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!methodParams.isEmpty()) {
|
||||
params.put("methodParams", methodParams);
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断是否需要过滤的参数类型
|
||||
*/
|
||||
private boolean isIgnorable(Object obj) {
|
||||
return obj instanceof HttpServletRequest
|
||||
|| obj instanceof HttpServletResponse
|
||||
|| obj instanceof MultipartFile
|
||||
|| obj instanceof MultipartFile[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Controller方法抛出异常时记录日志
|
||||
*/
|
||||
@AfterThrowing(pointcut = "controllerMethods()", throwing = "exception")
|
||||
public void logControllerAfterThrowing(JoinPoint joinPoint, Throwable exception) {
|
||||
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
|
||||
|
||||
Long userId = null;
|
||||
AuthPrincipalVo authPrincipalVo = UserContext.getUserHolder();
|
||||
if (authPrincipalVo != null) {
|
||||
userId = authPrincipalVo.getId();
|
||||
}
|
||||
|
||||
// 获取请求参数
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
if (attributes != null) {
|
||||
HttpServletRequest request = attributes.getRequest();
|
||||
params = getRequestParams(joinPoint, request);
|
||||
}
|
||||
|
||||
logger.error("=== 请求异常 ===");
|
||||
logger.error("用户ID: {}", userId);
|
||||
logger.error("调用方法: {}.{}", joinPoint.getSignature().getDeclaringType().getSimpleName(), joinPoint.getSignature().getName());
|
||||
logger.error("请求参数: {}", params);
|
||||
logger.error("异常信息: ", exception);
|
||||
logger.error("=== 异常结束 ===");
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取客户端真实IP地址
|
||||
*/
|
||||
private String getClientIpAddress(HttpServletRequest request) {
|
||||
String xForwardedFor = request.getHeader("X-Forwarded-For");
|
||||
if (xForwardedFor != null && !xForwardedFor.isEmpty() && !"unknown".equalsIgnoreCase(xForwardedFor)) {
|
||||
return xForwardedFor.split(",")[0];
|
||||
}
|
||||
|
||||
String xRealIp = request.getHeader("X-Real-IP");
|
||||
if (xRealIp != null && !xRealIp.isEmpty() && !"unknown".equalsIgnoreCase(xRealIp)) {
|
||||
return xRealIp;
|
||||
}
|
||||
|
||||
String proxyClientIp = request.getHeader("Proxy-Client-IP");
|
||||
if (proxyClientIp != null && !proxyClientIp.isEmpty() && !"unknown".equalsIgnoreCase(proxyClientIp)) {
|
||||
return proxyClientIp;
|
||||
}
|
||||
|
||||
String wlProxyClientIp = request.getHeader("WL-Proxy-Client-IP");
|
||||
if (wlProxyClientIp != null && !wlProxyClientIp.isEmpty() && !"unknown".equalsIgnoreCase(wlProxyClientIp)) {
|
||||
return wlProxyClientIp;
|
||||
}
|
||||
|
||||
return request.getRemoteAddr();
|
||||
}
|
||||
}
|
||||
@@ -202,7 +202,7 @@ public class MyTaskScheduler {
|
||||
}
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 0 9 * * ?")
|
||||
// @Scheduled(cron = "0 0 9 * * ?")
|
||||
public void sendTrialOrderExcelToManagements() {
|
||||
// 获取前一天日期
|
||||
LocalDate yesterday = LocalDate.now().minusDays(1);
|
||||
|
||||
@@ -23,6 +23,7 @@ public class CommonConstant {
|
||||
}
|
||||
|
||||
public static final String GENERATE_PATH = "/api/generate_image";
|
||||
public static final String GENERATE_PATH_FLUX2_KLEIN = "/api/generate_image_flux2_klein";
|
||||
|
||||
public static final String GENERATE_SINGLE_LOGO = "/api/generate_single_logo";
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ public class ModelConstants {
|
||||
public static final String PRINTBOARD_ADVANCED_T2I = "qwen-image";
|
||||
public static final String MOODBOARD_ADVANCED = "doubao-seedream-3-0-t2i-250415";
|
||||
public static final String PRINTBOARD_HIGH_T2I = "doubao-seedream-3-0-t2i-250415";
|
||||
public static final String PRINTBOARD_HIGH_I2I = "doubao-seededit-3-0-i2i-250628";
|
||||
public static final String PRINTBOARD_HIGH_I2I = "doubao-seedream-4-0-250828-fast";
|
||||
public static final String PRINTBOARD_ADVANCED_I2I = "doubao-seedream-4-0-250828";
|
||||
public static final String IMAGEN_MODEL = "imagen-4.0-generate-001";
|
||||
public static final String NANO_BANANA = "gemini-2.5-flash-image";
|
||||
|
||||
@@ -34,7 +34,7 @@ public class AccountTask {
|
||||
accountService.refreshCreditsMonthly();
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes
|
||||
// @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes
|
||||
public void getPaidUser() {
|
||||
// 获取code-create 表中 指定日期之后 订单状态为wc-processing的订单
|
||||
accountService.extendValidityForCC();
|
||||
|
||||
@@ -45,7 +45,7 @@ public class PaymentTask {
|
||||
@Resource
|
||||
private PayPalCheckoutService payPalCheckoutService;
|
||||
|
||||
// @Scheduled(cron = "0/30 * * * * ?")
|
||||
// @Scheduled(cron = "0/30 * * * * ?")
|
||||
public void orderConfirmForPaypal() throws SerializeException {
|
||||
|
||||
// log.info("PayPal orderConfirm 被执行......");
|
||||
@@ -97,7 +97,7 @@ public class PaymentTask {
|
||||
//
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes
|
||||
// @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes
|
||||
public void updateAffiliateInfoWithPayment(){
|
||||
// log.info("佣金计算定时器");
|
||||
affiliateService.updateAffiliateInfoWithPayment();
|
||||
@@ -109,7 +109,7 @@ public class PaymentTask {
|
||||
affiliateService.syncLinkViewCountToDB();
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 0 8 28-31 * ?")
|
||||
// @Scheduled(cron = "0 0 8 28-31 * ?")
|
||||
public void commissionSummaryReminder(){
|
||||
// 每个月末的最后一天的早上八点执行
|
||||
LocalDate today = LocalDate.now();
|
||||
|
||||
@@ -40,7 +40,7 @@ public class SubscriptionReminderTask {
|
||||
REMINDER_DAYS_CONFIG.put("year", 14);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 0 9 * * ?")
|
||||
// @Scheduled(cron = "0 0 9 * * ?")
|
||||
public void subscriptionReminder() {
|
||||
// 获取所有需要通知的订阅
|
||||
List<SubscriptionInfo> subscriptionInfos = getDueSubscriptions();
|
||||
@@ -97,7 +97,7 @@ public class SubscriptionReminderTask {
|
||||
return subscriptionInfoMapper.selectList(qw);
|
||||
}
|
||||
|
||||
// @Scheduled(cron = "0 0 9 * * ?")
|
||||
// @Scheduled(cron = "0 0 9 * * ?")
|
||||
public void trialReminder() {
|
||||
// 今天的 00:00:00 和 23:59:59
|
||||
LocalDateTime startOfDay = LocalDateTime.now().toLocalDate().atStartOfDay();
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
55
src/main/java/com/ai/da/model/dto/ImageProcessRequest.java
Normal file
55
src/main/java/com/ai/da/model/dto/ImageProcessRequest.java
Normal file
@@ -0,0 +1,55 @@
|
||||
package com.ai.da.model.dto;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 图片处理请求体
|
||||
*/
|
||||
@Data
|
||||
@Builder
|
||||
public class ImageProcessRequest {
|
||||
|
||||
/**
|
||||
* OSS桶名(bucket_name)
|
||||
*/
|
||||
private String bucket_name;
|
||||
|
||||
/**
|
||||
* OSS对象名(object_name)
|
||||
*/
|
||||
private String object_name;
|
||||
|
||||
/**
|
||||
* 输入图片路径列表(input_image_paths)
|
||||
*/
|
||||
private List<String> input_image_paths;
|
||||
|
||||
/**
|
||||
* 图像宽度(width)
|
||||
*/
|
||||
private Integer width;
|
||||
|
||||
/**
|
||||
* 图像高度(height)
|
||||
*/
|
||||
private Integer height;
|
||||
|
||||
/**
|
||||
* 文本提示(prompt)
|
||||
*/
|
||||
private String prompt;
|
||||
|
||||
/**
|
||||
* 推理步数(steps)
|
||||
*/
|
||||
private Integer steps;
|
||||
|
||||
/**
|
||||
* 引导系数(guidance)
|
||||
*/
|
||||
private Double guidance;
|
||||
|
||||
}
|
||||
@@ -2,8 +2,10 @@ package com.ai.da.python;
|
||||
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
import cn.hutool.core.exceptions.ExceptionUtil;
|
||||
import com.ai.da.common.RabbitMQ.RabbitMQProperties;
|
||||
import com.ai.da.common.config.FileProperties;
|
||||
import com.ai.da.common.config.exception.BusinessException;
|
||||
import com.ai.da.common.constant.CommonConstant;
|
||||
import com.ai.da.common.context.UserContext;
|
||||
import com.ai.da.common.enums.*;
|
||||
import com.ai.da.common.utils.*;
|
||||
@@ -20,6 +22,7 @@ import com.ai.da.model.vo.*;
|
||||
import com.ai.da.python.vo.*;
|
||||
import com.ai.da.service.DesignHistoryService;
|
||||
import com.ai.da.service.PythonTAllInfoService;
|
||||
import com.ai.da.service.RabbitMQService;
|
||||
import com.ai.da.service.SysFileService;
|
||||
import com.alibaba.fastjson.*;
|
||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
@@ -69,6 +72,8 @@ public class PythonService {
|
||||
private String accessPythonPort;
|
||||
@Value("${minio.bucketName.gradient}")
|
||||
private String gradientBucketName;
|
||||
@Value("${minio.bucketName.users}")
|
||||
private String userBucketName;
|
||||
@Value("${access.python.generate_sr_port}")
|
||||
private String srServicePort;
|
||||
|
||||
@@ -84,6 +89,12 @@ public class PythonService {
|
||||
@Resource
|
||||
private RedisUtil redisUtil;
|
||||
|
||||
@Resource
|
||||
private RabbitMQService rabbitMQService;
|
||||
|
||||
@Resource
|
||||
private RabbitMQProperties rabbitMQProperties;
|
||||
|
||||
/**
|
||||
* 生成打印的图片 二合一 (废弃于2024/01/02)
|
||||
*
|
||||
@@ -3334,7 +3345,7 @@ public class PythonService {
|
||||
throw new BusinessException("system error!");
|
||||
}
|
||||
|
||||
public Boolean generateSketchOrPrint(String params, String port, String servicePath) {
|
||||
public Boolean generateSketchOrPrint(String params, String port, String servicePath, String taskId) {
|
||||
//限流校验
|
||||
// AccessLimitUtils.validate("generateSketchOrPrint", 5);
|
||||
OkHttpClient client = new OkHttpClient().newBuilder()
|
||||
@@ -3396,12 +3407,36 @@ public class PythonService {
|
||||
if (result && jsonObject.get("code").equals(200)) {
|
||||
log.info("Generate##responseObject###{}", jsonObject);
|
||||
// return setGenerateImageList(jsonObject.getJSONObject("data"));
|
||||
if (servicePath == CommonConstant.GENERATE_PATH_FLUX2_KLEIN) {
|
||||
//放入结果到mq
|
||||
JSONObject data = jsonObject.getJSONObject("data");
|
||||
String outputPath = data.getString("output_path");
|
||||
|
||||
|
||||
Map<String, String> mqMessage = new HashMap<>();
|
||||
mqMessage.put("tasks_id", taskId);
|
||||
mqMessage.put("status", "SUCCESS");
|
||||
mqMessage.put("message", "success");
|
||||
mqMessage.put("image_url", outputPath);
|
||||
mqMessage.put("category", "");
|
||||
String mqMessageBody = JSON.toJSONString(mqMessage);
|
||||
rabbitMQService.publishMessageToGenerateResult(mqMessageBody);
|
||||
}
|
||||
return Boolean.TRUE;
|
||||
} else {
|
||||
log.info("generateSketchOrPrintPrint失败###{}", jsonObject);
|
||||
log.info("Generate Exception! Code : " + jsonObject.get("code"));
|
||||
Map<String, String> mqMessage = new HashMap<>();
|
||||
mqMessage.put("tasks_id", taskId);
|
||||
mqMessage.put("status", "ERROR");
|
||||
mqMessage.put("message", "");
|
||||
mqMessage.put("image_url", "");
|
||||
mqMessage.put("category", "");
|
||||
String mqMessageBody = JSON.toJSONString(mqMessage);
|
||||
rabbitMQService.publishMessageToGenerateResult(mqMessageBody);
|
||||
return Boolean.FALSE;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public Response sendPostToModel(String content, String portAndRoute, String functionName) {
|
||||
@@ -4139,6 +4174,9 @@ public class PythonService {
|
||||
.writeTimeout(60, TimeUnit.SECONDS)
|
||||
.build();
|
||||
MediaType mediaType = MediaType.parse("application/json");
|
||||
content.put("bucket", userBucketName);
|
||||
content.put("object_name", content.get("user_id") + "/" + "segment" + "/" + UUID.randomUUID() + ".png");
|
||||
content.remove("user_id");
|
||||
RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content));
|
||||
|
||||
String url = accessPythonIp + ":" + accessPythonPort + "/api/seg_anything";
|
||||
|
||||
@@ -7,6 +7,8 @@ public interface RabbitMQService {
|
||||
|
||||
void publishMessageToGenerate(String message);
|
||||
|
||||
void publishMessageToGenerateResult(String message);
|
||||
|
||||
void publishMessageToSR(String message);
|
||||
|
||||
Integer getMessageCount(String queueUrl);
|
||||
|
||||
@@ -48,6 +48,8 @@ import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.springframework.transaction.support.TransactionSynchronization;
|
||||
import org.springframework.transaction.support.TransactionSynchronizationManager;
|
||||
import org.apache.http.HttpHost;
|
||||
import org.apache.http.client.config.RequestConfig;
|
||||
import org.apache.http.client.methods.CloseableHttpResponse;
|
||||
@@ -59,10 +61,12 @@ import org.bytedeco.javacv.Java2DFrameConverter;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.dao.DuplicateKeyException;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Propagation;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import jakarta.annotation.Resource;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.*;
|
||||
@@ -194,10 +198,13 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
generate.setText(text);
|
||||
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
|
||||
// validateGeneraType(generate, text, elementId);
|
||||
if (!StringUtil.isNullOrEmpty(text)) {
|
||||
text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type(), generateThroughImageTextDTO.getAgeGroup());
|
||||
if (!(generateThroughImageTextDTO.getLevel1Type().equals(MOOD_BOARD.getRealName())&&generateThroughImageTextDTO.getModelName().equals("high"))){
|
||||
if (!StringUtil.isNullOrEmpty(text)) {
|
||||
text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type(), generateThroughImageTextDTO.getAgeGroup());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// todo 这一步现在还是有必要的吗?
|
||||
// 2.1 sketch或print在t_collection_element表/t_library表中的信息是否需要更新 如 level2Type
|
||||
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType());
|
||||
@@ -218,6 +225,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
version = "fast";
|
||||
params.put("version", "fast");
|
||||
}
|
||||
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
|
||||
saveGenerateImmediately(generate);
|
||||
// 3.1 确定不同类型的印花分别调哪个接口
|
||||
if (generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())) {
|
||||
switch (generateThroughImageTextDTO.getLevel2Type()) {
|
||||
@@ -243,15 +252,28 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue);
|
||||
}
|
||||
} else {
|
||||
GenerateToPythonDTO generateToPythonDTO = new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
|
||||
mode, category, generateThroughImageTextDTO.getGender(), version);
|
||||
jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue);
|
||||
if (Objects.equals(version, "fast")) {
|
||||
GenerateToPythonDTO generateToPythonDTO = new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
|
||||
mode, category, generateThroughImageTextDTO.getGender(), version);
|
||||
jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue);
|
||||
} else {
|
||||
|
||||
|
||||
path = CommonConstant.GENERATE_PATH_FLUX2_KLEIN;
|
||||
// 构建object_name: {userId}/{category}/{uuid}.png
|
||||
String objectName = generateThroughImageTextDTO.getUserId() + "/" + category + "/" + UUID.randomUUID() + ".png";
|
||||
|
||||
ImageProcessRequest imageProcessRequest = ImageProcessRequest.builder()
|
||||
.object_name(objectName)
|
||||
.bucket_name(userBucket)
|
||||
.prompt(text).build();
|
||||
jsonString = JSON.toJSONString(imageProcessRequest);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Boolean requestResult = pythonService.generateSketchOrPrint(jsonString, port, path);
|
||||
Boolean requestResult = pythonService.generateSketchOrPrint(jsonString, port, path, generateThroughImageTextDTO.getUniqueId());
|
||||
|
||||
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
|
||||
save(generate);
|
||||
|
||||
// 5、将本次请求存入redis
|
||||
String key = generateResultKey + ":" + generateThroughImageTextDTO.getUniqueId();
|
||||
@@ -266,6 +288,40 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
|
||||
}
|
||||
|
||||
public void saveGenerateImmediately(Generate generate) {
|
||||
save(generate);
|
||||
// 使用 TransactionSynchronizationManager 在事务真正提交后再设锁
|
||||
// 否则 save() 完成后事务尚未 commit,MQ 消费者立即读到 null
|
||||
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
|
||||
@Override
|
||||
public void afterCommit() {
|
||||
String lockKey = "generate:lock:" + generate.getUniqueId();
|
||||
redisUtil.addToString(lockKey, "1", 60L);
|
||||
log.debug("Save lock set after commit for uniqueId: {}", generate.getUniqueId());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void waitForSaveLock(String uniqueId) {
|
||||
String lockKey = "generate:lock:" + uniqueId;
|
||||
int maxRetries = 30;
|
||||
int retryIntervalMs = 200;
|
||||
for (int i = 0; i < maxRetries; i++) {
|
||||
if (Boolean.TRUE.equals(redisUtil.hasKey(lockKey))) {
|
||||
log.debug("Save lock acquired for uniqueId: {} after {} retries", uniqueId, i);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
Thread.sleep(retryIntervalMs);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
log.warn("Interrupted while waiting for save lock: {}", uniqueId);
|
||||
return;
|
||||
}
|
||||
}
|
||||
log.warn("Save lock timeout for uniqueId: {}, proceeding anyway", uniqueId);
|
||||
}
|
||||
|
||||
public GenerateModeEnum getMode(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getText())) {
|
||||
if (Objects.nonNull(generateThroughImageTextDTO.getCollectionElementId())) {
|
||||
@@ -284,11 +340,16 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
@Override
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public void processGenerateResult(String taskId, String url, String category) {
|
||||
log.info("============ProcessGenerateResult listening==========");
|
||||
log.debug("taskId: " + taskId);
|
||||
String status = null;
|
||||
// 1、处理模型返回的数据
|
||||
GenerateDetail generateDetail = new GenerateDetail();
|
||||
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
|
||||
Generate generate;
|
||||
try {
|
||||
// 等待 HTTP 线程写入完成后再查库
|
||||
waitForSaveLock(taskId);
|
||||
generate = selectByUniqueId(taskId);
|
||||
} catch (MybatisPlusException e) {
|
||||
log.error(e.getMessage());
|
||||
@@ -311,14 +372,15 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
generateDetail.setUrl(url);
|
||||
generateDetail.setGenerateId(generate.getId());
|
||||
generateDetail.setCreateDate(LocalDateTime.now());
|
||||
generateDetail.setMd5(md5);
|
||||
generateDetail.setMd5("");
|
||||
// 将相应的url保存到数据库
|
||||
generateDetailMapper.insert(generateDetail);
|
||||
log.debug("generateDetail: " + generateDetail.toString());
|
||||
|
||||
// String uuid = taskId.substring(0, taskId.substring(0, taskId.lastIndexOf("-")).lastIndexOf("-"));
|
||||
String key = generateResultKey + ":" + taskId;
|
||||
String imageName = url.substring(url.lastIndexOf("/") + 1);
|
||||
String status = imageName.equals("white_image.jpg") ? "Invalid" : "Success";
|
||||
status = imageName.equals("white_image.jpg") ? "Invalid" : "Success";
|
||||
if (StringUtil.isNullOrEmpty(category)) {
|
||||
Generate generateRecord = selectByUniqueId(taskId);
|
||||
category = generateRecord.getLevel2Type();
|
||||
@@ -326,6 +388,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
GenerateResultVO generateResultVO = new GenerateResultVO(taskId, generateDetail.getId(), url, status, category);
|
||||
// 更新redis
|
||||
redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
|
||||
log.debug("generateResultVO: " + generateResultVO.toString());
|
||||
|
||||
|
||||
// 执行积分扣除
|
||||
// ** 注:如果生成的图片都是空白 则不扣积分
|
||||
@@ -785,8 +849,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
long requestEndTime = System.currentTimeMillis();
|
||||
log.info("HTTP请求完成 - 响应状态: {}, 耗时: {}ms, taskId: {}",
|
||||
response.code(), (requestEndTime - requestStartTime), taskId);
|
||||
String result = response.body().string();
|
||||
if (!response.isSuccessful()) {
|
||||
log.warn("Google API响应失败,状态码: {} for taskId: {}", response.code(), taskId);
|
||||
log.warn("Google API响应失败,状态码: {} for taskId: {},结果:{}", response.code(), taskId, result);
|
||||
if (attempt < maxRetries) {
|
||||
Thread.sleep(retryDelay * attempt); // 递增延迟
|
||||
continue;
|
||||
@@ -795,7 +860,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
}
|
||||
}
|
||||
|
||||
String result = response.body().string();
|
||||
|
||||
// log.info("Google 响应结果:{}", result);
|
||||
com.alibaba.fastjson.JSONObject jsonResponse = JSON.parseObject(result);
|
||||
|
||||
@@ -1065,6 +1130,12 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
|
||||
String result = response.body().string();
|
||||
|
||||
if (response.code() != 200) {
|
||||
log.error("Google API 请求失败 - taskId: {}, 尝试: {}, URL: {}, 状态码: {}, 响应结果: {}",
|
||||
taskId, attempt, endpoint, response.code(), result);
|
||||
throw new BusinessException("system.error");
|
||||
}
|
||||
|
||||
// log.info("Google 响应结果:{}", result);
|
||||
com.alibaba.fastjson.JSONObject jsonResponse = JSON.parseObject(result);
|
||||
|
||||
@@ -1203,7 +1274,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
* @param modelName advanced high normal
|
||||
*/
|
||||
private HashMap<String, String> chooseModelAndPrompt(GenerateThroughImageTextDTO generateDTO, String modelName) {
|
||||
if (StringUtil.isNullOrEmpty(modelName)){
|
||||
if (StringUtil.isNullOrEmpty(modelName)) {
|
||||
throw new BusinessException("system error");
|
||||
}
|
||||
HashMap<String, String> modelAndPromptMap = new HashMap<>();
|
||||
@@ -1221,7 +1292,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
String style = generateDTO.getText().substring(0, firstCommaIndex).trim();
|
||||
|
||||
String prompt = generateDTO.getText().substring(firstCommaIndex + 1).trim();
|
||||
prompt = getPrintboardPrompt(style, prompt,modelName,isUseImage);
|
||||
prompt = getPrintboardPrompt(style, prompt, modelName, isUseImage);
|
||||
modelAndPromptMap.put(ModelConstants.PROMPT, prompt);
|
||||
|
||||
|
||||
@@ -1260,14 +1331,31 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
modelAndPromptMap.put(ModelConstants.USE_MODEL, ModelConstants.LOCAL_MODEL);
|
||||
}
|
||||
} else if (ModelConstants.SKETCHBOARD.equals(generateDTO.getLevel1Type())) {
|
||||
String[] split = generateDTO.getText().split(",");
|
||||
String style = split[0].trim();
|
||||
if ("Lolita".equals( style)){
|
||||
style = "洛丽塔";
|
||||
String style = "";
|
||||
String userPrompt = "";
|
||||
// 找到第一个逗号的位置
|
||||
int firstCommaIndex = generateDTO.getText().indexOf(",");
|
||||
if (firstCommaIndex != -1) {
|
||||
// 截取第一个逗号前的内容作为style
|
||||
style = generateDTO.getText().substring(0, firstCommaIndex).trim();
|
||||
// 截取第一个逗号后的所有内容作为userPrompt(去除首尾空格)
|
||||
userPrompt = generateDTO.getText().substring(firstCommaIndex + 1).trim();
|
||||
|
||||
if ("Lolita".equals(style)) {
|
||||
style = "洛丽塔";
|
||||
}
|
||||
} else {
|
||||
// 兼容无逗号的情况:style为空,全部内容作为userPrompt
|
||||
userPrompt = generateDTO.getText().trim();
|
||||
}
|
||||
String userPrompt = split[1];
|
||||
|
||||
String prompt = userPrompt + "rules:front view sketch only,plain white background, single garment only, orthographic, centered on white background, borderless canvas, thin monochrome black line art.\n" +
|
||||
" No clothes hanger, no fake clothes hanger, no human-related lines, no color fill, no words, no text, no black background, no boundary or frame.sketch style:"+ style;
|
||||
" No clothes hanger, no fake clothes hanger, no human-related lines, no color fill, no words, no text, no black background, no boundary or frame.";
|
||||
|
||||
if (!style.trim().isEmpty() && !"all".equalsIgnoreCase(style)) {
|
||||
prompt += ".sketch style:" + style.trim();
|
||||
}
|
||||
|
||||
modelAndPromptMap.put(ModelConstants.PROMPT, prompt);
|
||||
if (isUseImage) {
|
||||
if (ModelConstants.ADVANCED.equals(modelName)) {
|
||||
@@ -1465,6 +1553,13 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
if (imagePath != null) {
|
||||
requestBuilder.image(finalImagePath1);
|
||||
}
|
||||
if (useModel.equals(ModelConstants.PRINTBOARD_HIGH_I2I)) {
|
||||
GenerateImagesRequest.OptimizePromptOptions optimizePromptOptions = new GenerateImagesRequest.OptimizePromptOptions();
|
||||
optimizePromptOptions.setMode("fast");
|
||||
requestBuilder.optimizePromptOptions(optimizePromptOptions);
|
||||
//由于PRINTBOARD_HIGH_I2I与PRINTBOARD_ADVANCED_I2I使用模型一致,为了区别积分扣除,PRINTBOARD_HIGH_I2I加入了-fast,但传入模型时需要去掉-fast,用PRINTBOARD_ADVANCED_I2I的常量做替代
|
||||
requestBuilder.model(ModelConstants.PRINTBOARD_ADVANCED_I2I);
|
||||
}
|
||||
|
||||
// 保存生成记录到数据库
|
||||
Generate generate = new Generate(
|
||||
@@ -1602,14 +1697,14 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
"Flat textile pattern printed directly on fabric surface, no three-dimensional objects, no items placed on cloth. \n" +
|
||||
"Real style: fabric print, realistic woven/printed pattern, detailed surface pattern only";
|
||||
}
|
||||
}else {
|
||||
throw new BusinessException("style error:"+ style);
|
||||
} else {
|
||||
throw new BusinessException("style error:" + style);
|
||||
}
|
||||
|
||||
if (userInput == null || userInput.trim().isEmpty()) {
|
||||
if (isUseImage){
|
||||
if (isUseImage) {
|
||||
prompt = "Theme: Image content" + "\nRequirement: " + systemPrompt;
|
||||
}else {
|
||||
} else {
|
||||
throw new BusinessException("prompt null");
|
||||
}
|
||||
} else {
|
||||
@@ -2011,7 +2106,9 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
public Generate selectByUniqueId(String uniqueId) {
|
||||
QueryWrapper<Generate> qw = new QueryWrapper<>();
|
||||
qw.eq("unique_id", uniqueId);
|
||||
|
||||
log.debug("selectByUniqueId: " + uniqueId);
|
||||
Generate one = getOne(qw);
|
||||
log.debug("Generate: " + one);
|
||||
return getOne(qw);
|
||||
}
|
||||
|
||||
@@ -4182,11 +4279,11 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
|
||||
// 处理不同状态
|
||||
switch (statusEnum) {
|
||||
case TASK_NOT_FOUND:
|
||||
// 审核没过
|
||||
// 审核没过
|
||||
case REQUEST_MODERATED:
|
||||
// 审核没过
|
||||
// 审核没过
|
||||
case CONTENT_MODERATED:
|
||||
// 出错
|
||||
// 出错
|
||||
case ERROR:
|
||||
return "Fail";
|
||||
case PENDING_F:
|
||||
|
||||
@@ -31,6 +31,11 @@ public class RabbitMQServiceImpl implements RabbitMQService {
|
||||
mqPublisher.sendGenerateMessage(message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void publishMessageToGenerateResult(String message) {
|
||||
mqPublisher.sendGenerateResultMessage(message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void publishMessageToSR(String message) {
|
||||
mqPublisher.sendSRMessage(message);
|
||||
|
||||
Reference in New Issue
Block a user