diff --git a/src/main/java/com/ai/da/common/enums/CreditsEventsEnum.java b/src/main/java/com/ai/da/common/enums/CreditsEventsEnum.java index 5d3e6873..5bd7279d 100644 --- a/src/main/java/com/ai/da/common/enums/CreditsEventsEnum.java +++ b/src/main/java/com/ai/da/common/enums/CreditsEventsEnum.java @@ -57,6 +57,8 @@ public enum CreditsEventsEnum { LOCAL_TEXT2IMG_HIGH("Local_text2img_high","5"), LOCAL_IMG2IMG_HIGH("Local_img2img_high","5"), LOCAL_ANIMATION("Local_Animation","15"), + + LLM_CONVERSATION("LLM_Conversation", "0") ; private final String name; diff --git a/src/main/java/com/ai/da/mapper/primary/entity/ChatMessage.java b/src/main/java/com/ai/da/mapper/primary/entity/ChatMessage.java index 5b946bc9..0d84e2dc 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/ChatMessage.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/ChatMessage.java @@ -39,6 +39,22 @@ public class ChatMessage implements Serializable { @ApiModelProperty("0对话内容1颜色2图片") private Integer isImage; + /** + * 输入 + */ + private Long inputTokens; + /** + * 输出 + */ + private Long outputTokens; + /** + * 思考 + */ + private Long reasoningTokens; + /** + * 本次输出消耗的总金额 + */ + private String totalCost; @ApiModelProperty("创建时间") private LocalDateTime createTime; diff --git a/src/main/java/com/ai/da/mapper/primary/entity/WorkspaceRelStyle.java b/src/main/java/com/ai/da/mapper/primary/entity/WorkspaceRelStyle.java index 3cd4abf0..d5a35efa 100644 --- a/src/main/java/com/ai/da/mapper/primary/entity/WorkspaceRelStyle.java +++ b/src/main/java/com/ai/da/mapper/primary/entity/WorkspaceRelStyle.java @@ -9,10 +9,10 @@ import java.io.Serializable; @Data @TableName("workspace_rel_style") -public class WorkspaceRelStyle implements Serializable { +public class WorkspaceRelStyle extends BaseEntity implements Serializable { private static final long serialVersionUID = 1L; - @TableId(value = "id", type = IdType.AUTO) - private Long id; +// @TableId(value = "id", type = IdType.AUTO) +// private Long id; private Long workspaceId; private Long styleId; } diff --git a/src/main/java/com/ai/da/service/CreditsService.java b/src/main/java/com/ai/da/service/CreditsService.java index 9f7cf8bf..ab8ac126 100644 --- a/src/main/java/com/ai/da/service/CreditsService.java +++ b/src/main/java/com/ai/da/service/CreditsService.java @@ -7,6 +7,8 @@ import com.ai.da.mapper.primary.entity.CreditsDetail; import com.ai.da.model.dto.QueryIncomeOrExpenditureDTO; import com.baomidou.mybatisplus.extension.service.IService; +import java.math.BigDecimal; + public interface CreditsService extends IService { Boolean buyCredits(Long accountId, Float quantity); @@ -40,4 +42,6 @@ public interface CreditsService extends IService { void updateChangedCredits(String accountId, String taskId); CreditsDetail queryDetailByTaskId(String taskId); + + void tokenUsage(Long accountId, BigDecimal totalTokenUsage); } diff --git a/src/main/java/com/ai/da/service/impl/CreditsServiceImpl.java b/src/main/java/com/ai/da/service/impl/CreditsServiceImpl.java index e56f8563..b6ca3287 100644 --- a/src/main/java/com/ai/da/service/impl/CreditsServiceImpl.java +++ b/src/main/java/com/ai/da/service/impl/CreditsServiceImpl.java @@ -26,12 +26,10 @@ import org.springframework.util.CollectionUtils; import javax.annotation.Resource; import java.math.BigDecimal; +import java.math.RoundingMode; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; +import java.util.*; @Service @Slf4j @@ -368,4 +366,23 @@ public class CreditsServiceImpl extends ServiceImpl imageUrlList, String token, Boolean enableThinking, String process) { @@ -295,6 +296,14 @@ public class LLMServiceImpl implements LLMService { systemMessage.setAccountId(accountId); StringBuilder responseContentBuilder = new StringBuilder(); String contentType = ""; + BigDecimal totalTokenUsage = BigDecimal.ZERO; + // 需要存input_token output_token reasoning_tokens total_cost + Long inputToken = 0L; + Long outputToken = 0L; + Long reasoningToken = 0L; + String totalCost = "0"; + boolean tokenUsageFlag = false; + try (BufferedReader reader = new BufferedReader(new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { @@ -334,6 +343,13 @@ public class LLMServiceImpl implements LLMService { systemMessage.setSeq(getNextSeq(projectId)); systemMessage.setCreateTime(LocalDateTime.now()); systemMessage.setContent(responseContentBuilder.toString()); + if (tokenUsageFlag){ + systemMessage.setInputTokens(inputToken); + systemMessage.setOutputTokens(outputToken); + systemMessage.setReasoningTokens(reasoningToken); + systemMessage.setTotalCost(totalCost); + tokenUsageFlag = false; + } chatMessageMapper.insert(systemMessage); systemMessage.setId(null); responseContentBuilder = new StringBuilder(); @@ -368,6 +384,13 @@ public class LLMServiceImpl implements LLMService { systemImage.setCreateTime(LocalDateTime.now()); systemImage.setContent(contentSave); systemImage.setAccountId(accountId); + if (tokenUsageFlag){ + systemImage.setInputTokens(inputToken); + systemImage.setOutputTokens(outputToken); + systemImage.setReasoningTokens(reasoningToken); + systemImage.setTotalCost(totalCost); + tokenUsageFlag = false; + } chatMessageMapper.insert(systemImage); } else if (Objects.nonNull(toolsName) && toolsName.equals("search_sketch_img")) { json.put("content", processSearchSketchToolCon(projectId, toolsData)); @@ -376,6 +399,16 @@ public class LLMServiceImpl implements LLMService { updateProjectParams(projectId, toolsData); } emitter.send(json.toJSONString()); + }else if ("cost".equals(type)) { + JSONObject contentObj = json.getJSONObject("content"); + log.info("token usage: {}", contentObj); + inputToken = contentObj.getLong("input_tokens"); + outputToken = contentObj.getLong("output_tokens"); + reasoningToken = contentObj.getLong("reasoning_tokens"); + totalCost = contentObj.getString("total_cost"); + totalTokenUsage = totalTokenUsage.add(new BigDecimal(totalCost).divide(new BigDecimal("0.03"), 6, RoundingMode.CEILING)); + log.info("totalTokenUsage: {}", totalTokenUsage); + tokenUsageFlag = true; } } } @@ -384,8 +417,16 @@ public class LLMServiceImpl implements LLMService { systemMessage.setSeq(getNextSeq(projectId)); systemMessage.setCreateTime(LocalDateTime.now()); systemMessage.setContent(responseContentBuilder.toString()); + if (tokenUsageFlag){ + systemMessage.setInputTokens(inputToken); + systemMessage.setOutputTokens(outputToken); + systemMessage.setReasoningTokens(reasoningToken); + systemMessage.setTotalCost(totalCost); + } chatMessageMapper.insert(systemMessage); } + // 扣积分 + creditsService.tokenUsage(accountId, totalTokenUsage); } } emitter.complete();