TASK:LLM;

This commit is contained in:
shahaibo
2025-05-19 10:00:19 +08:00
parent 59ffa38ff7
commit 6aa1a3d167
5 changed files with 39 additions and 9 deletions

View File

@@ -1,6 +1,7 @@
package com.ai.da.service;
import com.ai.da.common.response.PageBaseResponse;
import com.ai.da.common.response.Response;
import com.ai.da.mapper.primary.entity.Account;
import com.ai.da.mapper.primary.entity.AccountExtend;
import com.ai.da.mapper.primary.entity.ChatMessage;
@@ -36,5 +37,5 @@ public interface LLMService {
List<String> uploadFile(MultipartFile file);
List<ChatMessage> getChatHistory(Long projectId);
PageBaseResponse<ChatMessage> getChatHistory(ChatHistoryDTO chatHistoryDTO);
}

View File

@@ -48,6 +48,7 @@ import java.math.RoundingMode;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@@ -1796,6 +1797,15 @@ public class DesignServiceImpl extends ServiceImpl<DesignMapper, Design> impleme
return result;
}
public static BigDecimal getRandomFromRange(BigDecimal min, BigDecimal max) {
if (min == null || max == null || min.compareTo(max) > 0) {
throw new IllegalArgumentException("Invalid imageStrength range.");
}
double random = ThreadLocalRandom.current().nextDouble(min.doubleValue(), max.doubleValue() + 0.0001);
return new BigDecimal(random).setScale(2, RoundingMode.HALF_UP);
}
@Override
public String designCloud(CloudTaskDTO cloudTaskDTO) {
if (cloudTaskDTO.getBuildType().equals(BuildType.DESIGN.getValue())) {
@@ -1851,6 +1861,9 @@ public class DesignServiceImpl extends ServiceImpl<DesignMapper, Design> impleme
int remainder = toProductImageDTO.getToProductImageVOList().size() % cloudTaskDTO.getNums(); // 剩下的余数
for (int i1 = 0; i1 < fullBatches; i1++) {
// TODO prompt微调
// String newPrompt = pythonService.getPrompt(prompt);
// BigDecimal randomFromRange = getRandomFromRange(toProductImageDTO.getImageStrengthMin(), toProductImageDTO.getImageStrengthMax());
for (ToProductImageVO toProductImageVO : toProductImageDTO.getToProductImageVOList()) {
String taskId;
if (toProductImageVO.getElementType().equals("DesignOutfit")) {
@@ -1935,6 +1948,10 @@ public class DesignServiceImpl extends ServiceImpl<DesignMapper, Design> impleme
}
if (remainder > 0) {
// TODO: 随机
// String newPrompt = pythonService.getPrompt(prompt);
// BigDecimal randomFromRange = getRandomFromRange(toProductImageDTO.getImageStrengthMin(), toProductImageDTO.getImageStrengthMax());
List<ToProductImageVO> tempList = new ArrayList<>(toProductImageDTO.getToProductImageVOList());
Collections.shuffle(tempList); // 打乱顺序

View File

@@ -2,13 +2,17 @@ package com.ai.da.service.impl;
import com.ai.da.common.constant.CommonConstant;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.response.PageBaseResponse;
import com.ai.da.common.response.Response;
import com.ai.da.common.security.jwt.JWTTokenHelper;
import com.ai.da.common.utils.MinioUtil;
import com.ai.da.mapper.primary.*;
import com.ai.da.mapper.primary.entity.*;
import com.ai.da.model.dto.ChatHistoryDTO;
import com.ai.da.model.dto.ReceiveDesignParam;
import com.ai.da.model.enums.*;
import com.ai.da.model.vo.AuthPrincipalVo;
import com.ai.da.model.vo.UserLikeGroupVO;
import com.ai.da.python.PythonService;
import com.ai.da.service.DesignService;
import com.ai.da.service.LLMService;
@@ -16,6 +20,7 @@ import com.ai.da.service.SysFileService;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import org.springframework.web.multipart.MultipartFile;
@@ -266,12 +271,12 @@ public class LLMServiceImpl implements LLMService {
}
@Override
public List<ChatMessage> getChatHistory(Long projectId) {
public PageBaseResponse<ChatMessage> getChatHistory(ChatHistoryDTO chatHistoryDTO) {
QueryWrapper<ChatMessage> qw = new QueryWrapper<>();
qw.lambda().eq(ChatMessage::getProjectId, projectId);
qw.lambda().orderByAsc(ChatMessage::getSeq);
List<ChatMessage> chatMessages = chatMessageMapper.selectList(qw);
return chatMessages;
qw.lambda().eq(ChatMessage::getProjectId, chatHistoryDTO);
qw.lambda().orderByDesc(ChatMessage::getSeq);
Page<ChatMessage> chatMessagePage = chatMessageMapper.selectPage(new Page<>(chatHistoryDTO.getPage(), chatHistoryDTO.getSize()), qw);
return PageBaseResponse.success(chatMessagePage);
}
private int getNextSeq(Long projectId) {