Files
FiDA-Gen-Img-Flux2-klein/litserve/model.py
zcr d89a176b63
Some checks failed
CI / lint (push) Has been cancelled
1
2026-03-18 10:12:06 +08:00

84 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# model.py
import torch
from diffusers import FluxPipeline # 假设你用的是 flux from diffusers
from PIL import Image
import os
from typing import List, Optional, Dict, Any
class Flux2KleinModel:
def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"):
print(f"Loading Flux model on {device} ...")
self.pipe = FluxPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16
)
self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况
# self.pipe.vae.enable_slicing() # 可选,省显存
self.device = device
def preprocess_image(self, image_path: str) -> Image.Image:
img = Image.open(image_path).convert("RGB")
return img
def run_inference(
self,
image: Image.Image,
prompt: str,
aspect_ratio: str = "1:1",
base_long_edge: int = 1024,
steps: int = 28,
guidance: float = 4.0,
seed: int = 0,
**kwargs
) -> Image.Image:
# 根据 aspect_ratio 计算 width/height示例实现可按需修改
if ":" in aspect_ratio:
w, h = map(int, aspect_ratio.split(":"))
ratio = w / h
else:
ratio = 1.0
if ratio >= 1:
width = base_long_edge
height = int(base_long_edge / ratio)
else:
height = base_long_edge
width = int(base_long_edge * ratio)
# 取整到 8 的倍数flux 常见要求)
width = (width // 8) * 8
height = (height // 8) * 8
generator = torch.Generator(device=self.device).manual_seed(seed)
result = self.pipe(
prompt=prompt,
image=image,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
**kwargs
).images[0]
return result
def batch_run(
self,
image_paths: List[str],
prompts: List[str],
out_paths: List[str],
**common_kwargs
):
assert len(image_paths) == len(prompts) == len(out_paths)
for img_path, prompt, out_path in zip(image_paths, prompts, out_paths):
try:
img = self.preprocess_image(img_path)
result = self.run_inference(img, prompt, **common_kwargs)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
result.save(out_path)
print(f"Saved: {out_path}")
except Exception as e:
print(f"Failed {img_path}: {e}")