diff --git a/pom.xml b/pom.xml
index db8b8b4d..9c14ce4b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -427,6 +427,11 @@
bcpkix-jdk18on
1.78.1
+
+
+ org.springframework.boot
+ spring-boot-starter-aop
+
diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java
index 1312ba83..833df8ca 100644
--- a/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java
+++ b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java
@@ -28,6 +28,11 @@ public class MQPublisher {
amqpTemplate.convertAndSend(rabbitMQProperties.getQueues().getSr(), mm);
}
+ public void sendGenerateResultMessage(String mm) {
+ log.info("send generate result message: {}", mm);
+ amqpTemplate.convertAndSend(rabbitMQProperties.getQueues().getGenerateResult(), mm);
+ }
+
/**
*
* @param mailParams 含有的字段
diff --git a/src/main/java/com/ai/da/common/aspect/ControllerLoggingAspect.java b/src/main/java/com/ai/da/common/aspect/ControllerLoggingAspect.java
new file mode 100644
index 00000000..391a0dda
--- /dev/null
+++ b/src/main/java/com/ai/da/common/aspect/ControllerLoggingAspect.java
@@ -0,0 +1,170 @@
+package com.ai.da.common.aspect;
+
+import com.ai.da.common.context.UserContext;
+import com.ai.da.model.vo.AuthPrincipalVo;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import org.aspectj.lang.JoinPoint;
+import org.aspectj.lang.annotation.*;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.stereotype.Component;
+import org.springframework.web.context.request.RequestContextHolder;
+import org.springframework.web.context.request.ServletRequestAttributes;
+import org.springframework.web.multipart.MultipartFile;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Controller日志切面
+ * 记录所有Controller接口的请求参数和用户信息
+ */
+@Aspect
+@Component
+public class ControllerLoggingAspect {
+
+ private static final Logger logger = LoggerFactory.getLogger(ControllerLoggingAspect.class);
+
+ /**
+ * 定义切点:所有Controller方法
+ */
+ @Pointcut("execution(* com.ai.da.controller..*(..))")
+ public void controllerMethods() {
+ }
+
+ /**
+ * Controller方法执行前记录日志
+ */
+// @Before("controllerMethods()")
+ public void logControllerBefore(JoinPoint joinPoint) {
+ ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
+ if (attributes != null) {
+ HttpServletRequest request = attributes.getRequest();
+
+ // 获取当前用户ID
+ Long userId = null;
+ AuthPrincipalVo authPrincipalVo = UserContext.getUserHolder();
+ if (authPrincipalVo != null) {
+ userId = authPrincipalVo.getId();
+ }
+
+ // 获取请求参数
+ Map params = getRequestParams(joinPoint, request);
+
+ logger.info("=== 请求开始 ===");
+ logger.info("用户ID: {}", userId);
+ logger.info("请求URL: {}", request.getRequestURL().toString());
+ logger.info("请求方法: {}", request.getMethod());
+ logger.info("请求IP: {}", getClientIpAddress(request));
+ logger.info("调用方法: {}.{}", joinPoint.getSignature().getDeclaringType().getSimpleName(), joinPoint.getSignature().getName());
+ logger.info("请求参数: {}", params);
+ }
+ }
+
+ /**
+ * 获取请求参数
+ */
+ private Map getRequestParams(JoinPoint joinPoint, HttpServletRequest request) {
+ Map params = new HashMap<>();
+
+ // 1. 获取Query String参数
+ String queryString = request.getQueryString();
+ if (queryString != null && !queryString.isEmpty()) {
+ params.put("queryString", queryString);
+ }
+
+ // 2. 获取方法参数(包含 @PathVariable, @RequestParam, @RequestBody 等)
+ Object[] args = joinPoint.getArgs();
+
+ if (args != null && args.length > 0) {
+ Map methodParams = new HashMap<>();
+ for (int i = 0; i < args.length; i++) {
+ Object arg = args[i];
+ // 过滤掉不可序列化的参数
+ if (arg != null) {
+ if (isIgnorable(arg)) {
+ // 对于可忽略的类型,记录类型名
+ methodParams.put("arg" + i, "[" + arg.getClass().getSimpleName() + "]");
+ } else {
+ try {
+ methodParams.put("arg" + i, arg);
+ } catch (Exception e) {
+ methodParams.put("arg" + i, arg.toString());
+ }
+ }
+ }
+ }
+ if (!methodParams.isEmpty()) {
+ params.put("methodParams", methodParams);
+ }
+ }
+
+ return params;
+ }
+
+ /**
+ * 判断是否需要过滤的参数类型
+ */
+ private boolean isIgnorable(Object obj) {
+ return obj instanceof HttpServletRequest
+ || obj instanceof HttpServletResponse
+ || obj instanceof MultipartFile
+ || obj instanceof MultipartFile[];
+ }
+
+ /**
+ * Controller方法抛出异常时记录日志
+ */
+ @AfterThrowing(pointcut = "controllerMethods()", throwing = "exception")
+ public void logControllerAfterThrowing(JoinPoint joinPoint, Throwable exception) {
+ ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
+
+ Long userId = null;
+ AuthPrincipalVo authPrincipalVo = UserContext.getUserHolder();
+ if (authPrincipalVo != null) {
+ userId = authPrincipalVo.getId();
+ }
+
+ // 获取请求参数
+ Map params = new HashMap<>();
+ if (attributes != null) {
+ HttpServletRequest request = attributes.getRequest();
+ params = getRequestParams(joinPoint, request);
+ }
+
+ logger.error("=== 请求异常 ===");
+ logger.error("用户ID: {}", userId);
+ logger.error("调用方法: {}.{}", joinPoint.getSignature().getDeclaringType().getSimpleName(), joinPoint.getSignature().getName());
+ logger.error("请求参数: {}", params);
+ logger.error("异常信息: ", exception);
+ logger.error("=== 异常结束 ===");
+ }
+
+ /**
+ * 获取客户端真实IP地址
+ */
+ private String getClientIpAddress(HttpServletRequest request) {
+ String xForwardedFor = request.getHeader("X-Forwarded-For");
+ if (xForwardedFor != null && !xForwardedFor.isEmpty() && !"unknown".equalsIgnoreCase(xForwardedFor)) {
+ return xForwardedFor.split(",")[0];
+ }
+
+ String xRealIp = request.getHeader("X-Real-IP");
+ if (xRealIp != null && !xRealIp.isEmpty() && !"unknown".equalsIgnoreCase(xRealIp)) {
+ return xRealIp;
+ }
+
+ String proxyClientIp = request.getHeader("Proxy-Client-IP");
+ if (proxyClientIp != null && !proxyClientIp.isEmpty() && !"unknown".equalsIgnoreCase(proxyClientIp)) {
+ return proxyClientIp;
+ }
+
+ String wlProxyClientIp = request.getHeader("WL-Proxy-Client-IP");
+ if (wlProxyClientIp != null && !wlProxyClientIp.isEmpty() && !"unknown".equalsIgnoreCase(wlProxyClientIp)) {
+ return wlProxyClientIp;
+ }
+
+ return request.getRemoteAddr();
+ }
+}
diff --git a/src/main/java/com/ai/da/common/config/MyTaskScheduler.java b/src/main/java/com/ai/da/common/config/MyTaskScheduler.java
index be6bd63b..f7c10bcc 100644
--- a/src/main/java/com/ai/da/common/config/MyTaskScheduler.java
+++ b/src/main/java/com/ai/da/common/config/MyTaskScheduler.java
@@ -202,7 +202,7 @@ public class MyTaskScheduler {
}
}
-// @Scheduled(cron = "0 0 9 * * ?")
+ // @Scheduled(cron = "0 0 9 * * ?")
public void sendTrialOrderExcelToManagements() {
// 获取前一天日期
LocalDate yesterday = LocalDate.now().minusDays(1);
diff --git a/src/main/java/com/ai/da/common/constant/CommonConstant.java b/src/main/java/com/ai/da/common/constant/CommonConstant.java
index 4fb6ec16..9f7ac8bc 100644
--- a/src/main/java/com/ai/da/common/constant/CommonConstant.java
+++ b/src/main/java/com/ai/da/common/constant/CommonConstant.java
@@ -23,6 +23,7 @@ public class CommonConstant {
}
public static final String GENERATE_PATH = "/api/generate_image";
+ public static final String GENERATE_PATH_FLUX2_KLEIN = "/api/generate_image_flux2_klein";
public static final String GENERATE_SINGLE_LOGO = "/api/generate_single_logo";
diff --git a/src/main/java/com/ai/da/common/constant/ModelConstants.java b/src/main/java/com/ai/da/common/constant/ModelConstants.java
index bc0b1745..fab503bb 100644
--- a/src/main/java/com/ai/da/common/constant/ModelConstants.java
+++ b/src/main/java/com/ai/da/common/constant/ModelConstants.java
@@ -20,7 +20,7 @@ public class ModelConstants {
public static final String PRINTBOARD_ADVANCED_T2I = "qwen-image";
public static final String MOODBOARD_ADVANCED = "doubao-seedream-3-0-t2i-250415";
public static final String PRINTBOARD_HIGH_T2I = "doubao-seedream-3-0-t2i-250415";
- public static final String PRINTBOARD_HIGH_I2I = "doubao-seededit-3-0-i2i-250628";
+ public static final String PRINTBOARD_HIGH_I2I = "doubao-seedream-4-0-250828-fast";
public static final String PRINTBOARD_ADVANCED_I2I = "doubao-seedream-4-0-250828";
public static final String IMAGEN_MODEL = "imagen-4.0-generate-001";
public static final String NANO_BANANA = "gemini-2.5-flash-image";
diff --git a/src/main/java/com/ai/da/common/task/AccountTask.java b/src/main/java/com/ai/da/common/task/AccountTask.java
index ee573eaa..e0d90480 100644
--- a/src/main/java/com/ai/da/common/task/AccountTask.java
+++ b/src/main/java/com/ai/da/common/task/AccountTask.java
@@ -34,7 +34,7 @@ public class AccountTask {
accountService.refreshCreditsMonthly();
}
-// @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes
+ // @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes
public void getPaidUser() {
// 获取code-create 表中 指定日期之后 订单状态为wc-processing的订单
accountService.extendValidityForCC();
diff --git a/src/main/java/com/ai/da/common/task/PaymentTask.java b/src/main/java/com/ai/da/common/task/PaymentTask.java
index 52a6e2d6..e63e9388 100644
--- a/src/main/java/com/ai/da/common/task/PaymentTask.java
+++ b/src/main/java/com/ai/da/common/task/PaymentTask.java
@@ -45,7 +45,7 @@ public class PaymentTask {
@Resource
private PayPalCheckoutService payPalCheckoutService;
-// @Scheduled(cron = "0/30 * * * * ?")
+ // @Scheduled(cron = "0/30 * * * * ?")
public void orderConfirmForPaypal() throws SerializeException {
// log.info("PayPal orderConfirm 被执行......");
@@ -97,7 +97,7 @@ public class PaymentTask {
//
}
-// @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes
+ // @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes
public void updateAffiliateInfoWithPayment(){
// log.info("佣金计算定时器");
affiliateService.updateAffiliateInfoWithPayment();
@@ -109,7 +109,7 @@ public class PaymentTask {
affiliateService.syncLinkViewCountToDB();
}
-// @Scheduled(cron = "0 0 8 28-31 * ?")
+ // @Scheduled(cron = "0 0 8 28-31 * ?")
public void commissionSummaryReminder(){
// 每个月末的最后一天的早上八点执行
LocalDate today = LocalDate.now();
diff --git a/src/main/java/com/ai/da/common/task/SubscriptionReminderTask.java b/src/main/java/com/ai/da/common/task/SubscriptionReminderTask.java
index 0c6ada03..26478270 100644
--- a/src/main/java/com/ai/da/common/task/SubscriptionReminderTask.java
+++ b/src/main/java/com/ai/da/common/task/SubscriptionReminderTask.java
@@ -40,7 +40,7 @@ public class SubscriptionReminderTask {
REMINDER_DAYS_CONFIG.put("year", 14);
}
-// @Scheduled(cron = "0 0 9 * * ?")
+ // @Scheduled(cron = "0 0 9 * * ?")
public void subscriptionReminder() {
// 获取所有需要通知的订阅
List subscriptionInfos = getDueSubscriptions();
@@ -97,7 +97,7 @@ public class SubscriptionReminderTask {
return subscriptionInfoMapper.selectList(qw);
}
-// @Scheduled(cron = "0 0 9 * * ?")
+ // @Scheduled(cron = "0 0 9 * * ?")
public void trialReminder() {
// 今天的 00:00:00 和 23:59:59
LocalDateTime startOfDay = LocalDateTime.now().toLocalDate().atStartOfDay();
diff --git a/src/main/java/com/ai/da/common/utils/RedisUtil.java b/src/main/java/com/ai/da/common/utils/RedisUtil.java
index 1264a1e4..c91411fa 100644
--- a/src/main/java/com/ai/da/common/utils/RedisUtil.java
+++ b/src/main/java/com/ai/da/common/utils/RedisUtil.java
@@ -1,659 +1,659 @@
-package com.ai.da.common.utils;
-
-import com.ai.da.model.dto.ProgressDTO;
-import com.ai.da.python.vo.DesignPythonObject;
-import com.alibaba.fastjson.JSON;
-import com.fasterxml.jackson.core.JsonProcessingException;
-import com.fasterxml.jackson.core.type.TypeReference;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import io.netty.util.internal.StringUtil;
-import lombok.extern.slf4j.Slf4j;
-import org.apache.commons.lang3.StringUtils;
-import org.springframework.data.redis.core.RedisTemplate;
-import org.springframework.data.redis.core.ValueOperations;
-import org.springframework.data.redis.core.ZSetOperations;
-import org.springframework.data.redis.core.script.DefaultRedisScript;
-import org.springframework.stereotype.Component;
-import org.springframework.util.CollectionUtils;
-
-import jakarta.annotation.Resource;
-import java.math.BigDecimal;
-import java.math.RoundingMode;
-import java.time.Duration;
-import java.time.LocalDateTime;
-import java.time.format.DateTimeFormatter;
-import java.util.*;
-import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
-
-@Slf4j
-@Component
-public class RedisUtil {
-
- @Resource
- private RedisTemplate redisTemplate;
-
- public final static String FLUX_POLLING_URL = "Flux:";
- /**
- * 登录 token 在 Redis 中的前缀:
- * 最终 key 结构为 login:token:{userId}
- */
- public final static String LOGIN_TOKEN_KEY = "login:token:";
-
- public Boolean hasKey(String key){
- return redisTemplate.hasKey(key);
- }
-
- //- - - - - - - - - - - - - - - - - - - - - ZSet类型 - - - - - - - - - - - - - - - - - - - -
-
- /**
- * 向ZSet中添加元素
- */
- public void addToZSet(String key, String value, Double score) {
- redisTemplate.opsForZSet().add(key, value, score);
- }
-
- /**
- * 从ZSet中删除元素
- */
- public void removeFromZSet(String key, String value) {
- redisTemplate.opsForZSet().remove(key, value);
- }
-
- /**
- * 获取指定元素的当前排列顺序
- */
- public Long getRank(String key, String value) {
- return redisTemplate.opsForZSet().rank(key, value);
- }
-
- /**
- * 获取当前ZSet中的最大score
- */
- public Double getMaxScore(String key) {
- Set> set = redisTemplate.opsForZSet().reverseRangeWithScores(key, 0, 0);
-
- if (!CollectionUtils.isEmpty(set)) {
- Double score = set.iterator().next().getScore();
- return score + 1.0;
- } else {
- return 1.0;
- }
- }
-
- /**
- * 判断元素是否存在
- */
- public Boolean isElementExistsInZSet(String key, String value) {
- return redisTemplate.opsForZSet().score(key, value) != null;
- }
-
- /**
- * 获取当前ZSet中数据量的总和
- */
- public Long getZSetTotalCount(String key) {
- return redisTemplate.opsForZSet().zCard(key);
- }
-
-
- public Set getZSetTotalData(String key){
- return redisTemplate.opsForZSet().range(key, 0, -1);
- }
-
- //- - - - - - - - - - - - - - - - - - - - - set类型 - - - - - - - - - - - - - - - - - - - -
-
- public final static String VIDEO_FINISHED_TASKS = "VideoFinishedTasks";
-
- /**
- * 将数据放入set缓存
- */
- public void addToSet(String key, String value, Long expiresIn) {
- redisTemplate.opsForSet().add(key, value);
- // 设置过期时间
- redisTemplate.expire(key, expiresIn, TimeUnit.SECONDS);
- }
-
- /**
- * 弹出变量中的元素
- */
- public void removeFromSet(String key, String value) {
- redisTemplate.opsForSet().remove(key, value);
- }
-
- /**
- * 检查给定的元素是否在变量中。
- */
- public Boolean isElementExistsInSet(String key, String obj) {
- return redisTemplate.opsForSet().isMember(key, obj);
- }
-
-
- //- - - - - - - - - - - - - - - - - - - - - hash类型 - - - - - - - - - - - - - - - - - - - -
-
- /**
- * 加入缓存
- */
- public void addToMap(String key, Map map) {
- redisTemplate.opsForHash().putAll(key, map);
- }
-
- /**
- * 验证指定 key 下 有没有指定的 hashkey
- */
- public Boolean isElementExistsInMap(String key, String hashKey) {
- return redisTemplate.opsForHash().hasKey(key, hashKey);
- }
-
- /**
- * 获取指定key的值string
- */
- public String getMapValue(String key1, String key2) {
- return String.valueOf(redisTemplate.opsForHash().get(key1, key2));
- }
-
- /**
- * 删除指定 hash 的 HashKey
- *
- * @return 删除成功的 数量
- */
- public Long removeFromMap(String key, String hashKeys) {
- return redisTemplate.opsForHash().delete(key, hashKeys);
- }
-
- //- - - - - - - - - - - - - - - - - - - - - String类型 - - - - - - - - - - - - - - - - - - - -
- public void addToString(String key, String value){
- redisTemplate.opsForValue().set(key,value);
- }
-
- public void addToString(String key, String value, Long expiresIn){
- redisTemplate.opsForValue().set(key,value,expiresIn, TimeUnit.SECONDS);
- }
-
- public String getFromString(String key){
- return redisTemplate.opsForValue().get(key);
- }
-
- public Set getKeysFromString(String key){
- return redisTemplate.keys(key);
- }
-
- public Long getSize(String key){return redisTemplate.opsForSet().size(key);}
-
- public List getMultiValue(Set keys){
- return redisTemplate.opsForValue().multiGet(keys);
- }
-
- public Long getExpire(String key){
- return redisTemplate.getExpire(key);
- }
-
- public void removeFromString(String key){
- redisTemplate.delete(key);
- }
-
- /**
- * 保存登录 token
- *
- * @param userId 用户 ID
- * @param token token 字符串
- * @param expireMillis 过期时间(毫秒,通常与 JWT 保持一致)
- */
- public void setLoginToken(Long userId, String token, long expireMillis) {
- if (expireMillis <= 0) {
- // 不设置过期时间,直到手动删除(不推荐)
- addToString(LOGIN_TOKEN_KEY + userId, token);
- return;
- }
- long expireSeconds = expireMillis / 1000;
- if (expireSeconds <= 0) {
- expireSeconds = 1;
- }
- addToString(LOGIN_TOKEN_KEY + userId, token, expireSeconds);
- }
-
- /**
- * 获取登录 token
- */
- public String getLoginToken(Long userId) {
- return getFromString(LOGIN_TOKEN_KEY + userId);
- }
-
- /**
- * 删除登录 token
- */
- public void deleteLoginToken(Long userId) {
- removeFromString(LOGIN_TOKEN_KEY + userId);
- }
-
- public final static String PORTFOLIO_LIKE_KEY = "portfolio:like:";
-
- public void likePost(Long portfolioId, Long userId) {
- redisTemplate.opsForSet().add(PORTFOLIO_LIKE_KEY + portfolioId, String.valueOf(userId));
- }
-
- public Long getLikeCount(Long portfolioId) {
- String key = PORTFOLIO_LIKE_KEY + portfolioId;
- return redisTemplate.opsForSet().size(key);
- }
-
- public List getLikedPortfolios(Long userId) {
- // 获取所有包含PORTFOLIO_LIKE_KEY的键
- Set likedPortfolios = redisTemplate.keys(PORTFOLIO_LIKE_KEY + "*");
-
- // 如果没有喜欢的,返回空列表
- if (likedPortfolios == null || likedPortfolios.isEmpty()) {
- return new ArrayList<>();
- }
-
- // 过滤出包含指定用户ID的键,并提取投资组合ID
- return likedPortfolios.stream()
- .filter(key -> redisTemplate.opsForSet().isMember(key, String.valueOf(userId)))
- .map(key -> Long.valueOf(key.replace(PORTFOLIO_LIKE_KEY, "")))
- .collect(Collectors.toList());
- }
-
- public void unLikePost(Long portfolioId, Long userId) {
- redisTemplate.opsForSet().remove(PORTFOLIO_LIKE_KEY + portfolioId, userId.toString());
- }
-
- // 检查用户是否喜欢某个作品
- public boolean isPostLikedByUser(Long portfolioId, Long userId) {
- String key = PORTFOLIO_LIKE_KEY + portfolioId;
- Boolean isMember = redisTemplate.opsForSet().isMember(key, userId.toString());
- return isMember != null && isMember;
- }
-
- public final static String PORTFOLIO_VIEW_KEY = "portfolio:view:";
-
- public void increaseViewCount(Long portfolioId) {
- String key = PORTFOLIO_VIEW_KEY + portfolioId;
- redisTemplate.opsForValue().increment(key);
- }
-
- public Long getViewCount(Long portfolioId) {
- String key = PORTFOLIO_VIEW_KEY + portfolioId;
- return redisTemplate.opsForValue().increment(key, 0);
- }
-
- public Long getViewCount(String key) {
- Object value = redisTemplate.opsForValue().get(key);
- if (value instanceof Integer) {
- return Long.valueOf((Integer) value);
- } else if (value instanceof Long) {
- return (Long) value;
- } else if (value instanceof String) {
- return Long.valueOf((String) value);
- } else {
- throw new IllegalArgumentException("Unexpected value type");
- }
- }
-
- public final static String PERSONAL_HOMEPAGE_VIEW_KEY = "PersonalHomepage:view:";
-
- public void increasePersonalHomepageViewCount(Long accountId) {
- String key = PERSONAL_HOMEPAGE_VIEW_KEY + accountId;
- redisTemplate.opsForValue().increment(key);
- }
-
- public Long getPersonalHomepageViewCount(Long accountId) {
- String key = PERSONAL_HOMEPAGE_VIEW_KEY + accountId;
- return redisTemplate.opsForValue().increment(key, 0);
- }
-
- public final static String MOODBOARD_POSITION_KEY = "moodboard:position:";
-
- public void saveMoodboardPosition(Long id, String moodboardPosition) {
- addToString(MOODBOARD_POSITION_KEY + id, moodboardPosition);
- }
-
- public String getMoodboardPosition(Long id) {
- return getFromString(MOODBOARD_POSITION_KEY + id);
- }
- public final static String NICKNAME_MODIFY_TIMES = "NicknameModifyTimes:";
- public final static String UNNAMED_PROJECT_SEQ = "Project:UnnamedProjectSeq:";
- public Long increaseCount(String key) {
- return redisTemplate.opsForValue().increment(key);
- }
-
- public Long getIncrementCount(String key) {
- return redisTemplate.opsForValue().increment(key, 0);
- }
-
- public void setKeyExpire(String key, Long expire) {
- redisTemplate.expire(key, expire, TimeUnit.DAYS);
- }
-
- public final static String CHANGE_MAILBOX = "ChangeMailbox:";
-
- // 每天允许通知3次
- public final static String UPLOAD_TIMEOUT_REMINDER_COUNTER = "UploadTimeoutReminderCounter";
-
- public void addProcessId(String processId, int progress) {
- // Redis 中的键,可以通过 processId 来唯一标识
- String redisKey = "process:progress:" + processId;
-
- // 将当前进度存储到 Redis
- redisTemplate.opsForValue().set(redisKey, String.valueOf(progress));
-
- // 设置过期时间为 5 分钟(300 秒)
- redisTemplate.expire(redisKey, 5, TimeUnit.MINUTES);
- }
-
- public void addPathToCache(Long collectionId, Long userId, String path) {
- // Redis 中的键,唯一标识由 collectionId 和 userId 组成
- String redisKey = "path:cache:" + collectionId + ":" + userId;
-
- // 增加路径的计数
- redisTemplate.opsForHash().increment(redisKey, path, 1);
-
- // 设置过期时间为 2 小时(7200 秒)
- redisTemplate.expire(redisKey, 8, TimeUnit.HOURS);
- }
-
- public int getPathUsageCount(Long collectionId, Long userId, String path) {
- String redisKey = "path:cache:" + collectionId + ":" + userId;
-
- // 获取路径的使用次数
- Object count = redisTemplate.opsForHash().get(redisKey, path);
- return count != null ? Integer.parseInt(count.toString()) : 0;
- }
-
- public void addAssembledObjects(Long collectionId, Set assembledObjects) {
- // Redis 中的键,使用 collectionId 来唯一标识
- String redisKey = "collection:assembledObjects:" + collectionId;
-
- // 将 assembledObjects 转换为 JSON 格式存储,避免直接存储对象
- String assembledObjectsJson = convertToJson(assembledObjects);
-
- // 使用 Redis 的 set 操作更新集合
- redisTemplate.opsForValue().set(redisKey, assembledObjectsJson);
-
- // 设置过期时间为 5 分钟(300 秒)
- redisTemplate.expire(redisKey, 30, TimeUnit.MINUTES);
- }
-
- // 将 Set 转换为 JSON 格式
- private String convertToJson(Set assembledObjects) {
- try {
- ObjectMapper objectMapper = new ObjectMapper();
- return objectMapper.writeValueAsString(assembledObjects);
- } catch (JsonProcessingException e) {
- e.printStackTrace();
- return null;
- }
- }
-
- public Set getAssembledObjects(Long collectionId) {
- // Redis 中的键,使用 collectionId 来唯一标识
- String redisKey = "collection:assembledObjects:" + collectionId;
-
- // 从 Redis 获取存储的 JSON 字符串
- String assembledObjectsJson = (String) redisTemplate.opsForValue().get(redisKey);
-
- if (assembledObjectsJson == null) {
- return new HashSet<>(); // 如果没有找到数据,返回一个空的 Set
- }
-
- // 将 JSON 字符串转换为 Set
- return convertFromJson(assembledObjectsJson);
- }
-
- // 将 JSON 字符串转换为 Set
- private Set convertFromJson(String json) {
- try {
- ObjectMapper objectMapper = new ObjectMapper();
- // 使用 TypeReference 来指定目标类型是 Set
- return objectMapper.readValue(json, new TypeReference>() {});
- } catch (JsonProcessingException e) {
- e.printStackTrace();
- return new HashSet<>(); // 如果转换失败,返回空的 Set
- }
- }
-
- public final static String PAYMENT_INFO_LAST_SCAN_TIME = "PaymentInfoLastScanTime";
-
- public final static String AFFILIATE_LINK_VIEW_KEY = "AffiliateLink:view:";
-
- public void increaseAffiliateLinkViewCount(Long accountId) {
- String key = AFFILIATE_LINK_VIEW_KEY + accountId;
- redisTemplate.opsForValue().increment(key);
- }
-
- public Long getAffiliateLinkViewCount(Long accountId) {
- String key = AFFILIATE_LINK_VIEW_KEY + accountId;
- return redisTemplate.opsForValue().increment(key, 0);
- }
-
- /**
- * 记录任务的耗时到Redis
- * @param taskKey 任务标识,如 "taskA"
- * @param elapsedTime 本次耗时,单位为毫秒
- */
- public void recordTaskElapsedTime(String taskKey, long elapsedTime) {
- String hashKey = "task:stats";
-
- // 累加总耗时
- redisTemplate.opsForHash().increment(hashKey, taskKey + ":totalTime", elapsedTime);
-
- // 增加计数器
- redisTemplate.opsForHash().increment(hashKey, taskKey + ":count", 1);
- }
-
- /**
- * 获取任务的平均耗时
- * @param taskKey 任务标识,如 "taskA"
- * @return 平均耗时(毫秒)
- */
- public double getTaskAverageTime(String taskKey) {
- String hashKey = "task:stats";
-
- // 获取总耗时和计数
- Object totalTime = redisTemplate.opsForHash().get(hashKey, taskKey + ":totalTime");
- Object count = redisTemplate.opsForHash().get(hashKey, taskKey + ":count");
-
- // 计算平均值
- if (totalTime == null || count == null) {
- return 0;
- }
- return Double.parseDouble(totalTime.toString()) / Long.parseLong(count.toString());
- }
-
- /**
- * 清除指定任务的统计数据
- * @param taskKey 任务标识,如 "taskA"
- */
- public void clearTaskStats(String taskKey) {
- String hashKey = "task:stats";
-
- // 删除总耗时和计数器
- redisTemplate.opsForHash().delete(hashKey, taskKey + ":totalTime", taskKey + ":count");
- }
-
- public void recordTaskElapsedTime(String taskKey, double elapsedTimeInSeconds) {
- // 将耗时转换为 BigDecimal,并四舍五入保留四位小数
- BigDecimal elapsedTime = new BigDecimal(elapsedTimeInSeconds).setScale(4, RoundingMode.HALF_UP);
-
- // 累加总耗时(以毫秒为单位)
- redisTemplate.opsForHash().increment("task:stats", taskKey + ":totalTime", elapsedTime.doubleValue());
-
- // 增加计数器
- redisTemplate.opsForHash().increment("task:stats", taskKey + ":count", 1);
- }
-
- // 获取第一部分(Sketch)耗时
- public double getFirstSketchTime() {
- // 获取 "firstSketchTime:totalTime" 对应的值,并返回(单位为秒)
- Object time = redisTemplate.opsForHash().get("task:stats", "firstSketchTime:totalTime");
- return time != null ? (double) time : 0.0;
- }
-
- // 获取第二部分(获取特征值)耗时
- public double getGetAttributeRecognitionTime() {
- // 获取 "getAttributeRecognitionTime:totalTime" 对应的值,并返回(单位为秒)
- Object time = redisTemplate.opsForHash().get("task:stats", "getAttributeRecognitionTime:totalTime");
- return time != null ? (double) time : 0.0;
- }
-
- // 获取第三部分(搭配 Sketch)耗时
- public double getOtherSketchTime() {
- // 获取 "otherSketchTime:totalTime" 对应的值,并返回(单位为秒)
- Object time = redisTemplate.opsForHash().get("task:stats", "otherSketchTime:totalTime");
- return time != null ? (double) time : 0.0;
- }
-
- // 清理三部分的缓存
- public void clearTaskElapsedTimeCache() {
- // 删除第一部分的缓存
- redisTemplate.opsForHash().delete("task:stats", "firstSketchTime:totalTime");
- redisTemplate.opsForHash().delete("task:stats", "firstSketchTime:count");
-
- // 删除第二部分的缓存
- redisTemplate.opsForHash().delete("task:stats", "getAttributeRecognitionTime:totalTime");
- redisTemplate.opsForHash().delete("task:stats", "getAttributeRecognitionTime:count");
-
- // 删除第三部分的缓存
- redisTemplate.opsForHash().delete("task:stats", "otherSketchTime:totalTime");
- redisTemplate.opsForHash().delete("task:stats", "otherSketchTime:count");
- }
-
- public boolean incrementLikeCount(Long userId, String sketchPath) {
- String redisKey = "user_like_count:" + userId;
- try {
- redisTemplate.opsForHash().increment(redisKey, sketchPath, 1);
- return true;
- } catch (Exception e) {
- log.error("Error incrementing like count for userId {} and sketchPath {}: {}", userId, sketchPath, e.getMessage());
- return false;
- }
- }
-
- public int getLikeCount(Long userId, String sketchPath) {
- String redisKey = "user_like_count:" + userId;
- Object count = redisTemplate.opsForHash().get(redisKey, sketchPath);
- return count != null ? Integer.parseInt(count.toString()) : 0;
- }
-
- public void storeMaxLikeCount(Long userId, int maxLikeCount) {
- String redisKey = "user_max_like_count:" + userId;
- redisTemplate.opsForValue().set(redisKey, String.valueOf(maxLikeCount));
- }
-
- public int getMaxLikeCount(Long userId) {
- String redisKey = "user_max_like_count:" + userId;
- String maxLikeCount = redisTemplate.opsForValue().get(redisKey);
- return maxLikeCount != null ? Integer.parseInt(maxLikeCount) : 0;
- }
-
- public final static String IMAGE_SEGMENTATION = "ImageSegmentation:";
-
- public final static String STRIPE_EXCEPTION_LOG = "StripeException:";
- public final static String SUBSCRIPTION_SENT_EMAIL_TYPE = "SubscriptionEmailSentType:";
-
- public void batchDeleteKeysWithSamePrefix(String prefix){
- Set keys = redisTemplate.keys(prefix + "*");
- assert keys != null;
- if (!keys.isEmpty()){
- redisTemplate.delete(keys);
- }
- }
-
- public void setTaskProgressDTO(String taskId, ProgressDTO dto) {
- String key = "task:progress:" + taskId;
- redisTemplate.opsForValue().set(key, JSON.toJSONString(dto), Duration.ofDays(1));
- }
-
- public ProgressDTO getTaskProgressDTO(String taskId) {
- String key = "task:progress:" + taskId;
- String json = redisTemplate.opsForValue().get(key);
- if (StringUtils.isBlank(json)) {
-// return new ProgressDTO(0, 0, false);
- return null;
- }
- try {
- return JSON.parseObject(json, ProgressDTO.class);
- } catch (Exception e) {
- log.warn("任务进度解析失败 key={}, json={}", key, json);
- return new ProgressDTO(0, 0, false, null);
- }
- }
-
- // Lua脚本(原子化操作)
- /*private static final String RATE_LIMIT_SCRIPT =
- "local current = redis.call('INCR', KEYS[1])\n" +
- "if tonumber(current) == 1 then\n" +
- " redis.call('EXPIRE', KEYS[1], ARGV[1])\n" +
- "end\n" +
- "return tonumber(current) <= tonumber(ARGV[2])";*/
- private static final String RATE_LIMIT_SCRIPT =
- "local current = redis.call('INCR', KEYS[1])\n" +
- "local ttl = redis.call('TTL', KEYS[1])\n" +
- "if tonumber(current) == 1 or tonumber(ttl) == -1 then\n" +
- " redis.call('EXPIRE', KEYS[1], ARGV[1])\n" +
- "end\n" +
- "return tonumber(current) <= tonumber(ARGV[2])";
-
- /**
- * 检查是否允许发送
- * @param userId 用户ID
- * @return true-允许发送,false-已超限
- */
- public boolean allowSend(Long userId) {
- String hourKey = getCurrentHourKey(userId);
-
- // 执行Lua脚本
- List keys = Collections.singletonList(hourKey);
- List args = Arrays.asList(
- 3600L, // 1小时过期
- 10L // 限制数量 一小时只能向普通用户发10封
- );
-
- Boolean result = redisTemplate.execute(
- new DefaultRedisScript<>(RATE_LIMIT_SCRIPT, Boolean.class),
- keys,
- args.toArray()
- );
-
- return Boolean.TRUE.equals(result);
- }
-
- /**
- * 获取当前小时的Key
- * 格式:email_limit:{userId}:{yyyyMMddHH}
- */
- private String getCurrentHourKey(Long userId) {
- String hour = LocalDateTime.now()
- .format(DateTimeFormatter.ofPattern("yyyyMMddHH"));
- return String.format("email_limit:%s:%s", userId, hour);
- }
-
- /**
- * 获取当前已发送数量
- */
- public int getCurrentCount(Long userId) {
- String key = getCurrentHourKey(userId);
- String val = redisTemplate.opsForValue().get(key);
- int count;
- if (StringUtils.isBlank(val)){
- count = 0;
- }else {
- count = Integer.parseInt(val);
- }
- return count;
- }
-
- public boolean allowRequest(String apiKey) {
- String key = "rate_limit:" + apiKey;
- ValueOperations ops = redisTemplate.opsForValue();
-
- // 使用Redis的INCR命令
- Long count = ops.increment(key, 1);
-
- if (count == 1) {
- // 第一次调用,设置过期时间
- redisTemplate.expire(key, 1, TimeUnit.MINUTES);
- }
-
- return count <= 3;
- }
-
+package com.ai.da.common.utils;
+
+import com.ai.da.model.dto.ProgressDTO;
+import com.ai.da.python.vo.DesignPythonObject;
+import com.alibaba.fastjson.JSON;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import io.netty.util.internal.StringUtil;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.StringUtils;
+import org.springframework.data.redis.core.RedisTemplate;
+import org.springframework.data.redis.core.ValueOperations;
+import org.springframework.data.redis.core.ZSetOperations;
+import org.springframework.data.redis.core.script.DefaultRedisScript;
+import org.springframework.stereotype.Component;
+import org.springframework.util.CollectionUtils;
+
+import jakarta.annotation.Resource;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.time.Duration;
+import java.time.LocalDateTime;
+import java.time.format.DateTimeFormatter;
+import java.util.*;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+
+@Slf4j
+@Component
+public class RedisUtil {
+
+ @Resource
+ private RedisTemplate redisTemplate;
+
+ public final static String FLUX_POLLING_URL = "Flux:";
+ /**
+ * 登录 token 在 Redis 中的前缀:
+ * 最终 key 结构为 login:token:{userId}
+ */
+ public final static String LOGIN_TOKEN_KEY = "login:token:";
+
+ public Boolean hasKey(String key){
+ return redisTemplate.hasKey(key);
+ }
+
+ //- - - - - - - - - - - - - - - - - - - - - ZSet类型 - - - - - - - - - - - - - - - - - - - -
+
+ /**
+ * 向ZSet中添加元素
+ */
+ public void addToZSet(String key, String value, Double score) {
+ redisTemplate.opsForZSet().add(key, value, score);
+ }
+
+ /**
+ * 从ZSet中删除元素
+ */
+ public void removeFromZSet(String key, String value) {
+ redisTemplate.opsForZSet().remove(key, value);
+ }
+
+ /**
+ * 获取指定元素的当前排列顺序
+ */
+ public Long getRank(String key, String value) {
+ return redisTemplate.opsForZSet().rank(key, value);
+ }
+
+ /**
+ * 获取当前ZSet中的最大score
+ */
+ public Double getMaxScore(String key) {
+ Set> set = redisTemplate.opsForZSet().reverseRangeWithScores(key, 0, 0);
+
+ if (!CollectionUtils.isEmpty(set)) {
+ Double score = set.iterator().next().getScore();
+ return score + 1.0;
+ } else {
+ return 1.0;
+ }
+ }
+
+ /**
+ * 判断元素是否存在
+ */
+ public Boolean isElementExistsInZSet(String key, String value) {
+ return redisTemplate.opsForZSet().score(key, value) != null;
+ }
+
+ /**
+ * 获取当前ZSet中数据量的总和
+ */
+ public Long getZSetTotalCount(String key) {
+ return redisTemplate.opsForZSet().zCard(key);
+ }
+
+
+ public Set getZSetTotalData(String key){
+ return redisTemplate.opsForZSet().range(key, 0, -1);
+ }
+
+ //- - - - - - - - - - - - - - - - - - - - - set类型 - - - - - - - - - - - - - - - - - - - -
+
+ public final static String VIDEO_FINISHED_TASKS = "VideoFinishedTasks";
+
+ /**
+ * 将数据放入set缓存
+ */
+ public void addToSet(String key, String value, Long expiresIn) {
+ redisTemplate.opsForSet().add(key, value);
+ // 设置过期时间
+ redisTemplate.expire(key, expiresIn, TimeUnit.SECONDS);
+ }
+
+ /**
+ * 弹出变量中的元素
+ */
+ public void removeFromSet(String key, String value) {
+ redisTemplate.opsForSet().remove(key, value);
+ }
+
+ /**
+ * 检查给定的元素是否在变量中。
+ */
+ public Boolean isElementExistsInSet(String key, String obj) {
+ return redisTemplate.opsForSet().isMember(key, obj);
+ }
+
+
+ //- - - - - - - - - - - - - - - - - - - - - hash类型 - - - - - - - - - - - - - - - - - - - -
+
+ /**
+ * 加入缓存
+ */
+ public void addToMap(String key, Map map) {
+ redisTemplate.opsForHash().putAll(key, map);
+ }
+
+ /**
+ * 验证指定 key 下 有没有指定的 hashkey
+ */
+ public Boolean isElementExistsInMap(String key, String hashKey) {
+ return redisTemplate.opsForHash().hasKey(key, hashKey);
+ }
+
+ /**
+ * 获取指定key的值string
+ */
+ public String getMapValue(String key1, String key2) {
+ return String.valueOf(redisTemplate.opsForHash().get(key1, key2));
+ }
+
+ /**
+ * 删除指定 hash 的 HashKey
+ *
+ * @return 删除成功的 数量
+ */
+ public Long removeFromMap(String key, String hashKeys) {
+ return redisTemplate.opsForHash().delete(key, hashKeys);
+ }
+
+ //- - - - - - - - - - - - - - - - - - - - - String类型 - - - - - - - - - - - - - - - - - - - -
+ public void addToString(String key, String value){
+ redisTemplate.opsForValue().set(key,value);
+ }
+
+ public void addToString(String key, String value, Long expiresIn){
+ redisTemplate.opsForValue().set(key,value,expiresIn, TimeUnit.SECONDS);
+ }
+
+ public String getFromString(String key){
+ return redisTemplate.opsForValue().get(key);
+ }
+
+ public Set getKeysFromString(String key){
+ return redisTemplate.keys(key);
+ }
+
+ public Long getSize(String key){return redisTemplate.opsForSet().size(key);}
+
+ public List getMultiValue(Set keys){
+ return redisTemplate.opsForValue().multiGet(keys);
+ }
+
+ public Long getExpire(String key){
+ return redisTemplate.getExpire(key);
+ }
+
+ public void removeFromString(String key){
+ redisTemplate.delete(key);
+ }
+
+ /**
+ * 保存登录 token
+ *
+ * @param userId 用户 ID
+ * @param token token 字符串
+ * @param expireMillis 过期时间(毫秒,通常与 JWT 保持一致)
+ */
+ public void setLoginToken(Long userId, String token, long expireMillis) {
+ if (expireMillis <= 0) {
+ // 不设置过期时间,直到手动删除(不推荐)
+ addToString(LOGIN_TOKEN_KEY + userId, token);
+ return;
+ }
+ long expireSeconds = expireMillis / 1000;
+ if (expireSeconds <= 0) {
+ expireSeconds = 1;
+ }
+ addToString(LOGIN_TOKEN_KEY + userId, token, expireSeconds);
+ }
+
+ /**
+ * 获取登录 token
+ */
+ public String getLoginToken(Long userId) {
+ return getFromString(LOGIN_TOKEN_KEY + userId);
+ }
+
+ /**
+ * 删除登录 token
+ */
+ public void deleteLoginToken(Long userId) {
+ removeFromString(LOGIN_TOKEN_KEY + userId);
+ }
+
+ public final static String PORTFOLIO_LIKE_KEY = "portfolio:like:";
+
+ public void likePost(Long portfolioId, Long userId) {
+ redisTemplate.opsForSet().add(PORTFOLIO_LIKE_KEY + portfolioId, String.valueOf(userId));
+ }
+
+ public Long getLikeCount(Long portfolioId) {
+ String key = PORTFOLIO_LIKE_KEY + portfolioId;
+ return redisTemplate.opsForSet().size(key);
+ }
+
+ public List getLikedPortfolios(Long userId) {
+ // 获取所有包含PORTFOLIO_LIKE_KEY的键
+ Set likedPortfolios = redisTemplate.keys(PORTFOLIO_LIKE_KEY + "*");
+
+ // 如果没有喜欢的,返回空列表
+ if (likedPortfolios == null || likedPortfolios.isEmpty()) {
+ return new ArrayList<>();
+ }
+
+ // 过滤出包含指定用户ID的键,并提取投资组合ID
+ return likedPortfolios.stream()
+ .filter(key -> redisTemplate.opsForSet().isMember(key, String.valueOf(userId)))
+ .map(key -> Long.valueOf(key.replace(PORTFOLIO_LIKE_KEY, "")))
+ .collect(Collectors.toList());
+ }
+
+ public void unLikePost(Long portfolioId, Long userId) {
+ redisTemplate.opsForSet().remove(PORTFOLIO_LIKE_KEY + portfolioId, userId.toString());
+ }
+
+ // 检查用户是否喜欢某个作品
+ public boolean isPostLikedByUser(Long portfolioId, Long userId) {
+ String key = PORTFOLIO_LIKE_KEY + portfolioId;
+ Boolean isMember = redisTemplate.opsForSet().isMember(key, userId.toString());
+ return isMember != null && isMember;
+ }
+
+ public final static String PORTFOLIO_VIEW_KEY = "portfolio:view:";
+
+ public void increaseViewCount(Long portfolioId) {
+ String key = PORTFOLIO_VIEW_KEY + portfolioId;
+ redisTemplate.opsForValue().increment(key);
+ }
+
+ public Long getViewCount(Long portfolioId) {
+ String key = PORTFOLIO_VIEW_KEY + portfolioId;
+ return redisTemplate.opsForValue().increment(key, 0);
+ }
+
+ public Long getViewCount(String key) {
+ Object value = redisTemplate.opsForValue().get(key);
+ if (value instanceof Integer) {
+ return Long.valueOf((Integer) value);
+ } else if (value instanceof Long) {
+ return (Long) value;
+ } else if (value instanceof String) {
+ return Long.valueOf((String) value);
+ } else {
+ throw new IllegalArgumentException("Unexpected value type");
+ }
+ }
+
+ public final static String PERSONAL_HOMEPAGE_VIEW_KEY = "PersonalHomepage:view:";
+
+ public void increasePersonalHomepageViewCount(Long accountId) {
+ String key = PERSONAL_HOMEPAGE_VIEW_KEY + accountId;
+ redisTemplate.opsForValue().increment(key);
+ }
+
+ public Long getPersonalHomepageViewCount(Long accountId) {
+ String key = PERSONAL_HOMEPAGE_VIEW_KEY + accountId;
+ return redisTemplate.opsForValue().increment(key, 0);
+ }
+
+ public final static String MOODBOARD_POSITION_KEY = "moodboard:position:";
+
+ public void saveMoodboardPosition(Long id, String moodboardPosition) {
+ addToString(MOODBOARD_POSITION_KEY + id, moodboardPosition);
+ }
+
+ public String getMoodboardPosition(Long id) {
+ return getFromString(MOODBOARD_POSITION_KEY + id);
+ }
+ public final static String NICKNAME_MODIFY_TIMES = "NicknameModifyTimes:";
+ public final static String UNNAMED_PROJECT_SEQ = "Project:UnnamedProjectSeq:";
+ public Long increaseCount(String key) {
+ return redisTemplate.opsForValue().increment(key);
+ }
+
+ public Long getIncrementCount(String key) {
+ return redisTemplate.opsForValue().increment(key, 0);
+ }
+
+ public void setKeyExpire(String key, Long expire) {
+ redisTemplate.expire(key, expire, TimeUnit.DAYS);
+ }
+
+ public final static String CHANGE_MAILBOX = "ChangeMailbox:";
+
+ // 每天允许通知3次
+ public final static String UPLOAD_TIMEOUT_REMINDER_COUNTER = "UploadTimeoutReminderCounter";
+
+ public void addProcessId(String processId, int progress) {
+ // Redis 中的键,可以通过 processId 来唯一标识
+ String redisKey = "process:progress:" + processId;
+
+ // 将当前进度存储到 Redis
+ redisTemplate.opsForValue().set(redisKey, String.valueOf(progress));
+
+ // 设置过期时间为 5 分钟(300 秒)
+ redisTemplate.expire(redisKey, 5, TimeUnit.MINUTES);
+ }
+
+ public void addPathToCache(Long collectionId, Long userId, String path) {
+ // Redis 中的键,唯一标识由 collectionId 和 userId 组成
+ String redisKey = "path:cache:" + collectionId + ":" + userId;
+
+ // 增加路径的计数
+ redisTemplate.opsForHash().increment(redisKey, path, 1);
+
+ // 设置过期时间为 2 小时(7200 秒)
+ redisTemplate.expire(redisKey, 8, TimeUnit.HOURS);
+ }
+
+ public int getPathUsageCount(Long collectionId, Long userId, String path) {
+ String redisKey = "path:cache:" + collectionId + ":" + userId;
+
+ // 获取路径的使用次数
+ Object count = redisTemplate.opsForHash().get(redisKey, path);
+ return count != null ? Integer.parseInt(count.toString()) : 0;
+ }
+
+ public void addAssembledObjects(Long collectionId, Set assembledObjects) {
+ // Redis 中的键,使用 collectionId 来唯一标识
+ String redisKey = "collection:assembledObjects:" + collectionId;
+
+ // 将 assembledObjects 转换为 JSON 格式存储,避免直接存储对象
+ String assembledObjectsJson = convertToJson(assembledObjects);
+
+ // 使用 Redis 的 set 操作更新集合
+ redisTemplate.opsForValue().set(redisKey, assembledObjectsJson);
+
+ // 设置过期时间为 5 分钟(300 秒)
+ redisTemplate.expire(redisKey, 30, TimeUnit.MINUTES);
+ }
+
+ // 将 Set 转换为 JSON 格式
+ private String convertToJson(Set assembledObjects) {
+ try {
+ ObjectMapper objectMapper = new ObjectMapper();
+ return objectMapper.writeValueAsString(assembledObjects);
+ } catch (JsonProcessingException e) {
+ e.printStackTrace();
+ return null;
+ }
+ }
+
+ public Set getAssembledObjects(Long collectionId) {
+ // Redis 中的键,使用 collectionId 来唯一标识
+ String redisKey = "collection:assembledObjects:" + collectionId;
+
+ // 从 Redis 获取存储的 JSON 字符串
+ String assembledObjectsJson = (String) redisTemplate.opsForValue().get(redisKey);
+
+ if (assembledObjectsJson == null) {
+ return new HashSet<>(); // 如果没有找到数据,返回一个空的 Set
+ }
+
+ // 将 JSON 字符串转换为 Set
+ return convertFromJson(assembledObjectsJson);
+ }
+
+ // 将 JSON 字符串转换为 Set
+ private Set convertFromJson(String json) {
+ try {
+ ObjectMapper objectMapper = new ObjectMapper();
+ // 使用 TypeReference 来指定目标类型是 Set
+ return objectMapper.readValue(json, new TypeReference>() {});
+ } catch (JsonProcessingException e) {
+ e.printStackTrace();
+ return new HashSet<>(); // 如果转换失败,返回空的 Set
+ }
+ }
+
+ public final static String PAYMENT_INFO_LAST_SCAN_TIME = "PaymentInfoLastScanTime";
+
+ public final static String AFFILIATE_LINK_VIEW_KEY = "AffiliateLink:view:";
+
+ public void increaseAffiliateLinkViewCount(Long accountId) {
+ String key = AFFILIATE_LINK_VIEW_KEY + accountId;
+ redisTemplate.opsForValue().increment(key);
+ }
+
+ public Long getAffiliateLinkViewCount(Long accountId) {
+ String key = AFFILIATE_LINK_VIEW_KEY + accountId;
+ return redisTemplate.opsForValue().increment(key, 0);
+ }
+
+ /**
+ * 记录任务的耗时到Redis
+ * @param taskKey 任务标识,如 "taskA"
+ * @param elapsedTime 本次耗时,单位为毫秒
+ */
+ public void recordTaskElapsedTime(String taskKey, long elapsedTime) {
+ String hashKey = "task:stats";
+
+ // 累加总耗时
+ redisTemplate.opsForHash().increment(hashKey, taskKey + ":totalTime", elapsedTime);
+
+ // 增加计数器
+ redisTemplate.opsForHash().increment(hashKey, taskKey + ":count", 1);
+ }
+
+ /**
+ * 获取任务的平均耗时
+ * @param taskKey 任务标识,如 "taskA"
+ * @return 平均耗时(毫秒)
+ */
+ public double getTaskAverageTime(String taskKey) {
+ String hashKey = "task:stats";
+
+ // 获取总耗时和计数
+ Object totalTime = redisTemplate.opsForHash().get(hashKey, taskKey + ":totalTime");
+ Object count = redisTemplate.opsForHash().get(hashKey, taskKey + ":count");
+
+ // 计算平均值
+ if (totalTime == null || count == null) {
+ return 0;
+ }
+ return Double.parseDouble(totalTime.toString()) / Long.parseLong(count.toString());
+ }
+
+ /**
+ * 清除指定任务的统计数据
+ * @param taskKey 任务标识,如 "taskA"
+ */
+ public void clearTaskStats(String taskKey) {
+ String hashKey = "task:stats";
+
+ // 删除总耗时和计数器
+ redisTemplate.opsForHash().delete(hashKey, taskKey + ":totalTime", taskKey + ":count");
+ }
+
+ public void recordTaskElapsedTime(String taskKey, double elapsedTimeInSeconds) {
+ // 将耗时转换为 BigDecimal,并四舍五入保留四位小数
+ BigDecimal elapsedTime = new BigDecimal(elapsedTimeInSeconds).setScale(4, RoundingMode.HALF_UP);
+
+ // 累加总耗时(以毫秒为单位)
+ redisTemplate.opsForHash().increment("task:stats", taskKey + ":totalTime", elapsedTime.doubleValue());
+
+ // 增加计数器
+ redisTemplate.opsForHash().increment("task:stats", taskKey + ":count", 1);
+ }
+
+ // 获取第一部分(Sketch)耗时
+ public double getFirstSketchTime() {
+ // 获取 "firstSketchTime:totalTime" 对应的值,并返回(单位为秒)
+ Object time = redisTemplate.opsForHash().get("task:stats", "firstSketchTime:totalTime");
+ return time != null ? (double) time : 0.0;
+ }
+
+ // 获取第二部分(获取特征值)耗时
+ public double getGetAttributeRecognitionTime() {
+ // 获取 "getAttributeRecognitionTime:totalTime" 对应的值,并返回(单位为秒)
+ Object time = redisTemplate.opsForHash().get("task:stats", "getAttributeRecognitionTime:totalTime");
+ return time != null ? (double) time : 0.0;
+ }
+
+ // 获取第三部分(搭配 Sketch)耗时
+ public double getOtherSketchTime() {
+ // 获取 "otherSketchTime:totalTime" 对应的值,并返回(单位为秒)
+ Object time = redisTemplate.opsForHash().get("task:stats", "otherSketchTime:totalTime");
+ return time != null ? (double) time : 0.0;
+ }
+
+ // 清理三部分的缓存
+ public void clearTaskElapsedTimeCache() {
+ // 删除第一部分的缓存
+ redisTemplate.opsForHash().delete("task:stats", "firstSketchTime:totalTime");
+ redisTemplate.opsForHash().delete("task:stats", "firstSketchTime:count");
+
+ // 删除第二部分的缓存
+ redisTemplate.opsForHash().delete("task:stats", "getAttributeRecognitionTime:totalTime");
+ redisTemplate.opsForHash().delete("task:stats", "getAttributeRecognitionTime:count");
+
+ // 删除第三部分的缓存
+ redisTemplate.opsForHash().delete("task:stats", "otherSketchTime:totalTime");
+ redisTemplate.opsForHash().delete("task:stats", "otherSketchTime:count");
+ }
+
+ public boolean incrementLikeCount(Long userId, String sketchPath) {
+ String redisKey = "user_like_count:" + userId;
+ try {
+ redisTemplate.opsForHash().increment(redisKey, sketchPath, 1);
+ return true;
+ } catch (Exception e) {
+ log.error("Error incrementing like count for userId {} and sketchPath {}: {}", userId, sketchPath, e.getMessage());
+ return false;
+ }
+ }
+
+ public int getLikeCount(Long userId, String sketchPath) {
+ String redisKey = "user_like_count:" + userId;
+ Object count = redisTemplate.opsForHash().get(redisKey, sketchPath);
+ return count != null ? Integer.parseInt(count.toString()) : 0;
+ }
+
+ public void storeMaxLikeCount(Long userId, int maxLikeCount) {
+ String redisKey = "user_max_like_count:" + userId;
+ redisTemplate.opsForValue().set(redisKey, String.valueOf(maxLikeCount));
+ }
+
+ public int getMaxLikeCount(Long userId) {
+ String redisKey = "user_max_like_count:" + userId;
+ String maxLikeCount = redisTemplate.opsForValue().get(redisKey);
+ return maxLikeCount != null ? Integer.parseInt(maxLikeCount) : 0;
+ }
+
+ public final static String IMAGE_SEGMENTATION = "ImageSegmentation:";
+
+ public final static String STRIPE_EXCEPTION_LOG = "StripeException:";
+ public final static String SUBSCRIPTION_SENT_EMAIL_TYPE = "SubscriptionEmailSentType:";
+
+ public void batchDeleteKeysWithSamePrefix(String prefix){
+ Set keys = redisTemplate.keys(prefix + "*");
+ assert keys != null;
+ if (!keys.isEmpty()){
+ redisTemplate.delete(keys);
+ }
+ }
+
+ public void setTaskProgressDTO(String taskId, ProgressDTO dto) {
+ String key = "task:progress:" + taskId;
+ redisTemplate.opsForValue().set(key, JSON.toJSONString(dto), Duration.ofDays(1));
+ }
+
+ public ProgressDTO getTaskProgressDTO(String taskId) {
+ String key = "task:progress:" + taskId;
+ String json = redisTemplate.opsForValue().get(key);
+ if (StringUtils.isBlank(json)) {
+// return new ProgressDTO(0, 0, false);
+ return null;
+ }
+ try {
+ return JSON.parseObject(json, ProgressDTO.class);
+ } catch (Exception e) {
+ log.warn("任务进度解析失败 key={}, json={}", key, json);
+ return new ProgressDTO(0, 0, false, null);
+ }
+ }
+
+ // Lua脚本(原子化操作)
+ /*private static final String RATE_LIMIT_SCRIPT =
+ "local current = redis.call('INCR', KEYS[1])\n" +
+ "if tonumber(current) == 1 then\n" +
+ " redis.call('EXPIRE', KEYS[1], ARGV[1])\n" +
+ "end\n" +
+ "return tonumber(current) <= tonumber(ARGV[2])";*/
+ private static final String RATE_LIMIT_SCRIPT =
+ "local current = redis.call('INCR', KEYS[1])\n" +
+ "local ttl = redis.call('TTL', KEYS[1])\n" +
+ "if tonumber(current) == 1 or tonumber(ttl) == -1 then\n" +
+ " redis.call('EXPIRE', KEYS[1], ARGV[1])\n" +
+ "end\n" +
+ "return tonumber(current) <= tonumber(ARGV[2])";
+
+ /**
+ * 检查是否允许发送
+ * @param userId 用户ID
+ * @return true-允许发送,false-已超限
+ */
+ public boolean allowSend(Long userId) {
+ String hourKey = getCurrentHourKey(userId);
+
+ // 执行Lua脚本
+ List keys = Collections.singletonList(hourKey);
+ List args = Arrays.asList(
+ 3600L, // 1小时过期
+ 10L // 限制数量 一小时只能向普通用户发10封
+ );
+
+ Boolean result = redisTemplate.execute(
+ new DefaultRedisScript<>(RATE_LIMIT_SCRIPT, Boolean.class),
+ keys,
+ args.toArray()
+ );
+
+ return Boolean.TRUE.equals(result);
+ }
+
+ /**
+ * 获取当前小时的Key
+ * 格式:email_limit:{userId}:{yyyyMMddHH}
+ */
+ private String getCurrentHourKey(Long userId) {
+ String hour = LocalDateTime.now()
+ .format(DateTimeFormatter.ofPattern("yyyyMMddHH"));
+ return String.format("email_limit:%s:%s", userId, hour);
+ }
+
+ /**
+ * 获取当前已发送数量
+ */
+ public int getCurrentCount(Long userId) {
+ String key = getCurrentHourKey(userId);
+ String val = redisTemplate.opsForValue().get(key);
+ int count;
+ if (StringUtils.isBlank(val)){
+ count = 0;
+ }else {
+ count = Integer.parseInt(val);
+ }
+ return count;
+ }
+
+ public boolean allowRequest(String apiKey) {
+ String key = "rate_limit:" + apiKey;
+ ValueOperations ops = redisTemplate.opsForValue();
+
+ // 使用Redis的INCR命令
+ Long count = ops.increment(key, 1);
+
+ if (count == 1) {
+ // 第一次调用,设置过期时间
+ redisTemplate.expire(key, 1, TimeUnit.MINUTES);
+ }
+
+ return count <= 3;
+ }
+
}
\ No newline at end of file
diff --git a/src/main/java/com/ai/da/model/dto/ImageProcessRequest.java b/src/main/java/com/ai/da/model/dto/ImageProcessRequest.java
new file mode 100644
index 00000000..79380472
--- /dev/null
+++ b/src/main/java/com/ai/da/model/dto/ImageProcessRequest.java
@@ -0,0 +1,55 @@
+package com.ai.da.model.dto;
+
+import lombok.Builder;
+import lombok.Data;
+
+import java.util.List;
+
+/**
+ * 图片处理请求体
+ */
+@Data
+@Builder
+public class ImageProcessRequest {
+
+ /**
+ * OSS桶名(bucket_name)
+ */
+ private String bucket_name;
+
+ /**
+ * OSS对象名(object_name)
+ */
+ private String object_name;
+
+ /**
+ * 输入图片路径列表(input_image_paths)
+ */
+ private List input_image_paths;
+
+ /**
+ * 图像宽度(width)
+ */
+ private Integer width;
+
+ /**
+ * 图像高度(height)
+ */
+ private Integer height;
+
+ /**
+ * 文本提示(prompt)
+ */
+ private String prompt;
+
+ /**
+ * 推理步数(steps)
+ */
+ private Integer steps;
+
+ /**
+ * 引导系数(guidance)
+ */
+ private Double guidance;
+
+}
\ No newline at end of file
diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java
index b95a2804..4ba86ce3 100644
--- a/src/main/java/com/ai/da/python/PythonService.java
+++ b/src/main/java/com/ai/da/python/PythonService.java
@@ -2,8 +2,10 @@ package com.ai.da.python;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.exceptions.ExceptionUtil;
+import com.ai.da.common.RabbitMQ.RabbitMQProperties;
import com.ai.da.common.config.FileProperties;
import com.ai.da.common.config.exception.BusinessException;
+import com.ai.da.common.constant.CommonConstant;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.enums.*;
import com.ai.da.common.utils.*;
@@ -20,6 +22,7 @@ import com.ai.da.model.vo.*;
import com.ai.da.python.vo.*;
import com.ai.da.service.DesignHistoryService;
import com.ai.da.service.PythonTAllInfoService;
+import com.ai.da.service.RabbitMQService;
import com.ai.da.service.SysFileService;
import com.alibaba.fastjson.*;
import com.alibaba.fastjson.serializer.SerializerFeature;
@@ -69,6 +72,8 @@ public class PythonService {
private String accessPythonPort;
@Value("${minio.bucketName.gradient}")
private String gradientBucketName;
+ @Value("${minio.bucketName.users}")
+ private String userBucketName;
@Value("${access.python.generate_sr_port}")
private String srServicePort;
@@ -84,6 +89,12 @@ public class PythonService {
@Resource
private RedisUtil redisUtil;
+ @Resource
+ private RabbitMQService rabbitMQService;
+
+ @Resource
+ private RabbitMQProperties rabbitMQProperties;
+
/**
* 生成打印的图片 二合一 (废弃于2024/01/02)
*
@@ -3334,7 +3345,7 @@ public class PythonService {
throw new BusinessException("system error!");
}
- public Boolean generateSketchOrPrint(String params, String port, String servicePath) {
+ public Boolean generateSketchOrPrint(String params, String port, String servicePath, String taskId) {
//限流校验
// AccessLimitUtils.validate("generateSketchOrPrint", 5);
OkHttpClient client = new OkHttpClient().newBuilder()
@@ -3396,12 +3407,36 @@ public class PythonService {
if (result && jsonObject.get("code").equals(200)) {
log.info("Generate##responseObject###{}", jsonObject);
// return setGenerateImageList(jsonObject.getJSONObject("data"));
+ if (servicePath == CommonConstant.GENERATE_PATH_FLUX2_KLEIN) {
+ //放入结果到mq
+ JSONObject data = jsonObject.getJSONObject("data");
+ String outputPath = data.getString("output_path");
+
+
+ Map mqMessage = new HashMap<>();
+ mqMessage.put("tasks_id", taskId);
+ mqMessage.put("status", "SUCCESS");
+ mqMessage.put("message", "success");
+ mqMessage.put("image_url", outputPath);
+ mqMessage.put("category", "");
+ String mqMessageBody = JSON.toJSONString(mqMessage);
+ rabbitMQService.publishMessageToGenerateResult(mqMessageBody);
+ }
return Boolean.TRUE;
} else {
log.info("generateSketchOrPrintPrint失败###{}", jsonObject);
log.info("Generate Exception! Code : " + jsonObject.get("code"));
+ Map mqMessage = new HashMap<>();
+ mqMessage.put("tasks_id", taskId);
+ mqMessage.put("status", "ERROR");
+ mqMessage.put("message", "");
+ mqMessage.put("image_url", "");
+ mqMessage.put("category", "");
+ String mqMessageBody = JSON.toJSONString(mqMessage);
+ rabbitMQService.publishMessageToGenerateResult(mqMessageBody);
return Boolean.FALSE;
}
+
}
public Response sendPostToModel(String content, String portAndRoute, String functionName) {
@@ -4139,6 +4174,9 @@ public class PythonService {
.writeTimeout(60, TimeUnit.SECONDS)
.build();
MediaType mediaType = MediaType.parse("application/json");
+ content.put("bucket", userBucketName);
+ content.put("object_name", content.get("user_id") + "/" + "segment" + "/" + UUID.randomUUID() + ".png");
+ content.remove("user_id");
RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content));
String url = accessPythonIp + ":" + accessPythonPort + "/api/seg_anything";
diff --git a/src/main/java/com/ai/da/service/RabbitMQService.java b/src/main/java/com/ai/da/service/RabbitMQService.java
index 7797c905..bf18b6eb 100644
--- a/src/main/java/com/ai/da/service/RabbitMQService.java
+++ b/src/main/java/com/ai/da/service/RabbitMQService.java
@@ -7,6 +7,8 @@ public interface RabbitMQService {
void publishMessageToGenerate(String message);
+ void publishMessageToGenerateResult(String message);
+
void publishMessageToSR(String message);
Integer getMessageCount(String queueUrl);
diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java
index fa080170..d23c7b6b 100644
--- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java
+++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java
@@ -48,6 +48,8 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
import org.apache.commons.io.IOUtils;
+import org.springframework.transaction.support.TransactionSynchronization;
+import org.springframework.transaction.support.TransactionSynchronizationManager;
import org.apache.http.HttpHost;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
@@ -59,10 +61,12 @@ import org.bytedeco.javacv.Java2DFrameConverter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.dao.DuplicateKeyException;
import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StringUtils;
import jakarta.annotation.Resource;
+
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.*;
@@ -194,10 +198,13 @@ public class GenerateServiceImpl extends ServiceImpl i
generate.setText(text);
Long elementId = generateThroughImageTextDTO.getCollectionElementId();
// validateGeneraType(generate, text, elementId);
- if (!StringUtil.isNullOrEmpty(text)) {
- text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type(), generateThroughImageTextDTO.getAgeGroup());
+ if (!(generateThroughImageTextDTO.getLevel1Type().equals(MOOD_BOARD.getRealName())&&generateThroughImageTextDTO.getModelName().equals("high"))){
+ if (!StringUtil.isNullOrEmpty(text)) {
+ text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type(), generateThroughImageTextDTO.getAgeGroup());
+ }
}
+
// todo 这一步现在还是有必要的吗?
// 2.1 sketch或print在t_collection_element表/t_library表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType());
@@ -218,6 +225,8 @@ public class GenerateServiceImpl extends ServiceImpl i
version = "fast";
params.put("version", "fast");
}
+ // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
+ saveGenerateImmediately(generate);
// 3.1 确定不同类型的印花分别调哪个接口
if (generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())) {
switch (generateThroughImageTextDTO.getLevel2Type()) {
@@ -243,15 +252,28 @@ public class GenerateServiceImpl extends ServiceImpl i
jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue);
}
} else {
- GenerateToPythonDTO generateToPythonDTO = new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
- mode, category, generateThroughImageTextDTO.getGender(), version);
- jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue);
+ if (Objects.equals(version, "fast")) {
+ GenerateToPythonDTO generateToPythonDTO = new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
+ mode, category, generateThroughImageTextDTO.getGender(), version);
+ jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue);
+ } else {
+
+
+ path = CommonConstant.GENERATE_PATH_FLUX2_KLEIN;
+ // 构建object_name: {userId}/{category}/{uuid}.png
+ String objectName = generateThroughImageTextDTO.getUserId() + "/" + category + "/" + UUID.randomUUID() + ".png";
+
+ ImageProcessRequest imageProcessRequest = ImageProcessRequest.builder()
+ .object_name(objectName)
+ .bucket_name(userBucket)
+ .prompt(text).build();
+ jsonString = JSON.toJSONString(imageProcessRequest);
+ }
+
}
- Boolean requestResult = pythonService.generateSketchOrPrint(jsonString, port, path);
+ Boolean requestResult = pythonService.generateSketchOrPrint(jsonString, port, path, generateThroughImageTextDTO.getUniqueId());
- // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
- save(generate);
// 5、将本次请求存入redis
String key = generateResultKey + ":" + generateThroughImageTextDTO.getUniqueId();
@@ -266,6 +288,40 @@ public class GenerateServiceImpl extends ServiceImpl i
}
+ public void saveGenerateImmediately(Generate generate) {
+ save(generate);
+ // 使用 TransactionSynchronizationManager 在事务真正提交后再设锁
+ // 否则 save() 完成后事务尚未 commit,MQ 消费者立即读到 null
+ TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {
+ @Override
+ public void afterCommit() {
+ String lockKey = "generate:lock:" + generate.getUniqueId();
+ redisUtil.addToString(lockKey, "1", 60L);
+ log.debug("Save lock set after commit for uniqueId: {}", generate.getUniqueId());
+ }
+ });
+ }
+
+ private void waitForSaveLock(String uniqueId) {
+ String lockKey = "generate:lock:" + uniqueId;
+ int maxRetries = 30;
+ int retryIntervalMs = 200;
+ for (int i = 0; i < maxRetries; i++) {
+ if (Boolean.TRUE.equals(redisUtil.hasKey(lockKey))) {
+ log.debug("Save lock acquired for uniqueId: {} after {} retries", uniqueId, i);
+ return;
+ }
+ try {
+ Thread.sleep(retryIntervalMs);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ log.warn("Interrupted while waiting for save lock: {}", uniqueId);
+ return;
+ }
+ }
+ log.warn("Save lock timeout for uniqueId: {}, proceeding anyway", uniqueId);
+ }
+
public GenerateModeEnum getMode(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getText())) {
if (Objects.nonNull(generateThroughImageTextDTO.getCollectionElementId())) {
@@ -284,11 +340,16 @@ public class GenerateServiceImpl extends ServiceImpl i
@Override
@Transactional(rollbackFor = Exception.class)
public void processGenerateResult(String taskId, String url, String category) {
+ log.info("============ProcessGenerateResult listening==========");
+ log.debug("taskId: " + taskId);
+ String status = null;
// 1、处理模型返回的数据
GenerateDetail generateDetail = new GenerateDetail();
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
Generate generate;
try {
+ // 等待 HTTP 线程写入完成后再查库
+ waitForSaveLock(taskId);
generate = selectByUniqueId(taskId);
} catch (MybatisPlusException e) {
log.error(e.getMessage());
@@ -311,14 +372,15 @@ public class GenerateServiceImpl extends ServiceImpl i
generateDetail.setUrl(url);
generateDetail.setGenerateId(generate.getId());
generateDetail.setCreateDate(LocalDateTime.now());
- generateDetail.setMd5(md5);
+ generateDetail.setMd5("");
// 将相应的url保存到数据库
generateDetailMapper.insert(generateDetail);
+ log.debug("generateDetail: " + generateDetail.toString());
// String uuid = taskId.substring(0, taskId.substring(0, taskId.lastIndexOf("-")).lastIndexOf("-"));
String key = generateResultKey + ":" + taskId;
String imageName = url.substring(url.lastIndexOf("/") + 1);
- String status = imageName.equals("white_image.jpg") ? "Invalid" : "Success";
+ status = imageName.equals("white_image.jpg") ? "Invalid" : "Success";
if (StringUtil.isNullOrEmpty(category)) {
Generate generateRecord = selectByUniqueId(taskId);
category = generateRecord.getLevel2Type();
@@ -326,6 +388,8 @@ public class GenerateServiceImpl extends ServiceImpl i
GenerateResultVO generateResultVO = new GenerateResultVO(taskId, generateDetail.getId(), url, status, category);
// 更新redis
redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
+ log.debug("generateResultVO: " + generateResultVO.toString());
+
// 执行积分扣除
// ** 注:如果生成的图片都是空白 则不扣积分
@@ -785,8 +849,9 @@ public class GenerateServiceImpl extends ServiceImpl i
long requestEndTime = System.currentTimeMillis();
log.info("HTTP请求完成 - 响应状态: {}, 耗时: {}ms, taskId: {}",
response.code(), (requestEndTime - requestStartTime), taskId);
+ String result = response.body().string();
if (!response.isSuccessful()) {
- log.warn("Google API响应失败,状态码: {} for taskId: {}", response.code(), taskId);
+ log.warn("Google API响应失败,状态码: {} for taskId: {},结果:{}", response.code(), taskId, result);
if (attempt < maxRetries) {
Thread.sleep(retryDelay * attempt); // 递增延迟
continue;
@@ -795,7 +860,7 @@ public class GenerateServiceImpl extends ServiceImpl i
}
}
- String result = response.body().string();
+
// log.info("Google 响应结果:{}", result);
com.alibaba.fastjson.JSONObject jsonResponse = JSON.parseObject(result);
@@ -1065,6 +1130,12 @@ public class GenerateServiceImpl extends ServiceImpl i
String result = response.body().string();
+ if (response.code() != 200) {
+ log.error("Google API 请求失败 - taskId: {}, 尝试: {}, URL: {}, 状态码: {}, 响应结果: {}",
+ taskId, attempt, endpoint, response.code(), result);
+ throw new BusinessException("system.error");
+ }
+
// log.info("Google 响应结果:{}", result);
com.alibaba.fastjson.JSONObject jsonResponse = JSON.parseObject(result);
@@ -1203,7 +1274,7 @@ public class GenerateServiceImpl extends ServiceImpl i
* @param modelName advanced high normal
*/
private HashMap chooseModelAndPrompt(GenerateThroughImageTextDTO generateDTO, String modelName) {
- if (StringUtil.isNullOrEmpty(modelName)){
+ if (StringUtil.isNullOrEmpty(modelName)) {
throw new BusinessException("system error");
}
HashMap modelAndPromptMap = new HashMap<>();
@@ -1221,7 +1292,7 @@ public class GenerateServiceImpl extends ServiceImpl i
String style = generateDTO.getText().substring(0, firstCommaIndex).trim();
String prompt = generateDTO.getText().substring(firstCommaIndex + 1).trim();
- prompt = getPrintboardPrompt(style, prompt,modelName,isUseImage);
+ prompt = getPrintboardPrompt(style, prompt, modelName, isUseImage);
modelAndPromptMap.put(ModelConstants.PROMPT, prompt);
@@ -1260,14 +1331,31 @@ public class GenerateServiceImpl extends ServiceImpl i
modelAndPromptMap.put(ModelConstants.USE_MODEL, ModelConstants.LOCAL_MODEL);
}
} else if (ModelConstants.SKETCHBOARD.equals(generateDTO.getLevel1Type())) {
- String[] split = generateDTO.getText().split(",");
- String style = split[0].trim();
- if ("Lolita".equals( style)){
- style = "洛丽塔";
+ String style = "";
+ String userPrompt = "";
+ // 找到第一个逗号的位置
+ int firstCommaIndex = generateDTO.getText().indexOf(",");
+ if (firstCommaIndex != -1) {
+ // 截取第一个逗号前的内容作为style
+ style = generateDTO.getText().substring(0, firstCommaIndex).trim();
+ // 截取第一个逗号后的所有内容作为userPrompt(去除首尾空格)
+ userPrompt = generateDTO.getText().substring(firstCommaIndex + 1).trim();
+
+ if ("Lolita".equals(style)) {
+ style = "洛丽塔";
+ }
+ } else {
+ // 兼容无逗号的情况:style为空,全部内容作为userPrompt
+ userPrompt = generateDTO.getText().trim();
}
- String userPrompt = split[1];
+
String prompt = userPrompt + "rules:front view sketch only,plain white background, single garment only, orthographic, centered on white background, borderless canvas, thin monochrome black line art.\n" +
- " No clothes hanger, no fake clothes hanger, no human-related lines, no color fill, no words, no text, no black background, no boundary or frame.sketch style:"+ style;
+ " No clothes hanger, no fake clothes hanger, no human-related lines, no color fill, no words, no text, no black background, no boundary or frame.";
+
+ if (!style.trim().isEmpty() && !"all".equalsIgnoreCase(style)) {
+ prompt += ".sketch style:" + style.trim();
+ }
+
modelAndPromptMap.put(ModelConstants.PROMPT, prompt);
if (isUseImage) {
if (ModelConstants.ADVANCED.equals(modelName)) {
@@ -1465,6 +1553,13 @@ public class GenerateServiceImpl extends ServiceImpl i
if (imagePath != null) {
requestBuilder.image(finalImagePath1);
}
+ if (useModel.equals(ModelConstants.PRINTBOARD_HIGH_I2I)) {
+ GenerateImagesRequest.OptimizePromptOptions optimizePromptOptions = new GenerateImagesRequest.OptimizePromptOptions();
+ optimizePromptOptions.setMode("fast");
+ requestBuilder.optimizePromptOptions(optimizePromptOptions);
+ //由于PRINTBOARD_HIGH_I2I与PRINTBOARD_ADVANCED_I2I使用模型一致,为了区别积分扣除,PRINTBOARD_HIGH_I2I加入了-fast,但传入模型时需要去掉-fast,用PRINTBOARD_ADVANCED_I2I的常量做替代
+ requestBuilder.model(ModelConstants.PRINTBOARD_ADVANCED_I2I);
+ }
// 保存生成记录到数据库
Generate generate = new Generate(
@@ -1602,14 +1697,14 @@ public class GenerateServiceImpl extends ServiceImpl i
"Flat textile pattern printed directly on fabric surface, no three-dimensional objects, no items placed on cloth. \n" +
"Real style: fabric print, realistic woven/printed pattern, detailed surface pattern only";
}
- }else {
- throw new BusinessException("style error:"+ style);
+ } else {
+ throw new BusinessException("style error:" + style);
}
if (userInput == null || userInput.trim().isEmpty()) {
- if (isUseImage){
+ if (isUseImage) {
prompt = "Theme: Image content" + "\nRequirement: " + systemPrompt;
- }else {
+ } else {
throw new BusinessException("prompt null");
}
} else {
@@ -2011,7 +2106,9 @@ public class GenerateServiceImpl extends ServiceImpl i
public Generate selectByUniqueId(String uniqueId) {
QueryWrapper qw = new QueryWrapper<>();
qw.eq("unique_id", uniqueId);
-
+ log.debug("selectByUniqueId: " + uniqueId);
+ Generate one = getOne(qw);
+ log.debug("Generate: " + one);
return getOne(qw);
}
@@ -4182,11 +4279,11 @@ public class GenerateServiceImpl extends ServiceImpl i
// 处理不同状态
switch (statusEnum) {
case TASK_NOT_FOUND:
- // 审核没过
+ // 审核没过
case REQUEST_MODERATED:
- // 审核没过
+ // 审核没过
case CONTENT_MODERATED:
- // 出错
+ // 出错
case ERROR:
return "Fail";
case PENDING_F:
diff --git a/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java b/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java
index 0f0e6dc4..42e920bf 100644
--- a/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java
+++ b/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java
@@ -31,6 +31,11 @@ public class RabbitMQServiceImpl implements RabbitMQService {
mqPublisher.sendGenerateMessage(message);
}
+ @Override
+ public void publishMessageToGenerateResult(String message) {
+ mqPublisher.sendGenerateResultMessage(message);
+ }
+
@Override
public void publishMessageToSR(String message) {
mqPublisher.sendSRMessage(message);