1
This commit is contained in:
146
trellis/modules/attention/modules.py
Normal file
146
trellis/modules/attention/modules.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .full_attn import scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadRMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, heads: int):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** 0.5
|
||||||
|
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryPositionEmbedder(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, in_channels: int = 3):
|
||||||
|
super().__init__()
|
||||||
|
assert hidden_size % 2 == 0, "Hidden size must be divisible by 2"
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.freq_dim = hidden_size // in_channels // 2
|
||||||
|
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
||||||
|
self.freqs = 1.0 / (10000 ** self.freqs)
|
||||||
|
|
||||||
|
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
||||||
|
self.freqs = self.freqs.to(indices.device)
|
||||||
|
phases = torch.outer(indices, self.freqs)
|
||||||
|
phases = torch.polar(torch.ones_like(phases), phases)
|
||||||
|
return phases
|
||||||
|
|
||||||
|
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
||||||
|
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||||
|
x_rotated = x_complex * phases
|
||||||
|
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
||||||
|
return x_embed
|
||||||
|
|
||||||
|
def forward(self, q: torch.Tensor, k: torch.Tensor, indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
q (sp.SparseTensor): [..., N, D] tensor of queries
|
||||||
|
k (sp.SparseTensor): [..., N, D] tensor of keys
|
||||||
|
indices (torch.Tensor): [..., N, C] tensor of spatial positions
|
||||||
|
"""
|
||||||
|
if indices is None:
|
||||||
|
indices = torch.arange(q.shape[-2], device=q.device)
|
||||||
|
if len(q.shape) > 2:
|
||||||
|
indices = indices.unsqueeze(0).expand(q.shape[:-2] + (-1,))
|
||||||
|
|
||||||
|
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||||
|
if phases.shape[1] < self.hidden_size // 2:
|
||||||
|
phases = torch.cat([phases, torch.polar(
|
||||||
|
torch.ones(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device),
|
||||||
|
torch.zeros(*phases.shape[:-1], self.hidden_size // 2 - phases.shape[1], device=phases.device)
|
||||||
|
)], dim=-1)
|
||||||
|
q_embed = self._rotary_embedding(q, phases)
|
||||||
|
k_embed = self._rotary_embedding(k, phases)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
ctx_channels: Optional[int]=None,
|
||||||
|
type: Literal["self", "cross"] = "self",
|
||||||
|
attn_mode: Literal["full", "windowed"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert channels % num_heads == 0
|
||||||
|
assert type in ["self", "cross"], f"Invalid attention type: {type}"
|
||||||
|
assert attn_mode in ["full", "windowed"], f"Invalid attention mode: {attn_mode}"
|
||||||
|
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
|
||||||
|
|
||||||
|
if attn_mode == "windowed":
|
||||||
|
raise NotImplementedError("Windowed attention is not yet implemented")
|
||||||
|
|
||||||
|
self.channels = channels
|
||||||
|
self.head_dim = channels // num_heads
|
||||||
|
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self._type = type
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
self.window_size = window_size
|
||||||
|
self.shift_window = shift_window
|
||||||
|
self.use_rope = use_rope
|
||||||
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
|
||||||
|
if self._type == "self":
|
||||||
|
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
||||||
|
else:
|
||||||
|
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
||||||
|
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
||||||
|
|
||||||
|
if self.qk_rms_norm:
|
||||||
|
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
||||||
|
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
|
||||||
|
|
||||||
|
self.to_out = nn.Linear(channels, channels)
|
||||||
|
|
||||||
|
if use_rope:
|
||||||
|
self.rope = RotaryPositionEmbedder(channels)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, indices: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
B, L, C = x.shape
|
||||||
|
if self._type == "self":
|
||||||
|
qkv = self.to_qkv(x)
|
||||||
|
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
||||||
|
if self.use_rope:
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
q, k = self.rope(q, k, indices)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
if self.attn_mode == "full":
|
||||||
|
if self.qk_rms_norm:
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
q = self.q_rms_norm(q)
|
||||||
|
k = self.k_rms_norm(k)
|
||||||
|
h = scaled_dot_product_attention(q, k, v)
|
||||||
|
else:
|
||||||
|
h = scaled_dot_product_attention(qkv)
|
||||||
|
elif self.attn_mode == "windowed":
|
||||||
|
raise NotImplementedError("Windowed attention is not yet implemented")
|
||||||
|
else:
|
||||||
|
Lkv = context.shape[1]
|
||||||
|
q = self.to_q(x)
|
||||||
|
kv = self.to_kv(context)
|
||||||
|
q = q.reshape(B, L, self.num_heads, -1)
|
||||||
|
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
||||||
|
if self.qk_rms_norm:
|
||||||
|
q = self.q_rms_norm(q)
|
||||||
|
k, v = kv.unbind(dim=2)
|
||||||
|
k = self.k_rms_norm(k)
|
||||||
|
h = scaled_dot_product_attention(q, k, v)
|
||||||
|
else:
|
||||||
|
h = scaled_dot_product_attention(q, kv)
|
||||||
|
h = h.reshape(B, L, -1)
|
||||||
|
h = self.to_out(h)
|
||||||
|
return h
|
||||||
139
trellis/modules/sparse/attention/modules.py
Normal file
139
trellis/modules/sparse/attention/modules.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .. import SparseTensor
|
||||||
|
from .full_attn import sparse_scaled_dot_product_attention
|
||||||
|
from .serialized_attn import SerializeMode, sparse_serialized_scaled_dot_product_self_attention
|
||||||
|
from .windowed_attn import sparse_windowed_scaled_dot_product_self_attention
|
||||||
|
from ...attention import RotaryPositionEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMultiHeadRMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, heads: int):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim ** 0.5
|
||||||
|
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||||
|
|
||||||
|
def forward(self, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
|
||||||
|
x_type = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
if isinstance(x, SparseTensor):
|
||||||
|
x = x.replace(F.normalize(x.feats, dim=-1))
|
||||||
|
else:
|
||||||
|
x = F.normalize(x, dim=-1)
|
||||||
|
return (x * self.gamma * self.scale).to(x_type)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMultiHeadAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
ctx_channels: Optional[int] = None,
|
||||||
|
type: Literal["self", "cross"] = "self",
|
||||||
|
attn_mode: Literal["full", "serialized", "windowed"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_sequence: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
serialize_mode: Optional[SerializeMode] = None,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert channels % num_heads == 0
|
||||||
|
assert type in ["self", "cross"], f"Invalid attention type: {type}"
|
||||||
|
assert attn_mode in ["full", "serialized", "windowed"], f"Invalid attention mode: {attn_mode}"
|
||||||
|
assert type == "self" or attn_mode == "full", "Cross-attention only supports full attention"
|
||||||
|
assert type == "self" or use_rope is False, "Rotary position embeddings only supported for self-attention"
|
||||||
|
self.channels = channels
|
||||||
|
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self._type = type
|
||||||
|
self.attn_mode = attn_mode
|
||||||
|
self.window_size = window_size
|
||||||
|
self.shift_sequence = shift_sequence
|
||||||
|
self.shift_window = shift_window
|
||||||
|
self.serialize_mode = serialize_mode
|
||||||
|
self.use_rope = use_rope
|
||||||
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
|
||||||
|
if self._type == "self":
|
||||||
|
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
|
||||||
|
else:
|
||||||
|
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
|
||||||
|
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
|
||||||
|
|
||||||
|
if self.qk_rms_norm:
|
||||||
|
self.q_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
|
||||||
|
self.k_rms_norm = SparseMultiHeadRMSNorm(channels // num_heads, num_heads)
|
||||||
|
|
||||||
|
self.to_out = nn.Linear(channels, channels)
|
||||||
|
|
||||||
|
if use_rope:
|
||||||
|
self.rope = RotaryPositionEmbedder(channels)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _linear(module: nn.Linear, x: Union[SparseTensor, torch.Tensor]) -> Union[SparseTensor, torch.Tensor]:
|
||||||
|
if isinstance(x, SparseTensor):
|
||||||
|
return x.replace(module(x.feats))
|
||||||
|
else:
|
||||||
|
return module(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reshape_chs(x: Union[SparseTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[SparseTensor, torch.Tensor]:
|
||||||
|
if isinstance(x, SparseTensor):
|
||||||
|
return x.reshape(*shape)
|
||||||
|
else:
|
||||||
|
return x.reshape(*x.shape[:2], *shape)
|
||||||
|
|
||||||
|
def _fused_pre(self, x: Union[SparseTensor, torch.Tensor], num_fused: int) -> Union[SparseTensor, torch.Tensor]:
|
||||||
|
if isinstance(x, SparseTensor):
|
||||||
|
x_feats = x.feats.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
x_feats = x
|
||||||
|
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
|
||||||
|
return x.replace(x_feats.squeeze(0)) if isinstance(x, SparseTensor) else x_feats
|
||||||
|
|
||||||
|
def _rope(self, qkv: SparseTensor) -> SparseTensor:
|
||||||
|
q, k, v = qkv.feats.unbind(dim=1) # [T, H, C]
|
||||||
|
q, k = self.rope(q, k, qkv.coords[:, 1:])
|
||||||
|
qkv = qkv.replace(torch.stack([q, k, v], dim=1))
|
||||||
|
return qkv
|
||||||
|
|
||||||
|
def forward(self, x: Union[SparseTensor, torch.Tensor], context: Optional[Union[SparseTensor, torch.Tensor]] = None) -> Union[SparseTensor, torch.Tensor]:
|
||||||
|
if self._type == "self":
|
||||||
|
qkv = self._linear(self.to_qkv, x)
|
||||||
|
qkv = self._fused_pre(qkv, num_fused=3)
|
||||||
|
if self.use_rope:
|
||||||
|
qkv = self._rope(qkv)
|
||||||
|
if self.qk_rms_norm:
|
||||||
|
q, k, v = qkv.unbind(dim=1)
|
||||||
|
q = self.q_rms_norm(q)
|
||||||
|
k = self.k_rms_norm(k)
|
||||||
|
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
|
||||||
|
if self.attn_mode == "full":
|
||||||
|
h = sparse_scaled_dot_product_attention(qkv)
|
||||||
|
elif self.attn_mode == "serialized":
|
||||||
|
h = sparse_serialized_scaled_dot_product_self_attention(
|
||||||
|
qkv, self.window_size, serialize_mode=self.serialize_mode, shift_sequence=self.shift_sequence, shift_window=self.shift_window
|
||||||
|
)
|
||||||
|
elif self.attn_mode == "windowed":
|
||||||
|
h = sparse_windowed_scaled_dot_product_self_attention(
|
||||||
|
qkv, self.window_size, shift_window=self.shift_window
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = self._linear(self.to_q, x)
|
||||||
|
q = self._reshape_chs(q, (self.num_heads, -1))
|
||||||
|
kv = self._linear(self.to_kv, context)
|
||||||
|
kv = self._fused_pre(kv, num_fused=2)
|
||||||
|
if self.qk_rms_norm:
|
||||||
|
q = self.q_rms_norm(q)
|
||||||
|
k, v = kv.unbind(dim=1)
|
||||||
|
k = self.k_rms_norm(k)
|
||||||
|
kv = kv.replace(torch.stack([k.feats, v.feats], dim=1))
|
||||||
|
h = sparse_scaled_dot_product_attention(q, kv)
|
||||||
|
h = self._reshape_chs(h, (-1,))
|
||||||
|
h = self._linear(self.to_out, h)
|
||||||
|
return h
|
||||||
15
trellis/modules/sparse/linear.py
Normal file
15
trellis/modules/sparse/linear.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from . import SparseTensor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'SparseLinear'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SparseLinear(nn.Linear):
|
||||||
|
def __init__(self, in_features, out_features, bias=True):
|
||||||
|
super(SparseLinear, self).__init__(in_features, out_features, bias)
|
||||||
|
|
||||||
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||||
|
return input.replace(super().forward(input.feats))
|
||||||
166
trellis/modules/sparse/transformer/modulated.py
Normal file
166
trellis/modules/sparse/transformer/modulated.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ..basic import SparseTensor
|
||||||
|
from ..attention import SparseMultiHeadAttention, SerializeMode
|
||||||
|
from ...norm import LayerNorm32
|
||||||
|
from .blocks import SparseFeedForwardNet
|
||||||
|
|
||||||
|
|
||||||
|
class ModulatedSparseTransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_sequence: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
serialize_mode: Optional[SerializeMode] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
share_mod: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.share_mod = share_mod
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.attn = SparseMultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_sequence=shift_sequence,
|
||||||
|
shift_window=shift_window,
|
||||||
|
serialize_mode=serialize_mode,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rope=use_rope,
|
||||||
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
)
|
||||||
|
self.mlp = SparseFeedForwardNet(
|
||||||
|
channels,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
if not share_mod:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(channels, 6 * channels, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
|
||||||
|
if self.share_mod:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||||
|
h = x.replace(self.norm1(x.feats))
|
||||||
|
h = h * (1 + scale_msa) + shift_msa
|
||||||
|
h = self.attn(h)
|
||||||
|
h = h * gate_msa
|
||||||
|
x = x + h
|
||||||
|
h = x.replace(self.norm2(x.feats))
|
||||||
|
h = h * (1 + scale_mlp) + shift_mlp
|
||||||
|
h = self.mlp(h)
|
||||||
|
h = h * gate_mlp
|
||||||
|
x = x + h
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
return self._forward(x, mod)
|
||||||
|
|
||||||
|
|
||||||
|
class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
ctx_channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_sequence: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
serialize_mode: Optional[SerializeMode] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qk_rms_norm_cross: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
share_mod: bool = False,
|
||||||
|
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.share_mod = share_mod
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.self_attn = SparseMultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="self",
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_sequence=shift_sequence,
|
||||||
|
shift_window=shift_window,
|
||||||
|
serialize_mode=serialize_mode,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rope=use_rope,
|
||||||
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
)
|
||||||
|
self.cross_attn = SparseMultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
ctx_channels=ctx_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="cross",
|
||||||
|
attn_mode="full",
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_rms_norm=qk_rms_norm_cross,
|
||||||
|
)
|
||||||
|
self.mlp = SparseFeedForwardNet(
|
||||||
|
channels,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
if not share_mod:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(channels, 6 * channels, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
|
||||||
|
if self.share_mod:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||||
|
h = x.replace(self.norm1(x.feats))
|
||||||
|
h = h * (1 + scale_msa) + shift_msa
|
||||||
|
h = self.self_attn(h)
|
||||||
|
h = h * gate_msa
|
||||||
|
x = x + h
|
||||||
|
h = x.replace(self.norm2(x.feats))
|
||||||
|
h = self.cross_attn(h, context)
|
||||||
|
x = x + h
|
||||||
|
h = x.replace(self.norm3(x.feats))
|
||||||
|
h = h * (1 + scale_mlp) + shift_mlp
|
||||||
|
h = self.mlp(h)
|
||||||
|
h = h * gate_mlp
|
||||||
|
x = x + h
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor) -> SparseTensor:
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
return self._forward(x, mod, context)
|
||||||
157
trellis/modules/transformer/modulated.py
Normal file
157
trellis/modules/transformer/modulated.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ..attention import MultiHeadAttention
|
||||||
|
from ..norm import LayerNorm32
|
||||||
|
from .blocks import FeedForwardNet
|
||||||
|
|
||||||
|
|
||||||
|
class ModulatedTransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer block (MSA + FFN) with adaptive layer norm conditioning.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "windowed"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
share_mod: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.share_mod = share_mod
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.attn = MultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_window=shift_window,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rope=use_rope,
|
||||||
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
)
|
||||||
|
self.mlp = FeedForwardNet(
|
||||||
|
channels,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
if not share_mod:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(channels, 6 * channels, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.share_mod:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||||
|
h = self.norm1(x)
|
||||||
|
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||||
|
h = self.attn(h)
|
||||||
|
h = h * gate_msa.unsqueeze(1)
|
||||||
|
x = x + h
|
||||||
|
h = self.norm2(x)
|
||||||
|
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||||
|
h = self.mlp(h)
|
||||||
|
h = h * gate_mlp.unsqueeze(1)
|
||||||
|
x = x + h
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
return self._forward(x, mod)
|
||||||
|
|
||||||
|
|
||||||
|
class ModulatedTransformerCrossBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
ctx_channels: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
attn_mode: Literal["full", "windowed"] = "full",
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_rope: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qk_rms_norm_cross: bool = False,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
share_mod: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.share_mod = share_mod
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.self_attn = MultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="self",
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_window=shift_window,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
use_rope=use_rope,
|
||||||
|
qk_rms_norm=qk_rms_norm,
|
||||||
|
)
|
||||||
|
self.cross_attn = MultiHeadAttention(
|
||||||
|
channels,
|
||||||
|
ctx_channels=ctx_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
type="cross",
|
||||||
|
attn_mode="full",
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_rms_norm=qk_rms_norm_cross,
|
||||||
|
)
|
||||||
|
self.mlp = FeedForwardNet(
|
||||||
|
channels,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
)
|
||||||
|
if not share_mod:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(channels, 6 * channels, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
|
||||||
|
if self.share_mod:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||||
|
h = self.norm1(x)
|
||||||
|
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||||
|
h = self.self_attn(h)
|
||||||
|
h = h * gate_msa.unsqueeze(1)
|
||||||
|
x = x + h
|
||||||
|
h = self.norm2(x)
|
||||||
|
h = self.cross_attn(h, context)
|
||||||
|
x = x + h
|
||||||
|
h = self.norm3(x)
|
||||||
|
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||||
|
h = self.mlp(h)
|
||||||
|
h = h * gate_mlp.unsqueeze(1)
|
||||||
|
x = x + h
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor):
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
return self._forward(x, mod, context)
|
||||||
|
|
||||||
133
trellis/renderers/mesh_renderer.py
Normal file
133
trellis/renderers/mesh_renderer.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import torch
|
||||||
|
import nvdiffrast.torch as dr
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from ..representations.mesh import MeshExtractResult
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def intrinsics_to_projection(
|
||||||
|
intrinsics: torch.Tensor,
|
||||||
|
near: float,
|
||||||
|
far: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
OpenCV intrinsics to OpenGL perspective matrix
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
|
||||||
|
near (float): near plane to clip
|
||||||
|
far (float): far plane to clip
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): [4, 4] OpenGL perspective matrix
|
||||||
|
"""
|
||||||
|
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
|
||||||
|
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
|
||||||
|
ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
|
||||||
|
ret[0, 0] = 2 * fx
|
||||||
|
ret[1, 1] = 2 * fy
|
||||||
|
ret[0, 2] = 2 * cx - 1
|
||||||
|
ret[1, 2] = - 2 * cy + 1
|
||||||
|
ret[2, 2] = far / (far - near)
|
||||||
|
ret[2, 3] = near * far / (near - far)
|
||||||
|
ret[3, 2] = 1.
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class MeshRenderer:
|
||||||
|
"""
|
||||||
|
Renderer for the Mesh representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rendering_options (dict): Rendering options.
|
||||||
|
glctx (nvdiffrast.torch.RasterizeGLContext): RasterizeGLContext object for CUDA/OpenGL interop.
|
||||||
|
"""
|
||||||
|
def __init__(self, rendering_options={}, device='cuda'):
|
||||||
|
self.rendering_options = edict({
|
||||||
|
"resolution": None,
|
||||||
|
"near": None,
|
||||||
|
"far": None,
|
||||||
|
"ssaa": 1
|
||||||
|
})
|
||||||
|
self.rendering_options.update(rendering_options)
|
||||||
|
self.glctx = dr.RasterizeCudaContext(device=device)
|
||||||
|
self.device=device
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
mesh : MeshExtractResult,
|
||||||
|
extrinsics: torch.Tensor,
|
||||||
|
intrinsics: torch.Tensor,
|
||||||
|
return_types = ["mask", "normal", "depth"]
|
||||||
|
) -> edict:
|
||||||
|
"""
|
||||||
|
Render the mesh.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh : meshmodel
|
||||||
|
extrinsics (torch.Tensor): (4, 4) camera extrinsics
|
||||||
|
intrinsics (torch.Tensor): (3, 3) camera intrinsics
|
||||||
|
return_types (list): list of return types, can be "mask", "depth", "normal_map", "normal", "color"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
edict based on return_types containing:
|
||||||
|
color (torch.Tensor): [3, H, W] rendered color image
|
||||||
|
depth (torch.Tensor): [H, W] rendered depth image
|
||||||
|
normal (torch.Tensor): [3, H, W] rendered normal image
|
||||||
|
normal_map (torch.Tensor): [3, H, W] rendered normal map image
|
||||||
|
mask (torch.Tensor): [H, W] rendered mask image
|
||||||
|
"""
|
||||||
|
resolution = self.rendering_options["resolution"]
|
||||||
|
near = self.rendering_options["near"]
|
||||||
|
far = self.rendering_options["far"]
|
||||||
|
ssaa = self.rendering_options["ssaa"]
|
||||||
|
|
||||||
|
if mesh.vertices.shape[0] == 0 or mesh.faces.shape[0] == 0:
|
||||||
|
default_img = torch.zeros((1, resolution, resolution, 3), dtype=torch.float32, device=self.device)
|
||||||
|
ret_dict = {k : default_img if k in ['normal', 'normal_map', 'color'] else default_img[..., :1] for k in return_types}
|
||||||
|
return ret_dict
|
||||||
|
|
||||||
|
perspective = intrinsics_to_projection(intrinsics, near, far)
|
||||||
|
|
||||||
|
RT = extrinsics.unsqueeze(0)
|
||||||
|
full_proj = (perspective @ extrinsics).unsqueeze(0)
|
||||||
|
|
||||||
|
vertices = mesh.vertices.unsqueeze(0)
|
||||||
|
|
||||||
|
vertices_homo = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
|
||||||
|
vertices_camera = torch.bmm(vertices_homo, RT.transpose(-1, -2))
|
||||||
|
vertices_clip = torch.bmm(vertices_homo, full_proj.transpose(-1, -2))
|
||||||
|
faces_int = mesh.faces.int()
|
||||||
|
rast, _ = dr.rasterize(
|
||||||
|
self.glctx, vertices_clip, faces_int, (resolution * ssaa, resolution * ssaa))
|
||||||
|
|
||||||
|
out_dict = edict()
|
||||||
|
for type in return_types:
|
||||||
|
img = None
|
||||||
|
if type == "mask" :
|
||||||
|
img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
|
||||||
|
elif type == "depth":
|
||||||
|
img = dr.interpolate(vertices_camera[..., 2:3].contiguous(), rast, faces_int)[0]
|
||||||
|
img = dr.antialias(img, rast, vertices_clip, faces_int)
|
||||||
|
elif type == "normal" :
|
||||||
|
img = dr.interpolate(
|
||||||
|
mesh.face_normal.reshape(1, -1, 3), rast,
|
||||||
|
torch.arange(mesh.faces.shape[0] * 3, device=self.device, dtype=torch.int).reshape(-1, 3)
|
||||||
|
)[0]
|
||||||
|
img = dr.antialias(img, rast, vertices_clip, faces_int)
|
||||||
|
# normalize norm pictures
|
||||||
|
img = (img + 1) / 2
|
||||||
|
elif type == "normal_map" :
|
||||||
|
img = dr.interpolate(mesh.vertex_attrs[:, 3:].contiguous(), rast, faces_int)[0]
|
||||||
|
img = dr.antialias(img, rast, vertices_clip, faces_int)
|
||||||
|
elif type == "color" :
|
||||||
|
img = dr.interpolate(mesh.vertex_attrs[:, :3].contiguous(), rast, faces_int)[0]
|
||||||
|
img = dr.antialias(img, rast, vertices_clip, faces_int)
|
||||||
|
|
||||||
|
if ssaa > 1:
|
||||||
|
img = F.interpolate(img.permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False, antialias=True)
|
||||||
|
img = img.squeeze()
|
||||||
|
else:
|
||||||
|
img = img.permute(0, 3, 1, 2).squeeze()
|
||||||
|
out_dict[type] = img
|
||||||
|
|
||||||
|
return out_dict
|
||||||
102
trellis/representations/mesh/flexicubes/examples/loss.py
Normal file
102
trellis/representations/mesh/flexicubes/examples/loss.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import torch
|
||||||
|
import torch_scatter
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Pytorch implementation of the developability regularizer introduced in paper
|
||||||
|
# "Developability of Triangle Meshes" by Stein et al.
|
||||||
|
###############################################################################
|
||||||
|
def mesh_developable_reg(mesh):
|
||||||
|
|
||||||
|
verts = mesh.vertices
|
||||||
|
tris = mesh.faces
|
||||||
|
|
||||||
|
device = verts.device
|
||||||
|
V = verts.shape[0]
|
||||||
|
F = tris.shape[0]
|
||||||
|
|
||||||
|
POS_EPS = 1e-6
|
||||||
|
REL_EPS = 1e-6
|
||||||
|
|
||||||
|
def normalize(vecs):
|
||||||
|
return vecs / (torch.linalg.norm(vecs, dim=-1, keepdim=True) + POS_EPS)
|
||||||
|
|
||||||
|
tri_pos = verts[tris]
|
||||||
|
|
||||||
|
vert_normal_covariance_sum = torch.zeros((V, 9), device=device)
|
||||||
|
vert_area = torch.zeros(V, device=device)
|
||||||
|
vert_degree = torch.zeros(V, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
for iC in range(3): # loop over three corners of each triangle
|
||||||
|
|
||||||
|
# gather tri verts
|
||||||
|
pRoot = tri_pos[:, iC, :]
|
||||||
|
pA = tri_pos[:, (iC + 1) % 3, :]
|
||||||
|
pB = tri_pos[:, (iC + 2) % 3, :]
|
||||||
|
|
||||||
|
# compute the corner angle & normal
|
||||||
|
vA = pA - pRoot
|
||||||
|
vAn = normalize(vA)
|
||||||
|
vB = pB - pRoot
|
||||||
|
vBn = normalize(vB)
|
||||||
|
area_normal = torch.linalg.cross(vA, vB, dim=-1)
|
||||||
|
face_area = 0.5 * torch.linalg.norm(area_normal, dim=-1)
|
||||||
|
normal = normalize(area_normal)
|
||||||
|
corner_angle = torch.acos(torch.clamp(torch.sum(vAn * vBn, dim=-1), min=-1., max=1.))
|
||||||
|
|
||||||
|
# add up the contribution to the covariance matrix
|
||||||
|
outer = normal[:, :, None] @ normal[:, None, :]
|
||||||
|
contrib = corner_angle[:, None] * outer.reshape(-1, 9)
|
||||||
|
|
||||||
|
# scatter the result to the appropriate matrices
|
||||||
|
vert_normal_covariance_sum = torch_scatter.scatter_add(src=contrib,
|
||||||
|
index=tris[:, iC],
|
||||||
|
dim=-2,
|
||||||
|
out=vert_normal_covariance_sum)
|
||||||
|
|
||||||
|
vert_area = torch_scatter.scatter_add(src=face_area / 3.,
|
||||||
|
index=tris[:, iC],
|
||||||
|
dim=-1,
|
||||||
|
out=vert_area)
|
||||||
|
|
||||||
|
vert_degree = torch_scatter.scatter_add(src=torch.ones(F, dtype=torch.int32, device=device),
|
||||||
|
index=tris[:, iC],
|
||||||
|
dim=-1,
|
||||||
|
out=vert_degree)
|
||||||
|
|
||||||
|
# The energy is the smallest eigenvalue of the outer-product matrix
|
||||||
|
vert_normal_covariance_sum = vert_normal_covariance_sum.reshape(
|
||||||
|
-1, 3, 3) # reshape to a batch of matrices
|
||||||
|
vert_normal_covariance_sum = vert_normal_covariance_sum + torch.eye(
|
||||||
|
3, device=device)[None, :, :] * REL_EPS
|
||||||
|
|
||||||
|
min_eigvals = torch.min(torch.linalg.eigvals(vert_normal_covariance_sum).abs(), dim=-1).values
|
||||||
|
|
||||||
|
# Mask out degree-3 vertices
|
||||||
|
vert_area = torch.where(vert_degree == 3, torch.tensor(0, dtype=vert_area.dtype,device=vert_area.device), vert_area)
|
||||||
|
|
||||||
|
# Adjust the vertex area weighting so it is unit-less, and 1 on average
|
||||||
|
vert_area = vert_area * (V / torch.sum(vert_area, dim=-1, keepdim=True))
|
||||||
|
|
||||||
|
return vert_area * min_eigvals
|
||||||
|
|
||||||
|
def sdf_reg_loss(sdf, all_edges):
|
||||||
|
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
|
||||||
|
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
|
||||||
|
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
||||||
|
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
|
||||||
|
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
|
||||||
|
return sdf_diff
|
||||||
92
trellis/utils/loss_utils.py
Normal file
92
trellis/utils/loss_utils.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from math import exp
|
||||||
|
from lpips import LPIPS
|
||||||
|
|
||||||
|
|
||||||
|
def smooth_l1_loss(pred, target, beta=1.0):
|
||||||
|
diff = torch.abs(pred - target)
|
||||||
|
loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
|
def l1_loss(network_output, gt):
|
||||||
|
return torch.abs((network_output - gt)).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def l2_loss(network_output, gt):
|
||||||
|
return ((network_output - gt) ** 2).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian(window_size, sigma):
|
||||||
|
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
||||||
|
return gauss / gauss.sum()
|
||||||
|
|
||||||
|
|
||||||
|
def create_window(window_size, channel):
|
||||||
|
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||||
|
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||||
|
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||||
|
return window
|
||||||
|
|
||||||
|
|
||||||
|
def psnr(img1, img2, max_val=1.0):
|
||||||
|
mse = F.mse_loss(img1, img2)
|
||||||
|
return 20 * torch.log10(max_val / torch.sqrt(mse))
|
||||||
|
|
||||||
|
|
||||||
|
def ssim(img1, img2, window_size=11, size_average=True):
|
||||||
|
channel = img1.size(-3)
|
||||||
|
window = create_window(window_size, channel)
|
||||||
|
|
||||||
|
if img1.is_cuda:
|
||||||
|
window = window.cuda(img1.get_device())
|
||||||
|
window = window.type_as(img1)
|
||||||
|
|
||||||
|
return _ssim(img1, img2, window, window_size, channel, size_average)
|
||||||
|
|
||||||
|
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
||||||
|
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||||
|
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||||
|
|
||||||
|
mu1_sq = mu1.pow(2)
|
||||||
|
mu2_sq = mu2.pow(2)
|
||||||
|
mu1_mu2 = mu1 * mu2
|
||||||
|
|
||||||
|
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||||
|
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||||
|
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||||
|
|
||||||
|
C1 = 0.01 ** 2
|
||||||
|
C2 = 0.03 ** 2
|
||||||
|
|
||||||
|
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||||
|
|
||||||
|
if size_average:
|
||||||
|
return ssim_map.mean()
|
||||||
|
else:
|
||||||
|
return ssim_map.mean(1).mean(1).mean(1)
|
||||||
|
|
||||||
|
|
||||||
|
loss_fn_vgg = None
|
||||||
|
def lpips(img1, img2, value_range=(0, 1)):
|
||||||
|
global loss_fn_vgg
|
||||||
|
if loss_fn_vgg is None:
|
||||||
|
loss_fn_vgg = LPIPS(net='vgg').cuda().eval()
|
||||||
|
# normalize to [-1, 1]
|
||||||
|
img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
|
||||||
|
img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1
|
||||||
|
return loss_fn_vgg(img1, img2).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def normal_angle(pred, gt):
|
||||||
|
pred = pred * 2.0 - 1.0
|
||||||
|
gt = gt * 2.0 - 1.0
|
||||||
|
norms = pred.norm(dim=-1) * gt.norm(dim=-1)
|
||||||
|
cos_sim = (pred * gt).sum(-1) / (norms + 1e-9)
|
||||||
|
cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
|
||||||
|
ang = torch.rad2deg(torch.acos(cos_sim[norms > 1e-9])).mean()
|
||||||
|
if ang.isnan():
|
||||||
|
return -1
|
||||||
|
return ang
|
||||||
Reference in New Issue
Block a user