84 lines
2.6 KiB
Python
84 lines
2.6 KiB
Python
# model.py
|
||
import torch
|
||
from diffusers import FluxPipeline # 假设你用的是 flux from diffusers
|
||
from PIL import Image
|
||
import os
|
||
from typing import List, Optional, Dict, Any
|
||
|
||
class Flux2KleinModel:
|
||
def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"):
|
||
print(f"Loading Flux model on {device} ...")
|
||
self.pipe = FluxPipeline.from_pretrained(
|
||
model_id,
|
||
torch_dtype=torch.bfloat16
|
||
)
|
||
self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况
|
||
# self.pipe.vae.enable_slicing() # 可选,省显存
|
||
self.device = device
|
||
|
||
def preprocess_image(self, image_path: str) -> Image.Image:
|
||
img = Image.open(image_path).convert("RGB")
|
||
return img
|
||
|
||
def run_inference(
|
||
self,
|
||
image: Image.Image,
|
||
prompt: str,
|
||
aspect_ratio: str = "1:1",
|
||
base_long_edge: int = 1024,
|
||
steps: int = 28,
|
||
guidance: float = 4.0,
|
||
seed: int = 0,
|
||
**kwargs
|
||
) -> Image.Image:
|
||
# 根据 aspect_ratio 计算 width/height(示例实现,可按需修改)
|
||
if ":" in aspect_ratio:
|
||
w, h = map(int, aspect_ratio.split(":"))
|
||
ratio = w / h
|
||
else:
|
||
ratio = 1.0
|
||
|
||
if ratio >= 1:
|
||
width = base_long_edge
|
||
height = int(base_long_edge / ratio)
|
||
else:
|
||
height = base_long_edge
|
||
width = int(base_long_edge * ratio)
|
||
|
||
# 取整到 8 的倍数(flux 常见要求)
|
||
width = (width // 8) * 8
|
||
height = (height // 8) * 8
|
||
|
||
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||
|
||
result = self.pipe(
|
||
prompt=prompt,
|
||
image=image,
|
||
height=height,
|
||
width=width,
|
||
num_inference_steps=steps,
|
||
guidance_scale=guidance,
|
||
generator=generator,
|
||
**kwargs
|
||
).images[0]
|
||
|
||
return result
|
||
|
||
def batch_run(
|
||
self,
|
||
image_paths: List[str],
|
||
prompts: List[str],
|
||
out_paths: List[str],
|
||
**common_kwargs
|
||
):
|
||
assert len(image_paths) == len(prompts) == len(out_paths)
|
||
|
||
for img_path, prompt, out_path in zip(image_paths, prompts, out_paths):
|
||
try:
|
||
img = self.preprocess_image(img_path)
|
||
result = self.run_inference(img, prompt, **common_kwargs)
|
||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||
result.save(out_path)
|
||
print(f"Saved: {out_path}")
|
||
except Exception as e:
|
||
print(f"Failed {img_path}: {e}") |