182 lines
6.2 KiB
Python
182 lines
6.2 KiB
Python
import chromadb
|
|
import os
|
|
import json
|
|
from copy import deepcopy
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
|
|
from app.taxonomy import ALL_SUBCATEGORY_LIST, OCCASION
|
|
|
|
|
|
BATCH_SOURCE = '2025_q4'
|
|
DATA_DIR = f'./data/{BATCH_SOURCE}'
|
|
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
|
|
|
|
RAW_DATA_PATH = f'{DATA_DIR}/products-all.json'
|
|
CATEGORIZED_METADATA_PATH = f'{DATA_DIR}/metadata_extraction.json'
|
|
ADD_TEXT_EMBEDDING = False
|
|
|
|
## Load data
|
|
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
|
|
raw_data = json.load(file)
|
|
|
|
with open(CATEGORIZED_METADATA_PATH, 'r', encoding='utf-8') as file:
|
|
categorized_data = json.load(file)
|
|
|
|
|
|
# Create Collection
|
|
client = chromadb.PersistentClient(path='./data/db')
|
|
collection = client.get_or_create_collection(
|
|
name="lc_clothing_embedding"
|
|
)
|
|
|
|
# if you wish to delete some item, uncomment following
|
|
results = collection.delete(
|
|
where={
|
|
"batch_source": BATCH_SOURCE
|
|
}
|
|
)
|
|
|
|
# Load model
|
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model.to(device)
|
|
|
|
def format_product_info(product):
|
|
tags_str = ", ".join(product.get('tags', []))
|
|
info = (
|
|
f"Product Name: {product.get('name', 'N/A')}\n"
|
|
f"Brand: {product.get('brand', 'N/A')}\n"
|
|
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
|
|
f"Color: {product.get('color', 'N/A')}\n"
|
|
f"Description: {product.get('description', '')}\n"
|
|
f"Tags: {tags_str}"
|
|
f"GroupName: {product.get('groupName', 'N/A')}\n"
|
|
f"DetpName: {product.get('deptName', 'N/A')}\n"
|
|
f"OnlineBU: {product.get('onlineBU', 'N/A')}\n"
|
|
)
|
|
return info
|
|
|
|
def raw_category_mapping(raw_category: str) -> str:
|
|
if raw_category == 'Fine Jewellery And Watches':
|
|
return 'accessories'
|
|
else:
|
|
return raw_category.lower()
|
|
|
|
# Combine all data together
|
|
valid_count = 0
|
|
all_count = 0
|
|
for raw_item in tqdm(raw_data['products']):
|
|
item_id = raw_item.get('id')
|
|
if not item_id:
|
|
print(f"This item {raw_item} did not have a valid item_id")
|
|
continue
|
|
|
|
raw_category = raw_item.get("category")
|
|
if raw_category not in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']:
|
|
continue
|
|
|
|
image_path = os.path.join(IMAGE_DIR, f"{item_id}.jpg")
|
|
if not os.path.exists(image_path):
|
|
print(f"Image not found: {image_path}")
|
|
continue
|
|
|
|
# All above is raw data error, it's not our business.
|
|
all_count += 1
|
|
|
|
processed_item = categorized_data.get(item_id, {})
|
|
if not processed_item:
|
|
print(f"{item_id} has not been categorized. It does not exist in {CATEGORIZED_METADATA_PATH}")
|
|
continue
|
|
|
|
category = raw_category_mapping(raw_category)
|
|
subcategory = processed_item.get("subcategory")
|
|
gender = processed_item.get("gender")
|
|
applicable_occasions = processed_item.get("applicable_occasions", [])
|
|
inappropriate_occasions = processed_item.get("inappropriate_occasions", [])
|
|
|
|
if subcategory not in ALL_SUBCATEGORY_LIST:
|
|
print(f"{item_id}'s category, {category}, does not valid.")
|
|
|
|
if gender not in ['female', 'male', 'unisex']:
|
|
print(f"{item_id}'s gender is not valid in {['female', 'male', 'unisex']}")
|
|
continue
|
|
|
|
occasions = applicable_occasions + inappropriate_occasions
|
|
if not set(occasions).issubset(set(OCCASION)):
|
|
# print(f"{item_id}'s some occasions is not vaild. \n Invalid occasion is {set(occasions).difference(set(OCCASION))}")
|
|
applicable_occasions = [o for o in applicable_occasions if o in OCCASION]
|
|
inappropriate_occasions = [o for o in inappropriate_occasions if o in OCCASION]
|
|
|
|
description = raw_item.get('description', "")
|
|
if not description:
|
|
f"{item_id}'s description is lost."
|
|
continue
|
|
|
|
url = raw_item.get('url', '')
|
|
if not url:
|
|
f"{item_id}'s url is lost."
|
|
continue
|
|
|
|
valid_count += 1
|
|
# Prepare metadata for db
|
|
item_img_metadata = {
|
|
"item_id": item_id,
|
|
"category": category,
|
|
"subcategory": subcategory,
|
|
"description": description,
|
|
"gender": gender,
|
|
'brand': raw_item.get('brand', ''),
|
|
'color': raw_item.get('color', ''),
|
|
'price': raw_item.get('price', ''),
|
|
'tags': ",".join(raw_item.get('tags', [])),
|
|
'url': url,
|
|
"modality": "image",
|
|
"batch_source": BATCH_SOURCE
|
|
}
|
|
for occasion in OCCASION:
|
|
item_img_metadata[occasion] = 0
|
|
for occasion in applicable_occasions:
|
|
item_img_metadata[occasion] = 1
|
|
for occasion in inappropriate_occasions:
|
|
item_img_metadata[occasion] = -1
|
|
|
|
# Get image feature
|
|
image = Image.open(image_path).convert("RGB")
|
|
inputs = processor(images=image, return_tensors="pt").to(device)
|
|
with torch.no_grad():
|
|
img_features = model.get_image_features(**inputs)
|
|
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
|
|
img_embedding = img_features.cpu().numpy().flatten().tolist()
|
|
|
|
product_info = format_product_info(raw_item)
|
|
# 插入到 ChromaDB
|
|
collection.add(
|
|
ids=[f'{item_id}_img'],
|
|
documents=[product_info],
|
|
embeddings=[img_embedding],
|
|
metadatas=[item_img_metadata],
|
|
)
|
|
|
|
if ADD_TEXT_EMBEDDING:
|
|
item_txt_metadata = deepcopy(item_img_metadata)
|
|
item_txt_metadata["modality"] = "text"
|
|
|
|
# Get text feature
|
|
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
|
|
with torch.no_grad():
|
|
txt_features = model.get_text_features(**inputs)
|
|
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
|
|
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
|
|
collection.add(
|
|
ids=[f'{item_id}_txt'],
|
|
documents=[product_info],
|
|
embeddings=[txt_embedding],
|
|
metadatas=[item_txt_metadata],
|
|
)
|
|
|
|
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}") |