feat 新增 process lookbooks 接口
fix
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -69,4 +69,5 @@ app/logs
|
||||
|
||||
feature/
|
||||
test
|
||||
*.zip
|
||||
*.zip
|
||||
*.pdf
|
||||
55
app/api/api_process_lookbooks.py
Normal file
55
app/api/api_process_lookbooks.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import List
|
||||
|
||||
import tqdm
|
||||
from fastapi import UploadFile, File, APIRouter
|
||||
|
||||
from app.service.lookbooks.service import create_image_batch_requests
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/process_lookbooks/")
|
||||
async def process_lookbooks(files: List[UploadFile] = File(...)):
|
||||
lookbook_dir = "service/lookbooks/temp_lookbooks"
|
||||
os.makedirs(lookbook_dir, exist_ok=True)
|
||||
|
||||
lookbook_list = []
|
||||
for file in files:
|
||||
file_path = os.path.join(lookbook_dir, file.filename)
|
||||
with open(file_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
lookbook_list.append(file_path)
|
||||
|
||||
image_list = []
|
||||
for look_book_path in tqdm.tqdm(lookbook_list):
|
||||
lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0]
|
||||
output_dir = os.path.join("app/service/lookbooks/fashion_documents/lookbook/images", lookbook_name)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
if not os.listdir(output_dir):
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
partition_pdf(
|
||||
filename=look_book_path,
|
||||
extract_images_in_pdf=True,
|
||||
infer_table_structure=False,
|
||||
chunking_strategy="by_title",
|
||||
max_characters=4000,
|
||||
new_after_n_chars=3800,
|
||||
combine_text_under_n_chars=2000,
|
||||
extract_image_block_output_dir=output_dir,
|
||||
)
|
||||
else:
|
||||
current_images = os.listdir(output_dir)
|
||||
image_list.extend([os.path.join(output_dir, x) for x in current_images])
|
||||
|
||||
image_description_results_file = create_image_batch_requests(image_list, "app/service/lookbooks/fashion_documents/lookbook/results")
|
||||
|
||||
shutil.rmtree(lookbook_dir)
|
||||
|
||||
if image_description_results_file:
|
||||
return {"message": "Lookbooks processed successfully", "result_file": image_description_results_file}
|
||||
else:
|
||||
return {"message": "No new images to process"}
|
||||
@@ -1,9 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import api_test
|
||||
from app.api import api_outfit_matcher
|
||||
from app.api import api_attribute
|
||||
from app.api import api_similar_match
|
||||
from app.api import api_test, api_outfit_matcher, api_attribute, api_similar_match, api_process_lookbooks
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -11,3 +8,4 @@ router.include_router(api_test.router, tags=["test"], prefix="/test")
|
||||
router.include_router(api_outfit_matcher.router, tags=["outfit_matcher"], prefix="/api/outfit_matcher")
|
||||
router.include_router(api_attribute.router, tags=["attribute"], prefix="/api/attribute")
|
||||
router.include_router(api_similar_match.router, tags=["similar_match"], prefix="/api/similar_match")
|
||||
router.include_router(api_process_lookbooks.router, tags=["process_lookbooks"], prefix="/api/process_lookbooks")
|
||||
|
||||
@@ -35,7 +35,7 @@ ATT_TRITON_PORT = "20010"
|
||||
|
||||
MILVUS_URL = "http://10.1.1.240:19530"
|
||||
|
||||
DEBUG = 1
|
||||
DEBUG = 2
|
||||
SHOW_OR_SAVE_result_image = False
|
||||
# service env : 1
|
||||
# pycharm debug : 2
|
||||
|
||||
0
app/schemas/lookbooks.py
Normal file
0
app/schemas/lookbooks.py
Normal file
0
app/service/lookbooks/__init__.py
Normal file
0
app/service/lookbooks/__init__.py
Normal file
128
app/service/lookbooks/service.py
Normal file
128
app/service/lookbooks/service.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from app.service.lookbooks.utils.image_utils import base64_encode_image, generate_text_id
|
||||
from app.service.lookbooks.utils.openai_utils import wait_for_job_completion
|
||||
|
||||
OPENAI_API_KEY = "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"
|
||||
OPENAI_API_BASE = "https://pangkaichen-openai-prox-98.deno.dev/v1"
|
||||
client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=OPENAI_API_BASE,
|
||||
)
|
||||
|
||||
|
||||
def create_image_batch_requests(
|
||||
image_list,
|
||||
output_path,
|
||||
prompt="""You are an AI assistant specializing in fashion analysis and tagging for an advanced clothing indexing system. Your task is to analyze images of outfits and provide concise, relevant information. Please structure your response as follows:
|
||||
|
||||
1. Brief Summary: Start with a one-sentence summary of the overall style and vibe of the outfit.
|
||||
|
||||
2. Tags: Provide relevant tags in the following categories. Use multiple tags where appropriate, separated by commas.
|
||||
|
||||
Season: [e.g., Spring/Summer, Fall/Winter]
|
||||
Style: [e.g., Minimalist, Bohemian, Streetwear, Business Casual]
|
||||
Occasion: [e.g., Office, Casual, Party, Outdoor]
|
||||
Colors: [List main colors used]
|
||||
Materials: [e.g., Cotton, Denim, Leather, Silk]
|
||||
Key Elements: [Any distinctive fashion elements]
|
||||
|
||||
3. Item Descriptions:
|
||||
- Briefly describe each main clothing item (top, bottom, outerwear). Number and categorize each item (e.g., 1. Top: ..., 2. Bottom: ...). Keep descriptions concise but include key details about style, color, and distinctive features.
|
||||
- In a single sentence, list all accessories and their overall effect on the outfit.
|
||||
|
||||
Ensure your tags and descriptions are accurate, relevant to current fashion trends, and useful for indexing and retrieval purposes."""
|
||||
):
|
||||
# 预处理 prompt:移除多余的空白和换行符
|
||||
prompt = ' '.join(prompt.split())
|
||||
|
||||
completed_id = []
|
||||
if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")):
|
||||
with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f:
|
||||
for line in f:
|
||||
completed_id.append(json.loads(line)["custom_id"])
|
||||
|
||||
tasks = []
|
||||
id2img = {}
|
||||
for idx, image_filename in enumerate(image_list):
|
||||
image_base64 = base64_encode_image(image_filename)
|
||||
if not image_base64:
|
||||
continue
|
||||
current_id = generate_text_id(image_base64)
|
||||
if current_id in completed_id:
|
||||
continue
|
||||
task = {
|
||||
"custom_id": current_id,
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-4o",
|
||||
"temperature": 1.0,
|
||||
"max_tokens": 500,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": prompt
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}",
|
||||
"detail": "low"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
tasks.append(task)
|
||||
id2img[current_id] = image_filename
|
||||
print(f"In total {len(tasks)} images")
|
||||
if tasks:
|
||||
batch_file_name = os.path.join(output_path, "image_batch_requests.jsonl")
|
||||
with open(batch_file_name, 'w', encoding='utf-8') as file:
|
||||
for obj in tasks:
|
||||
file.write(json.dumps(obj, ensure_ascii=False) + '\n')
|
||||
|
||||
batch_file = client.files.create(
|
||||
file=open(batch_file_name, "rb"),
|
||||
purpose="batch"
|
||||
)
|
||||
batch_job = client.batches.create(
|
||||
input_file_id=batch_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h"
|
||||
)
|
||||
|
||||
if wait_for_job_completion(client, batch_job.id):
|
||||
output_file_id = client.batches.retrieve(batch_job.id).output_file_id
|
||||
file_response = client.files.content(output_file_id)
|
||||
file_content_str = file_response.read().decode('utf-8')
|
||||
|
||||
with open(os.path.join(output_path, "image_description_results.jsonl"), "w", encoding='utf-8') as f:
|
||||
for line in file_content_str.splitlines():
|
||||
if line.strip():
|
||||
try:
|
||||
result = json.loads(line)
|
||||
image_id = result['custom_id']
|
||||
caption = result['response']['body']['choices'][0]['message']['content']
|
||||
output = json.dumps({
|
||||
"custom_id": image_id,
|
||||
"summary": caption,
|
||||
"url": id2img[image_id]
|
||||
})
|
||||
f.write(output + '\n')
|
||||
except json.JSONDecodeError as error:
|
||||
print(f"Error parsing: {error} -- at line: {line}")
|
||||
else:
|
||||
print("Job failed")
|
||||
return os.path.join(output_path, "image_description_results.jsonl")
|
||||
else:
|
||||
return None
|
||||
0
app/service/lookbooks/utils/__init__.py
Normal file
0
app/service/lookbooks/utils/__init__.py
Normal file
152
app/service/lookbooks/utils/image_utils.py
Normal file
152
app/service/lookbooks/utils/image_utils.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Optional, Type, List, Dict, Any
|
||||
import hashlib
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def generate_text_id(text):
|
||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
def get_file_hash(file_path):
|
||||
with open(file_path, 'rb') as f:
|
||||
return hashlib.md5(f.read()).hexdigest()
|
||||
|
||||
|
||||
def resize_image(img, max_size=512):
|
||||
"""调整图片大小,保持原始比例,使长宽均不超过max_size,并保存到output_path"""
|
||||
width, height = img.size
|
||||
if width > max_size or height > max_size:
|
||||
# 保持图片的宽高比
|
||||
if width > height:
|
||||
new_width = max_size
|
||||
new_height = int(max_size * height / width)
|
||||
else:
|
||||
new_height = max_size
|
||||
new_width = int(max_size * width / height)
|
||||
|
||||
# 进行图片缩放
|
||||
img_resized = img.resize((new_width, new_height), Image.LANCZOS)
|
||||
else:
|
||||
img_resized = img.copy() # 如果图片本身符合条件,则不进行缩放
|
||||
|
||||
return img_resized
|
||||
|
||||
|
||||
def base64_encode_image(image_path):
|
||||
try:
|
||||
image_path = image_path.replace("\\", "/")
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
# 如果图像模式不是 RGB,转换为 RGB
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
# 获取当前图片的宽和高
|
||||
img_resized = resize_image(img)
|
||||
# 将PIL Image对象编码为base64字符串。
|
||||
buffered = io.BytesIO()
|
||||
img_resized.save(buffered, format="JPEG") # 保存调整后的图像到内存流
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
except FileNotFoundError as e:
|
||||
print(f"File not found: {image_path}")
|
||||
print(f"Absolute path: {os.path.abspath(image_path)}")
|
||||
print(f"Current working directory: {os.getcwd()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error processing image {image_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_possible_image_paths(path):
|
||||
image_path = path.split(".")
|
||||
# Image path might contain file extension (e.g. .jpg),
|
||||
# In this case, we want the path without the extension
|
||||
image_path = image_path if len(image_path) == 1 else image_path[:-1]
|
||||
for ext in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
|
||||
image_ext = ".".join(image_path) + ext
|
||||
if os.path.isfile(image_ext):
|
||||
path = image_ext
|
||||
break
|
||||
return path
|
||||
|
||||
|
||||
def preprocess(image_urls: List[str]):
|
||||
transform_pipeline = transforms.Compose([
|
||||
# 缩放图像尺寸到 256x256
|
||||
transforms.Resize((256, 256)),
|
||||
# 从图像中心裁剪 224x224
|
||||
transforms.CenterCrop(224),
|
||||
# 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor
|
||||
transforms.ToTensor(),
|
||||
# 对图像进行标准化
|
||||
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711])
|
||||
])
|
||||
image_data = []
|
||||
if len(image_urls) == 0:
|
||||
return None
|
||||
for url in image_urls:
|
||||
path = get_possible_image_paths(url)
|
||||
try:
|
||||
image = Image.open(path).convert("RGB")
|
||||
except FileNotFoundError as e:
|
||||
print(f"File not found: {path}, {e}")
|
||||
continue
|
||||
image_tensor = transform_pipeline(image)
|
||||
image_data.append(image_tensor)
|
||||
image_data = torch.stack(image_data, dim=0)
|
||||
return image_data.numpy()
|
||||
|
||||
|
||||
def is_image_accessible(path: str) -> bool:
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
try:
|
||||
with Image.open(path) as img:
|
||||
img.verify() # 验证图像文件
|
||||
return True
|
||||
except (IOError, SyntaxError):
|
||||
return False
|
||||
|
||||
def create_image_grid(image_paths, grid_size=(2, 2)):
|
||||
# 打开所有图片
|
||||
images = [Image.open(path) for path in image_paths]
|
||||
|
||||
# 确定每个图片的大小
|
||||
max_width = max(img.width for img in images)
|
||||
max_height = max(img.height for img in images)
|
||||
|
||||
# 创建一个新的空白图片
|
||||
grid_width = grid_size[0] * max_width
|
||||
grid_height = grid_size[1] * max_height
|
||||
grid_image = Image.new('RGB', (grid_width, grid_height), color="white")
|
||||
|
||||
# 将图片粘贴到网格中
|
||||
for i, img in enumerate(images):
|
||||
x = (i % grid_size[0]) * max_width
|
||||
y = (i // grid_size[0]) * max_height
|
||||
|
||||
# 计算缩放比例,保持原始宽高比
|
||||
ratio = min(max_width / img.width, max_height / img.height)
|
||||
new_size = (int(img.width * ratio), int(img.height * ratio))
|
||||
|
||||
# 调整大小并保持宽高比
|
||||
img_resized = img.resize(new_size, Image.LANCZOS)
|
||||
|
||||
# 计算居中位置
|
||||
paste_x = x + (max_width - new_size[0]) // 2
|
||||
paste_y = y + (max_height - new_size[1]) // 2
|
||||
|
||||
# 粘贴到网格中
|
||||
grid_image.paste(img_resized, (paste_x, paste_y))
|
||||
|
||||
# 将图片转换为字节流
|
||||
img_byte_arr = io.BytesIO()
|
||||
grid_image.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr = img_byte_arr.getvalue()
|
||||
return img_byte_arr
|
||||
31
app/service/lookbooks/utils/openai_utils.py
Normal file
31
app/service/lookbooks/utils/openai_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import time
|
||||
|
||||
|
||||
def wait_for_job_completion(client, batch_job_id, timeout=6000, interval=10):
|
||||
"""
|
||||
等待批处理作业完成。
|
||||
|
||||
参数:
|
||||
- client: 批处理作业客户端对象
|
||||
- batch_job_id: 批处理作业的ID
|
||||
- timeout: 最大等待时间(秒)
|
||||
- interval: 检查状态的间隔时间(秒)
|
||||
|
||||
返回:
|
||||
- True 如果作业完成,False 如果超时
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
batch_job = client.batches.retrieve(batch_job_id)
|
||||
if batch_job.status == 'completed':
|
||||
print("Batch job completed successfully.")
|
||||
return True
|
||||
elif batch_job.status == 'failed':
|
||||
print("Batch job failed.")
|
||||
return False
|
||||
else:
|
||||
print(f"Current status: {batch_job.status}. Waiting for completion...")
|
||||
time.sleep(interval)
|
||||
|
||||
print("Timeout reached, job did not complete in time.")
|
||||
return False
|
||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
Reference in New Issue
Block a user