diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java index 650ddabd..e3c67b02 100644 --- a/src/main/java/com/ai/da/controller/GenerateController.java +++ b/src/main/java/com/ai/da/controller/GenerateController.java @@ -33,9 +33,9 @@ public class GenerateController { @ApiOperation("通过文字、图片生成图片") - @PostMapping("/sketch") + @PostMapping("/sketchAndPrint") public Response generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO){ - return Response.success(generateService.generateSketchThroughImageText(generateThroughImageTextDTO)); + return Response.success(generateService.generateThroughImageText(generateThroughImageTextDTO)); } diff --git a/src/main/java/com/ai/da/mapper/entity/Generate.java b/src/main/java/com/ai/da/mapper/entity/Generate.java index 5e21ce85..603e621f 100644 --- a/src/main/java/com/ai/da/mapper/entity/Generate.java +++ b/src/main/java/com/ai/da/mapper/entity/Generate.java @@ -42,6 +42,11 @@ public class Generate { */ private String generateType; + /** + * 模型名 + */ + private String modelName; + /** * 创建时间 */ diff --git a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java index 02c38037..6883c64f 100644 --- a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java +++ b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java @@ -30,6 +30,9 @@ public class GenerateThroughImageTextDTO { @ApiModelProperty("Outwear Dress Blouse Skirt Trousers") String level2Type; + @ApiModelProperty("选择的模型名") + String version; + @NotBlank(message = "timeZone cannot be empty!") @ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取") String timeZone; diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index 0fc80aeb..5fa22e20 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -1433,7 +1433,7 @@ public class PythonService { throw new BusinessException("system error!"); } - public String generateSketch(String url,String text) { + public String generateSketchOrPrint(String url, String text, int mode, String modelName) { //限流校验 AccessLimitUtils.validate("generateSketch",5); OkHttpClient client = new OkHttpClient().newBuilder() @@ -1443,9 +1443,11 @@ public class PythonService { .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) .build(); MediaType mediaType = MediaType.parse("application/json"); - Map content = Maps.newHashMap(); + Map content = Maps.newHashMap(); content.put("img_url", url); - content.put("input", text); + content.put("str", text); + content.put("mode",mode); + content.put("version",modelName); RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content)); Request request = new Request.Builder() .url(accessPythonIp + ":2828/aida/diffusion") @@ -1456,25 +1458,25 @@ public class PythonService { Response response = null; String bodyString = null; try { - log.info("generateSketch请求入参content###{}", JSON.toJSONString(content)); + log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(content)); response = client.newCall(request).execute(); bodyString = response.body().string(); } catch (IOException ioException) { - log.error("PythonService##generateSketch异常###{}", ExceptionUtil.getThrowableList(ioException)); + log.error("PythonService##generateSketchOrPrint异常###{}", ExceptionUtil.getThrowableList(ioException)); } //去除限流 - AccessLimitUtils.validateOut("generateSketch"); + AccessLimitUtils.validateOut("generateSketchOrPrint"); if (Objects.isNull(response)) { - log.error("PythonService##generateSketch异常###{}", "response or body is empty!"); - throw new BusinessException("generate sketch exception!"); + log.error("PythonService##generateSketchOrPrint异常###{}", "response or body is empty!"); + throw new BusinessException("generateSketchOrPrint exception!"); } JSONObject jsonObject = JSON.parseObject(JSON.toJSONString(response)); Boolean result = jsonObject.getBoolean("successful"); if (result) { return bodyString; } - log.info("generate sketch失败###{}", jsonObject); + log.info("generateSketchOrPrintPrint失败###{}", jsonObject); //生成失败 - throw new BusinessException("generate sketch exception!"); + throw new BusinessException("generateSketchOrPrint exception!"); } } diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index af8c75bb..d7bc029a 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -8,5 +8,5 @@ public interface GenerateService { GenerateCaptionVO generateCaption(Long sketchElementId); - GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO); + GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO); } diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 1cfa5865..50947a08 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -53,7 +53,7 @@ public class GenerateServiceImpl extends ServiceImpl im } @Override - public GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、获取用户信息 AuthPrincipalVo userHolder = UserContext.getUserHolder(); Long accountId = userHolder.getId(); @@ -61,49 +61,53 @@ public class GenerateServiceImpl extends ServiceImpl im // 2、判断必须入参是否为非空 String generateType = generateThroughImageTextDTO.getGenerateType(); String text = generateThroughImageTextDTO.getText(); - Long sketchId = generateThroughImageTextDTO.getCollectionElementId(); + Long elementId = generateThroughImageTextDTO.getCollectionElementId(); + String modelName = generateThroughImageTextDTO.getVersion(); Generate generate = new Generate(); generate.setAccountId(accountId); generate.setGenerateType(generateType); + generate.setModelName(StringUtil.isNullOrEmpty(modelName) ? "0" : modelName); generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone())); + int mode = 2; switch(generateType){ case "text": Assert.notNull(text,"Please input the caption"); generate.setText(text); + mode = 1; break; case "image": - Assert.notNull(sketchId,"Please choose a sketch"); - generate.setCollectionElementId(sketchId); + Assert.notNull(elementId,"Please choose a image"); + generate.setCollectionElementId(elementId); break; case "text-image": - Assert.isTrue(!StringUtil.isNullOrEmpty(text) && Objects.nonNull(sketchId), - "Please input the caption and choose a sketch"); + Assert.isTrue(!StringUtil.isNullOrEmpty(text) && Objects.nonNull(elementId), + "Please input the caption and choose a image"); generate.setText(text); - generate.setCollectionElementId(sketchId); + generate.setCollectionElementId(elementId); break; } // 3、将请求信息落库 // 3.1 sketch在t_collection_element表中的信息是否需要更新 如 level2Type CollectionElement collectionElement = null; - if(!Objects.isNull(sketchId)){ - collectionElement = collectionElementMapper.selectById(sketchId); + if(!Objects.isNull(elementId)){ + collectionElement = collectionElementMapper.selectById(elementId); if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(generateThroughImageTextDTO.getLevel2Type()) ){ collectionElement.setLevel2Type(generateThroughImageTextDTO.getLevel2Type()); QueryWrapper queryWrapper = new QueryWrapper<>(); - queryWrapper.eq("id", sketchId); + queryWrapper.eq("id", elementId); collectionElementMapper.update(collectionElement,queryWrapper); } } - // 3.2 将本次generate的请求信息添加到t_generate表中 save(generate); // 4、向模型发起请求 -// String generatedSketchUrl = pythonService.generateSketch(collectionElement.getUrl(), text); +// String generatedSketchUrl = pythonService.generateSketchOrPrint(collectionElement.getUrl(),text +// ,mode,generateThroughImageTextDTO.getVersion()); List generatedSketchUrl = Arrays.asList("testUrl1","testUrl2","testUrl3","testUrl4"); diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 464081e6..e8510ce9 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1,5 +1,8 @@ -#application-testļ(Ի) -#spring.profiles.active=test - -#application-prodļ() +#����application-test�ļ�(���Ի���) spring.profiles.active=test + +#����application-prod�ļ�(��������) +#spring.profiles.active=prod + +#����application-dev�ļ�(��������) +#spring.profiles.active=dev \ No newline at end of file