Files
FiDA-Gen-Img-Flux2-klein/app/litserve/cli.py
zcr b3eda2f7c7
Some checks failed
CI / lint (push) Failing after 11s
1
2026-03-18 11:09:01 +08:00

390 lines
13 KiB
Python
Executable File

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
FLUX.2 [klein] image-to-image editing with CLI args (single or batch).
Prereqs:
pip install -U "diffusers>=0.30.0" transformers accelerate safetensors pillow
# and you have CUDA set up for torch.
Examples:
# Single image
python flux2_klein_cli.py \
--image test_sketch/sofa_2.png \
--user_input "caramel brown leather, matte black legs" \
--out outputs/sofa_2_to_real.png
# Batch: one prompt applied to all images
python flux2_klein_cli.py \
--image_dir test_sketch \
--user_input "pure white background, photorealistic materials" \
--out_dir outputs/batch
# Batch: a prompt per image using a prompt file (mapping or list)
python flux2_klein_cli.py \
--image_dir test_sketch \
--prompt_file prompt.txt \
--out_dir outputs/batch
# Optional knobs
python flux2_klein_cli.py \
--image test_sketch/a.png \
--user_input "..." \
--out outputs/a.png \
--aspect_ratio 4:3 \
--base_long_edge 1024 \
--steps 28 \
--guidance 4.0 \
--seed 0 \
--cpu_offload
"""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import torch
from PIL import Image, ImageOps
from diffusers import Flux2KleinPipeline
from diffusers.utils import load_image
DEFAULT_REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
DEFAULT_DEVICE = "cuda"
DEFAULT_TORCH_DTYPE = "bfloat16"
DEFAULT_SYSTEM_PROMPT = (
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment, no interior, "
"no floor texture, and no visible floor/wall lines. Studio product photography lighting with soft "
"diffused light; ultra-clean industrial design presentation. Render a single furniture object only, "
"isolated and perfectly centered on a borderless canvas. Strictly preserve the geometry, silhouette, "
"proportions, and key structural details of the reference sketch (do not change shape, do not add/remove parts). "
"Realistic scale, high material fidelity, sharp focus, 4K quality. No people, no props, no text, no logos, no watermark."
)
# -----------------------------
# Prompt parsing (optional)
# -----------------------------
def load_prompts(prompt_txt: Path) -> Tuple[Dict[str, str], List[str]]:
"""
Supports:
A) "sofa_2.png: caramel brown leather ..." (filename: user_input)
B) "caramel brown leather ..." (user_input only; matched by sorted image filenames)
"""
if not prompt_txt.exists():
raise FileNotFoundError(f"prompt file not found: {prompt_txt}")
lines: List[str] = []
for raw in prompt_txt.read_text(encoding="utf-8").splitlines():
s = raw.strip()
if not s or s.startswith("#"):
continue
lines.append(s)
mapping: Dict[str, str] = {}
list_only: List[str] = []
looks_like_mapping = any(
(":" in ln and ln.split(":", 1)[0].strip().lower().endswith((".png", ".jpg", ".jpeg", ".webp")))
for ln in lines
)
if looks_like_mapping:
for ln in lines:
if ":" not in ln:
continue
k, v = ln.split(":", 1)
key = k.strip()
val = v.strip()
if key and val:
mapping[key] = val
else:
list_only = lines
return mapping, list_only
def list_images(image_dir: Path) -> List[Path]:
exts = {".png", ".jpg", ".jpeg", ".webp"}
imgs = [p for p in sorted(image_dir.iterdir()) if p.is_file() and p.suffix.lower() in exts]
if not imgs:
raise RuntimeError(f"No images found in: {image_dir}")
return imgs
def build_final_prompt(system_prompt: str, user_input: str) -> str:
user_input = (user_input or "").strip()
if not user_input:
return system_prompt
return f"{system_prompt} User requirements: {user_input}"
# -----------------------------
# Aspect ratio utils (resize+pad)
# -----------------------------
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
w_str, h_str = aspect_ratio.split(":")
w, h = float(w_str), float(h_str)
if w >= h:
width = base_long_edge
height = int(round(base_long_edge * (h / w)))
else:
height = base_long_edge
width = int(round(base_long_edge * (w / h)))
width = max(64, (width // 8) * 8)
height = max(64, (height // 8) * 8)
return width, height
def resize_pad_to_target(img: Image.Image, width: int, height: int, fill=(255, 255, 255)) -> Image.Image:
img = img.convert("RGB")
contained = ImageOps.contain(img, (width, height), method=Image.Resampling.LANCZOS)
pad_w = width - contained.size[0]
pad_h = height - contained.size[1]
padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2)
return ImageOps.expand(contained, border=padding, fill=fill)
def resize_to_long_edge(img: Image.Image, base_long_edge: int) -> Image.Image:
img = img.convert("RGB")
w, h = img.size
if max(w, h) == base_long_edge:
return img
if w >= h:
nw, nh = base_long_edge, int(round(base_long_edge * h / w))
else:
nw, nh = int(round(base_long_edge * w / h)), base_long_edge
nw = max(64, (nw // 8) * 8)
nh = max(64, (nh // 8) * 8)
return img.resize((nw, nh), Image.Resampling.LANCZOS)
# -----------------------------
# Run one
# -----------------------------
@torch.inference_mode()
def run_one(
pipe: Flux2KleinPipeline,
device: str,
image_path: Path,
final_prompt: str,
out_path: Path,
aspect_ratio: Optional[str],
base_long_edge: int,
steps: int,
guidance: float,
seed: int,
) -> None:
ref = load_image(str(image_path)) # PIL.Image
if aspect_ratio:
W, H = aspect_to_wh(aspect_ratio, base_long_edge)
ref = resize_pad_to_target(ref, W, H)
else:
ref = resize_to_long_edge(ref, base_long_edge)
W, H = ref.size
gen = torch.Generator(device=device).manual_seed(seed)
result = pipe(
image=ref,
prompt=final_prompt,
height=H,
width=W,
guidance_scale=guidance,
num_inference_steps=steps,
generator=gen,
).images[0]
out_path.parent.mkdir(parents=True, exist_ok=True)
result.save(out_path)
print(f"[OK] saved: {out_path}")
def parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(description="FLUX.2 [klein] img2img editing with CLI args.")
io = ap.add_argument_group("I/O")
io.add_argument("--image", type=str, default="", help="Single input image path.")
io.add_argument("--image_dir", type=str, default="", help="Directory of images for batch mode.")
io.add_argument("--out", type=str, default="", help="Single output image path (for --image).")
io.add_argument("--out_dir", type=str, default="", help="Output directory (for --image_dir).")
prompt = ap.add_argument_group("Prompt")
prompt.add_argument("--user_input", type=str, default="", help="User prompt (applied to single image or all images).")
prompt.add_argument(
"--prompt_file",
type=str,
default="",
help='Optional prompt file for batch mode. Supports "filename: prompt" mapping or one-prompt-per-line list.',
)
prompt.add_argument("--system_prompt", type=str, default=DEFAULT_SYSTEM_PROMPT, help="System prompt string.")
model = ap.add_argument_group("Model/Runtime")
model.add_argument("--repo_id", type=str, default=DEFAULT_REPO_ID, help=f"HF repo id (default: {DEFAULT_REPO_ID}).")
model.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='Device, e.g. "cuda" or "cpu".')
model.add_argument(
"--torch_dtype",
type=str,
default=DEFAULT_TORCH_DTYPE,
choices=["float16", "bfloat16", "float32"],
help="torch dtype for loading weights.",
)
model.add_argument("--cpu_offload", action="store_true", help="Enable model CPU offload (saves VRAM).")
gen = ap.add_argument_group("Generation")
gen.add_argument("--aspect_ratio", type=str, default="4:3", help='e.g. "16:9", "1:1". Set "" to disable.')
gen.add_argument("--base_long_edge", type=int, default=1024, help="Target long edge for output/resized reference.")
gen.add_argument("--steps", type=int, default=28, help="num_inference_steps.")
gen.add_argument("--guidance", type=float, default=4.0, help="guidance_scale.")
gen.add_argument("--seed", type=int, default=0, help="Random seed.")
args = ap.parse_args()
# Validate mode
if bool(args.image) == bool(args.image_dir):
raise SystemExit("Provide exactly one of --image or --image_dir.")
if args.image:
if not args.out:
raise SystemExit("In single-image mode, --out is required.")
if not args.user_input and not args.system_prompt:
raise SystemExit("Provide --user_input and/or --system_prompt.")
else:
if not args.out_dir:
raise SystemExit("In batch mode, --out_dir is required.")
if not args.prompt_file and not args.user_input:
raise SystemExit("Batch mode requires either --prompt_file or --user_input.")
return args
def resolve_torch_dtype(dtype_str: str) -> torch.dtype:
if dtype_str == "float16":
return torch.float16
if dtype_str == "bfloat16":
return torch.bfloat16
return torch.float32
def main() -> None:
args = parse_args()
# Normalize aspect_ratio
aspect_ratio = (args.aspect_ratio or "").strip()
if aspect_ratio == "":
aspect_ratio = None
torch_dtype = resolve_torch_dtype(args.torch_dtype)
pipe = Flux2KleinPipeline.from_pretrained(args.repo_id, torch_dtype=torch_dtype)
if args.cpu_offload:
pipe.enable_model_cpu_offload()
else:
pipe.to(args.device)
system_prompt = args.system_prompt
if args.image:
img_path = Path(args.image)
if not img_path.exists():
raise FileNotFoundError(f"Image not found: {img_path}")
out_path = Path(args.out)
final_prompt = build_final_prompt(system_prompt, args.user_input)
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
return
# Batch mode
image_dir = Path(args.image_dir)
if not image_dir.exists():
raise FileNotFoundError(f"Image dir not found: {image_dir}")
images = list_images(image_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
mapping: Dict[str, str] = {}
list_only: List[str] = []
if args.prompt_file:
mapping, list_only = load_prompts(Path(args.prompt_file))
if mapping:
# Mapping mode: filename -> prompt
for img_path in images:
user_input = mapping.get(img_path.name, "")
final_prompt = build_final_prompt(system_prompt, user_input)
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
elif list_only:
# List mode: one prompt per line, matched by sorted filenames
n = min(len(list_only), len(images))
if len(list_only) != len(images):
print(
f"[WARN] prompt lines ({len(list_only)}) != images ({len(images)}). "
"Will match by min length and leave the rest empty."
)
for i, img_path in enumerate(images):
user_input = list_only[i] if i < n else ""
final_prompt = build_final_prompt(system_prompt, user_input)
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
else:
# No prompt file parsed (or no prompt file). Apply --user_input to all
for img_path in images:
final_prompt = build_final_prompt(system_prompt, args.user_input)
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
if __name__ == "__main__":
main()