Files

182 lines
6.2 KiB
Python
Raw Permalink Normal View History

2025-12-10 17:27:56 +08:00
import chromadb
import os
import json
from copy import deepcopy
import torch
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from app.taxonomy import ALL_SUBCATEGORY_LIST, OCCASION
2025-12-10 17:27:56 +08:00
BATCH_SOURCE = '2025_q4'
DATA_DIR = f'./data/{BATCH_SOURCE}'
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
RAW_DATA_PATH = f'{DATA_DIR}/products-all.json'
CATEGORIZED_METADATA_PATH = f'{DATA_DIR}/metadata_extraction.json'
ADD_TEXT_EMBEDDING = False
2025-12-10 17:27:56 +08:00
## Load data
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
raw_data = json.load(file)
with open(CATEGORIZED_METADATA_PATH, 'r', encoding='utf-8') as file:
categorized_data = json.load(file)
# Create Collection
client = chromadb.PersistentClient(path='./data/db')
collection = client.get_or_create_collection(
name="lc_clothing_embedding"
)
# if you wish to delete some item, uncomment following
results = collection.delete(
where={
"batch_source": BATCH_SOURCE
}
)
2025-12-10 17:27:56 +08:00
# Load model
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def format_product_info(product):
tags_str = ", ".join(product.get('tags', []))
info = (
f"Product Name: {product.get('name', 'N/A')}\n"
f"Brand: {product.get('brand', 'N/A')}\n"
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
f"Color: {product.get('color', 'N/A')}\n"
f"Description: {product.get('description', '')}\n"
f"Tags: {tags_str}"
f"GroupName: {product.get('groupName', 'N/A')}\n"
f"DetpName: {product.get('deptName', 'N/A')}\n"
f"OnlineBU: {product.get('onlineBU', 'N/A')}\n"
)
return info
def raw_category_mapping(raw_category: str) -> str:
if raw_category == 'Fine Jewellery And Watches':
return 'accessories'
else:
return raw_category.lower()
2025-12-10 17:27:56 +08:00
# Combine all data together
valid_count = 0
all_count = 0
for raw_item in tqdm(raw_data['products']):
item_id = raw_item.get('id')
if not item_id:
print(f"This item {raw_item} did not have a valid item_id")
continue
raw_category = raw_item.get("category")
if raw_category not in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']:
continue
image_path = os.path.join(IMAGE_DIR, f"{item_id}.jpg")
if not os.path.exists(image_path):
print(f"Image not found: {image_path}")
continue
# All above is raw data error, it's not our business.
all_count += 1
processed_item = categorized_data.get(item_id, {})
if not processed_item:
print(f"{item_id} has not been categorized. It does not exist in {CATEGORIZED_METADATA_PATH}")
continue
category = raw_category_mapping(raw_category)
subcategory = processed_item.get("subcategory")
2025-12-10 17:27:56 +08:00
gender = processed_item.get("gender")
applicable_occasions = processed_item.get("applicable_occasions", [])
inappropriate_occasions = processed_item.get("inappropriate_occasions", [])
if subcategory not in ALL_SUBCATEGORY_LIST:
2025-12-10 17:27:56 +08:00
print(f"{item_id}'s category, {category}, does not valid.")
if gender not in ['female', 'male', 'unisex']:
print(f"{item_id}'s gender is not valid in {['female', 'male', 'unisex']}")
continue
occasions = applicable_occasions + inappropriate_occasions
if not set(occasions).issubset(set(OCCASION)):
# print(f"{item_id}'s some occasions is not vaild. \n Invalid occasion is {set(occasions).difference(set(OCCASION))}")
applicable_occasions = [o for o in applicable_occasions if o in OCCASION]
inappropriate_occasions = [o for o in inappropriate_occasions if o in OCCASION]
description = raw_item.get('description', "")
if not description:
f"{item_id}'s description is lost."
continue
url = raw_item.get('url', '')
if not url:
f"{item_id}'s url is lost."
continue
valid_count += 1
# Prepare metadata for db
item_img_metadata = {
"item_id": item_id,
"category": category,
"subcategory": subcategory,
2025-12-10 17:27:56 +08:00
"description": description,
"gender": gender,
'brand': raw_item.get('brand', ''),
'color': raw_item.get('color', ''),
'price': raw_item.get('price', ''),
'tags': ",".join(raw_item.get('tags', [])),
'url': url,
"modality": "image",
"batch_source": BATCH_SOURCE
}
for occasion in OCCASION:
item_img_metadata[occasion] = 0
for occasion in applicable_occasions:
item_img_metadata[occasion] = 1
for occasion in inappropriate_occasions:
item_img_metadata[occasion] = -1
# Get image feature
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
img_features = model.get_image_features(**inputs)
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
img_embedding = img_features.cpu().numpy().flatten().tolist()
product_info = format_product_info(raw_item)
# 插入到 ChromaDB
collection.add(
ids=[f'{item_id}_img'],
documents=[product_info],
embeddings=[img_embedding],
metadatas=[item_img_metadata],
2025-12-10 17:27:56 +08:00
)
if ADD_TEXT_EMBEDDING:
item_txt_metadata = deepcopy(item_img_metadata)
item_txt_metadata["modality"] = "text"
# Get text feature
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
txt_features = model.get_text_features(**inputs)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
collection.add(
ids=[f'{item_id}_txt'],
documents=[product_info],
embeddings=[txt_embedding],
metadatas=[item_txt_metadata],
)
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")