@@ -4,26 +4,32 @@ import com.ai.da.common.config.exception.BusinessException;
import com.ai.da.common.context.UserContext ;
import com.ai.da.common.enums.GenerateModeEnum ;
import com.ai.da.common.enums.ModelNameEnum ;
import com.ai.da.common.utils.DateUtil ;
import com.ai.da.common.utils.MD5Utils ;
import com.ai.da.common.utils.MinioUtil ;
import com.ai.da.common.utils.* ;
import com.ai.da.mapper.CollectionElementMapper ;
import com.ai.da.mapper.GenerateCancelMapper ;
import com.ai.da.mapper.GenerateDetailMapper ;
import com.ai.da.mapper.GenerateMapper ;
import com.ai.da.mapper.entity.* ;
import com.ai.da.model.dto.GenerateLikeDTO ;
import com.ai.da.model.dto.GenerateThroughImageTextDTO ;
import com.ai.da.model.dto.GenerateToPythonDTO ;
import com.ai.da.model.vo.* ;
import com.ai.da.python.PythonService ;
import com.ai.da.service.CollectionElementService ;
import com.ai.da.service.GenerateService ;
import com.ai.da.service.LibraryService ;
import com.ai.da.service.RabbitMQService ;
import com.alibaba.fastjson.JSON ;
import com.alibaba.fastjson.JSONObject ;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper ;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl ;
import io.minio.errors.MinioException ;
import io.netty.util.internal.StringUtil ;
import lombok.extern.slf4j.Slf4j ;
import org.springframework.beans.factory.annotation.Value ;
import org.springframework.stereotype.Service ;
import org.springframework.transaction.annotation.Transactional ;
import org.springframework.util.CollectionUtils ;
import javax.annotation.Resource ;
import java.io.IOException ;
@@ -31,6 +37,7 @@ import java.util.*;
import static com.ai.da.common.enums.CollectionLevel1TypeEnum.* ;
@Slf4j
@Service
public class GenerateServiceImpl extends ServiceImpl < GenerateMapper , Generate > implements GenerateService {
@@ -52,10 +59,31 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Resource
private MinioUtil minioUtil ;
@Resource
private RabbitMQService rabbitMQService ;
@Resource
private RedisUtil redisUtil ;
@Resource
private GenerateCancelMapper generateCancelMapper ;
@Value ( " ${redis.key.consumptionOrder} " )
private String consumptionOrderKey ;
@Value ( " ${redis.key.cancelSet} " )
private String cancelSetKey ;
@Value ( " ${redis.key.exceptionMap} " )
private String exceptionMapKey ;
@Value ( " ${redis.key.resultMap} " )
private String resultMapKey ;
@Override
public GenerateCaptionVO generateCaption ( Long sketchElementId ) {
CollectionElement collectionElement = collectionElementMapper . selectById ( sketchElementId ) ;
if ( Objects . isNull ( collectionElement ) ) {
if ( Objects . isNull ( collectionElement ) ) {
throw new BusinessException ( " the.image.does.not.exist.please.reselect " ) ;
}
String url = collectionElement . getUrl ( ) ;
@@ -69,45 +97,46 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Transactional ( rollbackFor = Exception . class )
public GenerateCollectionVO generateThroughImageText ( GenerateThroughImageTextDTO generateThroughImageTextDTO ) {
// 1、获取用户信息
AuthPrincipalVo userHolder = UserContext . getUserHolder ( ) ;
Long accountId = generateThroughImageTextDTO . getUserId ( ) ;
String generateType = generateThroughImageTextDTO . getGenerateType ( ) ;
Long accountId = userHolder . getId ( ) ;
if ( ! GenerateModeEnum . getGenerateModeList ( ) . contains ( generateType ) ) {
throw new BusinessException ( " unknown.generate.type " ) ;
}
// 2、判断必须入参是否为非空
Generate generate = new Generate ( ) ;
generate . setAccountId ( accountId ) ;
generate . setUniqueId ( generateThroughImageTextDTO . getUniqueId ( ) ) ;
generate . setLevel1Type ( generateThroughImageTextDTO . getLevel1Type ( ) ) ;
// 当level1type是sketchboard时, 存数据库需要加上当前性别
generate . setGenerateType ( generate . getLevel1Type ( ) . equals ( SKETCH_BOARD . getRealName ( ) ) ?
generateType + " ( " + generateThroughImageTextDTO . getGender ( ) + " ) " :
generateType + " ( " + generateThroughImageTextDTO . getGender ( ) + " ) " :
generateType ) ;
generate . setModelName ( StringUtil . isNullOrEmpty ( generateThroughImageTextDTO . getVersion ( ) ) ? ModelNameEnum . MODEL_0 . getCode ( ) : generateThroughImageTextDTO . getVersion ( ) ) ;
generate . setCreateDate ( DateUtil . getByTimeZone ( generateThroughImageTextDTO . getTimeZone ( ) ) ) ;
String text = generateThroughImageTextDTO . getText ( ) ;
Long elementId = generateThroughImageTextDTO . getCollectionElementId ( ) ;
validateGeneraType ( generate , text , elementId , generateType ) ;
validateGeneraType ( generate , text , elementId , generateType ) ;
// 3、将请求信息落库
// 3.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
// 2.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementService . editLevel2Type ( elementId , generateThroughImageTextDTO . getLevel2Type ( ) ) ;
// 3.2 将本次generate的请求信息添加到t_generate表中
save ( generate ) ;
// 4、向模型发起请求
// 3、向模型发起请求
int mode = GenerateModeEnum . TEXT . getValue ( ) . equals ( generateType ) ?
GenerateModeEnum . TEXT . getCode ( ) :
GenerateModeEnum . TEXT_IMAGE . getCode ( ) ;
String category = generateThroughImageTextDTO . getLevel1Type ( ) . equals ( SKETCH_BOARD . getRealName ( ) ) ? " sketch " :
generateThroughImageTextDTO . getLevel1Type ( ) . equals ( PRINT_BOARD . getRealName ( ) ) ? " print " : " moodboard " ;
// text = !StringUtil.isNullOrEmpty(text) && generateThroughImageTextDTO.getVersion().equals("1") ? "painting style, " + text : text ;
List < String > generatedSketchUrl = pythonService . generateSketchOrPrint ( accountId , Objects . isNull ( e lementId ) ? null : collectionElement . getUrl ( ) ,
category , text , mode , generateThroughImageTextDTO . getVersion ( ) , generateThroughImageTextDTO . getGender ( ) ) ;
AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil ( ) ;
List < String > generatedSketchUrl = asyncCallerUtil . generate ( new GenerateToPythonDTO ( accountId , Objects . isNull ( collectionE lement) ? " " : collectionElement . getUrl ( ) ,
category , text , mode , " 1 " , generateThroughImageTextDTO . getGender ( ) , generateThroughImageTextDTO . getUniqueId ( ) ) ) ;
// List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
// category, text, mode, "1", generateThroughImageTextDTO.getGender()));
log . info ( " generate 响应 : " + generatedSketchUrl ) ;
if ( CollectionUtils . isEmpty ( generatedSketchUrl ) ) {
return null ;
}
// 4、将请求信息落库,将本次generate的请求信息添加到t_generate表中
save ( generate ) ;
// 5、处理模型返回的数据
// 5.1 将相应的url保存到数据库
@@ -117,8 +146,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO ( ) ;
String md5 = MD5Utils . encryptFile ( minioUtil . getPresignedUrl ( item , 24 * 60 ) , Boolean . FALSE ) ;
// 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过
List < Map < String , Long > > libraryIdList = generateDetailMapper . getLibraryIdThroughMD5 ( md5 , generateThroughImageTextDTO . getLevel1Type ( ) ) ;
if ( ! libraryIdList . isEmpty ( ) ) {
List < Map < String , Long > > libraryIdList = generateDetailMapper . getLibraryIdThroughMD5 ( md5 , generateThroughImageTextDTO . getLevel1Type ( ) ) ;
if ( ! libraryIdList . isEmpty ( ) ) {
generateDetail . setIsLike ( ( byte ) 1 ) ;
generateDetail . setLibraryId ( libraryIdList . get ( 0 ) . get ( " library_id " ) ) ;
generateCollectionItemVO . setIsLiked ( Boolean . TRUE ) ;
@@ -139,22 +168,22 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
return new GenerateCollectionVO ( generate . getId ( ) , collectionId , generatedCollectionItems ) ;
}
private void validateGeneraType ( Generate generate , String text , Long elementId , String generateType ) {
private void validateGeneraType ( Generate generate , String text , Long elementId , String generateType ) {
switch ( generateType ) {
case " text " :
if ( StringUtil . isNullOrEmpty ( text ) ) {
if ( StringUtil . isNullOrEmpty ( text ) ) {
throw new BusinessException ( " please.input.the.caption " ) ;
}
generate . setText ( text ) ;
break ;
case " image " :
if ( Objects . isNull ( elementId ) ) {
if ( Objects . isNull ( elementId ) ) {
throw new BusinessException ( " please.choose.an.image " ) ;
}
generate . setCollectionElementId ( elementId ) ;
break ;
case " text-image " :
if ( StringUtil . isNullOrEmpty ( text ) | | Objects . isNull ( elementId ) ) {
if ( StringUtil . isNullOrEmpty ( text ) | | Objects . isNull ( elementId ) ) {
throw new BusinessException ( " please.input.the.caption.and.choose.an.image " ) ;
}
generate . setText ( text ) ;
@@ -169,21 +198,21 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 1、判断参数是否正确
// 1.1 必须参数是否非空
if ( SKETCH_BOARD . getRealName ( ) . equals ( generateLikeDTO . getLevel1Type ( ) ) ) {
if ( StringUtil . isNullOrEmpty ( generateLikeDTO . getLevel2Type ( ) ) ) {
if ( StringUtil . isNullOrEmpty ( generateLikeDTO . getLevel2Type ( ) ) ) {
throw new BusinessException ( " level2Type.cannot.be.empty " ) ;
}
if ( StringUtil . isNullOrEmpty ( generateLikeDTO . getGender ( ) ) ) {
if ( StringUtil . isNullOrEmpty ( generateLikeDTO . getGender ( ) ) ) {
throw new BusinessException ( " gender.cannot.be.empty " ) ;
}
}
// 1.2 判断参数是否真实有效
Long generateDetailId = generateLikeDTO . getGenerateDetailId ( ) ;
GenerateDetail generateDetail = generateDetailMapper . selectById ( generateDetailId ) ;
if ( Objects . isNull ( generateDetail ) ) {
if ( Objects . isNull ( generateDetail ) ) {
throw new BusinessException ( " generateItem.does.not.exist " ) ;
}
Generate generate = getById ( generateDetail . getGenerateId ( ) ) ;
if ( ! generateLikeDTO . getLevel1Type ( ) . equals ( generate . getLevel1Type ( ) ) ) {
if ( ! generateLikeDTO . getLevel1Type ( ) . equals ( generate . getLevel1Type ( ) ) ) {
throw new BusinessException ( " level1Type.does.not.match " ) ;
}
@@ -191,8 +220,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 2.1、不能重复喜欢
// 2.1.1 判断该图片是否被喜欢过
Library libraryDetail = libraryService . getById ( generateDetail . getLibraryId ( ) ) ;
if ( ( Objects . nonNull ( generateDetail . getLibraryId ( ) ) & & ! generateDetail . getLibraryId ( ) . equals ( 0L ) )
| | Objects . nonNull ( libraryDetail ) ) {
if ( ( Objects . nonNull ( generateDetail . getLibraryId ( ) ) & & ! generateDetail . getLibraryId ( ) . equals ( 0L ) )
| | Objects . nonNull ( libraryDetail ) ) {
throw new BusinessException ( " duplicate.likes.are.not.allowed " ) ;
}
@@ -215,7 +244,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
public Boolean generateDislike ( Long generateDetailId , String timeZone ) {
// 1、确定generateDetail中是否有这条记录
GenerateDetail generateDetail = generateDetailMapper . selectById ( generateDetailId ) ;
if ( Objects . isNull ( generateDetail ) ) {
if ( Objects . isNull ( generateDetail ) ) {
throw new BusinessException ( " generateItem.does.not.exist " ) ;
}
@@ -265,7 +294,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generateDetailMapper . update ( generateDetail , queryWrapper ) ;
}
public void updateLikeStatusBatch ( List < Long > generateDetailIdList , Byte hasLike , Long libraryId , String timeZone ) {
public void updateLikeStatusBatch ( List < Long > generateDetailIdList , Byte hasLike , Long libraryId , String timeZone ) {
QueryWrapper < GenerateDetail > queryWrapper = new QueryWrapper < > ( ) ;
queryWrapper . in ( " id " , generateDetailIdList ) ;
@@ -277,10 +306,156 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generateDetailMapper . update ( generateDetail , queryWrapper ) ;
}
public List < GenerateDetail > selectBatchByLibraryId ( List < Long > libraryId ) {
public List < GenerateDetail > selectBatchByLibraryId ( List < Long > libraryId ) {
QueryWrapper < GenerateDetail > qw = new QueryWrapper < > ( ) ;
qw . in ( " library_id " , libraryId ) ;
qw . in ( " library_id " , libraryId ) ;
return generateDetailMapper . selectList ( qw ) ;
}
@Override
public String prepareForGenerate ( GenerateThroughImageTextDTO generateThroughImageTextDTO ) {
// 1、参数检查, 判断必须参数是否为空
if ( Objects . isNull ( generateThroughImageTextDTO . getUserId ( ) ) ) {
throw new BusinessException ( " userId cannot be empty " ) ;
}
String generateType = generateThroughImageTextDTO . getGenerateType ( ) ;
if ( ! GenerateModeEnum . getGenerateModeList ( ) . contains ( generateType ) ) {
throw new BusinessException ( " unknown.generate.type " ) ;
}
String text = generateThroughImageTextDTO . getText ( ) ;
Long elementId = generateThroughImageTextDTO . getCollectionElementId ( ) ;
validateGeneraType ( new Generate ( ) , text , elementId , generateType ) ;
// 2、生成唯一id 使用uuid
String uuid = UUID . randomUUID ( ) . toString ( ) ;
// SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
// long snowflakeId = idWorker.nextId();
int num = 1 ;
// 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id
while ( ( redisUtil . isElementExistsInMap ( resultMapKey , uuid ) | |
redisUtil . isElementExistsInZSet ( consumptionOrderKey , uuid ) )
& & num < 10 ) {
uuid = UUID . randomUUID ( ) . toString ( ) ;
num + + ;
}
// 无依据确定的数字
if ( num > 10 ) {
try {
Thread . sleep ( 1000 ) ;
} catch ( InterruptedException e ) {
throw new RuntimeException ( e ) ;
}
uuid = UUID . randomUUID ( ) . toString ( ) ;
}
generateThroughImageTextDTO . setUniqueId ( uuid ) ;
String jsonString = JSON . toJSONString ( generateThroughImageTextDTO ) ;
// 3、加入redis排队, 便于获取实时排队信息
Double maxScore = redisUtil . getMaxScore ( consumptionOrderKey ) ;
redisUtil . addToZSet ( consumptionOrderKey , uuid , maxScore ) ;
// 4、将消息发布到MQ消息队列
rabbitMQService . publishMessage ( jsonString ) ;
// 5、返回唯一id
return uuid ;
}
@Override
public Long getRankPosition ( String uniqueId ) {
// rank 从0开始
return redisUtil . getRank ( consumptionOrderKey , uniqueId ) ;
}
@Override
public GenerateCollectionVO getGenerateResult ( String uniqueId ) {
// 1、判断该请求是否已经异常
Boolean isMember = redisUtil . isElementExistsInMap ( exceptionMapKey , uniqueId ) ;
if ( isMember ) {
throw new BusinessException ( " generate.interface.error " ) ;
}
// 2、判断该请求是否还在排队
Boolean existsInZSet = redisUtil . isElementExistsInZSet ( consumptionOrderKey , uniqueId ) ;
if ( existsInZSet ) {
// 排队中,给出当前排序位置,rank从0开始
Long rankPosition = getRankPosition ( uniqueId ) ;
// 有9个消费者, 所以当rank>8即当前请求至少排在第九位时, 其实际排队位置为9-8+1, 当rank <=8, 请求均在处理中
return new GenerateCollectionVO ( rankPosition > 8L ? rankPosition - 8 + 1 : 1L ) ;
}
// 3、判断redis中有没有
boolean hasHashKey = redisUtil . isElementExistsInMap ( resultMapKey , uniqueId ) ;
if ( hasHashKey ) {
// 3.1 有直接从redis中拿
String resultString = redisUtil . getMapValue ( resultMapKey , uniqueId ) ;
return JSONObject . parseObject ( resultString , GenerateCollectionVO . class ) ;
}
// 3.2 判断数据库中有没有
Generate generate = selectByUniqueId ( uniqueId ) ;
if ( Objects . isNull ( generate ) ) {
// 3.3 还没执行完,给出当前位置
return new GenerateCollectionVO ( 1L ) ;
}
Long generateId = generate . getId ( ) ;
QueryWrapper < GenerateDetail > qw = new QueryWrapper < > ( ) ;
qw . eq ( " generate_id " , generateId ) ;
List < GenerateDetail > generateDetails = generateDetailMapper . selectList ( qw ) ;
if ( CollectionUtils . isEmpty ( generateDetails ) ) {
// 会有这种情况吗? 存到generate中, 但是还没存到generateDetail中
return new GenerateCollectionVO ( 1L ) ;
}
List < GenerateCollectionItemVO > generatedCollectionItems = new ArrayList < > ( ) ;
generateDetails . forEach ( item - > {
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO ( ) ;
generateCollectionItemVO . setGenerateItemId ( item . getId ( ) ) ;
generateCollectionItemVO . setGenerateItemUrl ( minioUtil . getPresignedUrl ( item . getUrl ( ) , 24 * 60 ) ) ;
generatedCollectionItems . add ( generateCollectionItemVO ) ;
} ) ;
return new GenerateCollectionVO ( generateId , null , generatedCollectionItems ) ;
}
public Generate selectByUniqueId ( String uniqueId ) {
QueryWrapper < Generate > qw = new QueryWrapper < > ( ) ;
qw . eq ( " unique_id " , uniqueId ) ;
return getOne ( qw ) ;
}
@Override
@Transactional ( rollbackFor = Exception . class )
public void cancelGenerate ( Long userId , String uniqueId , String timeZone ) {
// 1、确认当前消息是否还在排队中
Boolean exists = redisUtil . isElementExistsInZSet ( consumptionOrderKey , uniqueId ) ;
Boolean flag = Boolean . FALSE ;
if ( exists ) flag = redisUtil . getRank ( consumptionOrderKey , uniqueId ) > 1L ? Boolean . TRUE : Boolean . FALSE ;
// 不管flag的默认值是true还是false, 只要exists为false, && 将短路
if ( exists & & flag ) {
// 1.1、将需要取消的唯一id加入redis, 以便及时取消生成
redisUtil . addToSet ( cancelSetKey , uniqueId ) ;
// 1.2 将需要取消的id从redis的ConsumptionOrder中删除
redisUtil . removeFromZSet ( consumptionOrderKey , uniqueId ) ;
} else {
// 2、判断该消息是否异常
boolean hasKey = redisUtil . isElementExistsInMap ( exceptionMapKey , uniqueId ) ;
// 3、判断该消息是否已经消费结束
Boolean existsInResult = redisUtil . isElementExistsInMap ( resultMapKey , uniqueId ) ;
if ( ! hasKey & & ! existsInResult ) {
// 设置取等待状态为false
AsyncCallerUtil . waitingStatus . put ( uniqueId , false ) ;
// 3、直接发送取消请求到python端
pythonService . cancelGenerateTask ( uniqueId ) ;
}
}
// 3、考虑加一张表, 专门用于记录哪些用户在什么时间进行了取消操作
GenerateCancel generateCancel = new GenerateCancel ( userId , uniqueId , DateUtil . getByTimeZone ( timeZone ) ) ;
generateCancelMapper . insert ( generateCancel ) ;
}
}