From 6aa1a3d167946b4eb838af6621add2a65fd18cd1 Mon Sep 17 00:00:00 2001 From: shahaibo <1023316923@qq.com> Date: Mon, 19 May 2025 10:00:19 +0800 Subject: [PATCH] =?UTF-8?q?TASK:LLM=EF=BC=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../com/ai/da/controller/LLMController.java | 6 +++--- .../com/ai/da/model/dto/ChatHistoryDTO.java | 7 +++++++ src/main/java/com/ai/da/service/LLMService.java | 3 ++- .../ai/da/service/impl/DesignServiceImpl.java | 17 +++++++++++++++++ .../com/ai/da/service/impl/LLMServiceImpl.java | 15 ++++++++++----- 5 files changed, 39 insertions(+), 9 deletions(-) create mode 100644 src/main/java/com/ai/da/model/dto/ChatHistoryDTO.java diff --git a/src/main/java/com/ai/da/controller/LLMController.java b/src/main/java/com/ai/da/controller/LLMController.java index 565f0b03..95283481 100644 --- a/src/main/java/com/ai/da/controller/LLMController.java +++ b/src/main/java/com/ai/da/controller/LLMController.java @@ -74,8 +74,8 @@ public class LLMController { } @ApiOperation(value = "获取历史聊天记录") - @GetMapping(value = "/getChatHistory") - public Response> getChatHistory(@RequestParam Long projectId) { - return Response.success(llmService.getChatHistory(projectId)); + @PostMapping(value = "/getChatHistory") + public Response> getChatHistory(@RequestBody ChatHistoryDTO chatHistoryDTO) { + return Response.success(llmService.getChatHistory(chatHistoryDTO)); } } diff --git a/src/main/java/com/ai/da/model/dto/ChatHistoryDTO.java b/src/main/java/com/ai/da/model/dto/ChatHistoryDTO.java new file mode 100644 index 00000000..1f608b6d --- /dev/null +++ b/src/main/java/com/ai/da/model/dto/ChatHistoryDTO.java @@ -0,0 +1,7 @@ +package com.ai.da.model.dto; + +import com.ai.da.model.vo.PageQueryBaseVo; + +public class ChatHistoryDTO extends PageQueryBaseVo { + private Long projectId; +} diff --git a/src/main/java/com/ai/da/service/LLMService.java b/src/main/java/com/ai/da/service/LLMService.java index 4dab01dd..0f4b400b 100644 --- a/src/main/java/com/ai/da/service/LLMService.java +++ b/src/main/java/com/ai/da/service/LLMService.java @@ -1,6 +1,7 @@ package com.ai.da.service; import com.ai.da.common.response.PageBaseResponse; +import com.ai.da.common.response.Response; import com.ai.da.mapper.primary.entity.Account; import com.ai.da.mapper.primary.entity.AccountExtend; import com.ai.da.mapper.primary.entity.ChatMessage; @@ -36,5 +37,5 @@ public interface LLMService { List uploadFile(MultipartFile file); - List getChatHistory(Long projectId); + PageBaseResponse getChatHistory(ChatHistoryDTO chatHistoryDTO); } diff --git a/src/main/java/com/ai/da/service/impl/DesignServiceImpl.java b/src/main/java/com/ai/da/service/impl/DesignServiceImpl.java index b94820ea..6f041687 100644 --- a/src/main/java/com/ai/da/service/impl/DesignServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/DesignServiceImpl.java @@ -48,6 +48,7 @@ import java.math.RoundingMode; import java.time.LocalDateTime; import java.util.*; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -1796,6 +1797,15 @@ public class DesignServiceImpl extends ServiceImpl impleme return result; } + public static BigDecimal getRandomFromRange(BigDecimal min, BigDecimal max) { + if (min == null || max == null || min.compareTo(max) > 0) { + throw new IllegalArgumentException("Invalid imageStrength range."); + } + + double random = ThreadLocalRandom.current().nextDouble(min.doubleValue(), max.doubleValue() + 0.0001); + return new BigDecimal(random).setScale(2, RoundingMode.HALF_UP); + } + @Override public String designCloud(CloudTaskDTO cloudTaskDTO) { if (cloudTaskDTO.getBuildType().equals(BuildType.DESIGN.getValue())) { @@ -1851,6 +1861,9 @@ public class DesignServiceImpl extends ServiceImpl impleme int remainder = toProductImageDTO.getToProductImageVOList().size() % cloudTaskDTO.getNums(); // 剩下的余数 for (int i1 = 0; i1 < fullBatches; i1++) { + // TODO prompt微调 +// String newPrompt = pythonService.getPrompt(prompt); +// BigDecimal randomFromRange = getRandomFromRange(toProductImageDTO.getImageStrengthMin(), toProductImageDTO.getImageStrengthMax()); for (ToProductImageVO toProductImageVO : toProductImageDTO.getToProductImageVOList()) { String taskId; if (toProductImageVO.getElementType().equals("DesignOutfit")) { @@ -1935,6 +1948,10 @@ public class DesignServiceImpl extends ServiceImpl impleme } if (remainder > 0) { + // TODO: 随机 +// String newPrompt = pythonService.getPrompt(prompt); +// BigDecimal randomFromRange = getRandomFromRange(toProductImageDTO.getImageStrengthMin(), toProductImageDTO.getImageStrengthMax()); + List tempList = new ArrayList<>(toProductImageDTO.getToProductImageVOList()); Collections.shuffle(tempList); // 打乱顺序 diff --git a/src/main/java/com/ai/da/service/impl/LLMServiceImpl.java b/src/main/java/com/ai/da/service/impl/LLMServiceImpl.java index 361ffc8a..71812ad8 100644 --- a/src/main/java/com/ai/da/service/impl/LLMServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/LLMServiceImpl.java @@ -2,13 +2,17 @@ package com.ai.da.service.impl; import com.ai.da.common.constant.CommonConstant; import com.ai.da.common.context.UserContext; +import com.ai.da.common.response.PageBaseResponse; +import com.ai.da.common.response.Response; import com.ai.da.common.security.jwt.JWTTokenHelper; import com.ai.da.common.utils.MinioUtil; import com.ai.da.mapper.primary.*; import com.ai.da.mapper.primary.entity.*; +import com.ai.da.model.dto.ChatHistoryDTO; import com.ai.da.model.dto.ReceiveDesignParam; import com.ai.da.model.enums.*; import com.ai.da.model.vo.AuthPrincipalVo; +import com.ai.da.model.vo.UserLikeGroupVO; import com.ai.da.python.PythonService; import com.ai.da.service.DesignService; import com.ai.da.service.LLMService; @@ -16,6 +20,7 @@ import com.ai.da.service.SysFileService; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; import org.springframework.web.multipart.MultipartFile; @@ -266,12 +271,12 @@ public class LLMServiceImpl implements LLMService { } @Override - public List getChatHistory(Long projectId) { + public PageBaseResponse getChatHistory(ChatHistoryDTO chatHistoryDTO) { QueryWrapper qw = new QueryWrapper<>(); - qw.lambda().eq(ChatMessage::getProjectId, projectId); - qw.lambda().orderByAsc(ChatMessage::getSeq); - List chatMessages = chatMessageMapper.selectList(qw); - return chatMessages; + qw.lambda().eq(ChatMessage::getProjectId, chatHistoryDTO); + qw.lambda().orderByDesc(ChatMessage::getSeq); + Page chatMessagePage = chatMessageMapper.selectPage(new Page<>(chatHistoryDTO.getPage(), chatHistoryDTO.getSize()), qw); + return PageBaseResponse.success(chatMessagePage); } private int getNextSeq(Long projectId) {