390 lines
13 KiB
Python
Executable File
390 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
FLUX.2 [klein] image-to-image editing with CLI args (single or batch).
|
|
|
|
Prereqs:
|
|
pip install -U "diffusers>=0.30.0" transformers accelerate safetensors pillow
|
|
# and you have CUDA set up for torch.
|
|
|
|
Examples:
|
|
# Single image
|
|
python flux2_klein_cli.py \
|
|
--image test_sketch/sofa_2.png \
|
|
--user_input "caramel brown leather, matte black legs" \
|
|
--out outputs/sofa_2_to_real.png
|
|
|
|
# Batch: one prompt applied to all images
|
|
python flux2_klein_cli.py \
|
|
--image_dir test_sketch \
|
|
--user_input "pure white background, photorealistic materials" \
|
|
--out_dir outputs/batch
|
|
|
|
# Batch: a prompt per image using a prompt file (mapping or list)
|
|
python flux2_klein_cli.py \
|
|
--image_dir test_sketch \
|
|
--prompt_file prompt.txt \
|
|
--out_dir outputs/batch
|
|
|
|
# Optional knobs
|
|
python flux2_klein_cli.py \
|
|
--image test_sketch/a.png \
|
|
--user_input "..." \
|
|
--out outputs/a.png \
|
|
--aspect_ratio 4:3 \
|
|
--base_long_edge 1024 \
|
|
--steps 28 \
|
|
--guidance 4.0 \
|
|
--seed 0 \
|
|
--cpu_offload
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple, Optional
|
|
|
|
import torch
|
|
from PIL import Image, ImageOps
|
|
from diffusers import Flux2KleinPipeline
|
|
from diffusers.utils import load_image
|
|
|
|
DEFAULT_REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
|
|
DEFAULT_DEVICE = "cuda"
|
|
DEFAULT_TORCH_DTYPE = "bfloat16"
|
|
|
|
DEFAULT_SYSTEM_PROMPT = (
|
|
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment, no interior, "
|
|
"no floor texture, and no visible floor/wall lines. Studio product photography lighting with soft "
|
|
"diffused light; ultra-clean industrial design presentation. Render a single furniture object only, "
|
|
"isolated and perfectly centered on a borderless canvas. Strictly preserve the geometry, silhouette, "
|
|
"proportions, and key structural details of the reference sketch (do not change shape, do not add/remove parts). "
|
|
"Realistic scale, high material fidelity, sharp focus, 4K quality. No people, no props, no text, no logos, no watermark."
|
|
)
|
|
|
|
|
|
# -----------------------------
|
|
# Prompt parsing (optional)
|
|
# -----------------------------
|
|
def load_prompts(prompt_txt: Path) -> Tuple[Dict[str, str], List[str]]:
|
|
"""
|
|
Supports:
|
|
A) "sofa_2.png: caramel brown leather ..." (filename: user_input)
|
|
B) "caramel brown leather ..." (user_input only; matched by sorted image filenames)
|
|
"""
|
|
if not prompt_txt.exists():
|
|
raise FileNotFoundError(f"prompt file not found: {prompt_txt}")
|
|
|
|
lines: List[str] = []
|
|
for raw in prompt_txt.read_text(encoding="utf-8").splitlines():
|
|
s = raw.strip()
|
|
if not s or s.startswith("#"):
|
|
continue
|
|
lines.append(s)
|
|
|
|
mapping: Dict[str, str] = {}
|
|
list_only: List[str] = []
|
|
|
|
looks_like_mapping = any(
|
|
(":" in ln and ln.split(":", 1)[0].strip().lower().endswith((".png", ".jpg", ".jpeg", ".webp")))
|
|
for ln in lines
|
|
)
|
|
|
|
if looks_like_mapping:
|
|
for ln in lines:
|
|
if ":" not in ln:
|
|
continue
|
|
k, v = ln.split(":", 1)
|
|
key = k.strip()
|
|
val = v.strip()
|
|
if key and val:
|
|
mapping[key] = val
|
|
else:
|
|
list_only = lines
|
|
|
|
return mapping, list_only
|
|
|
|
|
|
def list_images(image_dir: Path) -> List[Path]:
|
|
exts = {".png", ".jpg", ".jpeg", ".webp"}
|
|
imgs = [p for p in sorted(image_dir.iterdir()) if p.is_file() and p.suffix.lower() in exts]
|
|
if not imgs:
|
|
raise RuntimeError(f"No images found in: {image_dir}")
|
|
return imgs
|
|
|
|
|
|
def build_final_prompt(system_prompt: str, user_input: str) -> str:
|
|
user_input = (user_input or "").strip()
|
|
if not user_input:
|
|
return system_prompt
|
|
return f"{system_prompt} User requirements: {user_input}"
|
|
|
|
|
|
# -----------------------------
|
|
# Aspect ratio utils (resize+pad)
|
|
# -----------------------------
|
|
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
|
|
w_str, h_str = aspect_ratio.split(":")
|
|
w, h = float(w_str), float(h_str)
|
|
if w >= h:
|
|
width = base_long_edge
|
|
height = int(round(base_long_edge * (h / w)))
|
|
else:
|
|
height = base_long_edge
|
|
width = int(round(base_long_edge * (w / h)))
|
|
|
|
width = max(64, (width // 8) * 8)
|
|
height = max(64, (height // 8) * 8)
|
|
return width, height
|
|
|
|
|
|
def resize_pad_to_target(img: Image.Image, width: int, height: int, fill=(255, 255, 255)) -> Image.Image:
|
|
img = img.convert("RGB")
|
|
contained = ImageOps.contain(img, (width, height), method=Image.Resampling.LANCZOS)
|
|
pad_w = width - contained.size[0]
|
|
pad_h = height - contained.size[1]
|
|
padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2)
|
|
return ImageOps.expand(contained, border=padding, fill=fill)
|
|
|
|
|
|
def resize_to_long_edge(img: Image.Image, base_long_edge: int) -> Image.Image:
|
|
img = img.convert("RGB")
|
|
w, h = img.size
|
|
if max(w, h) == base_long_edge:
|
|
return img
|
|
if w >= h:
|
|
nw, nh = base_long_edge, int(round(base_long_edge * h / w))
|
|
else:
|
|
nw, nh = int(round(base_long_edge * w / h)), base_long_edge
|
|
nw = max(64, (nw // 8) * 8)
|
|
nh = max(64, (nh // 8) * 8)
|
|
return img.resize((nw, nh), Image.Resampling.LANCZOS)
|
|
|
|
|
|
# -----------------------------
|
|
# Run one
|
|
# -----------------------------
|
|
@torch.inference_mode()
|
|
def run_one(
|
|
pipe: Flux2KleinPipeline,
|
|
device: str,
|
|
image_path: Path,
|
|
final_prompt: str,
|
|
out_path: Path,
|
|
aspect_ratio: Optional[str],
|
|
base_long_edge: int,
|
|
steps: int,
|
|
guidance: float,
|
|
seed: int,
|
|
) -> None:
|
|
ref = load_image(str(image_path)) # PIL.Image
|
|
|
|
if aspect_ratio:
|
|
W, H = aspect_to_wh(aspect_ratio, base_long_edge)
|
|
ref = resize_pad_to_target(ref, W, H)
|
|
else:
|
|
ref = resize_to_long_edge(ref, base_long_edge)
|
|
W, H = ref.size
|
|
|
|
gen = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
result = pipe(
|
|
image=ref,
|
|
prompt=final_prompt,
|
|
height=H,
|
|
width=W,
|
|
guidance_scale=guidance,
|
|
num_inference_steps=steps,
|
|
generator=gen,
|
|
).images[0]
|
|
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
result.save(out_path)
|
|
print(f"[OK] saved: {out_path}")
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
ap = argparse.ArgumentParser(description="FLUX.2 [klein] img2img editing with CLI args.")
|
|
|
|
io = ap.add_argument_group("I/O")
|
|
io.add_argument("--image", type=str, default="", help="Single input image path.")
|
|
io.add_argument("--image_dir", type=str, default="", help="Directory of images for batch mode.")
|
|
io.add_argument("--out", type=str, default="", help="Single output image path (for --image).")
|
|
io.add_argument("--out_dir", type=str, default="", help="Output directory (for --image_dir).")
|
|
|
|
prompt = ap.add_argument_group("Prompt")
|
|
prompt.add_argument("--user_input", type=str, default="", help="User prompt (applied to single image or all images).")
|
|
prompt.add_argument(
|
|
"--prompt_file",
|
|
type=str,
|
|
default="",
|
|
help='Optional prompt file for batch mode. Supports "filename: prompt" mapping or one-prompt-per-line list.',
|
|
)
|
|
prompt.add_argument("--system_prompt", type=str, default=DEFAULT_SYSTEM_PROMPT, help="System prompt string.")
|
|
|
|
model = ap.add_argument_group("Model/Runtime")
|
|
model.add_argument("--repo_id", type=str, default=DEFAULT_REPO_ID, help=f"HF repo id (default: {DEFAULT_REPO_ID}).")
|
|
model.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='Device, e.g. "cuda" or "cpu".')
|
|
model.add_argument(
|
|
"--torch_dtype",
|
|
type=str,
|
|
default=DEFAULT_TORCH_DTYPE,
|
|
choices=["float16", "bfloat16", "float32"],
|
|
help="torch dtype for loading weights.",
|
|
)
|
|
model.add_argument("--cpu_offload", action="store_true", help="Enable model CPU offload (saves VRAM).")
|
|
|
|
gen = ap.add_argument_group("Generation")
|
|
gen.add_argument("--aspect_ratio", type=str, default="4:3", help='e.g. "16:9", "1:1". Set "" to disable.')
|
|
gen.add_argument("--base_long_edge", type=int, default=1024, help="Target long edge for output/resized reference.")
|
|
gen.add_argument("--steps", type=int, default=28, help="num_inference_steps.")
|
|
gen.add_argument("--guidance", type=float, default=4.0, help="guidance_scale.")
|
|
gen.add_argument("--seed", type=int, default=0, help="Random seed.")
|
|
|
|
args = ap.parse_args()
|
|
|
|
# Validate mode
|
|
if bool(args.image) == bool(args.image_dir):
|
|
raise SystemExit("Provide exactly one of --image or --image_dir.")
|
|
|
|
if args.image:
|
|
if not args.out:
|
|
raise SystemExit("In single-image mode, --out is required.")
|
|
if not args.user_input and not args.system_prompt:
|
|
raise SystemExit("Provide --user_input and/or --system_prompt.")
|
|
else:
|
|
if not args.out_dir:
|
|
raise SystemExit("In batch mode, --out_dir is required.")
|
|
if not args.prompt_file and not args.user_input:
|
|
raise SystemExit("Batch mode requires either --prompt_file or --user_input.")
|
|
|
|
return args
|
|
|
|
|
|
def resolve_torch_dtype(dtype_str: str) -> torch.dtype:
|
|
if dtype_str == "float16":
|
|
return torch.float16
|
|
if dtype_str == "bfloat16":
|
|
return torch.bfloat16
|
|
return torch.float32
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
|
|
# Normalize aspect_ratio
|
|
aspect_ratio = (args.aspect_ratio or "").strip()
|
|
if aspect_ratio == "":
|
|
aspect_ratio = None
|
|
|
|
torch_dtype = resolve_torch_dtype(args.torch_dtype)
|
|
|
|
pipe = Flux2KleinPipeline.from_pretrained(args.repo_id, torch_dtype=torch_dtype)
|
|
if args.cpu_offload:
|
|
pipe.enable_model_cpu_offload()
|
|
else:
|
|
pipe.to(args.device)
|
|
|
|
system_prompt = args.system_prompt
|
|
|
|
if args.image:
|
|
img_path = Path(args.image)
|
|
if not img_path.exists():
|
|
raise FileNotFoundError(f"Image not found: {img_path}")
|
|
out_path = Path(args.out)
|
|
|
|
final_prompt = build_final_prompt(system_prompt, args.user_input)
|
|
run_one(
|
|
pipe=pipe,
|
|
device=args.device,
|
|
image_path=img_path,
|
|
final_prompt=final_prompt,
|
|
out_path=out_path,
|
|
aspect_ratio=aspect_ratio,
|
|
base_long_edge=args.base_long_edge,
|
|
steps=args.steps,
|
|
guidance=args.guidance,
|
|
seed=args.seed,
|
|
)
|
|
return
|
|
|
|
# Batch mode
|
|
image_dir = Path(args.image_dir)
|
|
if not image_dir.exists():
|
|
raise FileNotFoundError(f"Image dir not found: {image_dir}")
|
|
|
|
images = list_images(image_dir)
|
|
out_dir = Path(args.out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
mapping: Dict[str, str] = {}
|
|
list_only: List[str] = []
|
|
|
|
if args.prompt_file:
|
|
mapping, list_only = load_prompts(Path(args.prompt_file))
|
|
|
|
if mapping:
|
|
# Mapping mode: filename -> prompt
|
|
for img_path in images:
|
|
user_input = mapping.get(img_path.name, "")
|
|
final_prompt = build_final_prompt(system_prompt, user_input)
|
|
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
|
|
run_one(
|
|
pipe=pipe,
|
|
device=args.device,
|
|
image_path=img_path,
|
|
final_prompt=final_prompt,
|
|
out_path=out_path,
|
|
aspect_ratio=aspect_ratio,
|
|
base_long_edge=args.base_long_edge,
|
|
steps=args.steps,
|
|
guidance=args.guidance,
|
|
seed=args.seed,
|
|
)
|
|
elif list_only:
|
|
# List mode: one prompt per line, matched by sorted filenames
|
|
n = min(len(list_only), len(images))
|
|
if len(list_only) != len(images):
|
|
print(
|
|
f"[WARN] prompt lines ({len(list_only)}) != images ({len(images)}). "
|
|
"Will match by min length and leave the rest empty."
|
|
)
|
|
for i, img_path in enumerate(images):
|
|
user_input = list_only[i] if i < n else ""
|
|
final_prompt = build_final_prompt(system_prompt, user_input)
|
|
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
|
|
run_one(
|
|
pipe=pipe,
|
|
device=args.device,
|
|
image_path=img_path,
|
|
final_prompt=final_prompt,
|
|
out_path=out_path,
|
|
aspect_ratio=aspect_ratio,
|
|
base_long_edge=args.base_long_edge,
|
|
steps=args.steps,
|
|
guidance=args.guidance,
|
|
seed=args.seed,
|
|
)
|
|
else:
|
|
# No prompt file parsed (or no prompt file). Apply --user_input to all
|
|
for img_path in images:
|
|
final_prompt = build_final_prompt(system_prompt, args.user_input)
|
|
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
|
|
run_one(
|
|
pipe=pipe,
|
|
device=args.device,
|
|
image_path=img_path,
|
|
final_prompt=final_prompt,
|
|
out_path=out_path,
|
|
aspect_ratio=aspect_ratio,
|
|
base_long_edge=args.base_long_edge,
|
|
steps=args.steps,
|
|
guidance=args.guidance,
|
|
seed=args.seed,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|