This commit is contained in:
zcr
2026-03-23 10:18:30 +08:00
parent edf5ef1231
commit 8991927cd9
6 changed files with 22 additions and 44 deletions

110
app/litserve_app/app.py Executable file
View File

@@ -0,0 +1,110 @@
import torch
import litserve as ls
from PIL import Image
import io
import base64
from diffusers import Flux2KleinPipeline
# 保持原有的辅助函数
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
w_str, h_str = aspect_ratio.split(":")
w, h = float(w_str), float(h_str)
if w >= h:
width = base_long_edge
height = int(round(base_long_edge * (h / w)))
else:
height = base_long_edge
width = int(round(base_long_edge * (w / h)))
width = max(64, (width // 8) * 8)
height = max(64, (height // 8) * 8)
return width, height
class FluxKleinAPI(ls.LitAPI):
def setup(self, device):
# 1. 模型初始化
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
self.device = device
self.dtype = torch.bfloat16
self.pipe = Flux2KleinPipeline.from_pretrained(
self.repo_id,
torch_dtype=self.dtype
)
self.pipe.to(device)
self.default_system_prompt = (
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment... "
"Realistic scale, high material fidelity, sharp focus, 4K quality."
)
def decode_request(self, request):
# 1. 获取 Base64 字符串
image_data = request.get("image", None)
if not image_data:
raise ValueError("Request must contain 'image' field with base64 data.")
# 2. 如果带有 data:image/png;base64, 前缀,去掉它
if "," in image_data:
image_data = image_data.split(",")[1]
# 3. 解码并转换为 PIL 对象 (修复点)
img_bytes = base64.b64decode(image_data)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
return {
"image": img,
"prompt": request.get("prompt", ""),
"steps": request.get("steps", 28),
"guidance": request.get("guidance", 4.0),
"seed": request.get("seed", 42),
"aspect_ratio": request.get("aspect_ratio", "4:3"),
"base_long_edge": request.get("base_long_edge", 1024)
}
@torch.inference_mode()
def predict(self, payload):
# 3. 执行推理逻辑
img = payload.get("image", None)
prompt = payload.get("prompt", "")
W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"])
gen = torch.Generator(device=self.device)
if img:
# 这里简化了 resize 逻辑,实际可调用你原有的 resize_pad_to_target
img = img.convert("RGB").resize((W, H), Image.Resampling.LANCZOS)
output = self.pipe(
image=img,
prompt=prompt,
height=H,
width=W,
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
else:
output = self.pipe(
image=img,
prompt=prompt,
height=H,
width=W,
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
return output
def encode_response(self, output):
# 4. 将 PIL 图像转回 Base64 返回
buffered = io.BytesIO()
output.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image": img_str}
if __name__ == "__main__":
# 启动服务器
api = FluxKleinAPI()
server = ls.LitServer(api, accelerator="cuda", devices=1)
server.run(port=8451)

389
app/litserve_app/cli.py Executable file
View File

@@ -0,0 +1,389 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
FLUX.2 [klein] image-to-image editing with CLI args (single or batch).
Prereqs:
pip install -U "diffusers>=0.30.0" transformers accelerate safetensors pillow
# and you have CUDA set up for torch.
Examples:
# Single image
python flux2_klein_cli.py \
--image test_sketch/sofa_2.png \
--user_input "caramel brown leather, matte black legs" \
--out outputs/sofa_2_to_real.png
# Batch: one prompt applied to all images
python flux2_klein_cli.py \
--image_dir test_sketch \
--user_input "pure white background, photorealistic materials" \
--out_dir outputs/batch
# Batch: a prompt per image using a prompt file (mapping or list)
python flux2_klein_cli.py \
--image_dir test_sketch \
--prompt_file prompt.txt \
--out_dir outputs/batch
# Optional knobs
python flux2_klein_cli.py \
--image test_sketch/a.png \
--user_input "..." \
--out outputs/a.png \
--aspect_ratio 4:3 \
--base_long_edge 1024 \
--steps 28 \
--guidance 4.0 \
--seed 0 \
--cpu_offload
"""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import torch
from PIL import Image, ImageOps
from diffusers import Flux2KleinPipeline
from diffusers.utils import load_image
DEFAULT_REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
DEFAULT_DEVICE = "cuda"
DEFAULT_TORCH_DTYPE = "bfloat16"
DEFAULT_SYSTEM_PROMPT = (
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment, no interior, "
"no floor texture, and no visible floor/wall lines. Studio product photography lighting with soft "
"diffused light; ultra-clean industrial design presentation. Render a single furniture object only, "
"isolated and perfectly centered on a borderless canvas. Strictly preserve the geometry, silhouette, "
"proportions, and key structural details of the reference sketch (do not change shape, do not add/remove parts). "
"Realistic scale, high material fidelity, sharp focus, 4K quality. No people, no props, no text, no logos, no watermark."
)
# -----------------------------
# Prompt parsing (optional)
# -----------------------------
def load_prompts(prompt_txt: Path) -> Tuple[Dict[str, str], List[str]]:
"""
Supports:
A) "sofa_2.png: caramel brown leather ..." (filename: user_input)
B) "caramel brown leather ..." (user_input only; matched by sorted image filenames)
"""
if not prompt_txt.exists():
raise FileNotFoundError(f"prompt file not found: {prompt_txt}")
lines: List[str] = []
for raw in prompt_txt.read_text(encoding="utf-8").splitlines():
s = raw.strip()
if not s or s.startswith("#"):
continue
lines.append(s)
mapping: Dict[str, str] = {}
list_only: List[str] = []
looks_like_mapping = any(
(":" in ln and ln.split(":", 1)[0].strip().lower().endswith((".png", ".jpg", ".jpeg", ".webp")))
for ln in lines
)
if looks_like_mapping:
for ln in lines:
if ":" not in ln:
continue
k, v = ln.split(":", 1)
key = k.strip()
val = v.strip()
if key and val:
mapping[key] = val
else:
list_only = lines
return mapping, list_only
def list_images(image_dir: Path) -> List[Path]:
exts = {".png", ".jpg", ".jpeg", ".webp"}
imgs = [p for p in sorted(image_dir.iterdir()) if p.is_file() and p.suffix.lower() in exts]
if not imgs:
raise RuntimeError(f"No images found in: {image_dir}")
return imgs
def build_final_prompt(system_prompt: str, user_input: str) -> str:
user_input = (user_input or "").strip()
if not user_input:
return system_prompt
return f"{system_prompt} User requirements: {user_input}"
# -----------------------------
# Aspect ratio utils (resize+pad)
# -----------------------------
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
w_str, h_str = aspect_ratio.split(":")
w, h = float(w_str), float(h_str)
if w >= h:
width = base_long_edge
height = int(round(base_long_edge * (h / w)))
else:
height = base_long_edge
width = int(round(base_long_edge * (w / h)))
width = max(64, (width // 8) * 8)
height = max(64, (height // 8) * 8)
return width, height
def resize_pad_to_target(img: Image.Image, width: int, height: int, fill=(255, 255, 255)) -> Image.Image:
img = img.convert("RGB")
contained = ImageOps.contain(img, (width, height), method=Image.Resampling.LANCZOS)
pad_w = width - contained.size[0]
pad_h = height - contained.size[1]
padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2)
return ImageOps.expand(contained, border=padding, fill=fill)
def resize_to_long_edge(img: Image.Image, base_long_edge: int) -> Image.Image:
img = img.convert("RGB")
w, h = img.size
if max(w, h) == base_long_edge:
return img
if w >= h:
nw, nh = base_long_edge, int(round(base_long_edge * h / w))
else:
nw, nh = int(round(base_long_edge * w / h)), base_long_edge
nw = max(64, (nw // 8) * 8)
nh = max(64, (nh // 8) * 8)
return img.resize((nw, nh), Image.Resampling.LANCZOS)
# -----------------------------
# Run one
# -----------------------------
@torch.inference_mode()
def run_one(
pipe: Flux2KleinPipeline,
device: str,
image_path: Path,
final_prompt: str,
out_path: Path,
aspect_ratio: Optional[str],
base_long_edge: int,
steps: int,
guidance: float,
seed: int,
) -> None:
ref = load_image(str(image_path)) # PIL.Image
if aspect_ratio:
W, H = aspect_to_wh(aspect_ratio, base_long_edge)
ref = resize_pad_to_target(ref, W, H)
else:
ref = resize_to_long_edge(ref, base_long_edge)
W, H = ref.size
gen = torch.Generator(device=device).manual_seed(seed)
result = pipe(
image=ref,
prompt=final_prompt,
height=H,
width=W,
guidance_scale=guidance,
num_inference_steps=steps,
generator=gen,
).images[0]
out_path.parent.mkdir(parents=True, exist_ok=True)
result.save(out_path)
print(f"[OK] saved: {out_path}")
def parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(description="FLUX.2 [klein] img2img editing with CLI args.")
io = ap.add_argument_group("I/O")
io.add_argument("--image", type=str, default="", help="Single input image path.")
io.add_argument("--image_dir", type=str, default="", help="Directory of images for batch mode.")
io.add_argument("--out", type=str, default="", help="Single output image path (for --image).")
io.add_argument("--out_dir", type=str, default="", help="Output directory (for --image_dir).")
prompt = ap.add_argument_group("Prompt")
prompt.add_argument("--user_input", type=str, default="", help="User prompt (applied to single image or all images).")
prompt.add_argument(
"--prompt_file",
type=str,
default="",
help='Optional prompt file for batch mode. Supports "filename: prompt" mapping or one-prompt-per-line list.',
)
prompt.add_argument("--system_prompt", type=str, default=DEFAULT_SYSTEM_PROMPT, help="System prompt string.")
model = ap.add_argument_group("Model/Runtime")
model.add_argument("--repo_id", type=str, default=DEFAULT_REPO_ID, help=f"HF repo id (default: {DEFAULT_REPO_ID}).")
model.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='Device, e.g. "cuda" or "cpu".')
model.add_argument(
"--torch_dtype",
type=str,
default=DEFAULT_TORCH_DTYPE,
choices=["float16", "bfloat16", "float32"],
help="torch dtype for loading weights.",
)
model.add_argument("--cpu_offload", action="store_true", help="Enable model CPU offload (saves VRAM).")
gen = ap.add_argument_group("Generation")
gen.add_argument("--aspect_ratio", type=str, default="4:3", help='e.g. "16:9", "1:1". Set "" to disable.')
gen.add_argument("--base_long_edge", type=int, default=1024, help="Target long edge for output/resized reference.")
gen.add_argument("--steps", type=int, default=28, help="num_inference_steps.")
gen.add_argument("--guidance", type=float, default=4.0, help="guidance_scale.")
gen.add_argument("--seed", type=int, default=0, help="Random seed.")
args = ap.parse_args()
# Validate mode
if bool(args.image) == bool(args.image_dir):
raise SystemExit("Provide exactly one of --image or --image_dir.")
if args.image:
if not args.out:
raise SystemExit("In single-image mode, --out is required.")
if not args.user_input and not args.system_prompt:
raise SystemExit("Provide --user_input and/or --system_prompt.")
else:
if not args.out_dir:
raise SystemExit("In batch mode, --out_dir is required.")
if not args.prompt_file and not args.user_input:
raise SystemExit("Batch mode requires either --prompt_file or --user_input.")
return args
def resolve_torch_dtype(dtype_str: str) -> torch.dtype:
if dtype_str == "float16":
return torch.float16
if dtype_str == "bfloat16":
return torch.bfloat16
return torch.float32
def main() -> None:
args = parse_args()
# Normalize aspect_ratio
aspect_ratio = (args.aspect_ratio or "").strip()
if aspect_ratio == "":
aspect_ratio = None
torch_dtype = resolve_torch_dtype(args.torch_dtype)
pipe = Flux2KleinPipeline.from_pretrained(args.repo_id, torch_dtype=torch_dtype)
if args.cpu_offload:
pipe.enable_model_cpu_offload()
else:
pipe.to(args.device)
system_prompt = args.system_prompt
if args.image:
img_path = Path(args.image)
if not img_path.exists():
raise FileNotFoundError(f"Image not found: {img_path}")
out_path = Path(args.out)
final_prompt = build_final_prompt(system_prompt, args.user_input)
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
return
# Batch mode
image_dir = Path(args.image_dir)
if not image_dir.exists():
raise FileNotFoundError(f"Image dir not found: {image_dir}")
images = list_images(image_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
mapping: Dict[str, str] = {}
list_only: List[str] = []
if args.prompt_file:
mapping, list_only = load_prompts(Path(args.prompt_file))
if mapping:
# Mapping mode: filename -> prompt
for img_path in images:
user_input = mapping.get(img_path.name, "")
final_prompt = build_final_prompt(system_prompt, user_input)
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
elif list_only:
# List mode: one prompt per line, matched by sorted filenames
n = min(len(list_only), len(images))
if len(list_only) != len(images):
print(
f"[WARN] prompt lines ({len(list_only)}) != images ({len(images)}). "
"Will match by min length and leave the rest empty."
)
for i, img_path in enumerate(images):
user_input = list_only[i] if i < n else ""
final_prompt = build_final_prompt(system_prompt, user_input)
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
else:
# No prompt file parsed (or no prompt file). Apply --user_input to all
for img_path in images:
final_prompt = build_final_prompt(system_prompt, args.user_input)
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
if __name__ == "__main__":
main()

16
app/litserve_app/client.py Executable file
View File

@@ -0,0 +1,16 @@
import httpx
import asyncio
async def main():
async with httpx.AsyncClient() as client:
response = await client.post(
"http://localhost:8451/predict",
json={
"prompt": "紫色实木窗帘",
}
)
print(response.json())
asyncio.run(main())

84
app/litserve_app/model.py Executable file
View File

@@ -0,0 +1,84 @@
# model.py
import torch
from diffusers import FluxPipeline # 假设你用的是 flux from diffusers
from PIL import Image
import os
from typing import List, Optional, Dict, Any
class Flux2KleinModel:
def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"):
print(f"Loading Flux model on {device} ...")
self.pipe = FluxPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16
)
self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况
# self.pipe.vae.enable_slicing() # 可选,省显存
self.device = device
def preprocess_image(self, image_path: str) -> Image.Image:
img = Image.open(image_path).convert("RGB")
return img
def run_inference(
self,
image: Image.Image,
prompt: str,
aspect_ratio: str = "1:1",
base_long_edge: int = 1024,
steps: int = 28,
guidance: float = 4.0,
seed: int = 0,
**kwargs
) -> Image.Image:
# 根据 aspect_ratio 计算 width/height示例实现可按需修改
if ":" in aspect_ratio:
w, h = map(int, aspect_ratio.split(":"))
ratio = w / h
else:
ratio = 1.0
if ratio >= 1:
width = base_long_edge
height = int(base_long_edge / ratio)
else:
height = base_long_edge
width = int(base_long_edge * ratio)
# 取整到 8 的倍数flux 常见要求)
width = (width // 8) * 8
height = (height // 8) * 8
generator = torch.Generator(device=self.device).manual_seed(seed)
result = self.pipe(
prompt=prompt,
image=image,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
**kwargs
).images[0]
return result
def batch_run(
self,
image_paths: List[str],
prompts: List[str],
out_paths: List[str],
**common_kwargs
):
assert len(image_paths) == len(prompts) == len(out_paths)
for img_path, prompt, out_path in zip(image_paths, prompts, out_paths):
try:
img = self.preprocess_image(img_path)
result = self.run_inference(img, prompt, **common_kwargs)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
result.save(out_path)
print(f"Saved: {out_path}")
except Exception as e:
print(f"Failed {img_path}: {e}")