From ffeaac8c460bd00e89e700c77cf212ca5a606d9a Mon Sep 17 00:00:00 2001 From: xupei Date: Thu, 2 May 2024 10:33:11 +0800 Subject: [PATCH 1/2] =?UTF-8?q?generate=20=E4=BB=A3=E7=A0=81=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../da/common/RabbitMQ/GenerateConsumer.java | 4 +- .../com/ai/da/service/GenerateService.java | 2 +- .../da/service/impl/GenerateServiceImpl.java | 44 +++++++------------ .../impl/SuperResolutionServiceImpl.java | 13 +----- 4 files changed, 20 insertions(+), 43 deletions(-) diff --git a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java index f9edde27..3e8e6a6e 100644 --- a/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java +++ b/src/main/java/com/ai/da/common/RabbitMQ/GenerateConsumer.java @@ -10,6 +10,7 @@ import com.alibaba.fastjson.JSONObject; import com.google.gson.Gson; import com.rabbitmq.client.Channel; import lombok.extern.slf4j.Slf4j; +import org.apache.tomcat.jni.Time; import org.springframework.amqp.core.Message; import org.springframework.amqp.rabbit.annotation.RabbitHandler; import org.springframework.amqp.rabbit.annotation.RabbitListener; @@ -41,9 +42,6 @@ public class GenerateConsumer { @Value("${redis.key.generateExceptionMap}") private String exceptionMapKey; - @Value("${redis.key.resultMap}") - private String resultMapKey; - @Value("${redis.key.generateResult}") private String generateResultKey; diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index c59513c8..a3f6ed7d 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -25,7 +25,7 @@ public interface GenerateService extends IService { List selectBatchByLibraryId(List libraryId); - GenerateCollectionVO getGenerateResult(String uniqueId); +// GenerateCollectionVO getGenerateResult(String uniqueId); List getGenerateResultList(List taskIdList); diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 1755eb60..1f56a4a9 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -82,9 +82,6 @@ public class GenerateServiceImpl extends ServiceImpl i @Value("${redis.key.generateExceptionMap}") private String exceptionMapKey; - @Value("${redis.key.resultMap}") - private String resultMapKey; - @Value("${redis.key.generateResult}") private String generateResultKey; @@ -428,27 +425,9 @@ public class GenerateServiceImpl extends ServiceImpl i Long elementId = generateThroughImageTextDTO.getCollectionElementId(); validateGeneraType(new Generate(), text, elementId, generateType); - // 2、生成唯一id 使用uuid + // 2、生成唯一id 使用uuid,由于uuid重复的几率很小,故取消对uuid重复性的校验 String uuid = UUID.randomUUID().toString(); - int num = 1; - // 判断已经正常生成结果的uuid或正在排队的uuid中是否有相同的id - while ((redisUtil.isElementExistsInMap(resultMapKey, uuid) || - redisUtil.isElementExistsInZSet(consumptionOrderKey, uuid)) - && num < 10) { - uuid = UUID.randomUUID().toString(); - num++; - } - // 无依据确定的数字 - if (num > 10) { - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - uuid = UUID.randomUUID().toString(); - } - ArrayList taskIdList = new ArrayList<>(); for (int i = 1; i <= 4; i++) { String temp = uuid; @@ -480,7 +459,7 @@ public class GenerateServiceImpl extends ServiceImpl i return redisUtil.getRank(consumptionOrderKey, uniqueId); } - @Override + /*@Override public GenerateCollectionVO getGenerateResult(String uniqueId) { // 1、判断该请求是否已经异常 Boolean isMember = redisUtil.isElementExistsInMap(exceptionMapKey, uniqueId); @@ -529,7 +508,7 @@ public class GenerateServiceImpl extends ServiceImpl i }); return new GenerateCollectionVO(generateId, null, generatedCollectionItems); - } + }*/ @Override public List getGenerateResultList(List taskIdList) { @@ -579,7 +558,10 @@ public class GenerateServiceImpl extends ServiceImpl i public void cancelGenerate(Long userId, List uniqueIdList, String timeZone) { // todo 取消待优化 uniqueIdList.forEach(uniqueId -> { - // 1、确认当前消息是否还在排队中 + // 1、将需要取消的唯一id加入redis,以便及时取消生成 + redisUtil.addToSet(cancelSetKey, uniqueId); + + /*// 1、确认当前消息是否还在排队中 Boolean exists = redisUtil.isElementExistsInZSet(consumptionOrderKey, uniqueId); Boolean flag = Boolean.FALSE; if (exists) flag = redisUtil.getRank(consumptionOrderKey, uniqueId) > 1L ? Boolean.TRUE : Boolean.FALSE; @@ -600,9 +582,17 @@ public class GenerateServiceImpl extends ServiceImpl i // 3、直接发送取消请求到python端 pythonService.cancelGenerateTask(uniqueId); } - } + }*/ + String key = generateResultKey + ":" + uniqueId; - redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(uniqueId, null, null, "Cancelled")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); + GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class); + // 判断当前task的状态是不是Fail + if (!generateResultVO.getStatus().equals("Fail")){ + // 2、不是,直接发送取消请求到python端 + pythonService.cancelGenerateTask(uniqueId); + // 3、更改result中当前taskId的状态 + redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(uniqueId, null, null, "Cancelled")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); + } // 3、考虑加一张表,专门用于记录哪些用户在什么时间进行了取消操作,包括已经异常的请求 GenerateCancel generateCancel = new GenerateCancel(userId, uniqueId, DateUtil.getByTimeZone(timeZone)); diff --git a/src/main/java/com/ai/da/service/impl/SuperResolutionServiceImpl.java b/src/main/java/com/ai/da/service/impl/SuperResolutionServiceImpl.java index eeb565e8..ee1fb6ff 100644 --- a/src/main/java/com/ai/da/service/impl/SuperResolutionServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/SuperResolutionServiceImpl.java @@ -57,9 +57,6 @@ public class SuperResolutionServiceImpl extends ServiceImpl Date: Tue, 7 May 2024 18:27:32 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=B0=86generate=E8=BE=93=E5=85=A5?= =?UTF-8?q?=E7=9A=84=E6=96=87=E6=9C=AC=E8=BF=9B=E8=A1=8C=E7=BF=BB=E8=AF=91?= =?UTF-8?q?=E6=88=96=E5=BE=AE=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../com/ai/da/common/RabbitMQ/MQConfig.java | 1 + .../java/com/ai/da/python/PythonService.java | 55 +++++++++++++++++++ .../da/service/impl/GenerateServiceImpl.java | 13 +++-- 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java index e9df0264..7079f02a 100644 --- a/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java +++ b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java @@ -10,6 +10,7 @@ public class MQConfig { public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange"; // public static final String GENERATE_QUEUE = "generate-queue-prod"; // public static final String GENERATE_QUEUE = "generate-queue-test"; +// ================================================================== // public static final String GENERATE_QUEUE = "generate-queue-local"; public static final String GENERATE_QUEUE = "generate-queue-dev"; diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index 3f0acc73..d8220f51 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -3151,4 +3151,59 @@ public class PythonService { return Boolean.TRUE; } + public String promptTranslate(String text) throws BusinessException { + OkHttpClient client = new OkHttpClient().newBuilder() + .connectTimeout(30, TimeUnit.SECONDS) + .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒) + .readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒) + .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) + .build(); + MediaType mediaType = MediaType.parse("application/json"); + + HashMap content = new HashMap<>(); + content.put("text", text); + + String jsonString = JSON.toJSONString(content, SerializerFeature.WriteNullStringAsEmpty); + RequestBody body = RequestBody.create(mediaType, jsonString); + Request request = new Request.Builder() + .url(accessPythonIp + ":" + accessPythonPort + "/api/translateToEN") + .method("POST", body) + .addHeader("Content-Type", "application/json") + .build(); + Response response = null; + try { + log.info("promptTranslation请求入参content###{}", jsonString); + response = client.newCall(request).execute(); + } catch (IOException ioException) { + log.error("PythonService##promptTranslation异常###{}", ExceptionUtil.getThrowableList(ioException)); + return text; + } + int responseCode = response.code(); + String bodyString; + try { + bodyString = response.body().string(); + if (responseCode != HttpURLConnection.HTTP_OK) { + // 基本不会有除200以外的code + log.info("promptTranslation 用户输入翻译失败。 Response code " + responseCode); + throw new BusinessException("promptTranslation 用户输入翻译失败。 Response code " + responseCode); + } + JSONObject jsonObject = JSON.parseObject(bodyString); + Boolean result = JSON.parseObject(JSON.toJSONString(response)).getBoolean("successful"); + if (result && jsonObject.get("msg").equals("OK!")) { + String translated = jsonObject.get("data").toString(); + log.info("翻译或处理后的文本 : {}", translated); + return translated; + } + } catch (IOException e) { + log.error("promptTranslation 用户输入翻译失败; error message => " + e.getMessage()); + throw new RuntimeException(e); + + } finally { + response.close(); + } + + log.info("promptTranslation 用户输入翻译失败,返回用户输入"); + return text; + } + } diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 1f56a4a9..48f3655a 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -121,7 +121,7 @@ public class GenerateServiceImpl extends ServiceImpl i String text = generateThroughImageTextDTO.getText(); Long elementId = generateThroughImageTextDTO.getCollectionElementId(); validateGeneraType(generate, text, elementId, generateType); - if (generateType.equals("text") || generateType.equals("text-image")) { + if (!StringUtil.isNullOrEmpty(text)) { text = modifyPrompt(text, generate, generateThroughImageTextDTO.getLevel1Type()); } @@ -255,24 +255,25 @@ public class GenerateServiceImpl extends ServiceImpl i private String modifyPrompt(String userInput, Generate generate, String level1Type) { String text = ""; + String translated = pythonService.promptTranslate(userInput); switch (level1Type) { case "Moodboard": - text = userInput + ",high quality"; + text = translated + ",high quality"; generate.setText(text); break; case "Printboard": if (userInput.contains("Painting Style")) { - userInput = "Picasso,increased color saturation,increased glossiness," + userInput; + userInput = "Picasso,increased color saturation,increased glossiness," + translated; } else if (userInput.contains("Illustration Style")) { - userInput = "Flat coating,romantic,soft,pencil strokes,accentuating and widening the depth of pencil strokes,paper patterns,block colors,crayons,reducing image contrast,and hand drawn painting marks," + userInput; + userInput = "Flat coating,romantic,soft,pencil strokes,accentuating and widening the depth of pencil strokes,paper patterns,block colors,crayons,reducing image contrast,and hand drawn painting marks," + translated; } else if (userInput.contains("Real Style")) { - userInput = "Still life photography,hyper realism,3d,deepened projection,increased permutation value,increased concavity and convexity value," + userInput; + userInput = "Still life photography,hyper realism,3d,deepened projection,increased permutation value,increased concavity and convexity value," + translated; } text = userInput + ", fabric print, high quality"; generate.setText(text); break; case "Sketchboard": - text = "clear lines, simple outlines monochrome white vector image of " + userInput + ", no background, sketch flat, front view display, best quality, ultra-high resolution 8k"; + text = "clear lines, simple outlines monochrome white vector image of " + translated + ", no background, sketch flat, front view display, best quality, ultra-high resolution 8k"; generate.setText(text); default: }