import torch import os from transformers import AutoProcessor, AutoModelForVision2Seq from PIL import Image import json from tqdm import tqdm from app.taxonomy import OCCASION, CATEGORY, ALL_CATEGORY # data config BATCH_SOURCE = '2025_q4' RAW_DATA_PATH = f'./data/{BATCH_SOURCE}/products-all.json' IMAGE_DIR = f'./data/{BATCH_SOURCE}/image_data' # MLLM config MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct" DEVICE = "cuda:0" # 确保设备设置正确,与您的 Traceback 匹配 BATCH_SIZE = 50 OUTPUT_FILE = f'./data/{BATCH_SOURCE}/metadata_extraction.json' # Load Model processor = AutoProcessor.from_pretrained(MODEL_NAME) if processor.tokenizer.padding_side != 'left': processor.tokenizer.padding_side = 'left' print(f"Set tokenizer padding_side to '{processor.tokenizer.padding_side}' for correct generation.") model = AutoModelForVision2Seq.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16).to(DEVICE) model.eval() # Load Data with open(RAW_DATA_PATH, 'r', encoding='utf-8') as file: data = json.load(file) EXAMPLE_1_INFO = """ Product Name: ARMARIUM - Loren Wool Blend Tube Skirt Category: Clothing / Bottoms Color: RED Description: Cut from cardinal-red virgin wool, Armarium's Loren skirt wields tailoring's exactitude in a column of colour. The low-slung waist and clean tube line are punctuated by a razor back slit—stride from boardroom to candlelit bar with modern hauteur. Tags: armarium, clothing, in-stock, new, loren, wool, blend, tube """ EXAMPLE_1_JSON = json.dumps({ "category": "skirts", "gender": "female", "applicable_occasions": [ "Business/workwear", "Evening", "Cocktail / Semi-Formal", "Party / Clubbing", "Formal" ], "inappropriate_occasions": [ "Activewear", "Beach / Swim", "Athleisure", "Ski / Snow / Mountain", "Casual" ] }, indent=4) # 示例 2:胸针 (Pin) EXAMPLE_2_INFO = """ Product Name: TATEOSSIAN - Mayfair 18K Yellow Gold Rhodium Plated Sterling Silver Peg Pin Category: Accessories / Accessories Color: MULTI Description: Crafted from 18k yellow gold and rhodium-plated sterling silver, this unique pins has been artfully finished with Tateossian's signature diamond engraving pattern. Tags: tateossian, accessories, in-stock, new, mayfair, yellow, gold, rhodium """ EXAMPLE_2_JSON = json.dumps({ "category": "jewelry", "gender": "female", "applicable_occasions": [ "Formal", "Black Tie / White Tie", "Bridal / Wedding", "Business/workwear", "Cocktail / Semi-Formal" ], "inappropriate_occasions": [ "Casual", "Activewear", "Beach / Swim", "Outdoor", "Athleisure", "Ski / Snow / Mountain" ] }, indent=4) # --- 2. 构造对话格式 Prompt --- BOS_TOKEN = "<|begin_of_text|>" EOS_TOKEN = "<|eot_id|>" SYSTEM_HEADER = "<|start_header_id|>system<|end_header_id|>\n" USER_HEADER = "<|start_header_id|>user<|end_header_id|>\n" ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n" IMAGE_TOKEN = "<|image|>" def format_product_info(product): tags_str = ", ".join(product.get('tags', [])) info = ( f"Product Name: {product.get('name', 'N/A')}\n" f"Category: {product.get('category', 'N/A')} / {product.get('deptName', 'N/A')}\n" f"Color: {product.get('color', 'N/A')}\n" f"Description: {product.get('description', '')}\n" f"Tags: {tags_str}", f"groupName: {product.get('groupName', 'N/A')}\n" f"onlineBU: {product.get('onlineBU', 'N/A')}\n" ) return info def generate_full_prompt(product_info, raw_category): if raw_category == 'Fine Jewellery And Watches': category = 'accessories' else: category = raw_category.lower() subcategory_list = CATEGORY.get(category) SYSTEM_PROMPT = f"""You are an expert fashion AI assistant. Your task is to analyze the provided product image and product details to: 1. determine the suitable occasions for wearing or using the item. You must choose occasions ONLY from the following strict list: {json.dumps(OCCASION, indent=4)}. Only relevant suitable or inappropriate occasions should be selected. 2. categorize it into suitable category in strict list: {json.dumps(subcategory_list)}. 3. categorize it into appropriate gender in ["female", "male", "unisex"] Output Format: Return ONLY a valid JSON object with four keys: "category", "gender", "applicable_occasions" and "inappropriate_occasions". Do not include any analysis or extra text outside of the final JSON object. """ # 组合对话序列 dialogue_prompt = ( # 1. System Instruction f"{BOS_TOKEN}{SYSTEM_HEADER}{SYSTEM_PROMPT}{EOS_TOKEN}" # 2. Example 1 (Few-Shot Round 1) # 格式: <|start_header_id|>user<|end_header_id|>\n<|image|>\n{Text Instruction}<|eot_id|> f"{USER_HEADER}\n{EXAMPLE_1_INFO}{EOS_TOKEN}" f"{ASSISTANT_HEADER}{EXAMPLE_1_JSON}{EOS_TOKEN}" # 3. Example 2 (Few-Shot Round 2) f"{USER_HEADER}\n{EXAMPLE_2_INFO}{EOS_TOKEN}" f"{ASSISTANT_HEADER}{EXAMPLE_2_JSON}{EOS_TOKEN}" # 4. Target Item (Target Query) f"{USER_HEADER}{IMAGE_TOKEN}\nInput Data:\n{product_info}{EOS_TOKEN}" f"{ASSISTANT_HEADER}" # 最后的 Assistant Header 告诉模型从这里开始生成 ) return dialogue_prompt # 2. 加载数据 products = data['products'] product_list = [ product for product in products if product.get('category') in ['Clothing', 'Accessories', 'Shoes', 'Bags', 'Fine Jewellery And Watches'] and os.path.exists(os.path.join(IMAGE_DIR, f"{product.get('id')}.jpg")) ] def validate_results(): if os.path.exists(OUTPUT_FILE): with open(OUTPUT_FILE, 'r') as f: final_results = json.load(f) else: final_results = {} unfinished_ids = [] for product in product_list: item_id = product.get('id') if item_id not in final_results.keys(): unfinished_ids.append(product) else: processed_item = final_results[item_id] category = processed_item.get("category") gender = processed_item.get("gender") if category not in ALL_CATEGORY: unfinished_ids.append(product) if gender not in ['female', 'male', 'unisex']: unfinished_ids.append(product) return unfinished_ids, final_results attemps = 0 while attemps < 3: attemps += 1 unfinished_products, final_results = validate_results() completion_ratio = len(unfinished_products) / len(product_list) if (completion_ratio > 0.95): print("valid results surpass 95%. Finish Now.") break else: print(f"Start {attemps} categorization process. Current ratio: {completion_ratio * 100}%") try: # 按照 BATCH_SIZE 进行切片迭代 for i in tqdm(range(0, len(unfinished_products), BATCH_SIZE)): batch_samples = unfinished_products[i:i + BATCH_SIZE] target_images = [] target_prompts = [] target_products_in_batch = [] # 准备当前批次的输入数据 for product in batch_samples: product_id = product['id'] raw_category = product.get('category') image_path = os.path.join(IMAGE_DIR, f"{product_id}.jpg") try: # 收集图片、Prompt 和产品数据 image = Image.open(image_path).convert("RGB") product_info = format_product_info(product) full_prompt = generate_full_prompt(product_info, raw_category) target_images.append(image) target_prompts.append(full_prompt) target_products_in_batch.append(product) except Exception as e: # 跳过任何加载失败的单个样本 print(f"Skipping product {product_id} due to loading error: {e}") continue if not target_images: continue # 如果整个批次都没有有效图片,跳过 # 4. 批量推理 print(f"\nProcessing batch {i//BATCH_SIZE + 1}/{int(len(unfinished_products)/BATCH_SIZE)+1} (Size: {len(target_images)})...") # 处理器输入:使用嵌套列表 [[img1], [img2], ...] inputs = processor( images=[[img] for img in target_images], text=target_prompts, return_tensors="pt", padding=True, truncation=True ).to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, do_sample=False ) # 5. 批量解码和解析结果 input_lengths = inputs.input_ids.size(1) for j in range(len(target_products_in_batch)): product = target_products_in_batch[j] product_id = product['id'] # 提取当前 item 的生成结果 # 注意: outputs 是 [batch_size, sequence_length] newly_generated_tokens = outputs[j, input_lengths:] generated_text = processor.decode(newly_generated_tokens, skip_special_tokens=True) # 清理和解析 if generated_text.endswith(processor.tokenizer.eos_token): generated_text = generated_text[:-len(processor.tokenizer.eos_token)] try: start_idx = generated_text.find('{') end_idx = generated_text.rfind('}') + 1 if start_idx == -1 or end_idx == -1: raise ValueError("JSON start or end delimiter not found.") json_str = generated_text[start_idx:end_idx] result_dict = json.loads(json_str) final_results[product_id] = result_dict except Exception as e: print(f"ID {product_id}: FAILED to parse JSON. Raw Output: {generated_text.strip()}") final_results[product_id] = {"error": str(e), "raw_output": generated_text.strip()} # 显存清理(可选,但在长任务中推荐) del inputs, outputs torch.cuda.empty_cache() with open(OUTPUT_FILE, 'w', encoding='utf-8') as f: json.dump(final_results, f, indent=4, ensure_ascii=False) # 6. 保存最终结果 print("\n\n=== ALL BATCHES COMPLETE ===") # 保存最终结果到 JSON 文件 with open(OUTPUT_FILE, 'w', encoding='utf-8') as f: json.dump(final_results, f, indent=4, ensure_ascii=False) print(f"Results saved to {OUTPUT_FILE}") except Exception as e: print(f"\n--- Execution Error ---") print(f"An unexpected error occurred: {e}")