commit 4c0d8817e322297a6274f52248d9410735cf3a49 Author: pangkaicheng <924366729@qq.com> Date: Thu Oct 16 14:04:42 2025 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..74c6dfe --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.env +.vscode/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..bcef194 --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +Checklist +1. Vector database path +2. set GOOGLE_API_KEY in env variable +```bash +export GOOGLE_API_KEY="" +``` +3. Ensure root path added to PYTHONPATH \ No newline at end of file diff --git a/app/api/chat_router.py b/app/api/chat_router.py new file mode 100644 index 0000000..e482c5c --- /dev/null +++ b/app/api/chat_router.py @@ -0,0 +1,103 @@ +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, Field +from typing import List, Dict + +# 导入 ChatbotAgent 类和配置 +# 假设 ChatbotAgent 是单例或在应用启动时创建 +from app.core.chatbot_agent import ChatbotAgent +from app.core.config import settings + +# --- Pydantic Data Models --- +# 用于接收用户消息的请求体 +class ChatRequest(BaseModel): + user_id: str = Field(..., description="Unique identifier for the user.") + user_message: str = Field(..., description="The user's text message.") + +# 用于启动搭配推荐的请求体 +class OutfitStartRequest(BaseModel): + user_id: str = Field(..., description="Unique identifier for the user.") + stylist_name: str = Field("crystal", description="The name of the stylist guide to use (e.g., 'crystal').") + # 用于从已选单品继续搭配,可选 + start_outfit: List[Dict[str, str]] = Field(default_factory=list, description="Optional list of items already selected.") + +# 用于返回 LLM/Agent 的文本回复 +class ChatResponse(BaseModel): + response_text: str = Field(..., description="The chatbot's text response.") + +# 搭配推荐的响应体 (StylistAgent的输出是JSON,这里简化为字符串,实际项目中会返回结构化的JSON) +class OutfitResponse(BaseModel): + summary: str = Field(..., description="The conversation summary used for styling.") + next_item_json: Dict = Field(..., description="The next recommended item in JSON format.") + +# --- Router Setup --- + +router = APIRouter( + prefix="/chat", + tags=["Chatbot & Styling"] +) + +# 在应用启动时实例化 ChatbotAgent,以便在路由中使用同一个实例 +# 注意:在真实的 FastAPI 应用中,我们通常使用依赖注入 (Depends) 来管理实例 +try: + # 假设 ChatbotAgent 的 __init__ 已经包含了所有依赖的初始化 + global_agent = ChatbotAgent() +except Exception as e: + # 如果依赖(如 Redis 或 VectorDB)初始化失败,抛出错误 + print(f"FATAL: ChatbotAgent failed to initialize. Error: {e}") + global_agent = None # 保持 None 状态,路由会抛出 500 + +@router.post("/", response_model=ChatResponse, summary="Process user message and get chatbot response") +def handle_chat_message(request: ChatRequest): + """ + 处理用户的聊天消息,将消息添加到历史记录,并调用 LLM 生成回复。 + """ + if not global_agent: + raise HTTPException(status_code=500, detail="Chatbot agent not initialized.") + + try: + response = global_agent.process_query(request.user_id, request.user_message) + return ChatResponse(response_text=response) + except Exception as e: + print(f"Error processing chat query: {e}") + # 返回一个通用错误信息,而不是内部错误 + raise HTTPException(status_code=500, detail="Internal server error while processing message.") + +@router.post("/outfit/start", summary="Start outfit recommendation based on conversation history") +def start_outfit_recommendation(request: OutfitStartRequest): + """ + 基于用户的对话历史,生成搭配总结,并启动 Stylist Agent 推荐第一个单品。 + + 返回:对话总结和 Stylist Agent 推荐的第一个单品的 JSON。 + """ + if not global_agent: + raise HTTPException(status_code=500, detail="Chatbot agent not initialized.") + + try: + # 1. 获取对话总结 + request_summary = global_agent.get_conversation_summary(request.user_id) + + # 2. 调用 Stylist Agent 运行搭配流程 + # run_styling_process 应该返回 Stylist Agent 的第一个 JSON 输出 (recommend_item) + # 假设 StylistAgent.run_styling_process 返回一个 JSON 字典 + next_item_json = global_agent.stylist_agent.run_styling_process( + request_summary, + request.stylist_name, + request.start_outfit + ) + + return OutfitResponse( + summary=request_summary, + next_item_json=next_item_json + ) + + except Exception as e: + print(f"Error starting outfit recommendation: {e}") + raise HTTPException(status_code=500, detail=f"Failed to start styling process: {e}") + +# (可选) 搭配继续接口:用于处理用户的反馈和继续推荐 +@router.post("/outfit/continue", summary="Continue outfit recommendation with user feedback") +def continue_outfit_recommendation(): + """ + 这是一个占位符,用于处理用户对已推荐单品(如图片)的反馈,并让 Stylist Agent 推荐下一个单品。 + """ + raise HTTPException(status_code=501, detail="Endpoint for outfit continuation is not yet implemented.") diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/__pycache__/__init__.cpython-310.pyc b/app/core/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..0033aaf Binary files /dev/null and b/app/core/__pycache__/__init__.cpython-310.pyc differ diff --git a/app/core/__pycache__/config.cpython-310.pyc b/app/core/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..5973a14 Binary files /dev/null and b/app/core/__pycache__/config.cpython-310.pyc differ diff --git a/app/core/__pycache__/data_structure.cpython-310.pyc b/app/core/__pycache__/data_structure.cpython-310.pyc new file mode 100644 index 0000000..9999a55 Binary files /dev/null and b/app/core/__pycache__/data_structure.cpython-310.pyc differ diff --git a/app/core/__pycache__/llm_interface.cpython-310.pyc b/app/core/__pycache__/llm_interface.cpython-310.pyc new file mode 100644 index 0000000..75b60f5 Binary files /dev/null and b/app/core/__pycache__/llm_interface.cpython-310.pyc differ diff --git a/app/core/__pycache__/redis_manager.cpython-310.pyc b/app/core/__pycache__/redis_manager.cpython-310.pyc new file mode 100644 index 0000000..ddc653c Binary files /dev/null and b/app/core/__pycache__/redis_manager.cpython-310.pyc differ diff --git a/app/core/__pycache__/stylist_agent.cpython-310.pyc b/app/core/__pycache__/stylist_agent.cpython-310.pyc new file mode 100644 index 0000000..94bc308 Binary files /dev/null and b/app/core/__pycache__/stylist_agent.cpython-310.pyc differ diff --git a/app/core/__pycache__/system_prompt.cpython-310.pyc b/app/core/__pycache__/system_prompt.cpython-310.pyc new file mode 100644 index 0000000..2bd9fda Binary files /dev/null and b/app/core/__pycache__/system_prompt.cpython-310.pyc differ diff --git a/app/core/__pycache__/utils.cpython-310.pyc b/app/core/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..372dd6a Binary files /dev/null and b/app/core/__pycache__/utils.cpython-310.pyc differ diff --git a/app/core/__pycache__/vector_database.cpython-310.pyc b/app/core/__pycache__/vector_database.cpython-310.pyc new file mode 100644 index 0000000..21aab9d Binary files /dev/null and b/app/core/__pycache__/vector_database.cpython-310.pyc differ diff --git a/app/core/chatbot_agent.py b/app/core/chatbot_agent.py new file mode 100644 index 0000000..3be2c31 --- /dev/null +++ b/app/core/chatbot_agent.py @@ -0,0 +1,135 @@ +from typing import Dict, List +import asyncio + +from app.core.data_structure import Message, Role +from app.core.llm_interface import AsyncLLMInterface, AsyncGeminiLLM +from app.core.redis_manager import RedisManager +from app.core.system_prompt import BASIC_PROMPT, SUMMARY_PROMPT +from app.core.stylist_agent import AsyncStylistAgent +from app.core.vector_database import VectorDatabase +from app.core.config import settings + +class ChatbotAgent: + def __init__(self, llm_model: AsyncLLMInterface = None): + self.llm = llm_model if llm_model else AsyncGeminiLLM(model_name=settings.LLM_MODEL_NAME) + self.redis = RedisManager( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + key_prefix=settings.REDIS_HISTORY_KEY_PREFIX + ) + self.vector_db = VectorDatabase( + vector_db_dir=settings.VECTOR_DB_DIR, + collection_name=settings.COLLECTION_NAME, + embedding_model_name=settings.EMBEDDING_MODEL_NAME + ) + self.stylist_agent_kwages = { + 'local_db': self.vector_db, + 'max_len': 5, + 'outfits_root': settings.OUTFIT_OUTPUT_DIR, + 'image_dir': settings.IMAGE_DIR, + 'stylist_guide_dir': settings.STYLIST_GUIDE_DIR, + 'gemini_model_name': settings.LLM_MODEL_NAME + } + + async def process_query(self, user_id: str, user_message: str) -> str: + """ + 处理用户的最新输入,调用 LLM, 并更新历史记录。 + """ + + # 添加用户消息到历史 + user_msg = Message(role=Role.USER, content=user_message) + chat_history = self.redis.get_history(user_id) + chat_history.append(user_msg) + + # 生成 LLM 回复 + try: + response_text = await self.llm.generate_response(chat_history, system_prompt=BASIC_PROMPT) + except Exception as e: + print(f"LLM 调用失败: {e}") + response_text = "抱歉,系统暂时无法响应,请稍后再试。" + + # 添加助手消息到历史 + if response_text: + assistant_msg = Message(role=Role.ASSISTANT, content=response_text) + else: + assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.") + + self.redis.save_message(user_id, user_msg) + self.redis.save_message(user_id, assistant_msg) + + return response_text + + async def get_conversation_summary(self, user_id: str) -> str: + """ + 分析用户的完整会话历史,并打包成一个简洁的需求总结。 + + 这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。` + """ + history_messages = self.redis.get_history(user_id) + input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages]) + + # 临时调用 LLM 或使用本地逻辑生成总结 + summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)], system_prompt=SUMMARY_PROMPT) + + return summary + + async def recommend_outfit(self, user_id: str, stylist_name: str, start_outfit: List[Dict[str, str]] = [], num_outfits: int = 1): + """ + 基于用户的对话历史和需求,推荐一套搭配。 + + Args: + user_id: 用户唯一标识符。 + start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。 + """ + request_summary = await self.get_conversation_summary(user_id) + print(f"Conversation Summary:\n{request_summary}") + + tasks = [] + for _ in range(num_outfits): + agent = AsyncStylistAgent(**self.stylist_agent_kwages) + task = agent.run_styling_process(request_summary, stylist_name, start_outfit) + tasks.append(task) + print(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---") + + try: + results = await asyncio.gather(*tasks, return_exceptions=True) + + successful_outfits = [] + failed_outfits = [] + for result in results: + if isinstance(result, Exception): + # 任务执行中发生异常 + failed_outfits.append(f"Failed: {result}") + else: + # 任务成功,result 是 run_styling_process 返回的图片路径 + successful_outfits.append(result) + + return { + "successful_outfits": successful_outfits, + "failed_outfits": failed_outfits + } + + except Exception as e: + print(f"An unexpected error occurred during concurrent recommendation: {e}") + return {"error": str(e)} + + + +if __name__ == "__main__": + async def test(): + agent = ChatbotAgent() + user_id = "user123" + agent.redis.clear_history(user_id) # 清除历史,便于测试 + print(await agent.process_query(user_id, "I need a dress for a summer wedding. I prefer something floral and light.")) + # print(agent.process_query(user_id, "I prefer something floral and light.")) + recommendation_results = await agent.recommend_outfit(user_id, stylist_name="crystal", start_outfit=[], num_outfits=2) + + print("\n--- Final Recommendation Results ---") + for i, path in enumerate(recommendation_results.get("successful_outfits", [])): + print(f"✅ Outfit {i+1} saved to: {path}") + + for error in recommendation_results.get("failed_outfits", []): + print(f"❌ {error}") + + asyncio.run(test()) \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000..7e3d098 --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,38 @@ +from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic import Field + +# ⚠️ 注意: 您需要安装 pydantic-settings: pip install pydantic-settings + +class Settings(BaseSettings): + """ + 应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。 + """ + model_config = SettingsConfigDict( + env_file='.env', + env_file_encoding='utf-8', + extra='ignore' # 忽略环境变量中多余的键 + ) + + # Redis 配置 + REDIS_HOST: str = Field(default='localhost', description="Redis服务器地址") + REDIS_PORT: int = Field(default=6379, description="Redis服务器端口") + REDIS_DB: int = Field(default=0, description="Redis数据库编号") + REDIS_HISTORY_KEY_PREFIX: str = Field(default="chat:history:", description="Redis会话历史键的前缀") + + # LLM 配置 + GEMINI_API_KEY: str = Field(..., description="Google Gemini API 密钥。必须设置。") + LLM_MODEL_NAME: str = Field(default="gemini-2.5-flash", description="使用的 LLM 模型名称") + + # 路径配置参数 + DATA_ROOT: str = Field(default="./data", description="数据根目录") + IMAGE_DIR: str = Field(default="./data/image_data", description="图片数据目录") + OUTFIT_OUTPUT_DIR: str = Field(default="./data/outfit_output", description="生成的搭配图片输出目录") + STYLIST_GUIDE_DIR: str = Field(default="./data/stylist_guide", description="风格指南文本目录") + + # 向量数据库配置参数 + VECTOR_DB_DIR: str = Field(default="./data/db", description="向量数据库目录") + COLLECTION_NAME: str = Field(default="lc_clothing_embedding", description="向量数据库集合名称") + EMBEDDING_MODEL_NAME: str = Field(default="openai/clip-vit-base-patch32", description="CLIP嵌入模型名称") + +# 创建配置实例,供应用其他部分使用 +settings = Settings() diff --git a/app/core/data_structure.py b/app/core/data_structure.py new file mode 100644 index 0000000..241ec6a --- /dev/null +++ b/app/core/data_structure.py @@ -0,0 +1,16 @@ +from typing import List, Dict, Any +from enum import Enum +from pydantic import BaseModel, Field +import datetime + +# 角色枚举,用于区分用户和系统的消息 +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + +# 单条消息的数据模型 +class Message(BaseModel): + role: Role = Field(..., description="Role of message sender") + content: str = Field(..., description="Content of the message") + # timestamp: str = Field(default_factory=lambda: datetime.datetime.now().isoformat()) # 记录时间戳 \ No newline at end of file diff --git a/app/core/llm_interface.py b/app/core/llm_interface.py new file mode 100644 index 0000000..7a0a96c --- /dev/null +++ b/app/core/llm_interface.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod +from typing import List + +from google import genai +from google.genai import types + +from app.core.data_structure import Message, Role + +class AsyncLLMInterface(ABC): + @abstractmethod + async def generate_response(self, history: List[Message], system_prompt: str) -> str: + """ + 根据对话历史和系统指令生成回复. + + Args: + history: 包含多条 Message 的列表。 + + Returns: + LLM 生成的回复字符串。 + """ + raise NotImplementedError("Subclasses must implement this method") + +class AsyncGeminiLLM(AsyncLLMInterface): + def __init__(self, model_name: str = "gemini-2.5-flash"): + self.model_name = model_name + try: + self.gemini_client = genai.Client() + except Exception as e: + raise type(e)(f"Failed to initialize Gemini Client. Check if GEMINI_API_KEY is set. Original error: {e}") + + async def generate_response(self, history: List[Message], system_prompt: str) -> str: + contents = [] + + for msg in history: + gemini_role = "user" if msg.role == Role.USER else "model" + content = types.Content( + role=gemini_role, + parts=[types.Part.from_text(text=msg.content)] + ) + contents.append(content) + + try: + response = await self.gemini_client.aio.models.generate_content( + model=self.model_name, + contents=contents, + config=types.GenerateContentConfig( + system_instruction=system_prompt, + # temperature=0.3, + ) + ) + return response.text + except Exception as e: + raise type(e)(f"Gemini API call failed: {e}") + + \ No newline at end of file diff --git a/app/core/redis_manager.py b/app/core/redis_manager.py new file mode 100644 index 0000000..c6c3768 --- /dev/null +++ b/app/core/redis_manager.py @@ -0,0 +1,63 @@ +import redis +from typing import List, Optional +from app.core.data_structure import Message, Role + +# 这是一个同步 Redis 客户端,用于演示如何替换内存存储。 +# 在生产环境和异步 Web 框架中,应替换为 aioredis 等异步客户端。 +class RedisManager: + """同步管理器,用于在 Redis 中存储和检索对话历史。""" + + def __init__(self, host: str = 'localhost', port: int = 6379, db: int = 0, key_prefix: str = "chat:history:"): + self.r: Optional[redis.Redis] = None + self.key_prefix = key_prefix + try: + # 尝试连接 Redis + self.r = redis.Redis(host=host, port=port, db=db, decode_responses=True) + self.r.ping() + print("Successfully connected to Redis at {}:{}".format(host, port)) + except Exception as e: + print(f"⚠️ Failed to connect to Redis: {e}. Falling back to No-Op.") + self.r = None # 连接失败时设置为 None,避免后续操作报错 + + def _get_key(self, user_id: str) -> str: + """生成用户历史记录的 Redis 键名。""" + return f"{self.key_prefix}{user_id}" + + def _message_to_json(self, message: Message) -> str: + """将 Message 对象序列化为 JSON 字符串以便存储。""" + return message.model_dump_json() + + def _json_to_message(self, data: str) -> Message: + """将 JSON 字符串反序列化回 Message 对象。""" + try: + return Message.model_validate_json(data) + except Exception as e: + print(f"Error deserializing message data: {data[:50]}... Error: {e}") + return Message(role=Role.ASSISTANT, content="[Deserialization Error]") + + def save_message(self, user_id: str, message: Message): + """将单条消息保存到用户历史记录列表的末尾。""" + if not self.r: + return + + message_json = self._message_to_json(message) + # RPUSH:将元素添加到列表的尾部 + self.r.rpush(self._get_key(user_id), message_json) + + def get_history(self, user_id: str) -> List[Message]: + """检索用户的完整会话历史记录。""" + if not self.r: + return [] + + # LRANGE:获取列表的所有元素 (0 到 -1) + raw_history = self.r.lrange(self._get_key(user_id), 0, -1) + + # 将 JSON 字符串列表转换为 Message 对象列表 + messages = [self._json_to_message(data) for data in raw_history] + return messages + + def clear_history(self, user_id: str): + """删除用户的完整历史记录。""" + if self.r: + self.r.delete(self._get_key(user_id)) + print(f"History cleared for {user_id}") diff --git a/app/core/stylist_agent.py b/app/core/stylist_agent.py new file mode 100644 index 0000000..67818f8 --- /dev/null +++ b/app/core/stylist_agent.py @@ -0,0 +1,300 @@ +import json +import os +import random +import uuid +from typing import List, Dict, Any, Optional + +from google import genai + +from app.core.utils import merge_images_to_square + + +class AsyncStylistAgent: + CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'} + + def __init__(self, local_db, max_len: int, outfits_root: str, image_dir: str, stylist_guide_dir: str, gemini_model_name: str): + # self.outfit_items: List[Dict[str, str]] = [] + self.outfit_id = str(uuid.uuid4()) + self.gemini_client = genai.Client() + self.local_db = local_db + self.max_len = max_len + self.output_outfit_path = os.path.join(outfits_root, f"{self.outfit_id}.jpg") + self.output_json_path = os.path.join(outfits_root, f"{self.outfit_id}_items.json") + self.image_dir = image_dir + self.stylist_guide_dir = stylist_guide_dir + self.gemini_model_name = gemini_model_name + self.stop_reason = "" + + def _load_style_guide(self, path: str) -> str: + """加载 markdown 风格指南内容。""" + try: + with open(path, 'r', encoding='utf-8') as f: + return f.read() + except Exception as e: + raise FileNotFoundError(f"Failed to load style guide from {path}: {e}") + + def _build_system_prompt(self, request_summary: str = "") -> str: + """Constructs the complete System Prompt.""" + # Insert the style_guide content into the template + template = f""" + You are a professional fashion stylist Agent, specialized in creating complete outfits for the user. + + Your task is to **create a cohesive and complete outfit**, strictly adhering to **BOTH** the user's explicit **Request Summary** and the **Outfit Style Guide**. You must decide the next logical item to add to the outfit based on the currently selected items (if any). + + --- + + ## Request from the User: + + {request_summary} + + ## Core Guidance Document: Outfit Style Guide + + {self.style_guide} + + --- + + ## Your Workflow and Constraints + + 1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**. + 2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories. + 3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows: + + ```json + {{ + "action": "recommend_item", + "category": "YOUR_ITEM_CATEGORY", + "description": "YOUR_DETAILED_DESCRIPTION" + }} + ``` + + * `action`: Must always be `"recommend_item"` until the outfit is complete. + * `category`: Must be the category of the item you are recommending, strictly selected from the following list: {list(self.CATEGORY_SET)}. + * `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include: + * **Color** (e.g., milk tea, pure white, dark gray) + * **Fit/Silhouette** (e.g., Oversize, loose, slim-fit) + * **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern) + * **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look) + * **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.** + + 4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process: + + ```json + {{ + "action": "stop", + "reason": "OUTFIT_COMPLETE_AND_MEETS_ALL_MINI_GUIDELINES" + }} + ``` + Normally, five or six items are totally enough for an outfit. + + 5. **Context Dependency**: The user's next input (if not `Start`) will contain the **image and description of the selected item**. When recommending the next item, you must consider the coordination between the **already selected items** and the Style Guide. + + **Now, please start building an outfit and output the JSON for the first item.** + """ + return template.strip() + + def _clear_uploaded_files(self): + for f in self.gemini_client.files.list(): + self.gemini_client.files.delete(name=f.name) + + async def _call_gemini(self, user_input: str) -> str: + """ + 实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。 + + Args: + user_input: 发送给模型的主文本内容。 + image_paths: 待发送图片的本地路径列表。 + + Returns: + 模型的响应文本(预期为 JSON 字符串)。 + """ + content_parts = [] + self._clear_uploaded_files() + # 1. 添加图片内容 + if self.outfit_items: + merged_image_path = merge_images_to_square(self.outfit_items, max_len=self.max_len, output_path=self.output_outfit_path) + try: + myfile = await self.gemini_client.aio.files.upload(file=merged_image_path) + content_parts.append(myfile) + except Exception as e: + print(f"Error loading image {merged_image_path}: {e}") + + # 2. 添加文本内容 + content_parts.append(user_input) + + print(f"\n--- Calling Gemini with {len(self.outfit_items) if self.outfit_items else 0} images and query:\n{user_input}") + + try: + # 3. 实际 API 调用 + response = await self.gemini_client.aio.models.generate_content( + model=self.gemini_model_name, + contents=content_parts, + config={ + "system_instruction": self.system_prompt, + # 确保模型返回 JSON 格式 + "response_mime_type": "application/json", + "response_schema": { + "type": "object", + "properties": { + "action": {"type": "string", "enum": ["recommend_item", "stop"]}, + "category": {"type": "string"}, + "description": {"type": "string"}, + "reason": {"type": "string"} + }, + "required": ["action"] + } + } + ) + + # response.text 将包含一个 JSON 字符串 + return response.text + + except Exception as e: + print(f"Gemini API Call failed: {e}") + # 返回一个停止信号以防止循环继续 + return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"}) + + def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]: + """安全解析 Gemini 的 JSON 响应。""" + try: + # 有时 Gemini 可能会在 JSON 外面添加文字,尝试清理 + response_text = response_text.strip().replace('```json', '').replace('```', '') + data = json.loads(response_text) + print(f"The agent response is: {data}") + return data + except json.JSONDecodeError as e: + print(f"Error parsing JSON from Gemini: {e}") + print(f"Raw response: {response_text}") + return None + + def _get_next_item(self, item_description: str, category: str) -> Optional[Dict[str, str]]: + """ + 1. 根据描述生成嵌入。 + 2. 查询本地数据库以找到最佳匹配项。 + 3. 模拟 Agent 审核匹配项(这里简化为总是通过)。 + """ + try: + # 1. 生成查询嵌入 + query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False) + + # 2. 执行查询,并过滤类别 + results = self.local_db.query_local_db(query_embedding, category, n_results=1) + + if not results: + print(f"❌ 数据库中未找到符合 '{category}' 和描述的单品。") + return None + + # 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核) + best_meta = results['metadatas'][0][0] # 第一个 batch 的第一个 metadata + return { + "item_id": best_meta['item_id'], # 从 metadata 字典中安全获取 + "category": category, + "gpt_description": item_description, + 'description': best_meta['description'], + # 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导 + # 这里假设 item_id 就是文件名的一部分 + "image_path": os.path.join(self.image_dir, f"{best_meta['item_id']}.jpg") + } + + except Exception as e: + print(f"An error occurred during item retrieval: {e}") + return None + + def _build_user_input(self) -> str: + """构建发送给 Gemini 的用户输入,包含已选单品信息。""" + if not self.outfit_items: + return "Start" + + # 将已选单品的信息作为上下文发回给 Agent + context = "Selected fashion items:\n" + for ii, item in enumerate(self.outfit_items): + context += f"{ii + 1}. Category: {item['category']}. Description: {item['description']}\n" + context += "\nPlease recommend the next single item based on the selected items, user's request, and style guide." + return context + + async def run_styling_process(self, request_summary, stylist_name, start_outfit=[]): + self.outfit_items = start_outfit if start_outfit else [] + """主流程控制循环。""" + print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---") + + self.style_guide = self._load_style_guide(os.path.join(self.stylist_guide_dir, f"{stylist_name}_en.md")) + self.system_prompt = self._build_system_prompt(request_summary) + + while True: + # 1. 准备用户输入(上下文) + user_input = self._build_user_input() + + # 2. 调用 Gemini Agent + gemini_response_text = await self._call_gemini(user_input) + gemini_data = self._parse_gemini_response(gemini_response_text) + + if not gemini_data: + print("🚨 Agent 返回无效响应,终止流程。") + self.stop_reason = "Agent failed to return response" + break + + # 3. 检查终止条件 + if gemini_data.get('action') == 'stop': + print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}") + self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided') + break + + # 4. 处理推荐单品 + if gemini_data.get('action') == 'recommend_item': + category = gemini_data.get('category') + description = gemini_data.get('description') + + # 4a. 检查类别是否有效 (重要步骤) + if category not in self.CATEGORY_SET: + print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。") + # 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正 + # 这里简化为跳过本次循环 + continue + + # 4b. 在本地 DB 中查询单品 + new_item = self._get_next_item(description, category) + + if new_item: + # 4c. (实际步骤) 将选中的单品图片和描述发回给 Agent 进行最终审核 + # 这里的代码框架省略了图片回传和二次审核的步骤,直接视为通过 + # 实际你需要: new_user_input = f"Check this item: {new_item['description']}, path: {new_item['image_path']}" + # call_gemini_agent(...) -> 如果返回"pass",则添加到outfit_items + + if new_item['item_id'] in [x['item_id'] for x in self.outfit_items]: + print("This item exists. Stop here.") + self.stop_reason = "Finish reason: Duplicate item selected." + break + + if new_item['item_id'] == "ELG383": + if random.random() < 0.70: + self.stop_reason = "Finish reason: ELG383 is seleced repeatly." + break + + self.outfit_items.append(new_item) + print(f"➕ 成功添加单品: {new_item['category']} ({new_item['item_id']}). 当前搭配数量: {len(self.outfit_items)}") + + else: + print("⚠️ 未找到匹配单品,无法继续搭配。终止。") + self.stop_reason = "Finish reason: No matching item found in local database." + break + + if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环 + print("🚨 达到最大搭配数量限制,强制终止。") + self.stop_reason = "Finish reason: Reached max outfit length." + break + + # 5. 流程结束后保存结果 + self._save_outfit_results() + return self.output_outfit_path + + def _save_outfit_results(self): + """保存最终的 JSON 列表和图片到指定文件夹。""" + if not self.outfit_items: + raise ValueError("No outfit items to save.") + + # 1. 保存 JSON 文件 + results_list = [{'item_id': item['item_id'], 'category': item['category'], 'description': item['description'], 'gpt_description': item['gpt_description']} for item in self.outfit_items] + results_list.append({'stop_reason': self.stop_reason}) + with open(self.output_json_path, 'w', encoding='utf-8') as f: + json.dump(results_list, f, ensure_ascii=False, indent=4) + + merge_images_to_square(self.outfit_items, max_len=self.max_len, output_path=self.output_outfit_path, add_text=False) \ No newline at end of file diff --git a/app/core/system_prompt.py b/app/core/system_prompt.py new file mode 100644 index 0000000..dc93ca9 --- /dev/null +++ b/app/core/system_prompt.py @@ -0,0 +1,3 @@ +BASIC_PROMPT = """You are a fashion stylist AI named StylistGPT. Your task is to assist users in creating personalized fashion looks based on their preferences, occasions, and current fashion trends.""" + +SUMMARY_PROMPT = """Given conversation history, summarize the user's fashion preferences, occasions, and any specific requirements they have mentioned. Provide a concise overview of their style profile in one sentence.""" \ No newline at end of file diff --git a/app/core/utils.py b/app/core/utils.py new file mode 100644 index 0000000..f2a10cf --- /dev/null +++ b/app/core/utils.py @@ -0,0 +1,163 @@ +from typing import List, Dict +import shutil + +from PIL import Image, ImageDraw, ImageFont + +# 9个 341x341 左右的单元格 (ALL_9_CELLS) +# 布局顺序: 从上到下,从左到右 (1 -> 9) +ALL_9_CELLS = [ + # Top Row (Y=0, H=341) + (0, 0, 341, 341), # 1. Top-Left (341x341) + (341, 0, 341, 341), # 2. Top-Middle (341x341) + (682, 0, 342, 341), # 3. Top-Right (342x341) + # Middle Row (Y=341, H=341) + (0, 341, 341, 341), # 4. Mid-Left (341x341) + (341, 341, 341, 341),# 5. Center (341x341) + (682, 341, 342, 341),# 6. Mid-Right (342x341) + # Bottom Row (Y=682, H=342) + (0, 682, 341, 342), # 7. Bottom-Left (341x342) + (341, 682, 341, 342),# 8. Bottom-Middle (341x342) + (682, 682, 342, 342) # 9. Bottom-Right (342x342) +] + +def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, output_path="temp.jpg", add_text=True) -> str: + """ + Loads up to 4 images from the given paths, resizes them while maintaining + aspect ratio, and merges them onto a 1024x1024 white background JPG. + + The layout depends on the number of images: + 1: Center the single image on the 1024x1024 canvas. + 2: Place side-by-side, each scaled to fit a 512x1024 half. + 3: Place in top-left (512x512), top-right (512x512), and bottom-left (512x512). + 4: Place in all four 512x512 quadrants. + + Args: + outfit_items: A list of item metadata (max length 9). + + Returns: + The file path of the temporary merged JPG image. + """ + + # Define the final canvas size + CANVAS_SIZE = 1024 + + # 1. Create the final white canvas + # Using 'RGB' mode for JPG output + canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), 'white') + draw = ImageDraw.Draw(canvas) + font = ImageFont.load_default() + + + # 2. Define the quadrants/target areas (x, y, w, h) + # The positions are based on a 512x512 quadrant size + quadrants = { + 1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement + 2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right + 3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left + 4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four + 5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5) + 6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6) + 7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7) + 8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8) + 9: ALL_9_CELLS[:9] # 布局全部9个单元格 (1-9) + } + + # 3. Load and Filter Images + valid_images = [] + image_paths = [item['image_path'] for item in outfit_items] + for path in image_paths: + try: + # We use Image.open() and convert to 'RGB' to handle potential transparency (RGBA) + # and ensure compatibility with the final 'RGB' canvas and JPG output. + img = Image.open(path).convert('RGB') + valid_images.append(img) + except Exception as e: + print(f"Error loading image {path}. Skipping: {e}") + + num_images = len(valid_images) + + if num_images == 0: + raise ValueError("No valid images were loaded.") + + if num_images > max_len: + raise ValueError(f"Valid item number {num_images} exceed max limit {max_len}") + + + # Get the correct list of target areas based on the number of valid images + target_areas = quadrants.get(num_images, []) + + # 4. Resize and Paste + for i, (img, item) in enumerate(zip(valid_images, outfit_items)): + item_id = item['item_id'] + category = item['category'] + if i >= len(target_areas): + # This should not happen if num_images <= 4 + break + + # Target area dimensions (x_start, y_start, width, height) + x_start, y_start, target_w, target_h = target_areas[i] + + # Calculate new size while maintaining aspect ratio + original_w, original_h = img.size + + # Calculate the ratio needed to fit within the target area + ratio_w = target_w / original_w + ratio_h = target_h / original_h + + # Use the *smaller* of the two ratios to ensure the image fits entirely + resize_ratio = min(ratio_w, ratio_h) + + # Calculate the new dimensions + new_w = int(original_w * resize_ratio) + new_h = int(original_h * resize_ratio) + + # Resize the image. Image.Resampling.LANCZOS provides high-quality scaling. + # Pillow documentation recommends ANTIALIAS or BICUBIC for downscaling, + # but LANCZOS is a good general high-quality filter. + # Note: In Pillow versions > 9.0.0, Image.LANCZOS is now Image.Resampling.LANCZOS + resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + + # Calculate the paste position to center the resized image within its target area + # Center X: (Target Width - New Width) / 2 + X Start + paste_x = (target_w - new_w) // 2 + x_start + # Center Y: (Target Height - New Height) / 2 + Y Start + # paste_y = (target_h - new_h) // 2 + y_start + + TEXT_RESERVE_HEIGHT = 30 + paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start + paste_y = max(paste_y, y_start) + + # Paste the resized image onto the canvas + canvas.paste(resized_img, (paste_x, paste_y)) + + full_text = f"ID: {item_id}, Category: {category}" + try: + # 推荐使用:计算文本的实际尺寸 (width, height) + bbox = draw.textbbox((0, 0), full_text, font=font) + text_w = bbox[2] - bbox[0] + text_h = bbox[3] - bbox[1] + except AttributeError: + # 兼容旧版本 Pillow + text_w, text_h = draw.textsize(full_text, font=font) + + + # 计算 X 轴起始位置:使其在目标区域 (target_w) 中居中 + text_x_center = x_start + target_w // 2 + text_x_start = text_x_center - text_w // 2 + + # 计算 Y 轴起始位置:将其放在目标区域的底部 + # (目标区域的起始Y + 目标区域的高度 - 文本行的高度) + text_y_start = y_start + target_h - text_h - 5 # 减去 5 像素作为边距 + + # 3. 绘制合并后的文本 + if add_text: + draw.text((text_x_start, text_y_start), + full_text, + fill='black', + font=font) + + # Save as a high-quality JPG (quality=90 is a good balance) + canvas.save(output_path, 'JPEG', quality=90) + + return output_path + diff --git a/app/core/vector_database.py b/app/core/vector_database.py new file mode 100644 index 0000000..f5f1a57 --- /dev/null +++ b/app/core/vector_database.py @@ -0,0 +1,63 @@ +import os +from typing import List, Dict, Any, Optional + +import torch +import chromadb +from transformers import CLIPProcessor, CLIPModel +from PIL import Image + + +class VectorDatabase(): + def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str): + print(vector_db_dir) + self.client = chromadb.PersistentClient(path=vector_db_dir) + self.collection = self.client.get_collection(name=collection_name) + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device) + self.processor = CLIPProcessor.from_pretrained(embedding_model_name) + + def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]: + """生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。""" + + if is_image: + inputs = self.processor(images=data, return_tensors="pt").to(self.device) + with torch.no_grad(): + features = self.model.get_image_features(**inputs) + else: + # 强制截断,解决序列长度问题 + inputs = self.processor( + text=[data], + return_tensors="pt", + padding=True, + truncation=True + ).to(self.device) + with torch.no_grad(): + features = self.model.get_text_features(**inputs) + + # L2 归一化 + features = features / features.norm(p=2, dim=-1, keepdim=True) + + return features.cpu().numpy().flatten().tolist() + + def query_local_db(self, embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]: + """ + 基于嵌入向量在本地数据库中查询相似单品。 + 实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。 + """ + print(f"--- Querying DB for Category: {category} ---") + # 实际应执行向量查询 + # 为了演示流程,返回一个模拟结果 + results = self.collection.query( + query_embeddings=[embedding], + n_results=n_results, + where={ + "$and": [ + {"category": category}, + {"modality": "image"}, + ] + }, + include=['documents', 'metadatas', 'distances'] + ) + return results \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..760d79c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,231 @@ +accelerate==0.21.0 +aiohttp==3.9.5 +aiosignal==1.3.1 +albumentations==0.3.2 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anykeystore==0.2 +asn1crypto==1.5.1 +asttokens==2.4.1 +async-timeout==4.0.3 +attrs==21.2.0 +bidict==0.23.1 +blessed==1.20.0 +boto3==1.34.113 +botocore==1.34.113 +braceexpand==0.1.7 +cachetools==5.3.3 +certifi==2024.2.2 +cffi==1.16.0 +chardet==5.2.0 +charset-normalizer==3.3.2 +click==8.1.7 +clip==0.2.0 +clip-openai==1.0.post20230121 +cmake==3.29.3 +cramjam==2.8.3 +crcmod==1.7 +cryptacular==1.6.2 +cryptography==39.0.2 +cycler==0.12.1 +datasets==2.2.1 +diffusers==0.30.1 +decorator==5.1.1 +decord==0.6.0 +deepspeed==0.14.2 +defusedxml==0.7.1 +Deprecated==1.2.14 +descartes==1.1.0 +dill==0.3.8 +distlib==0.3.8 +distro-info==1.0 +dnspython==2.6.1 +docker-pycreds==0.4.0 +docstring_parser==0.16 +ecdsa==0.19.0 +einops==0.6.0 +exceptiongroup==1.2.1 +executing==2.0.1 +fairscale==0.4.13 +fastparquet==2024.5.0 +ffmpegcv==0.3.13 +filelock==3.14.0 +fire==0.6.0 +fonttools==4.51.0 +frozenlist==1.4.1 +fsspec==2023.6.0 +ftfy==6.2.0 +gitdb==4.0.11 +GitPython==3.1.43 +gpustat==1.1.1 +greenlet==3.0.3 +grpcio==1.64.0 +h11==0.14.0 +hjson==3.1.0 +hupper==1.12.1 +idna==3.7 +imageio==2.34.1 +imgaug==0.2.6 +iniconfig==2.0.0 +ipaddress==1.0.23 +ipdb==0.13.13 +ipython==8.18.1 +jaxtyping==0.2.28 +jedi==0.19.1 +Jinja2==3.1.4 +jmespath==1.0.1 +joblib==1.4.2 +jsonargparse==4.14.1 +jsonlines==4.0.0 +kiwisolver==1.4.5 +kornia==0.7.2 +kornia_rs==0.1.3 +lazy_loader==0.4 +lightning==2.2.3 +lightning-utilities==0.11.2 +lit==18.1.6 +MarkupSafe==2.1.5 +matplotlib==3.5.3 +matplotlib-inline==0.1.7 +miscreant==0.3.0 +mpmath==1.3.0 +msgpack==1.0.8 +multidict==6.0.5 +multiprocess==0.70.16 +natsort==8.4.0 +networkx==3.2.1 +ninja==1.11.1.1 +numpy==1.24.4 +nuscenes-devkit==1.1.11 +oauthlib==3.2.2 +omegaconf==2.3.0 +open-clip-torch==2.24.0 +openai-clip +opencv-python==4.9.0.80 +opencv-python-headless==3.4.18.65 +packaging==22.0 +pandas==1.5.3 +parquet==1.3.1 +parso==0.8.4 +PasteDeploy==3.1.0 +pathlib2==2.3.7.post1 +pathtools==0.1.2 +pbkdf2==1.3 +pexpect==4.9.0 +pillow==10.3.0 +plaster==1.1.2 +plaster-pastedeploy==1.0.1 +platformdirs==4.2.2 +plotly==5.22.0 +pluggy==1.5.0 +ply==3.11 +promise==2.3 +prompt-toolkit==3.0.43 +protobuf==3.20.3 +psutil==5.9.8 +ptyprocess==0.7.0 +pure-eval==0.2.2 +py==1.11.0 +py-cpuinfo==9.0.0 +py-spy==0.3.14 +pyarrow==11.0.0 +pyarrow-hotfix==0.6 +pyasn1==0.6.0 +pycocotools==2.0.7 +pycparser==2.22 +pycryptodomex==3.20.0 +pycurl==7.43.0.6 +ollama==0.4.4 +pydantic==2.9.2 +pydantic_core==2.23.4 +Pygments==2.18.0 +PyJWT==2.8.0 +pynvml==11.5.0 +pyope==0.2.2 +pyOpenSSL==23.2.0 +pyparsing==3.1.2 +pyquaternion==0.9.9 +pyramid==2.0.2 +pyramid-mailer==0.15.1 +pytest==6.2.5 +python-consul==1.1.0 +python-dateutil==2.9.0.post0 +python-engineio==4.9.1 +python-etcd==0.4.5 +python-jose==3.3.0 +python-socketio==5.11.2 +python3-openid==3.2.0 +pytorch-extension==0.2 +pytorch-lightning==2.2.3 +pytz==2024.1 +PyYAML==6.0.1 +regex==2024.5.15 +repoze.sendmail==4.4.1 +requests==2.31.0 +requests-oauthlib==2.0.0 +rsa==4.9 +s3transfer==0.10.1 +safetensors==0.4.3 +schedule==1.2.2 +scikit-image==0.22.0 +scikit-learn==1.5.0 +scipy==1.13.1 +sentencepiece==0.2.0 +sentry-sdk==2.3.1 +setproctitle==1.3.3 +Shapely==1.8.5.post1 +shortuuid==1.0.13 +simple-websocket==1.0.0 +six==1.16.0 +smmap==5.0.1 +SQLAlchemy==2.0.30 +stack-data==0.6.3 +sympy==1.12 +taming-transformers-rom1504==0.0.6 +tenacity==8.3.0 +tensorboardX==2.6.2.2 +termcolor==2.4.0 +threadpoolctl==3.5.0 +thriftpy2==0.5.0 +tifffile==2024.5.22 +timm==1.0.3 +tokenizers==0.19.1 +toml==0.10.2 +tomli==2.0.1 +torch==2.2.1 +torch-fidelity==0.3.0 +torchmetrics==1.4.0.post0 +torchvision==0.17.1 +tox==3.28.0 +tqdm==4.66.4 +traitlets==5.14.3 +transaction==4.0 +transformers==4.41.1 +translationstring==1.4 +triton==2.2.0 +typeguard==2.13.3 +typing_extensions==4.12.0 +tzdata==2024.1 +urllib3==1.26.18 +velruse==1.1.1 +venusian==3.1.0 +virtualenv==20.26.2 +wandb==0.17.2 +watchdog==4.0.1 +wcwidth==0.2.13 +webdataset==0.2.86 +WebOb==1.8.7 +websocket-client==1.8.0 +wrapt==1.16.0 +wsproto==1.2.0 +WTForms==3.1.2 +wtforms-recaptcha==0.3.2 +xformers==0.0.25 +xxhash==3.4.1 +yarl==1.9.4 +zope.deprecation==5.0 +zope.interface==6.4.post2 +zope.sqlalchemy==3.1 +pytorch-fid==0.3.0 +lpips==0.1.4 +huggingface_hub==0.24.6 \ No newline at end of file