1
This commit is contained in:
146
trellis/modules/attention/modules.py
Normal file
146
trellis/modules/attention/modules.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .full_attn import scaled_dot_product_attention
|
||||
|
||||
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
||||
|
||||
|
||||
class RotaryPositionEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size: int, in_channels: int = 3):
|
||||
super().__init__()
|
||||
assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
|
||||
self.hidden_size = hidden_size
|
||||
self.in_channels = in_channels
|
||||
self.freq_dim = hidden_size // in_channels // 2
|
||||
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
||||
self.freqs = 1.0 / (10000 ** self.freqs)
|
||||
|
||||
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
self.freqs = self.freqs.to(indices.device)
|
||||
phases = torch.outer(indices, self.freqs)
|
||||
phases = torch.polar(torch.ones_like(phases), phases)
|
||||
return phases
|
||||
|
||||
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_rotated = x_complex * phases
|
||||
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
||||
return x_embed
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
q (sp.SparseTensor): [..., N, D] tensor of queries
|
||||
k (sp.SparseTensor): [..., N, D] tensor of keys
|
||||
indices (torch.Tensor): [..., N, C] tensor of spatial positions
|
||||
"""
|
||||
if indices is None:
|
||||
indices = torch.arange(q.shape[-2], device=q.device)
|
||||
if len(q.shape) > 2:
|
||||
indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
|
||||
|
||||
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||
if phases.shape[1] < self.hidden_size // 2:
|
||||
phases = torch.cat([phases, torch.polar(
|
||||
torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
|
||||
torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
|
||||
)], dim=-1)
|
||||
q_embed = self._rotary_embedding(q, phases)
|
||||
k_embed = self._rotary_embedding(k, phases)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
ctx_channels: Optional[int]=None,
|
||||
type: Literal["self", "cross"] = "self",
|
||||
attn_mode: Literal["full", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
qkv_bias: bool = True,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert channels % num_heads == 0
|
||||
assert type in ["self", "cross"], f"Invalid attention type: {type}"
|
||||
assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
|
||||
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
|
||||
|
||||
if attn_mode == "windowed":
|
||||
raise NotImplementedError("Windowed attention is not yet implemented")
|
||||
|
||||
self.channels = channels
|
||||
self.head_dim = channels // num_heads
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.attn_mode = attn_mode
|
||||
self.window_size = window_size
|
||||
self.shift_window = shift_window
|
||||
self.use_rope = use_rope
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
if self._type == "self":
|
||||
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
||||
else:
|
||||
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
||||
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
||||
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
||||
|
||||
self.to_out = nn.Linear(channels, channels)
|
||||
|
||||
if use_rope:
|
||||
self.rope = RotaryPositionEmbedder(channels)
|
||||
|
||||
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
B, L, C = x.shape
|
||||
if self._type == "self":
|
||||
qkv = self.to_qkv(x)
|
||||
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
||||
if self.use_rope:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
q, k = self.rope(q, k, indices)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
if self.attn_mode == "full":
|
||||
if self.qk_rms_norm:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
h = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = scaled_dot_product_attention(qkv)
|
||||
elif self.attn_mode == "windowed":
|
||||
raise NotImplementedError("Windowed attention is not yet implemented")
|
||||
else:
|
||||
Lkv = context.shape[1]
|
||||
q = self.to_q(x)
|
||||
kv = self.to_kv(context)
|
||||
q = q.reshape(B, L, self.num_heads, -1)
|
||||
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k, v = kv.unbind(dim=2)
|
||||
k = self.k_rms_norm(k)
|
||||
h = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = scaled_dot_product_attention(q, kv)
|
||||
h = h.reshape(B, L, -1)
|
||||
h = self.to_out(h)
|
||||
return h
|
||||
139
trellis/modules/sparse/attention/modules.py
Normal file
139
trellis/modules/sparse/attention/modules.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .. import SparseTensor
|
||||
from .full_attn import sparse_scaled_dot_product_attention
|
||||
from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
|
||||
from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
|
||||
from ...attention import RotaryPositionEmbedder
|
||||
|
||||
|
||||
class SparseMultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||
|
||||
def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
|
||||
x_type = x.dtype
|
||||
x = x.float()
|
||||
if isinstance(x, SparseTensor):
|
||||
x = x.replace(F.normalize(x.feats, dim=-1))
|
||||
else:
|
||||
x = F.normalize(x, dim=-1)
|
||||
return (x * self.gamma * self.scale).to(x_type)
|
||||
|
||||
|
||||
class SparseMultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
ctx_channels: Optional[int] = None,
|
||||
type: Literal["self", "cross"] = "self",
|
||||
attn_mode: Literal["full", "serialized", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_sequence: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
serialize_mode: Optional[SerializeMode] = None,
|
||||
qkv_bias: bool = True,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert channels % num_heads == 0
|
||||
assert type in ["self", "cross"], f"Invalid attention type: {type}"
|
||||
assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
|
||||
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
|
||||
assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
|
||||
self.channels = channels
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.attn_mode = attn_mode
|
||||
self.window_size = window_size
|
||||
self.shift_sequence = shift_sequence
|
||||
self.shift_window = shift_window
|
||||
self.serialize_mode = serialize_mode
|
||||
self.use_rope = use_rope
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
if self._type == "self":
|
||||
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
||||
else:
|
||||
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
||||
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
||||
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
|
||||
self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
|
||||
|
||||
self.to_out = nn.Linear(channels, channels)
|
||||
|
||||
if use_rope:
|
||||
self.rope = RotaryPositionEmbedder(channels)
|
||||
|
||||
@staticmethod
|
||||
def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
|
||||
if isinstance(x, SparseTensor):
|
||||
return x.replace(module(x.feats))
|
||||
else:
|
||||
return module(x)
|
||||
|
||||
@staticmethod
|
||||
def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
|
||||
if isinstance(x, SparseTensor):
|
||||
return x.reshape(*shape)
|
||||
else:
|
||||
return x.reshape(*x.shape[:2], *shape)
|
||||
|
||||
def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
|
||||
if isinstance(x, SparseTensor):
|
||||
x_feats = x.feats.unsqueeze(0)
|
||||
else:
|
||||
x_feats = x
|
||||
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
|
||||
return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
|
||||
|
||||
def _rope(self, qkv: SparseTensor) -> SparseTensor:
|
||||
q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
|
||||
q, k = self.rope(q, k, qkv.coords[:, 1:])
|
||||
qkv = qkv.replace(torch.stack([q, k, v], dim=1))
|
||||
return qkv
|
||||
|
||||
def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
|
||||
if self._type == "self":
|
||||
qkv = self._linear(self.to_qkv, x)
|
||||
qkv = self._fused_pre(qkv, num_fused=3)
|
||||
if self.use_rope:
|
||||
qkv = self._rope(qkv)
|
||||
if self.qk_rms_norm:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
|
||||
if self.attn_mode == "full":
|
||||
h = sparse_scaled_dot_product_attention(qkv)
|
||||
elif self.attn_mode == "serialized":
|
||||
h = sparse_serialized_scaled_dot_product_self_attention(
|
||||
qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
|
||||
)
|
||||
elif self.attn_mode == "windowed":
|
||||
h = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv, self.window_size, shift_window=self.shift_window
|
||||
)
|
||||
else:
|
||||
q = self._linear(self.to_q, x)
|
||||
q = self._reshape_chs(q, (self.num_heads, -1))
|
||||
kv = self._linear(self.to_kv, context)
|
||||
kv = self._fused_pre(kv, num_fused=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k, v = kv.unbind(dim=1)
|
||||
k = self.k_rms_norm(k)
|
||||
kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
|
||||
h = sparse_scaled_dot_product_attention(q, kv)
|
||||
h = self._reshape_chs(h, (-1,))
|
||||
h = self._linear(self.to_out, h)
|
||||
return h
|
||||
15
trellis/modules/sparse/linear.py
Normal file
15
trellis/modules/sparse/linear.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from . import SparseTensor
|
||||
|
||||
__all__ = [
|
||||
'SparseLinear'
|
||||
]
|
||||
|
||||
|
||||
class SparseLinear(nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super(SparseLinear, self).__init__(in_features, out_features, bias)
|
||||
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
166
trellis/modules/sparse/transformer/modulated.py
Normal file
166
trellis/modules/sparse/transformer/modulated.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..basic import SparseTensor
|
||||
from ..attention import SparseMultiHeadAttention, SerializeMode
|
||||
from ...norm import LayerNorm32
|
||||
from .blocks import SparseFeedForwardNet
|
||||
|
||||
|
||||
class ModulatedSparseTransformerBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_sequence: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
serialize_mode: Optional[SerializeMode] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_sequence=shift_sequence,
|
||||
shift_window=shift_window,
|
||||
serialize_mode=serialize_mode,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 6 * channels, bias=True)
|
||||
)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h * (1 + scale_msa) + shift_msa
|
||||
h = self.attn(h)
|
||||
h = h * gate_msa
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = h * (1 + scale_mlp) + shift_mlp
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, mod)
|
||||
|
||||
|
||||
class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
ctx_channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_sequence: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
serialize_mode: Optional[SerializeMode] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.self_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
type="self",
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_sequence=shift_sequence,
|
||||
shift_window=shift_window,
|
||||
serialize_mode=serialize_mode,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.cross_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
type="cross",
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 6 * channels, bias=True)
|
||||
)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h * (1 + scale_msa) + shift_msa
|
||||
h = self.self_attn(h)
|
||||
h = h * gate_msa
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = x.replace(self.norm3(x.feats))
|
||||
h = h * (1 + scale_mlp) + shift_mlp
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, mod, context)
|
||||
157
trellis/modules/transformer/modulated.py
Normal file
157
trellis/modules/transformer/modulated.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..attention import MultiHeadAttention
|
||||
from ..norm import LayerNorm32
|
||||
from .blocks import FeedForwardNet
|
||||
|
||||
|
||||
class ModulatedTransformerBlock(nn.Module):
|
||||
"""
|
||||
Transformer block (MSA + FFN) with adaptive layer norm conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = MultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_window=shift_window,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.mlp = FeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 6 * channels, bias=True)
|
||||
)
|
||||
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = self.norm1(x)
|
||||
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||
h = self.attn(h)
|
||||
h = h * gate_msa.unsqueeze(1)
|
||||
x = x + h
|
||||
h = self.norm2(x)
|
||||
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp.unsqueeze(1)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, mod)
|
||||
|
||||
|
||||
class ModulatedTransformerCrossBlock(nn.Module):
|
||||
"""
|
||||
Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
ctx_channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.self_attn = MultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
type="self",
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_window=shift_window,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.cross_attn = MultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
type="cross",
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
)
|
||||
self.mlp = FeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 6 * channels, bias=True)
|
||||
)
|
||||
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = self.norm1(x)
|
||||
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||
h = self.self_attn(h)
|
||||
h = h * gate_msa.unsqueeze(1)
|
||||
x = x + h
|
||||
h = self.norm2(x)
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = self.norm3(x)
|
||||
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp.unsqueeze(1)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, mod, context)
|
||||
|
||||
Reference in New Issue
Block a user