Files
sora_python/app/service/lookbooks/service.py
2024-10-21 16:57:57 +08:00

204 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import os
import logging
import tqdm
import aiofiles
from openai import OpenAI
from app.service.lookbooks.utils.image_utils import base64_encode_image, generate_text_id
from app.service.lookbooks.utils.openai_utils import wait_for_job_completion
# 设置日志
logger = logging.getLogger()
# OpenAI 配置
OPENAI_API_KEY = "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"
OPENAI_API_BASE = "https://pangkaichen-openai-prox-98.deno.dev/v1"
client = OpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_API_BASE,
)
def create_image_batch_requests(
image_list,
output_path,
prompt="""You are an AI assistant specializing in fashion analysis and tagging for an advanced clothing indexing system. Your task is to analyze images of outfits and provide concise, relevant information. Please structure your response as follows:
1. Brief Summary: Start with a one-sentence summary of the overall style and vibe of the outfit.
2. Tags: Provide relevant tags in the following categories. Use multiple tags where appropriate, separated by commas.
Season: [e.g., Spring/Summer, Fall/Winter]
Style: [e.g., Minimalist, Bohemian, Streetwear, Business Casual]
Occasion: [e.g., Office, Casual, Party, Outdoor]
Colors: [List main colors used]
Materials: [e.g., Cotton, Denim, Leather, Silk]
Key Elements: [Any distinctive fashion elements]
3. Item Descriptions:
- Briefly describe each main clothing item (top, bottom, outerwear). Number and categorize each item (e.g., 1. Top: ..., 2. Bottom: ...). Keep descriptions concise but include key details about style, color, and distinctive features.
- In a single sentence, list all accessories and their overall effect on the outfit.
Ensure your tags and descriptions are accurate, relevant to current fashion trends, and useful for indexing and retrieval purposes."""
):
# 预处理 prompt移除多余的空白和换行符
prompt = ' '.join(prompt.split())
completed_id = []
if os.path.exists(os.path.join(output_path, "image_description_results.jsonl")):
with open(os.path.join(output_path, "image_description_results.jsonl"), "r") as f:
for line in f:
completed_id.append(json.loads(line)["custom_id"])
tasks = []
id2img = {}
for idx, image_filename in enumerate(image_list):
image_base64 = base64_encode_image(image_filename)
if not image_base64:
continue
current_id = generate_text_id(image_base64)
if current_id in completed_id:
continue
task = {
"custom_id": current_id,
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-4o",
"temperature": 1.0,
"max_tokens": 500,
"messages": [
{
"role": "system",
"content": prompt
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}",
"detail": "low"
}
}
]
}
]
}
}
tasks.append(task)
id2img[current_id] = image_filename
logger.info(f"In total {len(tasks)} images to process")
if tasks:
batch_file_name = os.path.join(output_path, "image_batch_requests.jsonl")
with open(batch_file_name, 'w', encoding='utf-8') as file:
for obj in tasks:
file.write(json.dumps(obj, ensure_ascii=False) + '\n')
batch_file = client.files.create(
file=open(batch_file_name, "rb"),
purpose="batch"
)
batch_job = client.batches.create(
input_file_id=batch_file.id,
endpoint="/v1/chat/completions",
completion_window="24h"
)
if wait_for_job_completion(client, batch_job.id):
output_file_id = client.batches.retrieve(batch_job.id).output_file_id
file_response = client.files.content(output_file_id)
file_content_str = file_response.read().decode('utf-8')
with open(os.path.join(output_path, "image_description_results.jsonl"), "w", encoding='utf-8') as f:
for line in file_content_str.splitlines():
if line.strip():
try:
result = json.loads(line)
image_id = result['custom_id']
caption = result['response']['body']['choices'][0]['message']['content']
output = json.dumps({
"custom_id": image_id,
"summary": caption,
"url": id2img[image_id]
})
f.write(output + '\n')
except json.JSONDecodeError as error:
logger.error(f"Error parsing: {error} -- at line: {line}")
else:
logger.error("Job failed")
return os.path.join(output_path, "image_description_results.jsonl")
else:
return None
async def process_lookbook_task(lookbook_list, tag, year):
"""后台异步任务,用于处理 lookbook 并保存到向量数据库"""
image_list = []
try:
for look_book_path in tqdm.tqdm(lookbook_list):
lookbook_name = os.path.splitext(os.path.basename(look_book_path))[0]
output_dir = os.path.join("fashion_documents/lookbook/images", lookbook_name)
os.makedirs(output_dir, exist_ok=True)
if not os.listdir(output_dir):
from unstructured.partition.pdf import partition_pdf
partition_pdf(
filename=look_book_path,
extract_images_in_pdf=True,
infer_table_structure=False,
chunking_strategy="by_title",
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
extract_image_block_output_dir=output_dir,
)
else:
current_images = os.listdir(output_dir)
image_list.extend([os.path.join(output_dir, x) for x in current_images])
# 1. 处理图片并生成批量请求
image_description_results_file = create_image_batch_requests(image_list, "fashion_documents/lookbook/results")
# 2. 保存结果到向量数据库
if image_description_results_file:
save_to_vector_db(image_description_results_file, tag, year)
except Exception as e:
logger.error(f"Error processing lookbooks: {str(e)}")
raise e
def save_to_vector_db(image_description_results_file, tag, year):
"""保存图像描述到向量数据库"""
image_ids = set()
image_summaries = []
image_metadatas = []
try:
with open(image_description_results_file, "r", encoding="utf-8") as f:
for image_content in f:
image_content = json.loads(image_content)
# 确保ID不重复
if image_content["custom_id"] not in image_ids:
image_ids.add(image_content["custom_id"])
image_summaries.append(image_content["summary"])
image_metadatas.append({
"data_type": "image",
"url": image_content["url"].replace("\\", "/"),
"source": "mitu",
"tag": tag,
"year": year,
"gender": "female"
})
# 将图像的描述和元数据添加到向量数据库中
collection.add_texts(texts=image_summaries, metadatas=image_metadatas, ids=list(image_ids))
logger.info("Successfully saved data to vector database")
except Exception as e:
logger.error(f"Error saving to vector database: {e}")
raise e