TASK:LLM;
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.ai.da.service;
|
||||
|
||||
import com.ai.da.common.response.PageBaseResponse;
|
||||
import com.ai.da.common.response.Response;
|
||||
import com.ai.da.mapper.primary.entity.Account;
|
||||
import com.ai.da.mapper.primary.entity.AccountExtend;
|
||||
import com.ai.da.mapper.primary.entity.ChatMessage;
|
||||
@@ -36,5 +37,5 @@ public interface LLMService {
|
||||
|
||||
List<String> uploadFile(MultipartFile file);
|
||||
|
||||
List<ChatMessage> getChatHistory(Long projectId);
|
||||
PageBaseResponse<ChatMessage> getChatHistory(ChatHistoryDTO chatHistoryDTO);
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ import java.math.RoundingMode;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -1796,6 +1797,15 @@ public class DesignServiceImpl extends ServiceImpl<DesignMapper, Design> impleme
|
||||
return result;
|
||||
}
|
||||
|
||||
public static BigDecimal getRandomFromRange(BigDecimal min, BigDecimal max) {
|
||||
if (min == null || max == null || min.compareTo(max) > 0) {
|
||||
throw new IllegalArgumentException("Invalid imageStrength range.");
|
||||
}
|
||||
|
||||
double random = ThreadLocalRandom.current().nextDouble(min.doubleValue(), max.doubleValue() + 0.0001);
|
||||
return new BigDecimal(random).setScale(2, RoundingMode.HALF_UP);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String designCloud(CloudTaskDTO cloudTaskDTO) {
|
||||
if (cloudTaskDTO.getBuildType().equals(BuildType.DESIGN.getValue())) {
|
||||
@@ -1851,6 +1861,9 @@ public class DesignServiceImpl extends ServiceImpl<DesignMapper, Design> impleme
|
||||
int remainder = toProductImageDTO.getToProductImageVOList().size() % cloudTaskDTO.getNums(); // 剩下的余数
|
||||
|
||||
for (int i1 = 0; i1 < fullBatches; i1++) {
|
||||
// TODO prompt微调
|
||||
// String newPrompt = pythonService.getPrompt(prompt);
|
||||
// BigDecimal randomFromRange = getRandomFromRange(toProductImageDTO.getImageStrengthMin(), toProductImageDTO.getImageStrengthMax());
|
||||
for (ToProductImageVO toProductImageVO : toProductImageDTO.getToProductImageVOList()) {
|
||||
String taskId;
|
||||
if (toProductImageVO.getElementType().equals("DesignOutfit")) {
|
||||
@@ -1935,6 +1948,10 @@ public class DesignServiceImpl extends ServiceImpl<DesignMapper, Design> impleme
|
||||
}
|
||||
|
||||
if (remainder > 0) {
|
||||
// TODO: 随机
|
||||
// String newPrompt = pythonService.getPrompt(prompt);
|
||||
// BigDecimal randomFromRange = getRandomFromRange(toProductImageDTO.getImageStrengthMin(), toProductImageDTO.getImageStrengthMax());
|
||||
|
||||
List<ToProductImageVO> tempList = new ArrayList<>(toProductImageDTO.getToProductImageVOList());
|
||||
Collections.shuffle(tempList); // 打乱顺序
|
||||
|
||||
|
||||
@@ -2,13 +2,17 @@ package com.ai.da.service.impl;
|
||||
|
||||
import com.ai.da.common.constant.CommonConstant;
|
||||
import com.ai.da.common.context.UserContext;
|
||||
import com.ai.da.common.response.PageBaseResponse;
|
||||
import com.ai.da.common.response.Response;
|
||||
import com.ai.da.common.security.jwt.JWTTokenHelper;
|
||||
import com.ai.da.common.utils.MinioUtil;
|
||||
import com.ai.da.mapper.primary.*;
|
||||
import com.ai.da.mapper.primary.entity.*;
|
||||
import com.ai.da.model.dto.ChatHistoryDTO;
|
||||
import com.ai.da.model.dto.ReceiveDesignParam;
|
||||
import com.ai.da.model.enums.*;
|
||||
import com.ai.da.model.vo.AuthPrincipalVo;
|
||||
import com.ai.da.model.vo.UserLikeGroupVO;
|
||||
import com.ai.da.python.PythonService;
|
||||
import com.ai.da.service.DesignService;
|
||||
import com.ai.da.service.LLMService;
|
||||
@@ -16,6 +20,7 @@ import com.ai.da.service.SysFileService;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.StringUtils;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
@@ -266,12 +271,12 @@ public class LLMServiceImpl implements LLMService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatMessage> getChatHistory(Long projectId) {
|
||||
public PageBaseResponse<ChatMessage> getChatHistory(ChatHistoryDTO chatHistoryDTO) {
|
||||
QueryWrapper<ChatMessage> qw = new QueryWrapper<>();
|
||||
qw.lambda().eq(ChatMessage::getProjectId, projectId);
|
||||
qw.lambda().orderByAsc(ChatMessage::getSeq);
|
||||
List<ChatMessage> chatMessages = chatMessageMapper.selectList(qw);
|
||||
return chatMessages;
|
||||
qw.lambda().eq(ChatMessage::getProjectId, chatHistoryDTO);
|
||||
qw.lambda().orderByDesc(ChatMessage::getSeq);
|
||||
Page<ChatMessage> chatMessagePage = chatMessageMapper.selectPage(new Page<>(chatHistoryDTO.getPage(), chatHistoryDTO.getSize()), qw);
|
||||
return PageBaseResponse.success(chatMessagePage);
|
||||
}
|
||||
|
||||
private int getNextSeq(Long projectId) {
|
||||
|
||||
Reference in New Issue
Block a user