Enable data auto process for new data

This commit is contained in:
pangkaicheng
2025-12-10 17:27:56 +08:00
parent 0b1d948f77
commit 0e9546aa1a
12 changed files with 936 additions and 171 deletions

View File

@@ -51,7 +51,7 @@ class AgentRequestModel(BaseModel):
session_id: str session_id: str
num_outfits: int num_outfits: int
stylist_path: str stylist_path: str
batch_source: str batch_sources: List[str]
callback_url: str callback_url: str
gender: str gender: str
max_len: int = 9 max_len: int = 9
@@ -112,8 +112,8 @@ class LCAgent(ls.LitAPI):
request_summary=request_summary, request_summary=request_summary,
occasions=occasions, occasions=occasions,
stylist_name=request.stylist_path, stylist_name=request.stylist_path,
batch_source=request.batch_source,
start_outfit=[], start_outfit=[],
batch_sources=request.batch_sources,
num_outfits=request.num_outfits, num_outfits=request.num_outfits,
user_id=request.user_id, user_id=request.user_id,
gender=request.gender, gender=request.gender,
@@ -162,7 +162,7 @@ class LCAgent(ls.LitAPI):
return str(parsed_result.summary), [occ.value for occ in parsed_result.occasions] return str(parsed_result.summary), [occ.value for occ in parsed_result.occasions]
async def recommend_outfit( async def recommend_outfit(
self, request_summary: str, occasions: List[str], batch_source: str, stylist_name: str, start_outfit=[], self, request_summary: str, occasions: List[str], stylist_name: str, start_outfit: List = [], batch_sources: List[str] = [],
num_outfits: int = 1, user_id: str = "test", gender: str = "male", num_outfits: int = 1, user_id: str = "test", gender: str = "male",
callback_url: str = None, max_len: int = 9, outfit_ids=None callback_url: str = None, max_len: int = 9, outfit_ids=None
): ):
@@ -186,9 +186,9 @@ class LCAgent(ls.LitAPI):
task = agent.run_styling_process( task = agent.run_styling_process(
request_summary=request_summary, request_summary=request_summary,
occasions=occasions, occasions=occasions,
batch_source=batch_source,
stylist_name=stylist_name, stylist_name=stylist_name,
start_outfit=start_outfit, start_outfit=start_outfit,
batch_sources=batch_sources,
user_id=user_id, user_id=user_id,
callback_url=callback_url, callback_url=callback_url,
gender=gender, gender=gender,
@@ -227,9 +227,9 @@ class LCAgent(ls.LitAPI):
new_task = agent.run_styling_process( new_task = agent.run_styling_process(
request_summary=request_summary, request_summary=request_summary,
occasions=occasions, occasions=occasions,
batch_source=batch_source,
stylist_name=stylist_name, stylist_name=stylist_name,
start_outfit=start_outfit, start_outfit=start_outfit,
batch_sources=batch_sources,
user_id=user_id, user_id=user_id,
callback_url=callback_url callback_url=callback_url
) )
@@ -295,9 +295,9 @@ if __name__ == "__main__":
task = agent.run_styling_process( task = agent.run_styling_process(
request_summary=request_summary, request_summary=request_summary,
occasions=occasions, occasions=occasions,
batch_source="2025_q4",
stylist_name=stylist_name, stylist_name=stylist_name,
start_outfit=[], start_outfit=[],
batch_sources=["2025_q4"],
user_id=test_content['test_case_id'], user_id=test_content['test_case_id'],
callback_url="http://mock-callback.com/result", callback_url="http://mock-callback.com/result",
gender="female", gender="female",

View File

@@ -16,11 +16,16 @@ from app.server.utils.img_operation import merge_images_to_square
from app.server.utils.minio_client import minio_client, oss_upload_image from app.server.utils.minio_client import minio_client, oss_upload_image
from app.server.utils.request_post import post_request from app.server.utils.request_post import post_request
from app.config import settings from app.config import settings
from app.taxonomy import CLOTHING_CATEGORY, ACCESSORY_CATEGORY from app.taxonomy import CATEGORY, ALL_CATEGORY, IGNORE_CATEGORY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
IGNORE_CATEGORY = set(IGNORE_CATEGORY)
CLOTHING_CATEGORY = set(CATEGORY['clothing'] + CATEGORY['shoes'] + CATEGORY['bags']) - IGNORE_CATEGORY
ACCESSORY_CATEGORY = set(CATEGORY['accessories']) - IGNORE_CATEGORY
class AsyncStylistAgent: class AsyncStylistAgent:
def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str): def __init__(self, local_db, max_len: int, gemini_model_name: str, outfit_id=str):
# self.outfit_items: List[Dict[str, str]] = [] # self.outfit_items: List[Dict[str, str]] = []
@@ -145,7 +150,7 @@ class AsyncStylistAgent:
``` ```
* `action`: Must always be `"recommend_item"` until the outfit is complete. * `action`: Must always be `"recommend_item"` until the outfit is complete.
* `category`: Must be an unused category from the following list: {CLOTHING_CATEGORY} (strictly no repeats, per the Category Uniqueness Mandate). * `category`: Must be an unused category from the following list: {list(CLOTHING_CATEGORY)} (strictly no repeats, per the Category Uniqueness Mandate).
* `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include: * `description`: This must be an **extremely detailed and precise** description of the item. This description is used for **high-accuracy vector search** in the database and must include:
* **Color** (e.g., milk tea, pure white, dark gray) * **Color** (e.g., milk tea, pure white, dark gray)
* **Fit/Silhouette** (e.g., Oversize, loose, slim-fit) * **Fit/Silhouette** (e.g., Oversize, loose, slim-fit)
@@ -193,7 +198,7 @@ class AsyncStylistAgent:
--- ---
## STRICT RULES ## STRICT RULES
1. **Batch Recommendation**: Do NOT recommend items one by one. You must output the **COMPLETE LIST** of accessories (e.g., jewelry, bag, watch, hat) in a single response using the 'recommended_accessories' list. 1. **Batch Recommendation**: Do NOT recommend items one by one. You must output the **COMPLETE LIST** of accessories (e.g., jewelry, bag, watch, hat) in a single response using the 'recommended_accessories' list.
2. **Allowed Categories**: Select only from: {ACCESSORY_CATEGORY}. 2. **Allowed Categories**: Select only from: {list(ACCESSORY_CATEGORY)}.
3. **Harmony & Constraints**: 3. **Harmony & Constraints**:
- The accessories must complement the [Current Outfit Base]. - The accessories must complement the [Current Outfit Base].
- Strictly follow the [Accessories Style Guide] regarding metals (gold/silver), numbers, and prohibited items. - Strictly follow the [Accessories Style Guide] regarding metals (gold/silver), numbers, and prohibited items.
@@ -295,40 +300,46 @@ class AsyncStylistAgent:
print(f"Raw response: {response_text}") print(f"Raw response: {response_text}")
return None return None
def _get_next_item(self, item_description: str, category: str, occasions: List[str], batch_source: str = "2025_q4", gender: str = "female") -> Optional[Dict[str, str]]: def _get_next_item(self, item_description: str, category: str, occasions: List[str], batch_sources: List[str] = [], gender: str = "female") -> Optional[Dict[str, str]]:
""" """
1. 根据描述生成嵌入。 1. 根据描述生成嵌入。
2. 查询本地数据库以找到最佳匹配项。 2. 查询本地数据库以找到最佳匹配项。
3. 模拟 Agent 审核匹配项(这里简化为总是通过)。 3. 模拟 Agent 审核匹配项(这里简化为总是通过)。
""" """
# 1. 生成查询嵌入
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False)
# 2. 执行查询,并过滤类别
try: try:
# 1. 生成查询嵌入 results = self.local_db.get_matched_item(
query_embedding = self.local_db.get_clip_embedding(item_description, is_image=False) query_embedding,
category,
occasions=occasions,
batch_sources=batch_sources,
gender=gender,
n_results=1
)
except ValueError as e:
print(f"检测到无效参数错误:{e}")
results = []
# 2. 执行查询,并过滤类别 if not results:
results = self.local_db.get_matched_item(query_embedding, category, occasions=occasions, batch_source=batch_source, gender=gender, n_results=1) print(f"数据库中未找到符合 '{category}' 和描述的单品。")
if not results:
print(f"❌ 数据库中未找到符合 '{category}' 和描述的单品。")
return None
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
best_meta = results[0] # 第一个 batch 的第一个 metadata
item_id = best_meta['item_id'].replace("_img", "")
return {
"item_id": item_id, # 从 metadata 字典中安全获取
"category": best_meta['category'],
"gpt_description": item_description,
'description': best_meta['description'],
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
# 这里假设 item_id 就是文件名的一部分
"image_path": os.path.join(f"{item_id}.jpg")
}
except Exception as e:
print(f"An error occurred during item retrieval: {e}")
return None return None
# 3. 模拟 Agent 审核(实际应用中,你需要将图片发回给 Agent进行审核)
best_meta = results[0] # 第一个 batch 的第一个 metadata
item_id = best_meta['item_id'].replace("_img", "")
return {
"item_id": item_id, # 从 metadata 字典中安全获取
"category": best_meta['category'],
"gpt_description": item_description,
'description': best_meta['description'],
# 假设 'item_path' 存储在 metadata 中,或从 'item_id' 推导
# 这里假设 item_id 就是文件名的一部分
"image_path": os.path.join(f"{item_id}.jpg")
}
def _build_user_input(self, recommend_acc=False) -> str: def _build_user_input(self, recommend_acc=False) -> str:
"""构建发送给 Gemini 的用户输入,包含已选单品信息。""" """构建发送给 Gemini 的用户输入,包含已选单品信息。"""
if not self.outfit_items: if not self.outfit_items:
@@ -353,7 +364,7 @@ class AsyncStylistAgent:
response = post_request(url=callback_url, data=json.dumps(response_data), headers=self.headers) response = post_request(url=callback_url, data=json.dumps(response_data), headers=self.headers)
logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}") logger.info(f"request data {response_data} | JAVA callback info -> status:{response.status_code} | message:{response.text}")
async def run_styling_process(self, request_summary, occasions, stylist_name, batch_source="2025_q4", start_outfit=[], user_id="test", callback_url="", gender: str = "male"): async def run_styling_process(self, request_summary, occasions, stylist_name, start_outfit=[], batch_sources=[], user_id="test", callback_url="", gender: str = "male"):
self.outfit_items = start_outfit self.outfit_items = start_outfit
"""主流程控制循环。""" """主流程控制循环。"""
print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---") print(f"--- Starting Agent (Outfit ID: {self.outfit_id}) ---")
@@ -412,7 +423,7 @@ class AsyncStylistAgent:
continue continue
# 4b. 在本地 DB 中查询单品 # 4b. 在本地 DB 中查询单品
new_item = self._get_next_item(description, category, occasions, batch_source, gender) new_item = self._get_next_item(description, category, occasions, batch_sources, gender)
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]: if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
self.post_operation( self.post_operation(
response_data, response_data,
@@ -460,7 +471,7 @@ class AsyncStylistAgent:
continue continue
# 4b. 在本地 DB 中查询单品 # 4b. 在本地 DB 中查询单品
new_item = self._get_next_item(description, category, occasions, batch_source, gender) new_item = self._get_next_item(description, category, occasions, batch_sources, gender)
if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]: if not new_item or new_item['item_id'] in [x['item_id'] for x in self.outfit_items]:
continue continue
else: else:

View File

@@ -8,13 +8,12 @@ from PIL import Image
from typing import List, Dict, Any from typing import List, Dict, Any
from transformers import CLIPProcessor, CLIPModel from transformers import CLIPProcessor, CLIPModel
from app.taxonomy import CATEGORY, OCCASION from app.taxonomy import OCCASION, ALL_CATEGORY
class VectorDatabase(): class VectorDatabase():
def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str): def __init__(self, vector_db_dir: str, collection_name: str, embedding_model_name: str):
self.client = chromadb.PersistentClient(path=vector_db_dir) self.client = chromadb.PersistentClient(path=vector_db_dir)
self.collection = self.client.get_or_create_collection( self.collection = self.client.get_or_create_collection(
name=collection_name, name=collection_name,
configuration={ configuration={
@@ -23,17 +22,9 @@ class VectorDatabase():
} }
} }
) )
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device) self.model = CLIPModel.from_pretrained(embedding_model_name).to(self.device)
self.processor = CLIPProcessor.from_pretrained(embedding_model_name) self.processor = CLIPProcessor.from_pretrained(embedding_model_name)
# self.cache_filtered_ids = self.load_filtered_ids([
# {"item_group_id": {"$ne": "Clothing"}},
# {"item_group_id": {"$ne": "Shoes"}},
# {"modality": "image"}
# ])
# self.total_count = len(self.cache_filtered_ids)
def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]: def get_clip_embedding(self, data: str | Image.Image, is_image: bool) -> List[float]:
"""生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。""" """生成图像或文本的 CLIP 嵌入,并进行 L2 归一化。"""
@@ -57,46 +48,32 @@ class VectorDatabase():
features = features / features.norm(p=2, dim=-1, keepdim=True) features = features / features.norm(p=2, dim=-1, keepdim=True)
return features.cpu().numpy().flatten().tolist() return features.cpu().numpy().flatten().tolist()
def query_local_db(self, embedding: List[float], category: str, occasions: List[str] = [], n_results: int = 3) -> List[Dict[str, Any]]:
"""
基于嵌入向量在本地数据库中查询相似单品。
实际应执行 ChromaDB 查询,并根据 category 进行过滤(metadatas)。
"""
for occasion in occasions:
where_clauses = {
"$and": [
{"category": category},
{"modality": "image"},
{"batch_source": '2025_q4'}
]
}
if occasion not in OCCASION:
continue
else:
where_clauses['$and'].append({occasion: 1})
results = self.collection.query(
query_embeddings=[embedding],
n_results=n_results,
where=where_clauses,
include=['metadatas', 'distances']
)
return results
def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_source: str = "2025_q4", gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]: def get_matched_item(self, embedding: List[float], category: str, occasions: List[str] = [], batch_sources: List[str] = [], gender: str = 'female', n_results: int = 1) -> List[Dict[str, Any]]:
if category not in ALL_CATEGORY:
raise ValueError(f"Recommended {category} is not valid.")
and_conditions = [
{"category": category},
{"modality": "image"},
{"$or": [
{"gender": gender},
{"gender": "unisex"},
]}
]
if batch_sources and len(batch_sources) > 0:
source_conditions = []
for source in batch_sources:
source_conditions.append({"batch_source": source})
# 将 Batch Source 的 OR 子句添加到主 AND 条件中
and_conditions.append({"$or": source_conditions})
results = self.collection.query( results = self.collection.query(
query_embeddings=[embedding], query_embeddings=[embedding],
n_results=500, n_results=500,
where={ where={"$and": and_conditions},
"$and": [ include=['metadatas', 'distances'],
{"category": category},
{"modality": "image"},
{"gender": gender},
{"batch_source": batch_source}
]
},
include=['metadatas', 'distances']
) )
if not results['ids'][0]: if not results['ids'][0]:
return [] return []
@@ -124,7 +101,7 @@ class VectorDatabase():
score_occ = score_occ / count if count else 0.0 score_occ = score_occ / count if count else 0.0
final_score = 0.6 * score_vec + 0.3 * score_occ final_score = 0.6 * score_vec + 0.4 * score_occ
final_scores.append(final_score) final_scores.append(final_score)
scores_arr = np.array(final_scores) scores_arr = np.array(final_scores)
@@ -139,79 +116,3 @@ class VectorDatabase():
sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样 sampled_index = np.random.choice(a=len(results['ids'][0]), p=probabilities, size=n_results, replace=False) # 不重复采样
sampled_items = [metadatas[i] for i in sampled_index] sampled_items = [metadatas[i] for i in sampled_index]
return sampled_items return sampled_items
def load_filtered_ids(self, filter_item):
# print("\n--- 初始化阶段:加载所有符合条件的 ID ---")
start_time = time.time()
FILTER_CRITERIA = {
"$and": filter_item
}
MAX_LIMIT = 100000
try:
# 获取所有符合条件的 ID
all_ids_results = self.collection.get(
where=FILTER_CRITERIA,
limit=MAX_LIMIT,
include=[]
)
all_matched_ids = all_ids_results['ids']
# print(f"🎉 成功加载 {len(all_matched_ids)} 个 ID 到缓存。")
print(time.time() - start_time)
return all_matched_ids
except Exception as e:
print(f"❌ 初始化失败:获取 ID 列表时发生错误: {e}")
return []
def random_get_accessories(self, ids):
# 2. 调用 ChromaDB只查询这一个 ID 的详细信息
try:
final_results = self.collection.get(
ids=ids,
include=["metadatas"] # 你只需要元数据
)
# 提取结果
if final_results['ids']:
return final_results
else:
return None
except Exception as e:
print(f"❌ 获取最终记录时发生错误: {e}")
return None
if __name__ == '__main__':
stylist = {
'text': "gold necklace",
'count': 2,
'category': "Jewelry"
}
max_len = 5
local_db = VectorDatabase(vector_db_dir="/workspace/lc_stylist_agent/db", collection_name="lc_clothing_embedding", embedding_model_name="openai/clip-vit-base-patch32")
A = local_db.load_filtered_ids([
{"item_group_id": {"$ne": "Clothing"}},
{"item_group_id": {"$ne": "Shoes"}},
{"modality": "image"}
])
# print(db.random_get_accessories())
start_time = time.time()
X = local_db.random_get_accessories(['ELI699_img'])
print(X)
print(time.time() - start_time)
# query_embedding = local_db.get_clip_embedding(stylist['text'], is_image=False)
#
# results = local_db.query_local_db(query_embedding, stylist['category'], n_results=10)
# # 2. 从结果集中抽 stylist['count'] 个item
# stylist_item = random.choices(results['metadatas'][0], k=stylist['count'])
# stylist_item_ids = [item_id['item_id'] for item_id in stylist_item]
#
# # 3. 从随机库中抽取配饰总数达到9件 需过滤掉已经抽中的item
# accessories_count = 9 - max_len - stylist['count']
#
# random_single_ids = random.choices(list(set(local_db.cache_filtered_ids) - set([f"{i}_img" for i in stylist_item_ids])), k=accessories_count)
# random_items = local_db.random_get_accessories(random_single_ids)['metadatas']
# all_items = stylist_item + random_items

View File

@@ -1,15 +1,4 @@
# 这个文件用来储存所有的category和occasion这是标准文件。 # 这个文件用来储存所有的category和occasion这是标准文件。
CATEGORY = [
'shoes', 'bags', 'dresses', 'tops', 'pants', 'skirts', 'outerwear', 'swimwear', 'suits',
'watches', 'sunglasses', 'belts', 'hats', 'jewelry', 'neckties', 'scarves & shawls'
]
CLOTHING_CATEGORY = [
'shoes', 'bags', 'dresses', 'tops', 'pants', 'skirts', 'outerwear', 'swimwear'
]
ACCESSORY_CATEGORY = [
'watches', 'sunglasses', 'belts', 'hats', 'jewelry', 'neckties', 'scarves & shawls'
]
OCCASION = [ OCCASION = [
"Casual", "Formal", "Activewear", "Resort", "Evening", "Outdoor", "Casual", "Formal", "Activewear", "Resort", "Evening", "Outdoor",
"Business / workwear", "Cocktail / Semi-Formal", "Black Tie / White Tie", "Business / workwear", "Cocktail / Semi-Formal", "Black Tie / White Tie",
@@ -17,3 +6,54 @@ OCCASION = [
"Travel / Transit", "Athleisure", "Beach / Swim", "Ski / Snow / Mountain", "Travel / Transit", "Athleisure", "Beach / Swim", "Ski / Snow / Mountain",
"Garden Party / Daytime Event" "Garden Party / Daytime Event"
] ]
CATEGORY = {
'clothing': [
'coats',
'jackets',
'blazers',
'puffer',
'cardigan',
'sweater',
'shirts',
't-shirts',
'pullover',
'polos',
'bodysuits',
'dresses',
'skirts',
'jeans',
'shorts',
'leggings',
'jumpsuits',
'swimwear',
],
'shoes': [
'sneakers',
'formal shoes',
'heels',
'flats',
'sandals',
'slides',
'boots',
],
'bags': [
'bags'
],
'accessories': [
'necklaces',
'bracelets',
'jewellery',
'eyewear',
'scarves',
'hats',
'gloves',
'belts',
'socks',
'watches'
'ties',
]
}
ALL_CATEGORY = sum(CATEGORY.values(), [])
IGNORE_CATEGORY = ['socks']

92
data_ingestion/README.md Normal file
View File

@@ -0,0 +1,92 @@
## Steps
1. Prepare products-all.json and image_data (folder) using javascript to download. These files should be saved in `./data/BATCH_SOURCE` which is a new folder. Give a new batch_source id to each new incoming data.
1. Run `process_item.py` to categorize category, gender and occasions for each data. Output to `./data/{BATCH_SOURCE}/metadata_extraction.json`. This should be running on H200 device.
3. Organize all data and then embed them into db locally using `run_ingestion.py`
## Raw Data Structure
```json
## products-all.json
{
"id": "BUL808",
"name": "SARAH ZHUANG - 'Click & Link' diamond 18k gold earrings",
"brand": "SARAH ZHUANG",
"category": "Fine Jewellery And Watches",
"subcategory": "General",
"price": 17500,
"currency": "HKD",
"description": "Sarah Zhuang's Click & Link earrings embrace the allure of geometry. Forged into elegant rectangles with one side encrusted with diamonds, this gold pair will certainly elevate your cocktail ensembles.",
"tags": [
"sarah zhuang",
"fine jewellery and watches",
"in-stock",
"new",
"sarah",
"zhuang",
"'click",
"link'",
"diamond"
],
"imageUrl": "https://media.lanecrawford.com/B/U/L/BUL808_in_xl.jpg",
"url": "https://www.lanecrawford.com.hk/product/sarah-zhuang/-click-link-diamond-18k-gold-earrings/_/BUL808/product.lc?utm_medium=embed&utm_source=ai-recommended&utm_campaign=2025-christmas_lc_ai-recommended",
"color": "YELLOW GOLD",
"groupName": "Fine Jewellery",
"deptName": "Women's Fine Jewellery",
"onlineBU": "Fine Jewellery",
"stockAvailability": true
}
```
## Example in `metadata_extraction.json`
```json
"EOJ367": {
"category": "shoes",
"gender": "female",
"applicable_occasions": [
"Casual",
"Outdoor",
"Travel / Transit"
],
"inappropriate_occasions": [
"Formal",
"Black Tie / White Tie",
"Bridal / Wedding",
"Business / workwear",
"Cocktail / Semi-Formal"
]
}
```
## Metadata in Vector Database
```json
{
'item_id': 'EOJ128',
'category': 'sunglasses',
'gender': 'unisex',
'modality': 'image',
'brand': 'CELINE',
'color': 'BROWN',
'description': "Immerse yourself in the depth of classic style with CELINE\'s Tortoiseshell Logo Sunglasses. Featuring a rich, tortoiseshell acetate frame and adorned with the iconic CELINE logo in gold, these sunglasses are a testament to timeless elegance and luxury. Perfect for those who appreciate a sophisticated aesthetic, they offer optimal UV protection while ensuring you remain at the forefront of fashion.",
'tags': 'celine,accessories,in-stock,new,maxi,triomphe,acetate,round',
'price': 4500,
'url': 'https://www.lanecrawford.com.hk/product/celine/maxi-triomphe-acetate-round-sunglasses/_/EOJ128/product.lc?utm_medium=embed&utm_source=ai-recommended&utm_campaign=2025-christmas_lc_ai-recommended',
'batch_source': '2025_q4',
'Outdoor': 0,
'Ski / Snow / Mountain': 0,
'Festival / Concert': 0,
'Activewear': 0,
'Casual': 1,
'Cocktail / Semi-Formal': -1,
'Formal': -1,
'Party / Clubbing': 0,
'Evening': 0,
'Travel / Transit': 0,
'Beach / Swim': 0,
'Garden Party / Daytime Event': 1,
'Black Tie / White Tie': -1,
'Resort': 1,
'Athleisure': 0,
'Business / workwear': -1,
'Bridal / Wedding': -1,
}
```

View File

@@ -0,0 +1,280 @@
import torch
import os
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import json
from tqdm import tqdm
from app.taxonomy import OCCASION, CATEGORY, ALL_CATEGORY
# data config
BATCH_SOURCE = '2025_q4'
RAW_DATA_PATH = f'./data/{BATCH_SOURCE}/products-all.json'
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
# MLLM config
MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"
DEVICE = "cuda:0" # 确保设备设置正确,与您的 Traceback 匹配
BATCH_SIZE = 50
OUTPUT_FILE = f'./data/{BATCH_SOURCE}/metadata_extraction.json'
# Load Model
processor = AutoProcessor.from_pretrained(MODEL_NAME)
if processor.tokenizer.padding_side != 'left':
processor.tokenizer.padding_side = 'left'
print(f"Set tokenizer padding_side to '{processor.tokenizer.padding_side}' for correct generation.")
model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16).to(DEVICE)
model.eval()
# Load Data
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
data = json.load(file)
EXAMPLE_1_INFO = """
Product Name: ARMARIUM - Loren Wool Blend Tube Skirt
Category: Clothing / Bottoms
Color: RED
Description: Cut from cardinal-red virgin wool, Armarium's Loren skirt wields tailoring's exactitude in a column of colour. The low-slung waist and clean tube line are punctuated by a razor back slit—stride from boardroom to candlelit bar with modern hauteur.
Tags: armarium, clothing, in-stock, new, loren, wool, blend, tube
"""
EXAMPLE_1_JSON = json.dumps({
"category": "skirts",
"gender": "female",
"applicable_occasions": [
"Business/workwear", "Evening", "Cocktail / Semi-Formal", "Party / Clubbing", "Formal"
],
"inappropriate_occasions": [
"Activewear", "Beach / Swim", "Athleisure", "Ski / Snow / Mountain", "Casual"
]
}, indent=4)
# 示例 2胸针 (Pin)
EXAMPLE_2_INFO = """
Product Name: TATEOSSIAN - Mayfair 18K Yellow Gold Rhodium Plated Sterling Silver Peg Pin
Category: Accessories / Accessories
Color: MULTI
Description: Crafted from 18k yellow gold and rhodium-plated sterling silver, this unique pins has been artfully finished with Tateossian's signature diamond engraving pattern.
Tags: tateossian, accessories, in-stock, new, mayfair, yellow, gold, rhodium
"""
EXAMPLE_2_JSON = json.dumps({
"category": "jewelry",
"gender": "female",
"applicable_occasions": [
"Formal", "Black Tie / White Tie", "Bridal / Wedding", "Business/workwear", "Cocktail / Semi-Formal"
],
"inappropriate_occasions": [
"Casual", "Activewear", "Beach / Swim", "Outdoor", "Athleisure", "Ski / Snow / Mountain"
]
}, indent=4)
# --- 2. 构造对话格式 Prompt ---
BOS_TOKEN = "<|begin_of_text|>"
EOS_TOKEN = "<|eot_id|>"
SYSTEM_HEADER = "<|start_header_id|>system<|end_header_id|>\n"
USER_HEADER = "<|start_header_id|>user<|end_header_id|>\n"
ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n"
IMAGE_TOKEN = "<|image|>"
def format_product_info(product):
tags_str = ", ".join(product.get('tags', []))
info = (
f"Product Name: {product.get('name', 'N/A')}\n"
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
f"Color: {product.get('color', 'N/A')}\n"
f"Description: {product.get('description', '')}\n"
f"Tags: {tags_str}",
f"groupName: {product.get('groupName', 'N/A')}\n"
f"onlineBU: {product.get('onlineBU', 'N/A')}\n"
)
return info
def generate_full_prompt(product_info, raw_category):
if raw_category == 'Fine Jewellery And Watches':
category = 'accessories'
else:
category = raw_category.lower()
subcategory_list = CATEGORY.get(category)
SYSTEM_PROMPT = f"""You are an expert fashion AI assistant. Your task is to analyze the provided product image and product details to:
1. determine the suitable occasions for wearing or using the item. You must choose occasions ONLY from the following strict list: {json.dumps(OCCASION, indent=4)}. Only relevant suitable or inappropriate occasions should be selected.
2. categorize it into suitable category in strict list: {json.dumps(subcategory_list)}.
3. categorize it into appropriate gender in ["female", "male", "unisex"]
Output Format:
Return ONLY a valid JSON object with four keys: "category", "gender", "applicable_occasions" and "inappropriate_occasions". Do not include any analysis or extra text outside of the final JSON object.
"""
# 组合对话序列
dialogue_prompt = (
# 1. System Instruction
f"{BOS_TOKEN}{SYSTEM_HEADER}{SYSTEM_PROMPT}{EOS_TOKEN}"
# 2. Example 1 (Few-Shot Round 1)
# 格式: <|start_header_id|>user<|end_header_id|>\n<|image|>\n{Text Instruction}<|eot_id|>
f"{USER_HEADER}\n{EXAMPLE_1_INFO}{EOS_TOKEN}"
f"{ASSISTANT_HEADER}{EXAMPLE_1_JSON}{EOS_TOKEN}"
# 3. Example 2 (Few-Shot Round 2)
f"{USER_HEADER}\n{EXAMPLE_2_INFO}{EOS_TOKEN}"
f"{ASSISTANT_HEADER}{EXAMPLE_2_JSON}{EOS_TOKEN}"
# 4. Target Item (Target Query)
f"{USER_HEADER}{IMAGE_TOKEN}\nInput Data:\n{product_info}{EOS_TOKEN}"
f"{ASSISTANT_HEADER}" # 最后的 Assistant Header 告诉模型从这里开始生成
)
return dialogue_prompt
# 2. 加载数据
products = data['products']
product_list = [
product for product in products
if product.get('category') in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']
and os.path.exists(os.path.join(IMAGE_DIR, f"{product.get('id')}.jpg"))
]
def validate_results():
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, 'r') as f:
final_results = json.load(f)
else:
final_results = {}
unfinished_ids = []
for product in product_list:
item_id = product.get('id')
if item_id not in final_results.keys():
unfinished_ids.append(product)
else:
processed_item = final_results[item_id]
category = processed_item.get("category")
gender = processed_item.get("gender")
if category not in ALL_CATEGORY:
unfinished_ids.append(product)
if gender not in ['female', 'male', 'unisex']:
unfinished_ids.append(product)
return unfinished_ids, final_results
attemps = 0
while attemps < 3:
attemps += 1
unfinished_products, final_results = validate_results()
completion_ratio = len(unfinished_products) / len(product_list)
if (completion_ratio > 0.95):
print("valid results surpass 95%. Finish Now.")
break
else:
print(f"Start {attemps} categorization process. Current ratio: {completion_ratio * 100}%")
try:
# 按照 BATCH_SIZE 进行切片迭代
for i in tqdm(range(0, len(unfinished_products), BATCH_SIZE)):
batch_samples = unfinished_products[i:i + BATCH_SIZE]
target_images = []
target_prompts = []
target_products_in_batch = []
# 准备当前批次的输入数据
for product in batch_samples:
product_id = product['id']
raw_category = product.get('category')
image_path = os.path.join(IMAGE_DIR, f"{product_id}.jpg")
try:
# 收集图片、Prompt 和产品数据
image = Image.open(image_path).convert("RGB")
product_info = format_product_info(product)
full_prompt = generate_full_prompt(product_info, raw_category)
target_images.append(image)
target_prompts.append(full_prompt)
target_products_in_batch.append(product)
except Exception as e:
# 跳过任何加载失败的单个样本
print(f"Skipping product {product_id} due to loading error: {e}")
continue
if not target_images:
continue # 如果整个批次都没有有效图片,跳过
# 4. 批量推理
print(f"\nProcessing batch {i//BATCH_SIZE + 1}/{int(len(unfinished_products)/BATCH_SIZE)+1} (Size: {len(target_images)})...")
# 处理器输入:使用嵌套列表 [[img1], [img2], ...]
inputs = processor(
images=[[img] for img in target_images],
text=target_prompts,
return_tensors="pt",
padding=True,
truncation=True
).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=False
)
# 5. 批量解码和解析结果
input_lengths = inputs.input_ids.size(1)
for j in range(len(target_products_in_batch)):
product = target_products_in_batch[j]
product_id = product['id']
# 提取当前 item 的生成结果
# 注意: outputs 是 [batch_size, sequence_length]
newly_generated_tokens = outputs[j, input_lengths:]
generated_text = processor.decode(newly_generated_tokens, skip_special_tokens=True)
# 清理和解析
if generated_text.endswith(processor.tokenizer.eos_token):
generated_text = generated_text[:-len(processor.tokenizer.eos_token)]
try:
start_idx = generated_text.find('{')
end_idx = generated_text.rfind('}') + 1
if start_idx == -1 or end_idx == -1:
raise ValueError("JSON start or end delimiter not found.")
json_str = generated_text[start_idx:end_idx]
result_dict = json.loads(json_str)
final_results[product_id] = result_dict
except Exception as e:
print(f"ID {product_id}: FAILED to parse JSON. Raw Output: {generated_text.strip()}")
final_results[product_id] = {"error": str(e), "raw_output": generated_text.strip()}
# 显存清理(可选,但在长任务中推荐)
del inputs, outputs
torch.cuda.empty_cache()
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
json.dump(final_results, f, indent=4, ensure_ascii=False)
# 6. 保存最终结果
print("\n\n=== ALL BATCHES COMPLETE ===")
# 保存最终结果到 JSON 文件
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
json.dump(final_results, f, indent=4, ensure_ascii=False)
print(f"Results saved to {OUTPUT_FILE}")
except Exception as e:
print(f"\n--- Execution Error ---")
print(f"An unexpected error occurred: {e}")

View File

@@ -0,0 +1,178 @@
import chromadb
import os
import json
from copy import deepcopy
import torch
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from app.taxonomy import CATEGORY, ALL_CATEGORY, OCCASION
BATCH_SOURCE = '2025_q4'
DATA_DIR = f'./data/{BATCH_SOURCE}'
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
RAW_DATA_PATH = f'{DATA_DIR}/products-all.json'
CATEGORIZED_METADATA_PATH = f'{DATA_DIR}/metadata_extraction.json'
## Load data
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
raw_data = json.load(file)
with open(CATEGORIZED_METADATA_PATH, 'r', encoding='utf-8') as file:
categorized_data = json.load(file)
# Create Collection
client = chromadb.PersistentClient(path='./data/db')
collection = client.get_or_create_collection(
name="lc_clothing_embedding"
)
# if you wish to delete some item, uncomment following
# results = collection.delete(
# where={
# "batch_source": BATCH_SOURCE
# }
# )
# Load model
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def format_product_info(product):
tags_str = ", ".join(product.get('tags', []))
info = (
f"Product Name: {product.get('name', 'N/A')}\n"
f"Brand: {product.get('brand', 'N/A')}\n"
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
f"Color: {product.get('color', 'N/A')}\n"
f"Description: {product.get('description', '')}\n"
f"Tags: {tags_str}"
f"GroupName: {product.get('groupName', 'N/A')}\n"
f"DetpName: {product.get('deptName', 'N/A')}\n"
f"OnlineBU: {product.get('onlineBU', 'N/A')}\n"
)
return info
# Combine all data together
new_category = {}
valid_count = 0
all_count = 0
for raw_item in tqdm(raw_data['products']):
item_id = raw_item.get('id')
if not item_id:
print(f"This item {raw_item} did not have a valid item_id")
continue
raw_category = raw_item.get("category")
if raw_category not in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']:
continue
image_path = os.path.join(IMAGE_DIR, f"{item_id}.jpg")
if not os.path.exists(image_path):
print(f"Image not found: {image_path}")
continue
# All above is raw data error, it's not our business.
all_count += 1
processed_item = categorized_data.get(item_id, {})
if not processed_item:
print(f"{item_id} has not been categorized. It does not exist in {CATEGORIZED_METADATA_PATH}")
continue
category = processed_item.get("category")
gender = processed_item.get("gender")
applicable_occasions = processed_item.get("applicable_occasions", [])
inappropriate_occasions = processed_item.get("inappropriate_occasions", [])
if category not in ALL_CATEGORY:
print(f"{item_id}'s category, {category}, does not valid.")
if category not in new_category:
new_category[category] = [item_id]
else:
new_category[category].append(item_id)
continue
if gender not in ['female', 'male', 'unisex']:
print(f"{item_id}'s gender is not valid in {['female', 'male', 'unisex']}")
continue
occasions = applicable_occasions + inappropriate_occasions
if not set(occasions).issubset(set(OCCASION)):
# print(f"{item_id}'s some occasions is not vaild. \n Invalid occasion is {set(occasions).difference(set(OCCASION))}")
applicable_occasions = [o for o in applicable_occasions if o in OCCASION]
inappropriate_occasions = [o for o in inappropriate_occasions if o in OCCASION]
description = raw_item.get('description', "")
if not description:
f"{item_id}'s description is lost."
continue
url = raw_item.get('url', '')
if not url:
f"{item_id}'s url is lost."
continue
valid_count += 1
# Prepare metadata for db
item_img_metadata = {
"item_id": item_id,
"category": category,
"description": description,
"gender": gender,
'brand': raw_item.get('brand', ''),
'color': raw_item.get('color', ''),
'price': raw_item.get('price', ''),
'tags': ",".join(raw_item.get('tags', [])),
'url': url,
"modality": "image",
"batch_source": BATCH_SOURCE
}
for occasion in OCCASION:
item_img_metadata[occasion] = 0
for occasion in applicable_occasions:
item_img_metadata[occasion] = 1
for occasion in inappropriate_occasions:
item_img_metadata[occasion] = -1
item_txt_metadata = deepcopy(item_img_metadata)
item_txt_metadata["modality"] = "text"
# Get image feature
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
img_features = model.get_image_features(**inputs)
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
img_embedding = img_features.cpu().numpy().flatten().tolist()
# Get text feature
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
txt_features = model.get_text_features(**inputs)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
product_info = format_product_info(raw_item)
# 插入到 ChromaDB
collection.add(
ids=[f'{item_id}_img', f'{item_id}_txt'],
documents=[product_info, product_info],
embeddings=[img_embedding, txt_embedding],
metadatas=[item_img_metadata, item_txt_metadata],
)
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")
print(f'Found new category for consideration: {new_category}')

BIN
docs/Edi.docx Normal file

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

BIN
docs/vera.docx Normal file

Binary file not shown.