diff --git a/src/main/java/com/ai/da/common/constant/CommonConstant.java b/src/main/java/com/ai/da/common/constant/CommonConstant.java index eb325fe6..09496110 100644 --- a/src/main/java/com/ai/da/common/constant/CommonConstant.java +++ b/src/main/java/com/ai/da/common/constant/CommonConstant.java @@ -23,6 +23,10 @@ public class CommonConstant { public static final String GENERATE_SLOGAN = "/api/slogan"; + public static final String GENERATE_CANCEL = "/api/generate_cancel/"; + + public static final String GENERATE_LOGO_SINGLE_CANCEL = "/api/generate_single_logo_cancel/"; + public static final String PYTHON_PORT_9996 = "9996"; public static final String PYTHON_PORT_9997 = "9997"; diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java index 74743856..caa80dcc 100644 --- a/src/main/java/com/ai/da/controller/GenerateController.java +++ b/src/main/java/com/ai/da/controller/GenerateController.java @@ -63,8 +63,9 @@ public class GenerateController { @GetMapping("/stopWaiting") public Response stopWaiting(@RequestParam("userId") Long userId, @RequestParam("uniqueId") List uniqueId, - @RequestParam("timeZone") String timeZone) { - generateService.cancelGenerate(userId, uniqueId, timeZone); + @RequestParam("timeZone") String timeZone, + @RequestParam("type") String type) { + generateService.cancelGenerate(userId, uniqueId, timeZone, type); return Response.success("stop waiting successfully"); } diff --git a/src/main/java/com/ai/da/python/PythonService.java b/src/main/java/com/ai/da/python/PythonService.java index 66d18ba9..f64dc3ea 100644 --- a/src/main/java/com/ai/da/python/PythonService.java +++ b/src/main/java/com/ai/da/python/PythonService.java @@ -3174,15 +3174,15 @@ public class PythonService { throw new BusinessException("cloth-classification.interface.exception"); } - public Boolean cancelGenerateTask(String taskId) { + public Boolean cancelGenerateTask(String taskId, String path) { OkHttpClient client = new OkHttpClient().newBuilder() .connectTimeout(30, TimeUnit.SECONDS) .pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒) .readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒) .writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒) .build(); -// String url = accessPythonIp + ":" + accessPythonPort + "/api/generate_cancel/" + taskId; - String url = fastApiPythonAddress + "/api/generate_cancel/" + taskId; + String url = accessPythonIp + ":" + accessPythonPort + path + taskId; +// String url = fastApiPythonAddress + "/api/generate_cancel/" + taskId; Request request = new Request.Builder() .url(url) // .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==") @@ -3193,14 +3193,14 @@ public class PythonService { log.info("cancelGenerateTask请求入参content###{}", taskId); response = client.newCall(request).execute(); } catch (IOException ioException) { - log.error("PythonService##cancelGenerateTask异常###{}", ExceptionUtil.getThrowableList(ioException)); + log.error("PythonService##cancelGenerateTask异常###{}", response); return null; } int responseCode = response.code(); response.close(); if (responseCode != HttpURLConnection.HTTP_OK) { - log.info("generate-python 取消请求失败"); + log.info("generate-python 取消请求失败. {}", response); return Boolean.FALSE; } log.info("generate-python 取消请求成功"); diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index 89fd6267..e706a8e6 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -35,5 +35,5 @@ public interface GenerateService extends IService { Long getRankPosition(String uniqueId); - void cancelGenerate(Long userId, List uniqueId, String timeZone); + void cancelGenerate(Long userId, List uniqueId, String timeZone, String type); } diff --git a/src/main/java/com/ai/da/service/impl/ChatRobotServiceImpl.java b/src/main/java/com/ai/da/service/impl/ChatRobotServiceImpl.java index 9994bf73..b03e181f 100644 --- a/src/main/java/com/ai/da/service/impl/ChatRobotServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/ChatRobotServiceImpl.java @@ -152,8 +152,8 @@ public class ChatRobotServiceImpl implements ChatRobotService { RequestBody body = RequestBody.create(mediaType, param); Request request = new Request.Builder() // .url("http://127.0.0.1:5000/api/chat_stream_test") - .url(accessPythonIp + ":" + accessPythonPort + "/api/chat_stream_test") -// .url(fastApiPythonAddress + "/api/chat_robot") +// .url(accessPythonIp + ":" + accessPythonPort + "/api/chat_stream_test") + .url(accessPythonIp + ":" + accessPythonPort + "/api/chat_robot") // .url(accessPythonIp + ":10200/aifda/api/v1.0/generate") .method("POST", body) .addHeader("Content-Type", "application/json") diff --git a/src/main/java/com/ai/da/service/impl/CreditsServiceImpl.java b/src/main/java/com/ai/da/service/impl/CreditsServiceImpl.java index 61b67921..e245145f 100644 --- a/src/main/java/com/ai/da/service/impl/CreditsServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/CreditsServiceImpl.java @@ -17,6 +17,7 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import io.netty.util.internal.StringUtil; +import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -31,6 +32,7 @@ import java.util.Objects; import java.util.Set; @Service +@Slf4j public class CreditsServiceImpl extends ServiceImpl implements CreditsService { @Value("${redis.key.credits.pre-deduction}") @@ -224,13 +226,18 @@ public class CreditsServiceImpl extends ServiceImpl i return new PrepareForGenerateVO(0); } } - CreditsEventsEnum creditsEventsEnum = null; + CreditsEventsEnum creditsEventsEnum = CreditsEventsEnum.NORMAL_GENERATE; int times = 4; // 当level1Type为Print_board时,level2Type为pattern时需要确定generateType if (generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())){ @@ -493,7 +493,6 @@ public class GenerateServiceImpl extends ServiceImpl i validateGeneraType(generate, text, elementId); // 校验后获取 generateThroughImageTextDTO.setGenerateType(generate.getGenerateType()); - creditsEventsEnum = CreditsEventsEnum.NORMAL_GENERATE; } // Slogan 参数校验 if (generateThroughImageTextDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.SLOGAN.getRealName())){ @@ -546,7 +545,6 @@ public class GenerateServiceImpl extends ServiceImpl i if (seed < 0 || seed > 99999){ throw new BusinessException("the.value.range.of.seed"); } - creditsEventsEnum = CreditsEventsEnum.NORMAL_GENERATE; } } @@ -638,7 +636,7 @@ public class GenerateServiceImpl extends ServiceImpl i @Override @Transactional(rollbackFor = Exception.class) - public void cancelGenerate(Long userId, List uniqueIdList, String timeZone) { + public void cancelGenerate(Long userId, List uniqueIdList, String timeZone, String type) { // todo 取消待优化 uniqueIdList.forEach(uniqueId -> { // 1、将需要取消的唯一id加入redis,以便及时取消生成 @@ -666,13 +664,23 @@ public class GenerateServiceImpl extends ServiceImpl i pythonService.cancelGenerateTask(uniqueId); } }*/ + String path; + if (type.equals("Logo")){ + path = CommonConstant.GENERATE_LOGO_SINGLE_CANCEL; + }else { + path = CommonConstant.GENERATE_CANCEL; + } String key = generateResultKey + ":" + uniqueId; GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class); + if (Objects.isNull(generateResultVO)){ + log.warn("任务不存在,无法取消"); + return; + } // 判断当前task的状态是不是Fail if (!generateResultVO.getStatus().equals("Fail")) { // 2、不是,直接发送取消请求到python端 - pythonService.cancelGenerateTask(uniqueId); + pythonService.cancelGenerateTask(uniqueId, path); // 3、更改result中当前taskId的状态 redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(uniqueId, null, null, "Cancelled")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME); }