generateSketch 功能

This commit is contained in:
徐佩
2023-08-17 11:59:19 +08:00
parent 11e9ff4e2c
commit 5b41b51859
13 changed files with 485 additions and 1 deletions

View File

@@ -0,0 +1,47 @@
package com.ai.da.controller;
import com.ai.da.common.response.Response;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.vo.GenerateCaptionVO;
import com.ai.da.model.vo.GenerateCollectionVO;
import com.ai.da.service.GenerateService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*;
import javax.annotation.Resource;
import javax.validation.Valid;
/**
* @author XP
*/
@Api(tags = "Generate模块")
@Slf4j
@RestController
@RequestMapping("/api/generate")
public class GenerateController {
@Resource
private GenerateService generateService;
@ApiOperation("自动识别sketch的caption")
@PostMapping("/caption")
public Response<GenerateCaptionVO> generateCaption(@RequestParam Long sketchElementId){
return Response.success(generateService.generateCaption(sketchElementId));
}
@ApiOperation("通过文字、图片生成图片")
@PostMapping("/sketch")
public Response<GenerateCollectionVO> generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO){
return Response.success(generateService.generateSketchThroughImageText(generateThroughImageTextDTO));
}
}

View File

@@ -0,0 +1,7 @@
package com.ai.da.mapper;
import com.ai.da.common.config.mybatis.plus.CommonMapper;
import com.ai.da.mapper.entity.GenerateDetail;
public interface GenerateDetailMapper extends CommonMapper<GenerateDetail> {
}

View File

@@ -0,0 +1,7 @@
package com.ai.da.mapper;
import com.ai.da.common.config.mybatis.plus.CommonMapper;
import com.ai.da.mapper.entity.Generate;
public interface GenerateMapper extends CommonMapper<Generate> {
}

View File

@@ -0,0 +1,55 @@
package com.ai.da.mapper.entity;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.Accessors;
import java.util.Date;
@Data
@EqualsAndHashCode(callSuper = false)
@Accessors(chain = true)
@TableName("t_generate")
public class Generate {
/**
* ID
*/
@TableId(value = "id", type = IdType.AUTO)
private Long id;
/**
* 用户ID
*/
private Long accountId;
/**
* 关联collection element id
*/
private Long collectionElementId;
/**
* caption的内容
*/
private String text;
/**
* 选择生成类型text、image、text-image
*/
private String generateType;
/**
* 创建时间
*/
private Date createDate;
/**
* 更新时间
*/
private Date updateDate;
}

View File

@@ -0,0 +1,49 @@
package com.ai.da.mapper.entity;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.Accessors;
import java.util.Date;
@Data
@EqualsAndHashCode(callSuper = false)
@Accessors(chain = true)
@TableName("t_generate_detail")
public class GenerateDetail {
/**
* ID
*/
@TableId(value = "id", type = IdType.AUTO)
private Long id;
/**
* 关联 generate ID
*/
private Long generateId;
/**
* 模型返回的图片url
*/
private String url;
/**
* 创建时间
*/
private Date createDate;
/**
* 更新时间
*/
private Date updateDate;
}

View File

@@ -0,0 +1,36 @@
package com.ai.da.model.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.NotBlank;
@Data
@ApiModel("GenerateThroughImageTextDTO")
public class GenerateThroughImageTextDTO {
@ApiModelProperty("caption")
String text;
@ApiModelProperty("图片在t_collection_element表中的id")
Long collectionElementId;
@NotBlank(message = "you have to choose the generate type")
@ApiModelProperty("text image text-image")
String generateType;
@ApiModelProperty("图片是update还是从library中选择")
String designType;
@NotBlank(message = "level1Type cannot be empty!")
@ApiModelProperty("Moodboard Printboard Sketchboard MarketingSketch")
String level1Type;
@ApiModelProperty("Outwear Dress Blouse Skirt Trousers")
String level2Type;
@NotBlank(message = "timeZone cannot be empty!")
@ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取")
String timeZone;
}

View File

@@ -0,0 +1,18 @@
package com.ai.da.model.vo;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@AllArgsConstructor
@NoArgsConstructor
@Data
@ApiModel("生成sketch的caption")
public class GenerateCaptionVO {
@ApiModelProperty("caption")
private String caption;
}

View File

@@ -0,0 +1,16 @@
package com.ai.da.model.vo;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
@Data
@ApiModel("生成 ConllectionItem响应")
public class GenerateCollectionItemVO {
@ApiModelProperty("generate生成图片的id")
private Long generateItemId;
@ApiModelProperty("generate生成图片的url")
private String generateItemUrl;
}

View File

@@ -0,0 +1,21 @@
package com.ai.da.model.vo;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import java.util.List;
@Data
@ApiModel("generate响应vo")
public class GenerateCollectionVO {
@ApiModelProperty("generateId")
private Long generateId;
@ApiModelProperty("collection")
private Long collectionId;
@ApiModelProperty("生成的图片信息")
private List<GenerateCollectionItemVO> generatedCollectionItems;
}

View File

@@ -1390,4 +1390,91 @@ public class PythonService {
originRatioList.set(1, originRatioList.get(1).multiply(BigDecimal.valueOf(high))); originRatioList.set(1, originRatioList.get(1).multiply(BigDecimal.valueOf(high)));
return originRatioList; return originRatioList;
} }
public String generateSketchCaption(String url) {
//限流校验
AccessLimitUtils.validate("generateSketchCaption",5);
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, url);
Request request = new Request.Builder()
.url(accessPythonIp+":2828/aida/interrogator")
.method("POST", body)
.addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
.addHeader("Content-Type", "application/json")
.build();
Response response = null;
String bodyStr = null;
try {
log.info("generateSketchCaption请求入参content###{}", url);
response = client.newCall(request).execute();
bodyStr = response.body().string();
} catch (IOException ioException) {
log.error("generateSketchCaption异常###{}", ExceptionUtil.getThrowableList(ioException));
}
//去除限流
AccessLimitUtils.validateOut("generateSketchCaption");
if (Objects.isNull(response)) {
log.error("generateSketchCaption异常###{}", "response or body is empty!");
throw new BusinessException("system error!");
}
JSONObject jsonObject = JSON.parseObject(JSON.toJSONString(response));
Boolean result = jsonObject.getBoolean("successful");
if (result) {
return bodyStr;
}
log.info("attribute_retrieval失败###{}", bodyStr);
//生成失败
throw new BusinessException("system error!");
}
public String generateSketch(String url,String text) {
//限流校验
AccessLimitUtils.validate("generateSketch",5);
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
MediaType mediaType = MediaType.parse("application/json");
Map<String, String> content = Maps.newHashMap();
content.put("img_url", url);
content.put("input", text);
RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content));
Request request = new Request.Builder()
.url(accessPythonIp + ":2828/aida/diffusion")
.method("POST", body)
.addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
.addHeader("Content-Type", "application/json")
.build();
Response response = null;
String bodyString = null;
try {
log.info("generateSketch请求入参content###{}", JSON.toJSONString(content));
response = client.newCall(request).execute();
bodyString = response.body().string();
} catch (IOException ioException) {
log.error("PythonService##generateSketch异常###{}", ExceptionUtil.getThrowableList(ioException));
}
//去除限流
AccessLimitUtils.validateOut("generateSketch");
if (Objects.isNull(response)) {
log.error("PythonService##generateSketch异常###{}", "response or body is empty!");
throw new BusinessException("generate sketch exception!");
}
JSONObject jsonObject = JSON.parseObject(JSON.toJSONString(response));
Boolean result = jsonObject.getBoolean("successful");
if (result) {
return bodyString;
}
log.info("generate sketch失败###{}", jsonObject);
//生成失败
throw new BusinessException("generate sketch exception!");
}
} }

View File

@@ -0,0 +1,12 @@
package com.ai.da.service;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.vo.GenerateCaptionVO;
import com.ai.da.model.vo.GenerateCollectionVO;
public interface GenerateService {
GenerateCaptionVO generateCaption(Long sketchElementId);
GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO);
}

View File

@@ -0,0 +1,129 @@
package com.ai.da.service.impl;
import cn.hutool.core.lang.Assert;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.utils.DateUtil;
import com.ai.da.mapper.CollectionElementMapper;
import com.ai.da.mapper.GenerateDetailMapper;
import com.ai.da.mapper.GenerateMapper;
import com.ai.da.mapper.entity.CollectionElement;
import com.ai.da.mapper.entity.Generate;
import com.ai.da.mapper.entity.GenerateDetail;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.vo.AuthPrincipalVo;
import com.ai.da.model.vo.GenerateCaptionVO;
import com.ai.da.model.vo.GenerateCollectionItemVO;
import com.ai.da.model.vo.GenerateCollectionVO;
import com.ai.da.python.PythonService;
import com.ai.da.service.GenerateService;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import io.netty.util.internal.StringUtil;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
@Service
public class GenerateServiceImpl extends ServiceImpl<GenerateMapper,Generate> implements GenerateService {
@Resource
private CollectionElementMapper collectionElementMapper;
@Resource
private GenerateDetailMapper generateDetailMapper;
@Resource
private PythonService pythonService;
@Override
public GenerateCaptionVO generateCaption(Long sketchElementId) {
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
Assert.notNull(collectionElement,"System error!Please reselect the sketch");
Assert.isTrue("Sketchboard".equals(collectionElement.getLevel1Type()) && !StringUtil.isNullOrEmpty(collectionElement.getUrl())
,"System error!Please reselect the sketch");
// String url = collectionElement.getUrl();
// String caption = pythonService.generateSketchCaption(url);
GenerateCaptionVO recognized_caption = new GenerateCaptionVO("recognized caption");
return recognized_caption;
}
@Override
public GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、获取用户信息
AuthPrincipalVo userHolder = UserContext.getUserHolder();
Long accountId = userHolder.getId();
// 2、判断必须入参是否为非空
String generateType = generateThroughImageTextDTO.getGenerateType();
String text = generateThroughImageTextDTO.getText();
Long sketchId = generateThroughImageTextDTO.getCollectionElementId();
Generate generate = new Generate();
generate.setAccountId(accountId);
generate.setGenerateType(generateType);
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
switch(generateType){
case "text":
Assert.notNull(text,"Please input the caption");
generate.setText(text);
break;
case "image":
Assert.notNull(sketchId,"Please choose a sketch");
generate.setCollectionElementId(sketchId);
break;
case "text-image":
Assert.isTrue(!StringUtil.isNullOrEmpty(text) && Objects.nonNull(sketchId),
"Please input the caption and choose a sketch");
generate.setText(text);
generate.setCollectionElementId(sketchId);
break;
}
// 3、将请求信息落库
// 3.1 sketch在t_collection_element表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementMapper.selectById(sketchId);
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(generateThroughImageTextDTO.getLevel2Type()) ){
collectionElement.setLevel2Type(generateThroughImageTextDTO.getLevel2Type());
QueryWrapper<CollectionElement> queryWrapper = new QueryWrapper<>();
queryWrapper.eq("id", sketchId);
collectionElementMapper.update(collectionElement,queryWrapper);
}
// 3.2 将本次generate的请求信息添加到t_generate表中
save(generate);
// 4、向模型发起请求
// String generatedSketchUrl = pythonService.generateSketch(collectionElement.getUrl(), text);
List<String> generatedSketchUrl = Arrays.asList("testUrl1","testUrl2","testUrl3","testUrl4");
GenerateCollectionVO generateCollectionVO = new GenerateCollectionVO();
List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
generateCollectionVO.setGenerateId(generate.getId());
generateCollectionVO.setCollectionId(collectionElement.getCollectionId());
generateCollectionVO.setGeneratedCollectionItems(generatedCollectionItems);
// 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库
generatedSketchUrl.forEach(item -> {
GenerateDetail generateDetail = new GenerateDetail();
generateDetail.setUrl(item);
generateDetail.setGenerateId(generate.getId());
generateDetail.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
generateDetailMapper.insert(generateDetail);
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
generateCollectionItemVO.setGenerateItemId(generateDetail.getId());
generateCollectionItemVO.setGenerateItemUrl(item);
generatedCollectionItems.add(generateCollectionItemVO);
});
// 6、将模型返回的图片地址返回给前端
return generateCollectionVO;
}
}

View File

@@ -1,4 +1,4 @@
server.port=7788 server.port=5567
#datasource #datasource
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver