import chromadb import os import json from copy import deepcopy import torch from tqdm import tqdm from PIL import Image from transformers import CLIPProcessor, CLIPModel from app.taxonomy import ALL_SUBCATEGORY_LIST, OCCASION BATCH_SOURCE = '2025_q4' DATA_DIR = f'./data/{BATCH_SOURCE}' IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data' RAW_DATA_PATH = f'{DATA_DIR}/products-all.json' CATEGORIZED_METADATA_PATH = f'{DATA_DIR}/metadata_extraction.json' ADD_TEXT_EMBEDDING = False ## Load data with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file: raw_data = json.load(file) with open(CATEGORIZED_METADATA_PATH, 'r', encoding='utf-8') as file: categorized_data = json.load(file) # Create Collection client = chromadb.PersistentClient(path='./data/db') collection = client.get_or_create_collection( name="lc_clothing_embedding" ) # if you wish to delete some item, uncomment following results = collection.delete( where={ "batch_source": BATCH_SOURCE } ) # Load model processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def format_product_info(product): tags_str = ", ".join(product.get('tags', [])) info = ( f"Product Name: {product.get('name', 'N/A')}\n" f"Brand: {product.get('brand', 'N/A')}\n" f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n" f"Color: {product.get('color', 'N/A')}\n" f"Description: {product.get('description', '')}\n" f"Tags: {tags_str}" f"GroupName: {product.get('groupName', 'N/A')}\n" f"DetpName: {product.get('deptName', 'N/A')}\n" f"OnlineBU: {product.get('onlineBU', 'N/A')}\n" ) return info def raw_category_mapping(raw_category: str) -> str: if raw_category == 'Fine Jewellery And Watches': return 'accessories' else: return raw_category.lower() # Combine all data together valid_count = 0 all_count = 0 for raw_item in tqdm(raw_data['products']): item_id = raw_item.get('id') if not item_id: print(f"This item {raw_item} did not have a valid item_id") continue raw_category = raw_item.get("category") if raw_category not in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']: continue image_path = os.path.join(IMAGE_DIR, f"{item_id}.jpg") if not os.path.exists(image_path): print(f"Image not found: {image_path}") continue # All above is raw data error, it's not our business. all_count += 1 processed_item = categorized_data.get(item_id, {}) if not processed_item: print(f"{item_id} has not been categorized. It does not exist in {CATEGORIZED_METADATA_PATH}") continue category = raw_category_mapping(raw_category) subcategory = processed_item.get("subcategory") gender = processed_item.get("gender") applicable_occasions = processed_item.get("applicable_occasions", []) inappropriate_occasions = processed_item.get("inappropriate_occasions", []) if subcategory not in ALL_SUBCATEGORY_LIST: print(f"{item_id}'s category, {category}, does not valid.") if gender not in ['female', 'male', 'unisex']: print(f"{item_id}'s gender is not valid in {['female', 'male', 'unisex']}") continue occasions = applicable_occasions + inappropriate_occasions if not set(occasions).issubset(set(OCCASION)): # print(f"{item_id}'s some occasions is not vaild. \n Invalid occasion is {set(occasions).difference(set(OCCASION))}") applicable_occasions = [o for o in applicable_occasions if o in OCCASION] inappropriate_occasions = [o for o in inappropriate_occasions if o in OCCASION] description = raw_item.get('description', "") if not description: f"{item_id}'s description is lost." continue url = raw_item.get('url', '') if not url: f"{item_id}'s url is lost." continue valid_count += 1 # Prepare metadata for db item_img_metadata = { "item_id": item_id, "category": category, "subcategory": subcategory, "description": description, "gender": gender, 'brand': raw_item.get('brand', ''), 'color': raw_item.get('color', ''), 'price': raw_item.get('price', ''), 'tags': ",".join(raw_item.get('tags', [])), 'url': url, "modality": "image", "batch_source": BATCH_SOURCE } for occasion in OCCASION: item_img_metadata[occasion] = 0 for occasion in applicable_occasions: item_img_metadata[occasion] = 1 for occasion in inappropriate_occasions: item_img_metadata[occasion] = -1 # Get image feature image = Image.open(image_path).convert("RGB") inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): img_features = model.get_image_features(**inputs) img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) img_embedding = img_features.cpu().numpy().flatten().tolist() product_info = format_product_info(raw_item) # 插入到 ChromaDB collection.add( ids=[f'{item_id}_img'], documents=[product_info], embeddings=[img_embedding], metadatas=[item_img_metadata], ) if ADD_TEXT_EMBEDDING: item_txt_metadata = deepcopy(item_img_metadata) item_txt_metadata["modality"] = "text" # Get text feature inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device) with torch.no_grad(): txt_features = model.get_text_features(**inputs) txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) txt_embedding = txt_features.cpu().numpy().flatten().tolist() collection.add( ids=[f'{item_id}_txt'], documents=[product_info], embeddings=[txt_embedding], metadatas=[item_txt_metadata], ) print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")