1
This commit is contained in:
25
trellis/modules/norm.py
Normal file
25
trellis/modules/norm.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
"""
|
||||
A GroupNorm layer that converts to float32 before the forward pass.
|
||||
"""
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class ChannelLayerNorm32(LayerNorm32):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
DIM = x.dim()
|
||||
x = x.permute(0, *range(2, DIM), 1).contiguous()
|
||||
x = super().forward(x)
|
||||
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
|
||||
return x
|
||||
|
||||
35
trellis/modules/sparse/nonlinearity.py
Normal file
35
trellis/modules/sparse/nonlinearity.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from . import SparseTensor
|
||||
|
||||
__all__ = [
|
||||
'SparseReLU',
|
||||
'SparseSiLU',
|
||||
'SparseGELU',
|
||||
'SparseActivation'
|
||||
]
|
||||
|
||||
|
||||
class SparseReLU(nn.ReLU):
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
|
||||
class SparseSiLU(nn.SiLU):
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
|
||||
class SparseGELU(nn.GELU):
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
|
||||
class SparseActivation(nn.Module):
|
||||
def __init__(self, activation: nn.Module):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
return input.replace(self.activation(input.feats))
|
||||
|
||||
58
trellis/modules/sparse/norm.py
Normal file
58
trellis/modules/sparse/norm.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from . import SparseTensor
|
||||
from . import DEBUG
|
||||
|
||||
__all__ = [
|
||||
'SparseGroupNorm',
|
||||
'SparseLayerNorm',
|
||||
'SparseGroupNorm32',
|
||||
'SparseLayerNorm32',
|
||||
]
|
||||
|
||||
|
||||
class SparseGroupNorm(nn.GroupNorm):
|
||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
||||
super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
|
||||
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
nfeats = torch.zeros_like(input.feats)
|
||||
for k in range(input.shape[0]):
|
||||
if DEBUG:
|
||||
assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
|
||||
bfeats = input.feats[input.layout[k]]
|
||||
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
|
||||
bfeats = super().forward(bfeats)
|
||||
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
|
||||
nfeats[input.layout[k]] = bfeats
|
||||
return input.replace(nfeats)
|
||||
|
||||
|
||||
class SparseLayerNorm(nn.LayerNorm):
|
||||
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
|
||||
super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
|
||||
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
nfeats = torch.zeros_like(input.feats)
|
||||
for k in range(input.shape[0]):
|
||||
bfeats = input.feats[input.layout[k]]
|
||||
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
|
||||
bfeats = super().forward(bfeats)
|
||||
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
|
||||
nfeats[input.layout[k]] = bfeats
|
||||
return input.replace(nfeats)
|
||||
|
||||
|
||||
class SparseGroupNorm32(SparseGroupNorm):
|
||||
"""
|
||||
A GroupNorm layer that converts to float32 before the forward pass.
|
||||
"""
|
||||
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
class SparseLayerNorm32(SparseLayerNorm):
|
||||
"""
|
||||
A LayerNorm layer that converts to float32 before the forward pass.
|
||||
"""
|
||||
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
Reference in New Issue
Block a user