2025-12-10 17:27:56 +08:00
|
|
|
|
import torch
|
|
|
|
|
|
import os
|
|
|
|
|
|
from transformers import AutoProcessor, AutoModelForVision2Seq
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
import json
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
2025-12-12 17:37:07 +08:00
|
|
|
|
from app.taxonomy import OCCASION, FASHION_TAXONOMY, ALL_SUBCATEGORY_LIST
|
2025-12-10 17:27:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# data config
|
|
|
|
|
|
BATCH_SOURCE = '2025_q4'
|
|
|
|
|
|
RAW_DATA_PATH = f'./data/{BATCH_SOURCE}/products-all.json'
|
|
|
|
|
|
IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data'
|
|
|
|
|
|
|
|
|
|
|
|
# MLLM config
|
|
|
|
|
|
MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
|
|
|
|
|
DEVICE = "cuda:0" # 确保设备设置正确,与您的 Traceback 匹配
|
|
|
|
|
|
BATCH_SIZE = 50
|
|
|
|
|
|
OUTPUT_FILE = f'./data/{BATCH_SOURCE}/metadata_extraction.json'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Load Model
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
if processor.tokenizer.padding_side != 'left':
|
|
|
|
|
|
processor.tokenizer.padding_side = 'left'
|
|
|
|
|
|
print(f"Set tokenizer padding_side to '{processor.tokenizer.padding_side}' for correct generation.")
|
|
|
|
|
|
model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16).to(DEVICE)
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Load Data
|
|
|
|
|
|
with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file:
|
|
|
|
|
|
data = json.load(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EXAMPLE_1_INFO = """
|
|
|
|
|
|
Product Name: ARMARIUM - Loren Wool Blend Tube Skirt
|
|
|
|
|
|
Category: Clothing / Bottoms
|
|
|
|
|
|
Color: RED
|
|
|
|
|
|
Description: Cut from cardinal-red virgin wool, Armarium's Loren skirt wields tailoring's exactitude in a column of colour. The low-slung waist and clean tube line are punctuated by a razor back slit—stride from boardroom to candlelit bar with modern hauteur.
|
|
|
|
|
|
Tags: armarium, clothing, in-stock, new, loren, wool, blend, tube
|
|
|
|
|
|
"""
|
|
|
|
|
|
EXAMPLE_1_JSON = json.dumps({
|
2025-12-12 17:37:07 +08:00
|
|
|
|
"subcategory": "skirts",
|
2025-12-10 17:27:56 +08:00
|
|
|
|
"gender": "female",
|
|
|
|
|
|
"applicable_occasions": [
|
|
|
|
|
|
"Business/workwear", "Evening", "Cocktail / Semi-Formal", "Party / Clubbing", "Formal"
|
|
|
|
|
|
],
|
|
|
|
|
|
"inappropriate_occasions": [
|
|
|
|
|
|
"Activewear", "Beach / Swim", "Athleisure", "Ski / Snow / Mountain", "Casual"
|
|
|
|
|
|
]
|
|
|
|
|
|
}, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
# 示例 2:胸针 (Pin)
|
|
|
|
|
|
EXAMPLE_2_INFO = """
|
|
|
|
|
|
Product Name: TATEOSSIAN - Mayfair 18K Yellow Gold Rhodium Plated Sterling Silver Peg Pin
|
|
|
|
|
|
Category: Accessories / Accessories
|
|
|
|
|
|
Color: MULTI
|
|
|
|
|
|
Description: Crafted from 18k yellow gold and rhodium-plated sterling silver, this unique pins has been artfully finished with Tateossian's signature diamond engraving pattern.
|
|
|
|
|
|
Tags: tateossian, accessories, in-stock, new, mayfair, yellow, gold, rhodium
|
|
|
|
|
|
"""
|
|
|
|
|
|
EXAMPLE_2_JSON = json.dumps({
|
2025-12-12 17:37:07 +08:00
|
|
|
|
"subcategory": "jewelry",
|
2025-12-10 17:27:56 +08:00
|
|
|
|
"gender": "female",
|
|
|
|
|
|
"applicable_occasions": [
|
|
|
|
|
|
"Formal", "Black Tie / White Tie", "Bridal / Wedding", "Business/workwear", "Cocktail / Semi-Formal"
|
|
|
|
|
|
],
|
|
|
|
|
|
"inappropriate_occasions": [
|
|
|
|
|
|
"Casual", "Activewear", "Beach / Swim", "Outdoor", "Athleisure", "Ski / Snow / Mountain"
|
|
|
|
|
|
]
|
|
|
|
|
|
}, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# --- 2. 构造对话格式 Prompt ---
|
|
|
|
|
|
BOS_TOKEN = "<|begin_of_text|>"
|
|
|
|
|
|
EOS_TOKEN = "<|eot_id|>"
|
|
|
|
|
|
SYSTEM_HEADER = "<|start_header_id|>system<|end_header_id|>\n"
|
|
|
|
|
|
USER_HEADER = "<|start_header_id|>user<|end_header_id|>\n"
|
|
|
|
|
|
ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n"
|
|
|
|
|
|
IMAGE_TOKEN = "<|image|>"
|
|
|
|
|
|
|
|
|
|
|
|
def format_product_info(product):
|
|
|
|
|
|
tags_str = ", ".join(product.get('tags', []))
|
|
|
|
|
|
info = (
|
|
|
|
|
|
f"Product Name: {product.get('name', 'N/A')}\n"
|
|
|
|
|
|
f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n"
|
|
|
|
|
|
f"Color: {product.get('color', 'N/A')}\n"
|
|
|
|
|
|
f"Description: {product.get('description', '')}\n"
|
|
|
|
|
|
f"Tags: {tags_str}",
|
|
|
|
|
|
f"groupName: {product.get('groupName', 'N/A')}\n"
|
|
|
|
|
|
f"onlineBU: {product.get('onlineBU', 'N/A')}\n"
|
|
|
|
|
|
)
|
|
|
|
|
|
return info
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-12-12 17:37:07 +08:00
|
|
|
|
def raw_category_mapping(raw_category: str) -> str:
|
2025-12-10 17:27:56 +08:00
|
|
|
|
if raw_category == 'Fine Jewellery And Watches':
|
2025-12-12 17:37:07 +08:00
|
|
|
|
return 'accessories'
|
2025-12-10 17:27:56 +08:00
|
|
|
|
else:
|
2025-12-12 17:37:07 +08:00
|
|
|
|
return raw_category.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_full_prompt(product_info, raw_category):
|
|
|
|
|
|
category = raw_category_mapping(raw_category)
|
|
|
|
|
|
subcategory_list = FASHION_TAXONOMY.get(category)
|
2025-12-10 17:27:56 +08:00
|
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = f"""You are an expert fashion AI assistant. Your task is to analyze the provided product image and product details to:
|
|
|
|
|
|
1. determine the suitable occasions for wearing or using the item. You must choose occasions ONLY from the following strict list: {json.dumps(OCCASION, indent=4)}. Only relevant suitable or inappropriate occasions should be selected.
|
2025-12-12 17:37:07 +08:00
|
|
|
|
2. categorize it into suitable subcategory in strict list: {json.dumps(subcategory_list)}.
|
2025-12-10 17:27:56 +08:00
|
|
|
|
3. categorize it into appropriate gender in ["female", "male", "unisex"]
|
|
|
|
|
|
|
|
|
|
|
|
Output Format:
|
2025-12-12 17:37:07 +08:00
|
|
|
|
Return ONLY a valid JSON object with four keys: "subcategory", "gender", "applicable_occasions" and "inappropriate_occasions". Do not include any analysis or extra text outside of the final JSON object.
|
2025-12-10 17:27:56 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# 组合对话序列
|
|
|
|
|
|
dialogue_prompt = (
|
|
|
|
|
|
# 1. System Instruction
|
|
|
|
|
|
f"{BOS_TOKEN}{SYSTEM_HEADER}{SYSTEM_PROMPT}{EOS_TOKEN}"
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Example 1 (Few-Shot Round 1)
|
|
|
|
|
|
# 格式: <|start_header_id|>user<|end_header_id|>\n<|image|>\n{Text Instruction}<|eot_id|>
|
|
|
|
|
|
f"{USER_HEADER}\n{EXAMPLE_1_INFO}{EOS_TOKEN}"
|
|
|
|
|
|
f"{ASSISTANT_HEADER}{EXAMPLE_1_JSON}{EOS_TOKEN}"
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Example 2 (Few-Shot Round 2)
|
|
|
|
|
|
f"{USER_HEADER}\n{EXAMPLE_2_INFO}{EOS_TOKEN}"
|
|
|
|
|
|
f"{ASSISTANT_HEADER}{EXAMPLE_2_JSON}{EOS_TOKEN}"
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Target Item (Target Query)
|
|
|
|
|
|
f"{USER_HEADER}{IMAGE_TOKEN}\nInput Data:\n{product_info}{EOS_TOKEN}"
|
|
|
|
|
|
f"{ASSISTANT_HEADER}" # 最后的 Assistant Header 告诉模型从这里开始生成
|
|
|
|
|
|
)
|
|
|
|
|
|
return dialogue_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 加载数据
|
|
|
|
|
|
products = data['products']
|
|
|
|
|
|
product_list = [
|
|
|
|
|
|
product for product in products
|
|
|
|
|
|
if product.get('category') in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches']
|
|
|
|
|
|
and os.path.exists(os.path.join(IMAGE_DIR, f"{product.get('id')}.jpg"))
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-12-12 17:37:07 +08:00
|
|
|
|
def validate_result(result_dict):
|
|
|
|
|
|
subcategory = result_dict.get("subcategory")
|
|
|
|
|
|
gender = result_dict.get("gender")
|
|
|
|
|
|
|
|
|
|
|
|
if not subcategory or not gender:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
if subcategory not in ALL_SUBCATEGORY_LIST:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
if gender not in ['female', 'male', 'unisex']:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
2025-12-10 17:27:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-12-12 17:37:07 +08:00
|
|
|
|
if os.path.exists(OUTPUT_FILE):
|
|
|
|
|
|
with open(OUTPUT_FILE, 'r') as f:
|
|
|
|
|
|
final_results = json.load(f)
|
|
|
|
|
|
else:
|
|
|
|
|
|
final_results = {}
|
2025-12-10 17:27:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attemps = 0
|
|
|
|
|
|
while attemps < 3:
|
2025-12-12 17:37:07 +08:00
|
|
|
|
unfinished_products = [product for product in product_list if product.get('id') not in final_results.keys()]
|
2025-12-10 17:27:56 +08:00
|
|
|
|
attemps += 1
|
2025-12-12 17:37:07 +08:00
|
|
|
|
completion_ratio = len(final_results) / len(product_list)
|
|
|
|
|
|
if (completion_ratio > 0.85):
|
|
|
|
|
|
print("valid results surpass 85%. Finish Now.")
|
2025-12-10 17:27:56 +08:00
|
|
|
|
break
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"Start {attemps} categorization process. Current ratio: {completion_ratio * 100}%")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 按照 BATCH_SIZE 进行切片迭代
|
|
|
|
|
|
for i in tqdm(range(0, len(unfinished_products), BATCH_SIZE)):
|
|
|
|
|
|
batch_samples = unfinished_products[i:i + BATCH_SIZE]
|
|
|
|
|
|
|
|
|
|
|
|
target_images = []
|
|
|
|
|
|
target_prompts = []
|
|
|
|
|
|
target_products_in_batch = []
|
|
|
|
|
|
|
|
|
|
|
|
# 准备当前批次的输入数据
|
|
|
|
|
|
for product in batch_samples:
|
|
|
|
|
|
product_id = product['id']
|
|
|
|
|
|
raw_category = product.get('category')
|
|
|
|
|
|
image_path = os.path.join(IMAGE_DIR, f"{product_id}.jpg")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 收集图片、Prompt 和产品数据
|
|
|
|
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
|
|
|
product_info = format_product_info(product)
|
|
|
|
|
|
full_prompt = generate_full_prompt(product_info, raw_category)
|
|
|
|
|
|
|
|
|
|
|
|
target_images.append(image)
|
|
|
|
|
|
target_prompts.append(full_prompt)
|
|
|
|
|
|
target_products_in_batch.append(product)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 跳过任何加载失败的单个样本
|
|
|
|
|
|
print(f"Skipping product {product_id} due to loading error: {e}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if not target_images:
|
|
|
|
|
|
continue # 如果整个批次都没有有效图片,跳过
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 批量推理
|
|
|
|
|
|
print(f"\nProcessing batch {i//BATCH_SIZE + 1}/{int(len(unfinished_products)/BATCH_SIZE)+1} (Size: {len(target_images)})...")
|
|
|
|
|
|
|
|
|
|
|
|
# 处理器输入:使用嵌套列表 [[img1], [img2], ...]
|
|
|
|
|
|
inputs = processor(
|
|
|
|
|
|
images=[[img] for img in target_images],
|
|
|
|
|
|
text=target_prompts,
|
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
|
padding=True,
|
|
|
|
|
|
truncation=True
|
|
|
|
|
|
).to(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
outputs = model.generate(
|
|
|
|
|
|
**inputs,
|
|
|
|
|
|
max_new_tokens=150,
|
|
|
|
|
|
do_sample=False
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 5. 批量解码和解析结果
|
|
|
|
|
|
input_lengths = inputs.input_ids.size(1)
|
|
|
|
|
|
|
|
|
|
|
|
for j in range(len(target_products_in_batch)):
|
|
|
|
|
|
product = target_products_in_batch[j]
|
|
|
|
|
|
product_id = product['id']
|
|
|
|
|
|
|
|
|
|
|
|
# 提取当前 item 的生成结果
|
|
|
|
|
|
# 注意: outputs 是 [batch_size, sequence_length]
|
|
|
|
|
|
newly_generated_tokens = outputs[j, input_lengths:]
|
|
|
|
|
|
generated_text = processor.decode(newly_generated_tokens, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
# 清理和解析
|
|
|
|
|
|
if generated_text.endswith(processor.tokenizer.eos_token):
|
|
|
|
|
|
generated_text = generated_text[:-len(processor.tokenizer.eos_token)]
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
start_idx = generated_text.find('{')
|
|
|
|
|
|
end_idx = generated_text.rfind('}') + 1
|
|
|
|
|
|
|
|
|
|
|
|
if start_idx == -1 or end_idx == -1:
|
|
|
|
|
|
raise ValueError("JSON start or end delimiter not found.")
|
|
|
|
|
|
|
|
|
|
|
|
json_str = generated_text[start_idx:end_idx]
|
|
|
|
|
|
result_dict = json.loads(json_str)
|
|
|
|
|
|
|
2025-12-12 17:37:07 +08:00
|
|
|
|
if validate_result(result_dict):
|
|
|
|
|
|
final_results[product_id] = result_dict
|
2025-12-10 17:27:56 +08:00
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"ID {product_id}: FAILED to parse JSON. Raw Output: {generated_text.strip()}")
|
|
|
|
|
|
|
|
|
|
|
|
# 显存清理(可选,但在长任务中推荐)
|
|
|
|
|
|
del inputs, outputs
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(final_results, f, indent=4, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
# 6. 保存最终结果
|
|
|
|
|
|
print("\n\n=== ALL BATCHES COMPLETE ===")
|
|
|
|
|
|
|
|
|
|
|
|
# 保存最终结果到 JSON 文件
|
|
|
|
|
|
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
|
|
|
|
|
|
json.dump(final_results, f, indent=4, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Results saved to {OUTPUT_FILE}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"\n--- Execution Error ---")
|
|
|
|
|
|
print(f"An unexpected error occurred: {e}")
|