diff --git a/app/config.py b/app/config.py index 6a8a8e4..0d83b4e 100644 --- a/app/config.py +++ b/app/config.py @@ -16,6 +16,8 @@ class Settings(BaseSettings): env_file_encoding='utf-8', extra='ignore' # 忽略环境变量中多余的键 ) + # 启动端口 + SERVE_PROD: int = Field(default=8000, description='') # 调试配饰 LOCAL: int = Field(default=0, description="是否在本地运行,1表示本地运行,0表示生产环境运行") diff --git a/app/main.py b/app/main.py index bdfbf36..0895586 100644 --- a/app/main.py +++ b/app/main.py @@ -27,4 +27,4 @@ if __name__ == "__main__": agent_api = LCAgent(enable_async=True, api_path='/api/v1/agent') reface_api = ReFace(api_path='/api/v1/reface') server = ls.LitServer([chat_boot_api, agent_api, reface_api]) - server.run(port=8000) + server.run(port=settings.SERVE_PROD) diff --git a/app/server/ChatbotAgent/agent_server.py b/app/server/ChatbotAgent/agent_server.py index 75e8340..87be7df 100644 --- a/app/server/ChatbotAgent/agent_server.py +++ b/app/server/ChatbotAgent/agent_server.py @@ -1,4 +1,5 @@ import asyncio +import json import logging import uuid from enum import Enum @@ -15,6 +16,7 @@ from app.server.ChatbotAgent.core.redis_manager import RedisManager from app.server.ChatbotAgent.core.stylist_agent_server import AsyncStylistAgent from app.server.ChatbotAgent.core.prompt import SUMMARY_PROMPT from app.server.ChatbotAgent.core.vector_database import VectorDatabase +from app.server.utils.request_post import post_request logger = logging.getLogger(__name__) @@ -262,7 +264,14 @@ class LCAgent(ls.LitAPI): logger.error(f"Outfit {outfit_id} failed with error: {result}. Current retries: {current_retries}.") if current_retries < retry_limit: - # 尚未达到重试上限,准备重试 + # 尚未达到重试上限,准备重试 并通知前端 + object_data = { + 'outfit_id': outfit_id, + "status": "retrying", + "path": "", + } + post_request(url=f'{callback_url}/api/style/callback', data=json.dumps(object_data)) + task_info["retries"] += 1 logger.info(f"--- Retrying outfit {outfit_id} (Attempt {current_retries + 1}/{retry_limit}). ---") @@ -286,7 +295,13 @@ class LCAgent(ls.LitAPI): # 清理旧任务(可选,但推荐,以防内存泄漏或混淆) del task_map[task] else: - # 达到重试上限,最终失败 + # 达到重试上限,最终失败 , 并通知前端 + object_data = { + 'outfit_id': outfit_id, + "status": "failed", + "path": "", + } + post_request(url=f'{callback_url}/api/style/callback', data=json.dumps(object_data)) failed_outfits.append(f"Outfit {outfit_id} ultimately failed after {retry_limit} retries: {result}") del task_map[task] diff --git a/app/server/ChatbotAgent/core/stylist_agent_server.py b/app/server/ChatbotAgent/core/stylist_agent_server.py index 8c10cbc..84e0fe9 100644 --- a/app/server/ChatbotAgent/core/stylist_agent_server.py +++ b/app/server/ChatbotAgent/core/stylist_agent_server.py @@ -38,13 +38,6 @@ class AsyncStylistAgent: self.local_db = local_db self.gemini_model_name = gemini_model_name self.stop_reason = "" - self.headers = { - 'Accept': "*/*", - 'Accept-Encoding': "gzip, deflate, br", - 'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0", - 'Connection': "keep-alive", - 'Content-Type': "application/json" - } # 存储桶配置 try: @@ -262,7 +255,7 @@ class AsyncStylistAgent: "request_summary": request_summary, "occasions": occasions } - response = post_request(url=callback_url, data=json.dumps(response_data), headers=self.headers) + response = post_request(url=callback_url, data=json.dumps(response_data)) logger.info(f"request data :{json.dumps(response_data, ensure_ascii=False, indent=2)} | JAVA callback info -> status:{response.status_code} | message:{response.text}") return response_data else: @@ -426,7 +419,7 @@ class AsyncStylistAgent: else: self.outfit_items.append(new_item) print(f"Item {idx + 1}: ({subcategory}) {rec_item}, found item: {new_item}") - + # 如果没有找到的item过于多,需要重试 if failed_found_item_count / len(recommended_items) > 0.5: self.post_operation( @@ -520,7 +513,13 @@ class AsyncStylistAgent: user_id, url ) - + # 推荐即将完成 回调通知前端 + self.post_operation( + status="almost_done", + message="Recommendation has been completed and the outfit is being assembled", + callback_url=url, + img_path="", + ) final_image_path, _ = await self._merge_images(self.outfit_id, user_id, self.stylist_name) response_data = self.post_operation( status="stop", diff --git a/app/server/utils/request_post.py b/app/server/utils/request_post.py index 1679588..61d9a47 100644 --- a/app/server/utils/request_post.py +++ b/app/server/utils/request_post.py @@ -4,18 +4,24 @@ import time import requests -def post_request(url, data=None, json_data=None, headers=None, auth=None, timeout=5): +def post_request(url, data=None, json_data=None, auth=None, timeout=5): """ 发送POST请求的封装函数 :param url: 接口的URL地址 :param data: 要发送的数据(字典形式,用于表单数据等,会自动编码) :param json_data: 要发送的JSON数据(字典形式,会自动转换为JSON字符串) - :param headers: 请求头字典 :param auth: 认证信息(如 ('username', 'password') 形式用于基本认证) :param timeout: 超时时间,单位为秒 :return: 返回接口的响应对象 """ + headers = { + 'Accept': "*/*", + 'Accept-Encoding': "gzip, deflate, br", + 'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0", + 'Connection': "keep-alive", + 'Content-Type': "application/json" + } try: response = requests.post( url, @@ -52,6 +58,6 @@ if __name__ == '__main__': 'Content-Type': "application/json" } start_time = time.time() - X = post_request(url=url, data=json.dumps(object_data), headers=headers) + X = post_request(url=url, data=json.dumps(object_data)) print(time.time() - start_time) print(X)