移除类变量self.stylist_agent_kwages 和 self.outfit_ids 避免出现同时调用模型导致参数复用
litserve 单实例复用setup() 方法只会被调用一次(在 Worker 启动时)
This commit is contained in:
@@ -44,7 +44,6 @@ class LCAgent(ls.LitAPI):
|
|||||||
'max_len': 9,
|
'max_len': 9,
|
||||||
'gemini_model_name': settings.LLM_MODEL_NAME
|
'gemini_model_name': settings.LLM_MODEL_NAME
|
||||||
}
|
}
|
||||||
self.outfit_ids = []
|
|
||||||
|
|
||||||
async def decode_request(self, request: AgentRequestModel):
|
async def decode_request(self, request: AgentRequestModel):
|
||||||
"""
|
"""
|
||||||
@@ -62,15 +61,17 @@ class LCAgent(ls.LitAPI):
|
|||||||
|
|
||||||
async def predict(self, request):
|
async def predict(self, request):
|
||||||
|
|
||||||
self.outfit_ids = [str(uuid.uuid4()) for _ in range(request.num_outfits)]
|
outfit_ids = [str(uuid.uuid4()) for _ in range(request.num_outfits)]
|
||||||
|
|
||||||
asyncio.create_task(self.background_run(request))
|
asyncio.create_task(self.background_run(request, outfit_ids))
|
||||||
return {"status": "Task initiated in background.", "outfit_ids": self.outfit_ids}
|
|
||||||
|
logger.info({"status": "Task initiated in background.", "outfit_ids": outfit_ids})
|
||||||
|
return {"status": "Task initiated in background.", "outfit_ids": outfit_ids}
|
||||||
|
|
||||||
async def encode_response(self, output):
|
async def encode_response(self, output):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def background_run(self, request: AgentRequestModel):
|
async def background_run(self, request: AgentRequestModel, outfit_ids):
|
||||||
# 1. 根据用户ID查询对话历史,总结对话内容
|
# 1. 根据用户ID查询对话历史,总结对话内容
|
||||||
request_summary = await self.get_conversation_summary(request.session_id)
|
request_summary = await self.get_conversation_summary(request.session_id)
|
||||||
logger.info(f"request_summary: {request_summary}")
|
logger.info(f"request_summary: {request_summary}")
|
||||||
@@ -83,7 +84,8 @@ class LCAgent(ls.LitAPI):
|
|||||||
user_id=request.user_id,
|
user_id=request.user_id,
|
||||||
gender=request.gender,
|
gender=request.gender,
|
||||||
callback_url=request.callback_url,
|
callback_url=request.callback_url,
|
||||||
max_len=request.max_len)
|
max_len=request.max_len,
|
||||||
|
outfit_ids=outfit_ids)
|
||||||
logger.info("--- Final Recommendation Results ---")
|
logger.info("--- Final Recommendation Results ---")
|
||||||
for i, path in enumerate(recommendation_results.get("successful_outfits", [])):
|
for i, path in enumerate(recommendation_results.get("successful_outfits", [])):
|
||||||
logger.info(f"✅ Outfit {i + 1} saved to: {path}")
|
logger.info(f"✅ Outfit {i + 1} saved to: {path}")
|
||||||
@@ -104,7 +106,7 @@ class LCAgent(ls.LitAPI):
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit=None, num_outfits: int = 1,
|
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit=None, num_outfits: int = 1,
|
||||||
user_id: str = "test", gender: str = "male", callback_url: str = None, max_len: int = 9):
|
user_id: str = "test", gender: str = "male", callback_url: str = None, max_len: int = 9, outfit_ids=None):
|
||||||
"""
|
"""
|
||||||
基于用户的对话历史和需求,推荐一套搭配。
|
基于用户的对话历史和需求,推荐一套搭配。
|
||||||
|
|
||||||
@@ -112,14 +114,18 @@ class LCAgent(ls.LitAPI):
|
|||||||
request_summary: 用户的request
|
request_summary: 用户的request
|
||||||
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。
|
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。
|
||||||
"""
|
"""
|
||||||
|
if outfit_ids is None:
|
||||||
|
outfit_ids = []
|
||||||
if start_outfit is None:
|
if start_outfit is None:
|
||||||
start_outfit = []
|
start_outfit = []
|
||||||
tasks = []
|
tasks = []
|
||||||
task_map = {}
|
task_map = {}
|
||||||
|
|
||||||
|
stylist_agent_kwages = self.stylist_agent_kwages.copy()
|
||||||
for i in range(num_outfits):
|
for i in range(num_outfits):
|
||||||
self.stylist_agent_kwages['outfit_id'] = self.outfit_ids[i]
|
stylist_agent_kwages['outfit_id'] = outfit_ids[i]
|
||||||
self.stylist_agent_kwages['max_len'] = max_len
|
stylist_agent_kwages['max_len'] = max_len
|
||||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
agent = AsyncStylistAgent(**stylist_agent_kwages)
|
||||||
task = agent.run_styling_process(
|
task = agent.run_styling_process(
|
||||||
request_summary=request_summary,
|
request_summary=request_summary,
|
||||||
stylist_path=stylist_name,
|
stylist_path=stylist_name,
|
||||||
@@ -129,7 +135,7 @@ class LCAgent(ls.LitAPI):
|
|||||||
gender=gender,
|
gender=gender,
|
||||||
)
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
task_map[task] = {"outfit_id": self.outfit_ids[i], "retries": 0}
|
task_map[task] = {"outfit_id": outfit_ids[i], "retries": 0}
|
||||||
logger.info(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
|
logger.info(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
|
||||||
|
|
||||||
# 2. 任务执行与重试循环
|
# 2. 任务执行与重试循环
|
||||||
@@ -157,8 +163,8 @@ class LCAgent(ls.LitAPI):
|
|||||||
logger.info(f"--- Retrying outfit {outfit_id} (Attempt {current_retries + 1}/{retry_limit}). ---")
|
logger.info(f"--- Retrying outfit {outfit_id} (Attempt {current_retries + 1}/{retry_limit}). ---")
|
||||||
|
|
||||||
# 重新创建任务 (可能需要短暂延迟,例如 time.sleep(1),但在此异步环境中,我们会通过重新创建 agent/task 来实现)
|
# 重新创建任务 (可能需要短暂延迟,例如 time.sleep(1),但在此异步环境中,我们会通过重新创建 agent/task 来实现)
|
||||||
self.stylist_agent_kwages['outfit_id'] = outfit_id
|
stylist_agent_kwages['outfit_id'] = outfit_id
|
||||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
agent = AsyncStylistAgent(**stylist_agent_kwages)
|
||||||
new_task = agent.run_styling_process(
|
new_task = agent.run_styling_process(
|
||||||
request_summary=request_summary,
|
request_summary=request_summary,
|
||||||
stylist_path=stylist_name,
|
stylist_path=stylist_name,
|
||||||
|
|||||||
Reference in New Issue
Block a user