chatbot 和 agent 增加session_id 作会话控制,弃用user_id 作为会话控制

This commit is contained in:
zhh
2025-11-03 15:21:43 +08:00
parent d1e2e56f42
commit cde7357df8
3 changed files with 12 additions and 7 deletions

1
.gitignore vendored
View File

@@ -9,3 +9,4 @@ app/core/data/
*.log *.log
db db
*.sqlite *.sqlite
*.png

View File

@@ -17,6 +17,7 @@ logger = logging.getLogger(__name__)
class AgentRequestModel(BaseModel): class AgentRequestModel(BaseModel):
user_id: str user_id: str
session_id: str
num_outfits: int num_outfits: int
stylist_path: str stylist_path: str
callback_url: str callback_url: str
@@ -71,7 +72,7 @@ class LCAgent(ls.LitAPI):
async def background_run(self, request: AgentRequestModel): async def background_run(self, request: AgentRequestModel):
# 1. 根据用户ID查询对话历史总结对话内容 # 1. 根据用户ID查询对话历史总结对话内容
request_summary = await self.get_conversation_summary(request.user_id) request_summary = await self.get_conversation_summary(request.session_id)
logger.info(f"request_summary: {request_summary}") logger.info(f"request_summary: {request_summary}")
# 2.根据对话总结推荐搭配 # 2.根据对话总结推荐搭配
@@ -89,13 +90,13 @@ class LCAgent(ls.LitAPI):
for failed in recommendation_results.get("failed_outfits", []): for failed in recommendation_results.get("failed_outfits", []):
logger.error(f"{failed}") logger.error(f"{failed}")
async def get_conversation_summary(self, user_id: str) -> str: async def get_conversation_summary(self, session_id: str) -> str:
""" """
分析用户的完整会话历史,并打包成一个简洁的需求总结。 分析用户的完整会话历史,并打包成一个简洁的需求总结。
这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。` 这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。`
""" """
history_messages = self.redis.get_history(user_id) history_messages = self.redis.get_history(session_id)
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages]) input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
# 临时调用 LLM 或使用本地逻辑生成总结 # 临时调用 LLM 或使用本地逻辑生成总结
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)], summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)],

View File

@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
class PredictRequest(BaseModel): class PredictRequest(BaseModel):
user_id: str # 用戶ID user_id: str # 用戶ID
session_id: str
user_message: str # 用戶輸入 user_message: str # 用戶輸入
gender: str # 服装类型 gender: str # 服装类型
@@ -55,8 +56,10 @@ class LCChatBot(ls.LitAPI):
# 添加用户消息到历史 # 添加用户消息到历史
user_message = request.user_message user_message = request.user_message
user_id = request.user_id user_id = request.user_id
session_id = request.session_id
user_msg = Message(role=Role.USER, content=user_message) user_msg = Message(role=Role.USER, content=user_message)
chat_history = self.redis.get_history(user_id) chat_history = self.redis.get_history(session_id)
chat_history.append(user_msg) chat_history.append(user_msg)
if request.gender == 'male': if request.gender == 'male':
BASIC_PROMPT = MEN_BASIC_PROMPT BASIC_PROMPT = MEN_BASIC_PROMPT
@@ -98,8 +101,8 @@ class LCChatBot(ls.LitAPI):
else: else:
assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.") assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.")
self.redis.save_message(user_id, user_msg) self.redis.save_message(session_id, user_msg)
self.redis.save_message(user_id, assistant_msg) self.redis.save_message(session_id, assistant_msg)
async def encode_response(self, output): async def encode_response(self, output):
# The for-loop must have async keyword here since output is an AsyncGenerator # The for-loop must have async keyword here since output is an AsyncGenerator