reconstruct whole recommendation pipeline and add new rec mode one-ask-for-all
This commit is contained in:
@@ -40,7 +40,7 @@
|
||||
## Example in `metadata_extraction.json`
|
||||
```json
|
||||
"EOJ367": {
|
||||
"category": "shoes",
|
||||
"subcategory": "necklaces",
|
||||
"gender": "female",
|
||||
"applicable_occasions": [
|
||||
"Casual",
|
||||
@@ -60,33 +60,34 @@
|
||||
## Metadata in Vector Database
|
||||
```json
|
||||
{
|
||||
'item_id': 'EOJ128',
|
||||
'category': 'sunglasses',
|
||||
'gender': 'unisex',
|
||||
'modality': 'image',
|
||||
'brand': 'CELINE',
|
||||
'color': 'BROWN',
|
||||
'description': "Immerse yourself in the depth of classic style with CELINE\'s Tortoiseshell Logo Sunglasses. Featuring a rich, tortoiseshell acetate frame and adorned with the iconic CELINE logo in gold, these sunglasses are a testament to timeless elegance and luxury. Perfect for those who appreciate a sophisticated aesthetic, they offer optimal UV protection while ensuring you remain at the forefront of fashion.",
|
||||
'tags': 'celine,accessories,in-stock,new,maxi,triomphe,acetate,round',
|
||||
'price': 4500,
|
||||
'url': 'https://www.lanecrawford.com.hk/product/celine/maxi-triomphe-acetate-round-sunglasses/_/EOJ128/product.lc?utm_medium=embed&utm_source=ai-recommended&utm_campaign=2025-christmas_lc_ai-recommended',
|
||||
'batch_source': '2025_q4',
|
||||
'Outdoor': 0,
|
||||
'Ski / Snow / Mountain': 0,
|
||||
'Festival / Concert': 0,
|
||||
'Activewear': 0,
|
||||
'Casual': 1,
|
||||
'Cocktail / Semi-Formal': -1,
|
||||
'Formal': -1,
|
||||
'Party / Clubbing': 0,
|
||||
'Evening': 0,
|
||||
'Travel / Transit': 0,
|
||||
'Beach / Swim': 0,
|
||||
'Garden Party / Daytime Event': 1,
|
||||
'Black Tie / White Tie': -1,
|
||||
'Resort': 1,
|
||||
'Athleisure': 0,
|
||||
'Business / workwear': -1,
|
||||
'Bridal / Wedding': -1,
|
||||
"item_id": "EOJ128",
|
||||
"category": "accessories",
|
||||
"subcategory": "eyewear",
|
||||
"gender": "unisex",
|
||||
"modality": "image",
|
||||
"brand": "CELINE",
|
||||
"color": "BROWN",
|
||||
"description": "Immerse yourself in the depth of classic style with CELINE's Tortoiseshell Logo Sunglasses. Featuring a rich, tortoiseshell acetate frame and adorned with the iconic CELINE logo in gold, these sunglasses are a testament to timeless elegance and luxury. Perfect for those who appreciate a sophisticated aesthetic, they offer optimal UV protection while ensuring you remain at the forefront of fashion.",
|
||||
"tags": "celine,accessories,in-stock,new,maxi,triomphe,acetate,round",
|
||||
"price": 4500,
|
||||
"url": "https://www.lanecrawford.com.hk/product/celine/maxi-triomphe-acetate-round-sunglasses/_/EOJ128/product.lc?utm_medium=embed&utm_source=ai-recommended&utm_campaign=2025-christmas_lc_ai-recommended",
|
||||
"batch_source": "2025_q4",
|
||||
"Outdoor": 0,
|
||||
"Ski / Snow / Mountain": 0,
|
||||
"Festival / Concert": 0,
|
||||
"Activewear": 0,
|
||||
"Casual": 1,
|
||||
"Cocktail / Semi-Formal": -1,
|
||||
"Formal": -1,
|
||||
"Party / Clubbing": 0,
|
||||
"Evening": 0,
|
||||
"Travel / Transit": 0,
|
||||
"Beach / Swim": 0,
|
||||
"Garden Party / Daytime Event": 1,
|
||||
"Black Tie / White Tie": -1,
|
||||
"Resort": 1,
|
||||
"Athleisure": 0,
|
||||
"Business / workwear": -1,
|
||||
"Bridal / Wedding": -1,
|
||||
}
|
||||
```
|
||||
@@ -5,7 +5,7 @@ from PIL import Image
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
from app.taxonomy import OCCASION, CATEGORY, ALL_CATEGORY
|
||||
from app.taxonomy import OCCASION, FASHION_TAXONOMY, ALL_SUBCATEGORY_LIST
|
||||
|
||||
|
||||
# data config
|
||||
@@ -42,7 +42,7 @@ Description: Cut from cardinal-red virgin wool, Armarium's Loren skirt wields ta
|
||||
Tags: armarium, clothing, in-stock, new, loren, wool, blend, tube
|
||||
"""
|
||||
EXAMPLE_1_JSON = json.dumps({
|
||||
"category": "skirts",
|
||||
"subcategory": "skirts",
|
||||
"gender": "female",
|
||||
"applicable_occasions": [
|
||||
"Business/workwear", "Evening", "Cocktail / Semi-Formal", "Party / Clubbing", "Formal"
|
||||
@@ -61,7 +61,7 @@ Description: Crafted from 18k yellow gold and rhodium-plated sterling silver, th
|
||||
Tags: tateossian, accessories, in-stock, new, mayfair, yellow, gold, rhodium
|
||||
"""
|
||||
EXAMPLE_2_JSON = json.dumps({
|
||||
"category": "jewelry",
|
||||
"subcategory": "jewelry",
|
||||
"gender": "female",
|
||||
"applicable_occasions": [
|
||||
"Formal", "Black Tie / White Tie", "Bridal / Wedding", "Business/workwear", "Cocktail / Semi-Formal"
|
||||
@@ -94,20 +94,24 @@ def format_product_info(product):
|
||||
return info
|
||||
|
||||
|
||||
def generate_full_prompt(product_info, raw_category):
|
||||
def raw_category_mapping(raw_category: str) -> str:
|
||||
if raw_category == 'Fine Jewellery And Watches':
|
||||
category = 'accessories'
|
||||
return 'accessories'
|
||||
else:
|
||||
category = raw_category.lower()
|
||||
subcategory_list = CATEGORY.get(category)
|
||||
return raw_category.lower()
|
||||
|
||||
|
||||
def generate_full_prompt(product_info, raw_category):
|
||||
category = raw_category_mapping(raw_category)
|
||||
subcategory_list = FASHION_TAXONOMY.get(category)
|
||||
|
||||
SYSTEM_PROMPT = f"""You are an expert fashion AI assistant. Your task is to analyze the provided product image and product details to:
|
||||
1. determine the suitable occasions for wearing or using the item. You must choose occasions ONLY from the following strict list: {json.dumps(OCCASION, indent=4)}. Only relevant suitable or inappropriate occasions should be selected.
|
||||
2. categorize it into suitable category in strict list: {json.dumps(subcategory_list)}.
|
||||
2. categorize it into suitable subcategory in strict list: {json.dumps(subcategory_list)}.
|
||||
3. categorize it into appropriate gender in ["female", "male", "unisex"]
|
||||
|
||||
Output Format:
|
||||
Return ONLY a valid JSON object with four keys: "category", "gender", "applicable_occasions" and "inappropriate_occasions". Do not include any analysis or extra text outside of the final JSON object.
|
||||
Return ONLY a valid JSON object with four keys: "subcategory", "gender", "applicable_occasions" and "inappropriate_occasions". Do not include any analysis or extra text outside of the final JSON object.
|
||||
"""
|
||||
|
||||
# 组合对话序列
|
||||
@@ -140,37 +144,36 @@ product_list = [
|
||||
]
|
||||
|
||||
|
||||
def validate_results():
|
||||
if os.path.exists(OUTPUT_FILE):
|
||||
with open(OUTPUT_FILE, 'r') as f:
|
||||
final_results = json.load(f)
|
||||
else:
|
||||
final_results = {}
|
||||
def validate_result(result_dict):
|
||||
subcategory = result_dict.get("subcategory")
|
||||
gender = result_dict.get("gender")
|
||||
|
||||
unfinished_ids = []
|
||||
for product in product_list:
|
||||
item_id = product.get('id')
|
||||
if item_id not in final_results.keys():
|
||||
unfinished_ids.append(product)
|
||||
else:
|
||||
processed_item = final_results[item_id]
|
||||
category = processed_item.get("category")
|
||||
gender = processed_item.get("gender")
|
||||
if not subcategory or not gender:
|
||||
return False
|
||||
|
||||
if subcategory not in ALL_SUBCATEGORY_LIST:
|
||||
return False
|
||||
|
||||
if category not in ALL_CATEGORY:
|
||||
unfinished_ids.append(product)
|
||||
if gender not in ['female', 'male', 'unisex']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if os.path.exists(OUTPUT_FILE):
|
||||
with open(OUTPUT_FILE, 'r') as f:
|
||||
final_results = json.load(f)
|
||||
else:
|
||||
final_results = {}
|
||||
|
||||
if gender not in ['female', 'male', 'unisex']:
|
||||
unfinished_ids.append(product)
|
||||
return unfinished_ids, final_results
|
||||
|
||||
attemps = 0
|
||||
while attemps < 3:
|
||||
unfinished_products = [product for product in product_list if product.get('id') not in final_results.keys()]
|
||||
attemps += 1
|
||||
unfinished_products, final_results = validate_results()
|
||||
completion_ratio = len(unfinished_products) / len(product_list)
|
||||
if (completion_ratio > 0.95):
|
||||
print("valid results surpass 95%. Finish Now.")
|
||||
completion_ratio = len(final_results) / len(product_list)
|
||||
if (completion_ratio > 0.85):
|
||||
print("valid results surpass 85%. Finish Now.")
|
||||
break
|
||||
else:
|
||||
print(f"Start {attemps} categorization process. Current ratio: {completion_ratio * 100}%")
|
||||
@@ -252,11 +255,11 @@ while attemps < 3:
|
||||
json_str = generated_text[start_idx:end_idx]
|
||||
result_dict = json.loads(json_str)
|
||||
|
||||
final_results[product_id] = result_dict
|
||||
if validate_result(result_dict):
|
||||
final_results[product_id] = result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"ID {product_id}: FAILED to parse JSON. Raw Output: {generated_text.strip()}")
|
||||
final_results[product_id] = {"error": str(e), "raw_output": generated_text.strip()}
|
||||
|
||||
# 显存清理(可选,但在长任务中推荐)
|
||||
del inputs, outputs
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
|
||||
|
||||
|
||||
import chromadb
|
||||
import os
|
||||
import json
|
||||
@@ -11,7 +8,7 @@ from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
from app.taxonomy import CATEGORY, ALL_CATEGORY, OCCASION
|
||||
from app.taxonomy import ALL_SUBCATEGORY_LIST, OCCASION
|
||||
|
||||
|
||||
BATCH_SOURCE = '2025_q4'
|
||||
@@ -20,6 +17,7 @@ IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
|
||||
|
||||
RAW_DATA_PATH = f'{DATA_DIR}/products-all.json'
|
||||
CATEGORIZED_METADATA_PATH = f'{DATA_DIR}/metadata_extraction.json'
|
||||
ADD_TEXT_EMBEDDING = False
|
||||
|
||||
## Load data
|
||||
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
|
||||
@@ -36,11 +34,11 @@ collection = client.get_or_create_collection(
|
||||
)
|
||||
|
||||
# if you wish to delete some item, uncomment following
|
||||
# results = collection.delete(
|
||||
# where={
|
||||
# "batch_source": BATCH_SOURCE
|
||||
# }
|
||||
# )
|
||||
results = collection.delete(
|
||||
where={
|
||||
"batch_source": BATCH_SOURCE
|
||||
}
|
||||
)
|
||||
|
||||
# Load model
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
@@ -63,9 +61,13 @@ def format_product_info(product):
|
||||
)
|
||||
return info
|
||||
|
||||
def raw_category_mapping(raw_category: str) -> str:
|
||||
if raw_category == 'Fine Jewellery And Watches':
|
||||
return 'accessories'
|
||||
else:
|
||||
return raw_category.lower()
|
||||
|
||||
# Combine all data together
|
||||
new_category = {}
|
||||
valid_count = 0
|
||||
all_count = 0
|
||||
for raw_item in tqdm(raw_data['products']):
|
||||
@@ -91,18 +93,14 @@ for raw_item in tqdm(raw_data['products']):
|
||||
print(f"{item_id} has not been categorized. It does not exist in {CATEGORIZED_METADATA_PATH}")
|
||||
continue
|
||||
|
||||
category = processed_item.get("category")
|
||||
category = raw_category_mapping(raw_category)
|
||||
subcategory = processed_item.get("subcategory")
|
||||
gender = processed_item.get("gender")
|
||||
applicable_occasions = processed_item.get("applicable_occasions", [])
|
||||
inappropriate_occasions = processed_item.get("inappropriate_occasions", [])
|
||||
|
||||
if category not in ALL_CATEGORY:
|
||||
if subcategory not in ALL_SUBCATEGORY_LIST:
|
||||
print(f"{item_id}'s category, {category}, does not valid.")
|
||||
if category not in new_category:
|
||||
new_category[category] = [item_id]
|
||||
else:
|
||||
new_category[category].append(item_id)
|
||||
continue
|
||||
|
||||
if gender not in ['female', 'male', 'unisex']:
|
||||
print(f"{item_id}'s gender is not valid in {['female', 'male', 'unisex']}")
|
||||
@@ -129,6 +127,7 @@ for raw_item in tqdm(raw_data['products']):
|
||||
item_img_metadata = {
|
||||
"item_id": item_id,
|
||||
"category": category,
|
||||
"subcategory": subcategory,
|
||||
"description": description,
|
||||
"gender": gender,
|
||||
'brand': raw_item.get('brand', ''),
|
||||
@@ -146,10 +145,6 @@ for raw_item in tqdm(raw_data['products']):
|
||||
for occasion in inappropriate_occasions:
|
||||
item_img_metadata[occasion] = -1
|
||||
|
||||
item_txt_metadata = deepcopy(item_img_metadata)
|
||||
item_txt_metadata["modality"] = "text"
|
||||
|
||||
|
||||
# Get image feature
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
inputs = processor(images=image, return_tensors="pt").to(device)
|
||||
@@ -158,21 +153,30 @@ for raw_item in tqdm(raw_data['products']):
|
||||
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
|
||||
img_embedding = img_features.cpu().numpy().flatten().tolist()
|
||||
|
||||
# Get text feature
|
||||
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
|
||||
with torch.no_grad():
|
||||
txt_features = model.get_text_features(**inputs)
|
||||
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
|
||||
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
|
||||
|
||||
product_info = format_product_info(raw_item)
|
||||
# 插入到 ChromaDB
|
||||
collection.add(
|
||||
ids=[f'{item_id}_img', f'{item_id}_txt'],
|
||||
documents=[product_info, product_info],
|
||||
embeddings=[img_embedding, txt_embedding],
|
||||
metadatas=[item_img_metadata, item_txt_metadata],
|
||||
ids=[f'{item_id}_img'],
|
||||
documents=[product_info],
|
||||
embeddings=[img_embedding],
|
||||
metadatas=[item_img_metadata],
|
||||
)
|
||||
|
||||
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")
|
||||
print(f'Found new category for consideration: {new_category}')
|
||||
if ADD_TEXT_EMBEDDING:
|
||||
item_txt_metadata = deepcopy(item_img_metadata)
|
||||
item_txt_metadata["modality"] = "text"
|
||||
|
||||
# Get text feature
|
||||
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
|
||||
with torch.no_grad():
|
||||
txt_features = model.get_text_features(**inputs)
|
||||
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
|
||||
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
|
||||
collection.add(
|
||||
ids=[f'{item_id}_txt'],
|
||||
documents=[product_info],
|
||||
embeddings=[txt_embedding],
|
||||
metadatas=[item_txt_metadata],
|
||||
)
|
||||
|
||||
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")
|
||||
Reference in New Issue
Block a user