diff --git a/pom.xml b/pom.xml index db8b8b4d..9c14ce4b 100644 --- a/pom.xml +++ b/pom.xml @@ -427,6 +427,11 @@ bcpkix-jdk18on 1.78.1 + + + org.springframework.boot + spring-boot-starter-aop + diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java index 1312ba83..833df8ca 100644 --- a/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java +++ b/src/main/java/com/ai/da/common/RabbitMQ/MQPublisher.java @@ -28,6 +28,11 @@ public class MQPublisher { amqpTemplate.convertAndSend(rabbitMQProperties.getQueues().getSr(), mm); } + public void sendGenerateResultMessage(String mm) { + log.info("send generate result message: {}", mm); + amqpTemplate.convertAndSend(rabbitMQProperties.getQueues().getGenerateResult(), mm); + } + /** * * @param mailParams 含有的字段 diff --git a/src/main/java/com/ai/da/common/aspect/ControllerLoggingAspect.java b/src/main/java/com/ai/da/common/aspect/ControllerLoggingAspect.java new file mode 100644 index 00000000..391a0dda --- /dev/null +++ b/src/main/java/com/ai/da/common/aspect/ControllerLoggingAspect.java @@ -0,0 +1,170 @@ +package com.ai.da.common.aspect; + +import com.ai.da.common.context.UserContext; +import com.ai.da.model.vo.AuthPrincipalVo; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.aspectj.lang.JoinPoint; +import org.aspectj.lang.annotation.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; +import org.springframework.web.multipart.MultipartFile; + +import java.util.HashMap; +import java.util.Map; + +/** + * Controller日志切面 + * 记录所有Controller接口的请求参数和用户信息 + */ +@Aspect +@Component +public class ControllerLoggingAspect { + + private static final Logger logger = LoggerFactory.getLogger(ControllerLoggingAspect.class); + + /** + * 定义切点:所有Controller方法 + */ + @Pointcut("execution(* com.ai.da.controller..*(..))") + public void controllerMethods() { + } + + /** + * Controller方法执行前记录日志 + */ +// @Before("controllerMethods()") + public void logControllerBefore(JoinPoint joinPoint) { + ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + if (attributes != null) { + HttpServletRequest request = attributes.getRequest(); + + // 获取当前用户ID + Long userId = null; + AuthPrincipalVo authPrincipalVo = UserContext.getUserHolder(); + if (authPrincipalVo != null) { + userId = authPrincipalVo.getId(); + } + + // 获取请求参数 + Map params = getRequestParams(joinPoint, request); + + logger.info("=== 请求开始 ==="); + logger.info("用户ID: {}", userId); + logger.info("请求URL: {}", request.getRequestURL().toString()); + logger.info("请求方法: {}", request.getMethod()); + logger.info("请求IP: {}", getClientIpAddress(request)); + logger.info("调用方法: {}.{}", joinPoint.getSignature().getDeclaringType().getSimpleName(), joinPoint.getSignature().getName()); + logger.info("请求参数: {}", params); + } + } + + /** + * 获取请求参数 + */ + private Map getRequestParams(JoinPoint joinPoint, HttpServletRequest request) { + Map params = new HashMap<>(); + + // 1. 获取Query String参数 + String queryString = request.getQueryString(); + if (queryString != null && !queryString.isEmpty()) { + params.put("queryString", queryString); + } + + // 2. 获取方法参数(包含 @PathVariable, @RequestParam, @RequestBody 等) + Object[] args = joinPoint.getArgs(); + + if (args != null && args.length > 0) { + Map methodParams = new HashMap<>(); + for (int i = 0; i < args.length; i++) { + Object arg = args[i]; + // 过滤掉不可序列化的参数 + if (arg != null) { + if (isIgnorable(arg)) { + // 对于可忽略的类型,记录类型名 + methodParams.put("arg" + i, "[" + arg.getClass().getSimpleName() + "]"); + } else { + try { + methodParams.put("arg" + i, arg); + } catch (Exception e) { + methodParams.put("arg" + i, arg.toString()); + } + } + } + } + if (!methodParams.isEmpty()) { + params.put("methodParams", methodParams); + } + } + + return params; + } + + /** + * 判断是否需要过滤的参数类型 + */ + private boolean isIgnorable(Object obj) { + return obj instanceof HttpServletRequest + || obj instanceof HttpServletResponse + || obj instanceof MultipartFile + || obj instanceof MultipartFile[]; + } + + /** + * Controller方法抛出异常时记录日志 + */ + @AfterThrowing(pointcut = "controllerMethods()", throwing = "exception") + public void logControllerAfterThrowing(JoinPoint joinPoint, Throwable exception) { + ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + + Long userId = null; + AuthPrincipalVo authPrincipalVo = UserContext.getUserHolder(); + if (authPrincipalVo != null) { + userId = authPrincipalVo.getId(); + } + + // 获取请求参数 + Map params = new HashMap<>(); + if (attributes != null) { + HttpServletRequest request = attributes.getRequest(); + params = getRequestParams(joinPoint, request); + } + + logger.error("=== 请求异常 ==="); + logger.error("用户ID: {}", userId); + logger.error("调用方法: {}.{}", joinPoint.getSignature().getDeclaringType().getSimpleName(), joinPoint.getSignature().getName()); + logger.error("请求参数: {}", params); + logger.error("异常信息: ", exception); + logger.error("=== 异常结束 ==="); + } + + /** + * 获取客户端真实IP地址 + */ + private String getClientIpAddress(HttpServletRequest request) { + String xForwardedFor = request.getHeader("X-Forwarded-For"); + if (xForwardedFor != null && !xForwardedFor.isEmpty() && !"unknown".equalsIgnoreCase(xForwardedFor)) { + return xForwardedFor.split(",")[0]; + } + + String xRealIp = request.getHeader("X-Real-IP"); + if (xRealIp != null && !xRealIp.isEmpty() && !"unknown".equalsIgnoreCase(xRealIp)) { + return xRealIp; + } + + String proxyClientIp = request.getHeader("Proxy-Client-IP"); + if (proxyClientIp != null && !proxyClientIp.isEmpty() && !"unknown".equalsIgnoreCase(proxyClientIp)) { + return proxyClientIp; + } + + String wlProxyClientIp = request.getHeader("WL-Proxy-Client-IP"); + if (wlProxyClientIp != null && !wlProxyClientIp.isEmpty() && !"unknown".equalsIgnoreCase(wlProxyClientIp)) { + return wlProxyClientIp; + } + + return request.getRemoteAddr(); + } +} diff --git a/src/main/java/com/ai/da/common/config/MyTaskScheduler.java b/src/main/java/com/ai/da/common/config/MyTaskScheduler.java index be6bd63b..f7c10bcc 100644 --- a/src/main/java/com/ai/da/common/config/MyTaskScheduler.java +++ b/src/main/java/com/ai/da/common/config/MyTaskScheduler.java @@ -202,7 +202,7 @@ public class MyTaskScheduler { } } -// @Scheduled(cron = "0 0 9 * * ?") + // @Scheduled(cron = "0 0 9 * * ?") public void sendTrialOrderExcelToManagements() { // 获取前一天日期 LocalDate yesterday = LocalDate.now().minusDays(1); diff --git a/src/main/java/com/ai/da/common/constant/CommonConstant.java b/src/main/java/com/ai/da/common/constant/CommonConstant.java index 4fb6ec16..9f7ac8bc 100644 --- a/src/main/java/com/ai/da/common/constant/CommonConstant.java +++ b/src/main/java/com/ai/da/common/constant/CommonConstant.java @@ -23,6 +23,7 @@ public class CommonConstant { } public static final String GENERATE_PATH = "/api/generate_image"; + public static final String GENERATE_PATH_FLUX2_KLEIN = "/api/generate_image_flux2_klein"; public static final String GENERATE_SINGLE_LOGO = "/api/generate_single_logo"; diff --git a/src/main/java/com/ai/da/common/constant/ModelConstants.java b/src/main/java/com/ai/da/common/constant/ModelConstants.java index bc0b1745..fab503bb 100644 --- a/src/main/java/com/ai/da/common/constant/ModelConstants.java +++ b/src/main/java/com/ai/da/common/constant/ModelConstants.java @@ -20,7 +20,7 @@ public class ModelConstants { public static final String PRINTBOARD_ADVANCED_T2I = "qwen-image"; public static final String MOODBOARD_ADVANCED = "doubao-seedream-3-0-t2i-250415"; public static final String PRINTBOARD_HIGH_T2I = "doubao-seedream-3-0-t2i-250415"; - public static final String PRINTBOARD_HIGH_I2I = "doubao-seededit-3-0-i2i-250628"; + public static final String PRINTBOARD_HIGH_I2I = "doubao-seedream-4-0-250828-fast"; public static final String PRINTBOARD_ADVANCED_I2I = "doubao-seedream-4-0-250828"; public static final String IMAGEN_MODEL = "imagen-4.0-generate-001"; public static final String NANO_BANANA = "gemini-2.5-flash-image"; diff --git a/src/main/java/com/ai/da/common/task/AccountTask.java b/src/main/java/com/ai/da/common/task/AccountTask.java index ee573eaa..e0d90480 100644 --- a/src/main/java/com/ai/da/common/task/AccountTask.java +++ b/src/main/java/com/ai/da/common/task/AccountTask.java @@ -34,7 +34,7 @@ public class AccountTask { accountService.refreshCreditsMonthly(); } -// @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes + // @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes public void getPaidUser() { // 获取code-create 表中 指定日期之后 订单状态为wc-processing的订单 accountService.extendValidityForCC(); diff --git a/src/main/java/com/ai/da/common/task/PaymentTask.java b/src/main/java/com/ai/da/common/task/PaymentTask.java index 52a6e2d6..e63e9388 100644 --- a/src/main/java/com/ai/da/common/task/PaymentTask.java +++ b/src/main/java/com/ai/da/common/task/PaymentTask.java @@ -45,7 +45,7 @@ public class PaymentTask { @Resource private PayPalCheckoutService payPalCheckoutService; -// @Scheduled(cron = "0/30 * * * * ?") + // @Scheduled(cron = "0/30 * * * * ?") public void orderConfirmForPaypal() throws SerializeException { // log.info("PayPal orderConfirm 被执行......"); @@ -97,7 +97,7 @@ public class PaymentTask { // } -// @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes + // @Scheduled(cron = "0 */5 * * * *") // Run every 5 minutes public void updateAffiliateInfoWithPayment(){ // log.info("佣金计算定时器"); affiliateService.updateAffiliateInfoWithPayment(); @@ -109,7 +109,7 @@ public class PaymentTask { affiliateService.syncLinkViewCountToDB(); } -// @Scheduled(cron = "0 0 8 28-31 * ?") + // @Scheduled(cron = "0 0 8 28-31 * ?") public void commissionSummaryReminder(){ // 每个月末的最后一天的早上八点执行 LocalDate today = LocalDate.now(); diff --git a/src/main/java/com/ai/da/common/task/SubscriptionReminderTask.java b/src/main/java/com/ai/da/common/task/SubscriptionReminderTask.java index 0c6ada03..26478270 100644 --- a/src/main/java/com/ai/da/common/task/SubscriptionReminderTask.java +++ b/src/main/java/com/ai/da/common/task/SubscriptionReminderTask.java @@ -40,7 +40,7 @@ public class SubscriptionReminderTask { REMINDER_DAYS_CONFIG.put("year", 14); } -// @Scheduled(cron = "0 0 9 * * ?") + // @Scheduled(cron = "0 0 9 * * ?") public void subscriptionReminder() { // 获取所有需要通知的订阅 List subscriptionInfos = getDueSubscriptions(); @@ -97,7 +97,7 @@ public class SubscriptionReminderTask { return subscriptionInfoMapper.selectList(qw); } -// @Scheduled(cron = "0 0 9 * * ?") + // @Scheduled(cron = "0 0 9 * * ?") public void trialReminder() { // 今天的 00:00:00 和 23:59:59 LocalDateTime startOfDay = LocalDateTime.now().toLocalDate().atStartOfDay(); diff --git a/src/main/java/com/ai/da/common/utils/RedisUtil.java b/src/main/java/com/ai/da/common/utils/RedisUtil.java index 1264a1e4..c91411fa 100644 --- a/src/main/java/com/ai/da/common/utils/RedisUtil.java +++ b/src/main/java/com/ai/da/common/utils/RedisUtil.java @@ -1,659 +1,659 @@ -package com.ai.da.common.utils; - -import com.ai.da.model.dto.ProgressDTO; -import com.ai.da.python.vo.DesignPythonObject; -import com.alibaba.fastjson.JSON; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.netty.util.internal.StringUtil; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.springframework.data.redis.core.RedisTemplate; -import org.springframework.data.redis.core.ValueOperations; -import org.springframework.data.redis.core.ZSetOperations; -import org.springframework.data.redis.core.script.DefaultRedisScript; -import org.springframework.stereotype.Component; -import org.springframework.util.CollectionUtils; - -import jakarta.annotation.Resource; -import java.math.BigDecimal; -import java.math.RoundingMode; -import java.time.Duration; -import java.time.LocalDateTime; -import java.time.format.DateTimeFormatter; -import java.util.*; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; - -@Slf4j -@Component -public class RedisUtil { - - @Resource - private RedisTemplate redisTemplate; - - public final static String FLUX_POLLING_URL = "Flux:"; - /** - * 登录 token 在 Redis 中的前缀: - * 最终 key 结构为 login:token:{userId} - */ - public final static String LOGIN_TOKEN_KEY = "login:token:"; - - public Boolean hasKey(String key){ - return redisTemplate.hasKey(key); - } - - //- - - - - - - - - - - - - - - - - - - - - ZSet类型 - - - - - - - - - - - - - - - - - - - - - - /** - * 向ZSet中添加元素 - */ - public void addToZSet(String key, String value, Double score) { - redisTemplate.opsForZSet().add(key, value, score); - } - - /** - * 从ZSet中删除元素 - */ - public void removeFromZSet(String key, String value) { - redisTemplate.opsForZSet().remove(key, value); - } - - /** - * 获取指定元素的当前排列顺序 - */ - public Long getRank(String key, String value) { - return redisTemplate.opsForZSet().rank(key, value); - } - - /** - * 获取当前ZSet中的最大score - */ - public Double getMaxScore(String key) { - Set> set = redisTemplate.opsForZSet().reverseRangeWithScores(key, 0, 0); - - if (!CollectionUtils.isEmpty(set)) { - Double score = set.iterator().next().getScore(); - return score + 1.0; - } else { - return 1.0; - } - } - - /** - * 判断元素是否存在 - */ - public Boolean isElementExistsInZSet(String key, String value) { - return redisTemplate.opsForZSet().score(key, value) != null; - } - - /** - * 获取当前ZSet中数据量的总和 - */ - public Long getZSetTotalCount(String key) { - return redisTemplate.opsForZSet().zCard(key); - } - - - public Set getZSetTotalData(String key){ - return redisTemplate.opsForZSet().range(key, 0, -1); - } - - //- - - - - - - - - - - - - - - - - - - - - set类型 - - - - - - - - - - - - - - - - - - - - - - public final static String VIDEO_FINISHED_TASKS = "VideoFinishedTasks"; - - /** - * 将数据放入set缓存 - */ - public void addToSet(String key, String value, Long expiresIn) { - redisTemplate.opsForSet().add(key, value); - // 设置过期时间 - redisTemplate.expire(key, expiresIn, TimeUnit.SECONDS); - } - - /** - * 弹出变量中的元素 - */ - public void removeFromSet(String key, String value) { - redisTemplate.opsForSet().remove(key, value); - } - - /** - * 检查给定的元素是否在变量中。 - */ - public Boolean isElementExistsInSet(String key, String obj) { - return redisTemplate.opsForSet().isMember(key, obj); - } - - - //- - - - - - - - - - - - - - - - - - - - - hash类型 - - - - - - - - - - - - - - - - - - - - - - /** - * 加入缓存 - */ - public void addToMap(String key, Map map) { - redisTemplate.opsForHash().putAll(key, map); - } - - /** - * 验证指定 key 下 有没有指定的 hashkey - */ - public Boolean isElementExistsInMap(String key, String hashKey) { - return redisTemplate.opsForHash().hasKey(key, hashKey); - } - - /** - * 获取指定key的值string - */ - public String getMapValue(String key1, String key2) { - return String.valueOf(redisTemplate.opsForHash().get(key1, key2)); - } - - /** - * 删除指定 hash 的 HashKey - * - * @return 删除成功的 数量 - */ - public Long removeFromMap(String key, String hashKeys) { - return redisTemplate.opsForHash().delete(key, hashKeys); - } - - //- - - - - - - - - - - - - - - - - - - - - String类型 - - - - - - - - - - - - - - - - - - - - - public void addToString(String key, String value){ - redisTemplate.opsForValue().set(key,value); - } - - public void addToString(String key, String value, Long expiresIn){ - redisTemplate.opsForValue().set(key,value,expiresIn, TimeUnit.SECONDS); - } - - public String getFromString(String key){ - return redisTemplate.opsForValue().get(key); - } - - public Set getKeysFromString(String key){ - return redisTemplate.keys(key); - } - - public Long getSize(String key){return redisTemplate.opsForSet().size(key);} - - public List getMultiValue(Set keys){ - return redisTemplate.opsForValue().multiGet(keys); - } - - public Long getExpire(String key){ - return redisTemplate.getExpire(key); - } - - public void removeFromString(String key){ - redisTemplate.delete(key); - } - - /** - * 保存登录 token - * - * @param userId 用户 ID - * @param token token 字符串 - * @param expireMillis 过期时间(毫秒,通常与 JWT 保持一致) - */ - public void setLoginToken(Long userId, String token, long expireMillis) { - if (expireMillis <= 0) { - // 不设置过期时间,直到手动删除(不推荐) - addToString(LOGIN_TOKEN_KEY + userId, token); - return; - } - long expireSeconds = expireMillis / 1000; - if (expireSeconds <= 0) { - expireSeconds = 1; - } - addToString(LOGIN_TOKEN_KEY + userId, token, expireSeconds); - } - - /** - * 获取登录 token - */ - public String getLoginToken(Long userId) { - return getFromString(LOGIN_TOKEN_KEY + userId); - } - - /** - * 删除登录 token - */ - public void deleteLoginToken(Long userId) { - removeFromString(LOGIN_TOKEN_KEY + userId); - } - - public final static String PORTFOLIO_LIKE_KEY = "portfolio:like:"; - - public void likePost(Long portfolioId, Long userId) { - redisTemplate.opsForSet().add(PORTFOLIO_LIKE_KEY + portfolioId, String.valueOf(userId)); - } - - public Long getLikeCount(Long portfolioId) { - String key = PORTFOLIO_LIKE_KEY + portfolioId; - return redisTemplate.opsForSet().size(key); - } - - public List getLikedPortfolios(Long userId) { - // 获取所有包含PORTFOLIO_LIKE_KEY的键 - Set likedPortfolios = redisTemplate.keys(PORTFOLIO_LIKE_KEY + "*"); - - // 如果没有喜欢的,返回空列表 - if (likedPortfolios == null || likedPortfolios.isEmpty()) { - return new ArrayList<>(); - } - - // 过滤出包含指定用户ID的键,并提取投资组合ID - return likedPortfolios.stream() - .filter(key -> redisTemplate.opsForSet().isMember(key, String.valueOf(userId))) - .map(key -> Long.valueOf(key.replace(PORTFOLIO_LIKE_KEY, ""))) - .collect(Collectors.toList()); - } - - public void unLikePost(Long portfolioId, Long userId) { - redisTemplate.opsForSet().remove(PORTFOLIO_LIKE_KEY + portfolioId, userId.toString()); - } - - // 检查用户是否喜欢某个作品 - public boolean isPostLikedByUser(Long portfolioId, Long userId) { - String key = PORTFOLIO_LIKE_KEY + portfolioId; - Boolean isMember = redisTemplate.opsForSet().isMember(key, userId.toString()); - return isMember != null && isMember; - } - - public final static String PORTFOLIO_VIEW_KEY = "portfolio:view:"; - - public void increaseViewCount(Long portfolioId) { - String key = PORTFOLIO_VIEW_KEY + portfolioId; - redisTemplate.opsForValue().increment(key); - } - - public Long getViewCount(Long portfolioId) { - String key = PORTFOLIO_VIEW_KEY + portfolioId; - return redisTemplate.opsForValue().increment(key, 0); - } - - public Long getViewCount(String key) { - Object value = redisTemplate.opsForValue().get(key); - if (value instanceof Integer) { - return Long.valueOf((Integer) value); - } else if (value instanceof Long) { - return (Long) value; - } else if (value instanceof String) { - return Long.valueOf((String) value); - } else { - throw new IllegalArgumentException("Unexpected value type"); - } - } - - public final static String PERSONAL_HOMEPAGE_VIEW_KEY = "PersonalHomepage:view:"; - - public void increasePersonalHomepageViewCount(Long accountId) { - String key = PERSONAL_HOMEPAGE_VIEW_KEY + accountId; - redisTemplate.opsForValue().increment(key); - } - - public Long getPersonalHomepageViewCount(Long accountId) { - String key = PERSONAL_HOMEPAGE_VIEW_KEY + accountId; - return redisTemplate.opsForValue().increment(key, 0); - } - - public final static String MOODBOARD_POSITION_KEY = "moodboard:position:"; - - public void saveMoodboardPosition(Long id, String moodboardPosition) { - addToString(MOODBOARD_POSITION_KEY + id, moodboardPosition); - } - - public String getMoodboardPosition(Long id) { - return getFromString(MOODBOARD_POSITION_KEY + id); - } - public final static String NICKNAME_MODIFY_TIMES = "NicknameModifyTimes:"; - public final static String UNNAMED_PROJECT_SEQ = "Project:UnnamedProjectSeq:"; - public Long increaseCount(String key) { - return redisTemplate.opsForValue().increment(key); - } - - public Long getIncrementCount(String key) { - return redisTemplate.opsForValue().increment(key, 0); - } - - public void setKeyExpire(String key, Long expire) { - redisTemplate.expire(key, expire, TimeUnit.DAYS); - } - - public final static String CHANGE_MAILBOX = "ChangeMailbox:"; - - // 每天允许通知3次 - public final static String UPLOAD_TIMEOUT_REMINDER_COUNTER = "UploadTimeoutReminderCounter"; - - public void addProcessId(String processId, int progress) { - // Redis 中的键,可以通过 processId 来唯一标识 - String redisKey = "process:progress:" + processId; - - // 将当前进度存储到 Redis - redisTemplate.opsForValue().set(redisKey, String.valueOf(progress)); - - // 设置过期时间为 5 分钟(300 秒) - redisTemplate.expire(redisKey, 5, TimeUnit.MINUTES); - } - - public void addPathToCache(Long collectionId, Long userId, String path) { - // Redis 中的键,唯一标识由 collectionId 和 userId 组成 - String redisKey = "path:cache:" + collectionId + ":" + userId; - - // 增加路径的计数 - redisTemplate.opsForHash().increment(redisKey, path, 1); - - // 设置过期时间为 2 小时(7200 秒) - redisTemplate.expire(redisKey, 8, TimeUnit.HOURS); - } - - public int getPathUsageCount(Long collectionId, Long userId, String path) { - String redisKey = "path:cache:" + collectionId + ":" + userId; - - // 获取路径的使用次数 - Object count = redisTemplate.opsForHash().get(redisKey, path); - return count != null ? Integer.parseInt(count.toString()) : 0; - } - - public void addAssembledObjects(Long collectionId, Set assembledObjects) { - // Redis 中的键,使用 collectionId 来唯一标识 - String redisKey = "collection:assembledObjects:" + collectionId; - - // 将 assembledObjects 转换为 JSON 格式存储,避免直接存储对象 - String assembledObjectsJson = convertToJson(assembledObjects); - - // 使用 Redis 的 set 操作更新集合 - redisTemplate.opsForValue().set(redisKey, assembledObjectsJson); - - // 设置过期时间为 5 分钟(300 秒) - redisTemplate.expire(redisKey, 30, TimeUnit.MINUTES); - } - - // 将 Set 转换为 JSON 格式 - private String convertToJson(Set assembledObjects) { - try { - ObjectMapper objectMapper = new ObjectMapper(); - return objectMapper.writeValueAsString(assembledObjects); - } catch (JsonProcessingException e) { - e.printStackTrace(); - return null; - } - } - - public Set getAssembledObjects(Long collectionId) { - // Redis 中的键,使用 collectionId 来唯一标识 - String redisKey = "collection:assembledObjects:" + collectionId; - - // 从 Redis 获取存储的 JSON 字符串 - String assembledObjectsJson = (String) redisTemplate.opsForValue().get(redisKey); - - if (assembledObjectsJson == null) { - return new HashSet<>(); // 如果没有找到数据,返回一个空的 Set - } - - // 将 JSON 字符串转换为 Set - return convertFromJson(assembledObjectsJson); - } - - // 将 JSON 字符串转换为 Set - private Set convertFromJson(String json) { - try { - ObjectMapper objectMapper = new ObjectMapper(); - // 使用 TypeReference 来指定目标类型是 Set - return objectMapper.readValue(json, new TypeReference>() {}); - } catch (JsonProcessingException e) { - e.printStackTrace(); - return new HashSet<>(); // 如果转换失败,返回空的 Set - } - } - - public final static String PAYMENT_INFO_LAST_SCAN_TIME = "PaymentInfoLastScanTime"; - - public final static String AFFILIATE_LINK_VIEW_KEY = "AffiliateLink:view:"; - - public void increaseAffiliateLinkViewCount(Long accountId) { - String key = AFFILIATE_LINK_VIEW_KEY + accountId; - redisTemplate.opsForValue().increment(key); - } - - public Long getAffiliateLinkViewCount(Long accountId) { - String key = AFFILIATE_LINK_VIEW_KEY + accountId; - return redisTemplate.opsForValue().increment(key, 0); - } - - /** - * 记录任务的耗时到Redis - * @param taskKey 任务标识,如 "taskA" - * @param elapsedTime 本次耗时,单位为毫秒 - */ - public void recordTaskElapsedTime(String taskKey, long elapsedTime) { - String hashKey = "task:stats"; - - // 累加总耗时 - redisTemplate.opsForHash().increment(hashKey, taskKey + ":totalTime", elapsedTime); - - // 增加计数器 - redisTemplate.opsForHash().increment(hashKey, taskKey + ":count", 1); - } - - /** - * 获取任务的平均耗时 - * @param taskKey 任务标识,如 "taskA" - * @return 平均耗时(毫秒) - */ - public double getTaskAverageTime(String taskKey) { - String hashKey = "task:stats"; - - // 获取总耗时和计数 - Object totalTime = redisTemplate.opsForHash().get(hashKey, taskKey + ":totalTime"); - Object count = redisTemplate.opsForHash().get(hashKey, taskKey + ":count"); - - // 计算平均值 - if (totalTime == null || count == null) { - return 0; - } - return Double.parseDouble(totalTime.toString()) / Long.parseLong(count.toString()); - } - - /** - * 清除指定任务的统计数据 - * @param taskKey 任务标识,如 "taskA" - */ - public void clearTaskStats(String taskKey) { - String hashKey = "task:stats"; - - // 删除总耗时和计数器 - redisTemplate.opsForHash().delete(hashKey, taskKey + ":totalTime", taskKey + ":count"); - } - - public void recordTaskElapsedTime(String taskKey, double elapsedTimeInSeconds) { - // 将耗时转换为 BigDecimal,并四舍五入保留四位小数 - BigDecimal elapsedTime = new BigDecimal(elapsedTimeInSeconds).setScale(4, RoundingMode.HALF_UP); - - // 累加总耗时(以毫秒为单位) - redisTemplate.opsForHash().increment("task:stats", taskKey + ":totalTime", elapsedTime.doubleValue()); - - // 增加计数器 - redisTemplate.opsForHash().increment("task:stats", taskKey + ":count", 1); - } - - // 获取第一部分(Sketch)耗时 - public double getFirstSketchTime() { - // 获取 "firstSketchTime:totalTime" 对应的值,并返回(单位为秒) - Object time = redisTemplate.opsForHash().get("task:stats", "firstSketchTime:totalTime"); - return time != null ? (double) time : 0.0; - } - - // 获取第二部分(获取特征值)耗时 - public double getGetAttributeRecognitionTime() { - // 获取 "getAttributeRecognitionTime:totalTime" 对应的值,并返回(单位为秒) - Object time = redisTemplate.opsForHash().get("task:stats", "getAttributeRecognitionTime:totalTime"); - return time != null ? (double) time : 0.0; - } - - // 获取第三部分(搭配 Sketch)耗时 - public double getOtherSketchTime() { - // 获取 "otherSketchTime:totalTime" 对应的值,并返回(单位为秒) - Object time = redisTemplate.opsForHash().get("task:stats", "otherSketchTime:totalTime"); - return time != null ? (double) time : 0.0; - } - - // 清理三部分的缓存 - public void clearTaskElapsedTimeCache() { - // 删除第一部分的缓存 - redisTemplate.opsForHash().delete("task:stats", "firstSketchTime:totalTime"); - redisTemplate.opsForHash().delete("task:stats", "firstSketchTime:count"); - - // 删除第二部分的缓存 - redisTemplate.opsForHash().delete("task:stats", "getAttributeRecognitionTime:totalTime"); - redisTemplate.opsForHash().delete("task:stats", "getAttributeRecognitionTime:count"); - - // 删除第三部分的缓存 - redisTemplate.opsForHash().delete("task:stats", "otherSketchTime:totalTime"); - redisTemplate.opsForHash().delete("task:stats", "otherSketchTime:count"); - } - - public boolean incrementLikeCount(Long userId, String sketchPath) { - String redisKey = "user_like_count:" + userId; - try { - redisTemplate.opsForHash().increment(redisKey, sketchPath, 1); - return true; - } catch (Exception e) { - log.error("Error incrementing like count for userId {} and sketchPath {}: {}", userId, sketchPath, e.getMessage()); - return false; - } - } - - public int getLikeCount(Long userId, String sketchPath) { - String redisKey = "user_like_count:" + userId; - Object count = redisTemplate.opsForHash().get(redisKey, sketchPath); - return count != null ? Integer.parseInt(count.toString()) : 0; - } - - public void storeMaxLikeCount(Long userId, int maxLikeCount) { - String redisKey = "user_max_like_count:" + userId; - redisTemplate.opsForValue().set(redisKey, String.valueOf(maxLikeCount)); - } - - public int getMaxLikeCount(Long userId) { - String redisKey = "user_max_like_count:" + userId; - String maxLikeCount = redisTemplate.opsForValue().get(redisKey); - return maxLikeCount != null ? Integer.parseInt(maxLikeCount) : 0; - } - - public final static String IMAGE_SEGMENTATION = "ImageSegmentation:"; - - public final static String STRIPE_EXCEPTION_LOG = "StripeException:"; - public final static String SUBSCRIPTION_SENT_EMAIL_TYPE = "SubscriptionEmailSentType:"; - - public void batchDeleteKeysWithSamePrefix(String prefix){ - Set keys = redisTemplate.keys(prefix + "*"); - assert keys != null; - if (!keys.isEmpty()){ - redisTemplate.delete(keys); - } - } - - public void setTaskProgressDTO(String taskId, ProgressDTO dto) { - String key = "task:progress:" + taskId; - redisTemplate.opsForValue().set(key, JSON.toJSONString(dto), Duration.ofDays(1)); - } - - public ProgressDTO getTaskProgressDTO(String taskId) { - String key = "task:progress:" + taskId; - String json = redisTemplate.opsForValue().get(key); - if (StringUtils.isBlank(json)) { -// return new ProgressDTO(0, 0, false); - return null; - } - try { - return JSON.parseObject(json, ProgressDTO.class); - } catch (Exception e) { - log.warn("任务进度解析失败 key={}, json={}", key, json); - return new ProgressDTO(0, 0, false, null); - } - } - - // Lua脚本(原子化操作) - /*private static final String RATE_LIMIT_SCRIPT = - "local current = redis.call('INCR', KEYS[1])\n" + - "if tonumber(current) == 1 then\n" + - " redis.call('EXPIRE', KEYS[1], ARGV[1])\n" + - "end\n" + - "return tonumber(current) <= tonumber(ARGV[2])";*/ - private static final String RATE_LIMIT_SCRIPT = - "local current = redis.call('INCR', KEYS[1])\n" + - "local ttl = redis.call('TTL', KEYS[1])\n" + - "if tonumber(current) == 1 or tonumber(ttl) == -1 then\n" + - " redis.call('EXPIRE', KEYS[1], ARGV[1])\n" + - "end\n" + - "return tonumber(current) <= tonumber(ARGV[2])"; - - /** - * 检查是否允许发送 - * @param userId 用户ID - * @return true-允许发送,false-已超限 - */ - public boolean allowSend(Long userId) { - String hourKey = getCurrentHourKey(userId); - - // 执行Lua脚本 - List keys = Collections.singletonList(hourKey); - List args = Arrays.asList( - 3600L, // 1小时过期 - 10L // 限制数量 一小时只能向普通用户发10封 - ); - - Boolean result = redisTemplate.execute( - new DefaultRedisScript<>(RATE_LIMIT_SCRIPT, Boolean.class), - keys, - args.toArray() - ); - - return Boolean.TRUE.equals(result); - } - - /** - * 获取当前小时的Key - * 格式:email_limit:{userId}:{yyyyMMddHH} - */ - private String getCurrentHourKey(Long userId) { - String hour = LocalDateTime.now() - .format(DateTimeFormatter.ofPattern("yyyyMMddHH")); - return String.format("email_limit:%s:%s", userId, hour); - } - - /** - * 获取当前已发送数量 - */ - public int getCurrentCount(Long userId) { - String key = getCurrentHourKey(userId); - String val = redisTemplate.opsForValue().get(key); - int count; - if (StringUtils.isBlank(val)){ - count = 0; - }else { - count = Integer.parseInt(val); - } - return count; - } - - public boolean allowRequest(String apiKey) { - String key = "rate_limit:" + apiKey; - ValueOperations ops = redisTemplate.opsForValue(); - - // 使用Redis的INCR命令 - Long count = ops.increment(key, 1); - - if (count == 1) { - // 第一次调用,设置过期时间 - redisTemplate.expire(key, 1, TimeUnit.MINUTES); - } - - return count <= 3; - } - +package com.ai.da.common.utils; + +import com.ai.da.model.dto.ProgressDTO; +import com.ai.da.python.vo.DesignPythonObject; +import com.alibaba.fastjson.JSON; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.netty.util.internal.StringUtil; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.core.ValueOperations; +import org.springframework.data.redis.core.ZSetOperations; +import org.springframework.data.redis.core.script.DefaultRedisScript; +import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; + +import jakarta.annotation.Resource; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.time.Duration; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.*; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +@Slf4j +@Component +public class RedisUtil { + + @Resource + private RedisTemplate redisTemplate; + + public final static String FLUX_POLLING_URL = "Flux:"; + /** + * 登录 token 在 Redis 中的前缀: + * 最终 key 结构为 login:token:{userId} + */ + public final static String LOGIN_TOKEN_KEY = "login:token:"; + + public Boolean hasKey(String key){ + return redisTemplate.hasKey(key); + } + + //- - - - - - - - - - - - - - - - - - - - - ZSet类型 - - - - - - - - - - - - - - - - - - - - + + /** + * 向ZSet中添加元素 + */ + public void addToZSet(String key, String value, Double score) { + redisTemplate.opsForZSet().add(key, value, score); + } + + /** + * 从ZSet中删除元素 + */ + public void removeFromZSet(String key, String value) { + redisTemplate.opsForZSet().remove(key, value); + } + + /** + * 获取指定元素的当前排列顺序 + */ + public Long getRank(String key, String value) { + return redisTemplate.opsForZSet().rank(key, value); + } + + /** + * 获取当前ZSet中的最大score + */ + public Double getMaxScore(String key) { + Set> set = redisTemplate.opsForZSet().reverseRangeWithScores(key, 0, 0); + + if (!CollectionUtils.isEmpty(set)) { + Double score = set.iterator().next().getScore(); + return score + 1.0; + } else { + return 1.0; + } + } + + /** + * 判断元素是否存在 + */ + public Boolean isElementExistsInZSet(String key, String value) { + return redisTemplate.opsForZSet().score(key, value) != null; + } + + /** + * 获取当前ZSet中数据量的总和 + */ + public Long getZSetTotalCount(String key) { + return redisTemplate.opsForZSet().zCard(key); + } + + + public Set getZSetTotalData(String key){ + return redisTemplate.opsForZSet().range(key, 0, -1); + } + + //- - - - - - - - - - - - - - - - - - - - - set类型 - - - - - - - - - - - - - - - - - - - - + + public final static String VIDEO_FINISHED_TASKS = "VideoFinishedTasks"; + + /** + * 将数据放入set缓存 + */ + public void addToSet(String key, String value, Long expiresIn) { + redisTemplate.opsForSet().add(key, value); + // 设置过期时间 + redisTemplate.expire(key, expiresIn, TimeUnit.SECONDS); + } + + /** + * 弹出变量中的元素 + */ + public void removeFromSet(String key, String value) { + redisTemplate.opsForSet().remove(key, value); + } + + /** + * 检查给定的元素是否在变量中。 + */ + public Boolean isElementExistsInSet(String key, String obj) { + return redisTemplate.opsForSet().isMember(key, obj); + } + + + //- - - - - - - - - - - - - - - - - - - - - hash类型 - - - - - - - - - - - - - - - - - - - - + + /** + * 加入缓存 + */ + public void addToMap(String key, Map map) { + redisTemplate.opsForHash().putAll(key, map); + } + + /** + * 验证指定 key 下 有没有指定的 hashkey + */ + public Boolean isElementExistsInMap(String key, String hashKey) { + return redisTemplate.opsForHash().hasKey(key, hashKey); + } + + /** + * 获取指定key的值string + */ + public String getMapValue(String key1, String key2) { + return String.valueOf(redisTemplate.opsForHash().get(key1, key2)); + } + + /** + * 删除指定 hash 的 HashKey + * + * @return 删除成功的 数量 + */ + public Long removeFromMap(String key, String hashKeys) { + return redisTemplate.opsForHash().delete(key, hashKeys); + } + + //- - - - - - - - - - - - - - - - - - - - - String类型 - - - - - - - - - - - - - - - - - - - - + public void addToString(String key, String value){ + redisTemplate.opsForValue().set(key,value); + } + + public void addToString(String key, String value, Long expiresIn){ + redisTemplate.opsForValue().set(key,value,expiresIn, TimeUnit.SECONDS); + } + + public String getFromString(String key){ + return redisTemplate.opsForValue().get(key); + } + + public Set getKeysFromString(String key){ + return redisTemplate.keys(key); + } + + public Long getSize(String key){return redisTemplate.opsForSet().size(key);} + + public List getMultiValue(Set keys){ + return redisTemplate.opsForValue().multiGet(keys); + } + + public Long getExpire(String key){ + return redisTemplate.getExpire(key); + } + + public void removeFromString(String key){ + redisTemplate.delete(key); + } + + /** + * 保存登录 token + * + * @param userId 用户 ID + * @param token token 字符串 + * @param expireMillis 过期时间(毫秒,通常与 JWT 保持一致) + */ + public void setLoginToken(Long userId, String token, long expireMillis) { + if (expireMillis <= 0) { + // 不设置过期时间,直到手动删除(不推荐) + addToString(LOGIN_TOKEN_KEY + userId, token); + return; + } + long expireSeconds = expireMillis / 1000; + if (expireSeconds <= 0) { + expireSeconds = 1; + } + addToString(LOGIN_TOKEN_KEY + userId, token, expireSeconds); + } + + /** + * 获取登录 token + */ + public String getLoginToken(Long userId) { + return getFromString(LOGIN_TOKEN_KEY + userId); + } + + /** + * 删除登录 token + */ + public void deleteLoginToken(Long userId) { + removeFromString(LOGIN_TOKEN_KEY + userId); + } + + public final static String PORTFOLIO_LIKE_KEY = "portfolio:like:"; + + public void likePost(Long portfolioId, Long userId) { + redisTemplate.opsForSet().add(PORTFOLIO_LIKE_KEY + portfolioId, String.valueOf(userId)); + } + + public Long getLikeCount(Long portfolioId) { + String key = PORTFOLIO_LIKE_KEY + portfolioId; + return redisTemplate.opsForSet().size(key); + } + + public List getLikedPortfolios(Long userId) { + // 获取所有包含PORTFOLIO_LIKE_KEY的键 + Set likedPortfolios = redisTemplate.keys(PORTFOLIO_LIKE_KEY + "*"); + + // 如果没有喜欢的,返回空列表 + if (likedPortfolios == null || likedPortfolios.isEmpty()) { + return new ArrayList<>(); + } + + // 过滤出包含指定用户ID的键,并提取投资组合ID + return likedPortfolios.stream() + .filter(key -> redisTemplate.opsForSet().isMember(key, String.valueOf(userId))) + .map(key -> Long.valueOf(key.replace(PORTFOLIO_LIKE_KEY, ""))) + .collect(Collectors.toList()); + } + + public void unLikePost(Long portfolioId, Long userId) { + redisTemplate.opsForSet().remove(PORTFOLIO_LIKE_KEY + portfolioId, userId.toString()); + } + + // 检查用户是否喜欢某个作品 + public boolean isPostLikedByUser(Long portfolioId, Long userId) { + String key = PORTFOLIO_LIKE_KEY + portfolioId; + Boolean isMember = redisTemplate.opsForSet().isMember(key, userId.toString()); + return isMember != null && isMember; + } + + public final static String PORTFOLIO_VIEW_KEY = "portfolio:view:"; + + public void increaseViewCount(Long portfolioId) { + String key = PORTFOLIO_VIEW_KEY + portfolioId; + redisTemplate.opsForValue().increment(key); + } + + public Long getViewCount(Long portfolioId) { + String key = PORTFOLIO_VIEW_KEY + portfolioId; + return redisTemplate.opsForValue().increment(key, 0); + } + + public Long getViewCount(String key) { + Object value = redisTemplate.opsForValue().get(key); + if (value instanceof Integer) { + return Long.valueOf((Integer) value); + } else if (value instanceof Long) { + return (Long) value; + } else if (value instanceof String) { + return Long.valueOf((String) value); + } else { + throw new IllegalArgumentException("Unexpected value type"); + } + } + + public final static String PERSONAL_HOMEPAGE_VIEW_KEY = "PersonalHomepage:view:"; + + public void increasePersonalHomepageViewCount(Long accountId) { + String key = PERSONAL_HOMEPAGE_VIEW_KEY + accountId; + redisTemplate.opsForValue().increment(key); + } + + public Long getPersonalHomepageViewCount(Long accountId) { + String key = PERSONAL_HOMEPAGE_VIEW_KEY + accountId; + return redisTemplate.opsForValue().increment(key, 0); + } + + public final static String MOODBOARD_POSITION_KEY = "moodboard:position:"; + + public void saveMoodboardPosition(Long id, String moodboardPosition) { + addToString(MOODBOARD_POSITION_KEY + id, moodboardPosition); + } + + public String getMoodboardPosition(Long id) { + return getFromString(MOODBOARD_POSITION_KEY + id); + } + public final static String NICKNAME_MODIFY_TIMES = "NicknameModifyTimes:"; + public final static String UNNAMED_PROJECT_SEQ = "Project:UnnamedProjectSeq:"; + public Long increaseCount(String key) { + return redisTemplate.opsForValue().increment(key); + } + + public Long getIncrementCount(String key) { + return redisTemplate.opsForValue().increment(key, 0); + } + + public void setKeyExpire(String key, Long expire) { + redisTemplate.expire(key, expire, TimeUnit.DAYS); + } + + public final static String CHANGE_MAILBOX = "ChangeMailbox:"; + + // 每天允许通知3次 + public final static String UPLOAD_TIMEOUT_REMINDER_COUNTER = "UploadTimeoutReminderCounter"; + + public void addProcessId(String processId, int progress) { + // Redis 中的键,可以通过 processId 来唯一标识 + String redisKey = "process:progress:" + processId; + + // 将当前进度存储到 Redis + redisTemplate.opsForValue().set(redisKey, String.valueOf(progress)); + + // 设置过期时间为 5 分钟(300 秒) + redisTemplate.expire(redisKey, 5, TimeUnit.MINUTES); + } + + public void addPathToCache(Long collectionId, Long userId, String path) { + // Redis 中的键,唯一标识由 collectionId 和 userId 组成 + String redisKey = "path:cache:" + collectionId + ":" + userId; + + // 增加路径的计数 + redisTemplate.opsForHash().increment(redisKey, path, 1); + + // 设置过期时间为 2 小时(7200 秒) + redisTemplate.expire(redisKey, 8, TimeUnit.HOURS); + } + + public int getPathUsageCount(Long collectionId, Long userId, String path) { + String redisKey = "path:cache:" + collectionId + ":" + userId; + + // 获取路径的使用次数 + Object count = redisTemplate.opsForHash().get(redisKey, path); + return count != null ? Integer.parseInt(count.toString()) : 0; + } + + public void addAssembledObjects(Long collectionId, Set assembledObjects) { + // Redis 中的键,使用 collectionId 来唯一标识 + String redisKey = "collection:assembledObjects:" + collectionId; + + // 将 assembledObjects 转换为 JSON 格式存储,避免直接存储对象 + String assembledObjectsJson = convertToJson(assembledObjects); + + // 使用 Redis 的 set 操作更新集合 + redisTemplate.opsForValue().set(redisKey, assembledObjectsJson); + + // 设置过期时间为 5 分钟(300 秒) + redisTemplate.expire(redisKey, 30, TimeUnit.MINUTES); + } + + // 将 Set 转换为 JSON 格式 + private String convertToJson(Set assembledObjects) { + try { + ObjectMapper objectMapper = new ObjectMapper(); + return objectMapper.writeValueAsString(assembledObjects); + } catch (JsonProcessingException e) { + e.printStackTrace(); + return null; + } + } + + public Set getAssembledObjects(Long collectionId) { + // Redis 中的键,使用 collectionId 来唯一标识 + String redisKey = "collection:assembledObjects:" + collectionId; + + // 从 Redis 获取存储的 JSON 字符串 + String assembledObjectsJson = (String) redisTemplate.opsForValue().get(redisKey); + + if (assembledObjectsJson == null) { + return new HashSet<>(); // 如果没有找到数据,返回一个空的 Set + } + + // 将 JSON 字符串转换为 Set + return convertFromJson(assembledObjectsJson); + } + + // 将 JSON 字符串转换为 Set + private Set convertFromJson(String json) { + try { + ObjectMapper objectMapper = new ObjectMapper(); + // 使用 TypeReference 来指定目标类型是 Set + return objectMapper.readValue(json, new TypeReference>() {}); + } catch (JsonProcessingException e) { + e.printStackTrace(); + return new HashSet<>(); // 如果转换失败,返回空的 Set + } + } + + public final static String PAYMENT_INFO_LAST_SCAN_TIME = "PaymentInfoLastScanTime"; + + public final static String AFFILIATE_LINK_VIEW_KEY = "AffiliateLink:view:"; + + public void increaseAffiliateLinkViewCount(Long accountId) { + String key = AFFILIATE_LINK_VIEW_KEY + accountId; + redisTemplate.opsForValue().increment(key); + } + + public Long getAffiliateLinkViewCount(Long accountId) { + String key = AFFILIATE_LINK_VIEW_KEY + accountId; + return redisTemplate.opsForValue().increment(key, 0); + } + + /** + * 记录任务的耗时到Redis + * @param taskKey 任务标识,如 "taskA" + * @param elapsedTime 本次耗时,单位为毫秒 + */ + public void recordTaskElapsedTime(String taskKey, long elapsedTime) { + String hashKey = "task:stats"; + + // 累加总耗时 + redisTemplate.opsForHash().increment(hashKey, taskKey + ":totalTime", elapsedTime); + + // 增加计数器 + redisTemplate.opsForHash().increment(hashKey, taskKey + ":count", 1); + } + + /** + * 获取任务的平均耗时 + * @param taskKey 任务标识,如 "taskA" + * @return 平均耗时(毫秒) + */ + public double getTaskAverageTime(String taskKey) { + String hashKey = "task:stats"; + + // 获取总耗时和计数 + Object totalTime = redisTemplate.opsForHash().get(hashKey, taskKey + ":totalTime"); + Object count = redisTemplate.opsForHash().get(hashKey, taskKey + ":count"); + + // 计算平均值 + if (totalTime == null || count == null) { + return 0; + } + return Double.parseDouble(totalTime.toString()) / Long.parseLong(count.toString()); + } + + /** + * 清除指定任务的统计数据 + * @param taskKey 任务标识,如 "taskA" + */ + public void clearTaskStats(String taskKey) { + String hashKey = "task:stats"; + + // 删除总耗时和计数器 + redisTemplate.opsForHash().delete(hashKey, taskKey + ":totalTime", taskKey + ":count"); + } + + public void recordTaskElapsedTime(String taskKey, double elapsedTimeInSeconds) { + // 将耗时转换为 BigDecimal,并四舍五入保留四位小数 + BigDecimal elapsedTime = new BigDecimal(elapsedTimeInSeconds).setScale(4, RoundingMode.HALF_UP); + + // 累加总耗时(以毫秒为单位) + redisTemplate.opsForHash().increment("task:stats", taskKey + ":totalTime", elapsedTime.doubleValue()); + + // 增加计数器 + redisTemplate.opsForHash().increment("task:stats", taskKey + ":count", 1); + } + + // 获取第一部分(Sketch)耗时 + public double getFirstSketchTime() { + // 获取 "firstSketchTime:totalTime" 对应的值,并返回(单位为秒) + Object time = redisTemplate.opsForHash().get("task:stats", "firstSketchTime:totalTime"); + return time != null ? (double) time : 0.0; + } + + // 获取第二部分(获取特征值)耗时 + public double getGetAttributeRecognitionTime() { + // 获取 "getAttributeRecognitionTime:totalTime" 对应的值,并返回(单位为秒) + Object time = redisTemplate.opsForHash().get("task:stats", "getAttributeRecognitionTime:totalTime"); + return time != null ? (double) time : 0.0; + } + + // 获取第三部分(搭配 Sketch)耗时 + public double getOtherSketchTime() { + // 获取 "otherSketchTime:totalTime" 对应的值,并返回(单位为秒) + Object time = redisTemplate.opsForHash().get("task:stats", "otherSketchTime:totalTime"); + return time != null ? (double) time : 0.0; + } + + // 清理三部分的缓存 + public void clearTaskElapsedTimeCache() { + // 删除第一部分的缓存 + redisTemplate.opsForHash().delete("task:stats", "firstSketchTime:totalTime"); + redisTemplate.opsForHash().delete("task:stats", "firstSketchTime:count"); + + // 删除第二部分的缓存 + redisTemplate.opsForHash().delete("task:stats", "getAttributeRecognitionTime:totalTime"); + redisTemplate.opsForHash().delete("task:stats", "getAttributeRecognitionTime:count"); + + // 删除第三部分的缓存 + redisTemplate.opsForHash().delete("task:stats", "otherSketchTime:totalTime"); + redisTemplate.opsForHash().delete("task:stats", "otherSketchTime:count"); + } + + public boolean incrementLikeCount(Long userId, String sketchPath) { + String redisKey = "user_like_count:" + userId; + try { + redisTemplate.opsForHash().increment(redisKey, sketchPath, 1); + return true; + } catch (Exception e) { + log.error("Error incrementing like count for userId {} and sketchPath {}: {}", userId, sketchPath, e.getMessage()); + return false; + } + } + + public int getLikeCount(Long userId, String sketchPath) { + String redisKey = "user_like_count:" + userId; + Object count = redisTemplate.opsForHash().get(redisKey, sketchPath); + return count != null ? Integer.parseInt(count.toString()) : 0; + } + + public void storeMaxLikeCount(Long userId, int maxLikeCount) { + String redisKey = "user_max_like_count:" + userId; + redisTemplate.opsForValue().set(redisKey, String.valueOf(maxLikeCount)); + } + + public int getMaxLikeCount(Long userId) { + String redisKey = "user_max_like_count:" + userId; + String maxLikeCount = redisTemplate.opsForValue().get(redisKey); + return maxLikeCount != null ? Integer.parseInt(maxLikeCount) : 0; + } + + public final static String IMAGE_SEGMENTATION = "ImageSegmentation:"; + + public final static String STRIPE_EXCEPTION_LOG = "StripeException:"; + public final static String SUBSCRIPTION_SENT_EMAIL_TYPE = "SubscriptionEmailSentType:"; + + public void batchDeleteKeysWithSamePrefix(String prefix){ + Set keys = redisTemplate.keys(prefix + "*"); + assert keys != null; + if (!keys.isEmpty()){ + redisTemplate.delete(keys); + } + } + + public void setTaskProgressDTO(String taskId, ProgressDTO dto) { + String key = "task:progress:" + taskId; + redisTemplate.opsForValue().set(key, JSON.toJSONString(dto), Duration.ofDays(1)); + } + + public ProgressDTO getTaskProgressDTO(String taskId) { + String key = "task:progress:" + taskId; + String json = redisTemplate.opsForValue().get(key); + if (StringUtils.isBlank(json)) { +// return new ProgressDTO(0, 0, false); + return null; + } + try { + return JSON.parseObject(json, ProgressDTO.class); + } catch (Exception e) { + log.warn("任务进度解析失败 key={}, json={}", key, json); + return new ProgressDTO(0, 0, false, null); + } + } + + // Lua脚本(原子化操作) + /*private static final String RATE_LIMIT_SCRIPT = + "local current = redis.call('INCR', KEYS[1])\n" + + "if tonumber(current) == 1 then\n" + + " redis.call('EXPIRE', KEYS[1], ARGV[1])\n" + + "end\n" + + "return tonumber(current) <= tonumber(ARGV[2])";*/ + private static final String RATE_LIMIT_SCRIPT = + "local current = redis.call('INCR', KEYS[1])\n" + + "local ttl = redis.call('TTL', KEYS[1])\n" + + "if tonumber(current) == 1 or tonumber(ttl) == -1 then\n" + + " redis.call('EXPIRE', KEYS[1], ARGV[1])\n" + + "end\n" + + "return tonumber(current) <= tonumber(ARGV[2])"; + + /** + * 检查是否允许发送 + * @param userId 用户ID + * @return true-允许发送,false-已超限 + */ + public boolean allowSend(Long userId) { + String hourKey = getCurrentHourKey(userId); + + // 执行Lua脚本 + List keys = Collections.singletonList(hourKey); + List args = Arrays.asList( + 3600L, // 1小时过期 + 10L // 限制数量 一小时只能向普通用户发10封 + ); + + Boolean result = redisTemplate.execute( + new DefaultRedisScript<>(RATE_LIMIT_SCRIPT, Boolean.class), + keys, + args.toArray() + ); + + return Boolean.TRUE.equals(result); + } + + /** + * 获取当前小时的Key + * 格式:email_limit:{userId}:{yyyyMMddHH} + */ + private String getCurrentHourKey(Long userId) { + String hour = LocalDateTime.now() + .format(DateTimeFormatter.ofPattern("yyyyMMddHH")); + return String.format("email_limit:%s:%s", userId, hour); + } + + /** + * 获取当前已发送数量 + */ + public int getCurrentCount(Long userId) { + String key = getCurrentHourKey(userId); + String val = redisTemplate.opsForValue().get(key); + int count; + if (StringUtils.isBlank(val)){ + count = 0; + }else { + count = Integer.parseInt(val); + } + return count; + } + + public boolean allowRequest(String apiKey) { + String key = "rate_limit:" + apiKey; + ValueOperations ops = redisTemplate.opsForValue(); + + // 使用Redis的INCR命令 + Long count = ops.increment(key, 1); + + if (count == 1) { + // 第一次调用,设置过期时间 + redisTemplate.expire(key, 1, TimeUnit.MINUTES); + } + + return count <= 3; + } + } \ No newline at end of file diff --git a/src/main/java/com/ai/da/model/dto/ImageProcessRequest.java b/src/main/java/com/ai/da/model/dto/ImageProcessRequest.java new file mode 100644 index 00000000..79380472 --- /dev/null +++ b/src/main/java/com/ai/da/model/dto/ImageProcessRequest.java @@ -0,0 +1,55 @@ +package com.ai.da.model.dto; + +import lombok.Builder; +import lombok.Data; + +import java.util.List; + +/** + * 图片处理请求体 + */ +@Data +@Builder +public class ImageProcessRequest { + + /** + * OSS桶名(bucket_name) + */ + private String bucket_name; + + /** + * OSS对象名(object_name) + */ + private String object_name; + + /** + * 输入图片路径列表(input_image_paths) + */ + private List input_image_paths; + + /** + * 图像宽度(width) + */ + private Integer width; + + /** + * 图像高度(height) + */ + private Integer height; + + /** + * 文本提示(prompt) + */ + private String prompt; + + /** + * 推理步数(steps) + */ + private Integer steps; + + /** + * 引导系数(guidance) + */ + private Double guidance; + +} \ No newline at end of file diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index b95a2804..4ba86ce3 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -2,8 +2,10 @@ package com.ai.da.python; import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.exceptions.ExceptionUtil; +import com.ai.da.common.RabbitMQ.RabbitMQProperties; import com.ai.da.common.config.FileProperties; import com.ai.da.common.config.exception.BusinessException; +import com.ai.da.common.constant.CommonConstant; import com.ai.da.common.context.UserContext; import com.ai.da.common.enums.*; import com.ai.da.common.utils.*; @@ -20,6 +22,7 @@ import com.ai.da.model.vo.*; import com.ai.da.python.vo.*; import com.ai.da.service.DesignHistoryService; import com.ai.da.service.PythonTAllInfoService; +import com.ai.da.service.RabbitMQService; import com.ai.da.service.SysFileService; import com.alibaba.fastjson.*; import com.alibaba.fastjson.serializer.SerializerFeature; @@ -69,6 +72,8 @@ public class PythonService { private String accessPythonPort; @Value("${minio.bucketName.gradient}") private String gradientBucketName; + @Value("${minio.bucketName.users}") + private String userBucketName; @Value("${access.python.generate_sr_port}") private String srServicePort; @@ -84,6 +89,12 @@ public class PythonService { @Resource private RedisUtil redisUtil; + @Resource + private RabbitMQService rabbitMQService; + + @Resource + private RabbitMQProperties rabbitMQProperties; + /** * 生成打印的图片 二合一 (废弃于2024/01/02) * @@ -3334,7 +3345,7 @@ public class PythonService { throw new BusinessException("system error!"); } - public Boolean generateSketchOrPrint(String params, String port, String servicePath) { + public Boolean generateSketchOrPrint(String params, String port, String servicePath, String taskId) { //限流校验 // AccessLimitUtils.validate("generateSketchOrPrint", 5); OkHttpClient client = new OkHttpClient().newBuilder() @@ -3396,12 +3407,36 @@ public class PythonService { if (result && jsonObject.get("code").equals(200)) { log.info("Generate##responseObject###{}", jsonObject); // return setGenerateImageList(jsonObject.getJSONObject("data")); + if (servicePath == CommonConstant.GENERATE_PATH_FLUX2_KLEIN) { + //放入结果到mq + JSONObject data = jsonObject.getJSONObject("data"); + String outputPath = data.getString("output_path"); + + + Map mqMessage = new HashMap<>(); + mqMessage.put("tasks_id", taskId); + mqMessage.put("status", "SUCCESS"); + mqMessage.put("message", "success"); + mqMessage.put("image_url", outputPath); + mqMessage.put("category", ""); + String mqMessageBody = JSON.toJSONString(mqMessage); + rabbitMQService.publishMessageToGenerateResult(mqMessageBody); + } return Boolean.TRUE; } else { log.info("generateSketchOrPrintPrint失败###{}", jsonObject); log.info("Generate Exception! Code : " + jsonObject.get("code")); + Map mqMessage = new HashMap<>(); + mqMessage.put("tasks_id", taskId); + mqMessage.put("status", "ERROR"); + mqMessage.put("message", ""); + mqMessage.put("image_url", ""); + mqMessage.put("category", ""); + String mqMessageBody = JSON.toJSONString(mqMessage); + rabbitMQService.publishMessageToGenerateResult(mqMessageBody); return Boolean.FALSE; } + } public Response sendPostToModel(String content, String portAndRoute, String functionName) { @@ -4139,6 +4174,9 @@ public class PythonService { .writeTimeout(60, TimeUnit.SECONDS) .build(); MediaType mediaType = MediaType.parse("application/json"); + content.put("bucket", userBucketName); + content.put("object_name", content.get("user_id") + "/" + "segment" + "/" + UUID.randomUUID() + ".png"); + content.remove("user_id"); RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content)); String url = accessPythonIp + ":" + accessPythonPort + "/api/seg_anything"; diff --git a/src/main/java/com/ai/da/service/RabbitMQService.java b/src/main/java/com/ai/da/service/RabbitMQService.java index 7797c905..bf18b6eb 100644 --- a/src/main/java/com/ai/da/service/RabbitMQService.java +++ b/src/main/java/com/ai/da/service/RabbitMQService.java @@ -7,6 +7,8 @@ public interface RabbitMQService { void publishMessageToGenerate(String message); + void publishMessageToGenerateResult(String message); + void publishMessageToSR(String message); Integer getMessageCount(String queueUrl); diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index fa080170..d23c7b6b 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -48,6 +48,8 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import okhttp3.*; import org.apache.commons.io.IOUtils; +import org.springframework.transaction.support.TransactionSynchronization; +import org.springframework.transaction.support.TransactionSynchronizationManager; import org.apache.http.HttpHost; import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.CloseableHttpResponse; @@ -59,10 +61,12 @@ import org.bytedeco.javacv.Java2DFrameConverter; import org.springframework.beans.factory.annotation.Value; import org.springframework.dao.DuplicateKeyException; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Propagation; import org.springframework.transaction.annotation.Transactional; import org.springframework.util.StringUtils; import jakarta.annotation.Resource; + import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import java.io.*; @@ -194,10 +198,13 @@ public class GenerateServiceImpl extends ServiceImpl i generate.setText(text); Long elementId = generateThroughImageTextDTO.getCollectionElementId(); // validateGeneraType(generate, text, elementId); - if (!StringUtil.isNullOrEmpty(text)) { - text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type(), generateThroughImageTextDTO.getAgeGroup()); + if (!(generateThroughImageTextDTO.getLevel1Type().equals(MOOD_BOARD.getRealName())&&generateThroughImageTextDTO.getModelName().equals("high"))){ + if (!StringUtil.isNullOrEmpty(text)) { + text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type(), generateThroughImageTextDTO.getAgeGroup()); + } } + // todo 这一步现在还是有必要的吗? // 2.1 sketch或print在t_collection_element表/t_library表中的信息是否需要更新 如 level2Type CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type(), generateThroughImageTextDTO.getDesignType()); @@ -218,6 +225,8 @@ public class GenerateServiceImpl extends ServiceImpl i version = "fast"; params.put("version", "fast"); } + // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中 + saveGenerateImmediately(generate); // 3.1 确定不同类型的印花分别调哪个接口 if (generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())) { switch (generateThroughImageTextDTO.getLevel2Type()) { @@ -243,15 +252,28 @@ public class GenerateServiceImpl extends ServiceImpl i jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue); } } else { - GenerateToPythonDTO generateToPythonDTO = new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), - mode, category, generateThroughImageTextDTO.getGender(), version); - jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue); + if (Objects.equals(version, "fast")) { + GenerateToPythonDTO generateToPythonDTO = new GenerateToPythonDTO(generateThroughImageTextDTO.getUniqueId(), text, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(), + mode, category, generateThroughImageTextDTO.getGender(), version); + jsonString = JSON.toJSONString(generateToPythonDTO, SerializerFeature.WriteMapNullValue); + } else { + + + path = CommonConstant.GENERATE_PATH_FLUX2_KLEIN; + // 构建object_name: {userId}/{category}/{uuid}.png + String objectName = generateThroughImageTextDTO.getUserId() + "/" + category + "/" + UUID.randomUUID() + ".png"; + + ImageProcessRequest imageProcessRequest = ImageProcessRequest.builder() + .object_name(objectName) + .bucket_name(userBucket) + .prompt(text).build(); + jsonString = JSON.toJSONString(imageProcessRequest); + } + } - Boolean requestResult = pythonService.generateSketchOrPrint(jsonString, port, path); + Boolean requestResult = pythonService.generateSketchOrPrint(jsonString, port, path, generateThroughImageTextDTO.getUniqueId()); - // 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中 - save(generate); // 5、将本次请求存入redis String key = generateResultKey + ":" + generateThroughImageTextDTO.getUniqueId(); @@ -266,6 +288,40 @@ public class GenerateServiceImpl extends ServiceImpl i } + public void saveGenerateImmediately(Generate generate) { + save(generate); + // 使用 TransactionSynchronizationManager 在事务真正提交后再设锁 + // 否则 save() 完成后事务尚未 commit,MQ 消费者立即读到 null + TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() { + @Override + public void afterCommit() { + String lockKey = "generate:lock:" + generate.getUniqueId(); + redisUtil.addToString(lockKey, "1", 60L); + log.debug("Save lock set after commit for uniqueId: {}", generate.getUniqueId()); + } + }); + } + + private void waitForSaveLock(String uniqueId) { + String lockKey = "generate:lock:" + uniqueId; + int maxRetries = 30; + int retryIntervalMs = 200; + for (int i = 0; i < maxRetries; i++) { + if (Boolean.TRUE.equals(redisUtil.hasKey(lockKey))) { + log.debug("Save lock acquired for uniqueId: {} after {} retries", uniqueId, i); + return; + } + try { + Thread.sleep(retryIntervalMs); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + log.warn("Interrupted while waiting for save lock: {}", uniqueId); + return; + } + } + log.warn("Save lock timeout for uniqueId: {}, proceeding anyway", uniqueId); + } + public GenerateModeEnum getMode(GenerateThroughImageTextDTO generateThroughImageTextDTO) { if (!StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getText())) { if (Objects.nonNull(generateThroughImageTextDTO.getCollectionElementId())) { @@ -284,11 +340,16 @@ public class GenerateServiceImpl extends ServiceImpl i @Override @Transactional(rollbackFor = Exception.class) public void processGenerateResult(String taskId, String url, String category) { + log.info("============ProcessGenerateResult listening=========="); + log.debug("taskId: " + taskId); + String status = null; // 1、处理模型返回的数据 GenerateDetail generateDetail = new GenerateDetail(); GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); Generate generate; try { + // 等待 HTTP 线程写入完成后再查库 + waitForSaveLock(taskId); generate = selectByUniqueId(taskId); } catch (MybatisPlusException e) { log.error(e.getMessage()); @@ -311,14 +372,15 @@ public class GenerateServiceImpl extends ServiceImpl i generateDetail.setUrl(url); generateDetail.setGenerateId(generate.getId()); generateDetail.setCreateDate(LocalDateTime.now()); - generateDetail.setMd5(md5); + generateDetail.setMd5(""); // 将相应的url保存到数据库 generateDetailMapper.insert(generateDetail); + log.debug("generateDetail: " + generateDetail.toString()); // String uuid = taskId.substring(0, taskId.substring(0, taskId.lastIndexOf("-")).lastIndexOf("-")); String key = generateResultKey + ":" + taskId; String imageName = url.substring(url.lastIndexOf("/") + 1); - String status = imageName.equals("white_image.jpg") ? "Invalid" : "Success"; + status = imageName.equals("white_image.jpg") ? "Invalid" : "Success"; if (StringUtil.isNullOrEmpty(category)) { Generate generateRecord = selectByUniqueId(taskId); category = generateRecord.getLevel2Type(); @@ -326,6 +388,8 @@ public class GenerateServiceImpl extends ServiceImpl i GenerateResultVO generateResultVO = new GenerateResultVO(taskId, generateDetail.getId(), url, status, category); // 更新redis redisUtil.addToString(key, new Gson().toJson(generateResultVO), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); + log.debug("generateResultVO: " + generateResultVO.toString()); + // 执行积分扣除 // ** 注:如果生成的图片都是空白 则不扣积分 @@ -785,8 +849,9 @@ public class GenerateServiceImpl extends ServiceImpl i long requestEndTime = System.currentTimeMillis(); log.info("HTTP请求完成 - 响应状态: {}, 耗时: {}ms, taskId: {}", response.code(), (requestEndTime - requestStartTime), taskId); + String result = response.body().string(); if (!response.isSuccessful()) { - log.warn("Google API响应失败,状态码: {} for taskId: {}", response.code(), taskId); + log.warn("Google API响应失败,状态码: {} for taskId: {},结果:{}", response.code(), taskId, result); if (attempt < maxRetries) { Thread.sleep(retryDelay * attempt); // 递增延迟 continue; @@ -795,7 +860,7 @@ public class GenerateServiceImpl extends ServiceImpl i } } - String result = response.body().string(); + // log.info("Google 响应结果:{}", result); com.alibaba.fastjson.JSONObject jsonResponse = JSON.parseObject(result); @@ -1065,6 +1130,12 @@ public class GenerateServiceImpl extends ServiceImpl i String result = response.body().string(); + if (response.code() != 200) { + log.error("Google API 请求失败 - taskId: {}, 尝试: {}, URL: {}, 状态码: {}, 响应结果: {}", + taskId, attempt, endpoint, response.code(), result); + throw new BusinessException("system.error"); + } + // log.info("Google 响应结果:{}", result); com.alibaba.fastjson.JSONObject jsonResponse = JSON.parseObject(result); @@ -1203,7 +1274,7 @@ public class GenerateServiceImpl extends ServiceImpl i * @param modelName advanced high normal */ private HashMap chooseModelAndPrompt(GenerateThroughImageTextDTO generateDTO, String modelName) { - if (StringUtil.isNullOrEmpty(modelName)){ + if (StringUtil.isNullOrEmpty(modelName)) { throw new BusinessException("system error"); } HashMap modelAndPromptMap = new HashMap<>(); @@ -1221,7 +1292,7 @@ public class GenerateServiceImpl extends ServiceImpl i String style = generateDTO.getText().substring(0, firstCommaIndex).trim(); String prompt = generateDTO.getText().substring(firstCommaIndex + 1).trim(); - prompt = getPrintboardPrompt(style, prompt,modelName,isUseImage); + prompt = getPrintboardPrompt(style, prompt, modelName, isUseImage); modelAndPromptMap.put(ModelConstants.PROMPT, prompt); @@ -1260,14 +1331,31 @@ public class GenerateServiceImpl extends ServiceImpl i modelAndPromptMap.put(ModelConstants.USE_MODEL, ModelConstants.LOCAL_MODEL); } } else if (ModelConstants.SKETCHBOARD.equals(generateDTO.getLevel1Type())) { - String[] split = generateDTO.getText().split(","); - String style = split[0].trim(); - if ("Lolita".equals( style)){ - style = "洛丽塔"; + String style = ""; + String userPrompt = ""; + // 找到第一个逗号的位置 + int firstCommaIndex = generateDTO.getText().indexOf(","); + if (firstCommaIndex != -1) { + // 截取第一个逗号前的内容作为style + style = generateDTO.getText().substring(0, firstCommaIndex).trim(); + // 截取第一个逗号后的所有内容作为userPrompt(去除首尾空格) + userPrompt = generateDTO.getText().substring(firstCommaIndex + 1).trim(); + + if ("Lolita".equals(style)) { + style = "洛丽塔"; + } + } else { + // 兼容无逗号的情况:style为空,全部内容作为userPrompt + userPrompt = generateDTO.getText().trim(); } - String userPrompt = split[1]; + String prompt = userPrompt + "rules:front view sketch only,plain white background, single garment only, orthographic, centered on white background, borderless canvas, thin monochrome black line art.\n" + - " No clothes hanger, no fake clothes hanger, no human-related lines, no color fill, no words, no text, no black background, no boundary or frame.sketch style:"+ style; + " No clothes hanger, no fake clothes hanger, no human-related lines, no color fill, no words, no text, no black background, no boundary or frame."; + + if (!style.trim().isEmpty() && !"all".equalsIgnoreCase(style)) { + prompt += ".sketch style:" + style.trim(); + } + modelAndPromptMap.put(ModelConstants.PROMPT, prompt); if (isUseImage) { if (ModelConstants.ADVANCED.equals(modelName)) { @@ -1465,6 +1553,13 @@ public class GenerateServiceImpl extends ServiceImpl i if (imagePath != null) { requestBuilder.image(finalImagePath1); } + if (useModel.equals(ModelConstants.PRINTBOARD_HIGH_I2I)) { + GenerateImagesRequest.OptimizePromptOptions optimizePromptOptions = new GenerateImagesRequest.OptimizePromptOptions(); + optimizePromptOptions.setMode("fast"); + requestBuilder.optimizePromptOptions(optimizePromptOptions); + //由于PRINTBOARD_HIGH_I2I与PRINTBOARD_ADVANCED_I2I使用模型一致,为了区别积分扣除,PRINTBOARD_HIGH_I2I加入了-fast,但传入模型时需要去掉-fast,用PRINTBOARD_ADVANCED_I2I的常量做替代 + requestBuilder.model(ModelConstants.PRINTBOARD_ADVANCED_I2I); + } // 保存生成记录到数据库 Generate generate = new Generate( @@ -1602,14 +1697,14 @@ public class GenerateServiceImpl extends ServiceImpl i "Flat textile pattern printed directly on fabric surface, no three-dimensional objects, no items placed on cloth. \n" + "Real style: fabric print, realistic woven/printed pattern, detailed surface pattern only"; } - }else { - throw new BusinessException("style error:"+ style); + } else { + throw new BusinessException("style error:" + style); } if (userInput == null || userInput.trim().isEmpty()) { - if (isUseImage){ + if (isUseImage) { prompt = "Theme: Image content" + "\nRequirement: " + systemPrompt; - }else { + } else { throw new BusinessException("prompt null"); } } else { @@ -2011,7 +2106,9 @@ public class GenerateServiceImpl extends ServiceImpl i public Generate selectByUniqueId(String uniqueId) { QueryWrapper qw = new QueryWrapper<>(); qw.eq("unique_id", uniqueId); - + log.debug("selectByUniqueId: " + uniqueId); + Generate one = getOne(qw); + log.debug("Generate: " + one); return getOne(qw); } @@ -4182,11 +4279,11 @@ public class GenerateServiceImpl extends ServiceImpl i // 处理不同状态 switch (statusEnum) { case TASK_NOT_FOUND: - // 审核没过 + // 审核没过 case REQUEST_MODERATED: - // 审核没过 + // 审核没过 case CONTENT_MODERATED: - // 出错 + // 出错 case ERROR: return "Fail"; case PENDING_F: diff --git a/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java b/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java index 0f0e6dc4..42e920bf 100644 --- a/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/RabbitMQServiceImpl.java @@ -31,6 +31,11 @@ public class RabbitMQServiceImpl implements RabbitMQService { mqPublisher.sendGenerateMessage(message); } + @Override + public void publishMessageToGenerateResult(String message) { + mqPublisher.sendGenerateResultMessage(message); + } + @Override public void publishMessageToSR(String message) { mqPublisher.sendSRMessage(message);