generateSketch 功能
This commit is contained in:
47
src/main/java/com/ai/da/controller/GenerateController.java
Normal file
47
src/main/java/com/ai/da/controller/GenerateController.java
Normal file
@@ -0,0 +1,47 @@
|
||||
package com.ai.da.controller;
|
||||
|
||||
import com.ai.da.common.response.Response;
|
||||
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
||||
import com.ai.da.model.vo.GenerateCaptionVO;
|
||||
import com.ai.da.model.vo.GenerateCollectionVO;
|
||||
import com.ai.da.service.GenerateService;
|
||||
import io.swagger.annotations.Api;
|
||||
import io.swagger.annotations.ApiOperation;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import javax.validation.Valid;
|
||||
|
||||
/**
|
||||
* @author XP
|
||||
*/
|
||||
@Api(tags = "Generate模块")
|
||||
@Slf4j
|
||||
@RestController
|
||||
@RequestMapping("/api/generate")
|
||||
public class GenerateController {
|
||||
|
||||
@Resource
|
||||
private GenerateService generateService;
|
||||
|
||||
@ApiOperation("自动识别sketch的caption")
|
||||
@PostMapping("/caption")
|
||||
public Response<GenerateCaptionVO> generateCaption(@RequestParam Long sketchElementId){
|
||||
return Response.success(generateService.generateCaption(sketchElementId));
|
||||
}
|
||||
|
||||
|
||||
@ApiOperation("通过文字、图片生成图片")
|
||||
@PostMapping("/sketch")
|
||||
public Response<GenerateCollectionVO> generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO){
|
||||
return Response.success(generateService.generateSketchThroughImageText(generateThroughImageTextDTO));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
7
src/main/java/com/ai/da/mapper/GenerateDetailMapper.java
Normal file
7
src/main/java/com/ai/da/mapper/GenerateDetailMapper.java
Normal file
@@ -0,0 +1,7 @@
|
||||
package com.ai.da.mapper;
|
||||
|
||||
import com.ai.da.common.config.mybatis.plus.CommonMapper;
|
||||
import com.ai.da.mapper.entity.GenerateDetail;
|
||||
|
||||
public interface GenerateDetailMapper extends CommonMapper<GenerateDetail> {
|
||||
}
|
||||
7
src/main/java/com/ai/da/mapper/GenerateMapper.java
Normal file
7
src/main/java/com/ai/da/mapper/GenerateMapper.java
Normal file
@@ -0,0 +1,7 @@
|
||||
package com.ai.da.mapper;
|
||||
|
||||
import com.ai.da.common.config.mybatis.plus.CommonMapper;
|
||||
import com.ai.da.mapper.entity.Generate;
|
||||
|
||||
public interface GenerateMapper extends CommonMapper<Generate> {
|
||||
}
|
||||
55
src/main/java/com/ai/da/mapper/entity/Generate.java
Normal file
55
src/main/java/com/ai/da/mapper/entity/Generate.java
Normal file
@@ -0,0 +1,55 @@
|
||||
package com.ai.da.mapper.entity;
|
||||
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Accessors(chain = true)
|
||||
@TableName("t_generate")
|
||||
public class Generate {
|
||||
|
||||
/**
|
||||
* ID
|
||||
*/
|
||||
@TableId(value = "id", type = IdType.AUTO)
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
* 用户ID
|
||||
*/
|
||||
private Long accountId;
|
||||
|
||||
/**
|
||||
* 关联collection element id
|
||||
*/
|
||||
private Long collectionElementId;
|
||||
|
||||
/**
|
||||
* caption的内容
|
||||
*/
|
||||
private String text;
|
||||
|
||||
/**
|
||||
* 选择生成类型:text、image、text-image
|
||||
*/
|
||||
private String generateType;
|
||||
|
||||
/**
|
||||
* 创建时间
|
||||
*/
|
||||
private Date createDate;
|
||||
|
||||
/**
|
||||
* 更新时间
|
||||
*/
|
||||
private Date updateDate;
|
||||
|
||||
}
|
||||
49
src/main/java/com/ai/da/mapper/entity/GenerateDetail.java
Normal file
49
src/main/java/com/ai/da/mapper/entity/GenerateDetail.java
Normal file
@@ -0,0 +1,49 @@
|
||||
package com.ai.da.mapper.entity;
|
||||
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@Accessors(chain = true)
|
||||
@TableName("t_generate_detail")
|
||||
public class GenerateDetail {
|
||||
|
||||
/**
|
||||
* ID
|
||||
*/
|
||||
@TableId(value = "id", type = IdType.AUTO)
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
* 关联 generate ID
|
||||
*/
|
||||
private Long generateId;
|
||||
|
||||
/**
|
||||
* 模型返回的图片url
|
||||
*/
|
||||
private String url;
|
||||
|
||||
/**
|
||||
* 创建时间
|
||||
*/
|
||||
private Date createDate;
|
||||
|
||||
/**
|
||||
* 更新时间
|
||||
*/
|
||||
private Date updateDate;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.ai.da.model.dto;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import javax.validation.constraints.NotBlank;
|
||||
|
||||
@Data
|
||||
@ApiModel("GenerateThroughImageTextDTO")
|
||||
public class GenerateThroughImageTextDTO {
|
||||
|
||||
@ApiModelProperty("caption")
|
||||
String text;
|
||||
|
||||
@ApiModelProperty("图片在t_collection_element表中的id")
|
||||
Long collectionElementId;
|
||||
|
||||
@NotBlank(message = "you have to choose the generate type")
|
||||
@ApiModelProperty("text image text-image")
|
||||
String generateType;
|
||||
|
||||
@ApiModelProperty("图片是update,还是从library中选择")
|
||||
String designType;
|
||||
|
||||
@NotBlank(message = "level1Type cannot be empty!")
|
||||
@ApiModelProperty("Moodboard Printboard Sketchboard MarketingSketch")
|
||||
String level1Type;
|
||||
|
||||
@ApiModelProperty("Outwear Dress Blouse Skirt Trousers")
|
||||
String level2Type;
|
||||
|
||||
@NotBlank(message = "timeZone cannot be empty!")
|
||||
@ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取")
|
||||
String timeZone;
|
||||
}
|
||||
18
src/main/java/com/ai/da/model/vo/GenerateCaptionVO.java
Normal file
18
src/main/java/com/ai/da/model/vo/GenerateCaptionVO.java
Normal file
@@ -0,0 +1,18 @@
|
||||
package com.ai.da.model.vo;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
@ApiModel("生成sketch的caption")
|
||||
public class GenerateCaptionVO {
|
||||
|
||||
@ApiModelProperty("caption")
|
||||
private String caption;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.ai.da.model.vo;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@ApiModel("生成 ConllectionItem响应")
|
||||
public class GenerateCollectionItemVO {
|
||||
|
||||
@ApiModelProperty("generate生成图片的id")
|
||||
private Long generateItemId;
|
||||
|
||||
@ApiModelProperty("generate生成图片的url")
|
||||
private String generateItemUrl;
|
||||
}
|
||||
21
src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java
Normal file
21
src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java
Normal file
@@ -0,0 +1,21 @@
|
||||
package com.ai.da.model.vo;
|
||||
|
||||
import io.swagger.annotations.ApiModel;
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@ApiModel("generate响应vo")
|
||||
public class GenerateCollectionVO {
|
||||
|
||||
@ApiModelProperty("generateId")
|
||||
private Long generateId;
|
||||
|
||||
@ApiModelProperty("collection")
|
||||
private Long collectionId;
|
||||
|
||||
@ApiModelProperty("生成的图片信息")
|
||||
private List<GenerateCollectionItemVO> generatedCollectionItems;
|
||||
}
|
||||
@@ -1390,4 +1390,91 @@ public class PythonService {
|
||||
originRatioList.set(1, originRatioList.get(1).multiply(BigDecimal.valueOf(high)));
|
||||
return originRatioList;
|
||||
}
|
||||
|
||||
public String generateSketchCaption(String url) {
|
||||
//限流校验
|
||||
AccessLimitUtils.validate("generateSketchCaption",5);
|
||||
OkHttpClient client = new OkHttpClient().newBuilder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
|
||||
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
|
||||
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
|
||||
.build();
|
||||
MediaType mediaType = MediaType.parse("application/json");
|
||||
RequestBody body = RequestBody.create(mediaType, url);
|
||||
Request request = new Request.Builder()
|
||||
.url(accessPythonIp+":2828/aida/interrogator")
|
||||
.method("POST", body)
|
||||
.addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
|
||||
.addHeader("Content-Type", "application/json")
|
||||
.build();
|
||||
Response response = null;
|
||||
String bodyStr = null;
|
||||
try {
|
||||
log.info("generateSketchCaption请求入参content###{}", url);
|
||||
response = client.newCall(request).execute();
|
||||
bodyStr = response.body().string();
|
||||
} catch (IOException ioException) {
|
||||
log.error("generateSketchCaption异常###{}", ExceptionUtil.getThrowableList(ioException));
|
||||
}
|
||||
//去除限流
|
||||
AccessLimitUtils.validateOut("generateSketchCaption");
|
||||
if (Objects.isNull(response)) {
|
||||
log.error("generateSketchCaption异常###{}", "response or body is empty!");
|
||||
throw new BusinessException("system error!");
|
||||
}
|
||||
JSONObject jsonObject = JSON.parseObject(JSON.toJSONString(response));
|
||||
Boolean result = jsonObject.getBoolean("successful");
|
||||
if (result) {
|
||||
return bodyStr;
|
||||
}
|
||||
log.info("attribute_retrieval失败###{}", bodyStr);
|
||||
//生成失败
|
||||
throw new BusinessException("system error!");
|
||||
}
|
||||
|
||||
public String generateSketch(String url,String text) {
|
||||
//限流校验
|
||||
AccessLimitUtils.validate("generateSketch",5);
|
||||
OkHttpClient client = new OkHttpClient().newBuilder()
|
||||
.connectTimeout(30, TimeUnit.SECONDS)
|
||||
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
|
||||
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
|
||||
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
|
||||
.build();
|
||||
MediaType mediaType = MediaType.parse("application/json");
|
||||
Map<String, String> content = Maps.newHashMap();
|
||||
content.put("img_url", url);
|
||||
content.put("input", text);
|
||||
RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content));
|
||||
Request request = new Request.Builder()
|
||||
.url(accessPythonIp + ":2828/aida/diffusion")
|
||||
.method("POST", body)
|
||||
.addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
|
||||
.addHeader("Content-Type", "application/json")
|
||||
.build();
|
||||
Response response = null;
|
||||
String bodyString = null;
|
||||
try {
|
||||
log.info("generateSketch请求入参content###{}", JSON.toJSONString(content));
|
||||
response = client.newCall(request).execute();
|
||||
bodyString = response.body().string();
|
||||
} catch (IOException ioException) {
|
||||
log.error("PythonService##generateSketch异常###{}", ExceptionUtil.getThrowableList(ioException));
|
||||
}
|
||||
//去除限流
|
||||
AccessLimitUtils.validateOut("generateSketch");
|
||||
if (Objects.isNull(response)) {
|
||||
log.error("PythonService##generateSketch异常###{}", "response or body is empty!");
|
||||
throw new BusinessException("generate sketch exception!");
|
||||
}
|
||||
JSONObject jsonObject = JSON.parseObject(JSON.toJSONString(response));
|
||||
Boolean result = jsonObject.getBoolean("successful");
|
||||
if (result) {
|
||||
return bodyString;
|
||||
}
|
||||
log.info("generate sketch失败###{}", jsonObject);
|
||||
//生成失败
|
||||
throw new BusinessException("generate sketch exception!");
|
||||
}
|
||||
}
|
||||
|
||||
12
src/main/java/com/ai/da/service/GenerateService.java
Normal file
12
src/main/java/com/ai/da/service/GenerateService.java
Normal file
@@ -0,0 +1,12 @@
|
||||
package com.ai.da.service;
|
||||
|
||||
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
||||
import com.ai.da.model.vo.GenerateCaptionVO;
|
||||
import com.ai.da.model.vo.GenerateCollectionVO;
|
||||
|
||||
public interface GenerateService {
|
||||
|
||||
GenerateCaptionVO generateCaption(Long sketchElementId);
|
||||
|
||||
GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO);
|
||||
}
|
||||
129
src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java
Normal file
129
src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java
Normal file
@@ -0,0 +1,129 @@
|
||||
package com.ai.da.service.impl;
|
||||
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import com.ai.da.common.context.UserContext;
|
||||
import com.ai.da.common.utils.DateUtil;
|
||||
import com.ai.da.mapper.CollectionElementMapper;
|
||||
import com.ai.da.mapper.GenerateDetailMapper;
|
||||
import com.ai.da.mapper.GenerateMapper;
|
||||
import com.ai.da.mapper.entity.CollectionElement;
|
||||
import com.ai.da.mapper.entity.Generate;
|
||||
import com.ai.da.mapper.entity.GenerateDetail;
|
||||
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
|
||||
import com.ai.da.model.vo.AuthPrincipalVo;
|
||||
import com.ai.da.model.vo.GenerateCaptionVO;
|
||||
import com.ai.da.model.vo.GenerateCollectionItemVO;
|
||||
import com.ai.da.model.vo.GenerateCollectionVO;
|
||||
import com.ai.da.python.PythonService;
|
||||
import com.ai.da.service.GenerateService;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import io.netty.util.internal.StringUtil;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
@Service
|
||||
public class GenerateServiceImpl extends ServiceImpl<GenerateMapper,Generate> implements GenerateService {
|
||||
|
||||
@Resource
|
||||
private CollectionElementMapper collectionElementMapper;
|
||||
|
||||
@Resource
|
||||
private GenerateDetailMapper generateDetailMapper;
|
||||
|
||||
@Resource
|
||||
private PythonService pythonService;
|
||||
|
||||
@Override
|
||||
public GenerateCaptionVO generateCaption(Long sketchElementId) {
|
||||
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
|
||||
Assert.notNull(collectionElement,"System error!Please reselect the sketch");
|
||||
Assert.isTrue("Sketchboard".equals(collectionElement.getLevel1Type()) && !StringUtil.isNullOrEmpty(collectionElement.getUrl())
|
||||
,"System error!Please reselect the sketch");
|
||||
// String url = collectionElement.getUrl();
|
||||
// String caption = pythonService.generateSketchCaption(url);
|
||||
GenerateCaptionVO recognized_caption = new GenerateCaptionVO("recognized caption");
|
||||
|
||||
return recognized_caption;
|
||||
}
|
||||
|
||||
@Override
|
||||
public GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
|
||||
// 1、获取用户信息
|
||||
AuthPrincipalVo userHolder = UserContext.getUserHolder();
|
||||
Long accountId = userHolder.getId();
|
||||
|
||||
// 2、判断必须入参是否为非空
|
||||
String generateType = generateThroughImageTextDTO.getGenerateType();
|
||||
String text = generateThroughImageTextDTO.getText();
|
||||
Long sketchId = generateThroughImageTextDTO.getCollectionElementId();
|
||||
|
||||
Generate generate = new Generate();
|
||||
generate.setAccountId(accountId);
|
||||
generate.setGenerateType(generateType);
|
||||
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
|
||||
|
||||
switch(generateType){
|
||||
case "text":
|
||||
Assert.notNull(text,"Please input the caption");
|
||||
generate.setText(text);
|
||||
break;
|
||||
case "image":
|
||||
Assert.notNull(sketchId,"Please choose a sketch");
|
||||
generate.setCollectionElementId(sketchId);
|
||||
break;
|
||||
case "text-image":
|
||||
Assert.isTrue(!StringUtil.isNullOrEmpty(text) && Objects.nonNull(sketchId),
|
||||
"Please input the caption and choose a sketch");
|
||||
generate.setText(text);
|
||||
generate.setCollectionElementId(sketchId);
|
||||
break;
|
||||
}
|
||||
|
||||
// 3、将请求信息落库
|
||||
// 3.1 sketch在t_collection_element表中的信息是否需要更新 如 level2Type
|
||||
CollectionElement collectionElement = collectionElementMapper.selectById(sketchId);
|
||||
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(generateThroughImageTextDTO.getLevel2Type()) ){
|
||||
collectionElement.setLevel2Type(generateThroughImageTextDTO.getLevel2Type());
|
||||
QueryWrapper<CollectionElement> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.eq("id", sketchId);
|
||||
collectionElementMapper.update(collectionElement,queryWrapper);
|
||||
}
|
||||
|
||||
// 3.2 将本次generate的请求信息添加到t_generate表中
|
||||
save(generate);
|
||||
|
||||
// 4、向模型发起请求
|
||||
// String generatedSketchUrl = pythonService.generateSketch(collectionElement.getUrl(), text);
|
||||
|
||||
List<String> generatedSketchUrl = Arrays.asList("testUrl1","testUrl2","testUrl3","testUrl4");
|
||||
|
||||
GenerateCollectionVO generateCollectionVO = new GenerateCollectionVO();
|
||||
List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
|
||||
generateCollectionVO.setGenerateId(generate.getId());
|
||||
generateCollectionVO.setCollectionId(collectionElement.getCollectionId());
|
||||
generateCollectionVO.setGeneratedCollectionItems(generatedCollectionItems);
|
||||
// 5、处理模型返回的数据
|
||||
// 5.1 将相应的url保存到数据库
|
||||
generatedSketchUrl.forEach(item -> {
|
||||
GenerateDetail generateDetail = new GenerateDetail();
|
||||
generateDetail.setUrl(item);
|
||||
generateDetail.setGenerateId(generate.getId());
|
||||
generateDetail.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
|
||||
generateDetailMapper.insert(generateDetail);
|
||||
|
||||
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
|
||||
generateCollectionItemVO.setGenerateItemId(generateDetail.getId());
|
||||
generateCollectionItemVO.setGenerateItemUrl(item);
|
||||
generatedCollectionItems.add(generateCollectionItemVO);
|
||||
});
|
||||
|
||||
// 6、将模型返回的图片地址返回给前端
|
||||
return generateCollectionVO;
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
server.port=7788
|
||||
server.port=5567
|
||||
|
||||
#datasource
|
||||
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver
|
||||
|
||||
Reference in New Issue
Block a user