153 lines
5.0 KiB
Python
153 lines
5.0 KiB
Python
|
|
from typing import Optional, Type, List, Dict, Any
|
|||
|
|
import hashlib
|
|||
|
|
import base64
|
|||
|
|
import io
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
from PIL import Image
|
|||
|
|
import torch
|
|||
|
|
from torchvision import transforms
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_text_id(text):
|
|||
|
|
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_file_hash(file_path):
|
|||
|
|
with open(file_path, 'rb') as f:
|
|||
|
|
return hashlib.md5(f.read()).hexdigest()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def resize_image(img, max_size=512):
|
|||
|
|
"""调整图片大小,保持原始比例,使长宽均不超过max_size,并保存到output_path"""
|
|||
|
|
width, height = img.size
|
|||
|
|
if width > max_size or height > max_size:
|
|||
|
|
# 保持图片的宽高比
|
|||
|
|
if width > height:
|
|||
|
|
new_width = max_size
|
|||
|
|
new_height = int(max_size * height / width)
|
|||
|
|
else:
|
|||
|
|
new_height = max_size
|
|||
|
|
new_width = int(max_size * width / height)
|
|||
|
|
|
|||
|
|
# 进行图片缩放
|
|||
|
|
img_resized = img.resize((new_width, new_height), Image.LANCZOS)
|
|||
|
|
else:
|
|||
|
|
img_resized = img.copy() # 如果图片本身符合条件,则不进行缩放
|
|||
|
|
|
|||
|
|
return img_resized
|
|||
|
|
|
|||
|
|
|
|||
|
|
def base64_encode_image(image_path):
|
|||
|
|
try:
|
|||
|
|
image_path = image_path.replace("\\", "/")
|
|||
|
|
|
|||
|
|
with Image.open(image_path) as img:
|
|||
|
|
# 如果图像模式不是 RGB,转换为 RGB
|
|||
|
|
if img.mode != 'RGB':
|
|||
|
|
img = img.convert('RGB')
|
|||
|
|
# 获取当前图片的宽和高
|
|||
|
|
img_resized = resize_image(img)
|
|||
|
|
# 将PIL Image对象编码为base64字符串。
|
|||
|
|
buffered = io.BytesIO()
|
|||
|
|
img_resized.save(buffered, format="JPEG") # 保存调整后的图像到内存流
|
|||
|
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|||
|
|
except FileNotFoundError as e:
|
|||
|
|
print(f"File not found: {image_path}")
|
|||
|
|
print(f"Absolute path: {os.path.abspath(image_path)}")
|
|||
|
|
print(f"Current working directory: {os.getcwd()}")
|
|||
|
|
return None
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error processing image {image_path}: {str(e)}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_possible_image_paths(path):
|
|||
|
|
image_path = path.split(".")
|
|||
|
|
# Image path might contain file extension (e.g. .jpg),
|
|||
|
|
# In this case, we want the path without the extension
|
|||
|
|
image_path = image_path if len(image_path) == 1 else image_path[:-1]
|
|||
|
|
for ext in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
|
|||
|
|
image_ext = ".".join(image_path) + ext
|
|||
|
|
if os.path.isfile(image_ext):
|
|||
|
|
path = image_ext
|
|||
|
|
break
|
|||
|
|
return path
|
|||
|
|
|
|||
|
|
|
|||
|
|
def preprocess(image_urls: List[str]):
|
|||
|
|
transform_pipeline = transforms.Compose([
|
|||
|
|
# 缩放图像尺寸到 256x256
|
|||
|
|
transforms.Resize((256, 256)),
|
|||
|
|
# 从图像中心裁剪 224x224
|
|||
|
|
transforms.CenterCrop(224),
|
|||
|
|
# 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor
|
|||
|
|
transforms.ToTensor(),
|
|||
|
|
# 对图像进行标准化
|
|||
|
|
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
|||
|
|
std=[0.26862954, 0.26130258, 0.27577711])
|
|||
|
|
])
|
|||
|
|
image_data = []
|
|||
|
|
if len(image_urls) == 0:
|
|||
|
|
return None
|
|||
|
|
for url in image_urls:
|
|||
|
|
path = get_possible_image_paths(url)
|
|||
|
|
try:
|
|||
|
|
image = Image.open(path).convert("RGB")
|
|||
|
|
except FileNotFoundError as e:
|
|||
|
|
print(f"File not found: {path}, {e}")
|
|||
|
|
continue
|
|||
|
|
image_tensor = transform_pipeline(image)
|
|||
|
|
image_data.append(image_tensor)
|
|||
|
|
image_data = torch.stack(image_data, dim=0)
|
|||
|
|
return image_data.numpy()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def is_image_accessible(path: str) -> bool:
|
|||
|
|
if not os.path.exists(path):
|
|||
|
|
return False
|
|||
|
|
try:
|
|||
|
|
with Image.open(path) as img:
|
|||
|
|
img.verify() # 验证图像文件
|
|||
|
|
return True
|
|||
|
|
except (IOError, SyntaxError):
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def create_image_grid(image_paths, grid_size=(2, 2)):
|
|||
|
|
# 打开所有图片
|
|||
|
|
images = [Image.open(path) for path in image_paths]
|
|||
|
|
|
|||
|
|
# 确定每个图片的大小
|
|||
|
|
max_width = max(img.width for img in images)
|
|||
|
|
max_height = max(img.height for img in images)
|
|||
|
|
|
|||
|
|
# 创建一个新的空白图片
|
|||
|
|
grid_width = grid_size[0] * max_width
|
|||
|
|
grid_height = grid_size[1] * max_height
|
|||
|
|
grid_image = Image.new('RGB', (grid_width, grid_height), color="white")
|
|||
|
|
|
|||
|
|
# 将图片粘贴到网格中
|
|||
|
|
for i, img in enumerate(images):
|
|||
|
|
x = (i % grid_size[0]) * max_width
|
|||
|
|
y = (i // grid_size[0]) * max_height
|
|||
|
|
|
|||
|
|
# 计算缩放比例,保持原始宽高比
|
|||
|
|
ratio = min(max_width / img.width, max_height / img.height)
|
|||
|
|
new_size = (int(img.width * ratio), int(img.height * ratio))
|
|||
|
|
|
|||
|
|
# 调整大小并保持宽高比
|
|||
|
|
img_resized = img.resize(new_size, Image.LANCZOS)
|
|||
|
|
|
|||
|
|
# 计算居中位置
|
|||
|
|
paste_x = x + (max_width - new_size[0]) // 2
|
|||
|
|
paste_y = y + (max_height - new_size[1]) // 2
|
|||
|
|
|
|||
|
|
# 粘贴到网格中
|
|||
|
|
grid_image.paste(img_resized, (paste_x, paste_y))
|
|||
|
|
|
|||
|
|
# 将图片转换为字节流
|
|||
|
|
img_byte_arr = io.BytesIO()
|
|||
|
|
grid_image.save(img_byte_arr, format='PNG')
|
|||
|
|
img_byte_arr = img_byte_arr.getvalue()
|
|||
|
|
return img_byte_arr
|