feat 新增 process lookbooks 接口

fix
This commit is contained in:
zhouchengrong
2024-10-21 11:01:28 +08:00
parent 3417bcb2ab
commit e88ba6994a
11 changed files with 371 additions and 6 deletions

View File

View File

@@ -0,0 +1,152 @@
from typing import Optional, Type, List, Dict, Any
import hashlib
import base64
import io
import os
from PIL import Image
import torch
from torchvision import transforms
def generate_text_id(text):
return hashlib.md5(text.encode('utf-8')).hexdigest()
def get_file_hash(file_path):
with open(file_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
def resize_image(img, max_size=512):
"""调整图片大小保持原始比例使长宽均不超过max_size并保存到output_path"""
width, height = img.size
if width > max_size or height > max_size:
# 保持图片的宽高比
if width > height:
new_width = max_size
new_height = int(max_size * height / width)
else:
new_height = max_size
new_width = int(max_size * width / height)
# 进行图片缩放
img_resized = img.resize((new_width, new_height), Image.LANCZOS)
else:
img_resized = img.copy() # 如果图片本身符合条件,则不进行缩放
return img_resized
def base64_encode_image(image_path):
try:
image_path = image_path.replace("\\", "/")
with Image.open(image_path) as img:
# 如果图像模式不是 RGB转换为 RGB
if img.mode != 'RGB':
img = img.convert('RGB')
# 获取当前图片的宽和高
img_resized = resize_image(img)
# 将PIL Image对象编码为base64字符串。
buffered = io.BytesIO()
img_resized.save(buffered, format="JPEG") # 保存调整后的图像到内存流
return base64.b64encode(buffered.getvalue()).decode("utf-8")
except FileNotFoundError as e:
print(f"File not found: {image_path}")
print(f"Absolute path: {os.path.abspath(image_path)}")
print(f"Current working directory: {os.getcwd()}")
return None
except Exception as e:
print(f"Error processing image {image_path}: {str(e)}")
return None
def get_possible_image_paths(path):
image_path = path.split(".")
# Image path might contain file extension (e.g. .jpg),
# In this case, we want the path without the extension
image_path = image_path if len(image_path) == 1 else image_path[:-1]
for ext in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
image_ext = ".".join(image_path) + ext
if os.path.isfile(image_ext):
path = image_ext
break
return path
def preprocess(image_urls: List[str]):
transform_pipeline = transforms.Compose([
# 缩放图像尺寸到 256x256
transforms.Resize((256, 256)),
# 从图像中心裁剪 224x224
transforms.CenterCrop(224),
# 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor
transforms.ToTensor(),
# 对图像进行标准化
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
])
image_data = []
if len(image_urls) == 0:
return None
for url in image_urls:
path = get_possible_image_paths(url)
try:
image = Image.open(path).convert("RGB")
except FileNotFoundError as e:
print(f"File not found: {path}, {e}")
continue
image_tensor = transform_pipeline(image)
image_data.append(image_tensor)
image_data = torch.stack(image_data, dim=0)
return image_data.numpy()
def is_image_accessible(path: str) -> bool:
if not os.path.exists(path):
return False
try:
with Image.open(path) as img:
img.verify() # 验证图像文件
return True
except (IOError, SyntaxError):
return False
def create_image_grid(image_paths, grid_size=(2, 2)):
# 打开所有图片
images = [Image.open(path) for path in image_paths]
# 确定每个图片的大小
max_width = max(img.width for img in images)
max_height = max(img.height for img in images)
# 创建一个新的空白图片
grid_width = grid_size[0] * max_width
grid_height = grid_size[1] * max_height
grid_image = Image.new('RGB', (grid_width, grid_height), color="white")
# 将图片粘贴到网格中
for i, img in enumerate(images):
x = (i % grid_size[0]) * max_width
y = (i // grid_size[0]) * max_height
# 计算缩放比例,保持原始宽高比
ratio = min(max_width / img.width, max_height / img.height)
new_size = (int(img.width * ratio), int(img.height * ratio))
# 调整大小并保持宽高比
img_resized = img.resize(new_size, Image.LANCZOS)
# 计算居中位置
paste_x = x + (max_width - new_size[0]) // 2
paste_y = y + (max_height - new_size[1]) // 2
# 粘贴到网格中
grid_image.paste(img_resized, (paste_x, paste_y))
# 将图片转换为字节流
img_byte_arr = io.BytesIO()
grid_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return img_byte_arr

View File

@@ -0,0 +1,31 @@
import time
def wait_for_job_completion(client, batch_job_id, timeout=6000, interval=10):
"""
等待批处理作业完成。
参数:
- client: 批处理作业客户端对象
- batch_job_id: 批处理作业的ID
- timeout: 最大等待时间(秒)
- interval: 检查状态的间隔时间(秒)
返回:
- True 如果作业完成False 如果超时
"""
start_time = time.time()
while time.time() - start_time < timeout:
batch_job = client.batches.retrieve(batch_job_id)
if batch_job.status == 'completed':
print("Batch job completed successfully.")
return True
elif batch_job.status == 'failed':
print("Batch job failed.")
return False
else:
print(f"Current status: {batch_job.status}. Waiting for completion...")
time.sleep(interval)
print("Timeout reached, job did not complete in time.")
return False