generateSketch 功能

This commit is contained in:
徐佩
2023-08-17 11:59:19 +08:00
parent 11e9ff4e2c
commit 5b41b51859
13 changed files with 485 additions and 1 deletions

View File

@@ -0,0 +1,12 @@
package com.ai.da.service;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.vo.GenerateCaptionVO;
import com.ai.da.model.vo.GenerateCollectionVO;
public interface GenerateService {
GenerateCaptionVO generateCaption(Long sketchElementId);
GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO);
}

View File

@@ -0,0 +1,129 @@
package com.ai.da.service.impl;
import cn.hutool.core.lang.Assert;
import com.ai.da.common.context.UserContext;
import com.ai.da.common.utils.DateUtil;
import com.ai.da.mapper.CollectionElementMapper;
import com.ai.da.mapper.GenerateDetailMapper;
import com.ai.da.mapper.GenerateMapper;
import com.ai.da.mapper.entity.CollectionElement;
import com.ai.da.mapper.entity.Generate;
import com.ai.da.mapper.entity.GenerateDetail;
import com.ai.da.model.dto.GenerateThroughImageTextDTO;
import com.ai.da.model.vo.AuthPrincipalVo;
import com.ai.da.model.vo.GenerateCaptionVO;
import com.ai.da.model.vo.GenerateCollectionItemVO;
import com.ai.da.model.vo.GenerateCollectionVO;
import com.ai.da.python.PythonService;
import com.ai.da.service.GenerateService;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import io.netty.util.internal.StringUtil;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
@Service
public class GenerateServiceImpl extends ServiceImpl<GenerateMapper,Generate> implements GenerateService {
@Resource
private CollectionElementMapper collectionElementMapper;
@Resource
private GenerateDetailMapper generateDetailMapper;
@Resource
private PythonService pythonService;
@Override
public GenerateCaptionVO generateCaption(Long sketchElementId) {
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
Assert.notNull(collectionElement,"System error!Please reselect the sketch");
Assert.isTrue("Sketchboard".equals(collectionElement.getLevel1Type()) && !StringUtil.isNullOrEmpty(collectionElement.getUrl())
,"System error!Please reselect the sketch");
// String url = collectionElement.getUrl();
// String caption = pythonService.generateSketchCaption(url);
GenerateCaptionVO recognized_caption = new GenerateCaptionVO("recognized caption");
return recognized_caption;
}
@Override
public GenerateCollectionVO generateSketchThroughImageText(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、获取用户信息
AuthPrincipalVo userHolder = UserContext.getUserHolder();
Long accountId = userHolder.getId();
// 2、判断必须入参是否为非空
String generateType = generateThroughImageTextDTO.getGenerateType();
String text = generateThroughImageTextDTO.getText();
Long sketchId = generateThroughImageTextDTO.getCollectionElementId();
Generate generate = new Generate();
generate.setAccountId(accountId);
generate.setGenerateType(generateType);
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
switch(generateType){
case "text":
Assert.notNull(text,"Please input the caption");
generate.setText(text);
break;
case "image":
Assert.notNull(sketchId,"Please choose a sketch");
generate.setCollectionElementId(sketchId);
break;
case "text-image":
Assert.isTrue(!StringUtil.isNullOrEmpty(text) && Objects.nonNull(sketchId),
"Please input the caption and choose a sketch");
generate.setText(text);
generate.setCollectionElementId(sketchId);
break;
}
// 3、将请求信息落库
// 3.1 sketch在t_collection_element表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementMapper.selectById(sketchId);
if (StringUtil.isNullOrEmpty(collectionElement.getLevel2Type()) || !(collectionElement.getLevel2Type()).equals(generateThroughImageTextDTO.getLevel2Type()) ){
collectionElement.setLevel2Type(generateThroughImageTextDTO.getLevel2Type());
QueryWrapper<CollectionElement> queryWrapper = new QueryWrapper<>();
queryWrapper.eq("id", sketchId);
collectionElementMapper.update(collectionElement,queryWrapper);
}
// 3.2 将本次generate的请求信息添加到t_generate表中
save(generate);
// 4、向模型发起请求
// String generatedSketchUrl = pythonService.generateSketch(collectionElement.getUrl(), text);
List<String> generatedSketchUrl = Arrays.asList("testUrl1","testUrl2","testUrl3","testUrl4");
GenerateCollectionVO generateCollectionVO = new GenerateCollectionVO();
List<GenerateCollectionItemVO> generatedCollectionItems = new ArrayList<>();
generateCollectionVO.setGenerateId(generate.getId());
generateCollectionVO.setCollectionId(collectionElement.getCollectionId());
generateCollectionVO.setGeneratedCollectionItems(generatedCollectionItems);
// 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库
generatedSketchUrl.forEach(item -> {
GenerateDetail generateDetail = new GenerateDetail();
generateDetail.setUrl(item);
generateDetail.setGenerateId(generate.getId());
generateDetail.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
generateDetailMapper.insert(generateDetail);
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
generateCollectionItemVO.setGenerateItemId(generateDetail.getId());
generateCollectionItemVO.setGenerateItemUrl(item);
generatedCollectionItems.add(generateCollectionItemVO);
});
// 6、将模型返回的图片地址返回给前端
return generateCollectionVO;
}
}