FLUX.2 [klein]

This commit is contained in:
timudk
2026-01-15 15:12:38 +01:00
parent ab7cca6801
commit b56ac61450
12 changed files with 530 additions and 119 deletions

View File

@@ -17,6 +17,35 @@ class Flux2Params:
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
use_guidance_embed: bool = True
@dataclass
class Klein9BParams:
in_channels: int = 128
context_in_dim: int = 12288
hidden_size: int = 4096
num_heads: int = 32
depth: int = 8
depth_single_blocks: int = 24
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
use_guidance_embed: bool = False
@dataclass
class Klein4BParams:
in_channels: int = 128
context_in_dim: int = 7680
hidden_size: int = 3072
num_heads: int = 24
depth: int = 5
depth_single_blocks: int = 20
axes_dim: list[int] = field(default_factory=lambda: [32, 32, 32, 32])
theta: int = 2000
mlp_ratio: float = 3.0
use_guidance_embed: bool = False
class Flux2(nn.Module):
@@ -37,9 +66,12 @@ class Flux2(nn.Module):
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=False)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, bias=False)
self.use_guidance_embed = params.use_guidance_embed
if self.use_guidance_embed:
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, disable_bias=True)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@@ -86,14 +118,15 @@ class Flux2(nn.Module):
timesteps: Tensor,
ctx: Tensor,
ctx_ids: Tensor,
guidance: Tensor,
guidance: Tensor | None,
):
num_txt_tokens = ctx.shape[1]
timestep_emb = timestep_embedding(timesteps, 256)
vec = self.time_in(timestep_emb)
guidance_emb = timestep_embedding(guidance, 256)
vec = vec + self.guidance_in(guidance_emb)
if self.use_guidance_embed:
guidance_emb = timestep_embedding(guidance, 256)
vec = vec + self.guidance_in(guidance_emb)
double_block_mod_img = self.double_stream_modulation_img(vec)
double_block_mod_txt = self.double_stream_modulation_txt(vec)