diff --git a/.env_local b/.env_local index 101b05e..c6b4704 100644 --- a/.env_local +++ b/.env_local @@ -1,2 +1,2 @@ GEMINI_API_KEY=AIzaSyAO4zXFke1bqyrXd9-RGfKJTLerwLSFKww - GOOGLE_APPLICATION_CREDENTIALS="/workspace/lc_stylist_agent/app/request.json" \ No newline at end of file +GOOGLE_APPLICATION_CREDENTIALS="/workspace/lc_stylist_agent/app/request.json" \ No newline at end of file diff --git a/app/server/ChatbotAgent/agent_server.py b/app/server/ChatbotAgent/agent_server.py index c2a3eb0..c5e7042 100644 --- a/app/server/ChatbotAgent/agent_server.py +++ b/app/server/ChatbotAgent/agent_server.py @@ -5,12 +5,12 @@ import uuid import litserve as ls from pydantic import BaseModel from app.core.config import settings -from app.core.data_structure import Message, Role -from app.core.llm_interface import AsyncGeminiLLM -from app.core.redis_manager import RedisManager -from app.core.stylist_agent_server import AsyncStylistAgent -from app.core.system_prompt import SUMMARY_PROMPT -from app.core.vector_database import VectorDatabase +from app.server.ChatbotAgent.core.data_structure import Message, Role +from app.server.ChatbotAgent.core.llm_interface import AsyncGeminiLLM +from app.server.ChatbotAgent.core.redis_manager import RedisManager +from app.server.ChatbotAgent.core.stylist_agent_server import AsyncStylistAgent +from app.server.ChatbotAgent.core.system_prompt import SUMMARY_PROMPT +from app.server.ChatbotAgent.core.vector_database import VectorDatabase logger = logging.getLogger(__name__) diff --git a/app/server/ChatbotAgent/chatbot_server.py b/app/server/ChatbotAgent/chatbot_server.py index 2dd6293..518d0df 100644 --- a/app/server/ChatbotAgent/chatbot_server.py +++ b/app/server/ChatbotAgent/chatbot_server.py @@ -1,21 +1,18 @@ -# server.py -import asyncio import logging -import time -from typing import AsyncGenerator - -from google import genai import litserve as ls + +from typing import AsyncGenerator +from google import genai from pydantic import BaseModel from app.core.config import settings -from app.core.data_structure import Role, Message -from app.core.llm_interface import AsyncGeminiLLM -from app.core.redis_manager import RedisManager -from app.core.stylist_agent import AsyncStylistAgent -from app.core.system_prompt import BASIC_PROMPT, SUMMARY_PROMPT, MEN_BASIC_PROMPT, WOMEN_BASIC_PROMPT -from app.core.vector_database import VectorDatabase from google.genai import types +from app.server.ChatbotAgent.core.data_structure import Message, Role +from app.server.ChatbotAgent.core.llm_interface import AsyncGeminiLLM +from app.server.ChatbotAgent.core.redis_manager import RedisManager +from app.server.ChatbotAgent.core.system_prompt import MEN_BASIC_PROMPT, WOMEN_BASIC_PROMPT +from app.server.ChatbotAgent.core.vector_database import VectorDatabase + logger = logging.getLogger(__name__) diff --git a/app/server/ChatbotAgent/core/data_structure.py b/app/server/ChatbotAgent/core/data_structure.py new file mode 100644 index 0000000..76170dd --- /dev/null +++ b/app/server/ChatbotAgent/core/data_structure.py @@ -0,0 +1,16 @@ +from enum import Enum +from pydantic import BaseModel, Field + + +# 角色枚举,用于区分用户和系统的消息 +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + +# 单条消息的数据模型 +class Message(BaseModel): + role: Role = Field(..., description="Role of message sender") + content: str = Field(..., description="Content of the message") + # timestamp: str = Field(default_factory=lambda: datetime.datetime.now().isoformat()) # 记录时间戳 diff --git a/app/server/ChatbotAgent/core/llm_interface.py b/app/server/ChatbotAgent/core/llm_interface.py new file mode 100644 index 0000000..ca071ce --- /dev/null +++ b/app/server/ChatbotAgent/core/llm_interface.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import List + +from google import genai +from google.genai import types + +from app.server.ChatbotAgent.core.data_structure import Message, Role + + +class AsyncLLMInterface(ABC): + @abstractmethod + async def generate_response(self, history: List[Message], system_prompt: str) -> str: + """ + 根据对话历史和系统指令生成回复. + + Args: + history: 包含多条 Message 的列表。 + + Returns: + LLM 生成的回复字符串。 + """ + raise NotImplementedError("Subclasses must implement this method") + + +class AsyncGeminiLLM(AsyncLLMInterface): + def __init__(self, model_name: str = "gemini-2.5-flash"): + self.model_name = model_name + try: + self.gemini_client = genai.Client( + vertexai=True, project='aida-461108', location='us-central1' + ) + except Exception as e: + raise type(e)(f"Failed to initialize Gemini Client. Check if GEMINI_API_KEY is set. Original error: {e}") + + async def generate_response(self, history: List[Message], system_prompt: str) -> str: + contents = [] + + for msg in history: + gemini_role = "user" if msg.role == Role.USER else "model" + content = types.Content( + role=gemini_role, + parts=[types.Part.from_text(text=msg.content)] + ) + contents.append(content) + + try: + response = await self.gemini_client.aio.models.generate_content( + model=self.model_name, + contents=contents, + config=types.GenerateContentConfig( + system_instruction=system_prompt, + # temperature=0.3, + ) + ) + return response.text + except Exception as e: + raise type(e)(f"Gemini API call failed: {e}") diff --git a/app/server/ChatbotAgent/core/redis_manager.py b/app/server/ChatbotAgent/core/redis_manager.py new file mode 100644 index 0000000..f7a7323 --- /dev/null +++ b/app/server/ChatbotAgent/core/redis_manager.py @@ -0,0 +1,69 @@ +import logging + +import redis +from typing import List, Optional + +from app.server.ChatbotAgent.core.data_structure import Message, Role + +logger = logging.getLogger(__name__) + + +# 这是一个同步 Redis 客户端,用于演示如何替换内存存储。 +# TODO 在生产环境和异步 Web 框架中,应替换为 aioredis 等异步客户端。 +class RedisManager: + """同步管理器,用于在 Redis 中存储和检索对话历史。""" + + def __init__(self, host: str = 'localhost', port: int = 6379, db: int = 0, key_prefix: str = "chat:history:"): + self.r: Optional[redis.Redis] = None + self.key_prefix = key_prefix + try: + # 尝试连接 Redis + self.r = redis.Redis(host=host, port=port, db=db, decode_responses=True) + self.r.ping() + logger.info("Successfully connected to Redis at {}:{}".format(host, port)) + except Exception as e: + logger.error(f"⚠️ Failed to connect to Redis: {e}. Falling back to No-Op.") + self.r = None # 连接失败时设置为 None,避免后续操作报错 + + def _get_key(self, user_id: str) -> str: + """生成用户历史记录的 Redis 键名。""" + return f"{self.key_prefix}{user_id}" + + def _message_to_json(self, message: Message) -> str: + """将 Message 对象序列化为 JSON 字符串以便存储。""" + return message.model_dump_json() + + def _json_to_message(self, data: str) -> Message: + """将 JSON 字符串反序列化回 Message 对象。""" + try: + return Message.model_validate_json(data) + except Exception as e: + logger.error(f"Error deserializing message data: {data[:50]}... Error: {e}") + return Message(role=Role.ASSISTANT, content="[Deserialization Error]") + + def save_message(self, user_id: str, message: Message): + """将单条消息保存到用户历史记录列表的末尾。""" + if not self.r: + return + + message_json = self._message_to_json(message) + # RPUSH:将元素添加到列表的尾部 + self.r.rpush(self._get_key(user_id), message_json) + + def get_history(self, user_id: str) -> List[Message]: + """检索用户的完整会话历史记录。""" + if not self.r: + return [] + + # LRANGE:获取列表的所有元素 (0 到 -1) + raw_history = self.r.lrange(self._get_key(user_id), 0, -1) + + # 将 JSON 字符串列表转换为 Message 对象列表 + messages = [self._json_to_message(data) for data in raw_history] + return messages + + def clear_history(self, user_id: str): + """删除用户的完整历史记录。""" + if self.r: + self.r.delete(self._get_key(user_id)) + logger.info(f"History cleared for {user_id}") diff --git a/app/server/ChatbotAgent/core/stylist_agent_server.py b/app/server/ChatbotAgent/core/stylist_agent_server.py new file mode 100644 index 0000000..baaf1c2 --- /dev/null +++ b/app/server/ChatbotAgent/core/stylist_agent_server.py @@ -0,0 +1,435 @@ +import asyncio +import io +import json +import logging +import os +import random +import uuid +from typing import List, Dict, Any, Optional + +from google import genai +from google.cloud import storage +from google.oauth2 import service_account + +from app.core.utils_litserve import merge_images_to_square +from app.server.utils.minio_client import minio_client, oss_upload_image +from app.server.utils.request_post import post_request + +logger = logging.getLogger(__name__) + + +class AsyncStylistAgent: + CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'} + + def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str): + # self.outfit_items: List[Dict[str, str]] = [] + self.outfit_id = outfit_id + self.gemini_client = genai.Client( + vertexai=True, project='aida-461108', location='us-central1' + ) + self.local_db = local_db + self.max_len = max_len + self.gemini_model_name = gemini_model_name + self.stop_reason = "" + + # 存储桶配置 + try: + # TODO 目前写死路径 生产环境切换路径 + self.credentials = service_account.Credentials.from_service_account_file(os.getenv("GOOGLE_APPLICATION_CREDENTIALS")) + except Exception as e: + # 这里的异常处理应根据实际情况调整 + raise RuntimeError(f"Failed to load credentials from file {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}: {e}") + + self.gcs_client = storage.Client( + project=self.credentials.project_id, + credentials=self.credentials + ) + self.gcs_bucket = "lc_stylist_agent_outfit_items" + self.minio_bucket = "lanecarford" + + def _load_style_guide(self, path: str) -> str: + """加载 markdown 风格指南内容。""" + parts = path.split('/', 1) + if len(parts) != 2: + raise ValueError("MinIO path must be in 'bucket_name/object_name' format.") + + bucket_name, object_name = parts + try: + # 1. 获取对象 + response = minio_client.get_object(bucket_name, object_name) + + # 2. 读取内容 + content_bytes = response.read() + + # 3. 关闭连接 + response.close() + response.release_conn() + + # 4. 解码并返回 + return content_bytes.decode('utf-8') + + except Exception as e: + raise Exception(f"Failed to load style guide from {path}: {e}") + + def _build_system_prompt(self, request_summary: str = "") -> str: + """Constructs the complete System Prompt.""" + # Insert the style_guide content into the template + template = f""" + You are a professional fashion stylist Agent, specialized in creating complete outfits for the user. + + Your task is to **create a cohesive and complete outfit**, strictly adhering to **BOTH** the user's explicit **Request Summary** and the **Outfit Style Guide**. You must decide the next logical item to add to the outfit based on the currently selected items (if any). + + --- + + ## Request from the User: + + {request_summary} + + ## Core Guidance Document: Outfit Style Guide + + {self.style_guide} + + --- + + ## Your Workflow and Constraints + + 1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**. + 2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories. + 3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows: + + ```json + {{ + "action": "recommend_item", + "category": "YOUR_ITEM_CATEGORY", + "description": "YOUR_DETAILED_DESCRIPTION" + }} + ``` + + * `action`: Must always be `"recommend_item"` until the outfit is complete. + * `category`: Must be the category of the item you are recommending, strictly selected from the following list: {list(self.CATEGORY_SET)}. + * `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include: + * **Color** (e.g., milk tea, pure white, dark gray) + * **Fit/Silhouette** (e.g., Oversize, loose, slim-fit) + * **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern) + * **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look) + * **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.** + + 4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process: + + ```json + {{ + "action": "stop", + "reason": "OUTFIT_COMPLETE_AND_MEETS_ALL_MINI_GUIDELINES" + }} + ``` + Normally, five or six items are totally enough for an outfit. + + 5. **Context Dependency**: The user's next input (if not `Start`) will contain the **image and description of the selected item**. When recommending the next item, you must consider the coordination between the **already selected items** and the Style Guide. + + **Now, please start building an outfit and output the JSON for the first item.** + """ + return template.strip() + + def _clear_uploaded_files(self): + for f in self.gemini_client.files.list(): + self.gemini_client.files.delete(name=f.name) + + async def _call_gemini(self, user_input: str, user_id: str): + """ + 实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。 + + Args: + user_input: 发送给模型的主文本内容。 + image_paths: 待发送图片的本地路径列表。 + + Returns: + 模型的响应文本(预期为 JSON 字符串)。 + """ + minio_path = "" + content_parts = [] + # self._clear_uploaded_files() + # 1. 添加图片内容 + if self.outfit_items: + merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len, add_text=False) + image_bytes_io = io.BytesIO() + image_format = 'JPEG' + mime_type = 'image/jpeg' + + merged_image.save(image_bytes_io, format=image_format) + image_bytes = image_bytes_io.getvalue() + + file_name = uuid.uuid4() + blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg" + gcs_path = self._upload_to_gcs(bucket_name=self.gcs_bucket, blob_name=blob_name, mime_type=mime_type, image_bytes=image_bytes) + responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes) + minio_path = f"{responses.bucket_name}/{responses.object_name}" + content_parts.append(gcs_path) + + # 2. 添加文本内容 + content_parts.append(user_input) + + # print(f"\n--- Calling Gemini with {len(self.outfit_items) if self.outfit_items else 0} images and query:\n{user_input}") + + try: + # 3. 实际 API 调用 + response = await self.gemini_client.aio.models.generate_content( + model=self.gemini_model_name, + contents=content_parts, + config={ + "system_instruction": self.system_prompt, + # 确保模型返回 JSON 格式 + "response_mime_type": "application/json", + "response_schema": { + "type": "object", + "properties": { + "action": {"type": "string", "enum": ["recommend_item", "stop"]}, + "category": {"type": "string"}, + "description": {"type": "string"}, + "reason": {"type": "string"} + }, + "required": ["action"] + } + } + ) + + # response.text 将包含一个 JSON 字符串 + return response.text, minio_path + + except Exception as e: + print(f"Gemini API Call failed: {e}") + # 返回一个停止信号以防止循环继续 + return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"}) + + def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]: + """安全解析 Gemini 的 JSON 响应。""" + try: + # 有时 Gemini 可能会在 JSON 外面添加文字,尝试清理 + response_text = response_text.strip().replace('```json', '').replace('```', '') + data = json.loads(response_text) + # print(f"The agent response is: {data}") + return data + except json.JSONDecodeError as e: + print(f"Error parsing JSON from Gemini: {e}") + print(f"Raw response: {response_text}") + return None + + def _get_next_item(self, item_description: str, category: str) -> Optional[Dict[str, str]]: + """ + 1. 根据描述生成嵌入。 + 2. 查询本地数据库以找到最佳匹配项。 + 3. 模拟 Agent 审核匹配项(这里简化为总是通过)。 + """ + try: + # 1. 生成查询嵌入 + query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False) + + # 2. 执行查询,并过滤类别 + results = self.local_db.query_local_db(query_embedding, category, n_results=1) + + if not results: + print(f"❌ 数据库中未找到符合 '{category}' 和描述的单品。") + return None + + # 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核) + best_meta = results['metadatas'][0][0] # 第一个 batch 的第一个 metadata + return { + "item_id": best_meta['item_id'], # 从 metadata 字典中安全获取 + "category": category, + "gpt_description": item_description, + 'description': best_meta['description'], + # 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导 + # 这里假设 item_id 就是文件名的一部分 + "image_path": os.path.join(f"{best_meta['item_id']}.jpg") + } + + except Exception as e: + print(f"An error occurred during item retrieval: {e}") + return None + + def _build_user_input(self) -> str: + """构建发送给 Gemini 的用户输入,包含已选单品信息。""" + if not self.outfit_items: + return "Start" + + # 将已选单品的信息作为上下文发回给 Agent + context = "Selected fashion items:\n" + for ii, item in enumerate(self.outfit_items): + context += f"{ii + 1}. Category: {item['category']}. Description: {item['description']}\n" + context += "\nPlease recommend the next single item based on the selected items, user's request, and style guide." + return context + + async def run_styling_process(self, request_summary, stylist_path, start_outfit=None, user_id="test", callback_url=""): + if start_outfit is None: + start_outfit = [] + self.outfit_items = start_outfit if start_outfit else [] + """主流程控制循环。""" + print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---") + + self.style_guide = self._load_style_guide(stylist_path) + self.system_prompt = self._build_system_prompt(request_summary) + response_data = {"status": "", + "message": "", + "path": "", + "outfit_id": self.outfit_id, + "items": [] + } + logger.info(response_data) + item_id = "" + item_category = "" + while True: + # 1. 准备用户输入(上下文) + user_input = self._build_user_input() + + # 2. 调用 Gemini Agent + gemini_response_text, minio_path = await self._call_gemini(user_input, user_id) + gemini_data = self._parse_gemini_response(gemini_response_text) + response_data['path'] = minio_path + if item_id: + response_data['items'].append({"item_id": item_id, "category": item_category}) + if not gemini_data: + print("🚨 Agent 返回无效响应,终止流程。") + self.stop_reason = "Agent failed to return response" + response_data['status'] = "failed" + response_data['message'] = self.stop_reason + break + + # 3. 检查终止条件 + if gemini_data.get('action') == 'stop': + print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}") + self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided') + response_data['status'] = "stop" + response_data['message'] = self.stop_reason + + # 4. 处理推荐单品 + if gemini_data.get('action') == 'recommend_item': + category = gemini_data.get('category') + description = gemini_data.get('description') + + # 4a. 检查类别是否有效 (重要步骤) + if category not in self.CATEGORY_SET: + print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。") + # 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正 + # 这里简化为跳过本次循环 + response_data['status'] = "continue" + response_data['message'] = f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。", + continue + + # 4b. 在本地 DB 中查询单品 + new_item = self._get_next_item(description, category) + item_id = new_item.get('item_id') + item_category = new_item.get('category') + + if new_item: + # 4c. (实际步骤) 将选中的单品图片和描述发回给 Agent 进行最终审核 + # 这里的代码框架省略了图片回传和二次审核的步骤,直接视为通过 + # 实际你需要: new_user_input = f"Check this item: {new_item['description']}, path: {new_item['image_path']}" + # call_gemini_agent(...) -> 如果返回"pass",则添加到outfit_items + + if new_item['item_id'] in [x['item_id'] for x in self.outfit_items]: + print("This item exists. Stop here.") + self.stop_reason = "Finish reason: Duplicate item selected." + response_data['status'] = "stop" + response_data['message'] = self.stop_reason + break + + if new_item['item_id'] == "ELG383": + if random.random() < 0.70: + self.stop_reason = "Finish reason: ELG383 is seleced repeatly." + response_data['status'] = "stop" + response_data['message'] = self.stop_reason + break + + self.outfit_items.append(new_item) + # print(f"➕ 成功添加单品: {new_item['category']} ({new_item['item_id']}). 当前搭配数量: {len(self.outfit_items)}") + response_data['status'] = "ok" + response_data['message'] = self.stop_reason + else: + print("⚠️ 未找到匹配单品,无法继续搭配。终止。") + self.stop_reason = "Finish reason: No matching item found in local database." + response_data['status'] = "stop" + response_data['message'] = self.stop_reason + break + + if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环 + logger.info("🚨 达到最大搭配数量限制,强制终止。") + self.stop_reason = "Finish reason: Reached max outfit length." + response_data['status'] = "stop" + response_data['message'] = self.stop_reason + + logger.info(f"request data :{response_data}") + headers = { + 'Accept': "*/*", + 'Accept-Encoding': "gzip, deflate, br", + 'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0", + 'Connection': "keep-alive", + 'Content-Type': "application/json" + } + url = f'{callback_url}/api/style/callback' + response = post_request(url=url, data=json.dumps(response_data), headers=headers) + logger.info(response.text) + return response_data + + # def _save_outfit_results(self, user_id): + # """保存最终的 JSON 列表和图片到指定文件夹。""" + # if not self.outfit_items: + # raise ValueError("No outfit items to save.") + # + # # 1. 保存 JSON 文件 + # results_list = [{'item_id': item['item_id'], 'category': item['category'], 'description': item['description'], 'gpt_description': item['gpt_description']} for item in self.outfit_items] + # results_list.append({'stop_reason': self.stop_reason}) + # + # return upload_json_to_minio_sync( + # minio_client=minio_client, + # bucket_name=f"lanecarford", + # object_name=f"lc_stylist_agent_outfit_items/{user_id}/{uuid.uuid4()}.json", + # data=results_list + # ) + + def _upload_to_gcs(self, bucket_name: str, blob_name: str, mime_type, image_bytes) -> str: + """同步方法:将文件上传到 GCS 并返回 GCS URI。""" + bucket = self.gcs_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + blob.upload_from_string( + data=image_bytes, + content_type=mime_type + ) + + gcs_uri = f"gs://{bucket_name}/{blob_name}" + return gcs_uri + + async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit: List[Dict[str, str]] = [], num_outfits: int = 1): + """ + 基于用户的对话历史和需求,推荐一套搭配。 + + Args: + request_summary: 用户的request + start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。 + """ + tasks = [] + for _ in range(num_outfits): + agent = AsyncStylistAgent(**self.stylist_agent_kwages) + task = agent.run_styling_process(request_summary, stylist_name, start_outfit) + tasks.append(task) + print(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---") + + try: + results = await asyncio.gather(*tasks, return_exceptions=True) + + successful_outfits = [] + failed_outfits = [] + for result in results: + if isinstance(result, Exception): + # 任务执行中发生异常 + failed_outfits.append(f"Failed: {result}") + else: + # 任务成功,result 是 run_styling_process 返回的图片路径 + successful_outfits.append(result) + + return { + "successful_outfits": successful_outfits, + "failed_outfits": failed_outfits + } + + except Exception as e: + print(f"An unexpected error occurred during concurrent recommendation: {e}") + return {"error": str(e)} diff --git a/app/server/ChatbotAgent/core/system_prompt.py b/app/server/ChatbotAgent/core/system_prompt.py new file mode 100644 index 0000000..7156586 --- /dev/null +++ b/app/server/ChatbotAgent/core/system_prompt.py @@ -0,0 +1,56 @@ +BASIC_PROMPT = """""" +WOMEN_BASIC_PROMPT = """You are a professional, friendly, and insightful AI women's styling assistant. + +Your primary mission is to engage in a multi-turn conversation with the user to fully understand their dressing intent. You must adopt a professional yet approachable tone. + +CONVERSATION GOALS: +1. **Occasion:** Determine the specific event (e.g., romantic dinner, summer wedding, business meeting). +2. **Style:** Pinpoint the desired aesthetic (e.g., classic elegance, edgy, minimalist, bohemian). +3. **Vibe/Details:** Gather any mood or specific constraints (e.g., needs to be comfortable, requires light colors, no bare shoulders). +4. **Item Preference:** Ask the user if they have any specific preferences for an item type or silhouette (e.g., preference for a dress, skirt, tailored pants, or a particular neckline/length). + +GUIDANCE FOR RESPONSE GENERATION: +- After the user's initial request (e.g., "I want a chic outfit for dinner."), immediately reply with a friendly, targeted follow-up question to elicit the most crucial missing information (usually a combination of **Occasion** and **Style**). +- Be concise. Ask only 1 to 2 essential questions per turn. +- You must gather sufficient, clear intent before proceeding to actual clothing recommendations. + +OUTPUT FORMAT INSTRUCTION: +- **DO NOT** use any Markdown formatting whatsoever (e.g., do not use asterisks (*), bold text (**), lists, or code blocks). +- **ONLY** output the plain text response spoken by the AI Assistant. + +Example Follow-up (mimicking a conversational flow): +User: I want a chic outfit for dinner. +Your Response: Hey there! A chic dinner outfit, I love that! To give you the perfect recommendations, tell me: is this a romantic date, business dinner, or celebration with friends? And what's your go-to style vibe: classic elegance or something with more edge?""" +MEN_BASIC_PROMPT = """You are a professional, friendly, and insightful AI men's styling assistant. + +Your primary mission is to engage in a multi-turn conversation with the user to fully understand their dressing intent. You must adopt a professional yet approachable tone. + +CONVERSATION GOALS: +1. **Occasion:** Determine the specific event (e.g., romantic dinner, summer wedding, business meeting). +2. **Style:** Pinpoint the desired aesthetic (e.g., classic elegance, edgy, minimalist, bohemian). +3. **Vibe/Details:** Gather any mood or specific constraints (e.g., needs to be comfortable, requires light colors, no bare shoulders). +4. **Item Preference:** Ask the user if they have any specific preferences for an item type or silhouette (e.g., preference for a dress, skirt, tailored pants, or a particular neckline/length). + +GUIDANCE FOR RESPONSE GENERATION: +- After the user's initial request (e.g., "I want a chic outfit for dinner."), immediately reply with a friendly, targeted follow-up question to elicit the most crucial missing information (usually a combination of **Occasion** and **Style**). +- Be concise. Ask only 1 to 2 essential questions per turn. +- You must gather sufficient, clear intent before proceeding to actual clothing recommendations. + +OUTPUT FORMAT INSTRUCTION: +- **DO NOT** use any Markdown formatting whatsoever (e.g., do not use asterisks (*), bold text (**), lists, or code blocks). +- **ONLY** output the plain text response spoken by the AI Assistant. + +Example Follow-up (mimicking a conversational flow): +User: I want a chic outfit for dinner. +Your Response: Hey there! A chic dinner outfit, I love that! To give you the perfect recommendations, tell me: is this a romantic date, business dinner, or celebration with friends? And what's your go-to style vibe: classic elegance or something with more edge?""" + +SUMMARY_PROMPT = """Analyze the following chat history. Your task is to extract all user intentions, scenarios, style preferences, and constraints expressed during the conversation, and distill them into a concise, structured JSON object. + +**YOUR OUTPUT MUST BE A JSON OBJECT ONLY, WITH NO SURROUNDING TEXT, MARKDOWN, OR EXPLANATION.** + +JSON FIELD REQUIREMENTS: +- **occasion (string):** The specific event and purpose (e.g., "Romantic date dinner", "Summer outdoor wedding", "Casual Friday at office"). +- **style (string):** The overall aesthetic description (e.g., "Classic elegance", "Modern minimalist", "Bohemian vibe", "Edgy and contemporary"). +- **color_preference (string or list):** User's preferred or excluded colors/tones (e.g., "Light colors only", "Avoid deep shades", "['Cream', 'Pale Blue']", "No preference"). +- **clothing_type (string):** User's preference for specific garment types, material, or silhouette (e.g., "Lightweight maxi dress", "Skirt with silk blouse", "Tailored wide-leg pants", "Floral print"). +- **vibe_or_details (string):** Any other details, mood requirements, or specific constraints (e.g., "Needs to be comfortable and breathable", "Accent on accessories", "Must cover shoulders").""" diff --git a/app/server/ChatbotAgent/core/utils_litserve.py.py b/app/server/ChatbotAgent/core/utils_litserve.py.py new file mode 100644 index 0000000..fcaa4a4 --- /dev/null +++ b/app/server/ChatbotAgent/core/utils_litserve.py.py @@ -0,0 +1,163 @@ +import logging +from typing import List, Dict +from PIL import Image, ImageDraw, ImageFont +from app.server.utils.minio_client import oss_get_image, minio_client +from app.server.utils.minio_config import MINIO_LC_DATA_PATH + +logger = logging.getLogger(__name__) +# 9个 341x341 左右的单元格 (ALL_9_CELLS) +# 布局顺序: 从上到下,从左到右 (1 -> 9) +ALL_9_CELLS = [ + # Top Row (Y=0, H=341) + (0, 0, 341, 341), # 1. Top-Left (341x341) + (341, 0, 341, 341), # 2. Top-Middle (341x341) + (682, 0, 342, 341), # 3. Top-Right (342x341) + # Middle Row (Y=341, H=341) + (0, 341, 341, 341), # 4. Mid-Left (341x341) + (341, 341, 341, 341), # 5. Center (341x341) + (682, 341, 342, 341), # 6. Mid-Right (342x341) + # Bottom Row (Y=682, H=342) + (0, 682, 341, 342), # 7. Bottom-Left (341x342) + (341, 682, 341, 342), # 8. Bottom-Middle (341x342) + (682, 682, 342, 342) # 9. Bottom-Right (342x342) +] + + +def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, add_text=True): + """ + Loads up to 4 images from the given paths, resizes them while maintaining + aspect ratio, and merges them onto a 1024x1024 white background JPG. + + The layout depends on the number of images: + 1: Center the single image on the 1024x1024 canvas. + 2: Place side-by-side, each scaled to fit a 512x1024 half. + 3: Place in top-left (512x512), top-right (512x512), and bottom-left (512x512). + 4: Place in all four 512x512 quadrants. + + Args: + outfit_items: A list of item metadata (max length 9). + + Returns: + The file path of the temporary merged JPG image. + """ + + # Define the final canvas size + CANVAS_SIZE = 1024 + + # 1. Create the final white canvas + # Using 'RGB' mode for JPG output + canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), 'white') + draw = ImageDraw.Draw(canvas) + font = ImageFont.load_default() + + # 2. Define the quadrants/target areas (x, y, w, h) + # The positions are based on a 512x512 quadrant size + quadrants = { + 1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement + 2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right + 3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left + 4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four + 5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5) + 6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6) + 7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7) + 8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8) + 9: ALL_9_CELLS[:9] # 布局全部9个单元格 (1-9) + } + + # 3. Load and Filter Images + valid_images = [] + image_paths = [item['image_path'] for item in outfit_items] + for path in image_paths: + try: + # We use Image.open() and convert to 'RGB' to handle potential transparency (RGBA) + # and ensure compatibility with the final 'RGB' canvas and JPG output. + img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB') + # img = Image.open(path).convert('RGB') + valid_images.append(img) + except Exception as e: + logger.error(f"Error loading image {path}. Skipping: {e}") + + num_images = len(valid_images) + + if num_images == 0: + raise ValueError("No valid images were loaded.") + + if num_images > max_len: + raise ValueError(f"Valid item number {num_images} exceed max limit {max_len}") + + # Get the correct list of target areas based on the number of valid images + target_areas = quadrants.get(num_images, []) + + # 4. Resize and Paste + for i, (img, item) in enumerate(zip(valid_images, outfit_items)): + item_id = item['item_id'] + category = item['category'] + if i >= len(target_areas): + # This should not happen if num_images <= 4 + break + + # Target area dimensions (x_start, y_start, width, height) + x_start, y_start, target_w, target_h = target_areas[i] + + # Calculate new size while maintaining aspect ratio + original_w, original_h = img.size + + # Calculate the ratio needed to fit within the target area + ratio_w = target_w / original_w + ratio_h = target_h / original_h + + # Use the *smaller* of the two ratios to ensure the image fits entirely + resize_ratio = min(ratio_w, ratio_h) + + # Calculate the new dimensions + new_w = int(original_w * resize_ratio) + new_h = int(original_h * resize_ratio) + + # Resize the image. Image.Resampling.LANCZOS provides high-quality scaling. + # Pillow documentation recommends ANTIALIAS or BICUBIC for downscaling, + # but LANCZOS is a good general high-quality filter. + # Note: In Pillow versions > 9.0.0, Image.LANCZOS is now Image.Resampling.LANCZOS + resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + + # Calculate the paste position to center the resized image within its target area + # Center X: (Target Width - New Width) / 2 + X Start + paste_x = (target_w - new_w) // 2 + x_start + # Center Y: (Target Height - New Height) / 2 + Y Start + # paste_y = (target_h - new_h) // 2 + y_start + + TEXT_RESERVE_HEIGHT = 30 + paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start + paste_y = max(paste_y, y_start) + + # Paste the resized image onto the canvas + canvas.paste(resized_img, (paste_x, paste_y)) + + full_text = f"ID: {item_id}, Category: {category}" + try: + # 推荐使用:计算文本的实际尺寸 (width, height) + bbox = draw.textbbox((0, 0), full_text, font=font) + text_w = bbox[2] - bbox[0] + text_h = bbox[3] - bbox[1] + except AttributeError: + # 兼容旧版本 Pillow + text_w, text_h = draw.textsize(full_text, font=font) + + # 计算 X 轴起始位置:使其在目标区域 (target_w) 中居中 + text_x_center = x_start + target_w // 2 + text_x_start = text_x_center - text_w // 2 + + # 计算 Y 轴起始位置:将其放在目标区域的底部 + # (目标区域的起始Y + 目标区域的高度 - 文本行的高度) + text_y_start = y_start + target_h - text_h - 5 # 减去 5 像素作为边距 + + # 3. 绘制合并后的文本 + if add_text: + draw.text((text_x_start, text_y_start), + full_text, + fill='black', + font=font) + + # Save as a high-quality JPG (quality=90 is a good balance) + # canvas.save(output_path, 'JPEG', quality=90) + + return canvas diff --git a/app/server/ChatbotAgent/core/vector_database.py b/app/server/ChatbotAgent/core/vector_database.py new file mode 100644 index 0000000..10c4f4a --- /dev/null +++ b/app/server/ChatbotAgent/core/vector_database.py @@ -0,0 +1,60 @@ +import torch +import chromadb +from PIL import Image +from typing import List, Dict, Any +from transformers import CLIPProcessor, CLIPModel + + +class VectorDatabase(): + def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str): + self.client = chromadb.PersistentClient(path=vector_db_dir) + + self.collection = self.client.get_or_create_collection(name=collection_name) + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device) + self.processor = CLIPProcessor.from_pretrained(embedding_model_name) + + def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]: + """生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。""" + + if is_image: + inputs = self.processor(images=data, return_tensors="pt").to(self.device) + with torch.no_grad(): + features = self.model.get_image_features(**inputs) + else: + # 强制截断,解决序列长度问题 + inputs = self.processor( + text=[data], + return_tensors="pt", + padding=True, + truncation=True + ).to(self.device) + with torch.no_grad(): + features = self.model.get_text_features(**inputs) + + # L2 归一化 + features = features / features.norm(p=2, dim=-1, keepdim=True) + + return features.cpu().numpy().flatten().tolist() + + def query_local_db(self, embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]: + """ + 基于嵌入向量在本地数据库中查询相似单品。 + 实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。 + """ + # 实际应执行向量查询 + # 为了演示流程,返回一个模拟结果 + results = self.collection.query( + query_embeddings=[embedding], + n_results=n_results, + where={ + "$and": [ + {"category": category}, + {"modality": "image"}, + ] + }, + include=['documents', 'metadatas', 'distances'] + ) + return results