first commit
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
.env
|
||||
.vscode/
|
||||
7
README.md
Normal file
7
README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
Checklist
|
||||
1. Vector database path
|
||||
2. set GOOGLE_API_KEY in env variable
|
||||
```bash
|
||||
export GOOGLE_API_KEY="<your_API_KEY>"
|
||||
```
|
||||
3. Ensure root path added to PYTHONPATH
|
||||
103
app/api/chat_router.py
Normal file
103
app/api/chat_router.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict
|
||||
|
||||
# 导入 ChatbotAgent 类和配置
|
||||
# 假设 ChatbotAgent 是单例或在应用启动时创建
|
||||
from app.core.chatbot_agent import ChatbotAgent
|
||||
from app.core.config import settings
|
||||
|
||||
# --- Pydantic Data Models ---
|
||||
# 用于接收用户消息的请求体
|
||||
class ChatRequest(BaseModel):
|
||||
user_id: str = Field(..., description="Unique identifier for the user.")
|
||||
user_message: str = Field(..., description="The user's text message.")
|
||||
|
||||
# 用于启动搭配推荐的请求体
|
||||
class OutfitStartRequest(BaseModel):
|
||||
user_id: str = Field(..., description="Unique identifier for the user.")
|
||||
stylist_name: str = Field("crystal", description="The name of the stylist guide to use (e.g., 'crystal').")
|
||||
# 用于从已选单品继续搭配,可选
|
||||
start_outfit: List[Dict[str, str]] = Field(default_factory=list, description="Optional list of items already selected.")
|
||||
|
||||
# 用于返回 LLM/Agent 的文本回复
|
||||
class ChatResponse(BaseModel):
|
||||
response_text: str = Field(..., description="The chatbot's text response.")
|
||||
|
||||
# 搭配推荐的响应体 (StylistAgent的输出是JSON,这里简化为字符串,实际项目中会返回结构化的JSON)
|
||||
class OutfitResponse(BaseModel):
|
||||
summary: str = Field(..., description="The conversation summary used for styling.")
|
||||
next_item_json: Dict = Field(..., description="The next recommended item in JSON format.")
|
||||
|
||||
# --- Router Setup ---
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/chat",
|
||||
tags=["Chatbot & Styling"]
|
||||
)
|
||||
|
||||
# 在应用启动时实例化 ChatbotAgent,以便在路由中使用同一个实例
|
||||
# 注意:在真实的 FastAPI 应用中,我们通常使用依赖注入 (Depends) 来管理实例
|
||||
try:
|
||||
# 假设 ChatbotAgent 的 __init__ 已经包含了所有依赖的初始化
|
||||
global_agent = ChatbotAgent()
|
||||
except Exception as e:
|
||||
# 如果依赖(如 Redis 或 VectorDB)初始化失败,抛出错误
|
||||
print(f"FATAL: ChatbotAgent failed to initialize. Error: {e}")
|
||||
global_agent = None # 保持 None 状态,路由会抛出 500
|
||||
|
||||
@router.post("/", response_model=ChatResponse, summary="Process user message and get chatbot response")
|
||||
def handle_chat_message(request: ChatRequest):
|
||||
"""
|
||||
处理用户的聊天消息,将消息添加到历史记录,并调用 LLM 生成回复。
|
||||
"""
|
||||
if not global_agent:
|
||||
raise HTTPException(status_code=500, detail="Chatbot agent not initialized.")
|
||||
|
||||
try:
|
||||
response = global_agent.process_query(request.user_id, request.user_message)
|
||||
return ChatResponse(response_text=response)
|
||||
except Exception as e:
|
||||
print(f"Error processing chat query: {e}")
|
||||
# 返回一个通用错误信息,而不是内部错误
|
||||
raise HTTPException(status_code=500, detail="Internal server error while processing message.")
|
||||
|
||||
@router.post("/outfit/start", summary="Start outfit recommendation based on conversation history")
|
||||
def start_outfit_recommendation(request: OutfitStartRequest):
|
||||
"""
|
||||
基于用户的对话历史,生成搭配总结,并启动 Stylist Agent 推荐第一个单品。
|
||||
|
||||
返回:对话总结和 Stylist Agent 推荐的第一个单品的 JSON。
|
||||
"""
|
||||
if not global_agent:
|
||||
raise HTTPException(status_code=500, detail="Chatbot agent not initialized.")
|
||||
|
||||
try:
|
||||
# 1. 获取对话总结
|
||||
request_summary = global_agent.get_conversation_summary(request.user_id)
|
||||
|
||||
# 2. 调用 Stylist Agent 运行搭配流程
|
||||
# run_styling_process 应该返回 Stylist Agent 的第一个 JSON 输出 (recommend_item)
|
||||
# 假设 StylistAgent.run_styling_process 返回一个 JSON 字典
|
||||
next_item_json = global_agent.stylist_agent.run_styling_process(
|
||||
request_summary,
|
||||
request.stylist_name,
|
||||
request.start_outfit
|
||||
)
|
||||
|
||||
return OutfitResponse(
|
||||
summary=request_summary,
|
||||
next_item_json=next_item_json
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error starting outfit recommendation: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start styling process: {e}")
|
||||
|
||||
# (可选) 搭配继续接口:用于处理用户的反馈和继续推荐
|
||||
@router.post("/outfit/continue", summary="Continue outfit recommendation with user feedback")
|
||||
def continue_outfit_recommendation():
|
||||
"""
|
||||
这是一个占位符,用于处理用户对已推荐单品(如图片)的反馈,并让 Stylist Agent 推荐下一个单品。
|
||||
"""
|
||||
raise HTTPException(status_code=501, detail="Endpoint for outfit continuation is not yet implemented.")
|
||||
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
BIN
app/core/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
app/core/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/config.cpython-310.pyc
Normal file
BIN
app/core/__pycache__/config.cpython-310.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/data_structure.cpython-310.pyc
Normal file
BIN
app/core/__pycache__/data_structure.cpython-310.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/llm_interface.cpython-310.pyc
Normal file
BIN
app/core/__pycache__/llm_interface.cpython-310.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/redis_manager.cpython-310.pyc
Normal file
BIN
app/core/__pycache__/redis_manager.cpython-310.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/stylist_agent.cpython-310.pyc
Normal file
BIN
app/core/__pycache__/stylist_agent.cpython-310.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/system_prompt.cpython-310.pyc
Normal file
BIN
app/core/__pycache__/system_prompt.cpython-310.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/utils.cpython-310.pyc
Normal file
BIN
app/core/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
app/core/__pycache__/vector_database.cpython-310.pyc
Normal file
BIN
app/core/__pycache__/vector_database.cpython-310.pyc
Normal file
Binary file not shown.
135
app/core/chatbot_agent.py
Normal file
135
app/core/chatbot_agent.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from typing import Dict, List
|
||||
import asyncio
|
||||
|
||||
from app.core.data_structure import Message, Role
|
||||
from app.core.llm_interface import AsyncLLMInterface, AsyncGeminiLLM
|
||||
from app.core.redis_manager import RedisManager
|
||||
from app.core.system_prompt import BASIC_PROMPT, SUMMARY_PROMPT
|
||||
from app.core.stylist_agent import AsyncStylistAgent
|
||||
from app.core.vector_database import VectorDatabase
|
||||
from app.core.config import settings
|
||||
|
||||
class ChatbotAgent:
|
||||
def __init__(self, llm_model: AsyncLLMInterface = None):
|
||||
self.llm = llm_model if llm_model else AsyncGeminiLLM(model_name=settings.LLM_MODEL_NAME)
|
||||
self.redis = RedisManager(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
key_prefix=settings.REDIS_HISTORY_KEY_PREFIX
|
||||
)
|
||||
self.vector_db = VectorDatabase(
|
||||
vector_db_dir=settings.VECTOR_DB_DIR,
|
||||
collection_name=settings.COLLECTION_NAME,
|
||||
embedding_model_name=settings.EMBEDDING_MODEL_NAME
|
||||
)
|
||||
self.stylist_agent_kwages = {
|
||||
'local_db': self.vector_db,
|
||||
'max_len': 5,
|
||||
'outfits_root': settings.OUTFIT_OUTPUT_DIR,
|
||||
'image_dir': settings.IMAGE_DIR,
|
||||
'stylist_guide_dir': settings.STYLIST_GUIDE_DIR,
|
||||
'gemini_model_name': settings.LLM_MODEL_NAME
|
||||
}
|
||||
|
||||
async def process_query(self, user_id: str, user_message: str) -> str:
|
||||
"""
|
||||
处理用户的最新输入,调用 LLM, 并更新历史记录。
|
||||
"""
|
||||
|
||||
# 添加用户消息到历史
|
||||
user_msg = Message(role=Role.USER, content=user_message)
|
||||
chat_history = self.redis.get_history(user_id)
|
||||
chat_history.append(user_msg)
|
||||
|
||||
# 生成 LLM 回复
|
||||
try:
|
||||
response_text = await self.llm.generate_response(chat_history, system_prompt=BASIC_PROMPT)
|
||||
except Exception as e:
|
||||
print(f"LLM 调用失败: {e}")
|
||||
response_text = "抱歉,系统暂时无法响应,请稍后再试。"
|
||||
|
||||
# 添加助手消息到历史
|
||||
if response_text:
|
||||
assistant_msg = Message(role=Role.ASSISTANT, content=response_text)
|
||||
else:
|
||||
assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.")
|
||||
|
||||
self.redis.save_message(user_id, user_msg)
|
||||
self.redis.save_message(user_id, assistant_msg)
|
||||
|
||||
return response_text
|
||||
|
||||
async def get_conversation_summary(self, user_id: str) -> str:
|
||||
"""
|
||||
分析用户的完整会话历史,并打包成一个简洁的需求总结。
|
||||
|
||||
这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。`
|
||||
"""
|
||||
history_messages = self.redis.get_history(user_id)
|
||||
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
|
||||
|
||||
# 临时调用 LLM 或使用本地逻辑生成总结
|
||||
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)], system_prompt=SUMMARY_PROMPT)
|
||||
|
||||
return summary
|
||||
|
||||
async def recommend_outfit(self, user_id: str, stylist_name: str, start_outfit: List[Dict[str, str]] = [], num_outfits: int = 1):
|
||||
"""
|
||||
基于用户的对话历史和需求,推荐一套搭配。
|
||||
|
||||
Args:
|
||||
user_id: 用户唯一标识符。
|
||||
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。
|
||||
"""
|
||||
request_summary = await self.get_conversation_summary(user_id)
|
||||
print(f"Conversation Summary:\n{request_summary}")
|
||||
|
||||
tasks = []
|
||||
for _ in range(num_outfits):
|
||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
||||
task = agent.run_styling_process(request_summary, stylist_name, start_outfit)
|
||||
tasks.append(task)
|
||||
print(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
successful_outfits = []
|
||||
failed_outfits = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
# 任务执行中发生异常
|
||||
failed_outfits.append(f"Failed: {result}")
|
||||
else:
|
||||
# 任务成功,result 是 run_styling_process 返回的图片路径
|
||||
successful_outfits.append(result)
|
||||
|
||||
return {
|
||||
"successful_outfits": successful_outfits,
|
||||
"failed_outfits": failed_outfits
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during concurrent recommendation: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def test():
|
||||
agent = ChatbotAgent()
|
||||
user_id = "user123"
|
||||
agent.redis.clear_history(user_id) # 清除历史,便于测试
|
||||
print(await agent.process_query(user_id, "I need a dress for a summer wedding. I prefer something floral and light."))
|
||||
# print(agent.process_query(user_id, "I prefer something floral and light."))
|
||||
recommendation_results = await agent.recommend_outfit(user_id, stylist_name="crystal", start_outfit=[], num_outfits=2)
|
||||
|
||||
print("\n--- Final Recommendation Results ---")
|
||||
for i, path in enumerate(recommendation_results.get("successful_outfits", [])):
|
||||
print(f"✅ Outfit {i+1} saved to: {path}")
|
||||
|
||||
for error in recommendation_results.get("failed_outfits", []):
|
||||
print(f"❌ {error}")
|
||||
|
||||
asyncio.run(test())
|
||||
38
app/core/config.py
Normal file
38
app/core/config.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
# ⚠️ 注意: 您需要安装 pydantic-settings: pip install pydantic-settings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""
|
||||
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
|
||||
"""
|
||||
model_config = SettingsConfigDict(
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
extra='ignore' # 忽略环境变量中多余的键
|
||||
)
|
||||
|
||||
# Redis 配置
|
||||
REDIS_HOST: str = Field(default='localhost', description="Redis服务器地址")
|
||||
REDIS_PORT: int = Field(default=6379, description="Redis服务器端口")
|
||||
REDIS_DB: int = Field(default=0, description="Redis数据库编号")
|
||||
REDIS_HISTORY_KEY_PREFIX: str = Field(default="chat:history:", description="Redis会话历史键的前缀")
|
||||
|
||||
# LLM 配置
|
||||
GEMINI_API_KEY: str = Field(..., description="Google Gemini API 密钥。必须设置。")
|
||||
LLM_MODEL_NAME: str = Field(default="gemini-2.5-flash", description="使用的 LLM 模型名称")
|
||||
|
||||
# 路径配置参数
|
||||
DATA_ROOT: str = Field(default="./data", description="数据根目录")
|
||||
IMAGE_DIR: str = Field(default="./data/image_data", description="图片数据目录")
|
||||
OUTFIT_OUTPUT_DIR: str = Field(default="./data/outfit_output", description="生成的搭配图片输出目录")
|
||||
STYLIST_GUIDE_DIR: str = Field(default="./data/stylist_guide", description="风格指南文本目录")
|
||||
|
||||
# 向量数据库配置参数
|
||||
VECTOR_DB_DIR: str = Field(default="./data/db", description="向量数据库目录")
|
||||
COLLECTION_NAME: str = Field(default="lc_clothing_embedding", description="向量数据库集合名称")
|
||||
EMBEDDING_MODEL_NAME: str = Field(default="openai/clip-vit-base-patch32", description="CLIP嵌入模型名称")
|
||||
|
||||
# 创建配置实例,供应用其他部分使用
|
||||
settings = Settings()
|
||||
16
app/core/data_structure.py
Normal file
16
app/core/data_structure.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import List, Dict, Any
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
import datetime
|
||||
|
||||
# 角色枚举,用于区分用户和系统的消息
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
# 单条消息的数据模型
|
||||
class Message(BaseModel):
|
||||
role: Role = Field(..., description="Role of message sender")
|
||||
content: str = Field(..., description="Content of the message")
|
||||
# timestamp: str = Field(default_factory=lambda: datetime.datetime.now().isoformat()) # 记录时间戳
|
||||
55
app/core/llm_interface.py
Normal file
55
app/core/llm_interface.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from app.core.data_structure import Message, Role
|
||||
|
||||
class AsyncLLMInterface(ABC):
|
||||
@abstractmethod
|
||||
async def generate_response(self, history: List[Message], system_prompt: str) -> str:
|
||||
"""
|
||||
根据对话历史和系统指令生成回复.
|
||||
|
||||
Args:
|
||||
history: 包含多条 Message 的列表。
|
||||
|
||||
Returns:
|
||||
LLM 生成的回复字符串。
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
class AsyncGeminiLLM(AsyncLLMInterface):
|
||||
def __init__(self, model_name: str = "gemini-2.5-flash"):
|
||||
self.model_name = model_name
|
||||
try:
|
||||
self.gemini_client = genai.Client()
|
||||
except Exception as e:
|
||||
raise type(e)(f"Failed to initialize Gemini Client. Check if GEMINI_API_KEY is set. Original error: {e}")
|
||||
|
||||
async def generate_response(self, history: List[Message], system_prompt: str) -> str:
|
||||
contents = []
|
||||
|
||||
for msg in history:
|
||||
gemini_role = "user" if msg.role == Role.USER else "model"
|
||||
content = types.Content(
|
||||
role=gemini_role,
|
||||
parts=[types.Part.from_text(text=msg.content)]
|
||||
)
|
||||
contents.append(content)
|
||||
|
||||
try:
|
||||
response = await self.gemini_client.aio.models.generate_content(
|
||||
model=self.model_name,
|
||||
contents=contents,
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction=system_prompt,
|
||||
# temperature=0.3,
|
||||
)
|
||||
)
|
||||
return response.text
|
||||
except Exception as e:
|
||||
raise type(e)(f"Gemini API call failed: {e}")
|
||||
|
||||
|
||||
63
app/core/redis_manager.py
Normal file
63
app/core/redis_manager.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import redis
|
||||
from typing import List, Optional
|
||||
from app.core.data_structure import Message, Role
|
||||
|
||||
# 这是一个同步 Redis 客户端,用于演示如何替换内存存储。
|
||||
# 在生产环境和异步 Web 框架中,应替换为 aioredis 等异步客户端。
|
||||
class RedisManager:
|
||||
"""同步管理器,用于在 Redis 中存储和检索对话历史。"""
|
||||
|
||||
def __init__(self, host: str = 'localhost', port: int = 6379, db: int = 0, key_prefix: str = "chat:history:"):
|
||||
self.r: Optional[redis.Redis] = None
|
||||
self.key_prefix = key_prefix
|
||||
try:
|
||||
# 尝试连接 Redis
|
||||
self.r = redis.Redis(host=host, port=port, db=db, decode_responses=True)
|
||||
self.r.ping()
|
||||
print("Successfully connected to Redis at {}:{}".format(host, port))
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to connect to Redis: {e}. Falling back to No-Op.")
|
||||
self.r = None # 连接失败时设置为 None,避免后续操作报错
|
||||
|
||||
def _get_key(self, user_id: str) -> str:
|
||||
"""生成用户历史记录的 Redis 键名。"""
|
||||
return f"{self.key_prefix}{user_id}"
|
||||
|
||||
def _message_to_json(self, message: Message) -> str:
|
||||
"""将 Message 对象序列化为 JSON 字符串以便存储。"""
|
||||
return message.model_dump_json()
|
||||
|
||||
def _json_to_message(self, data: str) -> Message:
|
||||
"""将 JSON 字符串反序列化回 Message 对象。"""
|
||||
try:
|
||||
return Message.model_validate_json(data)
|
||||
except Exception as e:
|
||||
print(f"Error deserializing message data: {data[:50]}... Error: {e}")
|
||||
return Message(role=Role.ASSISTANT, content="[Deserialization Error]")
|
||||
|
||||
def save_message(self, user_id: str, message: Message):
|
||||
"""将单条消息保存到用户历史记录列表的末尾。"""
|
||||
if not self.r:
|
||||
return
|
||||
|
||||
message_json = self._message_to_json(message)
|
||||
# RPUSH:将元素添加到列表的尾部
|
||||
self.r.rpush(self._get_key(user_id), message_json)
|
||||
|
||||
def get_history(self, user_id: str) -> List[Message]:
|
||||
"""检索用户的完整会话历史记录。"""
|
||||
if not self.r:
|
||||
return []
|
||||
|
||||
# LRANGE:获取列表的所有元素 (0 到 -1)
|
||||
raw_history = self.r.lrange(self._get_key(user_id), 0, -1)
|
||||
|
||||
# 将 JSON 字符串列表转换为 Message 对象列表
|
||||
messages = [self._json_to_message(data) for data in raw_history]
|
||||
return messages
|
||||
|
||||
def clear_history(self, user_id: str):
|
||||
"""删除用户的完整历史记录。"""
|
||||
if self.r:
|
||||
self.r.delete(self._get_key(user_id))
|
||||
print(f"History cleared for {user_id}")
|
||||
300
app/core/stylist_agent.py
Normal file
300
app/core/stylist_agent.py
Normal file
@@ -0,0 +1,300 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from google import genai
|
||||
|
||||
from app.core.utils import merge_images_to_square
|
||||
|
||||
|
||||
class AsyncStylistAgent:
|
||||
CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'}
|
||||
|
||||
def __init__(self, local_db, max_len: int, outfits_root: str, image_dir: str, stylist_guide_dir: str, gemini_model_name: str):
|
||||
# self.outfit_items: List[Dict[str, str]] = []
|
||||
self.outfit_id = str(uuid.uuid4())
|
||||
self.gemini_client = genai.Client()
|
||||
self.local_db = local_db
|
||||
self.max_len = max_len
|
||||
self.output_outfit_path = os.path.join(outfits_root, f"{self.outfit_id}.jpg")
|
||||
self.output_json_path = os.path.join(outfits_root, f"{self.outfit_id}_items.json")
|
||||
self.image_dir = image_dir
|
||||
self.stylist_guide_dir = stylist_guide_dir
|
||||
self.gemini_model_name = gemini_model_name
|
||||
self.stop_reason = ""
|
||||
|
||||
def _load_style_guide(self, path: str) -> str:
|
||||
"""加载 markdown 风格指南内容。"""
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"Failed to load style guide from {path}: {e}")
|
||||
|
||||
def _build_system_prompt(self, request_summary: str = "") -> str:
|
||||
"""Constructs the complete System Prompt."""
|
||||
# Insert the style_guide content into the template
|
||||
template = f"""
|
||||
You are a professional fashion stylist Agent, specialized in creating complete outfits for the user.
|
||||
|
||||
Your task is to **create a cohesive and complete outfit**, strictly adhering to **BOTH** the user's explicit **Request Summary** and the **Outfit Style Guide**. You must decide the next logical item to add to the outfit based on the currently selected items (if any).
|
||||
|
||||
---
|
||||
|
||||
## Request from the User:
|
||||
|
||||
{request_summary}
|
||||
|
||||
## Core Guidance Document: Outfit Style Guide
|
||||
|
||||
{self.style_guide}
|
||||
|
||||
---
|
||||
|
||||
## Your Workflow and Constraints
|
||||
|
||||
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**.
|
||||
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories.
|
||||
3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "recommend_item",
|
||||
"category": "YOUR_ITEM_CATEGORY",
|
||||
"description": "YOUR_DETAILED_DESCRIPTION"
|
||||
}}
|
||||
```
|
||||
|
||||
* `action`: Must always be `"recommend_item"` until the outfit is complete.
|
||||
* `category`: Must be the category of the item you are recommending, strictly selected from the following list: {list(self.CATEGORY_SET)}.
|
||||
* `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include:
|
||||
* **Color** (e.g., milk tea, pure white, dark gray)
|
||||
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
|
||||
* **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern)
|
||||
* **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look)
|
||||
* **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.**
|
||||
|
||||
4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process:
|
||||
|
||||
```json
|
||||
{{
|
||||
"action": "stop",
|
||||
"reason": "OUTFIT_COMPLETE_AND_MEETS_ALL_MINI_GUIDELINES"
|
||||
}}
|
||||
```
|
||||
Normally, five or six items are totally enough for an outfit.
|
||||
|
||||
5. **Context Dependency**: The user's next input (if not `Start`) will contain the **image and description of the selected item**. When recommending the next item, you must consider the coordination between the **already selected items** and the Style Guide.
|
||||
|
||||
**Now, please start building an outfit and output the JSON for the first item.**
|
||||
"""
|
||||
return template.strip()
|
||||
|
||||
def _clear_uploaded_files(self):
|
||||
for f in self.gemini_client.files.list():
|
||||
self.gemini_client.files.delete(name=f.name)
|
||||
|
||||
async def _call_gemini(self, user_input: str) -> str:
|
||||
"""
|
||||
实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。
|
||||
|
||||
Args:
|
||||
user_input: 发送给模型的主文本内容。
|
||||
image_paths: 待发送图片的本地路径列表。
|
||||
|
||||
Returns:
|
||||
模型的响应文本(预期为 JSON 字符串)。
|
||||
"""
|
||||
content_parts = []
|
||||
self._clear_uploaded_files()
|
||||
# 1. 添加图片内容
|
||||
if self.outfit_items:
|
||||
merged_image_path = merge_images_to_square(self.outfit_items, max_len=self.max_len, output_path=self.output_outfit_path)
|
||||
try:
|
||||
myfile = await self.gemini_client.aio.files.upload(file=merged_image_path)
|
||||
content_parts.append(myfile)
|
||||
except Exception as e:
|
||||
print(f"Error loading image {merged_image_path}: {e}")
|
||||
|
||||
# 2. 添加文本内容
|
||||
content_parts.append(user_input)
|
||||
|
||||
print(f"\n--- Calling Gemini with {len(self.outfit_items) if self.outfit_items else 0} images and query:\n{user_input}")
|
||||
|
||||
try:
|
||||
# 3. 实际 API 调用
|
||||
response = await self.gemini_client.aio.models.generate_content(
|
||||
model=self.gemini_model_name,
|
||||
contents=content_parts,
|
||||
config={
|
||||
"system_instruction": self.system_prompt,
|
||||
# 确保模型返回 JSON 格式
|
||||
"response_mime_type": "application/json",
|
||||
"response_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {"type": "string", "enum": ["recommend_item", "stop"]},
|
||||
"category": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"reason": {"type": "string"}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# response.text 将包含一个 JSON 字符串
|
||||
return response.text
|
||||
|
||||
except Exception as e:
|
||||
print(f"Gemini API Call failed: {e}")
|
||||
# 返回一个停止信号以防止循环继续
|
||||
return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"})
|
||||
|
||||
def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""安全解析 Gemini 的 JSON 响应。"""
|
||||
try:
|
||||
# 有时 Gemini 可能会在 JSON 外面添加文字,尝试清理
|
||||
response_text = response_text.strip().replace('```json', '').replace('```', '')
|
||||
data = json.loads(response_text)
|
||||
print(f"The agent response is: {data}")
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON from Gemini: {e}")
|
||||
print(f"Raw response: {response_text}")
|
||||
return None
|
||||
|
||||
def _get_next_item(self, item_description: str, category: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
1. 根据描述生成嵌入。
|
||||
2. 查询本地数据库以找到最佳匹配项。
|
||||
3. 模拟 Agent 审核匹配项(这里简化为总是通过)。
|
||||
"""
|
||||
try:
|
||||
# 1. 生成查询嵌入
|
||||
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False)
|
||||
|
||||
# 2. 执行查询,并过滤类别
|
||||
results = self.local_db.query_local_db(query_embedding, category, n_results=1)
|
||||
|
||||
if not results:
|
||||
print(f"❌ 数据库中未找到符合 '{category}' 和描述的单品。")
|
||||
return None
|
||||
|
||||
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
|
||||
best_meta = results['metadatas'][0][0] # 第一个 batch 的第一个 metadata
|
||||
return {
|
||||
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
|
||||
"category": category,
|
||||
"gpt_description": item_description,
|
||||
'description': best_meta['description'],
|
||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||
# 这里假设 item_id 就是文件名的一部分
|
||||
"image_path": os.path.join(self.image_dir, f"{best_meta['item_id']}.jpg")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred during item retrieval: {e}")
|
||||
return None
|
||||
|
||||
def _build_user_input(self) -> str:
|
||||
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
||||
if not self.outfit_items:
|
||||
return "Start"
|
||||
|
||||
# 将已选单品的信息作为上下文发回给 Agent
|
||||
context = "Selected fashion items:\n"
|
||||
for ii, item in enumerate(self.outfit_items):
|
||||
context += f"{ii + 1}. Category: {item['category']}. Description: {item['description']}\n"
|
||||
context += "\nPlease recommend the next single item based on the selected items, user's request, and style guide."
|
||||
return context
|
||||
|
||||
async def run_styling_process(self, request_summary, stylist_name, start_outfit=[]):
|
||||
self.outfit_items = start_outfit if start_outfit else []
|
||||
"""主流程控制循环。"""
|
||||
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
|
||||
|
||||
self.style_guide = self._load_style_guide(os.path.join(self.stylist_guide_dir, f"{stylist_name}_en.md"))
|
||||
self.system_prompt = self._build_system_prompt(request_summary)
|
||||
|
||||
while True:
|
||||
# 1. 准备用户输入(上下文)
|
||||
user_input = self._build_user_input()
|
||||
|
||||
# 2. 调用 Gemini Agent
|
||||
gemini_response_text = await self._call_gemini(user_input)
|
||||
gemini_data = self._parse_gemini_response(gemini_response_text)
|
||||
|
||||
if not gemini_data:
|
||||
print("🚨 Agent 返回无效响应,终止流程。")
|
||||
self.stop_reason = "Agent failed to return response"
|
||||
break
|
||||
|
||||
# 3. 检查终止条件
|
||||
if gemini_data.get('action') == 'stop':
|
||||
print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
|
||||
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
|
||||
break
|
||||
|
||||
# 4. 处理推荐单品
|
||||
if gemini_data.get('action') == 'recommend_item':
|
||||
category = gemini_data.get('category')
|
||||
description = gemini_data.get('description')
|
||||
|
||||
# 4a. 检查类别是否有效 (重要步骤)
|
||||
if category not in self.CATEGORY_SET:
|
||||
print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。")
|
||||
# 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正
|
||||
# 这里简化为跳过本次循环
|
||||
continue
|
||||
|
||||
# 4b. 在本地 DB 中查询单品
|
||||
new_item = self._get_next_item(description, category)
|
||||
|
||||
if new_item:
|
||||
# 4c. (实际步骤) 将选中的单品图片和描述发回给 Agent 进行最终审核
|
||||
# 这里的代码框架省略了图片回传和二次审核的步骤,直接视为通过
|
||||
# 实际你需要: new_user_input = f"Check this item: {new_item['description']}, path: {new_item['image_path']}"
|
||||
# call_gemini_agent(...) -> 如果返回"pass",则添加到outfit_items
|
||||
|
||||
if new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
|
||||
print("This item exists. Stop here.")
|
||||
self.stop_reason = "Finish reason: Duplicate item selected."
|
||||
break
|
||||
|
||||
if new_item['item_id'] == "ELG383":
|
||||
if random.random() < 0.70:
|
||||
self.stop_reason = "Finish reason: ELG383 is seleced repeatly."
|
||||
break
|
||||
|
||||
self.outfit_items.append(new_item)
|
||||
print(f"➕ 成功添加单品: {new_item['category']} ({new_item['item_id']}). 当前搭配数量: {len(self.outfit_items)}")
|
||||
|
||||
else:
|
||||
print("⚠️ 未找到匹配单品,无法继续搭配。终止。")
|
||||
self.stop_reason = "Finish reason: No matching item found in local database."
|
||||
break
|
||||
|
||||
if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环
|
||||
print("🚨 达到最大搭配数量限制,强制终止。")
|
||||
self.stop_reason = "Finish reason: Reached max outfit length."
|
||||
break
|
||||
|
||||
# 5. 流程结束后保存结果
|
||||
self._save_outfit_results()
|
||||
return self.output_outfit_path
|
||||
|
||||
def _save_outfit_results(self):
|
||||
"""保存最终的 JSON 列表和图片到指定文件夹。"""
|
||||
if not self.outfit_items:
|
||||
raise ValueError("No outfit items to save.")
|
||||
|
||||
# 1. 保存 JSON 文件
|
||||
results_list = [{'item_id': item['item_id'], 'category': item['category'], 'description': item['description'], 'gpt_description': item['gpt_description']} for item in self.outfit_items]
|
||||
results_list.append({'stop_reason': self.stop_reason})
|
||||
with open(self.output_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(results_list, f, ensure_ascii=False, indent=4)
|
||||
|
||||
merge_images_to_square(self.outfit_items, max_len=self.max_len, output_path=self.output_outfit_path, add_text=False)
|
||||
3
app/core/system_prompt.py
Normal file
3
app/core/system_prompt.py
Normal file
@@ -0,0 +1,3 @@
|
||||
BASIC_PROMPT = """You are a fashion stylist AI named StylistGPT. Your task is to assist users in creating personalized fashion looks based on their preferences, occasions, and current fashion trends."""
|
||||
|
||||
SUMMARY_PROMPT = """Given conversation history, summarize the user's fashion preferences, occasions, and any specific requirements they have mentioned. Provide a concise overview of their style profile in one sentence."""
|
||||
163
app/core/utils.py
Normal file
163
app/core/utils.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from typing import List, Dict
|
||||
import shutil
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
# 9个 341x341 左右的单元格 (ALL_9_CELLS)
|
||||
# 布局顺序: 从上到下,从左到右 (1 -> 9)
|
||||
ALL_9_CELLS = [
|
||||
# Top Row (Y=0, H=341)
|
||||
(0, 0, 341, 341), # 1. Top-Left (341x341)
|
||||
(341, 0, 341, 341), # 2. Top-Middle (341x341)
|
||||
(682, 0, 342, 341), # 3. Top-Right (342x341)
|
||||
# Middle Row (Y=341, H=341)
|
||||
(0, 341, 341, 341), # 4. Mid-Left (341x341)
|
||||
(341, 341, 341, 341),# 5. Center (341x341)
|
||||
(682, 341, 342, 341),# 6. Mid-Right (342x341)
|
||||
# Bottom Row (Y=682, H=342)
|
||||
(0, 682, 341, 342), # 7. Bottom-Left (341x342)
|
||||
(341, 682, 341, 342),# 8. Bottom-Middle (341x342)
|
||||
(682, 682, 342, 342) # 9. Bottom-Right (342x342)
|
||||
]
|
||||
|
||||
def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, output_path="temp.jpg", add_text=True) -> str:
|
||||
"""
|
||||
Loads up to 4 images from the given paths, resizes them while maintaining
|
||||
aspect ratio, and merges them onto a 1024x1024 white background JPG.
|
||||
|
||||
The layout depends on the number of images:
|
||||
1: Center the single image on the 1024x1024 canvas.
|
||||
2: Place side-by-side, each scaled to fit a 512x1024 half.
|
||||
3: Place in top-left (512x512), top-right (512x512), and bottom-left (512x512).
|
||||
4: Place in all four 512x512 quadrants.
|
||||
|
||||
Args:
|
||||
outfit_items: A list of item metadata (max length 9).
|
||||
|
||||
Returns:
|
||||
The file path of the temporary merged JPG image.
|
||||
"""
|
||||
|
||||
# Define the final canvas size
|
||||
CANVAS_SIZE = 1024
|
||||
|
||||
# 1. Create the final white canvas
|
||||
# Using 'RGB' mode for JPG output
|
||||
canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), 'white')
|
||||
draw = ImageDraw.Draw(canvas)
|
||||
font = ImageFont.load_default()
|
||||
|
||||
|
||||
# 2. Define the quadrants/target areas (x, y, w, h)
|
||||
# The positions are based on a 512x512 quadrant size
|
||||
quadrants = {
|
||||
1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement
|
||||
2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right
|
||||
3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left
|
||||
4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four
|
||||
5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5)
|
||||
6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6)
|
||||
7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7)
|
||||
8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8)
|
||||
9: ALL_9_CELLS[:9] # 布局全部9个单元格 (1-9)
|
||||
}
|
||||
|
||||
# 3. Load and Filter Images
|
||||
valid_images = []
|
||||
image_paths = [item['image_path'] for item in outfit_items]
|
||||
for path in image_paths:
|
||||
try:
|
||||
# We use Image.open() and convert to 'RGB' to handle potential transparency (RGBA)
|
||||
# and ensure compatibility with the final 'RGB' canvas and JPG output.
|
||||
img = Image.open(path).convert('RGB')
|
||||
valid_images.append(img)
|
||||
except Exception as e:
|
||||
print(f"Error loading image {path}. Skipping: {e}")
|
||||
|
||||
num_images = len(valid_images)
|
||||
|
||||
if num_images == 0:
|
||||
raise ValueError("No valid images were loaded.")
|
||||
|
||||
if num_images > max_len:
|
||||
raise ValueError(f"Valid item number {num_images} exceed max limit {max_len}")
|
||||
|
||||
|
||||
# Get the correct list of target areas based on the number of valid images
|
||||
target_areas = quadrants.get(num_images, [])
|
||||
|
||||
# 4. Resize and Paste
|
||||
for i, (img, item) in enumerate(zip(valid_images, outfit_items)):
|
||||
item_id = item['item_id']
|
||||
category = item['category']
|
||||
if i >= len(target_areas):
|
||||
# This should not happen if num_images <= 4
|
||||
break
|
||||
|
||||
# Target area dimensions (x_start, y_start, width, height)
|
||||
x_start, y_start, target_w, target_h = target_areas[i]
|
||||
|
||||
# Calculate new size while maintaining aspect ratio
|
||||
original_w, original_h = img.size
|
||||
|
||||
# Calculate the ratio needed to fit within the target area
|
||||
ratio_w = target_w / original_w
|
||||
ratio_h = target_h / original_h
|
||||
|
||||
# Use the *smaller* of the two ratios to ensure the image fits entirely
|
||||
resize_ratio = min(ratio_w, ratio_h)
|
||||
|
||||
# Calculate the new dimensions
|
||||
new_w = int(original_w * resize_ratio)
|
||||
new_h = int(original_h * resize_ratio)
|
||||
|
||||
# Resize the image. Image.Resampling.LANCZOS provides high-quality scaling.
|
||||
# Pillow documentation recommends ANTIALIAS or BICUBIC for downscaling,
|
||||
# but LANCZOS is a good general high-quality filter.
|
||||
# Note: In Pillow versions > 9.0.0, Image.LANCZOS is now Image.Resampling.LANCZOS
|
||||
resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
|
||||
# Calculate the paste position to center the resized image within its target area
|
||||
# Center X: (Target Width - New Width) / 2 + X Start
|
||||
paste_x = (target_w - new_w) // 2 + x_start
|
||||
# Center Y: (Target Height - New Height) / 2 + Y Start
|
||||
# paste_y = (target_h - new_h) // 2 + y_start
|
||||
|
||||
TEXT_RESERVE_HEIGHT = 30
|
||||
paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start
|
||||
paste_y = max(paste_y, y_start)
|
||||
|
||||
# Paste the resized image onto the canvas
|
||||
canvas.paste(resized_img, (paste_x, paste_y))
|
||||
|
||||
full_text = f"ID: {item_id}, Category: {category}"
|
||||
try:
|
||||
# 推荐使用:计算文本的实际尺寸 (width, height)
|
||||
bbox = draw.textbbox((0, 0), full_text, font=font)
|
||||
text_w = bbox[2] - bbox[0]
|
||||
text_h = bbox[3] - bbox[1]
|
||||
except AttributeError:
|
||||
# 兼容旧版本 Pillow
|
||||
text_w, text_h = draw.textsize(full_text, font=font)
|
||||
|
||||
|
||||
# 计算 X 轴起始位置:使其在目标区域 (target_w) 中居中
|
||||
text_x_center = x_start + target_w // 2
|
||||
text_x_start = text_x_center - text_w // 2
|
||||
|
||||
# 计算 Y 轴起始位置:将其放在目标区域的底部
|
||||
# (目标区域的起始Y + 目标区域的高度 - 文本行的高度)
|
||||
text_y_start = y_start + target_h - text_h - 5 # 减去 5 像素作为边距
|
||||
|
||||
# 3. 绘制合并后的文本
|
||||
if add_text:
|
||||
draw.text((text_x_start, text_y_start),
|
||||
full_text,
|
||||
fill='black',
|
||||
font=font)
|
||||
|
||||
# Save as a high-quality JPG (quality=90 is a good balance)
|
||||
canvas.save(output_path, 'JPEG', quality=90)
|
||||
|
||||
return output_path
|
||||
|
||||
63
app/core/vector_database.py
Normal file
63
app/core/vector_database.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import torch
|
||||
import chromadb
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class VectorDatabase():
|
||||
def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str):
|
||||
print(vector_db_dir)
|
||||
self.client = chromadb.PersistentClient(path=vector_db_dir)
|
||||
self.collection = self.client.get_collection(name=collection_name)
|
||||
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
||||
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||
|
||||
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||
|
||||
if is_image:
|
||||
inputs = self.processor(images=data, return_tensors="pt").to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self.model.get_image_features(**inputs)
|
||||
else:
|
||||
# 强制截断,解决序列长度问题
|
||||
inputs = self.processor(
|
||||
text=[data],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self.model.get_text_features(**inputs)
|
||||
|
||||
# L2 归一化
|
||||
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
return features.cpu().numpy().flatten().tolist()
|
||||
|
||||
def query_local_db(self, embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
基于嵌入向量在本地数据库中查询相似单品。
|
||||
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
|
||||
"""
|
||||
print(f"--- Querying DB for Category: {category} ---")
|
||||
# 实际应执行向量查询
|
||||
# 为了演示流程,返回一个模拟结果
|
||||
results = self.collection.query(
|
||||
query_embeddings=[embedding],
|
||||
n_results=n_results,
|
||||
where={
|
||||
"$and": [
|
||||
{"category": category},
|
||||
{"modality": "image"},
|
||||
]
|
||||
},
|
||||
include=['documents', 'metadatas', 'distances']
|
||||
)
|
||||
return results
|
||||
231
requirements.txt
Normal file
231
requirements.txt
Normal file
@@ -0,0 +1,231 @@
|
||||
accelerate==0.21.0
|
||||
aiohttp==3.9.5
|
||||
aiosignal==1.3.1
|
||||
albumentations==0.3.2
|
||||
annotated-types==0.7.0
|
||||
antlr4-python3-runtime==4.9.3
|
||||
anykeystore==0.2
|
||||
asn1crypto==1.5.1
|
||||
asttokens==2.4.1
|
||||
async-timeout==4.0.3
|
||||
attrs==21.2.0
|
||||
bidict==0.23.1
|
||||
blessed==1.20.0
|
||||
boto3==1.34.113
|
||||
botocore==1.34.113
|
||||
braceexpand==0.1.7
|
||||
cachetools==5.3.3
|
||||
certifi==2024.2.2
|
||||
cffi==1.16.0
|
||||
chardet==5.2.0
|
||||
charset-normalizer==3.3.2
|
||||
click==8.1.7
|
||||
clip==0.2.0
|
||||
clip-openai==1.0.post20230121
|
||||
cmake==3.29.3
|
||||
cramjam==2.8.3
|
||||
crcmod==1.7
|
||||
cryptacular==1.6.2
|
||||
cryptography==39.0.2
|
||||
cycler==0.12.1
|
||||
datasets==2.2.1
|
||||
diffusers==0.30.1
|
||||
decorator==5.1.1
|
||||
decord==0.6.0
|
||||
deepspeed==0.14.2
|
||||
defusedxml==0.7.1
|
||||
Deprecated==1.2.14
|
||||
descartes==1.1.0
|
||||
dill==0.3.8
|
||||
distlib==0.3.8
|
||||
distro-info==1.0
|
||||
dnspython==2.6.1
|
||||
docker-pycreds==0.4.0
|
||||
docstring_parser==0.16
|
||||
ecdsa==0.19.0
|
||||
einops==0.6.0
|
||||
exceptiongroup==1.2.1
|
||||
executing==2.0.1
|
||||
fairscale==0.4.13
|
||||
fastparquet==2024.5.0
|
||||
ffmpegcv==0.3.13
|
||||
filelock==3.14.0
|
||||
fire==0.6.0
|
||||
fonttools==4.51.0
|
||||
frozenlist==1.4.1
|
||||
fsspec==2023.6.0
|
||||
ftfy==6.2.0
|
||||
gitdb==4.0.11
|
||||
GitPython==3.1.43
|
||||
gpustat==1.1.1
|
||||
greenlet==3.0.3
|
||||
grpcio==1.64.0
|
||||
h11==0.14.0
|
||||
hjson==3.1.0
|
||||
hupper==1.12.1
|
||||
idna==3.7
|
||||
imageio==2.34.1
|
||||
imgaug==0.2.6
|
||||
iniconfig==2.0.0
|
||||
ipaddress==1.0.23
|
||||
ipdb==0.13.13
|
||||
ipython==8.18.1
|
||||
jaxtyping==0.2.28
|
||||
jedi==0.19.1
|
||||
Jinja2==3.1.4
|
||||
jmespath==1.0.1
|
||||
joblib==1.4.2
|
||||
jsonargparse==4.14.1
|
||||
jsonlines==4.0.0
|
||||
kiwisolver==1.4.5
|
||||
kornia==0.7.2
|
||||
kornia_rs==0.1.3
|
||||
lazy_loader==0.4
|
||||
lightning==2.2.3
|
||||
lightning-utilities==0.11.2
|
||||
lit==18.1.6
|
||||
MarkupSafe==2.1.5
|
||||
matplotlib==3.5.3
|
||||
matplotlib-inline==0.1.7
|
||||
miscreant==0.3.0
|
||||
mpmath==1.3.0
|
||||
msgpack==1.0.8
|
||||
multidict==6.0.5
|
||||
multiprocess==0.70.16
|
||||
natsort==8.4.0
|
||||
networkx==3.2.1
|
||||
ninja==1.11.1.1
|
||||
numpy==1.24.4
|
||||
nuscenes-devkit==1.1.11
|
||||
oauthlib==3.2.2
|
||||
omegaconf==2.3.0
|
||||
open-clip-torch==2.24.0
|
||||
openai-clip
|
||||
opencv-python==4.9.0.80
|
||||
opencv-python-headless==3.4.18.65
|
||||
packaging==22.0
|
||||
pandas==1.5.3
|
||||
parquet==1.3.1
|
||||
parso==0.8.4
|
||||
PasteDeploy==3.1.0
|
||||
pathlib2==2.3.7.post1
|
||||
pathtools==0.1.2
|
||||
pbkdf2==1.3
|
||||
pexpect==4.9.0
|
||||
pillow==10.3.0
|
||||
plaster==1.1.2
|
||||
plaster-pastedeploy==1.0.1
|
||||
platformdirs==4.2.2
|
||||
plotly==5.22.0
|
||||
pluggy==1.5.0
|
||||
ply==3.11
|
||||
promise==2.3
|
||||
prompt-toolkit==3.0.43
|
||||
protobuf==3.20.3
|
||||
psutil==5.9.8
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
py==1.11.0
|
||||
py-cpuinfo==9.0.0
|
||||
py-spy==0.3.14
|
||||
pyarrow==11.0.0
|
||||
pyarrow-hotfix==0.6
|
||||
pyasn1==0.6.0
|
||||
pycocotools==2.0.7
|
||||
pycparser==2.22
|
||||
pycryptodomex==3.20.0
|
||||
pycurl==7.43.0.6
|
||||
ollama==0.4.4
|
||||
pydantic==2.9.2
|
||||
pydantic_core==2.23.4
|
||||
Pygments==2.18.0
|
||||
PyJWT==2.8.0
|
||||
pynvml==11.5.0
|
||||
pyope==0.2.2
|
||||
pyOpenSSL==23.2.0
|
||||
pyparsing==3.1.2
|
||||
pyquaternion==0.9.9
|
||||
pyramid==2.0.2
|
||||
pyramid-mailer==0.15.1
|
||||
pytest==6.2.5
|
||||
python-consul==1.1.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-engineio==4.9.1
|
||||
python-etcd==0.4.5
|
||||
python-jose==3.3.0
|
||||
python-socketio==5.11.2
|
||||
python3-openid==3.2.0
|
||||
pytorch-extension==0.2
|
||||
pytorch-lightning==2.2.3
|
||||
pytz==2024.1
|
||||
PyYAML==6.0.1
|
||||
regex==2024.5.15
|
||||
repoze.sendmail==4.4.1
|
||||
requests==2.31.0
|
||||
requests-oauthlib==2.0.0
|
||||
rsa==4.9
|
||||
s3transfer==0.10.1
|
||||
safetensors==0.4.3
|
||||
schedule==1.2.2
|
||||
scikit-image==0.22.0
|
||||
scikit-learn==1.5.0
|
||||
scipy==1.13.1
|
||||
sentencepiece==0.2.0
|
||||
sentry-sdk==2.3.1
|
||||
setproctitle==1.3.3
|
||||
Shapely==1.8.5.post1
|
||||
shortuuid==1.0.13
|
||||
simple-websocket==1.0.0
|
||||
six==1.16.0
|
||||
smmap==5.0.1
|
||||
SQLAlchemy==2.0.30
|
||||
stack-data==0.6.3
|
||||
sympy==1.12
|
||||
taming-transformers-rom1504==0.0.6
|
||||
tenacity==8.3.0
|
||||
tensorboardX==2.6.2.2
|
||||
termcolor==2.4.0
|
||||
threadpoolctl==3.5.0
|
||||
thriftpy2==0.5.0
|
||||
tifffile==2024.5.22
|
||||
timm==1.0.3
|
||||
tokenizers==0.19.1
|
||||
toml==0.10.2
|
||||
tomli==2.0.1
|
||||
torch==2.2.1
|
||||
torch-fidelity==0.3.0
|
||||
torchmetrics==1.4.0.post0
|
||||
torchvision==0.17.1
|
||||
tox==3.28.0
|
||||
tqdm==4.66.4
|
||||
traitlets==5.14.3
|
||||
transaction==4.0
|
||||
transformers==4.41.1
|
||||
translationstring==1.4
|
||||
triton==2.2.0
|
||||
typeguard==2.13.3
|
||||
typing_extensions==4.12.0
|
||||
tzdata==2024.1
|
||||
urllib3==1.26.18
|
||||
velruse==1.1.1
|
||||
venusian==3.1.0
|
||||
virtualenv==20.26.2
|
||||
wandb==0.17.2
|
||||
watchdog==4.0.1
|
||||
wcwidth==0.2.13
|
||||
webdataset==0.2.86
|
||||
WebOb==1.8.7
|
||||
websocket-client==1.8.0
|
||||
wrapt==1.16.0
|
||||
wsproto==1.2.0
|
||||
WTForms==3.1.2
|
||||
wtforms-recaptcha==0.3.2
|
||||
xformers==0.0.25
|
||||
xxhash==3.4.1
|
||||
yarl==1.9.4
|
||||
zope.deprecation==5.0
|
||||
zope.interface==6.4.post2
|
||||
zope.sqlalchemy==3.1
|
||||
pytorch-fid==0.3.0
|
||||
lpips==0.1.4
|
||||
huggingface_hub==0.24.6
|
||||
Reference in New Issue
Block a user