reconstruct whole recommendation pipeline and add new rec mode one-ask-for-all

This commit is contained in:
pangkaicheng
2025-12-12 17:37:07 +08:00
parent 0e9546aa1a
commit 85390d5e6d
12 changed files with 684 additions and 565 deletions

View File

@@ -1,6 +1,3 @@
import chromadb
import os
import json
@@ -11,7 +8,7 @@ from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from app.taxonomy import CATEGORY, ALL_CATEGORY, OCCASION
from app.taxonomy import ALL_SUBCATEGORY_LIST, OCCASION
BATCH_SOURCE = '2025_q4'
@@ -20,6 +17,7 @@ IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
RAW_DATA_PATH = f'{DATA_DIR}/products-all.json'
CATEGORIZED_METADATA_PATH = f'{DATA_DIR}/metadata_extraction.json'
ADD_TEXT_EMBEDDING = False
## Load data
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
@@ -36,11 +34,11 @@ collection = client.get_or_create_collection(
)
# if you wish to delete some item, uncomment following
# results = collection.delete(
# where={
# "batch_source": BATCH_SOURCE
# }
# )
results = collection.delete(
where={
"batch_source": BATCH_SOURCE
}
)
# Load model
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
@@ -63,9 +61,13 @@ def format_product_info(product):
)
return info
def raw_category_mapping(raw_category: str) -> str:
if raw_category == 'Fine Jewellery And Watches':
return 'accessories'
else:
return raw_category.lower()
# Combine all data together
new_category = {}
valid_count = 0
all_count = 0
for raw_item in tqdm(raw_data['products']):
@@ -91,18 +93,14 @@ for raw_item in tqdm(raw_data['products']):
print(f"{item_id} has not been categorized. It does not exist in {CATEGORIZED_METADATA_PATH}")
continue
category = processed_item.get("category")
category = raw_category_mapping(raw_category)
subcategory = processed_item.get("subcategory")
gender = processed_item.get("gender")
applicable_occasions = processed_item.get("applicable_occasions", [])
inappropriate_occasions = processed_item.get("inappropriate_occasions", [])
if category not in ALL_CATEGORY:
if subcategory not in ALL_SUBCATEGORY_LIST:
print(f"{item_id}'s category, {category}, does not valid.")
if category not in new_category:
new_category[category] = [item_id]
else:
new_category[category].append(item_id)
continue
if gender not in ['female', 'male', 'unisex']:
print(f"{item_id}'s gender is not valid in {['female', 'male', 'unisex']}")
@@ -129,6 +127,7 @@ for raw_item in tqdm(raw_data['products']):
item_img_metadata = {
"item_id": item_id,
"category": category,
"subcategory": subcategory,
"description": description,
"gender": gender,
'brand': raw_item.get('brand', ''),
@@ -146,10 +145,6 @@ for raw_item in tqdm(raw_data['products']):
for occasion in inappropriate_occasions:
item_img_metadata[occasion] = -1
item_txt_metadata = deepcopy(item_img_metadata)
item_txt_metadata["modality"] = "text"
# Get image feature
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
@@ -158,21 +153,30 @@ for raw_item in tqdm(raw_data['products']):
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
img_embedding = img_features.cpu().numpy().flatten().tolist()
# Get text feature
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
txt_features = model.get_text_features(**inputs)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
product_info = format_product_info(raw_item)
# 插入到 ChromaDB
collection.add(
ids=[f'{item_id}_img', f'{item_id}_txt'],
documents=[product_info, product_info],
embeddings=[img_embedding, txt_embedding],
metadatas=[item_img_metadata, item_txt_metadata],
ids=[f'{item_id}_img'],
documents=[product_info],
embeddings=[img_embedding],
metadatas=[item_img_metadata],
)
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")
print(f'Found new category for consideration: {new_category}')
if ADD_TEXT_EMBEDDING:
item_txt_metadata = deepcopy(item_img_metadata)
item_txt_metadata["modality"] = "text"
# Get text feature
inputs = processor(text=[description], return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
txt_features = model.get_text_features(**inputs)
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
txt_embedding = txt_features.cpu().numpy().flatten().tolist()
collection.add(
ids=[f'{item_id}_txt'],
documents=[product_info],
embeddings=[txt_embedding],
metadatas=[item_txt_metadata],
)
print(f"Final valid ratio is {valid_count / all_count * 100}%. Total number is {all_count}, Valid number is {valid_count}")