Compare commits
10 Commits
cde7357df8
...
78670b4210
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
78670b4210 | ||
|
|
266b23c97d | ||
|
|
affef482e6 | ||
|
|
38af75077d | ||
|
|
2d38e3fc0b | ||
|
|
e3d4939718 | ||
|
|
5f50f3ec1a | ||
|
|
3b685c34f0 | ||
|
|
36fc937f0c | ||
|
|
9579d498c9 |
@@ -40,8 +40,7 @@ COPY . /app
|
||||
# Install litserve and requirements
|
||||
RUN pip install --upgrade pip setuptools wheel
|
||||
RUN pip install --no-cache-dir litserve==0.2.16 -r requirements.txt
|
||||
RUN pip install torch torchvision
|
||||
|
||||
RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
EXPOSE 8000
|
||||
CMD ["python", "-m","app.main"]
|
||||
#CMD ["tail", "-f","/dev/null"]
|
||||
|
||||
@@ -44,7 +44,6 @@ class LCAgent(ls.LitAPI):
|
||||
'max_len': 9,
|
||||
'gemini_model_name': settings.LLM_MODEL_NAME
|
||||
}
|
||||
self.outfit_ids = []
|
||||
|
||||
async def decode_request(self, request: AgentRequestModel):
|
||||
"""
|
||||
@@ -62,15 +61,17 @@ class LCAgent(ls.LitAPI):
|
||||
|
||||
async def predict(self, request):
|
||||
|
||||
self.outfit_ids = [str(uuid.uuid4()) for _ in range(request.num_outfits)]
|
||||
outfit_ids = [str(uuid.uuid4()) for _ in range(request.num_outfits)]
|
||||
|
||||
asyncio.create_task(self.background_run(request))
|
||||
return {"status": "Task initiated in background.", "outfit_ids": self.outfit_ids}
|
||||
asyncio.create_task(self.background_run(request, outfit_ids))
|
||||
|
||||
logger.info({"status": "Task initiated in background.", "outfit_ids": outfit_ids})
|
||||
return {"status": "Task initiated in background.", "outfit_ids": outfit_ids}
|
||||
|
||||
async def encode_response(self, output):
|
||||
return output
|
||||
|
||||
async def background_run(self, request: AgentRequestModel):
|
||||
async def background_run(self, request: AgentRequestModel, outfit_ids):
|
||||
# 1. 根据用户ID查询对话历史,总结对话内容
|
||||
request_summary = await self.get_conversation_summary(request.session_id)
|
||||
logger.info(f"request_summary: {request_summary}")
|
||||
@@ -83,7 +84,8 @@ class LCAgent(ls.LitAPI):
|
||||
user_id=request.user_id,
|
||||
gender=request.gender,
|
||||
callback_url=request.callback_url,
|
||||
max_len=request.max_len)
|
||||
max_len=request.max_len,
|
||||
outfit_ids=outfit_ids)
|
||||
logger.info("--- Final Recommendation Results ---")
|
||||
for i, path in enumerate(recommendation_results.get("successful_outfits", [])):
|
||||
logger.info(f"✅ Outfit {i + 1} saved to: {path}")
|
||||
@@ -104,7 +106,7 @@ class LCAgent(ls.LitAPI):
|
||||
return summary
|
||||
|
||||
async def recommend_outfit(self, request_summary: str, stylist_name: str, start_outfit=None, num_outfits: int = 1,
|
||||
user_id: str = "test", gender: str = "male", callback_url: str = None, max_len: int = 9):
|
||||
user_id: str = "test", gender: str = "male", callback_url: str = None, max_len: int = 9, outfit_ids=None):
|
||||
"""
|
||||
基于用户的对话历史和需求,推荐一套搭配。
|
||||
|
||||
@@ -112,14 +114,18 @@ class LCAgent(ls.LitAPI):
|
||||
request_summary: 用户的request
|
||||
start_outfit: 可选的初始搭配列表,每个元素包含 'item_id' 和 'category'。
|
||||
"""
|
||||
if outfit_ids is None:
|
||||
outfit_ids = []
|
||||
if start_outfit is None:
|
||||
start_outfit = []
|
||||
tasks = []
|
||||
task_map = {}
|
||||
|
||||
stylist_agent_kwages = self.stylist_agent_kwages.copy()
|
||||
for i in range(num_outfits):
|
||||
self.stylist_agent_kwages['outfit_id'] = self.outfit_ids[i]
|
||||
self.stylist_agent_kwages['max_len'] = max_len
|
||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
||||
stylist_agent_kwages['outfit_id'] = outfit_ids[i]
|
||||
stylist_agent_kwages['max_len'] = max_len
|
||||
agent = AsyncStylistAgent(**stylist_agent_kwages)
|
||||
task = agent.run_styling_process(
|
||||
request_summary=request_summary,
|
||||
stylist_path=stylist_name,
|
||||
@@ -129,7 +135,7 @@ class LCAgent(ls.LitAPI):
|
||||
gender=gender,
|
||||
)
|
||||
tasks.append(task)
|
||||
task_map[task] = {"outfit_id": self.outfit_ids[i], "retries": 0}
|
||||
task_map[task] = {"outfit_id": outfit_ids[i], "retries": 0}
|
||||
logger.info(f"--- Starting {num_outfits} concurrent outfit generation tasks. ---")
|
||||
|
||||
# 2. 任务执行与重试循环
|
||||
@@ -157,8 +163,8 @@ class LCAgent(ls.LitAPI):
|
||||
logger.info(f"--- Retrying outfit {outfit_id} (Attempt {current_retries + 1}/{retry_limit}). ---")
|
||||
|
||||
# 重新创建任务 (可能需要短暂延迟,例如 time.sleep(1),但在此异步环境中,我们会通过重新创建 agent/task 来实现)
|
||||
self.stylist_agent_kwages['outfit_id'] = outfit_id
|
||||
agent = AsyncStylistAgent(**self.stylist_agent_kwages)
|
||||
stylist_agent_kwages['outfit_id'] = outfit_id
|
||||
agent = AsyncStylistAgent(**stylist_agent_kwages)
|
||||
new_task = agent.run_styling_process(
|
||||
request_summary=request_summary,
|
||||
stylist_path=stylist_name,
|
||||
|
||||
@@ -25,26 +25,26 @@ class PredictRequest(BaseModel):
|
||||
|
||||
class LCChatBot(ls.LitAPI):
|
||||
def setup(self, device):
|
||||
self.llm = AsyncGeminiLLM(model_name=settings.LLM_MODEL_NAME)
|
||||
# self.llm = AsyncGeminiLLM(model_name=settings.LLM_MODEL_NAME)
|
||||
self.redis = RedisManager(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
key_prefix=settings.REDIS_HISTORY_KEY_PREFIX
|
||||
)
|
||||
self.vector_db = VectorDatabase(
|
||||
vector_db_dir=settings.VECTOR_DB_DIR,
|
||||
collection_name=settings.COLLECTION_NAME,
|
||||
embedding_model_name=settings.EMBEDDING_MODEL_NAME
|
||||
)
|
||||
self.stylist_agent_kwages = {
|
||||
'local_db': self.vector_db,
|
||||
'max_len': 5,
|
||||
'outfits_root': settings.OUTFIT_OUTPUT_DIR,
|
||||
'image_dir': settings.IMAGE_DIR,
|
||||
'stylist_guide_dir': settings.STYLIST_GUIDE_DIR,
|
||||
'gemini_model_name': settings.LLM_MODEL_NAME
|
||||
}
|
||||
# self.vector_db = VectorDatabase(
|
||||
# vector_db_dir=settings.VECTOR_DB_DIR,
|
||||
# collection_name=settings.COLLECTION_NAME,
|
||||
# embedding_model_name=settings.EMBEDDING_MODEL_NAME
|
||||
# )
|
||||
# self.stylist_agent_kwages = {
|
||||
# 'local_db': self.vector_db,
|
||||
# 'max_len': 5,
|
||||
# 'outfits_root': settings.OUTFIT_OUTPUT_DIR,
|
||||
# 'image_dir': settings.IMAGE_DIR,
|
||||
# 'stylist_guide_dir': settings.STYLIST_GUIDE_DIR,
|
||||
# 'gemini_model_name': settings.LLM_MODEL_NAME
|
||||
# }
|
||||
self.gemini_client = genai.Client(
|
||||
vertexai=True, project='aida-461108', location='us-central1'
|
||||
)
|
||||
|
||||
@@ -19,7 +19,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncStylistAgent:
|
||||
CATEGORY_SET = {'Activewear', 'Watches', 'Shopping Totes', 'Underwear', 'Sunglasses', 'Dresses', 'Outerwear', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Skirts', 'Swimwear', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Pants', 'Suits', 'Shoes', 'Shirts & Tops', 'Scarves & Shawls'}
|
||||
CATEGORY_SET = {
|
||||
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Shoes',
|
||||
# 取消推荐配饰
|
||||
# 'Swimwear', 'Underwear',
|
||||
# , 'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry', 'Briefcases', 'Socks', 'Neckties', 'Scarves & Shawls'
|
||||
}
|
||||
CATEGORY_SET_ALL = {
|
||||
'Activewear', 'Dresses', 'Outerwear', 'Pants', 'Shirts & Tops', 'Skirts', 'Suits', 'Swimwear', 'Underwear',
|
||||
'Watches', 'Shopping Totes', 'Sunglasses', 'Handbags', 'Backpacks', 'Belts', 'Hats', 'Jewelry',
|
||||
'Briefcases', 'Socks', 'Neckties', 'Shoes', 'Scarves & Shawls'
|
||||
}
|
||||
|
||||
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
|
||||
# self.outfit_items: List[Dict[str, str]] = []
|
||||
@@ -46,7 +56,7 @@ class AsyncStylistAgent:
|
||||
self.gcs_bucket = "lc_stylist_agent_outfit_items"
|
||||
self.minio_bucket = "lanecarford"
|
||||
|
||||
def _load_style_guide(self, path: str) -> str:
|
||||
def _load_style_guide(self, path: str):
|
||||
"""加载 markdown 风格指南内容。"""
|
||||
parts = path.split('/', 1)
|
||||
if len(parts) != 2:
|
||||
@@ -54,18 +64,24 @@ class AsyncStylistAgent:
|
||||
|
||||
bucket_name, object_name = parts
|
||||
try:
|
||||
# 1. 获取对象
|
||||
# 获取对象 读取内容
|
||||
response = minio_client.get_object(bucket_name, object_name)
|
||||
|
||||
# 2. 读取内容
|
||||
content_bytes = response.read()
|
||||
|
||||
# 3. 关闭连接
|
||||
response.close()
|
||||
response.release_conn()
|
||||
json_response = minio_client.get_object(bucket_name, object_name.replace('.md', '.json'))
|
||||
json_data = json_response.data
|
||||
|
||||
# 4. 解码并返回
|
||||
return content_bytes.decode('utf-8')
|
||||
# 关闭连接
|
||||
response.close()
|
||||
json_response.close()
|
||||
response.release_conn()
|
||||
json_response.release_conn()
|
||||
|
||||
# 4. 解析 JSON 字符串
|
||||
json_string = json_data.decode('utf-8')
|
||||
json_content = json.loads(json_string)
|
||||
|
||||
return content_bytes.decode('utf-8'), json_content
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load style guide from {path}: {e}")
|
||||
@@ -99,8 +115,8 @@ class AsyncStylistAgent:
|
||||
|
||||
## Your Workflow and Constraints
|
||||
|
||||
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions, accessory stacking, and shoe/bag coordination**.
|
||||
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses), then shoes and bags, and finally accessories.
|
||||
1. **Style Adherence**: You must strictly observe all rules in the Style Guide concerning **color palette, fit, layering principles, pattern restrictions , shoe coordination**.
|
||||
2. **Step Planning**: The styling sequence must follow a **top-down, inside-out** approach: First major garments (tops/outerwear/bottoms/dresses) then shoes.
|
||||
3. **Structured Output**: Every response must recommend the **next single item**. You must strictly use the **JSON format** for your output, as follows:
|
||||
|
||||
```json
|
||||
@@ -118,7 +134,6 @@ class AsyncStylistAgent:
|
||||
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
|
||||
* **Material/Detail** (e.g., 100% cotton, linen, gold clasp, thin stripe, checkered pattern)
|
||||
* **Role in the Outfit** (e.g., serves as the innermost base layer for layering; acts as the crucial tie accent for the smart casual look)
|
||||
* **[CRITICAL FOR JEWELRY] If recommending 'Jewelry' (especially Necklaces), the description must specify its distinction (length, thickness, pendant style) from all previously selected necklaces to ensure layered variety.**
|
||||
|
||||
4. **Termination Condition**: Only when you deem the entire outfit complete and **all mandatory elements stipulated in the Style Guide are met**, you must output the following JSON format to terminate the process:
|
||||
|
||||
@@ -156,7 +171,7 @@ class AsyncStylistAgent:
|
||||
# self._clear_uploaded_files()
|
||||
# 1. 添加图片内容
|
||||
if self.outfit_items:
|
||||
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len, add_text=False)
|
||||
merged_image = merge_images_to_square(self.outfit_items, max_len=self.max_len + 1, add_text=False)
|
||||
image_bytes_io = io.BytesIO()
|
||||
image_format = 'JPEG'
|
||||
mime_type = 'image/jpeg'
|
||||
@@ -206,6 +221,33 @@ class AsyncStylistAgent:
|
||||
# 返回一个停止信号以防止循环继续
|
||||
return json.dumps({"action": "stop", "reason": f"API_ERROR: {str(e)}"})
|
||||
|
||||
async def _merge_images(self, user_id: str):
|
||||
"""
|
||||
实际调用 Gemini API 的函数,接受文本和可选的图片路径列表。
|
||||
|
||||
Args:
|
||||
user_input: 发送给模型的主文本内容。
|
||||
image_paths: 待发送图片的本地路径列表。
|
||||
|
||||
Returns:
|
||||
模型的响应文本(预期为 JSON 字符串)。
|
||||
"""
|
||||
minio_path = ""
|
||||
if self.outfit_items:
|
||||
merged_image = merge_images_to_square(self.outfit_items, max_len=9, add_text=False)
|
||||
image_bytes_io = io.BytesIO()
|
||||
image_format = 'JPEG'
|
||||
|
||||
merged_image.save(image_bytes_io, format=image_format)
|
||||
image_bytes = image_bytes_io.getvalue()
|
||||
|
||||
file_name = uuid.uuid4()
|
||||
blob_name = f"lc_stylist_agent_outfit_items/{user_id}/{file_name}.jpg"
|
||||
responses = oss_upload_image(oss_client=minio_client, bucket=self.minio_bucket, object_name=blob_name, image_bytes=image_bytes)
|
||||
minio_path = f"{responses.bucket_name}/{responses.object_name}"
|
||||
|
||||
return minio_path
|
||||
|
||||
def _parse_gemini_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""安全解析 Gemini 的 JSON 响应。"""
|
||||
try:
|
||||
@@ -252,6 +294,66 @@ class AsyncStylistAgent:
|
||||
print(f"An error occurred during item retrieval: {e}")
|
||||
return None
|
||||
|
||||
async def _get_random_accessories(self, stylist, item_count):
|
||||
stylist_item = []
|
||||
stylist_item_ids = []
|
||||
|
||||
filter_items = [
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"item_group_id": {"$ne": "Shoes"}},
|
||||
{"modality": "image"}
|
||||
]
|
||||
random_items = []
|
||||
|
||||
for i in stylist:
|
||||
# 1. 根据stylist要求抽取item
|
||||
query_embedding = self.local_db.get_clip_embedding(i['text'], is_image=False)
|
||||
stylist_results = self.local_db.query_local_db(query_embedding, i['category'], n_results=10)
|
||||
stylist_item += random.choices(stylist_results['metadatas'][0], k=i['count'])
|
||||
stylist_item_ids += [item_id['item_id'] for item_id in stylist_item]
|
||||
filter_items.append({"category": {"$ne": i["category"]}})
|
||||
|
||||
accessories_count = 9 - item_count - len(stylist_item)
|
||||
|
||||
if accessories_count > 0:
|
||||
if accessories_count > 4:
|
||||
accessories_count = 4
|
||||
for i in range(accessories_count):
|
||||
# 2. 在配饰池中过滤掉已经选中的item ,然后抽两个item
|
||||
random_poll = self.local_db.load_filtered_ids(filter_items)
|
||||
logger.info(f"random_poll 数量: {len(random_poll)}")
|
||||
|
||||
item = self.local_db.random_get_accessories(random.choice(random_poll))
|
||||
if item['metadatas'][0]['category'] in ['Shopping Totes', 'Handbags', 'Backpacks', 'Briefcases']:
|
||||
filter_items.append({"category": {"$ne": "Shopping Totes"}})
|
||||
filter_items.append({"category": {"$ne": "Handbags"}})
|
||||
filter_items.append({"category": {"$ne": "Backpacks"}})
|
||||
filter_items.append({"category": {"$ne": "Briefcases"}})
|
||||
else:
|
||||
filter_items.append({"category": {"$ne": item['metadatas'][0]['category']}})
|
||||
|
||||
random_items.append(item['metadatas'][0])
|
||||
|
||||
all_items = stylist_item + random_items
|
||||
|
||||
else:
|
||||
all_items = stylist_item
|
||||
|
||||
items_data = []
|
||||
|
||||
for best_meta in all_items:
|
||||
items_data.append({
|
||||
"item_id": best_meta['item_id'], # 从 metadata 字典中安全获取
|
||||
"category": best_meta['category'],
|
||||
"gpt_description": best_meta['description'],
|
||||
'description': best_meta['description'],
|
||||
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
|
||||
# 这里假设 item_id 就是文件名的一部分
|
||||
"image_path": os.path.join(f"{best_meta['item_id']}.jpg")
|
||||
})
|
||||
|
||||
return items_data
|
||||
|
||||
def _build_user_input(self) -> str:
|
||||
"""构建发送给 Gemini 的用户输入,包含已选单品信息。"""
|
||||
if not self.outfit_items:
|
||||
@@ -271,7 +373,7 @@ class AsyncStylistAgent:
|
||||
"""主流程控制循环。"""
|
||||
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
|
||||
|
||||
self.style_guide = self._load_style_guide(stylist_path)
|
||||
self.style_guide, self.style_accessories_guide = self._load_style_guide(stylist_path)
|
||||
self.system_prompt = self._build_system_prompt(request_summary, gender)
|
||||
response_data = {"status": "",
|
||||
"message": "",
|
||||
@@ -313,7 +415,21 @@ class AsyncStylistAgent:
|
||||
|
||||
# 3. 检查终止条件
|
||||
if gemini_data.get('action') == 'stop':
|
||||
print(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
|
||||
response_data['path'] = minio_path
|
||||
response_data['items'].append({"item_id": item_id, "category": item_category})
|
||||
response_data['status'] = "ok"
|
||||
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||
|
||||
# 根据stylist要求随机增加配饰 3-4个配饰
|
||||
new_item = await self._get_random_accessories(self.style_accessories_guide, len(self.outfit_items))
|
||||
for item in new_item:
|
||||
self.outfit_items.append(item)
|
||||
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
|
||||
|
||||
response_data['path'] = await self._merge_images(user_id)
|
||||
|
||||
logger.info(f"🛑 搭配完成,终止原因: {gemini_data.get('reason')}")
|
||||
self.stop_reason = "Finish reason: " + gemini_data.get('reason', 'No reason provided')
|
||||
response_data['status'] = "stop"
|
||||
response_data['message'] = self.stop_reason
|
||||
@@ -327,7 +443,7 @@ class AsyncStylistAgent:
|
||||
description = gemini_data.get('description')
|
||||
|
||||
# 4a. 检查类别是否有效 (重要步骤)
|
||||
if category not in self.CATEGORY_SET:
|
||||
if category not in self.CATEGORY_SET_ALL:
|
||||
print(f"❌ Agent 推荐了无效类别: {category}。要求 Agent 重新输出。")
|
||||
# 在实际应用中,这里需要将错误信息发回给 Agent,要求它更正
|
||||
# 这里简化为跳过本次循环
|
||||
@@ -382,6 +498,19 @@ class AsyncStylistAgent:
|
||||
break
|
||||
|
||||
if len(self.outfit_items) >= self.max_len: # 设置一个最大循环限制,防止无限循环
|
||||
gemini_response_text, response_data['path'] = await self._call_gemini(user_input, user_id)
|
||||
response_data['items'].append({"item_id": self.outfit_items[-1]['item_id'], "category": self.outfit_items[-1]['category']})
|
||||
response_data['status'] = "ok"
|
||||
response = post_request(url=url, data=json.dumps(response_data), headers=headers)
|
||||
logger.info(f"request data :{response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
|
||||
|
||||
# 根据stylist要求随机增加配饰 3-4个配饰
|
||||
new_item = await self._get_random_accessories(self.style_accessories_guide, len(self.outfit_items))
|
||||
for item in new_item:
|
||||
self.outfit_items.append(item)
|
||||
response_data['items'].append({"item_id": item.get('item_id'), "category": item.get('category')})
|
||||
response_data['path'] = await self._merge_images(user_id)
|
||||
|
||||
logger.info("🚨 达到最大搭配数量限制,强制终止。")
|
||||
self.stop_reason = "Finish reason: Reached max outfit length."
|
||||
response_data['status'] = "stop"
|
||||
|
||||
@@ -53,4 +53,4 @@ JSON FIELD REQUIREMENTS:
|
||||
- **style (string):** The overall aesthetic description (e.g., "Classic elegance", "Modern minimalist", "Bohemian vibe", "Edgy and contemporary").
|
||||
- **color_preference (string or list):** User's preferred or excluded colors/tones (e.g., "Light colors only", "Avoid deep shades", "['Cream', 'Pale Blue']", "No preference").
|
||||
- **clothing_type (string):** User's preference for specific garment types, material, or silhouette (e.g., "Lightweight maxi dress", "Skirt with silk blouse", "Tailored wide-leg pants", "Floral print").
|
||||
- **vibe_or_details (string):** Any other details, mood requirements, or specific constraints (e.g., "Needs to be comfortable and breathable", "Accent on accessories", "Must cover shoulders")."""
|
||||
- **vibe_or_details (string):** Any other details, mood requirements, or specific constraints (e.g., "Needs to be comfortable and breathable", "Must cover shoulders")."""
|
||||
@@ -1,3 +1,6 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
import chromadb
|
||||
from PIL import Image
|
||||
@@ -15,6 +18,12 @@ class VectorDatabase():
|
||||
|
||||
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
|
||||
self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||
# self.cache_filtered_ids = self.load_filtered_ids([
|
||||
# {"item_group_id": {"$ne": "Clothing"}},
|
||||
# {"item_group_id": {"$ne": "Shoes"}},
|
||||
# {"modality": "image"}
|
||||
# ])
|
||||
# self.total_count = len(self.cache_filtered_ids)
|
||||
|
||||
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
|
||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||
@@ -58,3 +67,79 @@ class VectorDatabase():
|
||||
include=['documents', 'metadatas', 'distances']
|
||||
)
|
||||
return results
|
||||
|
||||
def load_filtered_ids(self, filter_item):
|
||||
# print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
||||
start_time = time.time()
|
||||
FILTER_CRITERIA = {
|
||||
"$and": filter_item
|
||||
}
|
||||
MAX_LIMIT = 100000
|
||||
|
||||
try:
|
||||
# 获取所有符合条件的 ID
|
||||
all_ids_results = self.collection.get(
|
||||
where=FILTER_CRITERIA,
|
||||
limit=MAX_LIMIT,
|
||||
include=[]
|
||||
)
|
||||
all_matched_ids = all_ids_results['ids']
|
||||
# print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
||||
print(time.time() - start_time)
|
||||
return all_matched_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
def random_get_accessories(self, ids):
|
||||
# 2. 调用 ChromaDB:只查询这一个 ID 的详细信息
|
||||
try:
|
||||
final_results = self.collection.get(
|
||||
ids=ids,
|
||||
include=["metadatas"] # 你只需要元数据
|
||||
)
|
||||
|
||||
# 提取结果
|
||||
if final_results['ids']:
|
||||
return final_results
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 获取最终记录时发生错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
stylist = {
|
||||
'text': "gold necklace",
|
||||
'count': 2,
|
||||
'category': "Jewelry"
|
||||
}
|
||||
|
||||
max_len = 5
|
||||
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
|
||||
A = local_db.load_filtered_ids([
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"item_group_id": {"$ne": "Shoes"}},
|
||||
{"modality": "image"}
|
||||
])
|
||||
# print(db.random_get_accessories())
|
||||
start_time = time.time()
|
||||
X = local_db.random_get_accessories(['ELI699_img'])
|
||||
print(X)
|
||||
print(time.time() - start_time)
|
||||
# query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
|
||||
#
|
||||
# results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
|
||||
# # 2. 从结果集中抽 stylist['count'] 个item
|
||||
# stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
|
||||
# stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
|
||||
#
|
||||
# # 3. 从随机库中抽取配饰,总数达到9件 ,需过滤掉已经抽中的item
|
||||
# accessories_count = 9 - max_len - stylist['count']
|
||||
#
|
||||
# random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
|
||||
# random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
|
||||
# all_items = stylist_item + random_items
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
@@ -114,12 +115,19 @@ if __name__ == '__main__':
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/7fed1c7b-9efd-41fa-a335-182c310ea611.jpg"
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/5de155d0-56a6-43e8-a2f1-7538fce86220.jpg"
|
||||
# url = "lanecarford/lc_stylist_agent_outfit_items/string/1cd1803c-5f51-4961-a4f2-2acd3e0d8294.jpg"
|
||||
url = 'lanecarford/lc_stylist_agent_outfit_items/string/99cd8cc0-856a-487d-bb21-5684855ef48f.jpg'
|
||||
url = [
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251121/4b595d3b-5d3d-4617-ae09-5fca92d935f7.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251121/6d0d7540-5b61-45f2-a1fa-5cb1c7a3d0fa.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251121/a4e51ccb-9b95-4718-8153-92ee0a39d0c8.jpg',
|
||||
'lanecarford/lc_stylist_agent_outfit_items/zhhtest20251121/cbebbcf6-cca2-4460-9f9f-d0b1000dc2cd.jpg'
|
||||
]
|
||||
read_type = "1"
|
||||
img = oss_get_image(oss_client=minio_client, path=url, data_type=read_type)
|
||||
if read_type == "cv2":
|
||||
cv2.imshow("", img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
img.show()
|
||||
img.save("4.png")
|
||||
for id, i in enumerate(url):
|
||||
img = oss_get_image(minio_client, i, read_type)
|
||||
img = oss_get_image(oss_client=minio_client, path=i, data_type=read_type)
|
||||
if read_type == "cv2":
|
||||
cv2.imshow("", img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
img.show()
|
||||
img.save(f"{time.time()}.png")
|
||||
|
||||
90
app/test/chromadb/embedding_query.py
Normal file
90
app/test/chromadb/embedding_query.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import random
|
||||
|
||||
import chromadb
|
||||
from typing import Set, List, Dict, Union, Any
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import no_grad
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from app.server.utils.minio_client import oss_get_image, minio_client
|
||||
from app.server.utils.minio_config import MINIO_LC_DATA_PATH
|
||||
|
||||
# --- 你的配置 ---
|
||||
DB_PATH = "/workspace/lc_stylist_agent/db"
|
||||
COLLECTION_NAME = 'lc_clothing_embedding'
|
||||
# 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大)
|
||||
MAX_LIMIT = 1000000
|
||||
|
||||
client = chromadb.PersistentClient(path=DB_PATH)
|
||||
try:
|
||||
collection = client.get_collection(name=COLLECTION_NAME)
|
||||
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
|
||||
except ValueError:
|
||||
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
|
||||
# 如果 collection 不存在,我们将跳过后续操作
|
||||
collection = None
|
||||
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
|
||||
def get_clip_embedding(data: str | Image.Image) -> List[float]:
|
||||
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
|
||||
|
||||
embedding_model_name = "openai/clip-vit-base-patch32"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = CLIPModel.from_pretrained(embedding_model_name).to(device)
|
||||
processor = CLIPProcessor.from_pretrained(embedding_model_name)
|
||||
|
||||
# 强制截断,解决序列长度问题
|
||||
inputs = processor(
|
||||
text=[data],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True
|
||||
).to(device)
|
||||
with no_grad():
|
||||
features = model.get_text_features(**inputs)
|
||||
|
||||
# L2 归一化
|
||||
features = features / features.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
return features.cpu().numpy().flatten().tolist()
|
||||
|
||||
|
||||
def query_local_db(embedding: List[float], category: str, n_results: int = 3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
基于嵌入向量在本地数据库中查询相似单品。
|
||||
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
|
||||
"""
|
||||
# 实际应执行向量查询
|
||||
# 为了演示流程,返回一个模拟结果
|
||||
results = collection.query(
|
||||
query_embeddings=[embedding],
|
||||
n_results=n_results,
|
||||
where={
|
||||
"$and": [
|
||||
{"category": category},
|
||||
{"modality": "image"},
|
||||
]
|
||||
},
|
||||
include=['documents', 'metadatas', 'distances']
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
embedding = get_clip_embedding("watch")
|
||||
print(embedding)
|
||||
result = query_local_db(embedding, "Watches", 20)
|
||||
print(result)
|
||||
ids = result['ids'][0]
|
||||
|
||||
random_single_id = random.choices(ids, k=2)
|
||||
print(random_single_id)
|
||||
|
||||
# for id in ids:
|
||||
# path = id.replace("_img", ".jpg")
|
||||
# img = oss_get_image(oss_client=minio_client, path=f"{MINIO_LC_DATA_PATH}/{path}", data_type="PIL").convert('RGB')
|
||||
# img.save(path)
|
||||
107
app/test/chromadb/random_accessories.py
Normal file
107
app/test/chromadb/random_accessories.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import time
|
||||
|
||||
import chromadb
|
||||
import random
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# --- 你的配置 ---
|
||||
DB_PATH = "/workspace/lc_stylist_agent/db"
|
||||
COLLECTION_NAME = 'lc_clothing_embedding'
|
||||
FILTER_CRITERIA = {
|
||||
"$and": [
|
||||
{"item_group_id": {"$ne": "Clothing"}},
|
||||
{"item_group_id": {"$ne": "Shoes"}}, # 新增:过滤 Shoes
|
||||
{"modality": "image"},
|
||||
]
|
||||
}
|
||||
MAX_LIMIT = 1000000 # 用于第一次获取所有ID的限制
|
||||
|
||||
client = chromadb.PersistentClient(path=DB_PATH)
|
||||
try:
|
||||
collection = client.get_collection(name=COLLECTION_NAME)
|
||||
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
|
||||
except ValueError:
|
||||
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
|
||||
exit()
|
||||
|
||||
|
||||
# -----------------------------------------------
|
||||
# 步骤 1: 应用程序启动时/初始化时执行(只执行一次)
|
||||
# -----------------------------------------------
|
||||
def load_filtered_ids(coll: chromadb.api.models.Collection.Collection, filter_criteria: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
加载并缓存所有符合条件的记录ID。
|
||||
"""
|
||||
print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
|
||||
|
||||
try:
|
||||
# 获取所有符合条件的 ID
|
||||
all_ids_results = coll.get(
|
||||
where=filter_criteria,
|
||||
limit=MAX_LIMIT,
|
||||
include=[]
|
||||
)
|
||||
all_matched_ids = all_ids_results['ids']
|
||||
print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
|
||||
return all_matched_ids
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# 存储所有符合条件的 ID 的全局变量 (缓存)
|
||||
start_time = time.time()
|
||||
CACHED_FILTERED_IDS = load_filtered_ids(collection, FILTER_CRITERIA)
|
||||
print(time.time() - start_time)
|
||||
|
||||
|
||||
# -----------------------------------------------
|
||||
# 步骤 2: 每次需要随机记录时调用 (高效重复执行)
|
||||
# -----------------------------------------------
|
||||
def get_random_record_from_cache(coll: chromadb.api.models.Collection.Collection, cached_ids: List[str]) -> Dict[str, Any] | None:
|
||||
"""
|
||||
从缓存的 ID 列表中随机选择一个 ID,然后查询其详细信息。
|
||||
"""
|
||||
total_count = len(cached_ids)
|
||||
|
||||
if total_count == 0:
|
||||
return None
|
||||
|
||||
# 1. 纯 Python 内存操作:从缓存中随机选择一个 ID
|
||||
random_single_id = random.choice(cached_ids)
|
||||
|
||||
# 2. 调用 ChromaDB:只查询这一个 ID 的详细信息
|
||||
try:
|
||||
final_results = coll.get(
|
||||
ids=[random_single_id],
|
||||
)
|
||||
|
||||
# 提取结果
|
||||
if final_results['ids']:
|
||||
return {
|
||||
"id": final_results['ids'][0],
|
||||
"metadata": final_results['metadatas'][0]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 获取最终记录时发生错误: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# --- 执行并打印结果 (可以多次调用,每次都很快) ---
|
||||
print("\n--- 随机获取 1 ---")
|
||||
start_time = time.time()
|
||||
random_data_1 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS)
|
||||
print(time.time() - start_time)
|
||||
if random_data_1:
|
||||
print(f" ID: {random_data_1}")
|
||||
|
||||
print("\n--- 随机获取 2 ---")
|
||||
start_time = time.time()
|
||||
random_data_2 = get_random_record_from_cache(collection, CACHED_FILTERED_IDS)
|
||||
print(time.time() - start_time)
|
||||
if random_data_2:
|
||||
print(f" ID: {random_data_2}")
|
||||
101
app/test/chromadb/type_list.py
Normal file
101
app/test/chromadb/type_list.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import chromadb
|
||||
from typing import Set, List, Dict, Union, Any
|
||||
|
||||
# --- 你的配置 ---
|
||||
DB_PATH = "/workspace/lc_stylist_agent/db"
|
||||
COLLECTION_NAME = 'lc_clothing_embedding'
|
||||
# 设置一个足够大的限制来获取所有记录,或者使用分页(如果记录数非常庞大)
|
||||
MAX_LIMIT = 1000000
|
||||
|
||||
client = chromadb.PersistentClient(path=DB_PATH)
|
||||
try:
|
||||
collection = client.get_collection(name=COLLECTION_NAME)
|
||||
print(f"✅ 连接到 Collection: {COLLECTION_NAME}")
|
||||
except ValueError:
|
||||
print(f"⚠️ Collection '{COLLECTION_NAME}' 不存在。")
|
||||
# 如果 collection 不存在,我们将跳过后续操作
|
||||
collection = None
|
||||
|
||||
|
||||
def get_category_item_group_map(
|
||||
coll: chromadb.api.models.Collection.Collection
|
||||
) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
获取 Collection 中所有记录的 'category' 和 'item_group_id' 字段,
|
||||
并返回一个 Category 到其所有唯一 Item Group ID 集合的映射。
|
||||
"""
|
||||
print("\n--- 正在获取所有记录的元数据 (category 和 item_group_id)... ---")
|
||||
|
||||
# 1. 获取所有记录的元数据
|
||||
try:
|
||||
# 使用 .get() 方法获取所有 metadatas,不包含 embeddings 和 documents
|
||||
results = coll.get(
|
||||
limit=MAX_LIMIT,
|
||||
include=["metadatas"]
|
||||
)
|
||||
|
||||
all_metadatas: List[Dict[str, Any]] = results.get('metadatas', [])
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 获取元数据时发生错误: {e}")
|
||||
return {}
|
||||
|
||||
if not all_metadatas:
|
||||
print("❌ 集合中没有元数据记录。")
|
||||
return {}
|
||||
|
||||
# 2. 构建 Category 到 Item Group ID 集合的映射
|
||||
# 结构: { 'CategoryA': {'group_id_1', 'group_id_2'}, 'CategoryB': {'group_id_3'} }
|
||||
category_item_group_map: Dict[str, Set[str]] = {}
|
||||
|
||||
for metadata in all_metadatas:
|
||||
category_value: Union[str, None] = metadata.get('category')
|
||||
item_group_id_value: Union[str, None] = metadata.get('item_group_id')
|
||||
|
||||
# 确保两个元数据字段都存在
|
||||
if category_value is not None and item_group_id_value is not None:
|
||||
category = str(category_value)
|
||||
item_group_id = str(item_group_id_value)
|
||||
|
||||
# 使用 setdefault 来确保 category 键存在,并初始化为一个空的 set
|
||||
# 然后将 item_group_id 添加到对应的 set 中
|
||||
category_item_group_map.setdefault(category, set()).add(item_group_id)
|
||||
|
||||
return category_item_group_map
|
||||
|
||||
|
||||
# --- 执行并打印结果 ---
|
||||
if collection:
|
||||
category_item_group_mapping = get_category_item_group_map(collection)
|
||||
|
||||
if category_item_group_mapping:
|
||||
|
||||
# 统计总共发现多少种唯一 Category
|
||||
print(f"\n🎉 发现 {len(category_item_group_mapping)} 种唯一 Category 及其对应的 Item Group IDs:")
|
||||
|
||||
# 对类别名称进行排序,便于查看
|
||||
sorted_categories = sorted(category_item_group_mapping.keys())
|
||||
|
||||
# 打印详细结果
|
||||
for category in sorted_categories:
|
||||
item_group_ids = category_item_group_mapping[category]
|
||||
|
||||
# 打印类别名称和唯一的 Item Group ID 数量
|
||||
print(f"\n--- 👕 Category: **{category}** ---")
|
||||
print(f"**Item Group 总数:** {len(item_group_ids)} 个唯一 Item Group ID")
|
||||
|
||||
# 将 Item Group ID 转换为列表并排序,以便展示
|
||||
sorted_item_group_ids = sorted(list(item_group_ids))
|
||||
|
||||
# 打印前 10 个 Item Group ID 作为示例
|
||||
print(f"**部分 Item Group IDs (示例):**")
|
||||
|
||||
# 使用列表展示
|
||||
for i, item_group_id in enumerate(sorted_item_group_ids[:10]):
|
||||
print(f"* {item_group_id}")
|
||||
|
||||
if len(sorted_item_group_ids) > 10:
|
||||
print(f"* ... (还有 {len(sorted_item_group_ids) - 10} 个 Item Group ID 未列出)")
|
||||
|
||||
else:
|
||||
print("没有找到任何 Category 或 Item Group ID 数据。请检查元数据字段名称是否正确。")
|
||||
@@ -13,4 +13,12 @@ services:
|
||||
- ./db:/db
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
ports:
|
||||
- "10070:8000"
|
||||
- "10070:8000"
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
# 告诉 Docker 使用所有可用的 NVIDIA GPU
|
||||
- driver: nvidia
|
||||
device_ids: ['0']
|
||||
capabilities: [ gpu ]
|
||||
@@ -16,4 +16,5 @@ pytorch-fid==0.3.0
|
||||
open-clip-torch==2.24.0
|
||||
pytorch-fid==0.3.0
|
||||
litserve
|
||||
# pip install git+https://github.com/openai/CLIP.git
|
||||
# pip install git+https://github.com/openai/CLIP.git
|
||||
# pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
Reference in New Issue
Block a user