TASK:异步调用generate及取消generate

This commit is contained in:
2024-01-24 11:43:56 +08:00
parent a9ce35200c
commit 96858c2cc3
10 changed files with 159 additions and 112 deletions

View File

@@ -48,12 +48,11 @@ public class MQConsumer {
log.info("============start listening=========="); log.info("============start listening==========");
GenerateThroughImageTextDTO generateThroughImageTextDTO = JSONObject.parseObject(msg.getBody(), GenerateThroughImageTextDTO.class); GenerateThroughImageTextDTO generateThroughImageTextDTO = JSONObject.parseObject(msg.getBody(), GenerateThroughImageTextDTO.class);
Long uniqueId = generateThroughImageTextDTO.getUniqueId(); String uniqueId = generateThroughImageTextDTO.getUniqueId();
// 1、将消息从redis排队队列中删除
redisUtil.removeFromZSet(consumptionOrderKey, String.valueOf(uniqueId));
try { try {
// 2、判断当前消息是否在取消列表中 // 2、判断当前消息是否在取消列表中
Boolean isMember = redisUtil.isElementExistsInSet(cancelSetKey, String.valueOf(uniqueId)); Boolean isMember = redisUtil.isElementExistsInSet(cancelSetKey, uniqueId);
if (isMember) { if (isMember) {
try { try {
// 2.1 手动确认该消息 // 2.1 手动确认该消息
@@ -62,40 +61,43 @@ public class MQConsumer {
log.error("手动确认,不返回队列重新消费"); log.error("手动确认,不返回队列重新消费");
} }
// 2.2 将该消息从取消列表中删除 // 2.2 将该消息从取消列表中删除
redisUtil.removeFromSet(cancelSetKey, String.valueOf(uniqueId)); // redisUtil.removeFromSet(cancelSetKey, uniqueId);
} else { } else {
GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO); GenerateCollectionVO generateCollectionVO = generateService.generateThroughImageText(generateThroughImageTextDTO);
// try {
// Thread.sleep(15000);
// } catch (InterruptedException e) {
// throw new RuntimeException(e);
// }
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
if (!Objects.isNull(generateCollectionVO)){ if (!Objects.isNull(generateCollectionVO)){
HashMap<String, String> generateResult = new HashMap<>(); HashMap<String, String> generateResult = new HashMap<>();
generateResult.put(String.valueOf(uniqueId), JSONObject.toJSONString(generateCollectionVO)); generateResult.put(uniqueId, JSONObject.toJSONString(generateCollectionVO));
// 将结果存在redis中 ,为空时不要存 // 将结果存在redis中 ,为空时不要存
redisUtil.addToMap(resultMapKey, generateResult); redisUtil.addToMap(resultMapKey, generateResult);
} }
} }
} catch (BusinessException e) { } catch (BusinessException e) {
log.error(e.getMessage()); log.error(e.getMsg());
// channel.basicNack() 为不确认deliveryTag对应的消息第二个参数是否应用于多消息第三个参数是否requeue // channel.basicNack() 为不确认deliveryTag对应的消息第二个参数是否应用于多消息第三个参数是否requeue
try { try {
// 第二个参数是否批量确认消息当传false时只确认当前 deliveryTag对应的消息;当传true时会确认当前及之前所有未确认的消息。 // 第二个参数是否批量确认消息当传false时只确认当前 deliveryTag对应的消息;当传true时会确认当前及之前所有未确认的消息。
channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false); channel.basicAck(msg.getMessageProperties().getDeliveryTag(), false);
// 将消息从redis排队队列中删除,需保证被消费的消息存储到db之后再从redis删除
redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
} catch (IOException exception) { } catch (IOException exception) {
log.error("手动确认,取消返回队列,不再重新消费"); log.error("手动确认,取消返回队列,不再重新消费");
} }
// 将入参和错误信息存入数据库 // 将入参和错误信息存入数据库
String exceptionMessage = JSONObject.toJSONString(generateThroughImageTextDTO) + " Exception message " + e.getMessage(); String exceptionMessage = JSONObject.toJSONString(generateThroughImageTextDTO) +
" Exception message " + e.getMsg();
HashMap<String, String> exceptionInfo = new HashMap<>(); HashMap<String, String> exceptionInfo = new HashMap<>();
exceptionInfo.put(String.valueOf(uniqueId), exceptionMessage); exceptionInfo.put(String.valueOf(uniqueId), exceptionMessage);
// 存redis // 存redis
redisUtil.addToMap(exceptionMapKey, exceptionInfo); redisUtil.addToMap(exceptionMapKey, exceptionInfo);
} }
// log.info(JSONObject.parseObject(msg.getBody(), GenerateThroughImageTextDTO.class).toString());
// try {
// Thread.sleep(10000);
// } catch (InterruptedException e) {
// throw new RuntimeException(e);
// }
log.info("============end listening=========="); log.info("============end listening==========");
} }

View File

@@ -14,7 +14,7 @@ import java.util.concurrent.*;
@Component @Component
public class AsyncCallerUtil { public class AsyncCallerUtil {
public static Map<Long, Boolean> waitingStatus = new HashMap<>(); public static Map<String, Boolean> waitingStatus = new HashMap<>();
private static PythonService pythonService; private static PythonService pythonService;
@@ -27,9 +27,10 @@ public class AsyncCallerUtil {
return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython)); return CompletableFuture.supplyAsync(() -> pythonService.generateSketchOrPrint(generateToPython));
} }
public List<String> generate(GenerateToPythonDTO generateToPython, Long requestId) { public List<String> generate(GenerateToPythonDTO generateToPython) {
ScheduledExecutorService scheduledExecutorService = Executors.newScheduledThreadPool(1); ScheduledExecutorService scheduledExecutorService = Executors.newScheduledThreadPool(5);
waitingStatus.put(requestId, true); String taskId = generateToPython.getTasks_id();
waitingStatus.put(taskId, true);
ScheduledFuture<?> timeoutTask = null; ScheduledFuture<?> timeoutTask = null;
try { try {
@@ -37,10 +38,10 @@ public class AsyncCallerUtil {
// 10秒后第一次确认之后每隔10秒确认一次用户选择结果 // 10秒后第一次确认之后每隔10秒确认一次用户选择结果
timeoutTask = scheduledExecutorService.scheduleAtFixedRate(() -> { timeoutTask = scheduledExecutorService.scheduleAtFixedRate(() -> {
// 调用另一个接口获取用户的选择 // 调用另一个接口获取用户的选择
if (!waitingStatus.get(requestId)) { if (!waitingStatus.get(taskId)) {
// 如果用户选择取消则取消对generate的调用cancel判断是否成功取消 // 如果用户选择取消则取消对generate的调用cancel判断是否成功取消
generateResult.cancel(true); generateResult.cancel(true);
waitingStatus.remove(requestId); waitingStatus.remove(taskId);
} }
System.out.println("持续等待...... : " + DateUtil.getByTimeZone("Asia/Shanghai")); System.out.println("持续等待...... : " + DateUtil.getByTimeZone("Asia/Shanghai"));
}, 10, 10, TimeUnit.SECONDS); }, 10, 10, TimeUnit.SECONDS);
@@ -54,16 +55,15 @@ public class AsyncCallerUtil {
// 处理结果 // 处理结果
System.out.println("generate 响应: " + result); System.out.println("generate 响应: " + result);
System.out.println("schedule finish time : " + DateUtil.getByTimeZone("Asia/Shanghai")); System.out.println("schedule finish time : " + DateUtil.getByTimeZone("Asia/Shanghai"));
waitingStatus.remove(requestId); waitingStatus.remove(taskId);
return result; return result;
} catch (InterruptedException | ExecutionException | BusinessException e) { } catch (InterruptedException | ExecutionException | BusinessException e) {
// 处理异常 // 处理异常
log.error("发生错误 " + e); log.error("发生错误 " + e);
e.printStackTrace();
// 取消定时任务 // 取消定时任务
assert timeoutTask != null; assert timeoutTask != null;
timeoutTask.cancel(true); timeoutTask.cancel(true);
throw new BusinessException("generate.interface.error"); throw new BusinessException(e.getMessage());
} finally { } finally {
// 关闭线程池 // 关闭线程池
// executorService.shutdown(); // executorService.shutdown();

View File

@@ -55,22 +55,20 @@ public class GenerateController {
@ApiOperation(value = "发起生成请求,异步获取结果") @ApiOperation(value = "发起生成请求,异步获取结果")
@PostMapping("/prepare") @PostMapping("/prepare")
public Response<String> prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO){ public Response<String> prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) {
Long l = generateService.prepareForGenerate(generateThroughImageTextDTO); return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO));
// 防止long精度丢失这里转为String类型进行传输
return Response.success(String.valueOf(l));
} }
@ApiOperation(value = "取消继续生成") @ApiOperation(value = "取消继续生成")
@GetMapping("/stopWaiting") @GetMapping("/stopWaiting")
public Response<String> stopWaiting(@RequestParam("uniqueId") Long uniqueId){ public Response<String> stopWaiting(@RequestParam("uniqueId") String uniqueId) {
generateService.cancelGenerate(uniqueId); generateService.cancelGenerate(uniqueId);
return Response.success("stop waiting successfully"); return Response.success("stop waiting successfully");
} }
@ApiOperation(value = "获取生成结果") @ApiOperation(value = "获取生成结果")
@GetMapping("/result") @GetMapping("/result")
public Response<GenerateCollectionVO> getGenerateResult(@RequestParam("uniqueId") Long uniqueId){ public Response<GenerateCollectionVO> getGenerateResult(@RequestParam("uniqueId") String uniqueId) {
GenerateCollectionVO generateResult = generateService.getGenerateResult(uniqueId); GenerateCollectionVO generateResult = generateService.getGenerateResult(uniqueId);
return Response.success(generateResult); return Response.success(generateResult);
} }

View File

@@ -28,9 +28,9 @@ public class Generate {
private Long accountId; private Long accountId;
/** /**
* 唯一id,用于保持消息的唯一性 * 唯一id
*/ */
private Long uniqueId; private String uniqueId;
/** /**
* Sketchboard Printboard * Sketchboard Printboard

View File

@@ -46,5 +46,5 @@ public class GenerateThroughImageTextDTO {
String timeZone; String timeZone;
@ApiModelProperty("唯一id用于保持消息唯一性") @ApiModelProperty("唯一id用于保持消息唯一性")
Long uniqueId; String uniqueId;
} }

View File

@@ -22,4 +22,6 @@ public class GenerateToPythonDTO {
private String version; private String version;
private String gender; private String gender;
private String tasks_id;
} }

View File

@@ -56,6 +56,9 @@ public class PythonService {
private String accessPythonIp; private String accessPythonIp;
@Value("${access.python.port:''}") @Value("${access.python.port:''}")
private String accessPythonPort; private String accessPythonPort;
@Value("${access.generate.port:''}")
private String accessGeneratePort;
@Resource @Resource
private PythonTAllInfoService pythonTAllInfoService; private PythonTAllInfoService pythonTAllInfoService;
@@ -2251,9 +2254,10 @@ public class PythonService {
Request request = new Request.Builder() Request request = new Request.Builder()
// .url("http://18.167.251.121:9992") // .url("http://18.167.251.121:9992")
// .url("http://127.0.0.1:5000/api/diffusion") // .url("http://127.0.0.1:5000/api/diffusion")
.url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion") // .url(accessPythonIp + ":" + accessPythonPort + "/api/diffusion")
.url(accessPythonIp + ":" + accessPythonPort + "/api/generate_image")
.method("POST", body) .method("POST", body)
.addHeader("Authorization", "Basic YWlkbGFiOjEyMw==") // .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
.addHeader("Content-Type", "application/json") .addHeader("Content-Type", "application/json")
.build(); .build();
Response response = null; Response response = null;
@@ -2263,23 +2267,28 @@ public class PythonService {
response = client.newCall(request).execute(); response = client.newCall(request).execute();
} catch (IOException ioException) { } catch (IOException ioException) {
log.error("PythonService##generateSketchOrPrint异常###{}", ExceptionUtil.getThrowableList(ioException)); log.error("PythonService##generateSketchOrPrint异常###{}", ExceptionUtil.getThrowableList(ioException));
throw new BusinessException("generate.interface.error"); // throw new BusinessException("generate.interface.error");
throw new BusinessException(ioException.getMessage());
} }
//去除限流 //去除限流
AccessLimitUtils.validateOut("generateSketchOrPrint"); AccessLimitUtils.validateOut("generateSketchOrPrint");
// 判断是否生成失败 // 判断是否生成失败
if (Objects.isNull(response) || Objects.isNull(response.body())) { if (Objects.isNull(response.body())) {
log.error("PythonService##generateSketchOrPrint异常###{}", "response or body is empty!"); log.error("PythonService##generateSketchOrPrint异常###{}", "response or body is empty!");
throw new BusinessException("generate.interface.error"); // throw new BusinessException("generate.interface.error");
throw new BusinessException("PythonService##generateSketchOrPrint异常###: response or body is empty!");
} else if (response.code() != HttpURLConnection.HTTP_OK){ } else if (response.code() != HttpURLConnection.HTTP_OK){
log.error("PythonService##generateSketchOrPrint异常###{}", "Response error!Response code ## " + response.code() + " ##"); log.error("PythonService##generateSketchOrPrint异常###{}", "Response error!Response code ## " + response.code() + " ##");
throw new BusinessException("generate.interface.error"); // throw new BusinessException("generate.interface.error");
throw new BusinessException("PythonService##generateSketchOrPrint异常### Response error!Response code ## " + response.code() + " ##");
} else { } else {
try { try {
bodyString = response.body().string(); bodyString = response.body().string();
} catch (IOException e) { } catch (IOException e) {
throw new BusinessException("generate.interface.error"); log.error(e.getMessage());
// throw new BusinessException("generate.interface.error");
throw new BusinessException(e.getMessage());
} }
} }
JSONObject jsonObject = JSON.parseObject(bodyString); JSONObject jsonObject = JSON.parseObject(bodyString);
@@ -2417,4 +2426,28 @@ public class PythonService {
//生成失败 //生成失败
throw new BusinessException("cloth-classification.interface.exception"); throw new BusinessException("cloth-classification.interface.exception");
} }
public void cancelGenerateTask(String taskId){
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
HttpUrl.Builder builder = HttpUrl.parse(accessPythonIp + ":" + accessGeneratePort + "/cancel_task").newBuilder();
builder.addQueryParameter("task_id",taskId);
Request request = new Request.Builder()
.url(builder.build().toString())
.addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
.addHeader("Content-Type", "application/json")
.build();
try {
log.info("getGenerateResult请求入参content###{}", taskId);
client.newCall(request).execute();
} catch (IOException ioException) {
log.error("PythonService##getGenerateResult异常###{}", ExceptionUtil.getThrowableList(ioException));
throw new BusinessException("generate.interface.error");
}
}
} }

View File

@@ -25,12 +25,12 @@ public interface GenerateService extends IService<Generate> {
List<GenerateDetail> selectBatchByLibraryId(List<Long> libraryId); List<GenerateDetail> selectBatchByLibraryId(List<Long> libraryId);
GenerateCollectionVO getGenerateResult(Long uniqueId); GenerateCollectionVO getGenerateResult(String uniqueId);
Long prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO); String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO);
Long getRankPosition(Long uniqueId); Long getRankPosition(String uniqueId);
void cancelGenerate(Long uniqueId); void cancelGenerate(String uniqueId);
} }

View File

@@ -24,6 +24,7 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import io.minio.errors.MinioException; import io.minio.errors.MinioException;
import io.netty.util.internal.StringUtil; import io.netty.util.internal.StringUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@@ -35,6 +36,7 @@ import java.util.*;
import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*; import static com.ai.da.common.enums.CollectionLevel1TypeEnum.*;
@Slf4j
@Service @Service
public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> implements GenerateService { public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> implements GenerateService {
@@ -77,7 +79,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Override @Override
public GenerateCaptionVO generateCaption(Long sketchElementId) { public GenerateCaptionVO generateCaption(Long sketchElementId) {
CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId); CollectionElement collectionElement = collectionElementMapper.selectById(sketchElementId);
if (Objects.isNull(collectionElement)){ if (Objects.isNull(collectionElement)) {
throw new BusinessException("the.image.does.not.exist.please.reselect"); throw new BusinessException("the.image.does.not.exist.please.reselect");
} }
String url = collectionElement.getUrl(); String url = collectionElement.getUrl();
@@ -101,14 +103,14 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generate.setLevel1Type(generateThroughImageTextDTO.getLevel1Type()); generate.setLevel1Type(generateThroughImageTextDTO.getLevel1Type());
// 当level1type是sketchboard时存数据库需要加上当前性别 // 当level1type是sketchboard时存数据库需要加上当前性别
generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? generate.setGenerateType(generate.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ?
generateType + " (" +generateThroughImageTextDTO.getGender() + ")": generateType + " (" + generateThroughImageTextDTO.getGender() + ")" :
generateType); generateType);
generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getVersion()); generate.setModelName(StringUtil.isNullOrEmpty(generateThroughImageTextDTO.getVersion()) ? ModelNameEnum.MODEL_0.getCode() : generateThroughImageTextDTO.getVersion());
generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone())); generate.setCreateDate(DateUtil.getByTimeZone(generateThroughImageTextDTO.getTimeZone()));
String text = generateThroughImageTextDTO.getText(); String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId(); Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(generate, text, elementId,generateType); validateGeneraType(generate, text, elementId, generateType);
// 2.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type // 2.1 sketch或print在t_collection_element表中的信息是否需要更新 如 level2Type
CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type()); CollectionElement collectionElement = collectionElementService.editLevel2Type(elementId, generateThroughImageTextDTO.getLevel2Type());
@@ -120,11 +122,11 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" : String category = generateThroughImageTextDTO.getLevel1Type().equals(SKETCH_BOARD.getRealName()) ? "sketch" :
generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard"; generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName()) ? "print" : "moodboard";
AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil(); AsyncCallerUtil asyncCallerUtil = new AsyncCallerUtil();
List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? null : collectionElement.getUrl(), List<String> generatedSketchUrl = asyncCallerUtil.generate(new GenerateToPythonDTO(accountId, Objects.isNull(collectionElement) ? "" : collectionElement.getUrl(),
category, text, mode, "1", generateThroughImageTextDTO.getGender()),generateThroughImageTextDTO.getUniqueId()); category, text, mode, "1", generateThroughImageTextDTO.getGender() ,generateThroughImageTextDTO.getUniqueId()));
// List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(), // List<String> generatedSketchUrl = pythonService.generateSketchOrPrint(new GenerateToPythonDTO(accountId, Objects.isNull(elementId) ? null : collectionElement.getUrl(),
// category, text, mode, "1", generateThroughImageTextDTO.getGender())); // category, text, mode, "1", generateThroughImageTextDTO.getGender()));
if (CollectionUtils.isEmpty(generatedSketchUrl)){ if (CollectionUtils.isEmpty(generatedSketchUrl)) {
return null; return null;
} }
@@ -139,8 +141,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO(); GenerateCollectionItemVO generateCollectionItemVO = new GenerateCollectionItemVO();
String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(item, 24 * 60), Boolean.FALSE); String md5 = MD5Utils.encryptFile(minioUtil.getPresignedUrl(item, 24 * 60), Boolean.FALSE);
// 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过 // 通过MD5值和level1Type,判断不同level1Type下相同的图片是否被like过
List<Map<String,Long>> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generateThroughImageTextDTO.getLevel1Type()); List<Map<String, Long>> libraryIdList = generateDetailMapper.getLibraryIdThroughMD5(md5, generateThroughImageTextDTO.getLevel1Type());
if (!libraryIdList.isEmpty()){ if (!libraryIdList.isEmpty()) {
generateDetail.setIsLike((byte) 1); generateDetail.setIsLike((byte) 1);
generateDetail.setLibraryId(libraryIdList.get(0).get("library_id")); generateDetail.setLibraryId(libraryIdList.get(0).get("library_id"));
generateCollectionItemVO.setIsLiked(Boolean.TRUE); generateCollectionItemVO.setIsLiked(Boolean.TRUE);
@@ -161,22 +163,22 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems); return new GenerateCollectionVO(generate.getId(), collectionId, generatedCollectionItems);
} }
private void validateGeneraType(Generate generate, String text, Long elementId,String generateType) { private void validateGeneraType(Generate generate, String text, Long elementId, String generateType) {
switch (generateType) { switch (generateType) {
case "text": case "text":
if (StringUtil.isNullOrEmpty(text)){ if (StringUtil.isNullOrEmpty(text)) {
throw new BusinessException("please.input.the.caption"); throw new BusinessException("please.input.the.caption");
} }
generate.setText(text); generate.setText(text);
break; break;
case "image": case "image":
if (Objects.isNull(elementId)){ if (Objects.isNull(elementId)) {
throw new BusinessException("please.choose.an.image"); throw new BusinessException("please.choose.an.image");
} }
generate.setCollectionElementId(elementId); generate.setCollectionElementId(elementId);
break; break;
case "text-image": case "text-image":
if (StringUtil.isNullOrEmpty(text) || Objects.isNull(elementId)){ if (StringUtil.isNullOrEmpty(text) || Objects.isNull(elementId)) {
throw new BusinessException("please.input.the.caption.and.choose.an.image"); throw new BusinessException("please.input.the.caption.and.choose.an.image");
} }
generate.setText(text); generate.setText(text);
@@ -191,21 +193,21 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 1、判断参数是否正确 // 1、判断参数是否正确
// 1.1 必须参数是否非空 // 1.1 必须参数是否非空
if (SKETCH_BOARD.getRealName().equals(generateLikeDTO.getLevel1Type())) { if (SKETCH_BOARD.getRealName().equals(generateLikeDTO.getLevel1Type())) {
if (StringUtil.isNullOrEmpty(generateLikeDTO.getLevel2Type())){ if (StringUtil.isNullOrEmpty(generateLikeDTO.getLevel2Type())) {
throw new BusinessException("level2Type.cannot.be.empty"); throw new BusinessException("level2Type.cannot.be.empty");
} }
if (StringUtil.isNullOrEmpty(generateLikeDTO.getGender())){ if (StringUtil.isNullOrEmpty(generateLikeDTO.getGender())) {
throw new BusinessException("gender.cannot.be.empty"); throw new BusinessException("gender.cannot.be.empty");
} }
} }
// 1.2 判断参数是否真实有效 // 1.2 判断参数是否真实有效
Long generateDetailId = generateLikeDTO.getGenerateDetailId(); Long generateDetailId = generateLikeDTO.getGenerateDetailId();
GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId); GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId);
if (Objects.isNull(generateDetail)){ if (Objects.isNull(generateDetail)) {
throw new BusinessException("generateItem.does.not.exist"); throw new BusinessException("generateItem.does.not.exist");
} }
Generate generate = getById(generateDetail.getGenerateId()); Generate generate = getById(generateDetail.getGenerateId());
if (!generateLikeDTO.getLevel1Type().equals(generate.getLevel1Type())){ if (!generateLikeDTO.getLevel1Type().equals(generate.getLevel1Type())) {
throw new BusinessException("level1Type.does.not.match"); throw new BusinessException("level1Type.does.not.match");
} }
@@ -213,8 +215,8 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
// 2.1、不能重复喜欢 // 2.1、不能重复喜欢
// 2.1.1 判断该图片是否被喜欢过 // 2.1.1 判断该图片是否被喜欢过
Library libraryDetail = libraryService.getById(generateDetail.getLibraryId()); Library libraryDetail = libraryService.getById(generateDetail.getLibraryId());
if ( (Objects.nonNull(generateDetail.getLibraryId()) && !generateDetail.getLibraryId().equals(0L)) if ((Objects.nonNull(generateDetail.getLibraryId()) && !generateDetail.getLibraryId().equals(0L))
|| Objects.nonNull(libraryDetail)){ || Objects.nonNull(libraryDetail)) {
throw new BusinessException("duplicate.likes.are.not.allowed"); throw new BusinessException("duplicate.likes.are.not.allowed");
} }
@@ -237,7 +239,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
public Boolean generateDislike(Long generateDetailId, String timeZone) { public Boolean generateDislike(Long generateDetailId, String timeZone) {
// 1、确定generateDetail中是否有这条记录 // 1、确定generateDetail中是否有这条记录
GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId); GenerateDetail generateDetail = generateDetailMapper.selectById(generateDetailId);
if (Objects.isNull(generateDetail)){ if (Objects.isNull(generateDetail)) {
throw new BusinessException("generateItem.does.not.exist"); throw new BusinessException("generateItem.does.not.exist");
} }
@@ -287,7 +289,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generateDetailMapper.update(generateDetail, queryWrapper); generateDetailMapper.update(generateDetail, queryWrapper);
} }
public void updateLikeStatusBatch(List<Long> generateDetailIdList, Byte hasLike, Long libraryId, String timeZone){ public void updateLikeStatusBatch(List<Long> generateDetailIdList, Byte hasLike, Long libraryId, String timeZone) {
QueryWrapper<GenerateDetail> queryWrapper = new QueryWrapper<>(); QueryWrapper<GenerateDetail> queryWrapper = new QueryWrapper<>();
queryWrapper.in("id", generateDetailIdList); queryWrapper.in("id", generateDetailIdList);
@@ -299,93 +301,101 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
generateDetailMapper.update(generateDetail, queryWrapper); generateDetailMapper.update(generateDetail, queryWrapper);
} }
public List<GenerateDetail> selectBatchByLibraryId(List<Long> libraryId){ public List<GenerateDetail> selectBatchByLibraryId(List<Long> libraryId) {
QueryWrapper<GenerateDetail> qw = new QueryWrapper<>(); QueryWrapper<GenerateDetail> qw = new QueryWrapper<>();
qw.in("library_id",libraryId); qw.in("library_id", libraryId);
return generateDetailMapper.selectList(qw); return generateDetailMapper.selectList(qw);
} }
@Override @Override
public Long prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { public String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) {
// 1、参数检查判断必须参数是否为空 // 1、参数检查判断必须参数是否为空
if (Objects.isNull(generateThroughImageTextDTO.getUserId())){ if (Objects.isNull(generateThroughImageTextDTO.getUserId())) {
throw new BusinessException("userId cannot be empty"); throw new BusinessException("userId cannot be empty");
} }
String generateType = generateThroughImageTextDTO.getGenerateType(); String generateType = generateThroughImageTextDTO.getGenerateType();
if (!GenerateModeEnum.getGenerateModeList().contains(generateType)){ if (!GenerateModeEnum.getGenerateModeList().contains(generateType)) {
throw new BusinessException("unknown.generate.type"); throw new BusinessException("unknown.generate.type");
} }
String text = generateThroughImageTextDTO.getText(); String text = generateThroughImageTextDTO.getText();
Long elementId = generateThroughImageTextDTO.getCollectionElementId(); Long elementId = generateThroughImageTextDTO.getCollectionElementId();
validateGeneraType(new Generate(), text, elementId,generateType); validateGeneraType(new Generate(), text, elementId, generateType);
// 2、确定当前排队人数总数超过15个暂停使用当前功能 // 2、生成唯一id 使用uuid
Long zSetTotal = redisUtil.getZSetTotal(consumptionOrderKey); String uuid = UUID.randomUUID().toString();
if (zSetTotal.equals(15L)){
return null; // SnowflakeUtil idWorker = new SnowflakeUtil(0, 0);
// long snowflakeId = idWorker.nextId();
int num = 1;
// 判断与已经正常生成结果的uuid中有相同的id
while (redisUtil.isElementExistsInMap(resultMapKey, uuid) && num < 10) {
uuid = UUID.randomUUID().toString();
num++;
} }
// 无依据确定的数字
// 3、生成唯一id if (num > 10){
SnowflakeUtil idWorker = new SnowflakeUtil(0, 0); try {
long snowflakeId = idWorker.nextId(); Thread.sleep(1000);
} catch (InterruptedException e) {
if (AsyncCallerUtil.waitingStatus.containsKey(snowflakeId)){ throw new RuntimeException(e);
snowflakeId = idWorker.nextId(); }
uuid = UUID.randomUUID().toString();
} }
generateThroughImageTextDTO.setUniqueId(snowflakeId); generateThroughImageTextDTO.setUniqueId(uuid);
String jsonString = JSON.toJSONString(generateThroughImageTextDTO); String jsonString = JSON.toJSONString(generateThroughImageTextDTO);
// 4、加入redis排队便于获取实时排队信息 // 3、加入redis排队便于获取实时排队信息
Double maxScore = redisUtil.getMaxScore(consumptionOrderKey); Double maxScore = redisUtil.getMaxScore(consumptionOrderKey);
redisUtil.addToZSet(consumptionOrderKey, String.valueOf(snowflakeId),maxScore); redisUtil.addToZSet(consumptionOrderKey, uuid, maxScore);
// 5、将消息发布到MQ消息队列 // 4、将消息发布到MQ消息队列
rabbitMQService.publishMessage(jsonString); rabbitMQService.publishMessage(jsonString);
// 6、返回唯一id // 5、返回唯一id
return snowflakeId; return uuid;
} }
@Override @Override
public Long getRankPosition(Long uniqueId) { public Long getRankPosition(String uniqueId) {
return redisUtil.getRank(consumptionOrderKey, String.valueOf(uniqueId)); return redisUtil.getRank(consumptionOrderKey, uniqueId);
} }
@Override @Override
public GenerateCollectionVO getGenerateResult(Long uniqueId) { public GenerateCollectionVO getGenerateResult(String uniqueId) {
// 1、判断该请求是否已经异常 // 1、判断该请求是否已经异常
Boolean isMember = redisUtil.isElementExistsInMap(exceptionMapKey, String.valueOf(uniqueId)); Boolean isMember = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId);
if (isMember){ if (isMember) {
throw new BusinessException("generate.interface.error"); throw new BusinessException("generate.interface.error");
} }
// 2、判断该请求是否还在排队 // 2、判断该请求是否还在排队
Boolean existsInZSet = redisUtil.isElementExistsInZSet(consumptionOrderKey, String.valueOf(uniqueId)); Boolean existsInZSet = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId);
if (existsInZSet){ if (existsInZSet) {
// 排队中,给出当前排序位置 // 排队中,给出当前排序位置
return new GenerateCollectionVO(getRankPosition(uniqueId) + 1L); return new GenerateCollectionVO(getRankPosition(uniqueId) + 1L);
} }
// 3、判断redis中有没有 // 3、判断redis中有没有
boolean hasHashKey = redisUtil.isElementExistsInMap(resultMapKey, String.valueOf(uniqueId)); boolean hasHashKey = redisUtil.isElementExistsInMap(resultMapKey, uniqueId);
if (hasHashKey){ if (hasHashKey) {
// 3.1 有直接从redis中拿 // 3.1 有直接从redis中拿
String resultString = redisUtil.getMapValue(resultMapKey, String.valueOf(uniqueId)); String resultString = redisUtil.getMapValue(resultMapKey, uniqueId);
return JSONObject.parseObject(resultString,GenerateCollectionVO.class); return JSONObject.parseObject(resultString, GenerateCollectionVO.class);
} }
// 3.2 判断数据库中有没有 // 3.2 判断数据库中有没有
Generate generate = selectByUniqueId(uniqueId); Generate generate = selectByUniqueId(uniqueId);
if (Objects.isNull(generate)){ if (Objects.isNull(generate)) {
// 3.3 还没执行完,给出当前位置 // 3.3 还没执行完,给出当前位置
return new GenerateCollectionVO(0L); return new GenerateCollectionVO(0L);
} }
Long generateId = generate.getId(); Long generateId = generate.getId();
QueryWrapper<GenerateDetail> qw = new QueryWrapper<>(); QueryWrapper<GenerateDetail> qw = new QueryWrapper<>();
qw.eq("generate_id",generateId); qw.eq("generate_id", generateId);
List<GenerateDetail> generateDetails = generateDetailMapper.selectList(qw); List<GenerateDetail> generateDetails = generateDetailMapper.selectList(qw);
if (CollectionUtils.isEmpty(generateDetails)){ if (CollectionUtils.isEmpty(generateDetails)) {
// 会有这种情况吗存到generate中但是还没存到generateDetail中 // 会有这种情况吗存到generate中但是还没存到generateDetail中
return new GenerateCollectionVO(0L); return new GenerateCollectionVO(0L);
} }
@@ -401,7 +411,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
return new GenerateCollectionVO(generateId, null, generatedCollectionItems); return new GenerateCollectionVO(generateId, null, generatedCollectionItems);
} }
public Generate selectByUniqueId(Long uniqueId){ public Generate selectByUniqueId(String uniqueId){
QueryWrapper<Generate> qw = new QueryWrapper<>(); QueryWrapper<Generate> qw = new QueryWrapper<>();
qw.eq("unique_id",uniqueId); qw.eq("unique_id",uniqueId);
@@ -409,23 +419,24 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
} }
@Override @Override
public void cancelGenerate(Long uniqueId) { public void cancelGenerate(String uniqueId) {
// 1、确认当前消息是否还在排队中 // 1、确认当前消息是否还在排队中
Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, String.valueOf(uniqueId)); Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId);
if (exists){ if (exists) {
// 1.1、将需要取消的唯一id加入redis以便及时取消生成 // 1.1、将需要取消的唯一id加入redis以便及时取消生成
redisUtil.addToSet(cancelSetKey, String.valueOf(uniqueId)); redisUtil.addToSet(cancelSetKey, uniqueId);
// 1.2 将需要取消的id从redis的ConsumptionOrder中删除 // 1.2 将需要取消的id从redis的ConsumptionOrder中删除
redisUtil.removeFromZSet(consumptionOrderKey, String.valueOf(uniqueId)); redisUtil.removeFromZSet(consumptionOrderKey, uniqueId);
}else { }else {
// 2、判断该消息是否异常 // 2、判断该消息是否异常
boolean hasKey = redisUtil.isElementExistsInMap(exceptionMapKey, String.valueOf(uniqueId)); boolean hasKey = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId);
// 3、判断该消息是否已经消费结束 // 3、判断该消息是否已经消费结束
Boolean existsInResult = redisUtil.isElementExistsInMap(resultMapKey, String.valueOf(uniqueId)); Boolean existsInResult = redisUtil.isElementExistsInMap(resultMapKey, uniqueId);
if (!hasKey && !existsInResult){ if (!hasKey && !existsInResult){
// 设置取等待状态为false // 设置取等待状态为false
AsyncCallerUtil.waitingStatus.put(uniqueId,false); AsyncCallerUtil.waitingStatus.put(uniqueId,false);
// 3、直接发送取消请求到python端 // 3、直接发送取消请求到python端
pythonService.cancelGenerateTask(uniqueId);
} }
} }
} }

View File

@@ -50,8 +50,8 @@ spring.servlet.multipart.max-request-size= 5MB
#access.python.ip=http://43.198.80.117 #access.python.ip=http://43.198.80.117
access.python.ip=http://18.167.251.121 access.python.ip=http://18.167.251.121
#access.python.ip=http://18.167.251.121:9991/ #access.python.ip=http://18.167.251.121:9991/
#access.python.port=9992 access.python.port=9992
access.python.port=9991 #access.python.port=9991
# minIO服务配置之信息 # minIO服务配置之信息
minio.endpoint=https://www.minio.aida.com.hk:9000 minio.endpoint=https://www.minio.aida.com.hk:9000
@@ -70,7 +70,8 @@ spring.rabbitmq.username=rabbit
spring.rabbitmq.password=123456 spring.rabbitmq.password=123456
spring.rabbitmq.virtual-host=/ spring.rabbitmq.virtual-host=/
spring.redis.host=172.31.11.32 #spring.redis.host=172.31.11.32
spring.redis.host=18.167.251.121
spring.redis.port=6379 spring.redis.port=6379
spring.redis.database=1 spring.redis.database=1
spring.redis.password=Aidlab spring.redis.password=Aidlab