1
This commit is contained in:
25
trellis/modules/norm.py
Normal file
25
trellis/modules/norm.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
"""
|
||||
A GroupNorm layer that converts to float32 before the forward pass.
|
||||
"""
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class ChannelLayerNorm32(LayerNorm32):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
DIM = x.dim()
|
||||
x = x.permute(0, *range(2, DIM), 1).contiguous()
|
||||
x = super().forward(x)
|
||||
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user