# model.py import torch from diffusers import FluxPipeline # 假设你用的是 flux from diffusers from PIL import Image import os from typing import List, Optional, Dict, Any class Flux2KleinModel: def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"): print(f"Loading Flux model on {device} ...") self.pipe = FluxPipeline.from_pretrained( model_id, torch_dtype=torch.bfloat16 ) self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况 # self.pipe.vae.enable_slicing() # 可选,省显存 self.device = device def preprocess_image(self, image_path: str) -> Image.Image: img = Image.open(image_path).convert("RGB") return img def run_inference( self, image: Image.Image, prompt: str, aspect_ratio: str = "1:1", base_long_edge: int = 1024, steps: int = 28, guidance: float = 4.0, seed: int = 0, **kwargs ) -> Image.Image: # 根据 aspect_ratio 计算 width/height(示例实现,可按需修改) if ":" in aspect_ratio: w, h = map(int, aspect_ratio.split(":")) ratio = w / h else: ratio = 1.0 if ratio >= 1: width = base_long_edge height = int(base_long_edge / ratio) else: height = base_long_edge width = int(base_long_edge * ratio) # 取整到 8 的倍数(flux 常见要求) width = (width // 8) * 8 height = (height // 8) * 8 generator = torch.Generator(device=self.device).manual_seed(seed) result = self.pipe( prompt=prompt, image=image, height=height, width=width, num_inference_steps=steps, guidance_scale=guidance, generator=generator, **kwargs ).images[0] return result def batch_run( self, image_paths: List[str], prompts: List[str], out_paths: List[str], **common_kwargs ): assert len(image_paths) == len(prompts) == len(out_paths) for img_path, prompt, out_path in zip(image_paths, prompts, out_paths): try: img = self.preprocess_image(img_path) result = self.run_inference(img, prompt, **common_kwargs) os.makedirs(os.path.dirname(out_path), exist_ok=True) result.save(out_path) print(f"Saved: {out_path}") except Exception as e: print(f"Failed {img_path}: {e}")