From e88ba6994abf34683e40e69b2c342b0b4cf73e9f Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 21 Oct 2024 11:01:28 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=20=20=E6=96=B0=E5=A2=9E=20process=20loo?= =?UTF-8?q?kbooks=20=E6=8E=A5=E5=8F=A3=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- app/api/api_process_lookbooks.py | 55 +++++++ app/api/api_route.py | 6 +- app/core/config.py | 2 +- app/schemas/lookbooks.py | 0 app/service/lookbooks/__init__.py | 0 app/service/lookbooks/service.py | 128 +++++++++++++++++ app/service/lookbooks/utils/__init__.py | 0 app/service/lookbooks/utils/image_utils.py | 152 ++++++++++++++++++++ app/service/lookbooks/utils/openai_utils.py | 31 ++++ requirements.txt | Bin 740 -> 830 bytes 11 files changed, 371 insertions(+), 6 deletions(-) create mode 100644 app/api/api_process_lookbooks.py create mode 100644 app/schemas/lookbooks.py create mode 100644 app/service/lookbooks/__init__.py create mode 100644 app/service/lookbooks/service.py create mode 100644 app/service/lookbooks/utils/__init__.py create mode 100644 app/service/lookbooks/utils/image_utils.py create mode 100644 app/service/lookbooks/utils/openai_utils.py diff --git a/.gitignore b/.gitignore index 2e0eb4a..fa15abc 100644 --- a/.gitignore +++ b/.gitignore @@ -69,4 +69,5 @@ app/logs feature/ test -*.zip \ No newline at end of file +*.zip +*.pdf \ No newline at end of file diff --git a/app/api/api_process_lookbooks.py b/app/api/api_process_lookbooks.py new file mode 100644 index 0000000..752fce9 --- /dev/null +++ b/app/api/api_process_lookbooks.py @@ -0,0 +1,55 @@ +import logging +import os +import shutil +from typing import List + +import tqdm +from fastapi import UploadFile, File, APIRouter + +from app.service.lookbooks.service import create_image_batch_requests + +logger = logging.getLogger() +router = APIRouter() + + +@router.post("/process_lookbooks/") +async def process_lookbooks(files: List[UploadFile] = File(...)): + lookbook_dir = "service/lookbooks/temp_lookbooks" + os.makedirs(lookbook_dir, exist_ok=True) + + lookbook_list = [] + for file in files: + file_path = os.path.join(lookbook_dir, file.filename) + with open(file_path, "wb") as f: + shutil.copyfileobj(file.file, f) + lookbook_list.append(file_path) + + image_list = [] + for look_book_path in tqdm.tqdm(lookbook_list): + lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0] + output_dir = os.path.join("app/service/lookbooks/fashion_documents/lookbook/images", lookbook_name) + os.makedirs(output_dir, exist_ok=True) + if not os.listdir(output_dir): + from unstructured.partition.pdf import partition_pdf + partition_pdf( + filename=look_book_path, + extract_images_in_pdf=True, + infer_table_structure=False, + chunking_strategy="by_title", + max_characters=4000, + new_after_n_chars=3800, + combine_text_under_n_chars=2000, + extract_image_block_output_dir=output_dir, + ) + else: + current_images = os.listdir(output_dir) + image_list.extend([os.path.join(output_dir, x) for x in current_images]) + + image_description_results_file = create_image_batch_requests(image_list, "app/service/lookbooks/fashion_documents/lookbook/results") + + shutil.rmtree(lookbook_dir) + + if image_description_results_file: + return {"message": "Lookbooks processed successfully", "result_file": image_description_results_file} + else: + return {"message": "No new images to process"} diff --git a/app/api/api_route.py b/app/api/api_route.py index 983e725..923a9b0 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,9 +1,6 @@ from fastapi import APIRouter -from app.api import api_test -from app.api import api_outfit_matcher -from app.api import api_attribute -from app.api import api_similar_match +from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks router = APIRouter() @@ -11,3 +8,4 @@ router.include_router(api_test.router, tags=["test"], prefix="/test") router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api/outfit_matcher") router.include_router(api_attribute.router, tags=["attribute"], prefix="/api/attribute") router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api/similar_match") +router.include_router(api_process_lookbooks.router, tags=["process_lookbooks"], prefix="/api/process_lookbooks") diff --git a/app/core/config.py b/app/core/config.py index f15c473..7f0a2b8 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -35,7 +35,7 @@ ATT_TRITON_PORT = "20010" MILVUS_URL = "http://10.1.1.240:19530" -DEBUG = 1 +DEBUG = 2 SHOW_OR_SAVE_result_image = False # service env : 1 # pycharm debug : 2 diff --git a/app/schemas/lookbooks.py b/app/schemas/lookbooks.py new file mode 100644 index 0000000..e69de29 diff --git a/app/service/lookbooks/__init__.py b/app/service/lookbooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/service/lookbooks/service.py b/app/service/lookbooks/service.py new file mode 100644 index 0000000..065d7ee --- /dev/null +++ b/app/service/lookbooks/service.py @@ -0,0 +1,128 @@ +import json +import os + +from openai import OpenAI + +from app.service.lookbooks.utils.image_utils import base64_encode_image, generate_text_id +from app.service.lookbooks.utils.openai_utils import wait_for_job_completion + +OPENAI_API_KEY = "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" +OPENAI_API_BASE = "https://pangkaichen-openai-prox-98.deno.dev/v1" +client = OpenAI( + api_key=OPENAI_API_KEY, + base_url=OPENAI_API_BASE, +) + + +def create_image_batch_requests( + image_list, + output_path, + prompt="""You are an AI assistant specializing in fashion analysis and tagging for an advanced clothing indexing system. Your task is to analyze images of outfits and provide concise, relevant information. Please structure your response as follows: + + 1. Brief Summary: Start with a one-sentence summary of the overall style and vibe of the outfit. + + 2. Tags: Provide relevant tags in the following categories. Use multiple tags where appropriate, separated by commas. + + Season: [e.g., Spring/Summer, Fall/Winter] + Style: [e.g., Minimalist, Bohemian, Streetwear, Business Casual] + Occasion: [e.g., Office, Casual, Party, Outdoor] + Colors: [List main colors used] + Materials: [e.g., Cotton, Denim, Leather, Silk] + Key Elements: [Any distinctive fashion elements] + + 3. Item Descriptions: + - Briefly describe each main clothing item (top, bottom, outerwear). Number and categorize each item (e.g., 1. Top: ..., 2. Bottom: ...). Keep descriptions concise but include key details about style, color, and distinctive features. + - In a single sentence, list all accessories and their overall effect on the outfit. + + Ensure your tags and descriptions are accurate, relevant to current fashion trends, and useful for indexing and retrieval purposes.""" +): + # 预处理 prompt:移除多余的空白和换行符 + prompt = ' '.join(prompt.split()) + + completed_id = [] + if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")): + with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f: + for line in f: + completed_id.append(json.loads(line)["custom_id"]) + + tasks = [] + id2img = {} + for idx, image_filename in enumerate(image_list): + image_base64 = base64_encode_image(image_filename) + if not image_base64: + continue + current_id = generate_text_id(image_base64) + if current_id in completed_id: + continue + task = { + "custom_id": current_id, + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-4o", + "temperature": 1.0, + "max_tokens": 500, + "messages": [ + { + "role": "system", + "content": prompt + }, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}", + "detail": "low" + } + } + ] + } + ] + } + } + tasks.append(task) + id2img[current_id] = image_filename + print(f"In total {len(tasks)} images") + if tasks: + batch_file_name = os.path.join(output_path, "image_batch_requests.jsonl") + with open(batch_file_name, 'w', encoding='utf-8') as file: + for obj in tasks: + file.write(json.dumps(obj, ensure_ascii=False) + '\n') + + batch_file = client.files.create( + file=open(batch_file_name, "rb"), + purpose="batch" + ) + batch_job = client.batches.create( + input_file_id=batch_file.id, + endpoint="/v1/chat/completions", + completion_window="24h" + ) + + if wait_for_job_completion(client, batch_job.id): + output_file_id = client.batches.retrieve(batch_job.id).output_file_id + file_response = client.files.content(output_file_id) + file_content_str = file_response.read().decode('utf-8') + + with open(os.path.join(output_path, "image_description_results.jsonl"), "w", encoding='utf-8') as f: + for line in file_content_str.splitlines(): + if line.strip(): + try: + result = json.loads(line) + image_id = result['custom_id'] + caption = result['response']['body']['choices'][0]['message']['content'] + output = json.dumps({ + "custom_id": image_id, + "summary": caption, + "url": id2img[image_id] + }) + f.write(output + '\n') + except json.JSONDecodeError as error: + print(f"Error parsing: {error} -- at line: {line}") + else: + print("Job failed") + return os.path.join(output_path, "image_description_results.jsonl") + else: + return None diff --git a/app/service/lookbooks/utils/__init__.py b/app/service/lookbooks/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/service/lookbooks/utils/image_utils.py b/app/service/lookbooks/utils/image_utils.py new file mode 100644 index 0000000..08faaea --- /dev/null +++ b/app/service/lookbooks/utils/image_utils.py @@ -0,0 +1,152 @@ +from typing import Optional, Type, List, Dict, Any +import hashlib +import base64 +import io +import os + +from PIL import Image +import torch +from torchvision import transforms + + +def generate_text_id(text): + return hashlib.md5(text.encode('utf-8')).hexdigest() + + +def get_file_hash(file_path): + with open(file_path, 'rb') as f: + return hashlib.md5(f.read()).hexdigest() + + +def resize_image(img, max_size=512): + """调整图片大小,保持原始比例,使长宽均不超过max_size,并保存到output_path""" + width, height = img.size + if width > max_size or height > max_size: + # 保持图片的宽高比 + if width > height: + new_width = max_size + new_height = int(max_size * height / width) + else: + new_height = max_size + new_width = int(max_size * width / height) + + # 进行图片缩放 + img_resized = img.resize((new_width, new_height), Image.LANCZOS) + else: + img_resized = img.copy() # 如果图片本身符合条件,则不进行缩放 + + return img_resized + + +def base64_encode_image(image_path): + try: + image_path = image_path.replace("\\", "/") + + with Image.open(image_path) as img: + # 如果图像模式不是 RGB,转换为 RGB + if img.mode != 'RGB': + img = img.convert('RGB') + # 获取当前图片的宽和高 + img_resized = resize_image(img) + # 将PIL Image对象编码为base64字符串。 + buffered = io.BytesIO() + img_resized.save(buffered, format="JPEG") # 保存调整后的图像到内存流 + return base64.b64encode(buffered.getvalue()).decode("utf-8") + except FileNotFoundError as e: + print(f"File not found: {image_path}") + print(f"Absolute path: {os.path.abspath(image_path)}") + print(f"Current working directory: {os.getcwd()}") + return None + except Exception as e: + print(f"Error processing image {image_path}: {str(e)}") + return None + + +def get_possible_image_paths(path): + image_path = path.split(".") + # Image path might contain file extension (e.g. .jpg), + # In this case, we want the path without the extension + image_path = image_path if len(image_path) == 1 else image_path[:-1] + for ext in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"): + image_ext = ".".join(image_path) + ext + if os.path.isfile(image_ext): + path = image_ext + break + return path + + +def preprocess(image_urls: List[str]): + transform_pipeline = transforms.Compose([ + # 缩放图像尺寸到 256x256 + transforms.Resize((256, 256)), + # 从图像中心裁剪 224x224 + transforms.CenterCrop(224), + # 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor + transforms.ToTensor(), + # 对图像进行标准化 + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + ]) + image_data = [] + if len(image_urls) == 0: + return None + for url in image_urls: + path = get_possible_image_paths(url) + try: + image = Image.open(path).convert("RGB") + except FileNotFoundError as e: + print(f"File not found: {path}, {e}") + continue + image_tensor = transform_pipeline(image) + image_data.append(image_tensor) + image_data = torch.stack(image_data, dim=0) + return image_data.numpy() + + +def is_image_accessible(path: str) -> bool: + if not os.path.exists(path): + return False + try: + with Image.open(path) as img: + img.verify() # 验证图像文件 + return True + except (IOError, SyntaxError): + return False + +def create_image_grid(image_paths, grid_size=(2, 2)): + # 打开所有图片 + images = [Image.open(path) for path in image_paths] + + # 确定每个图片的大小 + max_width = max(img.width for img in images) + max_height = max(img.height for img in images) + + # 创建一个新的空白图片 + grid_width = grid_size[0] * max_width + grid_height = grid_size[1] * max_height + grid_image = Image.new('RGB', (grid_width, grid_height), color="white") + + # 将图片粘贴到网格中 + for i, img in enumerate(images): + x = (i % grid_size[0]) * max_width + y = (i // grid_size[0]) * max_height + + # 计算缩放比例,保持原始宽高比 + ratio = min(max_width / img.width, max_height / img.height) + new_size = (int(img.width * ratio), int(img.height * ratio)) + + # 调整大小并保持宽高比 + img_resized = img.resize(new_size, Image.LANCZOS) + + # 计算居中位置 + paste_x = x + (max_width - new_size[0]) // 2 + paste_y = y + (max_height - new_size[1]) // 2 + + # 粘贴到网格中 + grid_image.paste(img_resized, (paste_x, paste_y)) + + # 将图片转换为字节流 + img_byte_arr = io.BytesIO() + grid_image.save(img_byte_arr, format='PNG') + img_byte_arr = img_byte_arr.getvalue() + return img_byte_arr diff --git a/app/service/lookbooks/utils/openai_utils.py b/app/service/lookbooks/utils/openai_utils.py new file mode 100644 index 0000000..bb5f9cb --- /dev/null +++ b/app/service/lookbooks/utils/openai_utils.py @@ -0,0 +1,31 @@ +import time + + +def wait_for_job_completion(client, batch_job_id, timeout=6000, interval=10): + """ + 等待批处理作业完成。 + + 参数: + - client: 批处理作业客户端对象 + - batch_job_id: 批处理作业的ID + - timeout: 最大等待时间(秒) + - interval: 检查状态的间隔时间(秒) + + 返回: + - True 如果作业完成,False 如果超时 + """ + start_time = time.time() + while time.time() - start_time < timeout: + batch_job = client.batches.retrieve(batch_job_id) + if batch_job.status == 'completed': + print("Batch job completed successfully.") + return True + elif batch_job.status == 'failed': + print("Batch job failed.") + return False + else: + print(f"Current status: {batch_job.status}. Waiting for completion...") + time.sleep(interval) + + print("Timeout reached, job did not complete in time.") + return False diff --git a/requirements.txt b/requirements.txt index a7bcd05acb9e1d6316f125a2fb66b1febd1f0cca..1ca892730a1e186fbe49c305943250b945919bbb 100644 GIT binary patch delta 98 zcmaFDx{qzc6Q(F$1}+8=E@jALC}t>OC<3yQfpjU5mCBI95Y3RtkOPFeKz2S*rkEiX gtUjNi04Sdal+9$Y1wumxJqBY2Qy^&wBta$u0Qoi%N&o-= delta 7 OcmdnT_JnoA6D9x+zXL=7