1
This commit is contained in:
35
trellis/modules/sparse/nonlinearity.py
Normal file
35
trellis/modules/sparse/nonlinearity.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from . import SparseTensor
|
||||
|
||||
__all__ = [
|
||||
'SparseReLU',
|
||||
'SparseSiLU',
|
||||
'SparseGELU',
|
||||
'SparseActivation'
|
||||
]
|
||||
|
||||
|
||||
class SparseReLU(nn.ReLU):
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
|
||||
class SparseSiLU(nn.SiLU):
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
|
||||
class SparseGELU(nn.GELU):
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
|
||||
class SparseActivation(nn.Module):
|
||||
def __init__(self, activation: nn.Module):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||
return input.replace(self.activation(input.feats))
|
||||
|
||||
Reference in New Issue
Block a user