This commit is contained in:
zcr
2026-03-17 11:38:36 +08:00
parent 7531afd162
commit c110ed1db0
8 changed files with 950 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
import torch
import torch.nn as nn
from . import SparseTensor
__all__ = [
'SparseLinear'
]
class SparseLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(SparseLinear, self).__init__(in_features, out_features, bias)
def forward(self, input: SparseTensor) -> SparseTensor:
return input.replace(super().forward(input.feats))