84
app/litserve/model.py
Executable file
84
app/litserve/model.py
Executable file
@@ -0,0 +1,84 @@
|
||||
# model.py
|
||||
import torch
|
||||
from diffusers import FluxPipeline # 假设你用的是 flux from diffusers
|
||||
from PIL import Image
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
class Flux2KleinModel:
|
||||
def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"):
|
||||
print(f"Loading Flux model on {device} ...")
|
||||
self.pipe = FluxPipeline.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况
|
||||
# self.pipe.vae.enable_slicing() # 可选,省显存
|
||||
self.device = device
|
||||
|
||||
def preprocess_image(self, image_path: str) -> Image.Image:
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
return img
|
||||
|
||||
def run_inference(
|
||||
self,
|
||||
image: Image.Image,
|
||||
prompt: str,
|
||||
aspect_ratio: str = "1:1",
|
||||
base_long_edge: int = 1024,
|
||||
steps: int = 28,
|
||||
guidance: float = 4.0,
|
||||
seed: int = 0,
|
||||
**kwargs
|
||||
) -> Image.Image:
|
||||
# 根据 aspect_ratio 计算 width/height(示例实现,可按需修改)
|
||||
if ":" in aspect_ratio:
|
||||
w, h = map(int, aspect_ratio.split(":"))
|
||||
ratio = w / h
|
||||
else:
|
||||
ratio = 1.0
|
||||
|
||||
if ratio >= 1:
|
||||
width = base_long_edge
|
||||
height = int(base_long_edge / ratio)
|
||||
else:
|
||||
height = base_long_edge
|
||||
width = int(base_long_edge * ratio)
|
||||
|
||||
# 取整到 8 的倍数(flux 常见要求)
|
||||
width = (width // 8) * 8
|
||||
height = (height // 8) * 8
|
||||
|
||||
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||||
|
||||
result = self.pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=guidance,
|
||||
generator=generator,
|
||||
**kwargs
|
||||
).images[0]
|
||||
|
||||
return result
|
||||
|
||||
def batch_run(
|
||||
self,
|
||||
image_paths: List[str],
|
||||
prompts: List[str],
|
||||
out_paths: List[str],
|
||||
**common_kwargs
|
||||
):
|
||||
assert len(image_paths) == len(prompts) == len(out_paths)
|
||||
|
||||
for img_path, prompt, out_path in zip(image_paths, prompts, out_paths):
|
||||
try:
|
||||
img = self.preprocess_image(img_path)
|
||||
result = self.run_inference(img, prompt, **common_kwargs)
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
result.save(out_path)
|
||||
print(f"Saved: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed {img_path}: {e}")
|
||||
Reference in New Issue
Block a user