Files
FiDA-Gen-Img-Flux2-klein/litserve/model.py

84 lines
2.6 KiB
Python
Raw Normal View History

2026-03-18 10:12:06 +08:00
# model.py
import torch
from diffusers import FluxPipeline # 假设你用的是 flux from diffusers
from PIL import Image
import os
from typing import List, Optional, Dict, Any
class Flux2KleinModel:
def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"):
print(f"Loading Flux model on {device} ...")
self.pipe = FluxPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16
)
self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况
# self.pipe.vae.enable_slicing() # 可选,省显存
self.device = device
def preprocess_image(self, image_path: str) -> Image.Image:
img = Image.open(image_path).convert("RGB")
return img
def run_inference(
self,
image: Image.Image,
prompt: str,
aspect_ratio: str = "1:1",
base_long_edge: int = 1024,
steps: int = 28,
guidance: float = 4.0,
seed: int = 0,
**kwargs
) -> Image.Image:
# 根据 aspect_ratio 计算 width/height示例实现可按需修改
if ":" in aspect_ratio:
w, h = map(int, aspect_ratio.split(":"))
ratio = w / h
else:
ratio = 1.0
if ratio >= 1:
width = base_long_edge
height = int(base_long_edge / ratio)
else:
height = base_long_edge
width = int(base_long_edge * ratio)
# 取整到 8 的倍数flux 常见要求)
width = (width // 8) * 8
height = (height // 8) * 8
generator = torch.Generator(device=self.device).manual_seed(seed)
result = self.pipe(
prompt=prompt,
image=image,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
**kwargs
).images[0]
return result
def batch_run(
self,
image_paths: List[str],
prompts: List[str],
out_paths: List[str],
**common_kwargs
):
assert len(image_paths) == len(prompts) == len(out_paths)
for img_path, prompt, out_path in zip(image_paths, prompts, out_paths):
try:
img = self.preprocess_image(img_path)
result = self.run_inference(img, prompt, **common_kwargs)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
result.save(out_path)
print(f"Saved: {out_path}")
except Exception as e:
print(f"Failed {img_path}: {e}")