Enable data auto process for new data
This commit is contained in:
178
data_ingestion/run_ingestion.py
Normal file
178
data_ingestion/run_ingestion.py
Normal file
@@ -0,0 +1,178 @@
|
||||
|
||||
|
||||
|
||||
import chromadb
|
||||
import os
|
||||
import json
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
|
||||
from app.taxonomy import CATEGORY, ALL_CATEGORY, OCCASION
|
||||
|
||||
|
||||
BATCH_SOURCE = '2025_q4'
|
||||
DATA_DIR = f'./data/{BATCH_SOURCE}'
|
||||
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
|
||||
|
||||
RAW_DATA_PATH = f'{DATA_DIR}/products-all.json'
|
||||
CATEGORIZED_METADATA_PATH = f'{DATA_DIR}/metadata_extraction.json'
|
||||
|
||||
## Load data
|
||||
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
|
||||
raw_data = json.load(file)
|
||||
|
||||
with open(CATEGORIZED_METADATA_PATH, 'r', encoding='utf-8') as file:
|
||||
categorized_data = json.load(file)
|
||||
|
||||
|
||||
# Create Collection
|
||||
client = chromadb.PersistentClient(path='./data/db')
|
||||
collection = client.get_or_create_collection(
|
||||
name="lc_clothing_embedding"
|
||||
)
|
||||
|
||||
# if you wish to delete some item, uncomment following
|
||||
# results = collection.delete(
|
||||
# where={
|
||||
# "batch_source": BATCH_SOURCE
|
||||
# }
|
||||
# )
|
||||
|
||||
# Load model
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model.to(device)
|
||||
|
||||
def format_product_info(product):
|
||||
tags_str = ", ".join(product.get('tags', []))
|
||||
info = (
|
||||
f"Product Name: {product.get('name', 'N/A')}\n"
|
||||
f"Brand: {product.get('brand', 'N/A')}\n"
|
||||
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
|
||||
f"Color: {product.get('color', 'N/A')}\n"
|
||||
f"Description: {product.get('description', '')}\n"
|
||||
f"Tags: {tags_str}"
|
||||
f"GroupName: {product.get('groupName', 'N/A')}\n"
|
||||
f"DetpName: {product.get('deptName', 'N/A')}\n"
|
||||
f"OnlineBU: {product.get('onlineBU', 'N/A')}\n"
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
# Combine all data together
|
||||
new_category = {}
|
||||
valid_count = 0
|
||||
all_count = 0
|
||||
for raw_item in tqdm(raw_data['products']):
|
||||
item_id = raw_item.get('id')
|
||||
if not item_id:
|
||||
print(f"This item {raw_item} did not have a valid item_id")
|
||||
continue
|
||||
|
||||
raw_category = raw_item.get("category")
|
||||
if raw_category not in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']:
|
||||
continue
|
||||
|
||||
image_path = os.path.join(IMAGE_DIR, f"{item_id}.jpg")
|
||||
if not os.path.exists(image_path):
|
||||
print(f"Image not found: {image_path}")
|
||||
continue
|
||||
|
||||
# All above is raw data error, it's not our business.
|
||||
all_count += 1
|
||||
|
||||
processed_item = categorized_data.get(item_id, {})
|
||||
if not processed_item:
|
||||
print(f"{item_id} has not been categorized. It does not exist in {CATEGORIZED_METADATA_PATH}")
|
||||
continue
|
||||
|
||||
category = processed_item.get("category")
|
||||
gender = processed_item.get("gender")
|
||||
applicable_occasions = processed_item.get("applicable_occasions", [])
|
||||
inappropriate_occasions = processed_item.get("inappropriate_occasions", [])
|
||||
|
||||
if category not in ALL_CATEGORY:
|
||||
print(f"{item_id}'s category, {category}, does not valid.")
|
||||
if category not in new_category:
|
||||
new_category[category] = [item_id]
|
||||
else:
|
||||
new_category[category].append(item_id)
|
||||
continue
|
||||
|
||||
if gender not in ['female', 'male', 'unisex']:
|
||||
print(f"{item_id}'s gender is not valid in {['female', 'male', 'unisex']}")
|
||||
continue
|
||||
|
||||
occasions = applicable_occasions + inappropriate_occasions
|
||||
if not set(occasions).issubset(set(OCCASION)):
|
||||
# print(f"{item_id}'s some occasions is not vaild. \n Invalid occasion is {set(occasions).difference(set(OCCASION))}")
|
||||
applicable_occasions = [o for o in applicable_occasions if o in OCCASION]
|
||||
inappropriate_occasions = [o for o in inappropriate_occasions if o in OCCASION]
|
||||
|
||||
description = raw_item.get('description', "")
|
||||
if not description:
|
||||
f"{item_id}'s description is lost."
|
||||
continue
|
||||
|
||||
url = raw_item.get('url', '')
|
||||
if not url:
|
||||
f"{item_id}'s url is lost."
|
||||
continue
|
||||
|
||||
valid_count += 1
|
||||
# Prepare metadata for db
|
||||
item_img_metadata = {
|
||||
"item_id": item_id,
|
||||
"category": category,
|
||||
"description": description,
|
||||
"gender": gender,
|
||||
'brand': raw_item.get('brand', ''),
|
||||
'color': raw_item.get('color', ''),
|
||||
'price': raw_item.get('price', ''),
|
||||
'tags': ",".join(raw_item.get('tags', [])),
|
||||
'url': url,
|
||||
"modality": "image",
|
||||
"batch_source": BATCH_SOURCE
|
||||
}
|
||||
for occasion in OCCASION:
|
||||
item_img_metadata[occasion] = 0
|
||||
for occasion in applicable_occasions:
|
||||
item_img_metadata[occasion] = 1
|
||||
for occasion in inappropriate_occasions:
|
||||
item_img_metadata[occasion] = -1
|
||||
|
||||
item_txt_metadata = deepcopy(item_img_metadata)
|
||||
item_txt_metadata["modality"] = "text"
|
||||
|
||||
|
||||
# Get image feature
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
inputs = processor(images=image, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
img_features = model.get_image_features(**inputs)
|
||||
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
|
||||
img_embedding = img_features.cpu().numpy().flatten().tolist()
|
||||
|
||||
# Get text feature
|
||||
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
|
||||
with torch.no_grad():
|
||||
txt_features = model.get_text_features(**inputs)
|
||||
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
|
||||
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
|
||||
|
||||
product_info = format_product_info(raw_item)
|
||||
# 插入到 ChromaDB
|
||||
collection.add(
|
||||
ids=[f'{item_id}_img', f'{item_id}_txt'],
|
||||
documents=[product_info, product_info],
|
||||
embeddings=[img_embedding, txt_embedding],
|
||||
metadatas=[item_img_metadata, item_txt_metadata],
|
||||
)
|
||||
|
||||
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")
|
||||
print(f'Found new category for consideration: {new_category}')
|
||||
Reference in New Issue
Block a user