diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..e69de29b diff --git a/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java index d4e24508..208075fd 100644 --- a/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java +++ b/src/main/java/com/ai/da/common/RabbitMQ/MQConfig.java @@ -16,7 +16,8 @@ import org.springframework.beans.factory.annotation.Value; public class MQConfig { public static final String GENERATE_EXCHANGE_FANOUT = "generate-exchange"; - public static final String GENERATE_QUEUE = "generate-queue-prod"; +// public static final String GENERATE_QUEUE = "generate-queue-prod"; + public static final String GENERATE_QUEUE = "generate-queue-test"; public MQConfig() { } diff --git a/src/main/java/com/ai/da/common/security/filter/AuthenticationFilter.java b/src/main/java/com/ai/da/common/security/filter/AuthenticationFilter.java index c2625612..6efefd84 100644 --- a/src/main/java/com/ai/da/common/security/filter/AuthenticationFilter.java +++ b/src/main/java/com/ai/da/common/security/filter/AuthenticationFilter.java @@ -47,7 +47,7 @@ public class AuthenticationFilter extends OncePerRequestFilter { "/api/third/party/addUser","/api/third/party/addTrialUser", "/api/third/party/editUser", "/api/element/initDefaultSysFile", "/api/third/party/addNoLoginRequiredNew","/api/third/party/deleteNoLoginRequiredNew", "/api/third/party/existNoLoginRequired","/api/third/party/getRedirectUrl", - "/api/python/chatStream", +// "/api/python/chatStream", "/api/python/flush", "/api/account/healthy" ); diff --git a/src/main/java/com/ai/da/controller/GenerateController.java b/src/main/java/com/ai/da/controller/GenerateController.java index b3f6fb5b..c752fbc0 100644 --- a/src/main/java/com/ai/da/controller/GenerateController.java +++ b/src/main/java/com/ai/da/controller/GenerateController.java @@ -6,6 +6,7 @@ import com.ai.da.model.dto.GenerateThroughImageTextDTO; import com.ai.da.model.vo.GenerateCaptionVO; import com.ai.da.model.vo.GenerateCollectionVO; import com.ai.da.model.vo.GenerateLikeVO; +import com.ai.da.model.vo.PrepareForGenerateVO; import com.ai.da.service.GenerateService; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; @@ -55,7 +56,7 @@ public class GenerateController { @ApiOperation(value = "发起生成请求,异步获取结果") @PostMapping("/prepare") - public Response prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public Response prepareForGenerate(@Valid @RequestBody GenerateThroughImageTextDTO generateThroughImageTextDTO) { return Response.success(generateService.prepareForGenerate(generateThroughImageTextDTO)); } diff --git a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java index fd45533c..1adfdbef 100644 --- a/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java +++ b/src/main/java/com/ai/da/model/dto/GenerateThroughImageTextDTO.java @@ -47,4 +47,7 @@ public class GenerateThroughImageTextDTO { @ApiModelProperty("唯一id,用于保持消息唯一性") String uniqueId; + + @ApiModelProperty("是否是测试用户") + Boolean isTestUser; } diff --git a/src/main/java/com/ai/da/model/vo/PrepareForGenerateVO.java b/src/main/java/com/ai/da/model/vo/PrepareForGenerateVO.java new file mode 100644 index 00000000..438d94db --- /dev/null +++ b/src/main/java/com/ai/da/model/vo/PrepareForGenerateVO.java @@ -0,0 +1,25 @@ +package com.ai.da.model.vo; + +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +@Data +@ApiModel("prepare for generate响应vo") +public class PrepareForGenerateVO { + + @ApiModelProperty("uniqueId") + private String uniqueId; + + @ApiModelProperty("剩余使用次数") + private Integer leftUsageCount; + + public PrepareForGenerateVO(String uniqueId, Integer leftUsageCount) { + this.uniqueId = uniqueId; + this.leftUsageCount = leftUsageCount; + } + + public PrepareForGenerateVO(Integer leftUsageCount) { + this.leftUsageCount = leftUsageCount; + } +} diff --git a/src/main/java/com/ai/da/service/GenerateService.java b/src/main/java/com/ai/da/service/GenerateService.java index d1a327bd..06ca4122 100644 --- a/src/main/java/com/ai/da/service/GenerateService.java +++ b/src/main/java/com/ai/da/service/GenerateService.java @@ -7,6 +7,7 @@ import com.ai.da.model.dto.GenerateThroughImageTextDTO; import com.ai.da.model.vo.GenerateCaptionVO; import com.ai.da.model.vo.GenerateCollectionVO; import com.ai.da.model.vo.GenerateLikeVO; +import com.ai.da.model.vo.PrepareForGenerateVO; import com.baomidou.mybatisplus.extension.service.IService; import java.util.List; @@ -27,7 +28,7 @@ public interface GenerateService extends IService { GenerateCollectionVO getGenerateResult(String uniqueId); - String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO); + PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO); Long getRankPosition(String uniqueId); diff --git a/src/main/java/com/ai/da/service/impl/ChatRobotServiceImpl.java b/src/main/java/com/ai/da/service/impl/ChatRobotServiceImpl.java index 5f06a194..9cfffaf1 100644 --- a/src/main/java/com/ai/da/service/impl/ChatRobotServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/ChatRobotServiceImpl.java @@ -9,7 +9,9 @@ import com.ai.da.common.enums.LibraryLevel1TypeEnum; import com.ai.da.common.utils.CopyUtil; import com.ai.da.common.utils.MD5Utils; import com.ai.da.common.utils.MinioUtil; +import com.ai.da.mapper.AccountMapper; import com.ai.da.mapper.LibraryMapper; +import com.ai.da.mapper.entity.Account; import com.ai.da.mapper.entity.ChatRobot; import com.ai.da.mapper.ChatRobotMapper; import com.ai.da.mapper.entity.Library; @@ -74,6 +76,8 @@ public class ChatRobotServiceImpl implements ChatRobotService { @Value("${minio.bucketName.sysImage}") private String sysImage; + @Resource + private AccountMapper accountMapper; Gson gson = new GsonBuilder().create(); private final ExecutorService executorService = Executors.newSingleThreadExecutor(); @@ -178,7 +182,8 @@ public class ChatRobotServiceImpl implements ChatRobotService { chatRobot.setSessionId(data.getString("session_id")); BigDecimal totalCost = data.getBigDecimal("total_cost"); // 校验本次余额够不够 - checkBalance(totalCost, chatSendDTO.getUser_id()); + ChatRobotVO balance = checkBalance(totalCost, chatSendDTO.getUser_id()); + if (!Objects.isNull(balance)) return balance; chatRobot.setTotalCost(totalCost); chatRobot.setTotalTokens(data.getLong("total_tokens")); chatRobot.setUserId(chatSendDTO.getUser_id()); @@ -248,7 +253,7 @@ public class ChatRobotServiceImpl implements ChatRobotService { throw new BusinessException("chat-bot.interface.exception"); } - private void checkBalance(BigDecimal totalCost, Long userId) { + private ChatRobotVO checkBalance(BigDecimal totalCost, Long userId) { QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.eq("user_id", userId); queryWrapper.ge("create_time", LocalDateTime.now().withDayOfMonth(1).withHour(0).withMinute(0).withSecond(0)); @@ -257,10 +262,25 @@ public class ChatRobotServiceImpl implements ChatRobotService { List chatRobots = chatRobotMapper.selectList(queryWrapper); if (!CollectionUtils.isEmpty(chatRobots)) { BigDecimal totalCostUsed = chatRobots.get(0).getTotalCost(); - if (totalCostUsed.add(totalCost).compareTo(BigDecimal.valueOf(5)) > 0) { - throw new BusinessException("Your balance is insufficient"); + Account account = accountMapper.selectById(userId); + ChatRobot chatRobot = new ChatRobot(); + if (account.getIsTrial() == 1) { + if (totalCostUsed.add(totalCost).compareTo(BigDecimal.valueOf(0.1)) > 0) { + String messageFromResource = BusinessException.getMessageFromResource("balance.insufficient.for.trial"); + chatRobot.setOutput(messageFromResource); +// throw new BusinessException("Your balance is insufficient"); + return CopyUtil.copyObject(chatRobot, ChatRobotVO.class); + } + }else { + if (totalCostUsed.add(totalCost).compareTo(BigDecimal.valueOf(5)) > 0) { + String messageFromResource = BusinessException.getMessageFromResource("balance.insufficient.for.paying"); + chatRobot.setOutput(messageFromResource); +// throw new BusinessException("Your balance is insufficient"); + return CopyUtil.copyObject(chatRobot, ChatRobotVO.class); + } } } + return null; } @Override @@ -323,9 +343,14 @@ public class ChatRobotServiceImpl implements ChatRobotService { if (CollectionUtils.isEmpty(chatRobots)) { return BigDecimal.ONE; } else { + Account account = accountMapper.selectById(userId); BigDecimal totalCost = BigDecimal.valueOf(5).subtract(chatRobots.get(0).getTotalCost()); - BigDecimal result = totalCost.divide(BigDecimal.valueOf(5), 4, RoundingMode.HALF_UP); - return result; + if (account.getIsTrial() == 1) { + totalCost = BigDecimal.valueOf(0.1).subtract(chatRobots.get(0).getTotalCost()); + return totalCost.divide(BigDecimal.valueOf(0.1), 4, RoundingMode.HALF_UP); + }else { + return totalCost.divide(BigDecimal.valueOf(5), 4, RoundingMode.HALF_UP); + } } } diff --git a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java index 095db937..0ea8e713 100644 --- a/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/GenerateServiceImpl.java @@ -314,7 +314,7 @@ public class GenerateServiceImpl extends ServiceImpl i } @Override - public String prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { + public PrepareForGenerateVO prepareForGenerate(GenerateThroughImageTextDTO generateThroughImageTextDTO) { // 1、参数检查,判断必须参数是否为空 if (Objects.isNull(generateThroughImageTextDTO.getUserId())) { throw new BusinessException("userId cannot be empty"); @@ -323,6 +323,16 @@ public class GenerateServiceImpl extends ServiceImpl i if (!GenerateModeEnum.getGenerateModeList().contains(generateType)) { throw new BusinessException("unknown.generate.type"); } + + // 判断试用用户是否还有剩余试用机会 + int trialsCount = 0; + if (generateThroughImageTextDTO.getIsTestUser()){ + trialsCount = getTrialsCount(generateThroughImageTextDTO.getUserId(), generateThroughImageTextDTO.getLevel1Type()); + if (trialsCount >= 2){ + return new PrepareForGenerateVO(0); + } + } + String text = generateThroughImageTextDTO.getText(); Long elementId = generateThroughImageTextDTO.getCollectionElementId(); validateGeneraType(new Generate(), text, elementId, generateType); @@ -361,7 +371,7 @@ public class GenerateServiceImpl extends ServiceImpl i rabbitMQService.publishMessage(jsonString); // 5、返回唯一id - return uuid; + return new PrepareForGenerateVO(uuid, 2 - trialsCount); } @Override @@ -458,4 +468,28 @@ public class GenerateServiceImpl extends ServiceImpl i GenerateCancel generateCancel = new GenerateCancel(userId, uniqueId, DateUtil.getByTimeZone(timeZone)); generateCancelMapper.insert(generateCancel); } + + // 判断试用用户试用generate机会是否使用完毕 + private int getTrialsCount(Long userId, String level1Type){ + List getGenerateList = getGenerateByAccountId(userId, level1Type); + int trialsCount ; + if (getGenerateList.isEmpty()){ + trialsCount = 0; + } else if (getGenerateList.size() == 1) { + trialsCount = 1; + } else if (getGenerateList.size() == 2) { + trialsCount = 2; + }else { + trialsCount = 2; + } + return trialsCount; + } + + public List getGenerateByAccountId(Long accountId, String level1Type){ + QueryWrapper qw = new QueryWrapper<>(); + qw.eq("account_id",accountId); + qw.eq("level1_type", level1Type); + + return baseMapper.selectList(qw); + } } diff --git a/src/main/resources/application-prod.properties b/src/main/resources/application-prod.properties index e4df2d00..6a2b9d4d 100644 --- a/src/main/resources/application-prod.properties +++ b/src/main/resources/application-prod.properties @@ -1,10 +1,11 @@ -server.port=5566 +server.port=5567 #datasource spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver spring.datasource.url=jdbc:mysql://18.167.251.121:3306/aida?useUnicode=true&characterEncoding=UTF-8&useSSL=false&serverTimezone=Asia/Shanghai&allowPublicKeyRetrieval=true spring.datasource.username=root spring.datasource.password=QWa998345 +#spring.datasource.password=QWa998345 #security spring.security.jwtSecret=JWTSECRET diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index f00b793a..ec94939b 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -1,8 +1,8 @@ #����application-test�ļ�(���Ի���) -#spring.profiles.active=test +spring.profiles.active=test #����application-prod�ļ�(��������) -spring.profiles.active=dev +#spring.profiles.active=prod #����application-dev�ļ�(��������) #spring.profiles.active=dev diff --git a/src/main/resources/messages_en.properties b/src/main/resources/messages_en.properties index 3fa4295a..ca17378d 100644 --- a/src/main/resources/messages_en.properties +++ b/src/main/resources/messages_en.properties @@ -152,6 +152,8 @@ classificationName.already.exists=The label name you've entered already exists. # 用来提醒用户可能会导致不良后果的操作,但不一定是错误。用户需要认真考虑是否继续当前操作。 the.classification.you.deleted.has.associated.library=The label you are attempting to delete is associated with existing data. Are you sure you wish to proceed with deletion? the.model.has.been.referenced.by.the.workspace=This model is currently in use by a workspace. Deleting it might affect the workspace. Confirm deletion only if you are sure. +balance.insufficient.for.trial=Want to continue using it immediately? Please consider upgrading to our subscription plan to get more quota. +balance.insufficient.for.paying=You have reached your usage limit for this month. # Errors: # 这类错误是由系统内部错误引起的,用户通常无法自行解决,需要联系支持或等待系统管理员介入。 diff --git a/src/main/resources/messages_zh.properties b/src/main/resources/messages_zh.properties index 6f330d93..8fecb5f9 100644 --- a/src/main/resources/messages_zh.properties +++ b/src/main/resources/messages_zh.properties @@ -150,6 +150,8 @@ classificationName.already.exists=您输入的标签名已存在。请输入不 # 用来提醒用户可能会导致不良后果的操作,但不一定是错误。用户需要认真考虑是否继续当前操作。 the.classification.you.deleted.has.associated.library=您正在尝试删除的标签与现有数据相关联。您确定要继续删除吗? the.model.has.been.referenced.by.the.workspace=此模型当前正在工作区中使用。删除它可能会影响工作区。仅在确信后再确认删除。 +balance.insufficient.for.trial=想要立即继续使用?请考虑升级到我们的订阅计划,以获得更多额度。 +balance.insufficient.for.paying=您已达到本月的使用额度限制。 # Errors: # 这类错误是由系统内部错误引起的,用户通常无法自行解决,需要联系支持或等待系统管理员介入。