diff --git a/app/core/stylist_agent_server.py b/app/core/stylist_agent_server.py index 5366f07..baaf1c2 100644 --- a/app/core/stylist_agent_server.py +++ b/app/core/stylist_agent_server.py @@ -150,7 +150,7 @@ class AsyncStylistAgent: # self._clear_uploaded_files() # 1. 添加图片内容 if self.outfit_items: - merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len) + merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len, add_text=False) image_bytes_io = io.BytesIO() image_format = 'JPEG' mime_type = 'image/jpeg' @@ -275,6 +275,7 @@ class AsyncStylistAgent: } logger.info(response_data) item_id = "" + item_category = "" while True: # 1. 准备用户输入(上下文) user_input = self._build_user_input() @@ -284,7 +285,7 @@ class AsyncStylistAgent: gemini_data = self._parse_gemini_response(gemini_response_text) response_data['path'] = minio_path if item_id: - response_data['items'].append(item_id) + response_data['items'].append({"item_id": item_id, "category": item_category}) if not gemini_data: print("🚨 Agent 返回无效响应,终止流程。") self.stop_reason = "Agent failed to return response" @@ -316,6 +317,7 @@ class AsyncStylistAgent: # 4b. 在本地 DB 中查询单品 new_item = self._get_next_item(description, category) item_id = new_item.get('item_id') + item_category = new_item.get('category') if new_item: # 4c. (实际步骤) 将选中的单品图片和描述发回给 Agent 进行最终审核 diff --git a/app/main.py b/app/main.py index c94c4f2..64ecfa9 100644 --- a/app/main.py +++ b/app/main.py @@ -4,6 +4,7 @@ import litserve as ls from app.core.config import DEBUG, settings from app.server.ChatbotAgent.agent_server import LCAgent from app.server.ChatbotAgent.chatbot_server import LCChatBot +from app.server.ReFace.server import ReFace from logging_env import LOGGER_CONFIG_DICT logger = logging.getLogger(__name__) @@ -24,5 +25,6 @@ if __name__ == "__main__": logger.info(f"VECTOR_DB_DIR -> :{settings.VECTOR_DB_DIR}") chat_boot_api = LCChatBot(enable_async=True, stream=True, api_path='/api/v1/chatbot') agent_api = LCAgent(enable_async=True, api_path='/api/v1/agent') - server = ls.LitServer([chat_boot_api, agent_api]) + reface_api = ReFace(api_path='/api/v1/reface') + server = ls.LitServer([chat_boot_api, agent_api, reface_api]) server.run(port=8000) diff --git a/app/server/ChatbotAgent/agent_server.py b/app/server/ChatbotAgent/agent_server.py index 1605cd4..c2a3eb0 100644 --- a/app/server/ChatbotAgent/agent_server.py +++ b/app/server/ChatbotAgent/agent_server.py @@ -1,9 +1,6 @@ import asyncio -import json import logging -import os import uuid -from typing import List, Dict import litserve as ls from pydantic import BaseModel diff --git a/app/server/ReFace/server.py b/app/server/ReFace/server.py new file mode 100644 index 0000000..413e5f1 --- /dev/null +++ b/app/server/ReFace/server.py @@ -0,0 +1,59 @@ +import json + +import litserve as ls +import requests +from pydantic import BaseModel + + +class PredictRequest(BaseModel): + input_image_list: list[str] # 待换脸图片 + input_face: str # 目标脸图片 + threshold: float = 0.2 # 相似度 max:0.5 + + +class ReFace(ls.LitAPI): + def decode_request(self, request: PredictRequest): + return request + + def predict(self, request): + # 服务的 URL + url = "http://10.1.1.240:10071/predict" + + # 请求头 + headers = { + "accept": "application/json", + "Content-Type": "application/json" + } + + # 请求体数据 + # 这里的结构要和你的 LitServe 服务的 LitAPI.decode_request 预期的一致 + data = { + "input_image_list": request.input_image_list, + "input_face": request.input_face, + "threshold": request.threshold + } + + try: + # 使用 requests.post 发送请求 + # 使用 json= 参数可以自动将 Python 字典转换为 JSON 格式,并设置 Content-Type 头部 + response = requests.post(url, headers=headers, json=data) + + # 检查 HTTP 响应状态码,如果不是 200/201 等成功状态,将抛出异常 + response.raise_for_status() + + # 打印返回的 JSON 结果 + print("成功调用 LitServe 接口,返回结果:") + # .json() 方法将响应体解析为 Python 字典 + print(json.dumps(response.json(), indent=4, ensure_ascii=False)) + return response.json() + + except requests.exceptions.RequestException as e: + # 处理请求失败、连接错误或 HTTP 错误状态 + print(f"请求发生错误: {e}") + if 'response' in locals() and response is not None: + print(f"响应状态码: {response.status_code}") + try: + # 尝试打印服务器返回的错误详情 + print(f"服务器错误详情: {response.json()}") + except: + print(f"服务器错误详情 (非 JSON 格式): {response.text}")