tons of modification for occasion filtering

This commit is contained in:
pangkaicheng
2025-12-09 16:06:07 +08:00
parent ee695e7511
commit 0b1d948f77
35 changed files with 728 additions and 2186 deletions

View File

@@ -1,10 +1,13 @@
import asyncio
import logging
import uuid
from enum import Enum
from typing import List
from pydantic import Field
import litserve as ls
from pydantic import BaseModel
from app.core.config import settings
from app.config import settings
from app.server.ChatbotAgent.core.data_structure import Message, Role
from app.server.ChatbotAgent.core.llm_interface import AsyncGeminiLLM
from app.server.ChatbotAgent.core.redis_manager import RedisManager
@@ -15,11 +18,40 @@ from app.server.ChatbotAgent.core.vector_database import VectorDatabase
logger = logging.getLogger(__name__)
class OccasionEnum(str, Enum):
CASUAL = "Casual"
FORMAL = "Formal"
ACTIVEWEAR = "Activewear"
RESORT = "Resort"
EVENING = "Evening"
OUTDOOR = "Outdoor"
BUSINESS_WORKWEAR = "Business / workwear"
COCKTAIL_SEMI_FORMAL = "Cocktail / Semi-Formal"
BLACK_TIE_WHITE_TIE = "Black Tie / White Tie"
BRIDAL_WEDDING = "Bridal / Wedding"
FESTIVAL_CONCERT = "Festival / Concert"
PARTY_CLUBBING = "Party / Clubbing"
TRAVEL_TRANSIT = "Travel / Transit"
ATHLEISURE = "Athleisure"
BEACH_SWIM = "Beach / Swim"
SKI_SNOW_MOUNTAIN = "Ski / Snow / Mountain"
GARDEN_PARTY_DAYTIME = "Garden Party / Daytime Event"
class StylistResponse(BaseModel):
occasions: List[OccasionEnum] = Field(
description="A list of **applicable** occasions that are most strongly implied or explicitly requested by the user's conversation history. These occasions are used later in item retrieval for filtering and must strictly match the predefined OccasionEnum list."
)
summary: str = Field(
description="A detailed summary of the user's styling requirements, preferences, constraints, and specific item requests."
)
class AgentRequestModel(BaseModel):
user_id: str
session_id: str
num_outfits: int
stylist_path: str
batch_source: str
callback_url: str
gender: str
max_len: int = 9
@@ -41,7 +73,6 @@ class LCAgent(ls.LitAPI):
)
self.stylist_agent_kwages = {
'local_db': self.vector_db,
'max_len': 9,
'gemini_model_name': settings.LLM_MODEL_NAME
}
@@ -73,40 +104,68 @@ class LCAgent(ls.LitAPI):
async def background_run(self, request: AgentRequestModel, outfit_ids):
# 1. 根据用户ID查询对话历史总结对话内容
request_summary = await self.get_conversation_summary(request.session_id)
request_summary, occasions = await self.get_conversation_summary(request.session_id)
logger.info(f"request_summary: {request_summary}")
# 2.根据对话总结推荐搭配
recommendation_results = await self.recommend_outfit(request_summary=request_summary,
stylist_name=request.stylist_path,
start_outfit=[],
num_outfits=request.num_outfits,
user_id=request.user_id,
gender=request.gender,
callback_url=request.callback_url,
max_len=request.max_len,
outfit_ids=outfit_ids)
recommendation_results = await self.recommend_outfit(
request_summary=request_summary,
occasions=occasions,
stylist_name=request.stylist_path,
batch_source=request.batch_source,
start_outfit=[],
num_outfits=request.num_outfits,
user_id=request.user_id,
gender=request.gender,
callback_url=request.callback_url,
max_len=request.max_len,
outfit_ids=outfit_ids
)
logger.info("--- Final Recommendation Results ---")
for i, path in enumerate(recommendation_results.get("successful_outfits", [])):
logger.info(f"✅ Outfit {i + 1} saved to: {path}")
for failed in recommendation_results.get("failed_outfits", []):
logger.error(f"{failed}")
async def get_conversation_summary(self, session_id: str) -> str:
async def get_conversation_summary(self, session_id: str) -> dict:
"""
分析用户的完整会话历史,并打包成一个简洁的需求总结
这个总结可以直接作为输入 Prompt 传递给 Stylist Agent。`
分析用户的完整会话历史,并返回结构化的需求数据
Returns:
occasions: List[str], # 用于 Vector DB 筛选
summary: str # 用于 recommendation 的输入
"""
history_messages = self.redis.get_history(session_id)
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
# 临时调用 LLM 或使用本地逻辑生成总结
summary = await self.llm.generate_response(history=[Message(role=Role.USER, content=input_message)],
system_prompt=SUMMARY_PROMPT)
return summary
if not history_messages:
# 处理无历史记录的情况
return {"occasions": [], "summary": "User has no history provided."}
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit=None, num_outfits: int = 1,
user_id: str = "test", gender: str = "male", callback_url: str = None, max_len: int = 9, outfit_ids=None):
input_message = "\n".join([f"{msg.role.value}: {msg.content}" for msg in history_messages])
json_schema = StylistResponse.model_json_schema()
raw_response = await self.llm.generate_response(
history=[Message(role=Role.USER, content=input_message)],
system_prompt=SUMMARY_PROMPT,
json_schema=json_schema
)
try:
# 验证并解析 JSON
parsed_result = StylistResponse.model_validate_json(raw_response)
print(f"Occasions: {[occ.value for occ in parsed_result.occasions]}")
print(f"Summary: {parsed_result.summary}") # 这是一个 string
except Exception as e:
logger.error(f"Schema validation failed: {e}")
return str(parsed_result.summary), [occ.value for occ in parsed_result.occasions]
async def recommend_outfit(
self, request_summary: str, occasions: List[str], batch_source: str, stylist_name: str, start_outfit=[],
num_outfits: int = 1, user_id: str = "test", gender: str = "male",
callback_url: str = None, max_len: int = 9, outfit_ids=None
):
"""
基于用户的对话历史和需求,推荐一套搭配。
@@ -116,8 +175,6 @@ class LCAgent(ls.LitAPI):
"""
if outfit_ids is None:
outfit_ids = []
if start_outfit is None:
start_outfit = []
tasks = []
task_map = {}
@@ -128,7 +185,9 @@ class LCAgent(ls.LitAPI):
agent = AsyncStylistAgent(**stylist_agent_kwages)
task = agent.run_styling_process(
request_summary=request_summary,
stylist_path=stylist_name,
occasions=occasions,
batch_source=batch_source,
stylist_name=stylist_name,
start_outfit=start_outfit,
user_id=user_id,
callback_url=callback_url,
@@ -167,7 +226,9 @@ class LCAgent(ls.LitAPI):
agent = AsyncStylistAgent(**stylist_agent_kwages)
new_task = agent.run_styling_process(
request_summary=request_summary,
stylist_path=stylist_name,
occasions=occasions,
batch_source=batch_source,
stylist_name=stylist_name,
start_outfit=start_outfit,
user_id=user_id,
callback_url=callback_url
@@ -209,3 +270,48 @@ class LCAgent(ls.LitAPI):
"failed_outfits": failed_outfits,
"error": ""
}
if __name__ == "__main__":
async def test():
# 1. 准备测试实例
agent_api = LCAgent()
agent_api.setup(device='cpu')
# 2. 准备请求数据
import json
stylist_agent_kwages = agent_api.stylist_agent_kwages.copy()
with open("./data/2025_q4/request_test.json", "r") as f:
request_data = json.load(f)
tasks = []
for test_content in request_data[:30]:
occasions = test_content['occasions']
request_summary = test_content['request_summary']
stylist_agent_kwages['max_len'] = 5
for stylist_name in ["edi", "vera"]:
stylist_agent_kwages['outfit_id'] = test_content['test_case_id'] + "_" + "_".join(occasions) + f"_{stylist_name}"
agent = AsyncStylistAgent(**stylist_agent_kwages)
task = agent.run_styling_process(
request_summary=request_summary,
occasions=occasions,
batch_source="2025_q4",
stylist_name=stylist_name,
start_outfit=[],
user_id=test_content['test_case_id'],
callback_url="http://mock-callback.com/result",
gender="female",
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
print(f"❌ 任务失败: {type(result).__name__} - {str(result)}")
continue
try:
# 使用 asyncio.run() 来执行顶层异步函数
asyncio.run(test())
except Exception as e:
logger.error(f"Test failed due to an unexpected error: {e}")

View File

@@ -4,13 +4,13 @@ import litserve as ls
from typing import AsyncGenerator
from google import genai
from pydantic import BaseModel
from app.core.config import settings
from app.config import settings
from google.genai import types
from app.server.ChatbotAgent.core.data_structure import Message, Role
from app.server.ChatbotAgent.core.llm_interface import AsyncGeminiLLM
from app.server.ChatbotAgent.core.redis_manager import RedisManager
from app.server.ChatbotAgent.core.system_prompt import MEN_BASIC_PROMPT, WOMEN_BASIC_PROMPT
from app.server.ChatbotAgent.core.system_prompt import BASIC_PROMPT
from app.server.ChatbotAgent.core.vector_database import VectorDatabase
logger = logging.getLogger(__name__)
@@ -25,26 +25,12 @@ class PredictRequest(BaseModel):
class LCChatBot(ls.LitAPI):
def setup(self, device):
# self.llm = AsyncGeminiLLM(model_name=settings.LLM_MODEL_NAME)
self.redis = RedisManager(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
key_prefix=settings.REDIS_HISTORY_KEY_PREFIX
)
# self.vector_db = VectorDatabase(
# vector_db_dir=settings.VECTOR_DB_DIR,
# collection_name=settings.COLLECTION_NAME,
# embedding_model_name=settings.EMBEDDING_MODEL_NAME
# )
# self.stylist_agent_kwages = {
# 'local_db': self.vector_db,
# 'max_len': 5,
# 'outfits_root': settings.OUTFIT_OUTPUT_DIR,
# 'image_dir': settings.IMAGE_DIR,
# 'stylist_guide_dir': settings.STYLIST_GUIDE_DIR,
# 'gemini_model_name': settings.LLM_MODEL_NAME
# }
self.gemini_client = genai.Client(
vertexai=True, project='aida-461108', location='us-central1'
)
@@ -62,9 +48,9 @@ class LCChatBot(ls.LitAPI):
chat_history = self.redis.get_history(session_id)
chat_history.append(user_msg)
if request.gender == 'male':
BASIC_PROMPT = MEN_BASIC_PROMPT
prompt = BASIC_PROMPT.format(gender='men')
else:
BASIC_PROMPT = WOMEN_BASIC_PROMPT
prompt = BASIC_PROMPT.format(gender='women')
contents = []
@@ -80,7 +66,7 @@ class LCChatBot(ls.LitAPI):
model='gemini-2.5-flash',
contents=contents,
config=types.GenerateContentConfig(
system_instruction=BASIC_PROMPT,
system_instruction=prompt,
# temperature=0.3,
)
)
@@ -108,3 +94,45 @@ class LCChatBot(ls.LitAPI):
# The for-loop must have async keyword here since output is an AsyncGenerator
async for out in output:
yield {"output": out}
if __name__ == "__main__":
import asyncio
async def run_simple_test():
"""
一个简单的异步测试用例,用于测试 LCChatBot 的流式输出。
"""
print("\n" + "=" * 50)
print("--- 🔬 开始 LCChatBot 简单流式测试 ---")
# 1. 初始化 LitAPI 和其依赖
chatbot_api = LCChatBot()
chatbot_api.setup(device="cpu")
print("✅ Setup complete. Mock services initialized.")
# 2. 构造请求数据
request_data = PredictRequest(
user_id="simple_user",
session_id="simple_session",
user_message="I want an outfit. I am going to a evening party with friends. Suggest something stylish yet comfortable.",
gender="female"
)
chatbot_api.redis.clear_history(request_data.session_id)
print(f"-> 正在发送查询: {request_data.user_message}")
# 3. 调用 predict 方法并处理流
response_generator = chatbot_api.predict(request_data)
print("\n<- 接收流式响应:")
# 4. 异步迭代生成器,实时打印输出
async for chunk in response_generator:
print(chunk, end="", flush=True)
print("\n" + "=" * 50)
# 启动异步事件循环
try:
asyncio.run(run_simple_test())
except Exception as e:
print(f"\n发生致命错误: {e}")

View File

@@ -32,7 +32,7 @@ class AsyncGeminiLLM(AsyncLLMInterface):
except Exception as e:
raise type(e)(f"Failed to initialize Gemini Client. Check if GEMINI_API_KEY is set. Original error: {e}")
async def generate_response(self, history: List[Message], system_prompt: str) -> str:
async def generate_response(self, history: List[Message], system_prompt: str, json_schema=None) -> str:
contents = []
for msg in history:
@@ -44,14 +44,27 @@ class AsyncGeminiLLM(AsyncLLMInterface):
contents.append(content)
try:
response = await self.gemini_client.aio.models.generate_content(
model=self.model_name,
contents=contents,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
# temperature=0.3,
if json_schema:
response = await self.gemini_client.aio.models.generate_content(
model=self.model_name,
contents=contents,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
response_mime_type="application/json",
response_schema=json_schema
)
)
)
return response.text
return response.text
else:
response = await self.gemini_client.aio.models.generate_content(
model=self.model_name,
contents=contents,
config=types.GenerateContentConfig(
system_instruction=system_prompt,
# temperature=0.3,
)
)
return response.text
except Exception as e:
raise type(e)(f"Gemini API call failed: {e}")

View File

@@ -6,32 +6,22 @@ import os
import random
import uuid
from typing import List, Dict, Any, Optional
from copy import deepcopy
from google import genai
from google.cloud import storage
from google.oauth2 import service_account
from app.core.utils_litserve import merge_images_to_square
from app.server.utils.img_operation import merge_images_to_square
from app.server.utils.minio_client import minio_client, oss_upload_image
from app.server.utils.request_post import post_request
from app.config import settings
from app.taxonomy import CLOTHING_CATEGORY, ACCESSORY_CATEGORY
logger = logging.getLogger(__name__)
class AsyncStylistAgent:
CATEGORY_SET = {
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Shoes',
# 取消推荐配饰
# 'Swimwear', 'Underwear',
# , 'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Scarves & Shawls'
}
CATEGORY_SET_ALL = {
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear',
'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry',
'Briefcases', 'Neckties', 'Shoes', 'Scarves & Shawls',
# 'Socks',
}
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
# self.outfit_items: List[Dict[str, str]] = []
self.outfit_id = outfit_id
@@ -42,6 +32,56 @@ class AsyncStylistAgent:
self.max_len = max_len
self.gemini_model_name = gemini_model_name
self.stop_reason = ""
self.headers = {
'Accept': "*/*",
'Accept-Encoding': "gzip, deflate, br",
'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0",
'Connection': "keep-alive",
'Content-Type': "application/json"
}
self.main_clothing_schema = {
"type": "object",
"properties": {
"action": {"type": "string", "enum": ["recommend_item", "stop"]},
"category": {
"type": "string",
"description": "The category of the single clothing item being recommended in this step (e.g., 'outerwear', 'bottoms'). Only present if action is 'recommend_item'.",
"enum": CLOTHING_CATEGORY
},
"description": {
"type": "string",
"description": "an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database. It should include Color, Fit/Silhouette, Material/Detail, Role in the Outfit."
},
"reason": {"type": "string", "description": "The reason for the current action. Required if action is 'stop' (to summarize the final outfit)."}
},
"required": ["action"]
}
self.accessory_schema = {
"type": "object",
"properties": {
"reason": {
"type": "string",
"description": "The justification for completing the recommendation and the summary of the final outfit."
},
"recommended_accessories": {
"type": "array",
"description": "A list of accessories recommended to complete the outfit.",
"items": {
"type": "object",
"properties": {
"category": {
"type": "string",
"description": "The category of the accessory (e.g., jewelry, watches, bags).",
"enum": ACCESSORY_CATEGORY
},
"description": {"type": "string", "description": "The detailed description for this accessory item."}
},
"required": ["category", "description"]
}
}
},
"required": ["recommended_accessories", "reason"]
}
# 存储桶配置
try:
@@ -57,67 +97,42 @@ class AsyncStylistAgent:
self.gcs_bucket = "lc_stylist_agent_outfit_items"
self.minio_bucket = "lanecarford"
def _load_style_guide(self, path: str):
def _load_style_guide(self, stylist_name: str):
"""加载 markdown 风格指南内容。"""
parts = path.split('/', 1)
if len(parts) != 2:
raise ValueError("MinIO path must be in 'bucket_name/object_name' format.")
bucket_name, object_name = parts
guide_path = os.path.join(settings.STYLIST_GUIDE_DIR, f"{stylist_name}_en.md")
acc_guide_path = os.path.join(settings.STYLIST_GUIDE_DIR, f"{stylist_name}_acc.md")
try:
# 获取对象 读取内容
response = minio_client.get_object(bucket_name, object_name)
content_bytes = response.read()
json_response = minio_client.get_object(bucket_name, object_name.replace('.md', '.json'))
json_data = json_response.data
# 关闭连接
response.close()
json_response.close()
response.release_conn()
json_response.release_conn()
# 4. 解析 JSON 字符串
json_string = json_data.decode('utf-8')
json_content = json.loads(json_string)
return content_bytes.decode('utf-8'), json_content
with open(guide_path, 'r', encoding='utf-8') as file:
stylist_guide = file.read()
with open(acc_guide_path, 'r', encoding='utf-8') as file:
accessories_guide = file.read()
return stylist_guide, accessories_guide
except Exception as e:
raise Exception(f"Failed to load style guide from {path}: {e}")
raise Exception(f"Failed to load style guide from {guide_path}, {acc_guide_path}: {e}")
def _build_system_prompt(self, request_summary: str = "", gender: str = "male") -> str:
def _build_main_clothing_prompt(self, request_summary: str = "", gender: str = "male", stylist_guide: str = "") -> str:
"""Constructs the complete System Prompt."""
clothing_gender = "women's clothing"
if gender == "male":
clothing_gender = "men's clothing"
elif gender == "female":
clothing_gender = "women's clothing"
clothing_gender = "men's clothing" if gender == "male" else "women's clothing"
# Insert the style_guide content into the template
template = template = f"""
You are a professional fashion stylist Agent, specialized in creating complete, tailored outfits exclusively for {clothing_gender}.
You are a professional fashion stylist Agent, specialized in creating complete, tailored outfits for {clothing_gender}. Only main clothing including 'bags' is needed, excluding accessories like 'jewelry', 'hats', 'belts', etc.
Your task is to **create a cohesive and complete outfit**, strictly adhering to **BOTH** the user's explicit **Request Summary** and the **Outfit Style Guide**. You must decide the next logical item to add to the outfit based on the currently selected items (if any).
---
## Request from the User:
{request_summary}
## Core Guidance Document: Outfit Style Guide
{self.style_guide}
{stylist_guide}
---
## Your Workflow and Constraints
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions , shoe coordination**.
2. **Category Uniqueness Mandate**: Every outfit must follow the **absolute no-repeat rule for clothing categories** — each category from the allowed list ({list(self.CATEGORY_SET)}) can appear **exactly once** in the entire outfit. This rule is non-negotiable, even if the user explicitly requests repeating a category.
2. **Category Uniqueness Mandate**: Every outfit must follow the **absolute no-repeat rule for clothing categories** — each category from the allowed list can appear **exactly once** in the entire outfit. This rule is non-negotiable, even if the user explicitly requests repeating a category. Furthermore, the categories 'dresses' and 'pants' and 'skirts' are mutually exclusive; they NORMALLY cannot be included in the same outfit.
3. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses) then shoes. When selecting the next item, prioritize unused categories from the allowed list to avoid repetition.
4. **Structured Output**: Every response must recommend the **next single item** (from an unused category). You must strictly use the **JSON format** for your output, as follows:
@@ -130,7 +145,7 @@ class AsyncStylistAgent:
```
* `action`: Must always be `"recommend_item"` until the outfit is complete.
* `category`: Must be an unused category from the following list: {list(self.CATEGORY_SET)} (strictly no repeats, per the Category Uniqueness Mandate).
* `category`: Must be an unused category from the following list: {CLOTHING_CATEGORY} (strictly no repeats, per the Category Uniqueness Mandate).
* `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include:
* **Color** (e.g., milk tea, pure white, dark gray)
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
@@ -147,112 +162,125 @@ class AsyncStylistAgent:
"reason": "OUTFIT_COMPLETE_AND_MEETS_ALL_MINI_GUIDELINES"
}}
```
Normally, five or six items are totally enough for an outfit.
Normally, {self.max_len} items are totally enough for an outfit.
6. **Context Dependency**: The user's next input (if not Start) will contain the **image and description of the selected item**. When recommending the next item:
a) First verify the categories of all already selected items to ensure no duplicates;
b) Select an unused category from the allowed list ({list(self.CATEGORY_SET)}) as the priority;
b) Select an unused category from the allowed list as the priority;
c) Ensure the recommended item coordinates with the already selected items and complies with all rules in the Style Guide.
Now, please start building an outfit (with strictly unique categories for all items) and output the JSON for the first item.
"""
return template.strip()
def _clear_uploaded_files(self):
for f in self.gemini_client.files.list():
self.gemini_client.files.delete(name=f.name)
async def _call_gemini(self, user_input: str, user_id: str):
def _build_accessory_prompt(self, request_summary: str, gender: str, accessories_guide: str) -> str:
"""
实际调用 Gemini API 的函数,接受文本和可选的图片路径列表
构建配饰推荐 (Accessories) 的 System Prompt
特点:强调基于现有穿搭 (Context Aware),批量推荐 (Batch Recommendation),做最后的点缀。
"""
clothing_gender = "men's clothing" if gender == "male" else "women's clothing"
template = f"""
You are an expert Accessories Stylist for {clothing_gender}.
Your task is to select the perfect set of accessories to complete an existing outfit.
---
## CONTEXT
[User Request]: {request_summary}
[Accessories Style Guide]:
{accessories_guide}
---
## STRICT RULES
1. **Batch Recommendation**: Do NOT recommend items one by one. You must output the **COMPLETE LIST** of accessories (e.g., jewelry, bag, watch, hat) in a single response using the 'recommended_accessories' list.
2. **Allowed Categories**: Select only from: {ACCESSORY_CATEGORY}.
3. **Harmony & Constraints**:
- The accessories must complement the [Current Outfit Base].
- Strictly follow the [Accessories Style Guide] regarding metals (gold/silver), numbers, and prohibited items.
- If the guide mandates a watch or specific jewelry layering, ensure they are included.
4. **Quantity**: Typically recommend 2-4 distinct accessory items to complete the look.
Generate the final accessories list now.
"""
return template.strip()
async def _call_gemini(self, user_input: str, user_id: str, file_name: str, output_schema: Dict[str, Any], image_bytes: bytes = None, system_prompt: str = "") -> str:
"""
实际调用 Gemini API 的函数接受文本和用户的id。
会在这个函数中merge图片然后上传到google cloud供gemini参考。
Args:
user_input: 发送给模型的主文本内容。
image_paths: 待发送图片的本地路径列表
user_id: 用户id
file_name: 用于存储图片的文件名。
image_bytes: 可选的图片字节数据。
Returns:
模型的响应文本(预期为 JSON 字符串)。
"""
minio_path = ""
content_parts = []
# self._clear_uploaded_files()
# 1. 添加图片内容
if self.outfit_items:
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len + 1, add_text=False)
image_bytes_io = io.BytesIO()
image_format = 'JPEG'
mime_type = 'image/jpeg'
merged_image.save(image_bytes_io, format=image_format)
image_bytes = image_bytes_io.getvalue()
file_name = uuid.uuid4()
if image_bytes:
blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg"
gcs_path = self._upload_to_gcs(bucket_name=self.gcs_bucket, blob_name=blob_name, mime_type=mime_type, image_bytes=image_bytes)
responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes)
minio_path = f"{responses.bucket_name}/{responses.object_name}"
gcs_path = self._upload_to_gcs(bucket_name=self.gcs_bucket, blob_name=blob_name, mime_type='image/jpeg', image_bytes=image_bytes)
content_parts.append(gcs_path)
# 2. 添加文本内容
content_parts.append(user_input)
# print(f"\n--- Calling Gemini with {len(self.outfit_items) if self.outfit_items else 0} images and query:\n{user_input}")
try:
# 3. 实际 API 调用
response = await self.gemini_client.aio.models.generate_content(
model=self.gemini_model_name,
contents=content_parts,
config={
"system_instruction": self.system_prompt,
"system_instruction": system_prompt,
# 确保模型返回 JSON 格式
"response_mime_type": "application/json",
"response_schema": {
"type": "object",
"properties": {
"action": {"type": "string", "enum": ["recommend_item", "stop"]},
"category": {"type": "string"},
"description": {"type": "string"},
"reason": {"type": "string"}
},
"required": ["action"]
}
"response_schema": output_schema
}
)
# response.text 将包含一个 JSON 字符串
return response.text, minio_path
return response.text
except Exception as e:
print(f"Gemini API Call failed: {e}")
# 返回一个停止信号以防止循环继续
return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"})
async def _merge_images(self, user_id: str):
async def _merge_images(self, file_name: str, user_id: str, stylist_name: str):
"""
实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。
把所有的item图片组成一张图片并保存到jpg文件
Args:
user_input: 发送给模型的主文本内容。
image_paths: 待发送图片的本地路径列表。
user_id: 用户的id
stylist_name: 造型师的name
Returns:
模型的响应文本(预期为 JSON 字符串)。
"""
minio_path = ""
if self.outfit_items:
merged_image = merge_images_to_square(self.outfit_items, max_len=9, add_text=False)
image_bytes_io = io.BytesIO()
image_format = 'JPEG'
(存储的路径, 内存图片数据)
"""
if not self.outfit_items:
return "", None
merged_image.save(image_bytes_io, format=image_format)
image_bytes = image_bytes_io.getvalue()
file_name = uuid.uuid4()
merged_image = merge_images_to_square(self.outfit_items, max_len=9, add_text=False)
image_bytes_io = io.BytesIO()
image_format = 'JPEG'
merged_image.save(image_bytes_io, format=image_format)
image_bytes = image_bytes_io.getvalue()
if settings.LOCAL == 1:
local_dir = os.path.join(settings.OUTFIT_OUTPUT_DIR, stylist_name)
os.makedirs(local_dir, exist_ok=True)
local_file_path = os.path.join(local_dir, f"{file_name}.jpg")
with open(local_file_path, 'wb') as f:
f.write(image_bytes)
return local_file_path, image_bytes
else:
blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg"
responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes)
minio_path = f"{responses.bucket_name}/{responses.object_name}"
return minio_path
return minio_path, image_bytes
def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]:
"""安全解析 Gemini 的 JSON 响应。"""
@@ -267,7 +295,7 @@ class AsyncStylistAgent:
print(f"Raw response: {response_text}")
return None
def _get_next_item(self, item_description: str, category: str) -> Optional[Dict[str, str]]:
def _get_next_item(self, item_description: str, category: str, occasions: List[str], batch_source: str = "2025_q4", gender: str = "female") -> Optional[Dict[str, str]]:
"""
1. 根据描述生成嵌入。
2. 查询本地数据库以找到最佳匹配项。
@@ -278,92 +306,30 @@ class AsyncStylistAgent:
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False)
# 2. 执行查询,并过滤类别
results = self.local_db.query_local_db(query_embedding, category, n_results=1)
results = self.local_db.get_matched_item(query_embedding, category, occasions=occasions, batch_source=batch_source, gender=gender, n_results=1)
if not results:
print(f"❌ 数据库中未找到符合 '{category}' 和描述的单品。")
return None
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
best_meta = results['metadatas'][0][0] # 第一个 batch 的第一个 metadata
best_meta = results[0] # 第一个 batch 的第一个 metadata
item_id = best_meta['item_id'].replace("_img", "")
return {
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
"category": category,
"item_id": item_id, # 从 metadata 字典中安全获取
"category": best_meta['category'],
"gpt_description": item_description,
'description': best_meta['description'],
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
# 这里假设 item_id 就是文件名的一部分
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
"image_path": os.path.join(f"{item_id}.jpg")
}
except Exception as e:
print(f"An error occurred during item retrieval: {e}")
return None
async def _get_random_accessories(self, stylist, item_count):
stylist_item = []
stylist_item_ids = []
# 初始过滤类别
filter_items = [
{"item_group_id": {"$ne": "Clothing"}},
{"item_group_id": {"$ne": "Shoes"}},
{"category": {"$ne": "Socks"}},
{"modality": "image"}
]
random_items = []
for i in stylist:
# 1. 根据stylist要求抽取item
query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False)
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
filter_items.append({"category": {"$ne": i["category"]}})
accessories_count = 9 - item_count - len(stylist_item)
if accessories_count > 0:
if accessories_count > 4:
accessories_count = 4
for i in range(accessories_count):
# 2. 在配饰池中过滤掉已经选中的item 然后抽两个item
random_poll = self.local_db.load_filtered_ids(filter_items)
logger.info(f"random_poll 数量: {len(random_poll)}")
item = self.local_db.random_get_accessories(random.choice(random_poll))
# 如果随机选中了包类 则所有包类别都过滤掉
if item['metadatas'][0]['category'] in ['Shopping Totes', 'Handbags', 'Backpacks', 'Briefcases']:
filter_items.append({"category": {"$ne": "Shopping Totes"}})
filter_items.append({"category": {"$ne": "Handbags"}})
filter_items.append({"category": {"$ne": "Backpacks"}})
filter_items.append({"category": {"$ne": "Briefcases"}})
else:
filter_items.append({"category": {"$ne": item['metadatas'][0]['category']}})
random_items.append(item['metadatas'][0])
all_items = stylist_item + random_items
else:
all_items = stylist_item
items_data = []
for best_meta in all_items:
items_data.append({
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
"category": best_meta['category'],
"gpt_description": best_meta['description'],
'description': best_meta['description'],
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
# 这里假设 item_id 就是文件名的一部分
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
})
return items_data
def _build_user_input(self) -> str:
def _build_user_input(self, recommend_acc=False) -> str:
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
if not self.outfit_items:
return "Start"
@@ -372,164 +338,145 @@ class AsyncStylistAgent:
context = "Selected fashion items:\n"
for ii, item in enumerate(self.outfit_items):
context += f"{ii + 1}. Category: {item['category']}. Description: {item['description']}\n"
context += "\nPlease recommend the next single item based on the selected items, user's request, and style guide."
if not recommend_acc:
context += "\nPlease recommend the next single item based on the selected items, user's request, and style guide."
else:
context += "\nPlease recommend a complete list of accessories to complement the selected outfit based on the user's request and accessories style guide."
return context
def post_operation(self, response_data: Dict[str, Any], status: str, message: str, callback_url: str):
"""处理完成后的回调操作。"""
if settings.LOCAL == 0:
response_data['items'] = deepcopy(self.outfit_items)
response_data['status'] = status
response_data['message'] = message
response = post_request(url=callback_url, data=json.dumps(response_data), headers=self.headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
async def run_styling_process(self, request_summary, stylist_path, start_outfit=None, user_id="test", callback_url="", gender: str = "male"):
if start_outfit is None:
start_outfit = []
self.outfit_items = start_outfit if start_outfit else []
async def run_styling_process(self, request_summary, occasions, stylist_name, batch_source="2025_q4", start_outfit=[], user_id="test", callback_url="", gender: str = "male"):
self.outfit_items = start_outfit
"""主流程控制循环。"""
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
self.style_guide, self.style_accessories_guide = self._load_style_guide(stylist_path)
self.system_prompt = self._build_system_prompt(request_summary, gender)
response_data = {"status": "",
"message": "",
"path": "",
"outfit_id": self.outfit_id,
"items": []
}
logger.info(response_data)
item_id = ""
item_category = ""
headers = {
'Accept': "*/*",
'Accept-Encoding': "gzip, deflate, br",
'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0",
'Connection': "keep-alive",
'Content-Type': "application/json"
stylist_guide, accessories_guide = self._load_style_guide(stylist_name)
system_prompt = self._build_main_clothing_prompt(request_summary, gender, stylist_guide)
response_data = {
"status": "",
"message": "",
"path": "",
"outfit_id": self.outfit_id,
"items": []
}
logger.info(response_data)
url = f'{callback_url}/api/style/callback'
while True:
file_name = self.outfit_id
recommend_timestep = 0
gemini_data = {'action': 'start'}
while recommend_timestep < self.max_len and gemini_data.get('action') != 'stop':
recommend_timestep += 1
# 1. 准备用户输入(上下文)
user_input = self._build_user_input()
# 2. 调用 Gemini Agent
gemini_response_text, minio_path = await self._call_gemini(user_input, user_id)
# 2. 把图片组装起来供api调用
response_data['path'], image_bytes = await self._merge_images(file_name, user_id, stylist_name)
# 3. 调用 Gemini Agent
gemini_response_text = await self._call_gemini(user_input, user_id, file_name, self.main_clothing_schema, image_bytes, system_prompt)
gemini_data = self._parse_gemini_response(gemini_response_text)
response_data['path'] = minio_path
if item_id:
response_data['items'].append({"item_id": item_id, "category": item_category})
if not gemini_data:
# if gemini_data:
print("🚨 Agent 返回无效响应,终止流程。")
self.stop_reason = "Agent failed to return response"
response_data['status'] = "failed"
response_data['message'] = self.stop_reason
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
print("Agent 返回无效响应,终止流程。")
self.post_operation(
response_data,
status="failed",
message="Agent returned invalid response, terminating process.",
callback_url=url
)
break
# 3. 检查终止条件
if gemini_data.get('action') == 'stop':
if is_duplicate_by_key(response_data['items'], {"item_id": item_id, "category": item_category}):
print("重复按item_id判断不插入")
else:
response_data['path'] = minio_path
response_data['items'].append({"item_id": item_id, "category": item_category})
response_data['status'] = "ok"
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
# 根据stylist要求随机增加配饰 3-4个配饰
new_item = await self._get_random_accessories(self.style_accessories_guide, len(self.outfit_items))
for item in new_item:
self.outfit_items.append(item)
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
response_data['path'] = await self._merge_images(user_id)
logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
break
# 4. 处理推荐单品
# 处理推荐单品
if gemini_data.get('action') == 'recommend_item':
category = gemini_data.get('category')
description = gemini_data.get('description')
# 4a. 检查类别是否有效 (重要步骤)
if category not in self.CATEGORY_SET_ALL:
print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。")
# 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正
# 这里简化为跳过本次循环
response_data['status'] = "continue"
response_data['message'] = f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。",
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
if category not in CLOTHING_CATEGORY:
self.post_operation(
response_data,
status="continue",
message=f"Invalid category recommended by Agent: {category}. Requesting Agent to re-output.",
callback_url=url
)
continue
# 4b. 在本地 DB 中查询单品
new_item = self._get_next_item(description, category)
item_id = new_item.get('item_id')
item_category = new_item.get('category')
if new_item:
# 4c. (实际步骤) 将选中的单品图片和描述发回给 Agent 进行最终审核
# 这里的代码框架省略了图片回传和二次审核的步骤,直接视为通过
# 实际你需要: new_user_input = f"Check this item: {new_item['description']}, path: {new_item['image_path']}"
# call_gemini_agent(...) -> 如果返回"pass",则添加到outfit_items
if new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
print("This item exists. Stop here.")
self.stop_reason = "Finish reason: Duplicate item selected."
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
break
if new_item['item_id'] == "ELG383":
if random.random() < 0.70:
self.stop_reason = "Finish reason: ELG383 is seleced repeatly."
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
break
self.outfit_items.append(new_item)
# print(f" 成功添加单品: {new_item['category']} ({new_item['item_id']}). 当前搭配数量: {len(self.outfit_items)}")
response_data['status'] = "ok"
response_data['message'] = self.stop_reason
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
new_item = self._get_next_item(description, category, occasions, batch_source, gender)
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
self.post_operation(
response_data,
status="continue",
message=f"No matching item is found or item duplicated. Ask Gemini to re-output.",
callback_url=url
)
continue
else:
print("⚠️ 未找到匹配单品,无法继续搭配。终止。")
self.stop_reason = "Finish reason: No matching item found in local database."
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
break
self.outfit_items.append(new_item)
self.post_operation(
response_data,
status="ok",
message=f"Add new item {new_item['item_id']} in category {new_item['category']} successfully.",
callback_url=url
)
print(f"Step {recommend_timestep}: {gemini_data}, found item: {new_item}")
if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
response_data['items'].append({"item_id": self.outfit_items[-1]['item_id'], "category": self.outfit_items[-1]['category']})
response_data['status'] = "ok"
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
# When action is stop or timestep limit reached
logger.info(f"Main clothing stylist process finished: {gemini_data.get('reason')}")
# 根据stylist要求随机增加配饰 3-4个配饰
response_data['path'], image_bytes = await self._merge_images(file_name, user_id, stylist_name)
accessory_system_prompt = self._build_accessory_prompt(request_summary, gender, accessories_guide)
user_input = self._build_user_input(recommend_acc=True)
gemini_response_text = await self._call_gemini(user_input, user_id, file_name, self.accessory_schema, image_bytes, accessory_system_prompt)
gemini_data = self._parse_gemini_response(gemini_response_text)
# 根据stylist要求随机增加配饰 3-4个配饰
new_item = await self._get_random_accessories(self.style_accessories_guide, len(self.outfit_items))
for item in new_item:
self.outfit_items.append(item)
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
response_data['path'] = await self._merge_images(user_id)
recommended_accessories = gemini_data.get('recommended_accessories', [])
reason = gemini_data.get('reason', '')
if not recommended_accessories or not isinstance(recommended_accessories, List):
print("No accessory data from Gemini, terminating process.")
self.post_operation(
response_data,
status="failed",
message="Agent returned invalid response, terminating process.",
callback_url=url
)
else:
for idx, rec_accessory in enumerate(recommended_accessories):
category = rec_accessory.get('category')
description = rec_accessory.get('description')
logger.info("🚨 达到最大搭配数量限制,强制终止。")
self.stop_reason = "Finish reason: Reached max outfit length."
response_data['status'] = "stop"
response_data['message'] = self.stop_reason
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
break
# 4a. 检查类别是否有效 (重要步骤)
if category not in ACCESSORY_CATEGORY:
continue
# 4b. 在本地 DB 中查询单品
new_item = self._get_next_item(description, category, occasions, batch_source, gender)
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
continue
else:
self.outfit_items.append(new_item)
print(f"Accessory {idx + 1}: {rec_accessory}, found item: {new_item}")
response_data['path'] = await self._merge_images(file_name, user_id, stylist_name)
self.post_operation(
response_data,
status="stop",
message=reason,
callback_url=url
)
with open(os.path.join(settings.OUTFIT_OUTPUT_DIR, stylist_name, f'{file_name}.json'), 'w') as f:
json.dump(self.outfit_items, f, indent=2)
return response_data
def _upload_to_gcs(self, bucket_name: str, blob_name: str, mime_type, image_bytes) -> str:
@@ -543,11 +490,3 @@ class AsyncStylistAgent:
gcs_uri = f"gs://{bucket_name}/{blob_name}"
return gcs_uri
def is_duplicate_by_key(data, target_item):
"""基于item_id快速判断重复"""
# 提取所有item_id到集合
existing_ids = {item['item_id'] for item in data}
# 判断目标item_id是否在集合中
return target_item['item_id'] in existing_ids

View File

@@ -1,27 +1,4 @@
BASIC_PROMPT = """"""
WOMEN_BASIC_PROMPT = """You are a professional, friendly, and insightful AI women's styling assistant.
Your primary mission is to engage in a multi-turn conversation with the user to fully understand their dressing intent. You must adopt a professional yet approachable tone.
CONVERSATION GOALS:
1. **Occasion:** Determine the specific event (e.g., romantic dinner, summer wedding, business meeting).
2. **Style:** Pinpoint the desired aesthetic (e.g., classic elegance, edgy, minimalist, bohemian).
3. **Vibe/Details:** Gather any mood or specific constraints (e.g., needs to be comfortable, requires light colors, no bare shoulders).
4. **Item Preference:** Ask the user if they have any specific preferences for an item type or silhouette (e.g., preference for a dress, skirt, tailored pants, or a particular neckline/length).
GUIDANCE FOR RESPONSE GENERATION:
- After the user's initial request (e.g., "I want a chic outfit for dinner."), immediately reply with a friendly, targeted follow-up question to elicit the most crucial missing information (usually a combination of **Occasion** and **Style**).
- Be concise. Ask only 1 to 2 essential questions per turn.
- You must gather sufficient, clear intent before proceeding to actual clothing recommendations.
OUTPUT FORMAT INSTRUCTION:
- **DO NOT** use any Markdown formatting whatsoever (e.g., do not use asterisks (*), bold text (**), lists, or code blocks).
- **ONLY** output the plain text response spoken by the AI Assistant.
Example Follow-up (mimicking a conversational flow):
User: I want a chic outfit for dinner.
Your Response: Hey there! A chic dinner outfit, I love that! To give you the perfect recommendations, tell me: is this a romantic date, business dinner, or celebration with friends? And what's your go-to style vibe: classic elegance or something with more edge?"""
MEN_BASIC_PROMPT = """You are a professional, friendly, and insightful AI men's styling assistant.
BASIC_PROMPT = """You are a professional, friendly, and insightful AI {gender}'s styling assistant.
Your primary mission is to engage in a multi-turn conversation with the user to fully understand their dressing intent. You must adopt a professional yet approachable tone.
@@ -44,13 +21,12 @@ Example Follow-up (mimicking a conversational flow):
User: I want a chic outfit for dinner.
Your Response: Hey there! A chic dinner outfit, I love that! To give you the perfect recommendations, tell me: is this a romantic date, business dinner, or celebration with friends? And what's your go-to style vibe: classic elegance or something with more edge?"""
SUMMARY_PROMPT = """Analyze the following chat history. Your task is to extract all user intentions, scenarios, style preferences, and constraints expressed during the conversation, and distill them into a concise, structured JSON object.
SUMMARY_PROMPT = """
You are an expert fashion request analyzer. Analyze the conversation history provided by the user.
Your task is to:
**YOUR OUTPUT MUST BE A JSON OBJECT ONLY, WITH NO SURROUNDING TEXT, MARKDOWN, OR EXPLANATION.**
1. Identify the most appropriate occasions from the allowed list based on the user's intent.
2. Write a detailed summary string that captures the user's style preferences, specific item requests, disliked items, body concerns, and color preferences. This summary will be used by a stylist to recommend outfits.
JSON FIELD REQUIREMENTS:
- **occasion (string):** The specific event and purpose (e.g., "Romantic date dinner", "Summer outdoor wedding", "Casual Friday at office").
- **style (string):** The overall aesthetic description (e.g., "Classic elegance", "Modern minimalist", "Bohemian vibe", "Edgy and contemporary").
- **color_preference (string or list):** User's preferred or excluded colors/tones (e.g., "Light colors only", "Avoid deep shades", "['Cream', 'Pale Blue']", "No preference").
- **clothing_type (string):** User's preference for specific garment types, material, or silhouette (e.g., "Lightweight maxi dress", "Skirt with silk blouse", "Tailored wide-leg pants", "Floral print").
- **vibe_or_details (string):** Any other details, mood requirements, or specific constraints (e.g., "Needs to be comfortable and breathable", "Must cover shoulders")."""
Extract this information accurately from the chat history.
"""

View File

@@ -1,163 +0,0 @@
import logging
from typing import List, Dict
from PIL import Image, ImageDraw, ImageFont
from app.server.utils.minio_client import oss_get_image, minio_client
from app.server.utils.minio_config import MINIO_LC_DATA_PATH
logger = logging.getLogger(__name__)
# 9个 341x341 左右的单元格 (ALL_9_CELLS)
# 布局顺序: 从上到下,从左到右 (1 -> 9)
ALL_9_CELLS = [
# Top Row (Y=0, H=341)
(0, 0, 341, 341), # 1. Top-Left (341x341)
(341, 0, 341, 341), # 2. Top-Middle (341x341)
(682, 0, 342, 341), # 3. Top-Right (342x341)
# Middle Row (Y=341, H=341)
(0, 341, 341, 341), # 4. Mid-Left (341x341)
(341, 341, 341, 341), # 5. Center (341x341)
(682, 341, 342, 341), # 6. Mid-Right (342x341)
# Bottom Row (Y=682, H=342)
(0, 682, 341, 342), # 7. Bottom-Left (341x342)
(341, 682, 341, 342), # 8. Bottom-Middle (341x342)
(682, 682, 342, 342) # 9. Bottom-Right (342x342)
]
def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, add_text=True):
"""
Loads up to 4 images from the given paths, resizes them while maintaining
aspect ratio, and merges them onto a 1024x1024 white background JPG.
The layout depends on the number of images:
1: Center the single image on the 1024x1024 canvas.
2: Place side-by-side, each scaled to fit a 512x1024 half.
3: Place in top-left (512x512), top-right (512x512), and bottom-left (512x512).
4: Place in all four 512x512 quadrants.
Args:
outfit_items: A list of item metadata (max length 9).
Returns:
The file path of the temporary merged JPG image.
"""
# Define the final canvas size
CANVAS_SIZE = 1024
# 1. Create the final white canvas
# Using 'RGB' mode for JPG output
canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), 'white')
draw = ImageDraw.Draw(canvas)
font = ImageFont.load_default()
# 2. Define the quadrants/target areas (x, y, w, h)
# The positions are based on a 512x512 quadrant size
quadrants = {
1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement
2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right
3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left
4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four
5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5)
6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6)
7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7)
8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8)
9: ALL_9_CELLS[:9] # 布局全部9个单元格 (1-9)
}
# 3. Load and Filter Images
valid_images = []
image_paths = [item['image_path'] for item in outfit_items]
for path in image_paths:
try:
# We use Image.open() and convert to 'RGB' to handle potential transparency (RGBA)
# and ensure compatibility with the final 'RGB' canvas and JPG output.
img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB')
# img = Image.open(path).convert('RGB')
valid_images.append(img)
except Exception as e:
logger.error(f"Error loading image {path}. Skipping: {e}")
num_images = len(valid_images)
if num_images == 0:
raise ValueError("No valid images were loaded.")
if num_images > max_len:
raise ValueError(f"Valid item number {num_images} exceed max limit {max_len}")
# Get the correct list of target areas based on the number of valid images
target_areas = quadrants.get(num_images, [])
# 4. Resize and Paste
for i, (img, item) in enumerate(zip(valid_images, outfit_items)):
item_id = item['item_id']
category = item['category']
if i >= len(target_areas):
# This should not happen if num_images <= 4
break
# Target area dimensions (x_start, y_start, width, height)
x_start, y_start, target_w, target_h = target_areas[i]
# Calculate new size while maintaining aspect ratio
original_w, original_h = img.size
# Calculate the ratio needed to fit within the target area
ratio_w = target_w / original_w
ratio_h = target_h / original_h
# Use the *smaller* of the two ratios to ensure the image fits entirely
resize_ratio = min(ratio_w, ratio_h)
# Calculate the new dimensions
new_w = int(original_w * resize_ratio)
new_h = int(original_h * resize_ratio)
# Resize the image. Image.Resampling.LANCZOS provides high-quality scaling.
# Pillow documentation recommends ANTIALIAS or BICUBIC for downscaling,
# but LANCZOS is a good general high-quality filter.
# Note: In Pillow versions > 9.0.0, Image.LANCZOS is now Image.Resampling.LANCZOS
resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Calculate the paste position to center the resized image within its target area
# Center X: (Target Width - New Width) / 2 + X Start
paste_x = (target_w - new_w) // 2 + x_start
# Center Y: (Target Height - New Height) / 2 + Y Start
# paste_y = (target_h - new_h) // 2 + y_start
TEXT_RESERVE_HEIGHT = 30
paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start
paste_y = max(paste_y, y_start)
# Paste the resized image onto the canvas
canvas.paste(resized_img, (paste_x, paste_y))
full_text = f"ID: {item_id}, Category: {category}"
try:
# 推荐使用:计算文本的实际尺寸 (width, height)
bbox = draw.textbbox((0, 0), full_text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
except AttributeError:
# 兼容旧版本 Pillow
text_w, text_h = draw.textsize(full_text, font=font)
# 计算 X 轴起始位置:使其在目标区域 (target_w) 中居中
text_x_center = x_start + target_w // 2
text_x_start = text_x_center - text_w // 2
# 计算 Y 轴起始位置:将其放在目标区域的底部
# (目标区域的起始Y + 目标区域的高度 - 文本行的高度)
text_y_start = y_start + target_h - text_h - 5 # 减去 5 像素作为边距
# 3. 绘制合并后的文本
if add_text:
draw.text((text_x_start, text_y_start),
full_text,
fill='black',
font=font)
# Save as a high-quality JPG (quality=90 is a good balance)
# canvas.save(output_path, 'JPEG', quality=90)
return canvas

View File

@@ -1,18 +1,28 @@
import random
import time
import numpy as np
import torch
import chromadb
from PIL import Image
from typing import List, Dict, Any
from transformers import CLIPProcessor, CLIPModel
from app.taxonomy import CATEGORY, OCCASION
class VectorDatabase():
def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str):
self.client = chromadb.PersistentClient(path=vector_db_dir)
self.collection = self.client.get_or_create_collection(name=collection_name)
self.collection = self.client.get_or_create_collection(
name=collection_name,
configuration={
"hnsw": {
"space": "cosine",
}
}
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -48,25 +58,87 @@ class VectorDatabase():
return features.cpu().numpy().flatten().tolist()
def query_local_db(self, embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]:
def query_local_db(self, embedding: List[float], category: str, occasions: List[str] = [], n_results: int = 3) -> List[Dict[str, Any]]:
"""
基于嵌入向量在本地数据库中查询相似单品。
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
实际应执行 ChromaDB 查询并根据 category 进行过滤(metadatas)。
"""
# 实际应执行向量查询
# 为了演示流程,返回一个模拟结果
for occasion in occasions:
where_clauses = {
"$and": [
{"category": category},
{"modality": "image"},
{"batch_source": '2025_q4'}
]
}
if occasion not in OCCASION:
continue
else:
where_clauses['$and'].append({occasion: 1})
results = self.collection.query(
query_embeddings=[embedding],
n_results=n_results,
where=where_clauses,
include=['metadatas', 'distances']
)
return results
def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_source: str = "2025_q4", gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
results = self.collection.query(
query_embeddings=[embedding],
n_results=n_results,
n_results=500,
where={
"$and": [
{"category": category},
{"modality": "image"},
{"gender": gender},
{"batch_source": batch_source}
]
},
include=['documents', 'metadatas', 'distances']
include=['metadatas', 'distances']
)
return results
if not results['ids'][0]:
return []
metadatas = results['metadatas'][0] # List[Dict[str, Any]]
final_scores = []
for idx, metadata in enumerate(metadatas):
dist_img = results['distances'][0][idx]
score_vec = 1 - dist_img # cosine similarity range: [-1, 1]
score_occ = 0.0
if occasions:
count = 0
for occ in occasions:
if occ not in OCCASION:
continue
count += 1
status_val = metadata.get(occ, -1)
if status_val == 1:
score_occ += 1.0
elif status_val == 0:
score_occ += 0.0
else:
score_occ -= 100.0
score_occ = score_occ / count if count else 0.0
final_score = 0.6 * score_vec + 0.3 * score_occ
final_scores.append(final_score)
scores_arr = np.array(final_scores)
temperature = 0.5
scores_arr = scores_arr / temperature
# Softmax: 将分数转换为概率
exp_scores = np.exp(scores_arr - np.max(scores_arr))
probabilities = exp_scores / np.sum(exp_scores)
# 采样 (或直接取 Top 1)
sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样
sampled_items = [metadatas[i] for i in sampled_index]
return sampled_items
def load_filtered_ids(self, filter_item):
# print("\n--- 初始化阶段:加载所有符合条件的 ID ---")

View File

@@ -0,0 +1,326 @@
import logging
import os
from typing import List, Dict
from PIL import Image, ImageDraw, ImageFont
from app.server.utils.minio_client import oss_get_image, minio_client
from app.server.utils.minio_config import MINIO_LC_DATA_PATH
from app.config import settings
logger = logging.getLogger(__name__)
# 9个 341x341 左右的单元格 (ALL_9_CELLS)
# 布局顺序: 从上到下,从左到右 (1 -> 9)
ALL_9_CELLS = [
# Top Row (Y=0, H=341)
(0, 0, 341, 341), # 1. Top-Left (341x341)
(341, 0, 341, 341), # 2. Top-Middle (341x341)
(682, 0, 342, 341), # 3. Top-Right (342x341)
# Middle Row (Y=341, H=341)
(0, 341, 341, 341), # 4. Mid-Left (341x341)
(341, 341, 341, 341), # 5. Center (341x341)
(682, 341, 342, 341), # 6. Mid-Right (342x341)
# Bottom Row (Y=682, H=342)
(0, 682, 341, 342), # 7. Bottom-Left (341x342)
(341, 682, 341, 342), # 8. Bottom-Middle (341x342)
(682, 682, 342, 342) # 9. Bottom-Right (342x342)
]
def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, add_text=True):
"""
Loads up to 4 images from the given paths, resizes them while maintaining
aspect ratio, and merges them onto a 1024x1024 white background JPG.
The layout depends on the number of images:
1: Center the single image on the 1024x1024 canvas.
2: Place side-by-side, each scaled to fit a 512x1024 half.
3: Place in top-left (512x512), top-right (512x512), and bottom-left (512x512).
4: Place in all four 512x512 quadrants.
Args:
outfit_items: A list of item metadata (max length 9).
Returns:
The file path of the temporary merged JPG image.
"""
# Define the final canvas size
CANVAS_SIZE = 1024
# 定义每个 item 的外边距
MARGIN = 5 # 5像素外边距
# 1. Create the final white canvas
# Using 'RGB' mode for JPG output
canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), 'white')
draw = ImageDraw.Draw(canvas)
font = ImageFont.load_default()
# 2. Define the quadrants/target areas (x, y, w, h)
# The positions are based on a 512x512 quadrant size
quadrants = {
1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement
2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right
3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left
4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four
5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5)
6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6)
7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7)
8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8)
9: ALL_9_CELLS[:9] # 布局全部9个单元格 (1-9)
}
# 3. Load and Filter Images
valid_images = []
image_paths = [item['image_path'] for item in outfit_items]
for path in image_paths:
try:
# We use Image.open() and convert to 'RGB' to handle potential transparency (RGBA)
# and ensure compatibility with the final 'RGB' canvas and JPG output.
if settings.LOCAL == 1:
image_file_path = os.path.join(settings.LOCAL_IMAGE_DIR, path)
img = Image.open(image_file_path).convert('RGB')
else:
img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB')
# img = Image.open(path).convert('RGB')
valid_images.append(img)
except Exception as e:
logger.error(f"Error loading image {path}. Skipping: {e}")
num_images = len(valid_images)
if num_images == 0:
raise ValueError("No valid images were loaded.")
if num_images > max_len:
raise ValueError(f"Valid item number {num_images} exceed max limit {max_len}")
# Get the correct list of target areas based on the number of valid images
target_areas = quadrants.get(num_images, [])
if not target_areas:
raise ValueError(f"No layout defined for {num_images} images.")
# 4. Resize and Paste
for i, (img, item) in enumerate(zip(valid_images, outfit_items)):
item_id = item['item_id']
category = item['category']
if i >= len(target_areas):
# This should not happen if num_images <= 4
break
# 原始目标区域 (x_start, y_start, width, height)
orig_x_start, orig_y_start, orig_w, orig_h = target_areas[i]
# 📢 应用边距:实际用于图像和文本的区域
# 新的起始点:向内移动 MARGIN
x_start = orig_x_start + MARGIN
y_start = orig_y_start + MARGIN
# 新的宽高:减去两倍的 MARGIN (左右/上下)
target_w = orig_w - 2 * MARGIN
target_h = orig_h - 2 * MARGIN
# --- 图像缩放与居中 ---
# Calculate new size while maintaining aspect ratio
original_w, original_h = img.size
# Calculate the ratio needed to fit within the *带边距的* 目标区域
ratio_w = target_w / original_w
ratio_h = target_h / original_h
# Use the *smaller* of the two ratios to ensure the image fits entirely
resize_ratio = min(ratio_w, ratio_h)
# Calculate the new dimensions
new_w = int(original_w * resize_ratio)
new_h = int(original_h * resize_ratio)
# Resize the image.
resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Calculate the paste position to center the resized image within its target area
# Center X: (Target Width - New Width) / 2 + X Start (带边距的 X_start)
paste_x = (target_w - new_w) // 2 + x_start
# 预留文本高度 ( TEXT_RESERVE_HEIGHT )
TEXT_RESERVE_HEIGHT = 30
# Center Y: (Target Height - New Height - 预留文本高度) / 2 + Y Start (带边距的 Y_start)
paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start
# 确保图片顶部不超出目标区域的 Y_start
paste_y = max(paste_y, y_start)
# Paste the resized image onto the canvas
canvas.paste(resized_img, (paste_x, paste_y))
# --- 文本居中与定位 ---
full_text = f"ID: {item_id}, Category: {category}"
if add_text:
try:
# 推荐使用:计算文本的实际尺寸 (width, height)
bbox = draw.textbbox((0, 0), full_text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
except AttributeError:
# 兼容旧版本 Pillow
text_w, text_h = draw.textsize(full_text, font=font)
# 计算 X 轴起始位置:使其在 *带边距的目标区域* (target_w) 中居中
text_x_center = x_start + target_w // 2
text_x_start = text_x_center - text_w // 2
# 计算 Y 轴起始位置:将其放在 *带边距的目标区域* 的底部
# (带边距的起始Y + 带边距的高度 - 文本行的高度)
# 📢 在带边距的目标区域底部再减去 5 像素作为与底部的边距
text_y_start = y_start + target_h - text_h
draw.text((text_x_start, text_y_start),
full_text,
fill='black',
font=font)
# Save as a high-quality JPG (quality=90 is a good balance)
# canvas.save(output_path, 'JPEG', quality=90)
return canvas
# def merge_images_to_square(outfit_items: List[Dict[str, str]], max_len=9, add_text=True):
# """
# Loads up to 4 images from the given paths, resizes them while maintaining
# aspect ratio, and merges them onto a 1024x1024 white background JPG.
#
# The layout depends on the number of images:
# 1: Center the single image on the 1024x1024 canvas.
# 2: Place side-by-side, each scaled to fit a 512x1024 half.
# 3: Place in top-left (512x512), top-right (512x512), and bottom-left (512x512).
# 4: Place in all four 512x512 quadrants.
#
# Args:
# outfit_items: A list of item metadata (max length 9).
#
# Returns:
# The file path of the temporary merged JPG image.
# """
#
# # Define the final canvas size
# CANVAS_SIZE = 1024
#
# # 1. Create the final white canvas
# # Using 'RGB' mode for JPG output
# canvas = Image.new('RGB', (CANVAS_SIZE, CANVAS_SIZE), 'white')
# draw = ImageDraw.Draw(canvas)
# font = ImageFont.load_default()
#
# # 2. Define the quadrants/target areas (x, y, w, h)
# # The positions are based on a 512x512 quadrant size
# quadrants = {
# 1: [(0, 0, CANVAS_SIZE, CANVAS_SIZE)], # Single full-size placement
# 2: [(0, 0, 512, CANVAS_SIZE), (512, 0, 512, CANVAS_SIZE)], # Left, Right
# 3: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512)], # Top-Left, Top-Right, Bottom-Left
# 4: [(0, 0, 512, 512), (512, 0, 512, 512), (0, 512, 512, 512), (512, 512, 512, 512)], # All Four
# 5: ALL_9_CELLS[:5], # 布局前5个单元格 (1-5)
# 6: ALL_9_CELLS[:6], # 布局前6个单元格 (1-6)
# 7: ALL_9_CELLS[:7], # 布局前7个单元格 (1-7)
# 8: ALL_9_CELLS[:8], # 布局前8个单元格 (1-8)
# 9: ALL_9_CELLS[:9] # 布局全部9个单元格 (1-9)
# }
#
# # 3. Load and Filter Images
# valid_images = []
# image_paths = [item['image_path'] for item in outfit_items]
# for path in image_paths:
# try:
# # We use Image.open() and convert to 'RGB' to handle potential transparency (RGBA)
# # and ensure compatibility with the final 'RGB' canvas and JPG output.
# img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB')
# # img = Image.open(path).convert('RGB')
# valid_images.append(img)
# except Exception as e:
# logger.error(f"Error loading image {path}. Skipping: {e}")
#
# num_images = len(valid_images)
#
# if num_images == 0:
# raise ValueError("No valid images were loaded.")
#
# if num_images > max_len:
# raise ValueError(f"Valid item number {num_images} exceed max limit {max_len}")
#
# # Get the correct list of target areas based on the number of valid images
# target_areas = quadrants.get(num_images, [])
#
# # 4. Resize and Paste
# for i, (img, item) in enumerate(zip(valid_images, outfit_items)):
# item_id = item['item_id']
# category = item['category']
# if i >= len(target_areas):
# # This should not happen if num_images <= 4
# break
#
# # Target area dimensions (x_start, y_start, width, height)
# x_start, y_start, target_w, target_h = target_areas[i]
#
# # Calculate new size while maintaining aspect ratio
# original_w, original_h = img.size
#
# # Calculate the ratio needed to fit within the target area
# ratio_w = target_w / original_w
# ratio_h = target_h / original_h
#
# # Use the *smaller* of the two ratios to ensure the image fits entirely
# resize_ratio = min(ratio_w, ratio_h)
#
# # Calculate the new dimensions
# new_w = int(original_w * resize_ratio)
# new_h = int(original_h * resize_ratio)
#
# # Resize the image. Image.Resampling.LANCZOS provides high-quality scaling.
# # Pillow documentation recommends ANTIALIAS or BICUBIC for downscaling,
# # but LANCZOS is a good general high-quality filter.
# # Note: In Pillow versions > 9.0.0, Image.LANCZOS is now Image.Resampling.LANCZOS
# resized_img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)
#
# # Calculate the paste position to center the resized image within its target area
# # Center X: (Target Width - New Width) / 2 + X Start
# paste_x = (target_w - new_w) // 2 + x_start
# # Center Y: (Target Height - New Height) / 2 + Y Start
# # paste_y = (target_h - new_h) // 2 + y_start
#
# TEXT_RESERVE_HEIGHT = 30
# paste_y = (target_h - new_h - TEXT_RESERVE_HEIGHT) // 2 + y_start
# paste_y = max(paste_y, y_start)
#
# # Paste the resized image onto the canvas
# canvas.paste(resized_img, (paste_x, paste_y))
#
# full_text = f"ID: {item_id}, Category: {category}"
# try:
# # 推荐使用:计算文本的实际尺寸 (width, height)
# bbox = draw.textbbox((0, 0), full_text, font=font)
# text_w = bbox[2] - bbox[0]
# text_h = bbox[3] - bbox[1]
# except AttributeError:
# # 兼容旧版本 Pillow
# text_w, text_h = draw.textsize(full_text, font=font)
#
# # 计算 X 轴起始位置:使其在目标区域 (target_w) 中居中
# text_x_center = x_start + target_w // 2
# text_x_start = text_x_center - text_w // 2
#
# # 计算 Y 轴起始位置:将其放在目标区域的底部
# # (目标区域的起始Y + 目标区域的高度 - 文本行的高度)
# text_y_start = y_start + target_h - text_h - 5 # 减去 5 像素作为边距
#
# # 3. 绘制合并后的文本
# if add_text:
# draw.text((text_x_start, text_y_start),
# full_text,
# fill='black',
# font=font)
#
# # Save as a high-quality JPG (quality=90 is a good balance)
# # canvas.save(output_path, 'JPEG', quality=90)
#
# return canvas