#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ FLUX.2 [klein] image-to-image editing with CLI args (single or batch). Prereqs: pip install -U "diffusers>=0.30.0" transformers accelerate safetensors pillow # and you have CUDA set up for torch. Examples: # Single image python flux2_klein_cli.py \ --image test_sketch/sofa_2.png \ --user_input "caramel brown leather, matte black legs" \ --out outputs/sofa_2_to_real.png # Batch: one prompt applied to all images python flux2_klein_cli.py \ --image_dir test_sketch \ --user_input "pure white background, photorealistic materials" \ --out_dir outputs/batch # Batch: a prompt per image using a prompt file (mapping or list) python flux2_klein_cli.py \ --image_dir test_sketch \ --prompt_file prompt.txt \ --out_dir outputs/batch # Optional knobs python flux2_klein_cli.py \ --image test_sketch/a.png \ --user_input "..." \ --out outputs/a.png \ --aspect_ratio 4:3 \ --base_long_edge 1024 \ --steps 28 \ --guidance 4.0 \ --seed 0 \ --cpu_offload """ from __future__ import annotations import argparse from pathlib import Path from typing import Dict, List, Tuple, Optional import torch from PIL import Image, ImageOps from diffusers import Flux2KleinPipeline from diffusers.utils import load_image DEFAULT_REPO_ID = "black-forest-labs/FLUX.2-klein-4B" DEFAULT_DEVICE = "cuda" DEFAULT_TORCH_DTYPE = "bfloat16" DEFAULT_SYSTEM_PROMPT = ( "Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment, no interior, " "no floor texture, and no visible floor/wall lines. Studio product photography lighting with soft " "diffused light; ultra-clean industrial design presentation. Render a single furniture object only, " "isolated and perfectly centered on a borderless canvas. Strictly preserve the geometry, silhouette, " "proportions, and key structural details of the reference sketch (do not change shape, do not add/remove parts). " "Realistic scale, high material fidelity, sharp focus, 4K quality. No people, no props, no text, no logos, no watermark." ) # ----------------------------- # Prompt parsing (optional) # ----------------------------- def load_prompts(prompt_txt: Path) -> Tuple[Dict[str, str], List[str]]: """ Supports: A) "sofa_2.png: caramel brown leather ..." (filename: user_input) B) "caramel brown leather ..." (user_input only; matched by sorted image filenames) """ if not prompt_txt.exists(): raise FileNotFoundError(f"prompt file not found: {prompt_txt}") lines: List[str] = [] for raw in prompt_txt.read_text(encoding="utf-8").splitlines(): s = raw.strip() if not s or s.startswith("#"): continue lines.append(s) mapping: Dict[str, str] = {} list_only: List[str] = [] looks_like_mapping = any( (":" in ln and ln.split(":", 1)[0].strip().lower().endswith((".png", ".jpg", ".jpeg", ".webp"))) for ln in lines ) if looks_like_mapping: for ln in lines: if ":" not in ln: continue k, v = ln.split(":", 1) key = k.strip() val = v.strip() if key and val: mapping[key] = val else: list_only = lines return mapping, list_only def list_images(image_dir: Path) -> List[Path]: exts = {".png", ".jpg", ".jpeg", ".webp"} imgs = [p for p in sorted(image_dir.iterdir()) if p.is_file() and p.suffix.lower() in exts] if not imgs: raise RuntimeError(f"No images found in: {image_dir}") return imgs def build_final_prompt(system_prompt: str, user_input: str) -> str: user_input = (user_input or "").strip() if not user_input: return system_prompt return f"{system_prompt} User requirements: {user_input}" # ----------------------------- # Aspect ratio utils (resize+pad) # ----------------------------- def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]: w_str, h_str = aspect_ratio.split(":") w, h = float(w_str), float(h_str) if w >= h: width = base_long_edge height = int(round(base_long_edge * (h / w))) else: height = base_long_edge width = int(round(base_long_edge * (w / h))) width = max(64, (width // 8) * 8) height = max(64, (height // 8) * 8) return width, height def resize_pad_to_target(img: Image.Image, width: int, height: int, fill=(255, 255, 255)) -> Image.Image: img = img.convert("RGB") contained = ImageOps.contain(img, (width, height), method=Image.Resampling.LANCZOS) pad_w = width - contained.size[0] pad_h = height - contained.size[1] padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2) return ImageOps.expand(contained, border=padding, fill=fill) def resize_to_long_edge(img: Image.Image, base_long_edge: int) -> Image.Image: img = img.convert("RGB") w, h = img.size if max(w, h) == base_long_edge: return img if w >= h: nw, nh = base_long_edge, int(round(base_long_edge * h / w)) else: nw, nh = int(round(base_long_edge * w / h)), base_long_edge nw = max(64, (nw // 8) * 8) nh = max(64, (nh // 8) * 8) return img.resize((nw, nh), Image.Resampling.LANCZOS) # ----------------------------- # Run one # ----------------------------- @torch.inference_mode() def run_one( pipe: Flux2KleinPipeline, device: str, image_path: Path, final_prompt: str, out_path: Path, aspect_ratio: Optional[str], base_long_edge: int, steps: int, guidance: float, seed: int, ) -> None: ref = load_image(str(image_path)) # PIL.Image if aspect_ratio: W, H = aspect_to_wh(aspect_ratio, base_long_edge) ref = resize_pad_to_target(ref, W, H) else: ref = resize_to_long_edge(ref, base_long_edge) W, H = ref.size gen = torch.Generator(device=device).manual_seed(seed) result = pipe( image=ref, prompt=final_prompt, height=H, width=W, guidance_scale=guidance, num_inference_steps=steps, generator=gen, ).images[0] out_path.parent.mkdir(parents=True, exist_ok=True) result.save(out_path) print(f"[OK] saved: {out_path}") def parse_args() -> argparse.Namespace: ap = argparse.ArgumentParser(description="FLUX.2 [klein] img2img editing with CLI args.") io = ap.add_argument_group("I/O") io.add_argument("--image", type=str, default="", help="Single input image path.") io.add_argument("--image_dir", type=str, default="", help="Directory of images for batch mode.") io.add_argument("--out", type=str, default="", help="Single output image path (for --image).") io.add_argument("--out_dir", type=str, default="", help="Output directory (for --image_dir).") prompt = ap.add_argument_group("Prompt") prompt.add_argument("--user_input", type=str, default="", help="User prompt (applied to single image or all images).") prompt.add_argument( "--prompt_file", type=str, default="", help='Optional prompt file for batch mode. Supports "filename: prompt" mapping or one-prompt-per-line list.', ) prompt.add_argument("--system_prompt", type=str, default=DEFAULT_SYSTEM_PROMPT, help="System prompt string.") model = ap.add_argument_group("Model/Runtime") model.add_argument("--repo_id", type=str, default=DEFAULT_REPO_ID, help=f"HF repo id (default: {DEFAULT_REPO_ID}).") model.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='Device, e.g. "cuda" or "cpu".') model.add_argument( "--torch_dtype", type=str, default=DEFAULT_TORCH_DTYPE, choices=["float16", "bfloat16", "float32"], help="torch dtype for loading weights.", ) model.add_argument("--cpu_offload", action="store_true", help="Enable model CPU offload (saves VRAM).") gen = ap.add_argument_group("Generation") gen.add_argument("--aspect_ratio", type=str, default="4:3", help='e.g. "16:9", "1:1". Set "" to disable.') gen.add_argument("--base_long_edge", type=int, default=1024, help="Target long edge for output/resized reference.") gen.add_argument("--steps", type=int, default=28, help="num_inference_steps.") gen.add_argument("--guidance", type=float, default=4.0, help="guidance_scale.") gen.add_argument("--seed", type=int, default=0, help="Random seed.") args = ap.parse_args() # Validate mode if bool(args.image) == bool(args.image_dir): raise SystemExit("Provide exactly one of --image or --image_dir.") if args.image: if not args.out: raise SystemExit("In single-image mode, --out is required.") if not args.user_input and not args.system_prompt: raise SystemExit("Provide --user_input and/or --system_prompt.") else: if not args.out_dir: raise SystemExit("In batch mode, --out_dir is required.") if not args.prompt_file and not args.user_input: raise SystemExit("Batch mode requires either --prompt_file or --user_input.") return args def resolve_torch_dtype(dtype_str: str) -> torch.dtype: if dtype_str == "float16": return torch.float16 if dtype_str == "bfloat16": return torch.bfloat16 return torch.float32 def main() -> None: args = parse_args() # Normalize aspect_ratio aspect_ratio = (args.aspect_ratio or "").strip() if aspect_ratio == "": aspect_ratio = None torch_dtype = resolve_torch_dtype(args.torch_dtype) pipe = Flux2KleinPipeline.from_pretrained(args.repo_id, torch_dtype=torch_dtype) if args.cpu_offload: pipe.enable_model_cpu_offload() else: pipe.to(args.device) system_prompt = args.system_prompt if args.image: img_path = Path(args.image) if not img_path.exists(): raise FileNotFoundError(f"Image not found: {img_path}") out_path = Path(args.out) final_prompt = build_final_prompt(system_prompt, args.user_input) run_one( pipe=pipe, device=args.device, image_path=img_path, final_prompt=final_prompt, out_path=out_path, aspect_ratio=aspect_ratio, base_long_edge=args.base_long_edge, steps=args.steps, guidance=args.guidance, seed=args.seed, ) return # Batch mode image_dir = Path(args.image_dir) if not image_dir.exists(): raise FileNotFoundError(f"Image dir not found: {image_dir}") images = list_images(image_dir) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) mapping: Dict[str, str] = {} list_only: List[str] = [] if args.prompt_file: mapping, list_only = load_prompts(Path(args.prompt_file)) if mapping: # Mapping mode: filename -> prompt for img_path in images: user_input = mapping.get(img_path.name, "") final_prompt = build_final_prompt(system_prompt, user_input) out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png" run_one( pipe=pipe, device=args.device, image_path=img_path, final_prompt=final_prompt, out_path=out_path, aspect_ratio=aspect_ratio, base_long_edge=args.base_long_edge, steps=args.steps, guidance=args.guidance, seed=args.seed, ) elif list_only: # List mode: one prompt per line, matched by sorted filenames n = min(len(list_only), len(images)) if len(list_only) != len(images): print( f"[WARN] prompt lines ({len(list_only)}) != images ({len(images)}). " "Will match by min length and leave the rest empty." ) for i, img_path in enumerate(images): user_input = list_only[i] if i < n else "" final_prompt = build_final_prompt(system_prompt, user_input) out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png" run_one( pipe=pipe, device=args.device, image_path=img_path, final_prompt=final_prompt, out_path=out_path, aspect_ratio=aspect_ratio, base_long_edge=args.base_long_edge, steps=args.steps, guidance=args.guidance, seed=args.seed, ) else: # No prompt file parsed (or no prompt file). Apply --user_input to all for img_path in images: final_prompt = build_final_prompt(system_prompt, args.user_input) out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png" run_one( pipe=pipe, device=args.device, image_path=img_path, final_prompt=final_prompt, out_path=out_path, aspect_ratio=aspect_ratio, base_long_edge=args.base_long_edge, steps=args.steps, guidance=args.guidance, seed=args.seed, ) if __name__ == "__main__": main()