push gitignore
This commit is contained in:
@@ -34,7 +34,7 @@ class LCAgent(ls.LitAPI):
|
||||
key_prefix=settings.REDIS_HISTORY_KEY_PREFIX
|
||||
)
|
||||
self.vector_db = VectorDatabase(
|
||||
vector_db_dir=os.getenv('VECTOR_DB_DIR', '/app/app/db'),
|
||||
vector_db_dir=settings.VECTOR_DB_DIR,
|
||||
collection_name=settings.COLLECTION_NAME,
|
||||
embedding_model_name=settings.EMBEDDING_MODEL_NAME
|
||||
)
|
||||
@@ -43,6 +43,7 @@ class LCAgent(ls.LitAPI):
|
||||
'max_len': 5,
|
||||
'gemini_model_name': settings.LLM_MODEL_NAME
|
||||
}
|
||||
self.outfit_ids = []
|
||||
|
||||
async def decode_request(self, request: AgentRequestModel):
|
||||
"""
|
||||
@@ -59,8 +60,11 @@ class LCAgent(ls.LitAPI):
|
||||
return request
|
||||
|
||||
async def predict(self, request):
|
||||
|
||||
self.outfit_ids = [str(uuid.uuid4()) for _ in range(request.num_outfits)]
|
||||
|
||||
asyncio.create_task(self.background_run(request))
|
||||
return {"status": "Task initiated in background."}
|
||||
return {"status": "Task initiated in background.", "outfit_ids": self.outfit_ids}
|
||||
|
||||
async def encode_response(self, output):
|
||||
return output
|
||||
@@ -108,7 +112,8 @@ class LCAgent(ls.LitAPI):
|
||||
if start_outfit is None:
|
||||
start_outfit = []
|
||||
tasks = []
|
||||
for _ in range(num_outfits):
|
||||
for i in range(num_outfits):
|
||||
self.stylist_agent_kwages['outfit_id'] = self.outfit_ids[i]
|
||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
||||
task = agent.run_styling_process(
|
||||
request_summary=request_summary,
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.core.data_structure import Role, Message
|
||||
from app.core.llm_interface import AsyncGeminiLLM
|
||||
from app.core.redis_manager import RedisManager
|
||||
from app.core.stylist_agent import AsyncStylistAgent
|
||||
from app.core.system_prompt import BASIC_PROMPT, SUMMARY_PROMPT
|
||||
from app.core.system_prompt import BASIC_PROMPT, SUMMARY_PROMPT, MEN_BASIC_PROMPT, WOMEN_BASIC_PROMPT
|
||||
from app.core.vector_database import VectorDatabase
|
||||
from google.genai import types
|
||||
|
||||
@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
class PredictRequest(BaseModel):
|
||||
user_id: str # 用戶ID
|
||||
user_message: str # 用戶輸入
|
||||
gender: str # 服装类型
|
||||
|
||||
|
||||
class LCChatBot(ls.LitAPI):
|
||||
@@ -60,6 +61,10 @@ class LCChatBot(ls.LitAPI):
|
||||
user_msg = Message(role=Role.USER, content=user_message)
|
||||
chat_history = self.redis.get_history(user_id)
|
||||
chat_history.append(user_msg)
|
||||
if request.gender == 'male':
|
||||
BASIC_PROMPT = MEN_BASIC_PROMPT
|
||||
else:
|
||||
BASIC_PROMPT = WOMEN_BASIC_PROMPT
|
||||
|
||||
contents = []
|
||||
|
||||
@@ -103,86 +108,3 @@ class LCChatBot(ls.LitAPI):
|
||||
# The for-loop must have async keyword here since output is an AsyncGenerator
|
||||
async for out in output:
|
||||
yield {"output": out}
|
||||
|
||||
async def process_query(self, user_id: str, user_message: str) -> str:
|
||||
"""
|
||||
处理用户的最新输入,调用 LLM, 并更新历史记录。
|
||||
"""
|
||||
|
||||
# 添加用户消息到历史
|
||||
user_msg = Message(role=Role.USER, content=user_message)
|
||||
chat_history = self.redis.get_history(user_id)
|
||||
chat_history.append(user_msg)
|
||||
|
||||
# 生成 LLM 回复
|
||||
try:
|
||||
response_text = await self.llm.generate_response(chat_history, system_prompt=BASIC_PROMPT)
|
||||
except Exception as e:
|
||||
logger("\n--- Final Recommendation Results ---")
|
||||
|
||||
logger.error(f"LLM 调用失败: {e}")
|
||||
response_text = "抱歉,系统暂时无法响应,请稍后再试。"
|
||||
|
||||
# 添加助手消息到历史
|
||||
if response_text:
|
||||
assistant_msg = Message(role=Role.ASSISTANT, content=response_text)
|
||||
else:
|
||||
assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.")
|
||||
|
||||
self.redis.save_message(user_id, user_msg)
|
||||
self.redis.save_message(user_id, assistant_msg)
|
||||
|
||||
return response_text
|
||||
|
||||
async def get_conversation_summary(self, user_id: str) -> str:
|
||||
"""
|
||||
分析用户的完整会话历史,并打包成一个简洁的需求总结。
|
||||
|
||||
这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。`
|
||||
"""
|
||||
history_messages = self.redis.get_history(user_id)
|
||||
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
|
||||
|
||||
# 临时调用 LLM 或使用本地逻辑生成总结
|
||||
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)], system_prompt=SUMMARY_PROMPT)
|
||||
|
||||
return summary
|
||||
|
||||
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit=None, num_outfits: int = 1):
|
||||
"""
|
||||
基于用户的对话历史和需求,推荐一套搭配。
|
||||
|
||||
Args:
|
||||
request_summary: 用户的request
|
||||
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。
|
||||
"""
|
||||
if start_outfit is None:
|
||||
start_outfit = []
|
||||
tasks = []
|
||||
for _ in range(num_outfits):
|
||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
||||
task = agent.run_styling_process(request_summary, stylist_name, start_outfit)
|
||||
tasks.append(task)
|
||||
logger.info(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
successful_outfits = []
|
||||
failed_outfits = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
# 任务执行中发生异常
|
||||
failed_outfits.append(f"Failed: {result}")
|
||||
else:
|
||||
# 任务成功,result 是 run_styling_process 返回的图片路径
|
||||
successful_outfits.append(result)
|
||||
|
||||
return {
|
||||
"successful_outfits": successful_outfits,
|
||||
"failed_outfits": failed_outfits
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred during concurrent recommendation: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
Reference in New Issue
Block a user