Files
lc_stylist_agent/data_ingestion/run_ingestion.py

182 lines
6.2 KiB
Python

import chromadb
import os
import json
from copy import deepcopy
import torch
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from app.taxonomy import ALL_SUBCATEGORY_LIST, OCCASION
BATCH_SOURCE = '2025_q4'
DATA_DIR = f'./data/{BATCH_SOURCE}'
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
RAW_DATA_PATH = f'{DATA_DIR}/products-all.json'
CATEGORIZED_METADATA_PATH = f'{DATA_DIR}/metadata_extraction.json'
ADD_TEXT_EMBEDDING = False
## Load data
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
raw_data = json.load(file)
with open(CATEGORIZED_METADATA_PATH, 'r', encoding='utf-8') as file:
categorized_data = json.load(file)
# Create Collection
client = chromadb.PersistentClient(path='./data/db')
collection = client.get_or_create_collection(
name="lc_clothing_embedding"
)
# if you wish to delete some item, uncomment following
results = collection.delete(
where={
"batch_source": BATCH_SOURCE
}
)
# Load model
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def format_product_info(product):
tags_str = ", ".join(product.get('tags', []))
info = (
f"Product Name: {product.get('name', 'N/A')}\n"
f"Brand: {product.get('brand', 'N/A')}\n"
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
f"Color: {product.get('color', 'N/A')}\n"
f"Description: {product.get('description', '')}\n"
f"Tags: {tags_str}"
f"GroupName: {product.get('groupName', 'N/A')}\n"
f"DetpName: {product.get('deptName', 'N/A')}\n"
f"OnlineBU: {product.get('onlineBU', 'N/A')}\n"
)
return info
def raw_category_mapping(raw_category: str) -> str:
if raw_category == 'Fine Jewellery And Watches':
return 'accessories'
else:
return raw_category.lower()
# Combine all data together
valid_count = 0
all_count = 0
for raw_item in tqdm(raw_data['products']):
item_id = raw_item.get('id')
if not item_id:
print(f"This item {raw_item} did not have a valid item_id")
continue
raw_category = raw_item.get("category")
if raw_category not in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']:
continue
image_path = os.path.join(IMAGE_DIR, f"{item_id}.jpg")
if not os.path.exists(image_path):
print(f"Image not found: {image_path}")
continue
# All above is raw data error, it's not our business.
all_count += 1
processed_item = categorized_data.get(item_id, {})
if not processed_item:
print(f"{item_id} has not been categorized. It does not exist in {CATEGORIZED_METADATA_PATH}")
continue
category = raw_category_mapping(raw_category)
subcategory = processed_item.get("subcategory")
gender = processed_item.get("gender")
applicable_occasions = processed_item.get("applicable_occasions", [])
inappropriate_occasions = processed_item.get("inappropriate_occasions", [])
if subcategory not in ALL_SUBCATEGORY_LIST:
print(f"{item_id}'s category, {category}, does not valid.")
if gender not in ['female', 'male', 'unisex']:
print(f"{item_id}'s gender is not valid in {['female', 'male', 'unisex']}")
continue
occasions = applicable_occasions + inappropriate_occasions
if not set(occasions).issubset(set(OCCASION)):
# print(f"{item_id}'s some occasions is not vaild. \n Invalid occasion is {set(occasions).difference(set(OCCASION))}")
applicable_occasions = [o for o in applicable_occasions if o in OCCASION]
inappropriate_occasions = [o for o in inappropriate_occasions if o in OCCASION]
description = raw_item.get('description', "")
if not description:
f"{item_id}'s description is lost."
continue
url = raw_item.get('url', '')
if not url:
f"{item_id}'s url is lost."
continue
valid_count += 1
# Prepare metadata for db
item_img_metadata = {
"item_id": item_id,
"category": category,
"subcategory": subcategory,
"description": description,
"gender": gender,
'brand': raw_item.get('brand', ''),
'color': raw_item.get('color', ''),
'price': raw_item.get('price', ''),
'tags': ",".join(raw_item.get('tags', [])),
'url': url,
"modality": "image",
"batch_source": BATCH_SOURCE
}
for occasion in OCCASION:
item_img_metadata[occasion] = 0
for occasion in applicable_occasions:
item_img_metadata[occasion] = 1
for occasion in inappropriate_occasions:
item_img_metadata[occasion] = -1
# Get image feature
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
img_features = model.get_image_features(**inputs)
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
img_embedding = img_features.cpu().numpy().flatten().tolist()
product_info = format_product_info(raw_item)
# 插入到 ChromaDB
collection.add(
ids=[f'{item_id}_img'],
documents=[product_info],
embeddings=[img_embedding],
metadatas=[item_img_metadata],
)
if ADD_TEXT_EMBEDDING:
item_txt_metadata = deepcopy(item_img_metadata)
item_txt_metadata["modality"] = "text"
# Get text feature
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
txt_features = model.get_text_features(**inputs)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
collection.add(
ids=[f'{item_id}_txt'],
documents=[product_info],
embeddings=[txt_embedding],
metadatas=[item_txt_metadata],
)
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")