push gitignore

This commit is contained in:
zhh
2025-10-24 10:37:19 +08:00
parent ce00630bb2
commit 1adf29a3f6
36 changed files with 6089 additions and 334 deletions

View File

@@ -1,14 +1,16 @@
import time
from typing import Dict, List
import asyncio
from app.core.data_structure import Message, Role
from app.core.llm_interface import AsyncLLMInterface, AsyncGeminiLLM
from app.core.redis_manager import RedisManager
from app.core.redis_manager import RedisManager
from app.core.system_prompt import BASIC_PROMPT, SUMMARY_PROMPT
from app.core.stylist_agent import AsyncStylistAgent
from app.core.vector_database import VectorDatabase
from app.core.config import settings
class ChatbotAgent:
def __init__(self, llm_model: AsyncLLMInterface = None):
self.llm = llm_model if llm_model else AsyncGeminiLLM(model_name=settings.LLM_MODEL_NAME)
@@ -27,7 +29,7 @@ class ChatbotAgent:
'local_db': self.vector_db,
'max_len': 5,
'outfits_root': settings.OUTFIT_OUTPUT_DIR,
'image_dir': settings.IMAGE_DIR,
'image_dir': settings.IMAGE_DIR,
'stylist_guide_dir': settings.STYLIST_GUIDE_DIR,
'gemini_model_name': settings.LLM_MODEL_NAME
}
@@ -54,7 +56,7 @@ class ChatbotAgent:
assistant_msg = Message(role=Role.ASSISTANT, content=response_text)
else:
assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.")
self.redis.save_message(user_id, user_msg)
self.redis.save_message(user_id, assistant_msg)
@@ -68,12 +70,12 @@ class ChatbotAgent:
"""
history_messages = self.redis.get_history(user_id)
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
# 临时调用 LLM 或使用本地逻辑生成总结
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)], system_prompt=SUMMARY_PROMPT)
return summary
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit: List[Dict[str, str]] = [], num_outfits: int = 1):
"""
基于用户的对话历史和需求,推荐一套搭配。
@@ -111,25 +113,27 @@ class ChatbotAgent:
print(f"An unexpected error occurred during concurrent recommendation: {e}")
return {"error": str(e)}
if __name__ == "__main__":
async def test():
async def run():
start_time = time.time()
agent = ChatbotAgent()
user_id = "user123"
agent.redis.clear_history(user_id) # 清除历史,便于测试
print(await agent.process_query(user_id, "I want a chic outfit for a summer party."))
print(await agent.process_query(user_id, "I prefer something floral and light."))
user_id = "string"
# agent.redis.clear_history(user_id) # 清除历史,便于测试
# print(await agent.process_query(user_id, "I want a chic outfit for a summer party."))
# print(await agent.process_query(user_id, "I prefer something floral and light."))
request_summary = await agent.get_conversation_summary(user_id)
print(f"Conversation Summary:\n{request_summary}")
recommendation_results = await agent.recommend_outfit(request_summary, stylist_name="crystal", start_outfit=[], num_outfits=2)
recommendation_results = await agent.recommend_outfit(request_summary, stylist_name="crystal", start_outfit=[], num_outfits=4)
print("\n--- Final Recommendation Results ---")
for i, path in enumerate(recommendation_results.get("successful_outfits", [])):
print(f"✅ Outfit {i+1} saved to: {path}")
print(f"✅ Outfit {i + 1} saved to: {path}")
for error in recommendation_results.get("failed_outfits", []):
print(f"{error}")
print(time.time() - start_time)
asyncio.run(test())
asyncio.run(run())

View File

@@ -0,0 +1,163 @@
from google.genai import types
from typing import Dict, List
import asyncio
from google import genai
from app.core import system_prompt
from app.core.data_structure import Message, Role
from app.core.llm_interface_stream import AsyncLLMInterface, AsyncGeminiLLM
from app.core.redis_manager import RedisManager
from app.core.system_prompt import BASIC_PROMPT, SUMMARY_PROMPT
from app.core.stylist_agent import AsyncStylistAgent
from app.core.vector_database import VectorDatabase
from app.core.config import settings
class ChatbotAgent:
def __init__(self, llm_model: AsyncLLMInterface = None):
self.llm = llm_model if llm_model else AsyncGeminiLLM(model_name=settings.LLM_MODEL_NAME)
self.redis = RedisManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
key_prefix=settings.REDIS_HISTORY_KEY_PREFIX
)
self.vector_db = VectorDatabase(
vector_db_dir=settings.VECTOR_DB_DIR,
collection_name=settings.COLLECTION_NAME,
embedding_model_name=settings.EMBEDDING_MODEL_NAME
)
self.stylist_agent_kwages = {
'local_db': self.vector_db,
'max_len': 5,
'outfits_root': settings.OUTFIT_OUTPUT_DIR,
'image_dir': settings.IMAGE_DIR,
'stylist_guide_dir': settings.STYLIST_GUIDE_DIR,
'gemini_model_name': settings.LLM_MODEL_NAME
}
self.gemini_client = genai.Client(
vertexai=True, project='aida-461108', location='us-central1'
)
async def process_query(self, user_id: str, user_message: str) -> str:
"""
处理用户的最新输入,调用 LLM, 并更新历史记录。
"""
# 添加用户消息到历史
user_msg = Message(role=Role.USER, content=user_message)
chat_history = self.redis.get_history(user_id)
chat_history.append(user_msg)
contents = []
for msg in chat_history:
gemini_role = "user" if msg.role == Role.USER else "model"
content = types.Content(
role=gemini_role,
parts=[types.Part.from_text(text=msg.content)]
)
contents.append(content)
response_parts = []
response_stream = await self.gemini_client.aio.models.generate_content_stream(
model='gemini-2.5-flash',
contents=contents,
config=types.GenerateContentConfig(
system_instruction=BASIC_PROMPT,
# temperature=0.3,
)
)
async for chunk in response_stream:
# 您可以在这里处理每一个文本块,例如发送给前端
print(chunk.text, end="", flush=True)
response_parts.append(chunk.text)
# 3. 将所有文本块合并成最终的字符串
response_text = "".join(response_parts)
# 添加助手消息到历史
if response_text:
assistant_msg = Message(role=Role.ASSISTANT, content=response_text)
else:
assistant_msg = Message(role=Role.ASSISTANT, content="No response generated. Try again later.")
self.redis.save_message(user_id, user_msg)
self.redis.save_message(user_id, assistant_msg)
return response_text
async def get_conversation_summary(self, user_id: str) -> str:
"""
分析用户的完整会话历史,并打包成一个简洁的需求总结。
这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。`
"""
history_messages = self.redis.get_history(user_id)
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
# 临时调用 LLM 或使用本地逻辑生成总结
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)], system_prompt=SUMMARY_PROMPT)
return summary
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit: List[Dict[str, str]] = [], num_outfits: int = 1):
"""
基于用户的对话历史和需求,推荐一套搭配。
Args:
request_summary: 用户的request
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id''category'
"""
tasks = []
for _ in range(num_outfits):
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
task = agent.run_styling_process(request_summary, stylist_name, start_outfit)
tasks.append(task)
print(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
successful_outfits = []
failed_outfits = []
for result in results:
if isinstance(result, Exception):
# 任务执行中发生异常
failed_outfits.append(f"Failed: {result}")
else:
# 任务成功result 是 run_styling_process 返回的图片路径
successful_outfits.append(result)
return {
"successful_outfits": successful_outfits,
"failed_outfits": failed_outfits
}
except Exception as e:
print(f"An unexpected error occurred during concurrent recommendation: {e}")
return {"error": str(e)}
if __name__ == "__main__":
async def run():
# 阶段一:用户对话
agent = ChatbotAgent()
user_id = "string"
# agent.redis.clear_history(user_id) # 清除历史,便于测试
# await agent.process_query(user_id, "I want a chic outfit for a summer party.")
# print(await agent.process_query(user_id, "I prefer something floral and light."))
# 阶段二:读取聊天记录,生成推荐搭配
request_summary = await agent.get_conversation_summary(user_id)
print(f"Conversation Summary:\n{request_summary}")
recommendation_results = await agent.recommend_outfit(request_summary, stylist_name="crystal", start_outfit=[], num_outfits=1)
print("\n--- Final Recommendation Results ---")
for i, path in enumerate(recommendation_results.get("successful_outfits", [])):
print(f"✅ Outfit {i + 1} saved to: {path}")
for error in recommendation_results.get("failed_outfits", []):
print(f"{error}")
asyncio.run(run())

View File

@@ -1,6 +1,7 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import Field
# ⚠️ 注意: 您需要安装 pydantic-settings: pip install pydantic-settings
class Settings(BaseSettings):
@@ -8,31 +9,32 @@ class Settings(BaseSettings):
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
"""
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
extra='ignore' # 忽略环境变量中多余的键
env_file='.env',
env_file_encoding='utf-8',
extra='ignore' # 忽略环境变量中多余的键
)
# Redis 配置
REDIS_HOST: str = Field(default='localhost', description="Redis服务器地址")
REDIS_HOST: str = Field(default='10.1.1.240', description="Redis服务器地址")
REDIS_PORT: int = Field(default=6379, description="Redis服务器端口")
REDIS_DB: int = Field(default=0, description="Redis数据库编号")
REDIS_DB: int = Field(default=3, description="Redis数据库编号")
REDIS_HISTORY_KEY_PREFIX: str = Field(default="chat:history:", description="Redis会话历史键的前缀")
# LLM 配置
GEMINI_API_KEY: str = Field(..., description="Google Gemini API 密钥。必须设置。")
# GEMINI_API_KEY: str = Field(..., description="Google Gemini API 密钥。必须设置。")
LLM_MODEL_NAME: str = Field(default="gemini-2.5-flash", description="使用的 LLM 模型名称")
# 路径配置参数
DATA_ROOT: str = Field(default="./data", description="数据根目录")
IMAGE_DIR: str = Field(default="./data/image_data", description="图片数据目录")
OUTFIT_OUTPUT_DIR: str = Field(default="./data/outfit_output", description="生成的搭配图片输出目录")
STYLIST_GUIDE_DIR: str = Field(default="./data/stylist_guide", description="风格指南文本目录")
DATA_ROOT: str = Field(default="/workspace/lc_stylist_agent/app/core/data", description="数据根目录")
IMAGE_DIR: str = Field(default="/workspace/lc_stylist_agent/app/core/data/image_data", description="图片数据目录")
OUTFIT_OUTPUT_DIR: str = Field(default="/workspace/lc_stylist_agent/app/core/data/outfit_output", description="生成的搭配图片输出目录")
STYLIST_GUIDE_DIR: str = Field(default="/workspace/lc_stylist_agent/app/core/data/stylist_guide", description="风格指南文本目录")
# 向量数据库配置参数
VECTOR_DB_DIR: str = Field(default="./data/db", description="向量数据库目录")
VECTOR_DB_DIR: str = Field(default="./app/core/data/db", description="向量数据库目录")
COLLECTION_NAME: str = Field(default="lc_clothing_embedding", description="向量数据库集合名称")
EMBEDDING_MODEL_NAME: str = Field(default="openai/clip-vit-base-patch32", description="CLIP嵌入模型名称")
# 创建配置实例,供应用其他部分使用
settings = Settings()

View File

@@ -24,13 +24,15 @@ class AsyncGeminiLLM(AsyncLLMInterface):
def __init__(self, model_name: str = "gemini-2.5-flash"):
self.model_name = model_name
try:
self.gemini_client = genai.Client()
self.gemini_client = genai.Client(
vertexai=True, project='aida-461108', location='us-central1'
)
except Exception as e:
raise type(e)(f"Failed to initialize Gemini Client. Check if GEMINI_API_KEY is set. Original error: {e}")
async def generate_response(self, history: List[Message], system_prompt: str) -> str:
contents = []
for msg in history:
gemini_role = "user" if msg.role == Role.USER else "model"
content = types.Content(
@@ -52,4 +54,3 @@ class AsyncGeminiLLM(AsyncLLMInterface):
except Exception as e:
raise type(e)(f"Gemini API call failed: {e}")

View File

@@ -0,0 +1,61 @@
from abc import ABC, abstractmethod
from typing import List
from google import genai
from google.genai import types
from app.core.data_structure import Message, Role
class AsyncLLMInterface(ABC):
@abstractmethod
async def generate_response(self, history: List[Message], system_prompt: str) -> str:
"""
根据对话历史和系统指令生成回复.
Args:
history: 包含多条 Message 的列表。
Returns:
LLM 生成的回复字符串。
"""
raise NotImplementedError("Subclasses must implement this method")
class AsyncGeminiLLM(AsyncLLMInterface):
def __init__(self, model_name: str = "gemini-2.5-flash"):
self.model_name = model_name
try:
self.gemini_client = genai.Client(
vertexai=True, project='aida-461108', location='us-central1'
)
except Exception as e:
raise type(e)(f"Failed to initialize Gemini Client. Check if GEMINI_API_KEY is set. Original error: {e}")
async def generate_response(self, history: List[Message], system_prompt: str):
contents = []
for msg in history:
gemini_role = "user" if msg.role == Role.USER else "model"
content = types.Content(
role=gemini_role,
parts=[types.Part.from_text(text=msg.content)]
)
contents.append(content)
return contents
# response_stream = await self.gemini_client.aio.models.generate_content_stream(
# model=self.model_name,
# contents=contents,
# config=types.GenerateContentConfig(
# system_instruction=system_prompt,
# # temperature=0.3,
# )
# )
#
# # 3. 异步迭代流,并 yield 每个块的文本
# async for chunk in response_stream:
# # 确保 chunk 中有可用的文本
# if chunk.text:
# print(chunk.text)
# yield chunk.text

View File

@@ -1,12 +1,17 @@
import logging
import redis
from typing import List, Optional
from app.core.data_structure import Message, Role
logger = logging.getLogger(__name__)
# 这是一个同步 Redis 客户端,用于演示如何替换内存存储。
# 在生产环境和异步 Web 框架中,应替换为 aioredis 等异步客户端。
class RedisManager:
"""同步管理器,用于在 Redis 中存储和检索对话历史。"""
def __init__(self, host: str = 'localhost', port: int = 6379, db: int = 0, key_prefix: str = "chat:history:"):
self.r: Optional[redis.Redis] = None
self.key_prefix = key_prefix
@@ -14,10 +19,10 @@ class RedisManager:
# 尝试连接 Redis
self.r = redis.Redis(host=host, port=port, db=db, decode_responses=True)
self.r.ping()
print("Successfully connected to Redis at {}:{}".format(host, port))
logger.info("Successfully connected to Redis at {}:{}".format(host, port))
except Exception as e:
print(f"⚠️ Failed to connect to Redis: {e}. Falling back to No-Op.")
self.r = None # 连接失败时设置为 None避免后续操作报错
logger.error(f"⚠️ Failed to connect to Redis: {e}. Falling back to No-Op.")
self.r = None # 连接失败时设置为 None避免后续操作报错
def _get_key(self, user_id: str) -> str:
"""生成用户历史记录的 Redis 键名。"""
@@ -32,32 +37,32 @@ class RedisManager:
try:
return Message.model_validate_json(data)
except Exception as e:
print(f"Error deserializing message data: {data[:50]}... Error: {e}")
logger.error(f"Error deserializing message data: {data[:50]}... Error: {e}")
return Message(role=Role.ASSISTANT, content="[Deserialization Error]")
def save_message(self, user_id: str, message: Message):
"""将单条消息保存到用户历史记录列表的末尾。"""
if not self.r:
return
message_json = self._message_to_json(message)
# RPUSH将元素添加到列表的尾部
self.r.rpush(self._get_key(user_id), message_json)
def get_history(self, user_id: str) -> List[Message]:
"""检索用户的完整会话历史记录。"""
if not self.r:
return []
# LRANGE获取列表的所有元素 (0 到 -1)
raw_history = self.r.lrange(self._get_key(user_id), 0, -1)
# 将 JSON 字符串列表转换为 Message 对象列表
messages = [self._json_to_message(data) for data in raw_history]
return messages
def clear_history(self, user_id: str):
"""删除用户的完整历史记录。"""
if self.r:
self.r.delete(self._get_key(user_id))
print(f"History cleared for {user_id}")
logger.info(f"History cleared for {user_id}")

View File

@@ -15,7 +15,9 @@ class AsyncStylistAgent:
def __init__(self, local_db, max_len: int, outfits_root: str, image_dir: str, stylist_guide_dir: str, gemini_model_name: str):
# self.outfit_items: List[Dict[str, str]] = []
self.outfit_id = str(uuid.uuid4())
self.gemini_client = genai.Client()
self.gemini_client = genai.Client(
vertexai=True, project='aida-461108', location='us-central1'
)
self.local_db = local_db
self.max_len = max_len
self.output_outfit_path = os.path.join(outfits_root, f"{self.outfit_id}.jpg")
@@ -108,15 +110,15 @@ class AsyncStylistAgent:
模型的响应文本(预期为 JSON 字符串)。
"""
content_parts = []
self._clear_uploaded_files()
# self._clear_uploaded_files()
# 1. 添加图片内容
if self.outfit_items:
merged_image_path = merge_images_to_square(self.outfit_items, max_len=self.max_len, output_path=self.output_outfit_path)
try:
myfile = await self.gemini_client.aio.files.upload(file=merged_image_path)
content_parts.append(myfile)
except Exception as e:
print(f"Error loading image {merged_image_path}: {e}")
# if self.outfit_items:
# merged_image_path = merge_images_to_square(self.outfit_items, max_len=self.max_len, output_path=self.output_outfit_path)
# try:
# myfile = await self.gemini_client.aio.files.upload(file=merged_image_path)
# content_parts.append(myfile)
# except Exception as e:
# print(f"Error loading image {merged_image_path}: {e}")
# 2. 添加文本内容
content_parts.append(user_input)

View File

@@ -0,0 +1,436 @@
import asyncio
import io
import json
import logging
import os
import random
import uuid
from typing import List, Dict, Any, Optional
from google import genai
from google.cloud import storage
from google.oauth2 import service_account
from app.core.utils_litserve import merge_images_to_square
from app.server.utils.minio_client import minio_client, oss_upload_image
from app.server.utils.request_post import post_request
logger = logging.getLogger(__name__)
class AsyncStylistAgent:
CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'}
def __init__(self, local_db, max_len: int, gemini_model_name: str):
# self.outfit_items: List[Dict[str, str]] = []
self.outfit_id = str(uuid.uuid4())
self.gemini_client = genai.Client(
vertexai=True, project='aida-461108', location='us-central1'
)
self.local_db = local_db
self.max_len = max_len
self.gemini_model_name = gemini_model_name
self.stop_reason = ""
# 存储桶配置
try:
# TODO 目前写死路径 生产环境切换路径
self.credentials = service_account.Credentials.from_service_account_file(os.getenv("GOOGLE_APPLICATION_CREDENTIALS"))
except Exception as e:
# 这里的异常处理应根据实际情况调整
raise RuntimeError(f"Failed to load credentials from file {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}: {e}")
self.gcs_client = storage.Client(
project=self.credentials.project_id,
credentials=self.credentials
)
self.gcs_bucket = "lc_stylist_agent_outfit_items"
self.minio_bucket = "lanecarford"
def _load_style_guide(self, path: str) -> str:
"""加载 markdown 风格指南内容。"""
parts = path.split('/', 1)
if len(parts) != 2:
raise ValueError("MinIO path must be in 'bucket_name/object_name' format.")
bucket_name, object_name = parts
try:
# 1. 获取对象
response = minio_client.get_object(bucket_name, object_name)
# 2. 读取内容
content_bytes = response.read()
# 3. 关闭连接
response.close()
response.release_conn()
# 4. 解码并返回
return content_bytes.decode('utf-8')
except Exception as e:
raise Exception(f"Failed to load style guide from {path}: {e}")
def _build_system_prompt(self, request_summary: str = "") -> str:
"""Constructs the complete System Prompt."""
# Insert the style_guide content into the template
template = f"""
You are a professional fashion stylist Agent, specialized in creating complete outfits for the user.
Your task is to **create a cohesive and complete outfit**, strictly adhering to **BOTH** the user's explicit **Request Summary** and the **Outfit Style Guide**. You must decide the next logical item to add to the outfit based on the currently selected items (if any).
---
## Request from the User:
{request_summary}
## Core Guidance Document: Outfit Style Guide
{self.style_guide}
---
## Your Workflow and Constraints
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**.
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories.
3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows:
```json
{{
"action": "recommend_item",
"category": "YOUR_ITEM_CATEGORY",
"description": "YOUR_DETAILED_DESCRIPTION"
}}
```
* `action`: Must always be `"recommend_item"` until the outfit is complete.
* `category`: Must be the category of the item you are recommending, strictly selected from the following list: {list(self.CATEGORY_SET)}.
* `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include:
* **Color** (e.g., milk tea, pure white, dark gray)
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
* **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern)
* **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look)
* **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.**
4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process:
```json
{{
"action": "stop",
"reason": "OUTFIT_COMPLETE_AND_MEETS_ALL_MINI_GUIDELINES"
}}
```
Normally, five or six items are totally enough for an outfit.
5. **Context Dependency**: The user's next input (if not `Start`) will contain the **image and description of the selected item**. When recommending the next item, you must consider the coordination between the **already selected items** and the Style Guide.
**Now, please start building an outfit and output the JSON for the first item.**
"""
return template.strip()
def _clear_uploaded_files(self):
for f in self.gemini_client.files.list():
self.gemini_client.files.delete(name=f.name)
async def _call_gemini(self, user_input: str, user_id: str):
"""
实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。
Args:
user_input: 发送给模型的主文本内容。
image_paths: 待发送图片的本地路径列表。
Returns:
模型的响应文本(预期为 JSON 字符串)。
"""
minio_path = ""
content_parts = []
# self._clear_uploaded_files()
# 1. 添加图片内容
if self.outfit_items:
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len)
image_bytes_io = io.BytesIO()
image_format = 'JPEG'
mime_type = 'image/jpeg'
merged_image.save(image_bytes_io, format=image_format)
image_bytes = image_bytes_io.getvalue()
file_name = uuid.uuid4()
blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg"
gcs_path = self._upload_to_gcs(bucket_name=self.gcs_bucket, blob_name=blob_name, mime_type=mime_type, image_bytes=image_bytes)
responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes)
minio_path = f"{responses.bucket_name}/{responses.object_name}"
content_parts.append(gcs_path)
# 2. 添加文本内容
content_parts.append(user_input)
# print(f"\n--- Calling Gemini with {len(self.outfit_items) if self.outfit_items else 0} images and query:\n{user_input}")
try:
# 3. 实际 API 调用
response = await self.gemini_client.aio.models.generate_content(
model=self.gemini_model_name,
contents=content_parts,
config={
"system_instruction": self.system_prompt,
# 确保模型返回 JSON 格式
"response_mime_type": "application/json",
"response_schema": {
"type": "object",
"properties": {
"action": {"type": "string", "enum": ["recommend_item", "stop"]},
"category": {"type": "string"},
"description": {"type": "string"},
"reason": {"type": "string"}
},
"required": ["action"]
}
}
)
# response.text 将包含一个 JSON 字符串
return response.text, minio_path
except Exception as e:
print(f"Gemini API Call failed: {e}")
# 返回一个停止信号以防止循环继续
return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"})
def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]:
"""安全解析 Gemini 的 JSON 响应。"""
try:
# 有时 Gemini 可能会在 JSON 外面添加文字,尝试清理
response_text = response_text.strip().replace('```json', '').replace('```', '')
data = json.loads(response_text)
# print(f"The agent response is: {data}")
return data
except json.JSONDecodeError as e:
print(f"Error parsing JSON from Gemini: {e}")
print(f"Raw response: {response_text}")
return None
def _get_next_item(self, item_description: str, category: str) -> Optional[Dict[str, str]]:
"""
1. 根据描述生成嵌入。
2. 查询本地数据库以找到最佳匹配项。
3. 模拟 Agent 审核匹配项(这里简化为总是通过)。
"""
try:
# 1. 生成查询嵌入
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False)
# 2. 执行查询,并过滤类别
results = self.local_db.query_local_db(query_embedding, category, n_results=1)
if not results:
print(f"❌ 数据库中未找到符合 '{category}' 和描述的单品。")
return None
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
best_meta = results['metadatas'][0][0] # 第一个 batch 的第一个 metadata
return {
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
"category": category,
"gpt_description": item_description,
'description': best_meta['description'],
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
# 这里假设 item_id 就是文件名的一部分
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
}
except Exception as e:
print(f"An error occurred during item retrieval: {e}")
return None
def _build_user_input(self) -> str:
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
if not self.outfit_items:
return "Start"
# 将已选单品的信息作为上下文发回给 Agent
context = "Selected fashion items:\n"
for ii, item in enumerate(self.outfit_items):
context += f"{ii + 1}. Category: {item['category']}. Description: {item['description']}\n"
context += "\nPlease recommend the next single item based on the selected items, user's request, and style guide."
return context
async def run_styling_process(self, request_summary, stylist_path, start_outfit=None, user_id="test"):
if start_outfit is None:
start_outfit = []
self.outfit_items = start_outfit if start_outfit else []
"""主流程控制循环。"""
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
self.style_guide = self._load_style_guide(stylist_path)
self.system_prompt = self._build_system_prompt(request_summary)
response_data = {"status": "",
"message": "",
"path": "",
"outfit_id": self.outfit_id,
"items": []
}
logger.info(response_data)
item_id = ""
while True:
# 1. 准备用户输入(上下文)
user_input = self._build_user_input()
# 2. 调用 Gemini Agent
gemini_response_text, minio_path = await self._call_gemini(user_input, user_id)
gemini_data = self._parse_gemini_response(gemini_response_text)
response_data['path'] = minio_path
if item_id:
response_data['items'].append(item_id)
if not gemini_data:
print("🚨 Agent 返回无效响应,终止流程。")
self.stop_reason = "Agent failed to return response"
response_data['status'] = "failed"
response_data['message'] = self.stop_reason
break
# 3. 检查终止条件
if gemini_data.get('action') == 'stop':
print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
break
# 4. 处理推荐单品
if gemini_data.get('action') == 'recommend_item':
category = gemini_data.get('category')
description = gemini_data.get('description')
# 4a. 检查类别是否有效 (重要步骤)
if category not in self.CATEGORY_SET:
print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。")
# 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正
# 这里简化为跳过本次循环
response_data['status'] = "continue"
response_data['message'] = f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。",
continue
# 4b. 在本地 DB 中查询单品
new_item = self._get_next_item(description, category)
item_id = new_item.get('item_id')
if new_item:
# 4c. (实际步骤) 将选中的单品图片和描述发回给 Agent 进行最终审核
# 这里的代码框架省略了图片回传和二次审核的步骤,直接视为通过
# 实际你需要: new_user_input = f"Check this item: {new_item['description']}, path: {new_item['image_path']}"
# call_gemini_agent(...) -> 如果返回"pass",则添加到outfit_items
if new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
print("This item exists. Stop here.")
self.stop_reason = "Finish reason: Duplicate item selected."
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
break
if new_item['item_id'] == "ELG383":
if random.random() < 0.70:
self.stop_reason = "Finish reason: ELG383 is seleced repeatly."
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
break
self.outfit_items.append(new_item)
# print(f" 成功添加单品: {new_item['category']} ({new_item['item_id']}). 当前搭配数量: {len(self.outfit_items)}")
response_data['status'] = "ok"
response_data['message'] = self.stop_reason
else:
print("⚠️ 未找到匹配单品,无法继续搭配。终止。")
self.stop_reason = "Finish reason: No matching item found in local database."
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
break
if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环
print("🚨 达到最大搭配数量限制,强制终止。")
self.stop_reason = "Finish reason: Reached max outfit length."
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
logger.info(response_data)
break
logger.info(response_data)
headers = {
'Accept': "*/*",
'Accept-Encoding': "gzip, deflate, br",
'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0",
'Connection': "keep-alive",
'Content-Type': "application/json"
}
url = 'https://83aa2db8e006.ngrok-free.app/api/style/callback'
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(response.text)
return response_data
# def _save_outfit_results(self, user_id):
# """保存最终的 JSON 列表和图片到指定文件夹。"""
# if not self.outfit_items:
# raise ValueError("No outfit items to save.")
#
# # 1. 保存 JSON 文件
# results_list = [{'item_id': item['item_id'], 'category': item['category'], 'description': item['description'], 'gpt_description': item['gpt_description']} for item in self.outfit_items]
# results_list.append({'stop_reason': self.stop_reason})
#
# return upload_json_to_minio_sync(
# minio_client=minio_client,
# bucket_name=f"lanecarford",
# object_name=f"lc_stylist_agent_outfit_items/{user_id}/{uuid.uuid4()}.json",
# data=results_list
# )
def _upload_to_gcs(self, bucket_name: str, blob_name: str, mime_type, image_bytes) -> str:
"""同步方法:将文件上传到 GCS 并返回 GCS URI。"""
bucket = self.gcs_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
blob.upload_from_string(
data=image_bytes,
content_type=mime_type
)
gcs_uri = f"gs://{bucket_name}/{blob_name}"
return gcs_uri
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit: List[Dict[str, str]] = [], num_outfits: int = 1):
"""
基于用户的对话历史和需求,推荐一套搭配。
Args:
request_summary: 用户的request
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id''category'
"""
tasks = []
for _ in range(num_outfits):
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
task = agent.run_styling_process(request_summary, stylist_name, start_outfit)
tasks.append(task)
print(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
successful_outfits = []
failed_outfits = []
for result in results:
if isinstance(result, Exception):
# 任务执行中发生异常
failed_outfits.append(f"Failed: {result}")
else:
# 任务成功result 是 run_styling_process 返回的图片路径
successful_outfits.append(result)
return {
"successful_outfits": successful_outfits,
"failed_outfits": failed_outfits
}
except Exception as e:
print(f"An unexpected error occurred during concurrent recommendation: {e}")
return {"error": str(e)}

View File

@@ -7,20 +7,21 @@ from PIL import Image, ImageDraw, ImageFont
# 布局顺序: 从上到下,从左到右 (1 -> 9)
ALL_9_CELLS = [
# Top Row (Y=0, H=341)
(0, 0, 341, 341), # 1. Top-Left (341x341)
(0, 0, 341, 341), # 1. Top-Left (341x341)
(341, 0, 341, 341), # 2. Top-Middle (341x341)
(682, 0, 342, 341), # 3. Top-Right (342x341)
# Middle Row (Y=341, H=341)
(0, 341, 341, 341), # 4. Mid-Left (341x341)
(341, 341, 341, 341),# 5. Center (341x341)
(682, 341, 342, 341),# 6. Mid-Right (342x341)
(341, 341, 341, 341), # 5. Center (341x341)
(682, 341, 342, 341), # 6. Mid-Right (342x341)
# Bottom Row (Y=682, H=342)
(0, 682, 341, 342), # 7. Bottom-Left (341x342)
(341, 682, 341, 342),# 8. Bottom-Middle (341x342)
(682, 682, 342, 342) # 9. Bottom-Right (342x342)
(341, 682, 341, 342), # 8. Bottom-Middle (341x342)
(682, 682, 342, 342) # 9. Bottom-Right (342x342)
]
def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, output_path="temp.jpg", add_text=True) -> str:
def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, output_path="temp.jpg", add_text=True):
"""
Loads up to 4 images from the given paths, resizes them while maintaining
aspect ratio, and merges them onto a 1024x1024 white background JPG.
@@ -37,31 +38,30 @@ def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, output
Returns:
The file path of the temporary merged JPG image.
"""
# Define the final canvas size
CANVAS_SIZE = 1024
# 1. Create the final white canvas
# Using 'RGB' mode for JPG output
canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), 'white')
draw = ImageDraw.Draw(canvas)
font = ImageFont.load_default()
# 2. Define the quadrants/target areas (x, y, w, h)
# The positions are based on a 512x512 quadrant size
quadrants = {
1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement
2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right
3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left
4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four
5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5)
6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6)
7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7)
8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8)
1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement
2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right
3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left
4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four
5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5)
6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6)
7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7)
8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8)
9: ALL_9_CELLS[:9] # 布局全部9个单元格 (1-9)
}
# 3. Load and Filter Images
valid_images = []
image_paths = [item['image_path'] for item in outfit_items]
@@ -75,17 +75,16 @@ def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, output
print(f"Error loading image {path}. Skipping: {e}")
num_images = len(valid_images)
if num_images == 0:
raise ValueError("No valid images were loaded.")
if num_images > max_len:
raise ValueError(f"Valid item number {num_images} exceed max limit {max_len}")
# Get the correct list of target areas based on the number of valid images
target_areas = quadrants.get(num_images, [])
# 4. Resize and Paste
for i, (img, item) in enumerate(zip(valid_images, outfit_items)):
item_id = item['item_id']
@@ -93,40 +92,40 @@ def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, output
if i >= len(target_areas):
# This should not happen if num_images <= 4
break
# Target area dimensions (x_start, y_start, width, height)
x_start, y_start, target_w, target_h = target_areas[i]
# Calculate new size while maintaining aspect ratio
original_w, original_h = img.size
# Calculate the ratio needed to fit within the target area
ratio_w = target_w / original_w
ratio_h = target_h / original_h
# Use the *smaller* of the two ratios to ensure the image fits entirely
resize_ratio = min(ratio_w, ratio_h)
# Calculate the new dimensions
new_w = int(original_w * resize_ratio)
new_h = int(original_h * resize_ratio)
# Resize the image. Image.Resampling.LANCZOS provides high-quality scaling.
# Pillow documentation recommends ANTIALIAS or BICUBIC for downscaling,
# but LANCZOS is a good general high-quality filter.
# Note: In Pillow versions > 9.0.0, Image.LANCZOS is now Image.Resampling.LANCZOS
resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Calculate the paste position to center the resized image within its target area
# Center X: (Target Width - New Width) / 2 + X Start
paste_x = (target_w - new_w) // 2 + x_start
# Center Y: (Target Height - New Height) / 2 + Y Start
# paste_y = (target_h - new_h) // 2 + y_start
TEXT_RESERVE_HEIGHT = 30
paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start
paste_y = max(paste_y, y_start)
TEXT_RESERVE_HEIGHT = 30
paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start
paste_y = max(paste_y, y_start)
# Paste the resized image onto the canvas
canvas.paste(resized_img, (paste_x, paste_y))
@@ -140,24 +139,22 @@ def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, output
# 兼容旧版本 Pillow
text_w, text_h = draw.textsize(full_text, font=font)
# 计算 X 轴起始位置:使其在目标区域 (target_w) 中居中
text_x_center = x_start + target_w // 2
text_x_start = text_x_center - text_w // 2
# 计算 Y 轴起始位置:将其放在目标区域的底部
# (目标区域的起始Y + 目标区域的高度 - 文本行的高度)
text_y_start = y_start + target_h - text_h - 5 # 减去 5 像素作为边距
text_y_start = y_start + target_h - text_h - 5 # 减去 5 像素作为边距
# 3. 绘制合并后的文本
if add_text:
draw.text((text_x_start, text_y_start),
full_text,
fill='black',
font=font)
draw.text((text_x_start, text_y_start),
full_text,
fill='black',
font=font)
# Save as a high-quality JPG (quality=90 is a good balance)
canvas.save(output_path, 'JPEG', quality=90)
return output_path
# canvas.save(output_path, 'JPEG', quality=90)
return canvas

163
app/core/utils_litserve.py Normal file
View File

@@ -0,0 +1,163 @@
import logging
from typing import List, Dict
from PIL import Image, ImageDraw, ImageFont
from app.server.utils.minio_client import oss_get_image, minio_client
from app.server.utils.minio_config import MINIO_LC_DATA_PATH
logger = logging.getLogger(__name__)
# 9个 341x341 左右的单元格 (ALL_9_CELLS)
# 布局顺序: 从上到下,从左到右 (1 -> 9)
ALL_9_CELLS = [
# Top Row (Y=0, H=341)
(0, 0, 341, 341), # 1. Top-Left (341x341)
(341, 0, 341, 341), # 2. Top-Middle (341x341)
(682, 0, 342, 341), # 3. Top-Right (342x341)
# Middle Row (Y=341, H=341)
(0, 341, 341, 341), # 4. Mid-Left (341x341)
(341, 341, 341, 341), # 5. Center (341x341)
(682, 341, 342, 341), # 6. Mid-Right (342x341)
# Bottom Row (Y=682, H=342)
(0, 682, 341, 342), # 7. Bottom-Left (341x342)
(341, 682, 341, 342), # 8. Bottom-Middle (341x342)
(682, 682, 342, 342) # 9. Bottom-Right (342x342)
]
def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, add_text=True):
"""
Loads up to 4 images from the given paths, resizes them while maintaining
aspect ratio, and merges them onto a 1024x1024 white background JPG.
The layout depends on the number of images:
1: Center the single image on the 1024x1024 canvas.
2: Place side-by-side, each scaled to fit a 512x1024 half.
3: Place in top-left (512x512), top-right (512x512), and bottom-left (512x512).
4: Place in all four 512x512 quadrants.
Args:
outfit_items: A list of item metadata (max length 9).
Returns:
The file path of the temporary merged JPG image.
"""
# Define the final canvas size
CANVAS_SIZE = 1024
# 1. Create the final white canvas
# Using 'RGB' mode for JPG output
canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), 'white')
draw = ImageDraw.Draw(canvas)
font = ImageFont.load_default()
# 2. Define the quadrants/target areas (x, y, w, h)
# The positions are based on a 512x512 quadrant size
quadrants = {
1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement
2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right
3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left
4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four
5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5)
6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6)
7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7)
8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8)
9: ALL_9_CELLS[:9] # 布局全部9个单元格 (1-9)
}
# 3. Load and Filter Images
valid_images = []
image_paths = [item['image_path'] for item in outfit_items]
for path in image_paths:
try:
# We use Image.open() and convert to 'RGB' to handle potential transparency (RGBA)
# and ensure compatibility with the final 'RGB' canvas and JPG output.
img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB')
# img = Image.open(path).convert('RGB')
valid_images.append(img)
except Exception as e:
logger.error(f"Error loading image {path}. Skipping: {e}")
num_images = len(valid_images)
if num_images == 0:
raise ValueError("No valid images were loaded.")
if num_images > max_len:
raise ValueError(f"Valid item number {num_images} exceed max limit {max_len}")
# Get the correct list of target areas based on the number of valid images
target_areas = quadrants.get(num_images, [])
# 4. Resize and Paste
for i, (img, item) in enumerate(zip(valid_images, outfit_items)):
item_id = item['item_id']
category = item['category']
if i >= len(target_areas):
# This should not happen if num_images <= 4
break
# Target area dimensions (x_start, y_start, width, height)
x_start, y_start, target_w, target_h = target_areas[i]
# Calculate new size while maintaining aspect ratio
original_w, original_h = img.size
# Calculate the ratio needed to fit within the target area
ratio_w = target_w / original_w
ratio_h = target_h / original_h
# Use the *smaller* of the two ratios to ensure the image fits entirely
resize_ratio = min(ratio_w, ratio_h)
# Calculate the new dimensions
new_w = int(original_w * resize_ratio)
new_h = int(original_h * resize_ratio)
# Resize the image. Image.Resampling.LANCZOS provides high-quality scaling.
# Pillow documentation recommends ANTIALIAS or BICUBIC for downscaling,
# but LANCZOS is a good general high-quality filter.
# Note: In Pillow versions > 9.0.0, Image.LANCZOS is now Image.Resampling.LANCZOS
resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Calculate the paste position to center the resized image within its target area
# Center X: (Target Width - New Width) / 2 + X Start
paste_x = (target_w - new_w) // 2 + x_start
# Center Y: (Target Height - New Height) / 2 + Y Start
# paste_y = (target_h - new_h) // 2 + y_start
TEXT_RESERVE_HEIGHT = 30
paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start
paste_y = max(paste_y, y_start)
# Paste the resized image onto the canvas
canvas.paste(resized_img, (paste_x, paste_y))
full_text = f"ID: {item_id}, Category: {category}"
try:
# 推荐使用:计算文本的实际尺寸 (width, height)
bbox = draw.textbbox((0, 0), full_text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
except AttributeError:
# 兼容旧版本 Pillow
text_w, text_h = draw.textsize(full_text, font=font)
# 计算 X 轴起始位置:使其在目标区域 (target_w) 中居中
text_x_center = x_start + target_w // 2
text_x_start = text_x_center - text_w // 2
# 计算 Y 轴起始位置:将其放在目标区域的底部
# (目标区域的起始Y + 目标区域的高度 - 文本行的高度)
text_y_start = y_start + target_h - text_h - 5 # 减去 5 像素作为边距
# 3. 绘制合并后的文本
if add_text:
draw.text((text_x_start, text_y_start),
full_text,
fill='black',
font=font)
# Save as a high-quality JPG (quality=90 is a good balance)
# canvas.save(output_path, 'JPEG', quality=90)
return canvas

View File

@@ -9,9 +9,9 @@ from PIL import Image
class VectorDatabase():
def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str):
print(vector_db_dir)
self.client = chromadb.PersistentClient(path=vector_db_dir)
self.collection = self.client.get_collection(name=collection_name)
self.collection = self.client.get_or_create_collection(name=collection_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -20,7 +20,7 @@ class VectorDatabase():
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
if is_image:
inputs = self.processor(images=data, return_tensors="pt").to(self.device)
with torch.no_grad():
@@ -28,25 +28,24 @@ class VectorDatabase():
else:
# 强制截断,解决序列长度问题
inputs = self.processor(
text=[data],
return_tensors="pt",
text=[data],
return_tensors="pt",
padding=True,
truncation=True
truncation=True
).to(self.device)
with torch.no_grad():
features = self.model.get_text_features(**inputs)
# L2 归一化
features = features / features.norm(p=2, dim=-1, keepdim=True)
return features.cpu().numpy().flatten().tolist()
def query_local_db(self, embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]:
"""
基于嵌入向量在本地数据库中查询相似单品。
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
"""
print(f"--- Querying DB for Category: {category} ---")
# 实际应执行向量查询
# 为了演示流程,返回一个模拟结果
results = self.collection.query(
@@ -60,4 +59,4 @@ class VectorDatabase():
},
include=['documents', 'metadatas', 'distances']
)
return results
return results