feat 新增 process lookbooks 接口
fix
This commit is contained in:
0
app/service/lookbooks/utils/__init__.py
Normal file
0
app/service/lookbooks/utils/__init__.py
Normal file
152
app/service/lookbooks/utils/image_utils.py
Normal file
152
app/service/lookbooks/utils/image_utils.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Optional, Type, List, Dict, Any
|
||||
import hashlib
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def generate_text_id(text):
|
||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
def get_file_hash(file_path):
|
||||
with open(file_path, 'rb') as f:
|
||||
return hashlib.md5(f.read()).hexdigest()
|
||||
|
||||
|
||||
def resize_image(img, max_size=512):
|
||||
"""调整图片大小,保持原始比例,使长宽均不超过max_size,并保存到output_path"""
|
||||
width, height = img.size
|
||||
if width > max_size or height > max_size:
|
||||
# 保持图片的宽高比
|
||||
if width > height:
|
||||
new_width = max_size
|
||||
new_height = int(max_size * height / width)
|
||||
else:
|
||||
new_height = max_size
|
||||
new_width = int(max_size * width / height)
|
||||
|
||||
# 进行图片缩放
|
||||
img_resized = img.resize((new_width, new_height), Image.LANCZOS)
|
||||
else:
|
||||
img_resized = img.copy() # 如果图片本身符合条件,则不进行缩放
|
||||
|
||||
return img_resized
|
||||
|
||||
|
||||
def base64_encode_image(image_path):
|
||||
try:
|
||||
image_path = image_path.replace("\\", "/")
|
||||
|
||||
with Image.open(image_path) as img:
|
||||
# 如果图像模式不是 RGB,转换为 RGB
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
# 获取当前图片的宽和高
|
||||
img_resized = resize_image(img)
|
||||
# 将PIL Image对象编码为base64字符串。
|
||||
buffered = io.BytesIO()
|
||||
img_resized.save(buffered, format="JPEG") # 保存调整后的图像到内存流
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
except FileNotFoundError as e:
|
||||
print(f"File not found: {image_path}")
|
||||
print(f"Absolute path: {os.path.abspath(image_path)}")
|
||||
print(f"Current working directory: {os.getcwd()}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error processing image {image_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_possible_image_paths(path):
|
||||
image_path = path.split(".")
|
||||
# Image path might contain file extension (e.g. .jpg),
|
||||
# In this case, we want the path without the extension
|
||||
image_path = image_path if len(image_path) == 1 else image_path[:-1]
|
||||
for ext in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
|
||||
image_ext = ".".join(image_path) + ext
|
||||
if os.path.isfile(image_ext):
|
||||
path = image_ext
|
||||
break
|
||||
return path
|
||||
|
||||
|
||||
def preprocess(image_urls: List[str]):
|
||||
transform_pipeline = transforms.Compose([
|
||||
# 缩放图像尺寸到 256x256
|
||||
transforms.Resize((256, 256)),
|
||||
# 从图像中心裁剪 224x224
|
||||
transforms.CenterCrop(224),
|
||||
# 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor
|
||||
transforms.ToTensor(),
|
||||
# 对图像进行标准化
|
||||
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711])
|
||||
])
|
||||
image_data = []
|
||||
if len(image_urls) == 0:
|
||||
return None
|
||||
for url in image_urls:
|
||||
path = get_possible_image_paths(url)
|
||||
try:
|
||||
image = Image.open(path).convert("RGB")
|
||||
except FileNotFoundError as e:
|
||||
print(f"File not found: {path}, {e}")
|
||||
continue
|
||||
image_tensor = transform_pipeline(image)
|
||||
image_data.append(image_tensor)
|
||||
image_data = torch.stack(image_data, dim=0)
|
||||
return image_data.numpy()
|
||||
|
||||
|
||||
def is_image_accessible(path: str) -> bool:
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
try:
|
||||
with Image.open(path) as img:
|
||||
img.verify() # 验证图像文件
|
||||
return True
|
||||
except (IOError, SyntaxError):
|
||||
return False
|
||||
|
||||
def create_image_grid(image_paths, grid_size=(2, 2)):
|
||||
# 打开所有图片
|
||||
images = [Image.open(path) for path in image_paths]
|
||||
|
||||
# 确定每个图片的大小
|
||||
max_width = max(img.width for img in images)
|
||||
max_height = max(img.height for img in images)
|
||||
|
||||
# 创建一个新的空白图片
|
||||
grid_width = grid_size[0] * max_width
|
||||
grid_height = grid_size[1] * max_height
|
||||
grid_image = Image.new('RGB', (grid_width, grid_height), color="white")
|
||||
|
||||
# 将图片粘贴到网格中
|
||||
for i, img in enumerate(images):
|
||||
x = (i % grid_size[0]) * max_width
|
||||
y = (i // grid_size[0]) * max_height
|
||||
|
||||
# 计算缩放比例,保持原始宽高比
|
||||
ratio = min(max_width / img.width, max_height / img.height)
|
||||
new_size = (int(img.width * ratio), int(img.height * ratio))
|
||||
|
||||
# 调整大小并保持宽高比
|
||||
img_resized = img.resize(new_size, Image.LANCZOS)
|
||||
|
||||
# 计算居中位置
|
||||
paste_x = x + (max_width - new_size[0]) // 2
|
||||
paste_y = y + (max_height - new_size[1]) // 2
|
||||
|
||||
# 粘贴到网格中
|
||||
grid_image.paste(img_resized, (paste_x, paste_y))
|
||||
|
||||
# 将图片转换为字节流
|
||||
img_byte_arr = io.BytesIO()
|
||||
grid_image.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr = img_byte_arr.getvalue()
|
||||
return img_byte_arr
|
||||
31
app/service/lookbooks/utils/openai_utils.py
Normal file
31
app/service/lookbooks/utils/openai_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import time
|
||||
|
||||
|
||||
def wait_for_job_completion(client, batch_job_id, timeout=6000, interval=10):
|
||||
"""
|
||||
等待批处理作业完成。
|
||||
|
||||
参数:
|
||||
- client: 批处理作业客户端对象
|
||||
- batch_job_id: 批处理作业的ID
|
||||
- timeout: 最大等待时间(秒)
|
||||
- interval: 检查状态的间隔时间(秒)
|
||||
|
||||
返回:
|
||||
- True 如果作业完成,False 如果超时
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
batch_job = client.batches.retrieve(batch_job_id)
|
||||
if batch_job.status == 'completed':
|
||||
print("Batch job completed successfully.")
|
||||
return True
|
||||
elif batch_job.status == 'failed':
|
||||
print("Batch job failed.")
|
||||
return False
|
||||
else:
|
||||
print(f"Current status: {batch_job.status}. Waiting for completion...")
|
||||
time.sleep(interval)
|
||||
|
||||
print("Timeout reached, job did not complete in time.")
|
||||
return False
|
||||
Reference in New Issue
Block a user