diff --git a/src/main/java/com/ai/da/common/security/filter/AuthenticationFilter.java b/src/main/java/com/ai/da/common/security/filter/AuthenticationFilter.java index b9b586f1..957c17e7 100644 --- a/src/main/java/com/ai/da/common/security/filter/AuthenticationFilter.java +++ b/src/main/java/com/ai/da/common/security/filter/AuthenticationFilter.java @@ -53,7 +53,8 @@ public class AuthenticationFilter extends OncePerRequestFilter { "/api/portfolio/page", "/api/portfolio/detail", "/api/portfolio/commentPage", "/api/portfolio/viewsIncrease", "/api/account/designWorksRegister","/api/account/questionnaire","/api/stripe/trade/notify", "/notification","/api/account/activateNewEmail","/api/third/party/auth/google_callback","/api/third/party/parseGoogleCredential","/api/third/party/receiveDesignResults","/api/third/party/parseWeChatCode","/api/third/party/receiveDesignParams" - , "api/account/schoolLogin", "api/account/enterpriseLogin", "api/account/organizationNameSearch" + , "api/account/schoolLogin", "api/account/enterpriseLogin", "api/account/organizationNameSearch", + "/api/llm/stream" ); @Override diff --git a/src/main/java/com/ai/da/controller/LLMController.java b/src/main/java/com/ai/da/controller/LLMController.java new file mode 100644 index 00000000..565f0b03 --- /dev/null +++ b/src/main/java/com/ai/da/controller/LLMController.java @@ -0,0 +1,81 @@ +package com.ai.da.controller; + +import com.ai.da.common.config.exception.BusinessException; +import com.ai.da.common.response.PageBaseResponse; +import com.ai.da.common.response.Response; +import com.ai.da.mapper.primary.entity.Account; +import com.ai.da.mapper.primary.entity.AccountExtend; +import com.ai.da.mapper.primary.entity.ChatMessage; +import com.ai.da.mapper.primary.entity.TrialOrder; +import com.ai.da.model.dto.*; +import com.ai.da.model.vo.AccountLoginVO; +import com.ai.da.model.vo.AccountPreLoginVO; +import com.ai.da.model.vo.BindEmailVO; +import com.ai.da.model.vo.PersonalHomepageVO; +import com.ai.da.service.AccountService; +import com.ai.da.service.LLMService; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.util.StringUtils; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import javax.annotation.Resource; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.validation.Valid; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + + +@Api(tags = "llm模块") +@Slf4j +@RestController +@RequestMapping("/api/llm") +public class LLMController { + + @Resource + private LLMService llmService; + + @ApiOperation(value = "对话") + @CrossOrigin + @GetMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter streamPrompt(@RequestParam String prompt, + @RequestParam Long projectId, + @RequestParam(required = false) String fileUrl, + @RequestParam String token) { + return llmService.stream(prompt, projectId, fileUrl, token); + } + + @ApiOperation(value = "对话创建项目") + @GetMapping(value = "/chatCreateProject") + public Response chatCreateProject(@RequestParam String prompt) { + return Response.success(llmService.chatCreateProject(prompt)); + } + + @ApiOperation(value = "上传文件") + @GetMapping(value = "/uploadFile") + public Response> uploadFile(@RequestParam MultipartFile file) { + return Response.success(llmService.uploadFile(file)); + } + + @ApiOperation(value = "获取历史聊天记录") + @GetMapping(value = "/getChatHistory") + public Response> getChatHistory(@RequestParam Long projectId) { + return Response.success(llmService.getChatHistory(projectId)); + } +} diff --git a/src/main/java/com/ai/da/mapper/primary/ChatMessageMapper.java b/src/main/java/com/ai/da/mapper/primary/ChatMessageMapper.java new file mode 100644 index 00000000..f852bd41 --- /dev/null +++ b/src/main/java/com/ai/da/mapper/primary/ChatMessageMapper.java @@ -0,0 +1,7 @@ +package com.ai.da.mapper.primary; + +import com.ai.da.common.config.mybatis.plus.CommonMapper; +import com.ai.da.mapper.primary.entity.ChatMessage; + +public interface ChatMessageMapper extends CommonMapper { +} diff --git a/src/main/java/com/ai/da/mapper/primary/entity/ChatMessage.java b/src/main/java/com/ai/da/mapper/primary/entity/ChatMessage.java new file mode 100644 index 00000000..b1886773 --- /dev/null +++ b/src/main/java/com/ai/da/mapper/primary/entity/ChatMessage.java @@ -0,0 +1,35 @@ +package com.ai.da.mapper.primary.entity; + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.Accessors; + +import java.io.Serializable; +import java.time.LocalDateTime; + +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(chain = true) +@TableName("chat_message") +public class ChatMessage implements Serializable { + + private static final long serialVersionUID = 1L; + + @TableId(value = "id", type = IdType.AUTO) + private Long id; + + private Long projectId; + + private String role; + + private Integer seq; + + private String content; + + private Long accountId; + + private LocalDateTime createTime; +} diff --git a/src/main/java/com/ai/da/mapper/primary/entity/Workspace.java b/src/main/java/com/ai/da/mapper/primary/entity/Workspace.java index 1fea7108..e4277c63 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/Workspace.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/Workspace.java @@ -101,5 +101,9 @@ public class Workspace implements Serializable { private Long projectId; + private Integer userBrandDna; + + private Integer brandPercentage; + } diff --git a/src/main/java/com/ai/da/model/dto/DesignCollectionDTO.java b/src/main/java/com/ai/da/model/dto/DesignCollectionDTO.java index dc92ad3f..45da3381 100644 --- a/src/main/java/com/ai/da/model/dto/DesignCollectionDTO.java +++ b/src/main/java/com/ai/da/model/dto/DesignCollectionDTO.java @@ -69,4 +69,8 @@ public class DesignCollectionDTO { private Integer designNum; + private Long brandId; + + private Double brandScale; + } diff --git a/src/main/java/com/ai/da/model/dto/ToProductImageDTO.java b/src/main/java/com/ai/da/model/dto/ToProductImageDTO.java index 5fa93cc7..769a0519 100644 --- a/src/main/java/com/ai/da/model/dto/ToProductImageDTO.java +++ b/src/main/java/com/ai/da/model/dto/ToProductImageDTO.java @@ -14,4 +14,6 @@ public class ToProductImageDTO { private BigDecimal imageStrength; private String direction; private Double brightenValue; + private BigDecimal imageStrengthMin; + private BigDecimal imageStrengthMax; } diff --git a/src/main/java/com/ai/da/model/enums/AgeGroup.java b/src/main/java/com/ai/da/model/enums/AgeGroup.java index e5f576ae..fec476e2 100644 --- a/src/main/java/com/ai/da/model/enums/AgeGroup.java +++ b/src/main/java/com/ai/da/model/enums/AgeGroup.java @@ -28,4 +28,13 @@ public enum AgeGroup implements IEnumDisplay { } throw new IllegalArgumentException("No matching constant for [" + value + "]"); } + + public static boolean isValidName(String name) { + for (AgeGroup ageGroup : AgeGroup.values()) { + if (ageGroup.name().equalsIgnoreCase(name)) { + return true; + } + } + return false; + } } diff --git a/src/main/java/com/ai/da/model/enums/DesignProcess.java b/src/main/java/com/ai/da/model/enums/DesignProcess.java index e41e865a..4a0d3c33 100644 --- a/src/main/java/com/ai/da/model/enums/DesignProcess.java +++ b/src/main/java/com/ai/da/model/enums/DesignProcess.java @@ -31,4 +31,13 @@ public enum DesignProcess implements IEnumDisplay { } throw new IllegalArgumentException("No matching constant for [" + value + "]"); } + + public static boolean isValidName(String name) { + for (DesignProcess process : DesignProcess.values()) { + if (process.name().equalsIgnoreCase(name)) { + return true; + } + } + return false; + } } diff --git a/src/main/java/com/ai/da/model/enums/Position.java b/src/main/java/com/ai/da/model/enums/Position.java index 93ca5b63..a81f41a5 100644 --- a/src/main/java/com/ai/da/model/enums/Position.java +++ b/src/main/java/com/ai/da/model/enums/Position.java @@ -27,6 +27,15 @@ public enum Position implements IEnumDisplay { this.value = value; } + public static boolean isValidName(String name) { + for (Position position : Position.values()) { + if (position.name().equalsIgnoreCase(name)) { + return true; + } + } + return false; + } + @Override @JsonValue public String getValue() { diff --git a/src/main/java/com/ai/da/model/enums/Sex.java b/src/main/java/com/ai/da/model/enums/Sex.java index 9fef3460..3e4c7f57 100644 --- a/src/main/java/com/ai/da/model/enums/Sex.java +++ b/src/main/java/com/ai/da/model/enums/Sex.java @@ -22,6 +22,15 @@ public enum Sex implements IEnumDisplay { this.value = value; } + public static boolean isValidName(String name) { + for (Sex sex : Sex.values()) { + if (sex.name().equalsIgnoreCase(name)) { + return true; + } + } + return false; + } + @Override @JsonValue public String getValue() { diff --git a/src/main/java/com/ai/da/model/enums/StyleEnum.java b/src/main/java/com/ai/da/model/enums/StyleEnum.java index 1b059f3f..dcd62fb4 100644 --- a/src/main/java/com/ai/da/model/enums/StyleEnum.java +++ b/src/main/java/com/ai/da/model/enums/StyleEnum.java @@ -31,6 +31,15 @@ public enum StyleEnum { this.english = english; } + public static boolean isValidName(String name) { + for (StyleEnum styleEnum : StyleEnum.values()) { + if (styleEnum.name().equalsIgnoreCase(name)) { + return true; + } + } + return false; + } + // 获取中文描述 public String getChinese() { return chinese; diff --git a/src/main/java/com/ai/da/model/vo/ValidateElementVO.java b/src/main/java/com/ai/da/model/vo/ValidateElementVO.java index 2c0f6c2f..85dd456f 100644 --- a/src/main/java/com/ai/da/model/vo/ValidateElementVO.java +++ b/src/main/java/com/ai/da/model/vo/ValidateElementVO.java @@ -55,4 +55,8 @@ public class ValidateElementVO { private Long collectionId; private Long accountId; + + private Long brandId; + + private Double brandScale; } diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index 7e915a95..18623875 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -4382,4 +4382,46 @@ public class PythonService { log.error("PythonService##poseTransferBatch接口调用失败###{}", response); throw new BusinessException("poseTransferBatch.interface.exception"); } + + public JSONObject getProjectParam(String prompt) { + OkHttpClient client = new OkHttpClient().newBuilder() + .connectTimeout(30, TimeUnit.SECONDS) + .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒) + .readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒) + .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) + .build(); + MediaType mediaType = MediaType.parse("application/json"); + Map content = Maps.newHashMap(); + content.put("prompt", prompt); + RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content)); + + log.info("getProjectParam 请求地址: {}", accessPythonIp + ":" + accessPythonPort + "/api/extraction_project_info"); + Request request = new Request.Builder() + .url(accessPythonIp + ":" + accessPythonPort + "/api/extraction_project_info") + .method("POST", body) + .addHeader("Content-Type", "application/json") + .build(); + Response response = null; + try { + response = client.newCall(request).execute(); + } catch (IOException ioException) { + log.error("PythonService##getProjectParam异常###{}", ExceptionUtil.getThrowableList(ioException)); + throw new BusinessException("getProjectParam.interface.exception"); + } + String responseBody; + if (response.isSuccessful() && response.body() != null) { + try { + responseBody = response.body().string(); + JSONObject responseObject = JSON.parseObject(responseBody); + log.info("PythonService##responseObject###{}", responseObject); + return responseObject; + } catch (IOException | JSONException e) { + log.error("PythonService##getProjectParam异常###{}", e.getMessage()); + throw new BusinessException("getProjectParam.interface.exception"); + } + } + + log.error("PythonService##getProjectParam接口调用失败###{}", response); + throw new BusinessException("getProjectParam.interface.exception"); + } } diff --git a/src/main/java/com/ai/da/service/LLMService.java b/src/main/java/com/ai/da/service/LLMService.java new file mode 100644 index 00000000..4dab01dd --- /dev/null +++ b/src/main/java/com/ai/da/service/LLMService.java @@ -0,0 +1,40 @@ +package com.ai.da.service; + +import com.ai.da.common.response.PageBaseResponse; +import com.ai.da.mapper.primary.entity.Account; +import com.ai.da.mapper.primary.entity.AccountExtend; +import com.ai.da.mapper.primary.entity.ChatMessage; +import com.ai.da.mapper.primary.entity.TrialOrder; +import com.ai.da.model.dto.*; +import com.ai.da.model.vo.AccountLoginVO; +import com.ai.da.model.vo.AccountPreLoginVO; +import com.ai.da.model.vo.BindEmailVO; +import com.ai.da.model.vo.PersonalHomepageVO; +import com.baomidou.mybatisplus.core.metadata.IPage; +import com.baomidou.mybatisplus.extension.service.IService; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * 服务类 + * + * @author easy-generator + * @since 2022-08-11 + */ +public interface LLMService { + + SseEmitter stream(String prompt, Long projectId, String fileUrl, String token); + + Long chatCreateProject(String prompt); + + List uploadFile(MultipartFile file); + + List getChatHistory(Long projectId); +} diff --git a/src/main/java/com/ai/da/service/ProductImageService.java b/src/main/java/com/ai/da/service/ProductImageService.java index 8727aa15..29cbe0b2 100644 --- a/src/main/java/com/ai/da/service/ProductImageService.java +++ b/src/main/java/com/ai/da/service/ProductImageService.java @@ -1,8 +1,9 @@ package com.ai.da.service; import com.ai.da.model.dto.ProgressDTO; +import com.ai.da.model.vo.AuthPrincipalVo; public interface ProductImageService { - void asyncInitialize(Long brandId); + void asyncInitialize(Long brandId, AuthPrincipalVo userHolder); // double getInitializeProgress(Long brandId); } diff --git a/src/main/java/com/ai/da/service/SysFileService.java b/src/main/java/com/ai/da/service/SysFileService.java index a2a04ad5..6d93bef7 100644 --- a/src/main/java/com/ai/da/service/SysFileService.java +++ b/src/main/java/com/ai/da/service/SysFileService.java @@ -62,4 +62,6 @@ public interface SysFileService extends IService { List getByIds(List ids); SysFile getOneBySex(Long styleId, String sex, String ageGroup); + + SysFile getOneBySex(String styleName, String sex, String ageGroup); } diff --git a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java index 8bef0851..7b1ae53f 100644 --- a/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/CollectionElementServiceImpl.java @@ -622,6 +622,10 @@ public class CollectionElementServiceImpl extends ServiceImpl { + try { + boolean validate = jwtTokenHelper.validateToken(token); + if (validate) { + AuthPrincipalVo principal = jwtTokenHelper.parserToUser(token); + Long accountId = principal.getId(); + int userSeq = getNextSeq(projectId); // 获取当前session下一条消息序号 + String url = "http://18.167.251.121:10002/chat-stream"; + HttpURLConnection conn = (HttpURLConnection) new URL(url).openConnection(); + conn.setRequestMethod("POST"); + conn.setDoOutput(true); + conn.setRequestProperty("Content-Type", "application/json"); + + JSONObject jsonBodyObject = new JSONObject(); + jsonBodyObject.put("session_id", projectId.toString()); + jsonBodyObject.put("role", "user"); + jsonBodyObject.put("image", ""); // 可扩展 + jsonBodyObject.put("file", fileUrl != null ? fileUrl : ""); + jsonBodyObject.put("message", prompt); + jsonBodyObject.put("enable_thinking", false); + + // 1. 存储用户输入 + ChatMessage userMessage = new ChatMessage(); + userMessage.setRole("user"); + userMessage.setProjectId(projectId); + userMessage.setSeq(userSeq); + userMessage.setCreateTime(LocalDateTime.now()); + userMessage.setContent(jsonBodyObject.toJSONString()); + userMessage.setAccountId(accountId); + chatMessageMapper.insert(userMessage); + + try (OutputStream os = conn.getOutputStream()) { + byte[] input = jsonBodyObject.toJSONString().getBytes(StandardCharsets.UTF_8); + os.write(input, 0, input.length); + } + + // 2. 流式接收并累积内容 + StringBuilder responseBuilder = new StringBuilder(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + line = line.trim(); + if (!line.isEmpty() && line.startsWith("data: ")) { + String jsonStr = line.substring(6); + System.out.println(jsonStr); + JSONObject json = JSON.parseObject(jsonStr); + String status = json.getString("status"); + + if ("[DONE]".equals(status)) { + break; + } + + if (!StringUtils.isEmpty(status)) { + String content = json.getString("content"); + if (!status.equals("[RUNNING]") && !status.equals("[DESIGN_SIGNAL]")) { + JSONObject toolsData = json.getJSONObject("tools_data"); + ReceiveDesignParam receiveDesignParam = JSONObject.parseObject(JSONObject.toJSONString(toolsData), ReceiveDesignParam.class); + receiveDesignParam.setProjectId(projectId); + designService.receiveDesignParams(receiveDesignParam); + } + if (content != null) { + responseBuilder.append(content); + emitter.send(json.toJSONString()); + } + } + } + } + } + + // 3. 存储系统回复 + int systemSeq = getNextSeq(projectId); + ChatMessage systemMessage = new ChatMessage(); + systemMessage.setRole("user"); + systemMessage.setProjectId(projectId); + systemMessage.setSeq(systemSeq); + systemMessage.setCreateTime(LocalDateTime.now()); + systemMessage.setContent(responseBuilder.toString()); + systemMessage.setAccountId(accountId); + chatMessageMapper.insert(systemMessage); + } + + emitter.complete(); + } catch (Exception e) { + emitter.completeWithError(e); + } + }); + + return emitter; + } + + @Override + public Long chatCreateProject(String prompt) { + AuthPrincipalVo userHolder = UserContext.getUserHolder(); + JSONObject jsonObject = pythonService.getProjectParam(prompt); + JSONObject data = jsonObject.getJSONObject("data"); + Project project = new Project(); + LocalDateTime now = LocalDateTime.now(); + project.setUpdateTime(now); + project.setCreateTime(now); + project.setAccountId(userHolder.getId()); + project.setName(data.getString("project_name")); + project.setOriginal(1); + String process = data.getString("process"); + if (StringUtils.isEmpty(process)) { + project.setProcess(DesignProcess.SERIES_DESIGN.name()); + }else { + if (DesignProcess.isValidName(process)) { + project.setProcess(process); + }else { + project.setProcess(DesignProcess.SERIES_DESIGN.name()); + } + } + projectMapper.insert(project); + + Workspace workspace = new Workspace(); + workspace.setAccountId(userHolder.getId()); + workspace.setCreateTime(now); + String ageGroup = data.getString("ageGroup"); + if (StringUtils.isEmpty(ageGroup)) { + workspace.setAgeGroup("Adult"); + }else { + if (AgeGroup.isValidName(process)) { + workspace.setAgeGroup(ageGroup); + }else { + workspace.setAgeGroup("Adult"); + } + } + String gender = data.getString("gender"); + if (StringUtils.isEmpty(gender)) { + workspace.setSex("Female"); + }else { + if (Sex.isValidName(gender)) { + workspace.setSex(gender); + }else { + workspace.setSex("Female"); + } + } + String position = data.getString("position"); + if (StringUtils.isEmpty(position)) { + workspace.setPosition("Overall"); + }else { + if (Position.isValidName(position)) { + workspace.setPosition(position); + }else { + workspace.setPosition("Overall"); + } + } + workspace.setSystemDesignerPercentage(30); + workspace.setProjectId(project.getId()); + + String style = data.getString("style"); + String styleName = null; + if (StringUtils.isEmpty(style)) { + styleName = StyleEnum.NEW_CHINESE.name(); + }else { + if (StyleEnum.isValidName(style)) { + styleName = style; + }else { + styleName = StyleEnum.NEW_CHINESE.name(); + } + } + + SysFile sysFile = sysFileService.getOneBySex(styleName, workspace.getSex(), workspace.getAgeGroup()); + + if (workspace.getSex().equals(Sex.FEMALE.getValue())) { + workspace.setMannequinFemaleId(sysFile.getId()); + workspace.setMannequinFemaleType("System"); + }else { + workspace.setMannequinMaleId(sysFile.getId()); + workspace.setMannequinMaleType("System"); + } + + workspaceMapper.insert(workspace); + + if (!StringUtils.isEmpty(styleName)) { + QueryWrapper