FLUX.2 [klein]

This commit is contained in:
timudk
2026-01-15 15:12:38 +01:00
parent ab7cca6801
commit b56ac61450
12 changed files with 530 additions and 119 deletions

View File

@@ -16,11 +16,12 @@ from flux2.sampling import (
batched_prc_img,
batched_prc_txt,
denoise,
denoise_cfg,
encode_image_refs,
get_schedule,
scatter_ids,
)
from flux2.util import FLUX2_MODEL_INFO, load_ae, load_flow_model, load_mistral_small_embedder
from flux2.util import FLUX2_MODEL_INFO, load_ae, load_flow_model, load_text_encoder
# from flux2.watermark import embed_watermark
@@ -167,6 +168,7 @@ def print_help():
print("""
Available commands:
[Enter] - Run generation with current config
<any text> - Set as prompt (then press Enter to generate)
run - Run generation with current config
show - Show current configuration
reset - Reset configuration to defaults
@@ -206,40 +208,114 @@ Parameters:
""")
def validate_model_params(model_name: str, cfg: Config) -> bool:
"""Validate that config parameters match model requirements. Returns True if valid."""
model_info = FLUX2_MODEL_INFO[model_name]
defaults = model_info.get("defaults", {})
fixed_params = model_info.get("fixed_params", set())
errors = []
if "num_steps" in fixed_params and cfg.num_steps != defaults["num_steps"]:
errors.append(
f"Model '{model_name}' requires num_steps={defaults['num_steps']}, "
f"but you specified num_steps={cfg.num_steps}"
)
if "guidance" in fixed_params and cfg.guidance != defaults["guidance"]:
errors.append(
f"Model '{model_name}' requires guidance={defaults['guidance']}, "
f"but you specified guidance={cfg.guidance}"
)
if errors:
print("\nERROR: Invalid parameters for selected model:", file=sys.stderr)
for error in errors:
print(f" - {error}", file=sys.stderr)
print("\nPlease adjust your parameters and try again.", file=sys.stderr)
return False
return True
# ---------- Main Loop ----------
def main(
model_name: str = "flux.2-dev",
model_name: str | None = None,
single_eval: bool = False,
prompt: str | None = None,
debug_mode: bool = False,
cpu_offloading: bool = False,
**overwrite,
):
# Prompt for model selection if not provided
if model_name is None:
available_models = list(FLUX2_MODEL_INFO.keys())
print("Available models:")
for i, name in enumerate(available_models, 1):
print(f" {i}. {name}")
while True:
try:
choice = input(f"\nSelect a model [default: {available_models[0]}]: ").strip()
if choice == "":
model_name = available_models[0]
break
elif choice.isdigit():
idx = int(choice) - 1
if 0 <= idx < len(available_models):
model_name = available_models[idx]
break
print(f"Please enter a number between 1 and {len(available_models)}")
elif choice.lower() in FLUX2_MODEL_INFO:
model_name = choice.lower()
break
else:
print(f"Invalid choice. Available models: {', '.join(available_models)}")
except (EOFError, KeyboardInterrupt):
print("\nbye!")
return
assert (
model_name.lower() in FLUX2_MODEL_INFO
), f"{model_name} is not available, choose from {FLUX2_MODEL_INFO.keys()}"
model_info = FLUX2_MODEL_INFO[model_name]
torch_device = torch.device("cuda")
mistral = load_mistral_small_embedder()
text_encoder = load_text_encoder(model_name, device=torch_device)
if "klein" in model_name:
mod_and_upsampling_model = load_text_encoder("flux.2-dev")
else:
mod_and_upsampling_model = text_encoder
model = load_flow_model(
model_name, debug_mode=debug_mode, device="cpu" if cpu_offloading else torch_device
)
ae = load_ae(model_name)
ae.eval()
mistral.eval()
text_encoder.eval()
# API client will be initialized lazily when needed
openrouter_api_client: Optional[OpenRouterAPIClient] = None
cfg = DEFAULTS.copy()
# Apply model defaults if not overridden
defaults = model_info.get("defaults", {})
if "num_steps" in defaults and "num_steps" not in overwrite:
cfg.num_steps = defaults["num_steps"]
if "guidance" in defaults and "guidance" not in overwrite:
cfg.guidance = defaults["guidance"]
changes = [f"{key}={value}" for key, value in overwrite.items()]
updates = parse_key_values(" ".join(changes))
apply_updates(cfg, updates)
if prompt is not None:
cfg.prompt = prompt
# Validate initial config
if not validate_model_params(model_name, cfg):
sys.exit(1)
print_config(cfg)
while True:
@@ -255,17 +331,24 @@ def main(
cmd = "run"
updates = {}
else:
try:
updates = parse_key_values(line)
except Exception as e: # noqa: BLE001
print(f" ! Failed to parse command: {type(e).__name__}: {e}", file=sys.stderr)
print(
" ! Please check your syntax (e.g., matching quotes) and try again.\n",
file=sys.stderr,
)
continue
# Check if this is plain text (no key=value pairs and not a known command)
known_commands = {"run", "show", "reset", "quit", "q", "exit", "help", "h", "?"}
if "=" not in line and line.lower() not in known_commands:
# Treat the entire line as a prompt
updates = {"prompt": line}
cmd = None
else:
try:
updates = parse_key_values(line)
except Exception as e: # noqa: BLE001
print(f" ! Failed to parse command: {type(e).__name__}: {e}", file=sys.stderr)
print(
" ! Please check your syntax (e.g., matching quotes) and try again.\n",
file=sys.stderr,
)
continue
if "prompt" in updates and mistral.test_txt(updates["prompt"]):
if "prompt" in updates and mod_and_upsampling_model.test_txt(updates["prompt"]):
print(
"Your prompt has been flagged for potential copyright or public personas concerns. Please choose another."
)
@@ -274,7 +357,7 @@ def main(
if "input_images" in updates:
flagged = False
for image in updates["input_images"]:
if mistral.test_image(image):
if mod_and_upsampling_model.test_image(image):
print(f"The image {image} has been flagged as unsuitable. Please choose another.")
flagged = True
if flagged:
@@ -294,6 +377,11 @@ def main(
break
elif cmd == "reset":
cfg = DEFAULTS.copy()
# Re-apply model defaults
if "num_steps" in defaults:
cfg.num_steps = defaults["num_steps"]
if "guidance" in defaults:
cfg.guidance = defaults["guidance"]
print_config(cfg)
continue
elif cmd == "show":
@@ -305,7 +393,16 @@ def main(
# Apply key=value changes
if updates:
apply_updates(cfg, updates)
# Create a temporary copy to test the updates
temp_cfg = cfg.copy()
apply_updates(temp_cfg, updates)
# Validate the temporary config
if not validate_model_params(model_name, temp_cfg):
continue
# Only apply to actual config if validation passed
cfg = temp_cfg
print_config(cfg)
continue
@@ -453,7 +550,7 @@ def main(
prompt = cfg.prompt
elif cfg.upsample_prompt_mode == "local":
# Use local model for upsampling
upsampled_prompts = mistral.upsample_prompt(
upsampled_prompts = mod_and_upsampling_model.upsample_prompt(
[cfg.prompt], img=[img_ctx] if img_ctx else None
)
prompt = upsampled_prompts[0] if upsampled_prompts else cfg.prompt
@@ -463,13 +560,20 @@ def main(
print("Generating with prompt: ", prompt)
ctx = mistral([prompt]).to(torch.bfloat16)
if model_info["guidance_distilled"]:
ctx = text_encoder([prompt]).to(torch.bfloat16)
else:
ctx_empty = text_encoder([""]).to(torch.bfloat16)
ctx_prompt = text_encoder([prompt]).to(torch.bfloat16)
ctx = torch.cat([ctx_empty, ctx_prompt], dim=0)
ctx, ctx_ids = batched_prc_txt(ctx)
if cpu_offloading:
mistral = mistral.cpu()
text_encoder = text_encoder.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
if "klein" in model_name:
mod_and_upsampling_model = mod_and_upsampling_model.cpu()
# Create noise
shape = (1, 128, height // 16, width // 16)
@@ -478,17 +582,30 @@ def main(
x, x_ids = batched_prc_img(randn)
timesteps = get_schedule(cfg.num_steps, x.shape[1])
x = denoise(
model,
x,
x_ids,
ctx,
ctx_ids,
timesteps=timesteps,
guidance=cfg.guidance,
img_cond_seq=ref_tokens,
img_cond_seq_ids=ref_ids,
)
if model_info["guidance_distilled"]:
x = denoise(
model,
x,
x_ids,
ctx,
ctx_ids,
timesteps=timesteps,
guidance=cfg.guidance,
img_cond_seq=ref_tokens,
img_cond_seq_ids=ref_ids,
)
else:
x = denoise_cfg(
model,
x,
x_ids,
ctx,
ctx_ids,
timesteps=timesteps,
guidance=cfg.guidance,
img_cond_seq=ref_tokens,
img_cond_seq_ids=ref_ids,
)
x = torch.cat(scatter_ids(x, x_ids)).squeeze(2)
x = ae.decode(x).float()
# x = embed_watermark(x)
@@ -496,13 +613,17 @@ def main(
if cpu_offloading:
model = model.cpu()
torch.cuda.empty_cache()
mistral = mistral.to(torch_device)
text_encoder = text_encoder.to(torch_device)
if "klein" in model_name:
mod_and_upsampling_model = mod_and_upsampling_model.to(torch_device)
x = x.clamp(-1, 1)
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
if mistral.test_image(img):
if mod_and_upsampling_model.test_image(img):
print("Your output has been flagged. Please choose another prompt / input image combination")
else:
exif_data = Image.Exif()