FLUX.2 [klein]
This commit is contained in:
185
scripts/cli.py
185
scripts/cli.py
@@ -16,11 +16,12 @@ from flux2.sampling import (
|
||||
batched_prc_img,
|
||||
batched_prc_txt,
|
||||
denoise,
|
||||
denoise_cfg,
|
||||
encode_image_refs,
|
||||
get_schedule,
|
||||
scatter_ids,
|
||||
)
|
||||
from flux2.util import FLUX2_MODEL_INFO, load_ae, load_flow_model, load_mistral_small_embedder
|
||||
from flux2.util import FLUX2_MODEL_INFO, load_ae, load_flow_model, load_text_encoder
|
||||
|
||||
# from flux2.watermark import embed_watermark
|
||||
|
||||
@@ -167,6 +168,7 @@ def print_help():
|
||||
print("""
|
||||
Available commands:
|
||||
[Enter] - Run generation with current config
|
||||
<any text> - Set as prompt (then press Enter to generate)
|
||||
run - Run generation with current config
|
||||
show - Show current configuration
|
||||
reset - Reset configuration to defaults
|
||||
@@ -206,40 +208,114 @@ Parameters:
|
||||
""")
|
||||
|
||||
|
||||
def validate_model_params(model_name: str, cfg: Config) -> bool:
|
||||
"""Validate that config parameters match model requirements. Returns True if valid."""
|
||||
model_info = FLUX2_MODEL_INFO[model_name]
|
||||
defaults = model_info.get("defaults", {})
|
||||
fixed_params = model_info.get("fixed_params", set())
|
||||
|
||||
errors = []
|
||||
if "num_steps" in fixed_params and cfg.num_steps != defaults["num_steps"]:
|
||||
errors.append(
|
||||
f"Model '{model_name}' requires num_steps={defaults['num_steps']}, "
|
||||
f"but you specified num_steps={cfg.num_steps}"
|
||||
)
|
||||
|
||||
if "guidance" in fixed_params and cfg.guidance != defaults["guidance"]:
|
||||
errors.append(
|
||||
f"Model '{model_name}' requires guidance={defaults['guidance']}, "
|
||||
f"but you specified guidance={cfg.guidance}"
|
||||
)
|
||||
|
||||
if errors:
|
||||
print("\nERROR: Invalid parameters for selected model:", file=sys.stderr)
|
||||
for error in errors:
|
||||
print(f" - {error}", file=sys.stderr)
|
||||
print("\nPlease adjust your parameters and try again.", file=sys.stderr)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# ---------- Main Loop ----------
|
||||
|
||||
|
||||
def main(
|
||||
model_name: str = "flux.2-dev",
|
||||
model_name: str | None = None,
|
||||
single_eval: bool = False,
|
||||
prompt: str | None = None,
|
||||
debug_mode: bool = False,
|
||||
cpu_offloading: bool = False,
|
||||
**overwrite,
|
||||
):
|
||||
# Prompt for model selection if not provided
|
||||
if model_name is None:
|
||||
available_models = list(FLUX2_MODEL_INFO.keys())
|
||||
print("Available models:")
|
||||
for i, name in enumerate(available_models, 1):
|
||||
print(f" {i}. {name}")
|
||||
while True:
|
||||
try:
|
||||
choice = input(f"\nSelect a model [default: {available_models[0]}]: ").strip()
|
||||
if choice == "":
|
||||
model_name = available_models[0]
|
||||
break
|
||||
elif choice.isdigit():
|
||||
idx = int(choice) - 1
|
||||
if 0 <= idx < len(available_models):
|
||||
model_name = available_models[idx]
|
||||
break
|
||||
print(f"Please enter a number between 1 and {len(available_models)}")
|
||||
elif choice.lower() in FLUX2_MODEL_INFO:
|
||||
model_name = choice.lower()
|
||||
break
|
||||
else:
|
||||
print(f"Invalid choice. Available models: {', '.join(available_models)}")
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nbye!")
|
||||
return
|
||||
|
||||
assert (
|
||||
model_name.lower() in FLUX2_MODEL_INFO
|
||||
), f"{model_name} is not available, choose from {FLUX2_MODEL_INFO.keys()}"
|
||||
|
||||
model_info = FLUX2_MODEL_INFO[model_name]
|
||||
torch_device = torch.device("cuda")
|
||||
|
||||
mistral = load_mistral_small_embedder()
|
||||
text_encoder = load_text_encoder(model_name, device=torch_device)
|
||||
if "klein" in model_name:
|
||||
mod_and_upsampling_model = load_text_encoder("flux.2-dev")
|
||||
else:
|
||||
mod_and_upsampling_model = text_encoder
|
||||
|
||||
model = load_flow_model(
|
||||
model_name, debug_mode=debug_mode, device="cpu" if cpu_offloading else torch_device
|
||||
)
|
||||
ae = load_ae(model_name)
|
||||
ae.eval()
|
||||
mistral.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
# API client will be initialized lazily when needed
|
||||
openrouter_api_client: Optional[OpenRouterAPIClient] = None
|
||||
|
||||
cfg = DEFAULTS.copy()
|
||||
|
||||
# Apply model defaults if not overridden
|
||||
defaults = model_info.get("defaults", {})
|
||||
if "num_steps" in defaults and "num_steps" not in overwrite:
|
||||
cfg.num_steps = defaults["num_steps"]
|
||||
if "guidance" in defaults and "guidance" not in overwrite:
|
||||
cfg.guidance = defaults["guidance"]
|
||||
|
||||
changes = [f"{key}={value}" for key, value in overwrite.items()]
|
||||
updates = parse_key_values(" ".join(changes))
|
||||
apply_updates(cfg, updates)
|
||||
if prompt is not None:
|
||||
cfg.prompt = prompt
|
||||
|
||||
# Validate initial config
|
||||
if not validate_model_params(model_name, cfg):
|
||||
sys.exit(1)
|
||||
print_config(cfg)
|
||||
|
||||
while True:
|
||||
@@ -255,17 +331,24 @@ def main(
|
||||
cmd = "run"
|
||||
updates = {}
|
||||
else:
|
||||
try:
|
||||
updates = parse_key_values(line)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f" ! Failed to parse command: {type(e).__name__}: {e}", file=sys.stderr)
|
||||
print(
|
||||
" ! Please check your syntax (e.g., matching quotes) and try again.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
# Check if this is plain text (no key=value pairs and not a known command)
|
||||
known_commands = {"run", "show", "reset", "quit", "q", "exit", "help", "h", "?"}
|
||||
if "=" not in line and line.lower() not in known_commands:
|
||||
# Treat the entire line as a prompt
|
||||
updates = {"prompt": line}
|
||||
cmd = None
|
||||
else:
|
||||
try:
|
||||
updates = parse_key_values(line)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f" ! Failed to parse command: {type(e).__name__}: {e}", file=sys.stderr)
|
||||
print(
|
||||
" ! Please check your syntax (e.g., matching quotes) and try again.\n",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
if "prompt" in updates and mistral.test_txt(updates["prompt"]):
|
||||
if "prompt" in updates and mod_and_upsampling_model.test_txt(updates["prompt"]):
|
||||
print(
|
||||
"Your prompt has been flagged for potential copyright or public personas concerns. Please choose another."
|
||||
)
|
||||
@@ -274,7 +357,7 @@ def main(
|
||||
if "input_images" in updates:
|
||||
flagged = False
|
||||
for image in updates["input_images"]:
|
||||
if mistral.test_image(image):
|
||||
if mod_and_upsampling_model.test_image(image):
|
||||
print(f"The image {image} has been flagged as unsuitable. Please choose another.")
|
||||
flagged = True
|
||||
if flagged:
|
||||
@@ -294,6 +377,11 @@ def main(
|
||||
break
|
||||
elif cmd == "reset":
|
||||
cfg = DEFAULTS.copy()
|
||||
# Re-apply model defaults
|
||||
if "num_steps" in defaults:
|
||||
cfg.num_steps = defaults["num_steps"]
|
||||
if "guidance" in defaults:
|
||||
cfg.guidance = defaults["guidance"]
|
||||
print_config(cfg)
|
||||
continue
|
||||
elif cmd == "show":
|
||||
@@ -305,7 +393,16 @@ def main(
|
||||
|
||||
# Apply key=value changes
|
||||
if updates:
|
||||
apply_updates(cfg, updates)
|
||||
# Create a temporary copy to test the updates
|
||||
temp_cfg = cfg.copy()
|
||||
apply_updates(temp_cfg, updates)
|
||||
|
||||
# Validate the temporary config
|
||||
if not validate_model_params(model_name, temp_cfg):
|
||||
continue
|
||||
|
||||
# Only apply to actual config if validation passed
|
||||
cfg = temp_cfg
|
||||
print_config(cfg)
|
||||
continue
|
||||
|
||||
@@ -453,7 +550,7 @@ def main(
|
||||
prompt = cfg.prompt
|
||||
elif cfg.upsample_prompt_mode == "local":
|
||||
# Use local model for upsampling
|
||||
upsampled_prompts = mistral.upsample_prompt(
|
||||
upsampled_prompts = mod_and_upsampling_model.upsample_prompt(
|
||||
[cfg.prompt], img=[img_ctx] if img_ctx else None
|
||||
)
|
||||
prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt
|
||||
@@ -463,13 +560,20 @@ def main(
|
||||
|
||||
print("Generating with prompt: ", prompt)
|
||||
|
||||
ctx = mistral([prompt]).to(torch.bfloat16)
|
||||
if model_info["guidance_distilled"]:
|
||||
ctx = text_encoder([prompt]).to(torch.bfloat16)
|
||||
else:
|
||||
ctx_empty = text_encoder([""]).to(torch.bfloat16)
|
||||
ctx_prompt = text_encoder([prompt]).to(torch.bfloat16)
|
||||
ctx = torch.cat([ctx_empty, ctx_prompt], dim=0)
|
||||
ctx, ctx_ids = batched_prc_txt(ctx)
|
||||
|
||||
if cpu_offloading:
|
||||
mistral = mistral.cpu()
|
||||
text_encoder = text_encoder.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
model = model.to(torch_device)
|
||||
if "klein" in model_name:
|
||||
mod_and_upsampling_model = mod_and_upsampling_model.cpu()
|
||||
|
||||
# Create noise
|
||||
shape = (1, 128, height // 16, width // 16)
|
||||
@@ -478,17 +582,30 @@ def main(
|
||||
x, x_ids = batched_prc_img(randn)
|
||||
|
||||
timesteps = get_schedule(cfg.num_steps, x.shape[1])
|
||||
x = denoise(
|
||||
model,
|
||||
x,
|
||||
x_ids,
|
||||
ctx,
|
||||
ctx_ids,
|
||||
timesteps=timesteps,
|
||||
guidance=cfg.guidance,
|
||||
img_cond_seq=ref_tokens,
|
||||
img_cond_seq_ids=ref_ids,
|
||||
)
|
||||
if model_info["guidance_distilled"]:
|
||||
x = denoise(
|
||||
model,
|
||||
x,
|
||||
x_ids,
|
||||
ctx,
|
||||
ctx_ids,
|
||||
timesteps=timesteps,
|
||||
guidance=cfg.guidance,
|
||||
img_cond_seq=ref_tokens,
|
||||
img_cond_seq_ids=ref_ids,
|
||||
)
|
||||
else:
|
||||
x = denoise_cfg(
|
||||
model,
|
||||
x,
|
||||
x_ids,
|
||||
ctx,
|
||||
ctx_ids,
|
||||
timesteps=timesteps,
|
||||
guidance=cfg.guidance,
|
||||
img_cond_seq=ref_tokens,
|
||||
img_cond_seq_ids=ref_ids,
|
||||
)
|
||||
x = torch.cat(scatter_ids(x, x_ids)).squeeze(2)
|
||||
x = ae.decode(x).float()
|
||||
# x = embed_watermark(x)
|
||||
@@ -496,13 +613,17 @@ def main(
|
||||
if cpu_offloading:
|
||||
model = model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
mistral = mistral.to(torch_device)
|
||||
text_encoder = text_encoder.to(torch_device)
|
||||
|
||||
if "klein" in model_name:
|
||||
mod_and_upsampling_model = mod_and_upsampling_model.to(torch_device)
|
||||
|
||||
x = x.clamp(-1, 1)
|
||||
x = rearrange(x[0], "c h w -> h w c")
|
||||
|
||||
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
||||
if mistral.test_image(img):
|
||||
|
||||
if mod_and_upsampling_model.test_image(img):
|
||||
print("Your output has been flagged. Please choose another prompt / input image combination")
|
||||
else:
|
||||
exif_data = Image.Exif()
|
||||
|
||||
Reference in New Issue
Block a user