16 lines
387 B
Python
Executable File
16 lines
387 B
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
from . import SparseTensor
|
|
|
|
__all__ = [
|
|
'SparseLinear'
|
|
]
|
|
|
|
|
|
class SparseLinear(nn.Linear):
|
|
def __init__(self, in_features, out_features, bias=True):
|
|
super(SparseLinear, self).__init__(in_features, out_features, bias)
|
|
|
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
|
return input.replace(super().forward(input.feats))
|