84 lines
2.6 KiB
Python
84 lines
2.6 KiB
Python
|
|
# model.py
|
|||
|
|
import torch
|
|||
|
|
from diffusers import FluxPipeline # 假设你用的是 flux from diffusers
|
|||
|
|
from PIL import Image
|
|||
|
|
import os
|
|||
|
|
from typing import List, Optional, Dict, Any
|
|||
|
|
|
|||
|
|
class Flux2KleinModel:
|
|||
|
|
def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"):
|
|||
|
|
print(f"Loading Flux model on {device} ...")
|
|||
|
|
self.pipe = FluxPipeline.from_pretrained(
|
|||
|
|
model_id,
|
|||
|
|
torch_dtype=torch.bfloat16
|
|||
|
|
)
|
|||
|
|
self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况
|
|||
|
|
# self.pipe.vae.enable_slicing() # 可选,省显存
|
|||
|
|
self.device = device
|
|||
|
|
|
|||
|
|
def preprocess_image(self, image_path: str) -> Image.Image:
|
|||
|
|
img = Image.open(image_path).convert("RGB")
|
|||
|
|
return img
|
|||
|
|
|
|||
|
|
def run_inference(
|
|||
|
|
self,
|
|||
|
|
image: Image.Image,
|
|||
|
|
prompt: str,
|
|||
|
|
aspect_ratio: str = "1:1",
|
|||
|
|
base_long_edge: int = 1024,
|
|||
|
|
steps: int = 28,
|
|||
|
|
guidance: float = 4.0,
|
|||
|
|
seed: int = 0,
|
|||
|
|
**kwargs
|
|||
|
|
) -> Image.Image:
|
|||
|
|
# 根据 aspect_ratio 计算 width/height(示例实现,可按需修改)
|
|||
|
|
if ":" in aspect_ratio:
|
|||
|
|
w, h = map(int, aspect_ratio.split(":"))
|
|||
|
|
ratio = w / h
|
|||
|
|
else:
|
|||
|
|
ratio = 1.0
|
|||
|
|
|
|||
|
|
if ratio >= 1:
|
|||
|
|
width = base_long_edge
|
|||
|
|
height = int(base_long_edge / ratio)
|
|||
|
|
else:
|
|||
|
|
height = base_long_edge
|
|||
|
|
width = int(base_long_edge * ratio)
|
|||
|
|
|
|||
|
|
# 取整到 8 的倍数(flux 常见要求)
|
|||
|
|
width = (width // 8) * 8
|
|||
|
|
height = (height // 8) * 8
|
|||
|
|
|
|||
|
|
generator = torch.Generator(device=self.device).manual_seed(seed)
|
|||
|
|
|
|||
|
|
result = self.pipe(
|
|||
|
|
prompt=prompt,
|
|||
|
|
image=image,
|
|||
|
|
height=height,
|
|||
|
|
width=width,
|
|||
|
|
num_inference_steps=steps,
|
|||
|
|
guidance_scale=guidance,
|
|||
|
|
generator=generator,
|
|||
|
|
**kwargs
|
|||
|
|
).images[0]
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def batch_run(
|
|||
|
|
self,
|
|||
|
|
image_paths: List[str],
|
|||
|
|
prompts: List[str],
|
|||
|
|
out_paths: List[str],
|
|||
|
|
**common_kwargs
|
|||
|
|
):
|
|||
|
|
assert len(image_paths) == len(prompts) == len(out_paths)
|
|||
|
|
|
|||
|
|
for img_path, prompt, out_path in zip(image_paths, prompts, out_paths):
|
|||
|
|
try:
|
|||
|
|
img = self.preprocess_image(img_path)
|
|||
|
|
result = self.run_inference(img, prompt, **common_kwargs)
|
|||
|
|
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
|||
|
|
result.save(out_path)
|
|||
|
|
print(f"Saved: {out_path}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Failed {img_path}: {e}")
|