FLUX.2 launch

This commit is contained in:
timudk
2025-11-25 07:25:25 -08:00
commit e80b84ed9f
24 changed files with 3238 additions and 0 deletions

336
src/flux2/autoencoder.py Normal file
View File

@@ -0,0 +1,336 @@
import math
from dataclasses import dataclass, field
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int = 256
in_channels: int = 3
ch: int = 128
out_ch: int = 3
ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
num_res_blocks: int = 2
z_channels: int = 32
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
h = self.quant_conv(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
z = self.post_quant_conv(z)
# get dtype for proper tracing
upscale_dtype = next(self.up.parameters()).dtype
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# cast to proper dtype
h = h.to(upscale_dtype)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.params = params
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.bn_eps = 1e-4
self.bn_momentum = 0.1
self.ps = [2, 2]
self.bn = torch.nn.BatchNorm2d(
math.prod(self.ps) * params.z_channels,
eps=self.bn_eps,
momentum=self.bn_momentum,
affine=False,
track_running_stats=True,
)
def normalize(self, z):
self.bn.eval()
return self.bn(z)
def inv_normalize(self, z):
self.bn.eval()
s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)
m = self.bn.running_mean.view(1, -1, 1, 1)
return z * s + m
def encode(self, x: Tensor) -> Tensor:
moments = self.encoder(x)
mean = torch.chunk(moments, 2, dim=1)[0]
z = rearrange(
mean,
"... c (i pi) (j pj) -> ... (c pi pj) i j",
pi=self.ps[0],
pj=self.ps[1],
)
z = self.normalize(z)
return z
def decode(self, z: Tensor) -> Tensor:
z = self.inv_normalize(z)
z = rearrange(
z,
"... (c pi pj) i j -> ... c (i pi) (j pj)",
pi=self.ps[0],
pj=self.ps[1],
)
dec = self.decoder(z)
return dec

451
src/flux2/model.py Normal file
View File

@@ -0,0 +1,451 @@
import math
from dataclasses import dataclass, field
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class Flux2Params:
in_channels: int = 128
context_in_dim: int = 15360
hidden_size: int = 6144
num_heads: int = 48
depth: int = 8
depth_single_blocks: int = 48
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
class Flux2(nn.Module):
def __init__(self, params: Flux2Params):
super().__init__()
self.in_channels = params.in_channels
self.out_channels = params.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
)
for _ in range(params.depth_single_blocks)
]
)
self.double_stream_modulation_img = Modulation(
self.hidden_size,
double=True,
disable_bias=True,
)
self.double_stream_modulation_txt = Modulation(
self.hidden_size,
double=True,
disable_bias=True,
)
self.single_stream_modulation = Modulation(self.hidden_size, double=False, disable_bias=True)
self.final_layer = LastLayer(
self.hidden_size,
self.out_channels,
)
def forward(
self,
x: Tensor,
x_ids: Tensor,
timesteps: Tensor,
ctx: Tensor,
ctx_ids: Tensor,
guidance: Tensor,
):
num_txt_tokens = ctx.shape[1]
timestep_emb = timestep_embedding(timesteps, 256)
vec = self.time_in(timestep_emb)
guidance_emb = timestep_embedding(guidance, 256)
vec = vec + self.guidance_in(guidance_emb)
double_block_mod_img = self.double_stream_modulation_img(vec)
double_block_mod_txt = self.double_stream_modulation_txt(vec)
single_block_mod, _ = self.single_stream_modulation(vec)
img = self.img_in(x)
txt = self.txt_in(ctx)
pe_x = self.pe_embedder(x_ids)
pe_ctx = self.pe_embedder(ctx_ids)
for block in self.double_blocks:
img, txt = block(
img,
txt,
pe_x,
pe_ctx,
double_block_mod_img,
double_block_mod_txt,
)
img = torch.cat((txt, img), dim=1)
pe = torch.cat((pe_ctx, pe_x), dim=2)
for i, block in enumerate(self.single_blocks):
img = block(
img,
pe,
single_block_mod,
)
img = img[:, num_txt_tokens:, ...]
img = self.final_layer(img, vec)
return img
class SelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim, bias=False)
class SiLUActivation(nn.Module):
def __init__(self):
super().__init__()
self.gate_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return self.gate_fn(x1) * x2
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, disable_bias: bool = False):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
def forward(self, vec: torch.Tensor):
out = self.lin(nn.functional.silu(vec))
if out.ndim == 2:
out = out[:, None, :]
out = out.chunk(self.multiplier, dim=-1)
return out[:3], out[3:] if self.is_double else None
class LastLayer(nn.Module):
def __init__(
self,
hidden_size: int,
out_channels: int,
):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=False)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False))
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
mod = self.adaLN_modulation(vec)
shift, scale = mod.chunk(2, dim=-1)
if shift.ndim == 2:
shift = shift[:, None, :]
scale = scale[:, None, :]
x = (1 + scale) * self.norm_final(x) + shift
x = self.linear(x)
return x
class SingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_mult_factor = 2
self.linear1 = nn.Linear(
hidden_size,
hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
bias=False,
)
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=False)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_act = SiLUActivation()
def forward(
self,
x: Tensor,
pe: Tensor,
mod: tuple[Tensor, Tensor],
) -> Tensor:
mod_shift, mod_scale, mod_gate = mod
x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
qkv, mlp = torch.split(
self.linear1(x_mod),
[3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
dim=-1,
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
attn = attention(q, k, v, pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod_gate * output
class DoubleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
assert hidden_size % num_heads == 0, f"{hidden_size=} must be divisible by {num_heads=}"
self.hidden_size = hidden_size
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp_mult_factor = 2
self.img_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
SiLUActivation(),
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(
hidden_size,
mlp_hidden_dim * self.mlp_mult_factor,
bias=False,
),
SiLUActivation(),
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
)
def forward(
self,
img: Tensor,
txt: Tensor,
pe: Tensor,
pe_ctx: Tensor,
mod_img: tuple[Tensor, Tensor],
mod_txt: tuple[Tensor, Tensor],
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = mod_img
txt_mod1, txt_mod2 = mod_txt
img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
pe = torch.cat((pe_ctx, pe), dim=2)
attn = attention(q, k, v, pe)
txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :]
# calculate the img blocks
img = img + img_mod1_gate * self.img_attn.proj(img_attn)
img = img + img_mod2_gate * self.img_mlp(
(1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift
)
# calculate the txt blocks
txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2_gate * self.txt_mlp(
(1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift
)
return img, txt
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(len(self.axes_dim))],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * self.scale
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -0,0 +1,129 @@
"""OpenRouter API client for prompt upsampling."""
import os
from typing import Any
from openai import OpenAI
from PIL import Image
from .system_messages import SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I
from .util import image_to_base64
DEFAULT_SAMPLING_PARAMS = {"mistralai/pixtral-large-2411": dict(temperature=0.15)}
class OpenRouterAPIClient:
"""Client for OpenRouter API-based prompt upsampling."""
def __init__(
self,
sampling_params: dict[str, Any],
model: str = "mistralai/pixtral-large-2411",
max_tokens: int = 512,
):
"""
Initialize the OpenRouter API client.
Args:
model: Model name to use for upsampling. Defaults to "mistralai/pixtral-large-2411".
Can be any OpenRouter model (e.g., "mistralai/pixtral-large-2411",
"qwen/qwen3-vl-235b-a22b-instruct", etc.)
"""
self.api_key = os.environ["OPENROUTER_API_KEY"]
self.client = OpenAI(api_key=self.api_key, base_url="https://openrouter.ai/api/v1")
self.model = model
self.sampling_params = sampling_params
self.max_tokens = max_tokens
def _format_messages(
self,
prompt: str,
system_message: str,
images: list[Image.Image] | None = None,
) -> list[dict[str, str]]:
messages: list[dict[str, str]] = [
{"role": "system", "content": system_message},
]
if images:
content = []
for img in images:
img_base64 = image_to_base64(img)
content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{img_base64}",
},
}
)
content.append({"type": "text", "text": prompt})
messages.append({"role": "user", "content": content})
else:
messages.append({"role": "user", "content": prompt})
return messages
def upsample_prompt(
self,
txt: list[str],
img: list[Image.Image] | list[list[Image.Image]] | None = None,
) -> list[str]:
"""
Upsample prompts using OpenRouter API.
Args:
txt: List of input prompts to upsample
img: Optional list of images or list of lists of images.
If None or empty, uses t2i mode, otherwise i2i mode.
Returns:
List of upsampled prompts
"""
# Determine system message based on whether images are provided
has_images = img is not None and len(img) > 0
if has_images and isinstance(img[0], list):
has_images = len(img[0]) > 0
if has_images:
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
else:
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
upsampled_prompts = []
# Process each prompt (potentially with images)
for i, prompt in enumerate(txt):
# Get images for this prompt
prompt_images: list[Image.Image] | None = None
if img is not None and len(img) > i:
if isinstance(img[i], list):
prompt_images = img[i] if len(img[i]) > 0 else None
elif isinstance(img[i], Image.Image):
prompt_images = [img[i]]
# Format messages
messages = self._format_messages(
prompt=prompt,
system_message=system_message,
images=prompt_images,
)
# Call API
try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=self.max_tokens,
**self.sampling_params,
)
upsampled = response.choices[0].message.content.strip()
upsampled_prompts.append(upsampled)
except Exception as e:
print(f"Error upsampling prompt via OpenRouter API: {e}, returning original prompt")
upsampled_prompts.append(prompt)
return upsampled_prompts

339
src/flux2/sampling.py Normal file
View File

@@ -0,0 +1,339 @@
import math
import torch
import torchvision
from einops import rearrange
from PIL import Image
from torch import Tensor
from .model import Flux2
def compress_time(t_ids: Tensor) -> Tensor:
assert t_ids.ndim == 1
t_ids_max = torch.max(t_ids)
t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)
t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
t_remap[t_unique_sorted_ids] = torch.arange(
len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype
)
t_ids_compressed = t_remap[t_ids]
return t_ids_compressed
def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
"""
using position ids to scatter tokens into place
"""
x_list = []
t_coords = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
t_ids = pos[:, 0].to(torch.int64)
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
t_ids_cmpr = compress_time(t_ids)
t = torch.max(t_ids_cmpr) + 1
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
t_coords.append(torch.unique(t_ids, sorted=True))
return x_list
def encode_image_refs(ae, img_ctx: list[Image.Image]):
scale = 10
if len(img_ctx) > 1:
limit_pixels = 1024**2
elif len(img_ctx) == 1:
limit_pixels = 2024**2
else:
limit_pixels = None
if not img_ctx:
return None, None
img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
if not isinstance(img_ctx_prep, list):
img_ctx_prep = [img_ctx_prep]
# Encode each reference image
encoded_refs = []
for img in img_ctx_prep:
encoded = ae.encode(img[None].cuda())[0]
encoded_refs.append(encoded)
# Create time offsets for each reference
t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
t_off = [t.view(-1) for t in t_off]
# Process with position IDs
ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
# Concatenate all references along sequence dimension
ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C)
ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4)
# Add batch dimension
ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C)
ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4)
return ref_tokens.to(torch.bfloat16), ref_ids
def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
_l, _ = x.shape # noqa: F841
coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(1), # dummy dimension
"w": torch.arange(1), # dummy dimension
"l": torch.arange(_l),
}
x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
return x, x_ids.to(x.device)
def batched_wrapper(fn):
def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
)
)
x, x_ids = zip(*results)
return torch.stack(x), torch.stack(x_ids)
return batched_prc
def listed_wrapper(fn):
def listed_prc(
x: list[Tensor],
t_coord: list[Tensor] | None = None,
) -> tuple[list[Tensor], list[Tensor]]:
results = []
for i in range(len(x)):
results.append(
fn(
x[i],
t_coord[i] if t_coord is not None else None,
)
)
x, x_ids = zip(*results)
return list(x), list(x_ids)
return listed_prc
def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
_, h, w = x.shape # noqa: F841
x_coords = {
"t": torch.arange(1) if t_coord is None else t_coord,
"h": torch.arange(h),
"w": torch.arange(w),
"l": torch.arange(1),
}
x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"])
x = rearrange(x, "c h w -> (h w) c")
return x, x_ids.to(x.device)
listed_prc_img = listed_wrapper(prc_img)
batched_prc_img = batched_wrapper(prc_img)
batched_prc_txt = batched_wrapper(prc_txt)
def center_crop_to_multiple_of_x(
img: Image.Image | list[Image.Image], x: int
) -> Image.Image | list[Image.Image]:
if isinstance(img, list):
return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore
w, h = img.size
new_w = (w // x) * x
new_h = (h // x) * x
left = (w - new_w) // 2
top = (h - new_h) // 2
right = left + new_w
bottom = top + new_h
resized = img.crop((left, top, right, bottom))
return resized
def cap_pixels(img: Image.Image | list[Image.Image], k):
if isinstance(img, list):
return [cap_pixels(_img, k) for _img in img]
w, h = img.size
pixel_count = w * h
if pixel_count <= k:
return img
# Scaling factor to reduce total pixels below K
scale = math.sqrt(k / pixel_count)
new_w = int(w * scale)
new_h = int(h * scale)
return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64):
if isinstance(img, list):
return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img]
w, h = img.size
if w < min_sidelength or h < min_sidelength:
raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}")
if w / h > max_ar or h / w > max_ar:
raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
return img
def to_rgb(img: Image.Image | list[Image.Image]):
if isinstance(img, list):
return [
to_rgb(
_img,
)
for _img in img
]
return img.convert("RGB")
def default_images_prep(
x: Image.Image | list[Image.Image],
) -> torch.Tensor | list[torch.Tensor]:
if isinstance(x, list):
return [default_images_prep(e) for e in x] # type: ignore
x_tensor = torchvision.transforms.ToTensor()(x)
return 2 * x_tensor - 1
def default_prep(
img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16
) -> torch.Tensor | list[torch.Tensor]:
img_rgb = to_rgb(img)
img_min = cap_min_pixels(img_rgb) # type: ignore
if limit_pixels is not None:
img_cap = cap_pixels(img_min, limit_pixels) # type: ignore
else:
img_cap = img_min
img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore
img_tensor = default_images_prep(img_crop)
return img_tensor
def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
mu = compute_empirical_mu(image_seq_len, num_steps)
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
return timesteps.tolist()
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
def denoise(
model: Flux2,
# model input
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
# sampling parameters
timesteps: list[float],
guidance: float,
# extra img tokens (sequence-wise)
img_cond_seq: Tensor | None = None,
img_cond_seq_ids: Tensor | None = None,
):
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
img_input = img
img_input_ids = img_ids
if img_cond_seq is not None:
assert (
img_cond_seq_ids is not None
), "You need to provide either both or neither of the sequence conditioning"
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred = model(
x=img_input,
x_ids=img_input_ids,
timesteps=t_vec,
ctx=txt,
ctx_ids=txt_ids,
guidance=guidance_vec,
)
if img_input_ids is not None:
pred = pred[:, : img.shape[1]]
img = img + (t_prev - t_curr) * pred
return img
def concatenate_images(
images: list[Image.Image],
) -> Image.Image:
"""
Concatenate a list of PIL images horizontally with center alignment and white background.
"""
# If only one image, return a copy of it
if len(images) == 1:
return images[0].copy()
# Convert all images to RGB if not already
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
# Calculate dimensions for horizontal concatenation
total_width = sum(img.width for img in images)
max_height = max(img.height for img in images)
# Create new image with white background
background_color = (255, 255, 255)
new_img = Image.new("RGB", (total_width, max_height), background_color)
# Paste images with center alignment
x_offset = 0
for img in images:
y_offset = (max_height - img.height) // 2
new_img.paste(img, (x_offset, y_offset))
x_offset += img.width
return new_img

View File

@@ -0,0 +1,82 @@
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation."""
SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
Guidelines:
1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
Output only the revised prompt and nothing else."""
SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
Rules:
- Single instruction only, no commentary
- Use clear, analytical language (avoid "whimsical," "cascading," etc.)
- Specify what changes AND what stays the same (face, lighting, composition)
- Reference actual image elements
- Turn negatives into positives ("don't change X""keep X")
- Make abstractions concrete ("futuristic""glowing cyan neon, metallic panels")
- Keep content PG-13
Output only the final instruction in plain text and nothing else."""
SYSTEM_PROMPT_CONTENT_FILTER = """
You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.
Your task is identifying images and text for copyright concerns and depictions of public personas.
"""
PROMPT_IMAGE_INTEGRITY = """
Task: Analyze an image to identify potential copyright concerns or depictions of public figures.
Output: Respond with only "yes" or "no"
Criteria for "yes":
- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.)
- The image displays a trademarked logo or brand
- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.)
Criteria for "no":
- All other cases
- When you cannot identify the specific copyrighted work or named individual
Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person depicted
2. General references to demographics or characteristics are not sufficient
3. Base your decision solely on visual content, not interpretation
4. Provide only the one-word answer: "yes" or "no"
""".strip()
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?"
PROMPT_TEXT_INTEGRITY = """
Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures.
Output: Respond with only "yes" or "no"
Criteria for "Yes":
- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.)
- The prompt explicitly mentions a trademarked logo or brand
- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.)
Criteria for "No":
- All other cases
- When you cannot identify the specific copyrighted work or named individual
Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person referenced
2. General demographic descriptions or characteristics are not sufficient
3. Analyze only the prompt text, not potential image outcomes
4. Provide only the one-word answer: "yes" or "no"
The prompt to check is:
-----
{prompt}
-----
Does this prompt have copyright concerns or includes public figures?
""".strip()

356
src/flux2/text_encoder.py Normal file
View File

@@ -0,0 +1,356 @@
from pathlib import Path
import torch
import torch.nn as nn
from einops import rearrange
from PIL import Image
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, pipeline
from .sampling import cap_pixels, concatenate_images
from .system_messages import (
PROMPT_IMAGE_INTEGRITY,
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
PROMPT_TEXT_INTEGRITY,
SYSTEM_MESSAGE,
SYSTEM_MESSAGE_UPSAMPLING_I2I,
SYSTEM_MESSAGE_UPSAMPLING_T2I,
SYSTEM_PROMPT_CONTENT_FILTER,
)
OUTPUT_LAYERS = [10, 20, 30]
MAX_LENGTH = 512
NSFW_THRESHOLD = 0.85
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
class Mistral3SmallEmbedder(nn.Module):
def __init__(
self,
model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
torch_dtype: str = "bfloat16",
):
super().__init__()
self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained(
model_spec,
torch_dtype=getattr(torch, torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False)
self.yes_token, self.no_token = self.processor.tokenizer.encode(
["yes", "no"], add_special_tokens=False
)
self.max_length = MAX_LENGTH
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
self.nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
def _validate_and_process_images(
self, img: list[list[Image.Image]] | list[Image.Image]
) -> list[list[Image.Image]]:
# Simple validation: ensure it's a list of PIL images or list of lists of PIL images
if not img:
return []
# Check if it's a list of lists or a list of images
if isinstance(img[0], Image.Image):
# It's a list of images, convert to list of lists
img = [[im] for im in img]
# potentially concatenate multiple images to reduce the size
img = [[concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in img]
# cap the pixels
img = [[cap_pixels(img_i, self.upsampling_max_image_size) for img_i in img_i] for img_i in img]
return img
def format_input(
self,
txt: list[str],
system_message: str = SYSTEM_MESSAGE,
img: list[Image.Image] | list[list[Image.Image]] | None = None,
) -> list[list[dict]]:
"""
Format a batch of text prompts into the conversation format expected by apply_chat_template.
Optionally, add images to the input.
Args:
txt: List of text prompts
system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
img: List of images to add to the input.
Returns:
List of conversations, where each conversation is a list of message dicts
"""
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
# when truncation is enabled. The processor counts [IMG] tokens and fails
# if the count changes after truncation.
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt]
if img is None or len(img) == 0:
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
else:
assert len(img) == len(txt), "Number of images must match number of prompts"
img = self._validate_and_process_images(img)
messages = [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
]
for _ in cleaned_txt
]
for i, (el, images) in enumerate(zip(messages, img)):
# optionally add the images per batch element.
if images is not None:
el.append(
{
"role": "user",
"content": [{"type": "image", "image": image_obj} for image_obj in images],
}
)
# add the text.
el.append(
{
"role": "user",
"content": [{"type": "text", "text": cleaned_txt[i]}],
}
)
return messages
@torch.no_grad()
def upsample_prompt(
self,
txt: list[str],
img: list[Image.Image] | list[list[Image.Image]] | None = None,
temperature: float = 0.15,
) -> list[str]:
"""
Upsample prompts using the model's generate method.
Args:
txt: List of input prompts to upsample
img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode.
Returns:
List of upsampled prompts
"""
# Set system message based on whether images are provided
if img is None or len(img) == 0 or img[0] is None:
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
else:
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
# Format input messages
messages_batch = self.format_input(txt=txt, system_message=system_message, img=img)
# Process all messages at once
# with image processing a too short max length can throw an error in here.
try:
inputs = self.processor.apply_chat_template(
messages_batch,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=2048,
)
except ValueError as e:
print(
f"Error processing input: {e}, your max length is probably too short, when you have images in the input."
)
raise e
# Move to device
inputs["input_ids"] = inputs["input_ids"].to(self.model.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device)
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype)
# Generate text using the model's generate method
try:
generated_ids = self.model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=temperature,
use_cache=True,
)
# Decode only the newly generated tokens (skip input tokens)
# Extract only the generated portion
input_length = inputs["input_ids"].shape[1]
generated_tokens = generated_ids[:, input_length:]
raw_txt = self.processor.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return raw_txt
except Exception as e:
print(f"Error generating upsampled prompt: {e}, returning original prompt")
return txt
@torch.no_grad()
def forward(self, txt: list[str]):
# Format input messages
messages_batch = self.format_input(txt=txt)
# Process all messages at once
# with image processing a too short max length can throw an error in here.
inputs = self.processor.apply_chat_template(
messages_batch,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length,
)
# Move to device
input_ids = inputs["input_ids"].to(self.model.device)
attention_mask = inputs["attention_mask"].to(self.model.device)
# Forward pass through the model
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
return rearrange(out, "b c l d -> b l (c d)")
def yes_no_logit_processor(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
"""
Sets all tokens but yes/no to the minimum.
"""
scores_yes_token = scores[:, self.yes_token].clone()
scores_no_token = scores[:, self.no_token].clone()
scores_min = scores.min()
scores[:, :] = scores_min - 1
scores[:, self.yes_token] = scores_yes_token
scores[:, self.no_token] = scores_no_token
return scores
def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool:
if isinstance(image, torch.Tensor):
image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
elif isinstance(image, (str, Path)):
image = Image.open(image)
classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw")
if classification["score"] > NSFW_THRESHOLD:
return True
# 512^2 pixels are enough for checking
w, h = image.size
f = (512**2 / (w * h)) ** 0.5
image = image.resize((int(f * w), int(f * h)))
chat = [
{
"role": "system",
"content": [
{
"type": "text",
"text": SYSTEM_PROMPT_CONTENT_FILTER,
},
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": PROMPT_IMAGE_INTEGRITY,
},
{
"type": "image",
"image": image,
},
{
"type": "text",
"text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
},
],
},
]
inputs = self.processor.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype)
generate_ids = self.model.generate(
**inputs,
max_new_tokens=1,
logits_processor=[self.yes_no_logit_processor],
do_sample=False,
)
return generate_ids[0, -1].item() == self.yes_token
def test_txt(self, txt: str) -> bool:
chat = [
{
"role": "system",
"content": [
{
"type": "text",
"text": SYSTEM_PROMPT_CONTENT_FILTER,
},
],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
},
],
},
]
inputs = self.processor.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
generate_ids = self.model.generate(
**inputs,
max_new_tokens=1,
logits_processor=[self.yes_no_logit_processor],
do_sample=False,
)
return generate_ids[0, -1].item() == self.yes_token

105
src/flux2/util.py Normal file
View File

@@ -0,0 +1,105 @@
import base64
import io
import os
import sys
import huggingface_hub
import torch
from PIL import Image
from safetensors.torch import load_file as load_sft
from .autoencoder import AutoEncoder, AutoEncoderParams
from .model import Flux2, Flux2Params
from .text_encoder import Mistral3SmallEmbedder
FLUX2_MODEL_INFO = {
"flux.2-dev": {
"repo_id": "black-forest-labs/FLUX.2-dev",
"filename": "flux2-dev.safetensors",
"filename_ae": "ae.safetensors",
"params": Flux2Params(),
}
}
def load_flow_model(model_name: str, debug_mode: bool = False, device: str | torch.device = "cuda") -> Flux2:
config = FLUX2_MODEL_INFO[model_name.lower()]
if debug_mode:
config["params"].depth = 1
config["params"].depth_single_blocks = 1
else:
if "FLUX2_MODEL_PATH" in os.environ:
weight_path = os.environ["FLUX2_MODEL_PATH"]
assert os.path.exists(weight_path), f"Provided weight path {weight_path} does not exist"
else:
# download from huggingface
try:
weight_path = huggingface_hub.hf_hub_download(
repo_id=config["repo_id"],
filename=config["filename"],
repo_type="model",
)
except huggingface_hub.errors.RepositoryNotFoundError:
print(
f"Failed to access the model repository. Please check your internet "
f"connection and make sure you've access to {config['repo_id']}."
"Stopping."
)
sys.exit(1)
if not debug_mode:
with torch.device("meta"):
model = Flux2(FLUX2_MODEL_INFO[model_name.lower()]["params"]).to(torch.bfloat16)
print(f"Loading {weight_path} for the FLUX.2 weights")
sd = load_sft(weight_path, device=str(device))
model.load_state_dict(sd, strict=False, assign=True)
return model.to(device)
else:
with torch.device(device):
return Flux2(FLUX2_MODEL_INFO[model_name.lower()]["params"]).to(torch.bfloat16)
def load_mistral_small_embedder(device: str | torch.device = "cuda") -> Mistral3SmallEmbedder:
return Mistral3SmallEmbedder().to(device)
def load_ae(model_name: str, device: str | torch.device = "cuda") -> AutoEncoder:
config = FLUX2_MODEL_INFO[model_name.lower()]
if "AE_MODEL_PATH" in os.environ:
weight_path = os.environ["AE_MODEL_PATH"]
assert os.path.exists(weight_path), f"Provided weight path {weight_path} does not exist"
else:
# download from huggingface
try:
weight_path = huggingface_hub.hf_hub_download(
repo_id=config["repo_id"],
filename=config["filename_ae"],
repo_type="model",
)
except huggingface_hub.errors.RepositoryNotFoundError:
print(
f"Failed to access the model repository. Please check your internet "
f"connection and make sure you've access to {config['repo_id']}."
"Stopping."
)
sys.exit(1)
if isinstance(device, str):
device = torch.device(device)
with torch.device("meta"):
ae = AutoEncoder(AutoEncoderParams())
print(f"Loading {weight_path} for the AutoEncoder weights")
sd = load_sft(weight_path, device=str(device))
ae.load_state_dict(sd, strict=True, assign=True)
return ae.to(device)
def image_to_base64(image: Image.Image) -> str:
"""Convert PIL Image to base64 string."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str

47
src/flux2/watermark.py Normal file
View File

@@ -0,0 +1,47 @@
import torch
from einops import rearrange
from imwatermark import WatermarkEncoder
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, RGB, H, W) in range [-1, 1]
Returns:
same as input but watermarked
"""
image = 0.5 * image + 0.5
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
# watermarking libary expects input as cv2 BGR format
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
image.device
)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
image = 2 * image - 1
return image
# A fixed 48-bit message that was chosen at random
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)