Files
lc_stylist_agent/data_ingestion/process_item.py

284 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import os
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import json
from tqdm import tqdm
from app.taxonomy import OCCASION, FASHION_TAXONOMY, ALL_SUBCATEGORY_LIST
# data config
BATCH_SOURCE = '2025_q4'
RAW_DATA_PATH = f'./data/{BATCH_SOURCE}/products-all.json'
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
# MLLM config
MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"
DEVICE = "cuda:0" # 确保设备设置正确,与您的 Traceback 匹配
BATCH_SIZE = 50
OUTPUT_FILE = f'./data/{BATCH_SOURCE}/metadata_extraction.json'
# Load Model
processor = AutoProcessor.from_pretrained(MODEL_NAME)
if processor.tokenizer.padding_side != 'left':
processor.tokenizer.padding_side = 'left'
print(f"Set tokenizer padding_side to '{processor.tokenizer.padding_side}' for correct generation.")
model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16).to(DEVICE)
model.eval()
# Load Data
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
data = json.load(file)
EXAMPLE_1_INFO = """
Product Name: ARMARIUM - Loren Wool Blend Tube Skirt
Category: Clothing / Bottoms
Color: RED
Description: Cut from cardinal-red virgin wool, Armarium's Loren skirt wields tailoring's exactitude in a column of colour. The low-slung waist and clean tube line are punctuated by a razor back slit—stride from boardroom to candlelit bar with modern hauteur.
Tags: armarium, clothing, in-stock, new, loren, wool, blend, tube
"""
EXAMPLE_1_JSON = json.dumps({
"subcategory": "skirts",
"gender": "female",
"applicable_occasions": [
"Business/workwear", "Evening", "Cocktail / Semi-Formal", "Party / Clubbing", "Formal"
],
"inappropriate_occasions": [
"Activewear", "Beach / Swim", "Athleisure", "Ski / Snow / Mountain", "Casual"
]
}, indent=4)
# 示例 2胸针 (Pin)
EXAMPLE_2_INFO = """
Product Name: TATEOSSIAN - Mayfair 18K Yellow Gold Rhodium Plated Sterling Silver Peg Pin
Category: Accessories / Accessories
Color: MULTI
Description: Crafted from 18k yellow gold and rhodium-plated sterling silver, this unique pins has been artfully finished with Tateossian's signature diamond engraving pattern.
Tags: tateossian, accessories, in-stock, new, mayfair, yellow, gold, rhodium
"""
EXAMPLE_2_JSON = json.dumps({
"subcategory": "jewelry",
"gender": "female",
"applicable_occasions": [
"Formal", "Black Tie / White Tie", "Bridal / Wedding", "Business/workwear", "Cocktail / Semi-Formal"
],
"inappropriate_occasions": [
"Casual", "Activewear", "Beach / Swim", "Outdoor", "Athleisure", "Ski / Snow / Mountain"
]
}, indent=4)
# --- 2. 构造对话格式 Prompt ---
BOS_TOKEN = "<|begin_of_text|>"
EOS_TOKEN = "<|eot_id|>"
SYSTEM_HEADER = "<|start_header_id|>system<|end_header_id|>\n"
USER_HEADER = "<|start_header_id|>user<|end_header_id|>\n"
ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n"
IMAGE_TOKEN = "<|image|>"
def format_product_info(product):
tags_str = ", ".join(product.get('tags', []))
info = (
f"Product Name: {product.get('name', 'N/A')}\n"
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
f"Color: {product.get('color', 'N/A')}\n"
f"Description: {product.get('description', '')}\n"
f"Tags: {tags_str}",
f"groupName: {product.get('groupName', 'N/A')}\n"
f"onlineBU: {product.get('onlineBU', 'N/A')}\n"
)
return info
def raw_category_mapping(raw_category: str) -> str:
if raw_category == 'Fine Jewellery And Watches':
return 'accessories'
else:
return raw_category.lower()
def generate_full_prompt(product_info, raw_category):
category = raw_category_mapping(raw_category)
subcategory_list = FASHION_TAXONOMY.get(category)
SYSTEM_PROMPT = f"""You are an expert fashion AI assistant. Your task is to analyze the provided product image and product details to:
1. determine the suitable occasions for wearing or using the item. You must choose occasions ONLY from the following strict list: {json.dumps(OCCASION, indent=4)}. Only relevant suitable or inappropriate occasions should be selected.
2. categorize it into suitable subcategory in strict list: {json.dumps(subcategory_list)}.
3. categorize it into appropriate gender in ["female", "male", "unisex"]
Output Format:
Return ONLY a valid JSON object with four keys: "subcategory", "gender", "applicable_occasions" and "inappropriate_occasions". Do not include any analysis or extra text outside of the final JSON object.
"""
# 组合对话序列
dialogue_prompt = (
# 1. System Instruction
f"{BOS_TOKEN}{SYSTEM_HEADER}{SYSTEM_PROMPT}{EOS_TOKEN}"
# 2. Example 1 (Few-Shot Round 1)
# 格式: <|start_header_id|>user<|end_header_id|>\n<|image|>\n{Text Instruction}<|eot_id|>
f"{USER_HEADER}\n{EXAMPLE_1_INFO}{EOS_TOKEN}"
f"{ASSISTANT_HEADER}{EXAMPLE_1_JSON}{EOS_TOKEN}"
# 3. Example 2 (Few-Shot Round 2)
f"{USER_HEADER}\n{EXAMPLE_2_INFO}{EOS_TOKEN}"
f"{ASSISTANT_HEADER}{EXAMPLE_2_JSON}{EOS_TOKEN}"
# 4. Target Item (Target Query)
f"{USER_HEADER}{IMAGE_TOKEN}\nInput Data:\n{product_info}{EOS_TOKEN}"
f"{ASSISTANT_HEADER}" # 最后的 Assistant Header 告诉模型从这里开始生成
)
return dialogue_prompt
# 2. 加载数据
products = data['products']
product_list = [
product for product in products
if product.get('category') in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']
and os.path.exists(os.path.join(IMAGE_DIR, f"{product.get('id')}.jpg"))
]
def validate_result(result_dict):
subcategory = result_dict.get("subcategory")
gender = result_dict.get("gender")
if not subcategory or not gender:
return False
if subcategory not in ALL_SUBCATEGORY_LIST:
return False
if gender not in ['female', 'male', 'unisex']:
return False
return True
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, 'r') as f:
final_results = json.load(f)
else:
final_results = {}
attemps = 0
while attemps < 3:
unfinished_products = [product for product in product_list if product.get('id') not in final_results.keys()]
attemps += 1
completion_ratio = len(final_results) / len(product_list)
if (completion_ratio > 0.85):
print("valid results surpass 85%. Finish Now.")
break
else:
print(f"Start {attemps} categorization process. Current ratio: {completion_ratio * 100}%")
try:
# 按照 BATCH_SIZE 进行切片迭代
for i in tqdm(range(0, len(unfinished_products), BATCH_SIZE)):
batch_samples = unfinished_products[i:i + BATCH_SIZE]
target_images = []
target_prompts = []
target_products_in_batch = []
# 准备当前批次的输入数据
for product in batch_samples:
product_id = product['id']
raw_category = product.get('category')
image_path = os.path.join(IMAGE_DIR, f"{product_id}.jpg")
try:
# 收集图片、Prompt 和产品数据
image = Image.open(image_path).convert("RGB")
product_info = format_product_info(product)
full_prompt = generate_full_prompt(product_info, raw_category)
target_images.append(image)
target_prompts.append(full_prompt)
target_products_in_batch.append(product)
except Exception as e:
# 跳过任何加载失败的单个样本
print(f"Skipping product {product_id} due to loading error: {e}")
continue
if not target_images:
continue # 如果整个批次都没有有效图片,跳过
# 4. 批量推理
print(f"\nProcessing batch {i//BATCH_SIZE + 1}/{int(len(unfinished_products)/BATCH_SIZE)+1} (Size: {len(target_images)})...")
# 处理器输入:使用嵌套列表 [[img1], [img2], ...]
inputs = processor(
images=[[img] for img in target_images],
text=target_prompts,
return_tensors="pt",
padding=True,
truncation=True
).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=False
)
# 5. 批量解码和解析结果
input_lengths = inputs.input_ids.size(1)
for j in range(len(target_products_in_batch)):
product = target_products_in_batch[j]
product_id = product['id']
# 提取当前 item 的生成结果
# 注意: outputs 是 [batch_size, sequence_length]
newly_generated_tokens = outputs[j, input_lengths:]
generated_text = processor.decode(newly_generated_tokens, skip_special_tokens=True)
# 清理和解析
if generated_text.endswith(processor.tokenizer.eos_token):
generated_text = generated_text[:-len(processor.tokenizer.eos_token)]
try:
start_idx = generated_text.find('{')
end_idx = generated_text.rfind('}') + 1
if start_idx == -1 or end_idx == -1:
raise ValueError("JSON start or end delimiter not found.")
json_str = generated_text[start_idx:end_idx]
result_dict = json.loads(json_str)
if validate_result(result_dict):
final_results[product_id] = result_dict
except Exception as e:
print(f"ID {product_id}: FAILED to parse JSON. Raw Output: {generated_text.strip()}")
# 显存清理(可选,但在长任务中推荐)
del inputs, outputs
torch.cuda.empty_cache()
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
json.dump(final_results, f, indent=4, ensure_ascii=False)
# 6. 保存最终结果
print("\n\n=== ALL BATCHES COMPLETE ===")
# 保存最终结果到 JSON 文件
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
json.dump(final_results, f, indent=4, ensure_ascii=False)
print(f"Results saved to {OUTPUT_FILE}")
except Exception as e:
print(f"\n--- Execution Error ---")
print(f"An unexpected error occurred: {e}")