feat 新增 process lookbooks 接口

fix
This commit is contained in:
zhouchengrong
2024-10-21 11:01:28 +08:00
parent 3417bcb2ab
commit e88ba6994a
11 changed files with 371 additions and 6 deletions

3
.gitignore vendored
View File

@@ -69,4 +69,5 @@ app/logs
feature/
test
*.zip
*.zip
*.pdf

View File

@@ -0,0 +1,55 @@
import logging
import os
import shutil
from typing import List
import tqdm
from fastapi import UploadFile, File, APIRouter
from app.service.lookbooks.service import create_image_batch_requests
logger = logging.getLogger()
router = APIRouter()
@router.post("/process_lookbooks/")
async def process_lookbooks(files: List[UploadFile] = File(...)):
lookbook_dir = "service/lookbooks/temp_lookbooks"
os.makedirs(lookbook_dir, exist_ok=True)
lookbook_list = []
for file in files:
file_path = os.path.join(lookbook_dir, file.filename)
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
lookbook_list.append(file_path)
image_list = []
for look_book_path in tqdm.tqdm(lookbook_list):
lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0]
output_dir = os.path.join("app/service/lookbooks/fashion_documents/lookbook/images", lookbook_name)
os.makedirs(output_dir, exist_ok=True)
if not os.listdir(output_dir):
from unstructured.partition.pdf import partition_pdf
partition_pdf(
filename=look_book_path,
extract_images_in_pdf=True,
infer_table_structure=False,
chunking_strategy="by_title",
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
extract_image_block_output_dir=output_dir,
)
else:
current_images = os.listdir(output_dir)
image_list.extend([os.path.join(output_dir, x) for x in current_images])
image_description_results_file = create_image_batch_requests(image_list, "app/service/lookbooks/fashion_documents/lookbook/results")
shutil.rmtree(lookbook_dir)
if image_description_results_file:
return {"message": "Lookbooks processed successfully", "result_file": image_description_results_file}
else:
return {"message": "No new images to process"}

View File

@@ -1,9 +1,6 @@
from fastapi import APIRouter
from app.api import api_test
from app.api import api_outfit_matcher
from app.api import api_attribute
from app.api import api_similar_match
from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks
router = APIRouter()
@@ -11,3 +8,4 @@ router.include_router(api_test.router, tags=["test"], prefix="/test")
router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api/outfit_matcher")
router.include_router(api_attribute.router, tags=["attribute"], prefix="/api/attribute")
router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api/similar_match")
router.include_router(api_process_lookbooks.router, tags=["process_lookbooks"], prefix="/api/process_lookbooks")

View File

@@ -35,7 +35,7 @@ ATT_TRITON_PORT = "20010"
MILVUS_URL = "http://10.1.1.240:19530"
DEBUG = 1
DEBUG = 2
SHOW_OR_SAVE_result_image = False
# service env : 1
# pycharm debug : 2

0
app/schemas/lookbooks.py Normal file
View File

View File

View File

@@ -0,0 +1,128 @@
import json
import os
from openai import OpenAI
from app.service.lookbooks.utils.image_utils import base64_encode_image, generate_text_id
from app.service.lookbooks.utils.openai_utils import wait_for_job_completion
OPENAI_API_KEY = "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"
OPENAI_API_BASE = "https://pangkaichen-openai-prox-98.deno.dev/v1"
client = OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_API_BASE,
)
def create_image_batch_requests(
image_list,
output_path,
prompt="""You are an AI assistant specializing in fashion analysis and tagging for an advanced clothing indexing system. Your task is to analyze images of outfits and provide concise, relevant information. Please structure your response as follows:
1. Brief Summary: Start with a one-sentence summary of the overall style and vibe of the outfit.
2. Tags: Provide relevant tags in the following categories. Use multiple tags where appropriate, separated by commas.
Season: [e.g., Spring/Summer, Fall/Winter]
Style: [e.g., Minimalist, Bohemian, Streetwear, Business Casual]
Occasion: [e.g., Office, Casual, Party, Outdoor]
Colors: [List main colors used]
Materials: [e.g., Cotton, Denim, Leather, Silk]
Key Elements: [Any distinctive fashion elements]
3. Item Descriptions:
- Briefly describe each main clothing item (top, bottom, outerwear). Number and categorize each item (e.g., 1. Top: ..., 2. Bottom: ...). Keep descriptions concise but include key details about style, color, and distinctive features.
- In a single sentence, list all accessories and their overall effect on the outfit.
Ensure your tags and descriptions are accurate, relevant to current fashion trends, and useful for indexing and retrieval purposes."""
):
# 预处理 prompt移除多余的空白和换行符
prompt = ' '.join(prompt.split())
completed_id = []
if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")):
with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f:
for line in f:
completed_id.append(json.loads(line)["custom_id"])
tasks = []
id2img = {}
for idx, image_filename in enumerate(image_list):
image_base64 = base64_encode_image(image_filename)
if not image_base64:
continue
current_id = generate_text_id(image_base64)
if current_id in completed_id:
continue
task = {
"custom_id": current_id,
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-4o",
"temperature": 1.0,
"max_tokens": 500,
"messages": [
{
"role": "system",
"content": prompt
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}",
"detail": "low"
}
}
]
}
]
}
}
tasks.append(task)
id2img[current_id] = image_filename
print(f"In total {len(tasks)} images")
if tasks:
batch_file_name = os.path.join(output_path, "image_batch_requests.jsonl")
with open(batch_file_name, 'w', encoding='utf-8') as file:
for obj in tasks:
file.write(json.dumps(obj, ensure_ascii=False) + '\n')
batch_file = client.files.create(
file=open(batch_file_name, "rb"),
purpose="batch"
)
batch_job = client.batches.create(
input_file_id=batch_file.id,
endpoint="/v1/chat/completions",
completion_window="24h"
)
if wait_for_job_completion(client, batch_job.id):
output_file_id = client.batches.retrieve(batch_job.id).output_file_id
file_response = client.files.content(output_file_id)
file_content_str = file_response.read().decode('utf-8')
with open(os.path.join(output_path, "image_description_results.jsonl"), "w", encoding='utf-8') as f:
for line in file_content_str.splitlines():
if line.strip():
try:
result = json.loads(line)
image_id = result['custom_id']
caption = result['response']['body']['choices'][0]['message']['content']
output = json.dumps({
"custom_id": image_id,
"summary": caption,
"url": id2img[image_id]
})
f.write(output + '\n')
except json.JSONDecodeError as error:
print(f"Error parsing: {error} -- at line: {line}")
else:
print("Job failed")
return os.path.join(output_path, "image_description_results.jsonl")
else:
return None

View File

View File

@@ -0,0 +1,152 @@
from typing import Optional, Type, List, Dict, Any
import hashlib
import base64
import io
import os
from PIL import Image
import torch
from torchvision import transforms
def generate_text_id(text):
return hashlib.md5(text.encode('utf-8')).hexdigest()
def get_file_hash(file_path):
with open(file_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
def resize_image(img, max_size=512):
"""调整图片大小保持原始比例使长宽均不超过max_size并保存到output_path"""
width, height = img.size
if width > max_size or height > max_size:
# 保持图片的宽高比
if width > height:
new_width = max_size
new_height = int(max_size * height / width)
else:
new_height = max_size
new_width = int(max_size * width / height)
# 进行图片缩放
img_resized = img.resize((new_width, new_height), Image.LANCZOS)
else:
img_resized = img.copy() # 如果图片本身符合条件,则不进行缩放
return img_resized
def base64_encode_image(image_path):
try:
image_path = image_path.replace("\\", "/")
with Image.open(image_path) as img:
# 如果图像模式不是 RGB转换为 RGB
if img.mode != 'RGB':
img = img.convert('RGB')
# 获取当前图片的宽和高
img_resized = resize_image(img)
# 将PIL Image对象编码为base64字符串。
buffered = io.BytesIO()
img_resized.save(buffered, format="JPEG") # 保存调整后的图像到内存流
return base64.b64encode(buffered.getvalue()).decode("utf-8")
except FileNotFoundError as e:
print(f"File not found: {image_path}")
print(f"Absolute path: {os.path.abspath(image_path)}")
print(f"Current working directory: {os.getcwd()}")
return None
except Exception as e:
print(f"Error processing image {image_path}: {str(e)}")
return None
def get_possible_image_paths(path):
image_path = path.split(".")
# Image path might contain file extension (e.g. .jpg),
# In this case, we want the path without the extension
image_path = image_path if len(image_path) == 1 else image_path[:-1]
for ext in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
image_ext = ".".join(image_path) + ext
if os.path.isfile(image_ext):
path = image_ext
break
return path
def preprocess(image_urls: List[str]):
transform_pipeline = transforms.Compose([
# 缩放图像尺寸到 256x256
transforms.Resize((256, 256)),
# 从图像中心裁剪 224x224
transforms.CenterCrop(224),
# 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor
transforms.ToTensor(),
# 对图像进行标准化
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
])
image_data = []
if len(image_urls) == 0:
return None
for url in image_urls:
path = get_possible_image_paths(url)
try:
image = Image.open(path).convert("RGB")
except FileNotFoundError as e:
print(f"File not found: {path}, {e}")
continue
image_tensor = transform_pipeline(image)
image_data.append(image_tensor)
image_data = torch.stack(image_data, dim=0)
return image_data.numpy()
def is_image_accessible(path: str) -> bool:
if not os.path.exists(path):
return False
try:
with Image.open(path) as img:
img.verify() # 验证图像文件
return True
except (IOError, SyntaxError):
return False
def create_image_grid(image_paths, grid_size=(2, 2)):
# 打开所有图片
images = [Image.open(path) for path in image_paths]
# 确定每个图片的大小
max_width = max(img.width for img in images)
max_height = max(img.height for img in images)
# 创建一个新的空白图片
grid_width = grid_size[0] * max_width
grid_height = grid_size[1] * max_height
grid_image = Image.new('RGB', (grid_width, grid_height), color="white")
# 将图片粘贴到网格中
for i, img in enumerate(images):
x = (i % grid_size[0]) * max_width
y = (i // grid_size[0]) * max_height
# 计算缩放比例,保持原始宽高比
ratio = min(max_width / img.width, max_height / img.height)
new_size = (int(img.width * ratio), int(img.height * ratio))
# 调整大小并保持宽高比
img_resized = img.resize(new_size, Image.LANCZOS)
# 计算居中位置
paste_x = x + (max_width - new_size[0]) // 2
paste_y = y + (max_height - new_size[1]) // 2
# 粘贴到网格中
grid_image.paste(img_resized, (paste_x, paste_y))
# 将图片转换为字节流
img_byte_arr = io.BytesIO()
grid_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return img_byte_arr

View File

@@ -0,0 +1,31 @@
import time
def wait_for_job_completion(client, batch_job_id, timeout=6000, interval=10):
"""
等待批处理作业完成。
参数:
- client: 批处理作业客户端对象
- batch_job_id: 批处理作业的ID
- timeout: 最大等待时间(秒)
- interval: 检查状态的间隔时间(秒)
返回:
- True 如果作业完成False 如果超时
"""
start_time = time.time()
while time.time() - start_time < timeout:
batch_job = client.batches.retrieve(batch_job_id)
if batch_job.status == 'completed':
print("Batch job completed successfully.")
return True
elif batch_job.status == 'failed':
print("Batch job failed.")
return False
else:
print(f"Current status: {batch_job.status}. Waiting for completion...")
time.sleep(interval)
print("Timeout reached, job did not complete in time.")
return False

Binary file not shown.