generate cancel

This commit is contained in:
2024-06-24 17:02:25 +08:00
parent ffaef2ff6c
commit aacbe92cdc
7 changed files with 36 additions and 16 deletions

View File

@@ -23,6 +23,10 @@ public class CommonConstant {
public static final String GENERATE_SLOGAN = "/api/slogan";
public static final String GENERATE_CANCEL = "/api/generate_cancel/";
public static final String GENERATE_LOGO_SINGLE_CANCEL = "/api/generate_single_logo_cancel/";
public static final String PYTHON_PORT_9996 = "9996";
public static final String PYTHON_PORT_9997 = "9997";

View File

@@ -63,8 +63,9 @@ public class GenerateController {
@GetMapping("/stopWaiting")
public Response<String> stopWaiting(@RequestParam("userId") Long userId,
@RequestParam("uniqueId") List<String> uniqueId,
@RequestParam("timeZone") String timeZone) {
generateService.cancelGenerate(userId, uniqueId, timeZone);
@RequestParam("timeZone") String timeZone,
@RequestParam("type") String type) {
generateService.cancelGenerate(userId, uniqueId, timeZone, type);
return Response.success("stop waiting successfully");
}

View File

@@ -3174,15 +3174,15 @@ public class PythonService {
throw new BusinessException("cloth-classification.interface.exception");
}
public Boolean cancelGenerateTask(String taskId) {
public Boolean cancelGenerateTask(String taskId, String path) {
OkHttpClient client = new OkHttpClient().newBuilder()
.connectTimeout(30, TimeUnit.SECONDS)
.pingInterval(5, TimeUnit.SECONDS)//websocket轮训间隔(单位:秒)
.readTimeout(60, TimeUnit.SECONDS)//读取超时(单位:秒)
.writeTimeout(60, TimeUnit.SECONDS)//写入超时(单位:秒)
.build();
// String url = accessPythonIp + ":" + accessPythonPort + "/api/generate_cancel/" + taskId;
String url = fastApiPythonAddress + "/api/generate_cancel/" + taskId;
String url = accessPythonIp + ":" + accessPythonPort + path + taskId;
// String url = fastApiPythonAddress + "/api/generate_cancel/" + taskId;
Request request = new Request.Builder()
.url(url)
// .addHeader("Authorization", "Basic YWlkbGFiOjEyMw==")
@@ -3193,14 +3193,14 @@ public class PythonService {
log.info("cancelGenerateTask请求入参content###{}", taskId);
response = client.newCall(request).execute();
} catch (IOException ioException) {
log.error("PythonService##cancelGenerateTask异常###{}", ExceptionUtil.getThrowableList(ioException));
log.error("PythonService##cancelGenerateTask异常###{}", response);
return null;
}
int responseCode = response.code();
response.close();
if (responseCode != HttpURLConnection.HTTP_OK) {
log.info("generate-python 取消请求失败");
log.info("generate-python 取消请求失败. {}", response);
return Boolean.FALSE;
}
log.info("generate-python 取消请求成功");

View File

@@ -35,5 +35,5 @@ public interface GenerateService extends IService<Generate> {
Long getRankPosition(String uniqueId);
void cancelGenerate(Long userId, List<String> uniqueId, String timeZone);
void cancelGenerate(Long userId, List<String> uniqueId, String timeZone, String type);
}

View File

@@ -152,8 +152,8 @@ public class ChatRobotServiceImpl implements ChatRobotService {
RequestBody body = RequestBody.create(mediaType, param);
Request request = new Request.Builder()
// .url("http://127.0.0.1:5000/api/chat_stream_test")
.url(accessPythonIp + ":" + accessPythonPort + "/api/chat_stream_test")
// .url(fastApiPythonAddress + "/api/chat_robot")
// .url(accessPythonIp + ":" + accessPythonPort + "/api/chat_stream_test")
.url(accessPythonIp + ":" + accessPythonPort + "/api/chat_robot")
// .url(accessPythonIp + ":10200/aifda/api/v1.0/generate")
.method("POST", body)
.addHeader("Content-Type", "application/json")

View File

@@ -17,6 +17,7 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import io.netty.util.internal.StringUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -31,6 +32,7 @@ import java.util.Objects;
import java.util.Set;
@Service
@Slf4j
public class CreditsServiceImpl extends ServiceImpl<CreditsDetailMapper, CreditsDetail> implements CreditsService {
@Value("${redis.key.credits.pre-deduction}")
@@ -224,13 +226,18 @@ public class CreditsServiceImpl extends ServiceImpl<CreditsDetailMapper, Credits
@Override
@Transactional(rollbackFor = Exception.class)
public void taskCreditsDeduction(Long accountId, String taskId){
log.info("指定任务的积分扣除 {}",taskId);
String key = creditsDeduction + ":" + accountId + ":" + taskId;
// 1、获取当前任务id对应的积分
String value = redisUtil.getFromString(key);
// 1.1 没有。返回,报错,未找到当前任务
if (StringUtil.isNullOrEmpty(value)){
throw new BusinessException("当前任务不存在,无法扣除积分");
log.info("当前任务 {} 不存在,或当前任务已完成积分扣除", taskId);
return;
// throw new BusinessException("当前任务不存在,无法扣除积分");
}else {
log.info("指定任务 {} 扣除积分 {}",taskId, value);
}
// 2、操作数据库扣除积分

View File

@@ -475,7 +475,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
return new PrepareForGenerateVO(0);
}
}
CreditsEventsEnum creditsEventsEnum = null;
CreditsEventsEnum creditsEventsEnum = CreditsEventsEnum.NORMAL_GENERATE;
int times = 4;
// 当level1Type为Print_board时level2Type为pattern时需要确定generateType
if (generateThroughImageTextDTO.getLevel1Type().equals(PRINT_BOARD.getRealName())){
@@ -493,7 +493,6 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
validateGeneraType(generate, text, elementId);
// 校验后获取
generateThroughImageTextDTO.setGenerateType(generate.getGenerateType());
creditsEventsEnum = CreditsEventsEnum.NORMAL_GENERATE;
}
// Slogan 参数校验
if (generateThroughImageTextDTO.getLevel2Type().equals(CollectionLevel2TypeEnum.SLOGAN.getRealName())){
@@ -546,7 +545,6 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
if (seed < 0 || seed > 99999){
throw new BusinessException("the.value.range.of.seed");
}
creditsEventsEnum = CreditsEventsEnum.NORMAL_GENERATE;
}
}
@@ -638,7 +636,7 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
@Override
@Transactional(rollbackFor = Exception.class)
public void cancelGenerate(Long userId, List<String> uniqueIdList, String timeZone) {
public void cancelGenerate(Long userId, List<String> uniqueIdList, String timeZone, String type) {
// todo 取消待优化
uniqueIdList.forEach(uniqueId -> {
// 1、将需要取消的唯一id加入redis以便及时取消生成
@@ -666,13 +664,23 @@ public class GenerateServiceImpl extends ServiceImpl<GenerateMapper, Generate> i
pythonService.cancelGenerateTask(uniqueId);
}
}*/
String path;
if (type.equals("Logo")){
path = CommonConstant.GENERATE_LOGO_SINGLE_CANCEL;
}else {
path = CommonConstant.GENERATE_CANCEL;
}
String key = generateResultKey + ":" + uniqueId;
GenerateResultVO generateResultVO = new Gson().fromJson(redisUtil.getFromString(key), GenerateResultVO.class);
if (Objects.isNull(generateResultVO)){
log.warn("任务不存在,无法取消");
return;
}
// 判断当前task的状态是不是Fail
if (!generateResultVO.getStatus().equals("Fail")) {
// 2、不是直接发送取消请求到python端
pythonService.cancelGenerateTask(uniqueId);
pythonService.cancelGenerateTask(uniqueId, path);
// 3、更改result中当前taskId的状态
redisUtil.addToString(key, new Gson().toJson(new GenerateResultVO(uniqueId, null, null, "Cancelled")), CommonConstant.GENERATE_RESULT_EXPIRE_TIME);
}