FLUX.2 launch
This commit is contained in:
336
src/flux2/autoencoder.py
Normal file
336
src/flux2/autoencoder.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoEncoderParams:
|
||||
resolution: int = 256
|
||||
in_channels: int = 3
|
||||
ch: int = 128
|
||||
out_ch: int = 3
|
||||
ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
|
||||
num_res_blocks: int = 2
|
||||
z_channels: int = 32
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, h_: Tensor) -> Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
||||
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = swish(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = swish(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
# downsampling
|
||||
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
block_in = self.ch
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: list[int],
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
resolution: int,
|
||||
z_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
|
||||
self.ch = ch
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.ffactor = 2 ** (self.num_resolutions - 1)
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
z = self.post_quant_conv(z)
|
||||
|
||||
# get dtype for proper tracing
|
||||
upscale_dtype = next(self.up.parameters()).dtype
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# cast to proper dtype
|
||||
h = h.to(upscale_dtype)
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class AutoEncoder(nn.Module):
|
||||
def __init__(self, params: AutoEncoderParams):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.encoder = Encoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
resolution=params.resolution,
|
||||
in_channels=params.in_channels,
|
||||
ch=params.ch,
|
||||
out_ch=params.out_ch,
|
||||
ch_mult=params.ch_mult,
|
||||
num_res_blocks=params.num_res_blocks,
|
||||
z_channels=params.z_channels,
|
||||
)
|
||||
|
||||
self.bn_eps = 1e-4
|
||||
self.bn_momentum = 0.1
|
||||
self.ps = [2, 2]
|
||||
self.bn = torch.nn.BatchNorm2d(
|
||||
math.prod(self.ps) * params.z_channels,
|
||||
eps=self.bn_eps,
|
||||
momentum=self.bn_momentum,
|
||||
affine=False,
|
||||
track_running_stats=True,
|
||||
)
|
||||
|
||||
def normalize(self, z):
|
||||
self.bn.eval()
|
||||
return self.bn(z)
|
||||
|
||||
def inv_normalize(self, z):
|
||||
self.bn.eval()
|
||||
s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)
|
||||
m = self.bn.running_mean.view(1, -1, 1, 1)
|
||||
return z * s + m
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
moments = self.encoder(x)
|
||||
mean = torch.chunk(moments, 2, dim=1)[0]
|
||||
|
||||
z = rearrange(
|
||||
mean,
|
||||
"... c (i pi) (j pj) -> ... (c pi pj) i j",
|
||||
pi=self.ps[0],
|
||||
pj=self.ps[1],
|
||||
)
|
||||
z = self.normalize(z)
|
||||
return z
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
z = self.inv_normalize(z)
|
||||
z = rearrange(
|
||||
z,
|
||||
"... (c pi pj) i j -> ... c (i pi) (j pj)",
|
||||
pi=self.ps[0],
|
||||
pj=self.ps[1],
|
||||
)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
451
src/flux2/model.py
Normal file
451
src/flux2/model.py
Normal file
@@ -0,0 +1,451 @@
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class Flux2Params:
|
||||
in_channels: int = 128
|
||||
context_in_dim: int = 15360
|
||||
hidden_size: int = 6144
|
||||
num_heads: int = 48
|
||||
depth: int = 8
|
||||
depth_single_blocks: int = 48
|
||||
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
|
||||
theta: int = 2000
|
||||
mlp_ratio: float = 3.0
|
||||
|
||||
|
||||
class Flux2(nn.Module):
|
||||
def __init__(self, params: Flux2Params):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
|
||||
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.double_stream_modulation_img = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
disable_bias=True,
|
||||
)
|
||||
self.double_stream_modulation_txt = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
disable_bias=True,
|
||||
)
|
||||
self.single_stream_modulation = Modulation(self.hidden_size, double=False, disable_bias=True)
|
||||
|
||||
self.final_layer = LastLayer(
|
||||
self.hidden_size,
|
||||
self.out_channels,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
x_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
ctx: Tensor,
|
||||
ctx_ids: Tensor,
|
||||
guidance: Tensor,
|
||||
):
|
||||
num_txt_tokens = ctx.shape[1]
|
||||
|
||||
timestep_emb = timestep_embedding(timesteps, 256)
|
||||
vec = self.time_in(timestep_emb)
|
||||
guidance_emb = timestep_embedding(guidance, 256)
|
||||
vec = vec + self.guidance_in(guidance_emb)
|
||||
|
||||
double_block_mod_img = self.double_stream_modulation_img(vec)
|
||||
double_block_mod_txt = self.double_stream_modulation_txt(vec)
|
||||
single_block_mod, _ = self.single_stream_modulation(vec)
|
||||
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(ctx)
|
||||
|
||||
pe_x = self.pe_embedder(x_ids)
|
||||
pe_ctx = self.pe_embedder(ctx_ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(
|
||||
img,
|
||||
txt,
|
||||
pe_x,
|
||||
pe_ctx,
|
||||
double_block_mod_img,
|
||||
double_block_mod_txt,
|
||||
)
|
||||
|
||||
img = torch.cat((txt, img), dim=1)
|
||||
pe = torch.cat((pe_ctx, pe_x), dim=2)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
img = block(
|
||||
img,
|
||||
pe,
|
||||
single_block_mod,
|
||||
)
|
||||
|
||||
img = img[:, num_txt_tokens:, ...]
|
||||
|
||||
img = self.final_layer(img, vec)
|
||||
return img
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
|
||||
class SiLUActivation(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return self.gate_fn(x1) * x2
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, disable_bias: bool = False):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = nn.Linear(dim, self.multiplier * dim, bias=not disable_bias)
|
||||
|
||||
def forward(self, vec: torch.Tensor):
|
||||
out = self.lin(nn.functional.silu(vec))
|
||||
if out.ndim == 2:
|
||||
out = out[:, None, :]
|
||||
out = out.chunk(self.multiplier, dim=-1)
|
||||
return out[:3], out[3:] if self.is_double else None
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
out_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=False)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=False))
|
||||
|
||||
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
|
||||
mod = self.adaLN_modulation(vec)
|
||||
shift, scale = mod.chunk(2, dim=-1)
|
||||
if shift.ndim == 2:
|
||||
shift = shift[:, None, :]
|
||||
scale = scale[:, None, :]
|
||||
x = (1 + scale) * self.norm_final(x) + shift
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp_mult_factor = 2
|
||||
|
||||
self.linear1 = nn.Linear(
|
||||
hidden_size,
|
||||
hidden_size * 3 + self.mlp_hidden_dim * self.mlp_mult_factor,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=False)
|
||||
|
||||
self.norm = QKNorm(head_dim)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = SiLUActivation()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
pe: Tensor,
|
||||
mod: tuple[Tensor, Tensor],
|
||||
) -> Tensor:
|
||||
mod_shift, mod_scale, mod_gate = mod
|
||||
x_mod = (1 + mod_scale) * self.pre_norm(x) + mod_shift
|
||||
|
||||
qkv, mlp = torch.split(
|
||||
self.linear1(x_mod),
|
||||
[3 * self.hidden_size, self.mlp_hidden_dim * self.mlp_mult_factor],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
attn = attention(q, k, v, pe)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod_gate * output
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
):
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
assert hidden_size % num_heads == 0, f"{hidden_size=} must be divisible by {num_heads=}"
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.mlp_mult_factor = 2
|
||||
|
||||
self.img_attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, mlp_hidden_dim * self.mlp_mult_factor, bias=False),
|
||||
SiLUActivation(),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
|
||||
)
|
||||
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
hidden_size,
|
||||
mlp_hidden_dim * self.mlp_mult_factor,
|
||||
bias=False,
|
||||
),
|
||||
SiLUActivation(),
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=False),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
pe: Tensor,
|
||||
pe_ctx: Tensor,
|
||||
mod_img: tuple[Tensor, Tensor],
|
||||
mod_txt: tuple[Tensor, Tensor],
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = mod_img
|
||||
txt_mod1, txt_mod2 = mod_txt
|
||||
|
||||
img_mod1_shift, img_mod1_scale, img_mod1_gate = img_mod1
|
||||
img_mod2_shift, img_mod2_scale, img_mod2_gate = img_mod2
|
||||
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate = txt_mod1
|
||||
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = txt_mod2
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
|
||||
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
|
||||
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
pe = torch.cat((pe_ctx, pe), dim=2)
|
||||
attn = attention(q, k, v, pe)
|
||||
txt_attn, img_attn = attn[:, : txt_q.shape[2]], attn[:, txt_q.shape[2] :]
|
||||
|
||||
# calculate the img blocks
|
||||
img = img + img_mod1_gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2_gate * self.img_mlp(
|
||||
(1 + img_mod2_scale) * (self.img_norm2(img)) + img_mod2_shift
|
||||
)
|
||||
|
||||
# calculate the txt blocks
|
||||
txt = txt + txt_mod1_gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2_gate * self.txt_mlp(
|
||||
(1 + txt_mod2_scale) * (self.txt_norm2(txt)) + txt_mod2_shift
|
||||
)
|
||||
return img, txt
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, disable_bias: bool = False):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=not disable_bias)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=not disable_bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(len(self.axes_dim))],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, device=t.device, dtype=torch.float32) / half
|
||||
)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(t)
|
||||
return embedding
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||
return (x * rrms).to(dtype=x_dtype) * self.scale
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim)
|
||||
self.key_norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
return out.float()
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
129
src/flux2/openrouter_api_client.py
Normal file
129
src/flux2/openrouter_api_client.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""OpenRouter API client for prompt upsampling."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
|
||||
from .system_messages import SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I
|
||||
from .util import image_to_base64
|
||||
|
||||
DEFAULT_SAMPLING_PARAMS = {"mistralai/pixtral-large-2411": dict(temperature=0.15)}
|
||||
|
||||
|
||||
class OpenRouterAPIClient:
|
||||
"""Client for OpenRouter API-based prompt upsampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_params: dict[str, Any],
|
||||
model: str = "mistralai/pixtral-large-2411",
|
||||
max_tokens: int = 512,
|
||||
):
|
||||
"""
|
||||
Initialize the OpenRouter API client.
|
||||
|
||||
Args:
|
||||
model: Model name to use for upsampling. Defaults to "mistralai/pixtral-large-2411".
|
||||
Can be any OpenRouter model (e.g., "mistralai/pixtral-large-2411",
|
||||
"qwen/qwen3-vl-235b-a22b-instruct", etc.)
|
||||
"""
|
||||
|
||||
self.api_key = os.environ["OPENROUTER_API_KEY"]
|
||||
self.client = OpenAI(api_key=self.api_key, base_url="https://openrouter.ai/api/v1")
|
||||
self.model = model
|
||||
self.sampling_params = sampling_params
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _format_messages(
|
||||
self,
|
||||
prompt: str,
|
||||
system_message: str,
|
||||
images: list[Image.Image] | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
messages: list[dict[str, str]] = [
|
||||
{"role": "system", "content": system_message},
|
||||
]
|
||||
|
||||
if images:
|
||||
content = []
|
||||
|
||||
for img in images:
|
||||
img_base64 = image_to_base64(img)
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{img_base64}",
|
||||
},
|
||||
}
|
||||
)
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append({"role": "user", "content": content})
|
||||
else:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return messages
|
||||
|
||||
def upsample_prompt(
|
||||
self,
|
||||
txt: list[str],
|
||||
img: list[Image.Image] | list[list[Image.Image]] | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Upsample prompts using OpenRouter API.
|
||||
|
||||
Args:
|
||||
txt: List of input prompts to upsample
|
||||
img: Optional list of images or list of lists of images.
|
||||
If None or empty, uses t2i mode, otherwise i2i mode.
|
||||
|
||||
Returns:
|
||||
List of upsampled prompts
|
||||
"""
|
||||
# Determine system message based on whether images are provided
|
||||
has_images = img is not None and len(img) > 0
|
||||
if has_images and isinstance(img[0], list):
|
||||
has_images = len(img[0]) > 0
|
||||
|
||||
if has_images:
|
||||
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
|
||||
else:
|
||||
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
|
||||
|
||||
upsampled_prompts = []
|
||||
|
||||
# Process each prompt (potentially with images)
|
||||
for i, prompt in enumerate(txt):
|
||||
# Get images for this prompt
|
||||
prompt_images: list[Image.Image] | None = None
|
||||
if img is not None and len(img) > i:
|
||||
if isinstance(img[i], list):
|
||||
prompt_images = img[i] if len(img[i]) > 0 else None
|
||||
elif isinstance(img[i], Image.Image):
|
||||
prompt_images = [img[i]]
|
||||
|
||||
# Format messages
|
||||
messages = self._format_messages(
|
||||
prompt=prompt,
|
||||
system_message=system_message,
|
||||
images=prompt_images,
|
||||
)
|
||||
|
||||
# Call API
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=self.max_tokens,
|
||||
**self.sampling_params,
|
||||
)
|
||||
|
||||
upsampled = response.choices[0].message.content.strip()
|
||||
upsampled_prompts.append(upsampled)
|
||||
except Exception as e:
|
||||
print(f"Error upsampling prompt via OpenRouter API: {e}, returning original prompt")
|
||||
upsampled_prompts.append(prompt)
|
||||
|
||||
return upsampled_prompts
|
||||
339
src/flux2/sampling.py
Normal file
339
src/flux2/sampling.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from .model import Flux2
|
||||
|
||||
|
||||
def compress_time(t_ids: Tensor) -> Tensor:
|
||||
assert t_ids.ndim == 1
|
||||
t_ids_max = torch.max(t_ids)
|
||||
t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)
|
||||
t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
|
||||
t_remap[t_unique_sorted_ids] = torch.arange(
|
||||
len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype
|
||||
)
|
||||
t_ids_compressed = t_remap[t_ids]
|
||||
return t_ids_compressed
|
||||
|
||||
|
||||
def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
x_list = []
|
||||
t_coords = []
|
||||
for data, pos in zip(x, x_ids):
|
||||
_, ch = data.shape # noqa: F841
|
||||
t_ids = pos[:, 0].to(torch.int64)
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
t_ids_cmpr = compress_time(t_ids)
|
||||
|
||||
t = torch.max(t_ids_cmpr) + 1
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
|
||||
|
||||
out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)
|
||||
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||
|
||||
x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
|
||||
t_coords.append(torch.unique(t_ids, sorted=True))
|
||||
return x_list
|
||||
|
||||
|
||||
def encode_image_refs(ae, img_ctx: list[Image.Image]):
|
||||
scale = 10
|
||||
|
||||
if len(img_ctx) > 1:
|
||||
limit_pixels = 1024**2
|
||||
elif len(img_ctx) == 1:
|
||||
limit_pixels = 2024**2
|
||||
else:
|
||||
limit_pixels = None
|
||||
|
||||
if not img_ctx:
|
||||
return None, None
|
||||
|
||||
img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
|
||||
if not isinstance(img_ctx_prep, list):
|
||||
img_ctx_prep = [img_ctx_prep]
|
||||
|
||||
# Encode each reference image
|
||||
encoded_refs = []
|
||||
for img in img_ctx_prep:
|
||||
encoded = ae.encode(img[None].cuda())[0]
|
||||
encoded_refs.append(encoded)
|
||||
|
||||
# Create time offsets for each reference
|
||||
t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
|
||||
t_off = [t.view(-1) for t in t_off]
|
||||
|
||||
# Process with position IDs
|
||||
ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
|
||||
|
||||
# Concatenate all references along sequence dimension
|
||||
ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C)
|
||||
ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4)
|
||||
|
||||
# Add batch dimension
|
||||
ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C)
|
||||
ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4)
|
||||
|
||||
return ref_tokens.to(torch.bfloat16), ref_ids
|
||||
|
||||
|
||||
def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
|
||||
_l, _ = x.shape # noqa: F841
|
||||
|
||||
coords = {
|
||||
"t": torch.arange(1) if t_coord is None else t_coord,
|
||||
"h": torch.arange(1), # dummy dimension
|
||||
"w": torch.arange(1), # dummy dimension
|
||||
"l": torch.arange(_l),
|
||||
}
|
||||
x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
|
||||
return x, x_ids.to(x.device)
|
||||
|
||||
|
||||
def batched_wrapper(fn):
|
||||
def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
|
||||
results = []
|
||||
for i in range(len(x)):
|
||||
results.append(
|
||||
fn(
|
||||
x[i],
|
||||
t_coord[i] if t_coord is not None else None,
|
||||
)
|
||||
)
|
||||
x, x_ids = zip(*results)
|
||||
return torch.stack(x), torch.stack(x_ids)
|
||||
|
||||
return batched_prc
|
||||
|
||||
|
||||
def listed_wrapper(fn):
|
||||
def listed_prc(
|
||||
x: list[Tensor],
|
||||
t_coord: list[Tensor] | None = None,
|
||||
) -> tuple[list[Tensor], list[Tensor]]:
|
||||
results = []
|
||||
for i in range(len(x)):
|
||||
results.append(
|
||||
fn(
|
||||
x[i],
|
||||
t_coord[i] if t_coord is not None else None,
|
||||
)
|
||||
)
|
||||
x, x_ids = zip(*results)
|
||||
return list(x), list(x_ids)
|
||||
|
||||
return listed_prc
|
||||
|
||||
|
||||
def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
|
||||
_, h, w = x.shape # noqa: F841
|
||||
x_coords = {
|
||||
"t": torch.arange(1) if t_coord is None else t_coord,
|
||||
"h": torch.arange(h),
|
||||
"w": torch.arange(w),
|
||||
"l": torch.arange(1),
|
||||
}
|
||||
x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"])
|
||||
x = rearrange(x, "c h w -> (h w) c")
|
||||
return x, x_ids.to(x.device)
|
||||
|
||||
|
||||
listed_prc_img = listed_wrapper(prc_img)
|
||||
batched_prc_img = batched_wrapper(prc_img)
|
||||
batched_prc_txt = batched_wrapper(prc_txt)
|
||||
|
||||
|
||||
def center_crop_to_multiple_of_x(
|
||||
img: Image.Image | list[Image.Image], x: int
|
||||
) -> Image.Image | list[Image.Image]:
|
||||
if isinstance(img, list):
|
||||
return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore
|
||||
|
||||
w, h = img.size
|
||||
new_w = (w // x) * x
|
||||
new_h = (h // x) * x
|
||||
|
||||
left = (w - new_w) // 2
|
||||
top = (h - new_h) // 2
|
||||
right = left + new_w
|
||||
bottom = top + new_h
|
||||
|
||||
resized = img.crop((left, top, right, bottom))
|
||||
return resized
|
||||
|
||||
|
||||
def cap_pixels(img: Image.Image | list[Image.Image], k):
|
||||
if isinstance(img, list):
|
||||
return [cap_pixels(_img, k) for _img in img]
|
||||
w, h = img.size
|
||||
pixel_count = w * h
|
||||
|
||||
if pixel_count <= k:
|
||||
return img
|
||||
|
||||
# Scaling factor to reduce total pixels below K
|
||||
scale = math.sqrt(k / pixel_count)
|
||||
new_w = int(w * scale)
|
||||
new_h = int(h * scale)
|
||||
|
||||
return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64):
|
||||
if isinstance(img, list):
|
||||
return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img]
|
||||
w, h = img.size
|
||||
if w < min_sidelength or h < min_sidelength:
|
||||
raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}")
|
||||
if w / h > max_ar or h / w > max_ar:
|
||||
raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
|
||||
return img
|
||||
|
||||
|
||||
def to_rgb(img: Image.Image | list[Image.Image]):
|
||||
if isinstance(img, list):
|
||||
return [
|
||||
to_rgb(
|
||||
_img,
|
||||
)
|
||||
for _img in img
|
||||
]
|
||||
return img.convert("RGB")
|
||||
|
||||
|
||||
def default_images_prep(
|
||||
x: Image.Image | list[Image.Image],
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
if isinstance(x, list):
|
||||
return [default_images_prep(e) for e in x] # type: ignore
|
||||
x_tensor = torchvision.transforms.ToTensor()(x)
|
||||
return 2 * x_tensor - 1
|
||||
|
||||
|
||||
def default_prep(
|
||||
img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
img_rgb = to_rgb(img)
|
||||
img_min = cap_min_pixels(img_rgb) # type: ignore
|
||||
if limit_pixels is not None:
|
||||
img_cap = cap_pixels(img_min, limit_pixels) # type: ignore
|
||||
else:
|
||||
img_cap = img_min
|
||||
img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore
|
||||
img_tensor = default_images_prep(img_crop)
|
||||
return img_tensor
|
||||
|
||||
|
||||
def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
|
||||
mu = compute_empirical_mu(image_seq_len, num_steps)
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux2,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
guidance: float,
|
||||
# extra img tokens (sequence-wise)
|
||||
img_cond_seq: Tensor | None = None,
|
||||
img_cond_seq_ids: Tensor | None = None,
|
||||
):
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
img_input = img
|
||||
img_input_ids = img_ids
|
||||
if img_cond_seq is not None:
|
||||
assert (
|
||||
img_cond_seq_ids is not None
|
||||
), "You need to provide either both or neither of the sequence conditioning"
|
||||
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||
pred = model(
|
||||
x=img_input,
|
||||
x_ids=img_input_ids,
|
||||
timesteps=t_vec,
|
||||
ctx=txt,
|
||||
ctx_ids=txt_ids,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
if img_input_ids is not None:
|
||||
pred = pred[:, : img.shape[1]]
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def concatenate_images(
|
||||
images: list[Image.Image],
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Concatenate a list of PIL images horizontally with center alignment and white background.
|
||||
"""
|
||||
|
||||
# If only one image, return a copy of it
|
||||
if len(images) == 1:
|
||||
return images[0].copy()
|
||||
|
||||
# Convert all images to RGB if not already
|
||||
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
|
||||
|
||||
# Calculate dimensions for horizontal concatenation
|
||||
total_width = sum(img.width for img in images)
|
||||
max_height = max(img.height for img in images)
|
||||
|
||||
# Create new image with white background
|
||||
background_color = (255, 255, 255)
|
||||
new_img = Image.new("RGB", (total_width, max_height), background_color)
|
||||
|
||||
# Paste images with center alignment
|
||||
x_offset = 0
|
||||
for img in images:
|
||||
y_offset = (max_height - img.height) // 2
|
||||
new_img.paste(img, (x_offset, y_offset))
|
||||
x_offset += img.width
|
||||
|
||||
return new_img
|
||||
82
src/flux2/system_messages.py
Normal file
82
src/flux2/system_messages.py
Normal file
@@ -0,0 +1,82 @@
|
||||
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
|
||||
attribution and actions without speculation."""
|
||||
|
||||
SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
|
||||
|
||||
Guidelines:
|
||||
1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
|
||||
2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
|
||||
3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
|
||||
|
||||
Output only the revised prompt and nothing else."""
|
||||
|
||||
SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
|
||||
|
||||
Rules:
|
||||
- Single instruction only, no commentary
|
||||
- Use clear, analytical language (avoid "whimsical," "cascading," etc.)
|
||||
- Specify what changes AND what stays the same (face, lighting, composition)
|
||||
- Reference actual image elements
|
||||
- Turn negatives into positives ("don't change X" → "keep X")
|
||||
- Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels")
|
||||
- Keep content PG-13
|
||||
|
||||
Output only the final instruction in plain text and nothing else."""
|
||||
|
||||
|
||||
SYSTEM_PROMPT_CONTENT_FILTER = """
|
||||
You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.
|
||||
Your task is identifying images and text for copyright concerns and depictions of public personas.
|
||||
"""
|
||||
|
||||
PROMPT_IMAGE_INTEGRITY = """
|
||||
Task: Analyze an image to identify potential copyright concerns or depictions of public figures.
|
||||
|
||||
Output: Respond with only "yes" or "no"
|
||||
|
||||
Criteria for "yes":
|
||||
- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.)
|
||||
- The image displays a trademarked logo or brand
|
||||
- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.)
|
||||
|
||||
Criteria for "no":
|
||||
- All other cases
|
||||
- When you cannot identify the specific copyrighted work or named individual
|
||||
|
||||
Critical Requirements:
|
||||
1. You must be able to name the exact copyrighted work or specific person depicted
|
||||
2. General references to demographics or characteristics are not sufficient
|
||||
3. Base your decision solely on visual content, not interpretation
|
||||
4. Provide only the one-word answer: "yes" or "no"
|
||||
""".strip()
|
||||
|
||||
|
||||
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?"
|
||||
|
||||
PROMPT_TEXT_INTEGRITY = """
|
||||
Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures.
|
||||
|
||||
Output: Respond with only "yes" or "no"
|
||||
|
||||
Criteria for "Yes":
|
||||
- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.)
|
||||
- The prompt explicitly mentions a trademarked logo or brand
|
||||
- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.)
|
||||
|
||||
Criteria for "No":
|
||||
- All other cases
|
||||
- When you cannot identify the specific copyrighted work or named individual
|
||||
|
||||
Critical Requirements:
|
||||
1. You must be able to name the exact copyrighted work or specific person referenced
|
||||
2. General demographic descriptions or characteristics are not sufficient
|
||||
3. Analyze only the prompt text, not potential image outcomes
|
||||
4. Provide only the one-word answer: "yes" or "no"
|
||||
|
||||
The prompt to check is:
|
||||
-----
|
||||
{prompt}
|
||||
-----
|
||||
|
||||
Does this prompt have copyright concerns or includes public figures?
|
||||
""".strip()
|
||||
356
src/flux2/text_encoder.py
Normal file
356
src/flux2/text_encoder.py
Normal file
@@ -0,0 +1,356 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, pipeline
|
||||
|
||||
from .sampling import cap_pixels, concatenate_images
|
||||
from .system_messages import (
|
||||
PROMPT_IMAGE_INTEGRITY,
|
||||
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
|
||||
PROMPT_TEXT_INTEGRITY,
|
||||
SYSTEM_MESSAGE,
|
||||
SYSTEM_MESSAGE_UPSAMPLING_I2I,
|
||||
SYSTEM_MESSAGE_UPSAMPLING_T2I,
|
||||
SYSTEM_PROMPT_CONTENT_FILTER,
|
||||
)
|
||||
|
||||
OUTPUT_LAYERS = [10, 20, 30]
|
||||
MAX_LENGTH = 512
|
||||
NSFW_THRESHOLD = 0.85
|
||||
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
|
||||
|
||||
|
||||
class Mistral3SmallEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_spec: str = "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
|
||||
model_spec_processor: str = "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||
torch_dtype: str = "bfloat16",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model: Mistral3ForConditionalGeneration = Mistral3ForConditionalGeneration.from_pretrained(
|
||||
model_spec,
|
||||
torch_dtype=getattr(torch, torch_dtype),
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_spec_processor, use_fast=False)
|
||||
self.yes_token, self.no_token = self.processor.tokenizer.encode(
|
||||
["yes", "no"], add_special_tokens=False
|
||||
)
|
||||
|
||||
self.max_length = MAX_LENGTH
|
||||
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
|
||||
|
||||
self.nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
|
||||
|
||||
def _validate_and_process_images(
|
||||
self, img: list[list[Image.Image]] | list[Image.Image]
|
||||
) -> list[list[Image.Image]]:
|
||||
# Simple validation: ensure it's a list of PIL images or list of lists of PIL images
|
||||
if not img:
|
||||
return []
|
||||
|
||||
# Check if it's a list of lists or a list of images
|
||||
if isinstance(img[0], Image.Image):
|
||||
# It's a list of images, convert to list of lists
|
||||
img = [[im] for im in img]
|
||||
|
||||
# potentially concatenate multiple images to reduce the size
|
||||
img = [[concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in img]
|
||||
|
||||
# cap the pixels
|
||||
img = [[cap_pixels(img_i, self.upsampling_max_image_size) for img_i in img_i] for img_i in img]
|
||||
return img
|
||||
|
||||
def format_input(
|
||||
self,
|
||||
txt: list[str],
|
||||
system_message: str = SYSTEM_MESSAGE,
|
||||
img: list[Image.Image] | list[list[Image.Image]] | None = None,
|
||||
) -> list[list[dict]]:
|
||||
"""
|
||||
Format a batch of text prompts into the conversation format expected by apply_chat_template.
|
||||
Optionally, add images to the input.
|
||||
|
||||
Args:
|
||||
txt: List of text prompts
|
||||
system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
|
||||
img: List of images to add to the input.
|
||||
|
||||
Returns:
|
||||
List of conversations, where each conversation is a list of message dicts
|
||||
"""
|
||||
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues
|
||||
# when truncation is enabled. The processor counts [IMG] tokens and fails
|
||||
# if the count changes after truncation.
|
||||
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in txt]
|
||||
|
||||
if img is None or len(img) == 0:
|
||||
return [
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
for prompt in cleaned_txt
|
||||
]
|
||||
else:
|
||||
assert len(img) == len(txt), "Number of images must match number of prompts"
|
||||
img = self._validate_and_process_images(img)
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
]
|
||||
for _ in cleaned_txt
|
||||
]
|
||||
|
||||
for i, (el, images) in enumerate(zip(messages, img)):
|
||||
# optionally add the images per batch element.
|
||||
if images is not None:
|
||||
el.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image", "image": image_obj} for image_obj in images],
|
||||
}
|
||||
)
|
||||
# add the text.
|
||||
el.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": cleaned_txt[i]}],
|
||||
}
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
@torch.no_grad()
|
||||
def upsample_prompt(
|
||||
self,
|
||||
txt: list[str],
|
||||
img: list[Image.Image] | list[list[Image.Image]] | None = None,
|
||||
temperature: float = 0.15,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Upsample prompts using the model's generate method.
|
||||
|
||||
Args:
|
||||
txt: List of input prompts to upsample
|
||||
img: Optional list of images or list of lists of images. If None or all None, uses t2i mode, otherwise i2i mode.
|
||||
|
||||
Returns:
|
||||
List of upsampled prompts
|
||||
"""
|
||||
# Set system message based on whether images are provided
|
||||
if img is None or len(img) == 0 or img[0] is None:
|
||||
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
|
||||
else:
|
||||
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
|
||||
|
||||
# Format input messages
|
||||
messages_batch = self.format_input(txt=txt, system_message=system_message, img=img)
|
||||
|
||||
# Process all messages at once
|
||||
# with image processing a too short max length can throw an error in here.
|
||||
try:
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages_batch,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=2048,
|
||||
)
|
||||
except ValueError as e:
|
||||
print(
|
||||
f"Error processing input: {e}, your max length is probably too short, when you have images in the input."
|
||||
)
|
||||
raise e
|
||||
|
||||
# Move to device
|
||||
inputs["input_ids"] = inputs["input_ids"].to(self.model.device)
|
||||
inputs["attention_mask"] = inputs["attention_mask"].to(self.model.device)
|
||||
|
||||
if "pixel_values" in inputs:
|
||||
inputs["pixel_values"] = inputs["pixel_values"].to(self.model.device, self.model.dtype)
|
||||
|
||||
# Generate text using the model's generate method
|
||||
try:
|
||||
generated_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Decode only the newly generated tokens (skip input tokens)
|
||||
# Extract only the generated portion
|
||||
input_length = inputs["input_ids"].shape[1]
|
||||
generated_tokens = generated_ids[:, input_length:]
|
||||
|
||||
raw_txt = self.processor.tokenizer.batch_decode(
|
||||
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
return raw_txt
|
||||
except Exception as e:
|
||||
print(f"Error generating upsampled prompt: {e}, returning original prompt")
|
||||
return txt
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, txt: list[str]):
|
||||
# Format input messages
|
||||
messages_batch = self.format_input(txt=txt)
|
||||
|
||||
# Process all messages at once
|
||||
# with image processing a too short max length can throw an error in here.
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages_batch,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
# Move to device
|
||||
input_ids = inputs["input_ids"].to(self.model.device)
|
||||
attention_mask = inputs["attention_mask"].to(self.model.device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
out = torch.stack([output.hidden_states[k] for k in OUTPUT_LAYERS], dim=1)
|
||||
return rearrange(out, "b c l d -> b l (c d)")
|
||||
|
||||
def yes_no_logit_processor(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Sets all tokens but yes/no to the minimum.
|
||||
"""
|
||||
scores_yes_token = scores[:, self.yes_token].clone()
|
||||
scores_no_token = scores[:, self.no_token].clone()
|
||||
scores_min = scores.min()
|
||||
scores[:, :] = scores_min - 1
|
||||
scores[:, self.yes_token] = scores_yes_token
|
||||
scores[:, self.no_token] = scores_no_token
|
||||
return scores
|
||||
|
||||
def test_image(self, image: Image.Image | str | Path | torch.Tensor) -> bool:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
|
||||
image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
|
||||
elif isinstance(image, (str, Path)):
|
||||
image = Image.open(image)
|
||||
|
||||
classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw")
|
||||
if classification["score"] > NSFW_THRESHOLD:
|
||||
return True
|
||||
|
||||
# 512^2 pixels are enough for checking
|
||||
w, h = image.size
|
||||
f = (512**2 / (w * h)) ** 0.5
|
||||
image = image.resize((int(f * w), int(f * h)))
|
||||
|
||||
chat = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": SYSTEM_PROMPT_CONTENT_FILTER,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": PROMPT_IMAGE_INTEGRITY,
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image": image,
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(self.model.device)
|
||||
inputs["pixel_values"] = inputs["pixel_values"].to(dtype=self.model.dtype)
|
||||
|
||||
generate_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1,
|
||||
logits_processor=[self.yes_no_logit_processor],
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
return generate_ids[0, -1].item() == self.yes_token
|
||||
|
||||
def test_txt(self, txt: str) -> bool:
|
||||
chat = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": SYSTEM_PROMPT_CONTENT_FILTER,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
chat,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(self.model.device)
|
||||
|
||||
generate_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1,
|
||||
logits_processor=[self.yes_no_logit_processor],
|
||||
do_sample=False,
|
||||
)
|
||||
return generate_ids[0, -1].item() == self.yes_token
|
||||
105
src/flux2/util.py
Normal file
105
src/flux2/util.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
||||
import huggingface_hub
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_sft
|
||||
|
||||
from .autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from .model import Flux2, Flux2Params
|
||||
from .text_encoder import Mistral3SmallEmbedder
|
||||
|
||||
FLUX2_MODEL_INFO = {
|
||||
"flux.2-dev": {
|
||||
"repo_id": "black-forest-labs/FLUX.2-dev",
|
||||
"filename": "flux2-dev.safetensors",
|
||||
"filename_ae": "ae.safetensors",
|
||||
"params": Flux2Params(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_flow_model(model_name: str, debug_mode: bool = False, device: str | torch.device = "cuda") -> Flux2:
|
||||
config = FLUX2_MODEL_INFO[model_name.lower()]
|
||||
|
||||
if debug_mode:
|
||||
config["params"].depth = 1
|
||||
config["params"].depth_single_blocks = 1
|
||||
else:
|
||||
if "FLUX2_MODEL_PATH" in os.environ:
|
||||
weight_path = os.environ["FLUX2_MODEL_PATH"]
|
||||
assert os.path.exists(weight_path), f"Provided weight path {weight_path} does not exist"
|
||||
else:
|
||||
# download from huggingface
|
||||
try:
|
||||
weight_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=config["repo_id"],
|
||||
filename=config["filename"],
|
||||
repo_type="model",
|
||||
)
|
||||
except huggingface_hub.errors.RepositoryNotFoundError:
|
||||
print(
|
||||
f"Failed to access the model repository. Please check your internet "
|
||||
f"connection and make sure you've access to {config['repo_id']}."
|
||||
"Stopping."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not debug_mode:
|
||||
with torch.device("meta"):
|
||||
model = Flux2(FLUX2_MODEL_INFO[model_name.lower()]["params"]).to(torch.bfloat16)
|
||||
print(f"Loading {weight_path} for the FLUX.2 weights")
|
||||
sd = load_sft(weight_path, device=str(device))
|
||||
model.load_state_dict(sd, strict=False, assign=True)
|
||||
return model.to(device)
|
||||
else:
|
||||
with torch.device(device):
|
||||
return Flux2(FLUX2_MODEL_INFO[model_name.lower()]["params"]).to(torch.bfloat16)
|
||||
|
||||
|
||||
def load_mistral_small_embedder(device: str | torch.device = "cuda") -> Mistral3SmallEmbedder:
|
||||
return Mistral3SmallEmbedder().to(device)
|
||||
|
||||
|
||||
def load_ae(model_name: str, device: str | torch.device = "cuda") -> AutoEncoder:
|
||||
config = FLUX2_MODEL_INFO[model_name.lower()]
|
||||
|
||||
if "AE_MODEL_PATH" in os.environ:
|
||||
weight_path = os.environ["AE_MODEL_PATH"]
|
||||
assert os.path.exists(weight_path), f"Provided weight path {weight_path} does not exist"
|
||||
else:
|
||||
# download from huggingface
|
||||
try:
|
||||
weight_path = huggingface_hub.hf_hub_download(
|
||||
repo_id=config["repo_id"],
|
||||
filename=config["filename_ae"],
|
||||
repo_type="model",
|
||||
)
|
||||
except huggingface_hub.errors.RepositoryNotFoundError:
|
||||
print(
|
||||
f"Failed to access the model repository. Please check your internet "
|
||||
f"connection and make sure you've access to {config['repo_id']}."
|
||||
"Stopping."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
with torch.device("meta"):
|
||||
ae = AutoEncoder(AutoEncoderParams())
|
||||
|
||||
print(f"Loading {weight_path} for the AutoEncoder weights")
|
||||
sd = load_sft(weight_path, device=str(device))
|
||||
ae.load_state_dict(sd, strict=True, assign=True)
|
||||
return ae.to(device)
|
||||
|
||||
|
||||
def image_to_base64(image: Image.Image) -> str:
|
||||
"""Convert PIL Image to base64 string."""
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
return img_str
|
||||
47
src/flux2/watermark.py
Normal file
47
src/flux2/watermark.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from imwatermark import WatermarkEncoder
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(WATERMARK_BITS)
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
image = 0.5 * image + 0.5
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
||||
image.device
|
||||
)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
image = image[0]
|
||||
image = 2 * image - 1
|
||||
return image
|
||||
|
||||
|
||||
# A fixed 48-bit message that was chosen at random
|
||||
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
Reference in New Issue
Block a user