chatbot 和 agent 增加session_id 作会话控制,弃用user_id 作为会话控制
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,3 +9,4 @@ app/core/data/
|
|||||||
*.log
|
*.log
|
||||||
db
|
db
|
||||||
*.sqlite
|
*.sqlite
|
||||||
|
*.png
|
||||||
@@ -17,6 +17,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class AgentRequestModel(BaseModel):
|
class AgentRequestModel(BaseModel):
|
||||||
user_id: str
|
user_id: str
|
||||||
|
session_id: str
|
||||||
num_outfits: int
|
num_outfits: int
|
||||||
stylist_path: str
|
stylist_path: str
|
||||||
callback_url: str
|
callback_url: str
|
||||||
@@ -71,7 +72,7 @@ class LCAgent(ls.LitAPI):
|
|||||||
|
|
||||||
async def background_run(self, request: AgentRequestModel):
|
async def background_run(self, request: AgentRequestModel):
|
||||||
# 1. 根据用户ID查询对话历史,总结对话内容
|
# 1. 根据用户ID查询对话历史,总结对话内容
|
||||||
request_summary = await self.get_conversation_summary(request.user_id)
|
request_summary = await self.get_conversation_summary(request.session_id)
|
||||||
logger.info(f"request_summary: {request_summary}")
|
logger.info(f"request_summary: {request_summary}")
|
||||||
|
|
||||||
# 2.根据对话总结推荐搭配
|
# 2.根据对话总结推荐搭配
|
||||||
@@ -89,13 +90,13 @@ class LCAgent(ls.LitAPI):
|
|||||||
for failed in recommendation_results.get("failed_outfits", []):
|
for failed in recommendation_results.get("failed_outfits", []):
|
||||||
logger.error(f"❌ {failed}")
|
logger.error(f"❌ {failed}")
|
||||||
|
|
||||||
async def get_conversation_summary(self, user_id: str) -> str:
|
async def get_conversation_summary(self, session_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
分析用户的完整会话历史,并打包成一个简洁的需求总结。
|
分析用户的完整会话历史,并打包成一个简洁的需求总结。
|
||||||
|
|
||||||
这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。`
|
这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。`
|
||||||
"""
|
"""
|
||||||
history_messages = self.redis.get_history(user_id)
|
history_messages = self.redis.get_history(session_id)
|
||||||
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
|
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
|
||||||
# 临时调用 LLM 或使用本地逻辑生成总结
|
# 临时调用 LLM 或使用本地逻辑生成总结
|
||||||
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)],
|
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)],
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class PredictRequest(BaseModel):
|
class PredictRequest(BaseModel):
|
||||||
user_id: str # 用戶ID
|
user_id: str # 用戶ID
|
||||||
|
session_id: str
|
||||||
user_message: str # 用戶輸入
|
user_message: str # 用戶輸入
|
||||||
gender: str # 服装类型
|
gender: str # 服装类型
|
||||||
|
|
||||||
@@ -55,8 +56,10 @@ class LCChatBot(ls.LitAPI):
|
|||||||
# 添加用户消息到历史
|
# 添加用户消息到历史
|
||||||
user_message = request.user_message
|
user_message = request.user_message
|
||||||
user_id = request.user_id
|
user_id = request.user_id
|
||||||
|
session_id = request.session_id
|
||||||
|
|
||||||
user_msg = Message(role=Role.USER, content=user_message)
|
user_msg = Message(role=Role.USER, content=user_message)
|
||||||
chat_history = self.redis.get_history(user_id)
|
chat_history = self.redis.get_history(session_id)
|
||||||
chat_history.append(user_msg)
|
chat_history.append(user_msg)
|
||||||
if request.gender == 'male':
|
if request.gender == 'male':
|
||||||
BASIC_PROMPT = MEN_BASIC_PROMPT
|
BASIC_PROMPT = MEN_BASIC_PROMPT
|
||||||
@@ -98,8 +101,8 @@ class LCChatBot(ls.LitAPI):
|
|||||||
else:
|
else:
|
||||||
assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.")
|
assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.")
|
||||||
|
|
||||||
self.redis.save_message(user_id, user_msg)
|
self.redis.save_message(session_id, user_msg)
|
||||||
self.redis.save_message(user_id, assistant_msg)
|
self.redis.save_message(session_id, assistant_msg)
|
||||||
|
|
||||||
async def encode_response(self, output):
|
async def encode_response(self, output):
|
||||||
# The for-loop must have async keyword here since output is an AsyncGenerator
|
# The for-loop must have async keyword here since output is an AsyncGenerator
|
||||||
|
|||||||
Reference in New Issue
Block a user