Sketchboard与Printboard的generate功能 修改

This commit is contained in:
徐佩
2023-08-18 10:42:08 +08:00
parent c5e5b51852
commit b13feb8f1f
7 changed files with 46 additions and 29 deletions

View File

@@ -33,9 +33,9 @@ public class GenerateController {
@ApiOperation("通过文字、图片生成图片") @ApiOperation("通过文字、图片生成图片")
@PostMapping("/sketch") @PostMapping("/sketchAndPrint")
public Response<GenerateCollectionVO> generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO){ public Response<GenerateCollectionVO> generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO){
return Response.success(generateService.generateSketchThroughImageText(generateThroughImageTextDTO)); return Response.success(generateService.generateThroughImageText(generateThroughImageTextDTO));
} }

View File

@@ -42,6 +42,11 @@ public class Generate {
*/ */
private String generateType; private String generateType;
/**
* 模型名
*/
private String modelName;
/** /**
* 创建时间 * 创建时间
*/ */

View File

@@ -30,6 +30,9 @@ public class GenerateThroughImageTextDTO {
@ApiModelProperty("Outwear Dress Blouse Skirt Trousers") @ApiModelProperty("Outwear Dress Blouse Skirt Trousers")
String level2Type; String level2Type;
@ApiModelProperty("选择的模型名")
String version;
@NotBlank(message = "timeZone cannot be empty!") @NotBlank(message = "timeZone cannot be empty!")
@ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取") @ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取")
String timeZone; String timeZone;

View File

@@ -1433,7 +1433,7 @@ public class PythonService {
throw new BusinessException("system error!"); throw new BusinessException("system error!");
} }
public String generateSketch(String url,String text) { public String generateSketchOrPrint(String url, String text, int mode, String modelName) {
//限流校验 //限流校验
AccessLimitUtils.validate("generateSketch",5); AccessLimitUtils.validate("generateSketch",5);
OkHttpClient client = new OkHttpClient().newBuilder() OkHttpClient client = new OkHttpClient().newBuilder()
@@ -1443,9 +1443,11 @@ public class PythonService {
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build(); .build();
MediaType mediaType = MediaType.parse("application/json"); MediaType mediaType = MediaType.parse("application/json");
Map<String, String> content = Maps.newHashMap(); Map<String, Object> content = Maps.newHashMap();
content.put("img_url", url); content.put("img_url", url);
content.put("input", text); content.put("str", text);
content.put("mode",mode);
content.put("version",modelName);
RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content)); RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content));
Request request = new Request.Builder() Request request = new Request.Builder()
.url(accessPythonIp + ":2828/aida/diffusion") .url(accessPythonIp + ":2828/aida/diffusion")
@@ -1456,25 +1458,25 @@ public class PythonService {
Response response = null; Response response = null;
String bodyString = null; String bodyString = null;
try { try {
log.info("generateSketch请求入参content###{}", JSON.toJSONString(content)); log.info("generateSketchOrPrint请求入参content###{}", JSON.toJSONString(content));
response = client.newCall(request).execute(); response = client.newCall(request).execute();
bodyString = response.body().string(); bodyString = response.body().string();
} catch (IOException ioException) { } catch (IOException ioException) {
log.error("PythonService##generateSketch异常###{}", ExceptionUtil.getThrowableList(ioException)); log.error("PythonService##generateSketchOrPrint异常###{}", ExceptionUtil.getThrowableList(ioException));
} }
//去除限流 //去除限流
AccessLimitUtils.validateOut("generateSketch"); AccessLimitUtils.validateOut("generateSketchOrPrint");
if (Objects.isNull(response)) { if (Objects.isNull(response)) {
log.error("PythonService##generateSketch异常###{}", "response or body is empty!"); log.error("PythonService##generateSketchOrPrint异常###{}", "response or body is empty!");
throw new BusinessException("generate sketch exception!"); throw new BusinessException("generateSketchOrPrint exception!");
} }
JSONObject jsonObject = JSON.parseObject(JSON.toJSONString(response)); JSONObject jsonObject = JSON.parseObject(JSON.toJSONString(response));
Boolean result = jsonObject.getBoolean("successful"); Boolean result = jsonObject.getBoolean("successful");
if (result) { if (result) {
return bodyString; return bodyString;
} }
log.info("generate sketch失败###{}", jsonObject); log.info("generateSketchOrPrintPrint失败###{}", jsonObject);
//生成失败 //生成失败
throw new BusinessException("generate sketch exception!"); throw new BusinessException("generateSketchOrPrint exception!");
} }
} }

View File

@@ -8,5 +8,5 @@ public interface GenerateService {
GenerateCaptionVO generateCaption(Long sketchElementId); GenerateCaptionVO generateCaption(Long sketchElementId);
GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO); GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO);
} }

View File

@@ -53,7 +53,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper,Generate> im
} }
@Override @Override
public GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) { public GenerateCollectionVO generateThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、获取用户信息 // 1、获取用户信息
AuthPrincipalVo userHolder = UserContext.getUserHolder(); AuthPrincipalVo userHolder = UserContext.getUserHolder();
Long accountId = userHolder.getId(); Long accountId = userHolder.getId();
@@ -61,49 +61,53 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper,Generate> im
// 2、判断必须入参是否为非空 // 2、判断必须入参是否为非空
String generateType = generateThroughImageTextDTO.getGenerateType(); String generateType = generateThroughImageTextDTO.getGenerateType();
String text = generateThroughImageTextDTO.getText(); String text = generateThroughImageTextDTO.getText();
Long sketchId = generateThroughImageTextDTO.getCollectionElementId(); Long elementId = generateThroughImageTextDTO.getCollectionElementId();
String modelName = generateThroughImageTextDTO.getVersion();
Generate generate = new Generate(); Generate generate = new Generate();
generate.setAccountId(accountId); generate.setAccountId(accountId);
generate.setGenerateType(generateType); generate.setGenerateType(generateType);
generate.setModelName(StringUtil.isNullOrEmpty(modelName) ? "0" : modelName);
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone())); generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
int mode = 2;
switch(generateType){ switch(generateType){
case "text": case "text":
Assert.notNull(text,"Please input the caption"); Assert.notNull(text,"Please input the caption");
generate.setText(text); generate.setText(text);
mode = 1;
break; break;
case "image": case "image":
Assert.notNull(sketchId,"Please choose a sketch"); Assert.notNull(elementId,"Please choose a image");
generate.setCollectionElementId(sketchId); generate.setCollectionElementId(elementId);
break; break;
case "text-image": case "text-image":
Assert.isTrue(!StringUtil.isNullOrEmpty(text) && Objects.nonNull(sketchId), Assert.isTrue(!StringUtil.isNullOrEmpty(text) && Objects.nonNull(elementId),
"Please input the caption and choose a sketch"); "Please input the caption and choose a image");
generate.setText(text); generate.setText(text);
generate.setCollectionElementId(sketchId); generate.setCollectionElementId(elementId);
break; break;
} }
// 3、将请求信息落库 // 3、将请求信息落库
// 3.1 sketch在t_collection_element表中的信息是否需要更新 如 level2Type // 3.1 sketch在t_collection_element表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = null; CollectionElement collectionElement = null;
if(!Objects.isNull(sketchId)){ if(!Objects.isNull(elementId)){
collectionElement = collectionElementMapper.selectById(sketchId); collectionElement = collectionElementMapper.selectById(elementId);
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(generateThroughImageTextDTO.getLevel2Type()) ){ if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(generateThroughImageTextDTO.getLevel2Type()) ){
collectionElement.setLevel2Type(generateThroughImageTextDTO.getLevel2Type()); collectionElement.setLevel2Type(generateThroughImageTextDTO.getLevel2Type());
QueryWrapper<CollectionElement> queryWrapper = new QueryWrapper<>(); QueryWrapper<CollectionElement> queryWrapper = new QueryWrapper<>();
queryWrapper.eq("id", sketchId); queryWrapper.eq("id", elementId);
collectionElementMapper.update(collectionElement,queryWrapper); collectionElementMapper.update(collectionElement,queryWrapper);
} }
} }
// 3.2 将本次generate的请求信息添加到t_generate表中 // 3.2 将本次generate的请求信息添加到t_generate表中
save(generate); save(generate);
// 4、向模型发起请求 // 4、向模型发起请求
// String generatedSketchUrl = pythonService.generateSketch(collectionElement.getUrl(), text); // String generatedSketchUrl = pythonService.generateSketchOrPrint(collectionElement.getUrl(),text
// ,mode,generateThroughImageTextDTO.getVersion());
List<String> generatedSketchUrl = Arrays.asList("testUrl1","testUrl2","testUrl3","testUrl4"); List<String> generatedSketchUrl = Arrays.asList("testUrl1","testUrl2","testUrl3","testUrl4");

View File

@@ -1,5 +1,8 @@
#<23><><EFBFBD><EFBFBD>application-test<73>ļ<EFBFBD>(<28><><EFBFBD>Ի<EFBFBD><D4BB><EFBFBD>) #<23><><EFBFBD><EFBFBD>application-test<73>ļ<EFBFBD>(<28><><EFBFBD>Ի<EFBFBD><D4BB><EFBFBD>)
#spring.profiles.active=test
#<23><><EFBFBD><EFBFBD>application-prod<6F>ļ<EFBFBD>(<28><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>)
spring.profiles.active=test spring.profiles.active=test
#<23><><EFBFBD><EFBFBD>application-prod<6F>ļ<EFBFBD>(<28><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>)
#spring.profiles.active=prod
#<23><><EFBFBD><EFBFBD>application-dev<65>ļ<EFBFBD>(<28><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>)
#spring.profiles.active=dev