From 5b41b51859bbf038cb67c33056a5ca0b3e3c5baf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E4=BD=A9?= <1779019091@qq.com> Date: Thu, 17 Aug 2023 11:59:19 +0800 Subject: [PATCH] =?UTF-8?q?generateSketch=20=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ai/da/controller/GenerateController.java | 47 +++++++ .../ai/da/mapper/GenerateDetailMapper.java | 7 + .../java/com/ai/da/mapper/GenerateMapper.java | 7 + .../com/ai/da/mapper/entity/Generate.java | 55 ++++++++ .../ai/da/mapper/entity/GenerateDetail.java | 49 +++++++ .../dto/GenerateThroughImageTextDTO.java | 36 +++++ .../com/ai/da/model/vo/GenerateCaptionVO.java | 18 +++ .../da/model/vo/GenerateCollectionItemVO.java | 16 +++ .../ai/da/model/vo/GenerateCollectionVO.java | 21 +++ .../java/com/ai/da/python/PythonService.java | 87 ++++++++++++ .../com/ai/da/service/GenerateService.java | 12 ++ .../da/service/impl/GenerateServiceImpl.java | 129 ++++++++++++++++++ .../resources/application-test.properties | 2 +- 13 files changed, 485 insertions(+), 1 deletion(-) create mode 100644 src/main/java/com/ai/da/controller/GenerateController.java create mode 100644 src/main/java/com/ai/da/mapper/GenerateDetailMapper.java create mode 100644 src/main/java/com/ai/da/mapper/GenerateMapper.java create mode 100644 src/main/java/com/ai/da/mapper/entity/Generate.java create mode 100644 src/main/java/com/ai/da/mapper/entity/GenerateDetail.java create mode 100644 src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java create mode 100644 src/main/java/com/ai/da/model/vo/GenerateCaptionVO.java create mode 100644 src/main/java/com/ai/da/model/vo/GenerateCollectionItemVO.java create mode 100644 src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java create mode 100644 src/main/java/com/ai/da/service/GenerateService.java create mode 100644 src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java new file mode 100644 index 00000000..650ddabd --- /dev/null +++ b/src/main/java/com/ai/da/controller/GenerateController.java @@ -0,0 +1,47 @@ +package com.ai.da.controller; + +import com.ai.da.common.response.Response; +import com.ai.da.model.dto.GenerateThroughImageTextDTO; +import com.ai.da.model.vo.GenerateCaptionVO; +import com.ai.da.model.vo.GenerateCollectionVO; +import com.ai.da.service.GenerateService; +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; +import lombok.extern.slf4j.Slf4j; +import org.springframework.web.bind.annotation.*; + +import javax.annotation.Resource; +import javax.validation.Valid; + +/** + * @author XP + */ +@Api(tags = "Generate模块") +@Slf4j +@RestController +@RequestMapping("/api/generate") +public class GenerateController { + + @Resource + private GenerateService generateService; + + @ApiOperation("自动识别sketch的caption") + @PostMapping("/caption") + public Response generateCaption(@RequestParam Long sketchElementId){ + return Response.success(generateService.generateCaption(sketchElementId)); + } + + + @ApiOperation("通过文字、图片生成图片") + @PostMapping("/sketch") + public Response generateThroughImageText(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO){ + return Response.success(generateService.generateSketchThroughImageText(generateThroughImageTextDTO)); + } + + + + + + + +} diff --git a/src/main/java/com/ai/da/mapper/GenerateDetailMapper.java b/src/main/java/com/ai/da/mapper/GenerateDetailMapper.java new file mode 100644 index 00000000..3b6cf30f --- /dev/null +++ b/src/main/java/com/ai/da/mapper/GenerateDetailMapper.java @@ -0,0 +1,7 @@ +package com.ai.da.mapper; + +import com.ai.da.common.config.mybatis.plus.CommonMapper; +import com.ai.da.mapper.entity.GenerateDetail; + +public interface GenerateDetailMapper extends CommonMapper { +} diff --git a/src/main/java/com/ai/da/mapper/GenerateMapper.java b/src/main/java/com/ai/da/mapper/GenerateMapper.java new file mode 100644 index 00000000..43436d81 --- /dev/null +++ b/src/main/java/com/ai/da/mapper/GenerateMapper.java @@ -0,0 +1,7 @@ +package com.ai.da.mapper; + +import com.ai.da.common.config.mybatis.plus.CommonMapper; +import com.ai.da.mapper.entity.Generate; + +public interface GenerateMapper extends CommonMapper { +} diff --git a/src/main/java/com/ai/da/mapper/entity/Generate.java b/src/main/java/com/ai/da/mapper/entity/Generate.java new file mode 100644 index 00000000..5e21ce85 --- /dev/null +++ b/src/main/java/com/ai/da/mapper/entity/Generate.java @@ -0,0 +1,55 @@ +package com.ai.da.mapper.entity; + + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.Accessors; + +import java.util.Date; + +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(chain = true) +@TableName("t_generate") +public class Generate { + + /** + * ID + */ + @TableId(value = "id", type = IdType.AUTO) + private Long id; + + /** + * 用户ID + */ + private Long accountId; + + /** + * 关联collection element id + */ + private Long collectionElementId; + + /** + * caption的内容 + */ + private String text; + + /** + * 选择生成类型:text、image、text-image + */ + private String generateType; + + /** + * 创建时间 + */ + private Date createDate; + + /** + * 更新时间 + */ + private Date updateDate; + +} diff --git a/src/main/java/com/ai/da/mapper/entity/GenerateDetail.java b/src/main/java/com/ai/da/mapper/entity/GenerateDetail.java new file mode 100644 index 00000000..535856ca --- /dev/null +++ b/src/main/java/com/ai/da/mapper/entity/GenerateDetail.java @@ -0,0 +1,49 @@ +package com.ai.da.mapper.entity; + + +import com.baomidou.mybatisplus.annotation.IdType; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.experimental.Accessors; + +import java.util.Date; + +@Data +@EqualsAndHashCode(callSuper = false) +@Accessors(chain = true) +@TableName("t_generate_detail") +public class GenerateDetail { + + /** + * ID + */ + @TableId(value = "id", type = IdType.AUTO) + private Long id; + + /** + * 关联 generate ID + */ + private Long generateId; + + /** + * 模型返回的图片url + */ + private String url; + + /** + * 创建时间 + */ + private Date createDate; + + /** + * 更新时间 + */ + private Date updateDate; + + + + + +} diff --git a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java new file mode 100644 index 00000000..02c38037 --- /dev/null +++ b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java @@ -0,0 +1,36 @@ +package com.ai.da.model.dto; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import javax.validation.constraints.NotBlank; + +@Data +@ApiModel("GenerateThroughImageTextDTO") +public class GenerateThroughImageTextDTO { + + @ApiModelProperty("caption") + String text; + + @ApiModelProperty("图片在t_collection_element表中的id") + Long collectionElementId; + + @NotBlank(message = "you have to choose the generate type") + @ApiModelProperty("text image text-image") + String generateType; + + @ApiModelProperty("图片是update,还是从library中选择") + String designType; + + @NotBlank(message = "level1Type cannot be empty!") + @ApiModelProperty("Moodboard Printboard Sketchboard MarketingSketch") + String level1Type; + + @ApiModelProperty("Outwear Dress Blouse Skirt Trousers") + String level2Type; + + @NotBlank(message = "timeZone cannot be empty!") + @ApiModelProperty("本地时区,比如 'Asia/Tokyo' 东京时间 , 'Asia/Shanghai' 北京时间 由js本地获取") + String timeZone; +} diff --git a/src/main/java/com/ai/da/model/vo/GenerateCaptionVO.java b/src/main/java/com/ai/da/model/vo/GenerateCaptionVO.java new file mode 100644 index 00000000..ad725488 --- /dev/null +++ b/src/main/java/com/ai/da/model/vo/GenerateCaptionVO.java @@ -0,0 +1,18 @@ +package com.ai.da.model.vo; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@AllArgsConstructor +@NoArgsConstructor +@Data +@ApiModel("生成sketch的caption") +public class GenerateCaptionVO { + + @ApiModelProperty("caption") + private String caption; + +} diff --git a/src/main/java/com/ai/da/model/vo/GenerateCollectionItemVO.java b/src/main/java/com/ai/da/model/vo/GenerateCollectionItemVO.java new file mode 100644 index 00000000..7079f6cb --- /dev/null +++ b/src/main/java/com/ai/da/model/vo/GenerateCollectionItemVO.java @@ -0,0 +1,16 @@ +package com.ai.da.model.vo; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +@Data +@ApiModel("生成 ConllectionItem响应") +public class GenerateCollectionItemVO { + + @ApiModelProperty("generate生成图片的id") + private Long generateItemId; + + @ApiModelProperty("generate生成图片的url") + private String generateItemUrl; +} diff --git a/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java b/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java new file mode 100644 index 00000000..7664ffe3 --- /dev/null +++ b/src/main/java/com/ai/da/model/vo/GenerateCollectionVO.java @@ -0,0 +1,21 @@ +package com.ai.da.model.vo; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.util.List; + +@Data +@ApiModel("generate响应vo") +public class GenerateCollectionVO { + + @ApiModelProperty("generateId") + private Long generateId; + + @ApiModelProperty("collection") + private Long collectionId; + + @ApiModelProperty("生成的图片信息") + private List generatedCollectionItems; +} diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index 60f32aa8..0fc80aeb 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -1390,4 +1390,91 @@ public class PythonService { originRatioList.set(1, originRatioList.get(1).multiply(BigDecimal.valueOf(high))); return originRatioList; } + + public String generateSketchCaption(String url) { + //限流校验 + AccessLimitUtils.validate("generateSketchCaption",5); + OkHttpClient client = new OkHttpClient().newBuilder() + .connectTimeout(30, TimeUnit.SECONDS) + .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒) + .readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒) + .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) + .build(); + MediaType mediaType = MediaType.parse("application/json"); + RequestBody body = RequestBody.create(mediaType, url); + Request request = new Request.Builder() + .url(accessPythonIp+":2828/aida/interrogator") + .method("POST", body) + .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==") + .addHeader("Content-Type", "application/json") + .build(); + Response response = null; + String bodyStr = null; + try { + log.info("generateSketchCaption请求入参content###{}", url); + response = client.newCall(request).execute(); + bodyStr = response.body().string(); + } catch (IOException ioException) { + log.error("generateSketchCaption异常###{}", ExceptionUtil.getThrowableList(ioException)); + } + //去除限流 + AccessLimitUtils.validateOut("generateSketchCaption"); + if (Objects.isNull(response)) { + log.error("generateSketchCaption异常###{}", "response or body is empty!"); + throw new BusinessException("system error!"); + } + JSONObject jsonObject = JSON.parseObject(JSON.toJSONString(response)); + Boolean result = jsonObject.getBoolean("successful"); + if (result) { + return bodyStr; + } + log.info("attribute_retrieval失败###{}", bodyStr); + //生成失败 + throw new BusinessException("system error!"); + } + + public String generateSketch(String url,String text) { + //限流校验 + AccessLimitUtils.validate("generateSketch",5); + OkHttpClient client = new OkHttpClient().newBuilder() + .connectTimeout(30, TimeUnit.SECONDS) + .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒) + .readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒) + .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) + .build(); + MediaType mediaType = MediaType.parse("application/json"); + Map content = Maps.newHashMap(); + content.put("img_url", url); + content.put("input", text); + RequestBody body = RequestBody.create(mediaType, JSON.toJSONString(content)); + Request request = new Request.Builder() + .url(accessPythonIp + ":2828/aida/diffusion") + .method("POST", body) + .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==") + .addHeader("Content-Type", "application/json") + .build(); + Response response = null; + String bodyString = null; + try { + log.info("generateSketch请求入参content###{}", JSON.toJSONString(content)); + response = client.newCall(request).execute(); + bodyString = response.body().string(); + } catch (IOException ioException) { + log.error("PythonService##generateSketch异常###{}", ExceptionUtil.getThrowableList(ioException)); + } + //去除限流 + AccessLimitUtils.validateOut("generateSketch"); + if (Objects.isNull(response)) { + log.error("PythonService##generateSketch异常###{}", "response or body is empty!"); + throw new BusinessException("generate sketch exception!"); + } + JSONObject jsonObject = JSON.parseObject(JSON.toJSONString(response)); + Boolean result = jsonObject.getBoolean("successful"); + if (result) { + return bodyString; + } + log.info("generate sketch失败###{}", jsonObject); + //生成失败 + throw new BusinessException("generate sketch exception!"); + } } diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java new file mode 100644 index 00000000..af8c75bb --- /dev/null +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -0,0 +1,12 @@ +package com.ai.da.service; + +import com.ai.da.model.dto.GenerateThroughImageTextDTO; +import com.ai.da.model.vo.GenerateCaptionVO; +import com.ai.da.model.vo.GenerateCollectionVO; + +public interface GenerateService { + + GenerateCaptionVO generateCaption(Long sketchElementId); + + GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO); +} diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java new file mode 100644 index 00000000..cb2dbda5 --- /dev/null +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -0,0 +1,129 @@ +package com.ai.da.service.impl; + +import cn.hutool.core.lang.Assert; +import com.ai.da.common.context.UserContext; +import com.ai.da.common.utils.DateUtil; +import com.ai.da.mapper.CollectionElementMapper; +import com.ai.da.mapper.GenerateDetailMapper; +import com.ai.da.mapper.GenerateMapper; +import com.ai.da.mapper.entity.CollectionElement; +import com.ai.da.mapper.entity.Generate; +import com.ai.da.mapper.entity.GenerateDetail; +import com.ai.da.model.dto.GenerateThroughImageTextDTO; +import com.ai.da.model.vo.AuthPrincipalVo; +import com.ai.da.model.vo.GenerateCaptionVO; +import com.ai.da.model.vo.GenerateCollectionItemVO; +import com.ai.da.model.vo.GenerateCollectionVO; +import com.ai.da.python.PythonService; +import com.ai.da.service.GenerateService; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import io.netty.util.internal.StringUtil; +import org.springframework.stereotype.Service; + +import javax.annotation.Resource; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +@Service +public class GenerateServiceImpl extends ServiceImpl implements GenerateService { + + @Resource + private CollectionElementMapper collectionElementMapper; + + @Resource + private GenerateDetailMapper generateDetailMapper; + + @Resource + private PythonService pythonService; + + @Override + public GenerateCaptionVO generateCaption(Long sketchElementId) { + CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId); + Assert.notNull(collectionElement,"System error!Please reselect the sketch"); + Assert.isTrue("Sketchboard".equals(collectionElement.getLevel1Type()) && !StringUtil.isNullOrEmpty(collectionElement.getUrl()) + ,"System error!Please reselect the sketch"); +// String url = collectionElement.getUrl(); +// String caption = pythonService.generateSketchCaption(url); + GenerateCaptionVO recognized_caption = new GenerateCaptionVO("recognized caption"); + + return recognized_caption; + } + + @Override + public GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) { + // 1、获取用户信息 + AuthPrincipalVo userHolder = UserContext.getUserHolder(); + Long accountId = userHolder.getId(); + + // 2、判断必须入参是否为非空 + String generateType = generateThroughImageTextDTO.getGenerateType(); + String text = generateThroughImageTextDTO.getText(); + Long sketchId = generateThroughImageTextDTO.getCollectionElementId(); + + Generate generate = new Generate(); + generate.setAccountId(accountId); + generate.setGenerateType(generateType); + generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone())); + + switch(generateType){ + case "text": + Assert.notNull(text,"Please input the caption"); + generate.setText(text); + break; + case "image": + Assert.notNull(sketchId,"Please choose a sketch"); + generate.setCollectionElementId(sketchId); + break; + case "text-image": + Assert.isTrue(!StringUtil.isNullOrEmpty(text) && Objects.nonNull(sketchId), + "Please input the caption and choose a sketch"); + generate.setText(text); + generate.setCollectionElementId(sketchId); + break; + } + + // 3、将请求信息落库 + // 3.1 sketch在t_collection_element表中的信息是否需要更新 如 level2Type + CollectionElement collectionElement = collectionElementMapper.selectById(sketchId); + if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(generateThroughImageTextDTO.getLevel2Type()) ){ + collectionElement.setLevel2Type(generateThroughImageTextDTO.getLevel2Type()); + QueryWrapper queryWrapper = new QueryWrapper<>(); + queryWrapper.eq("id", sketchId); + collectionElementMapper.update(collectionElement,queryWrapper); + } + + // 3.2 将本次generate的请求信息添加到t_generate表中 + save(generate); + + // 4、向模型发起请求 +// String generatedSketchUrl = pythonService.generateSketch(collectionElement.getUrl(), text); + + List generatedSketchUrl = Arrays.asList("testUrl1","testUrl2","testUrl3","testUrl4"); + + GenerateCollectionVO generateCollectionVO = new GenerateCollectionVO(); + List generatedCollectionItems = new ArrayList<>(); + generateCollectionVO.setGenerateId(generate.getId()); + generateCollectionVO.setCollectionId(collectionElement.getCollectionId()); + generateCollectionVO.setGeneratedCollectionItems(generatedCollectionItems); + // 5、处理模型返回的数据 + // 5.1 将相应的url保存到数据库 + generatedSketchUrl.forEach(item -> { + GenerateDetail generateDetail = new GenerateDetail(); + generateDetail.setUrl(item); + generateDetail.setGenerateId(generate.getId()); + generateDetail.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone())); + generateDetailMapper.insert(generateDetail); + + GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); + generateCollectionItemVO.setGenerateItemId(generateDetail.getId()); + generateCollectionItemVO.setGenerateItemUrl(item); + generatedCollectionItems.add(generateCollectionItemVO); + }); + + // 6、将模型返回的图片地址返回给前端 + return generateCollectionVO; + } +} diff --git a/src/main/resources/application-test.properties b/src/main/resources/application-test.properties index 4a364c6c..3ec14c9d 100644 --- a/src/main/resources/application-test.properties +++ b/src/main/resources/application-test.properties @@ -1,4 +1,4 @@ -server.port=7788 +server.port=5567 #datasource spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver