1
This commit is contained in:
102
trellis/modules/sparse/__init__.py
Executable file
102
trellis/modules/sparse/__init__.py
Executable file
@@ -0,0 +1,102 @@
|
||||
from typing import *
|
||||
|
||||
BACKEND = 'spconv'
|
||||
DEBUG = False
|
||||
ATTN = 'flash_attn'
|
||||
|
||||
def __from_env():
|
||||
import os
|
||||
|
||||
global BACKEND
|
||||
global DEBUG
|
||||
global ATTN
|
||||
|
||||
env_sparse_backend = os.environ.get('SPARSE_BACKEND')
|
||||
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
|
||||
env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
|
||||
if env_sparse_attn is None:
|
||||
env_sparse_attn = os.environ.get('ATTN_BACKEND')
|
||||
|
||||
if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
|
||||
BACKEND = env_sparse_backend
|
||||
if env_sparse_debug is not None:
|
||||
DEBUG = env_sparse_debug == '1'
|
||||
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
|
||||
ATTN = env_sparse_attn
|
||||
|
||||
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
||||
|
||||
|
||||
__from_env()
|
||||
|
||||
|
||||
def set_backend(backend: Literal['spconv', 'torchsparse']):
|
||||
global BACKEND
|
||||
BACKEND = backend
|
||||
|
||||
def set_debug(debug: bool):
|
||||
global DEBUG
|
||||
DEBUG = debug
|
||||
|
||||
def set_attn(attn: Literal['xformers', 'flash_attn']):
|
||||
global ATTN
|
||||
ATTN = attn
|
||||
|
||||
|
||||
import importlib
|
||||
|
||||
__attributes = {
|
||||
'SparseTensor': 'basic',
|
||||
'sparse_batch_broadcast': 'basic',
|
||||
'sparse_batch_op': 'basic',
|
||||
'sparse_cat': 'basic',
|
||||
'sparse_unbind': 'basic',
|
||||
'SparseGroupNorm': 'norm',
|
||||
'SparseLayerNorm': 'norm',
|
||||
'SparseGroupNorm32': 'norm',
|
||||
'SparseLayerNorm32': 'norm',
|
||||
'SparseReLU': 'nonlinearity',
|
||||
'SparseSiLU': 'nonlinearity',
|
||||
'SparseGELU': 'nonlinearity',
|
||||
'SparseActivation': 'nonlinearity',
|
||||
'SparseLinear': 'linear',
|
||||
'sparse_scaled_dot_product_attention': 'attention',
|
||||
'SerializeMode': 'attention',
|
||||
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
|
||||
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
|
||||
'SparseMultiHeadAttention': 'attention',
|
||||
'SparseConv3d': 'conv',
|
||||
'SparseInverseConv3d': 'conv',
|
||||
'SparseDownsample': 'spatial',
|
||||
'SparseUpsample': 'spatial',
|
||||
'SparseSubdivide' : 'spatial'
|
||||
}
|
||||
|
||||
__submodules = ['transformer']
|
||||
|
||||
__all__ = list(__attributes.keys()) + __submodules
|
||||
|
||||
def __getattr__(name):
|
||||
if name not in globals():
|
||||
if name in __attributes:
|
||||
module_name = __attributes[name]
|
||||
module = importlib.import_module(f".{module_name}", __name__)
|
||||
globals()[name] = getattr(module, name)
|
||||
elif name in __submodules:
|
||||
module = importlib.import_module(f".{name}", __name__)
|
||||
globals()[name] = module
|
||||
else:
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
return globals()[name]
|
||||
|
||||
|
||||
# For Pylance
|
||||
if __name__ == '__main__':
|
||||
from .basic import *
|
||||
from .norm import *
|
||||
from .nonlinearity import *
|
||||
from .linear import *
|
||||
from .attention import *
|
||||
from .conv import *
|
||||
from .spatial import *
|
||||
import transformer
|
||||
4
trellis/modules/sparse/attention/__init__.py
Executable file
4
trellis/modules/sparse/attention/__init__.py
Executable file
@@ -0,0 +1,4 @@
|
||||
from .full_attn import *
|
||||
from .serialized_attn import *
|
||||
from .windowed_attn import *
|
||||
from .modules import *
|
||||
459
trellis/modules/sparse/basic.py
Executable file
459
trellis/modules/sparse/basic.py
Executable file
@@ -0,0 +1,459 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from . import BACKEND, DEBUG
|
||||
SparseTensorData = None # Lazy import
|
||||
|
||||
|
||||
__all__ = [
|
||||
'SparseTensor',
|
||||
'sparse_batch_broadcast',
|
||||
'sparse_batch_op',
|
||||
'sparse_cat',
|
||||
'sparse_unbind',
|
||||
]
|
||||
|
||||
|
||||
class SparseTensor:
|
||||
"""
|
||||
Sparse tensor with support for both torchsparse and spconv backends.
|
||||
|
||||
Parameters:
|
||||
- feats (torch.Tensor): Features of the sparse tensor.
|
||||
- coords (torch.Tensor): Coordinates of the sparse tensor.
|
||||
- shape (torch.Size): Shape of the sparse tensor.
|
||||
- layout (List[slice]): Layout of the sparse tensor for each batch
|
||||
- data (SparseTensorData): Sparse tensor data used for convolusion
|
||||
|
||||
NOTE:
|
||||
- Data corresponding to a same batch should be contiguous.
|
||||
- Coords should be in [0, 1023]
|
||||
"""
|
||||
@overload
|
||||
def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
|
||||
|
||||
@overload
|
||||
def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Lazy import of sparse tensor backend
|
||||
global SparseTensorData
|
||||
if SparseTensorData is None:
|
||||
import importlib
|
||||
if BACKEND == 'torchsparse':
|
||||
SparseTensorData = importlib.import_module('torchsparse').SparseTensor
|
||||
elif BACKEND == 'spconv':
|
||||
SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
|
||||
|
||||
method_id = 0
|
||||
if len(args) != 0:
|
||||
method_id = 0 if isinstance(args[0], torch.Tensor) else 1
|
||||
else:
|
||||
method_id = 1 if 'data' in kwargs else 0
|
||||
|
||||
if method_id == 0:
|
||||
feats, coords, shape, layout = args + (None,) * (4 - len(args))
|
||||
if 'feats' in kwargs:
|
||||
feats = kwargs['feats']
|
||||
del kwargs['feats']
|
||||
if 'coords' in kwargs:
|
||||
coords = kwargs['coords']
|
||||
del kwargs['coords']
|
||||
if 'shape' in kwargs:
|
||||
shape = kwargs['shape']
|
||||
del kwargs['shape']
|
||||
if 'layout' in kwargs:
|
||||
layout = kwargs['layout']
|
||||
del kwargs['layout']
|
||||
|
||||
if shape is None:
|
||||
shape = self.__cal_shape(feats, coords)
|
||||
if layout is None:
|
||||
layout = self.__cal_layout(coords, shape[0])
|
||||
if BACKEND == 'torchsparse':
|
||||
self.data = SparseTensorData(feats, coords, **kwargs)
|
||||
elif BACKEND == 'spconv':
|
||||
spatial_shape = list(coords.max(0)[0] + 1)[1:]
|
||||
self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
|
||||
self.data._features = feats
|
||||
elif method_id == 1:
|
||||
data, shape, layout = args + (None,) * (3 - len(args))
|
||||
if 'data' in kwargs:
|
||||
data = kwargs['data']
|
||||
del kwargs['data']
|
||||
if 'shape' in kwargs:
|
||||
shape = kwargs['shape']
|
||||
del kwargs['shape']
|
||||
if 'layout' in kwargs:
|
||||
layout = kwargs['layout']
|
||||
del kwargs['layout']
|
||||
|
||||
self.data = data
|
||||
if shape is None:
|
||||
shape = self.__cal_shape(self.feats, self.coords)
|
||||
if layout is None:
|
||||
layout = self.__cal_layout(self.coords, shape[0])
|
||||
|
||||
self._shape = shape
|
||||
self._layout = layout
|
||||
self._scale = kwargs.get('scale', (1, 1, 1))
|
||||
self._spatial_cache = kwargs.get('spatial_cache', {})
|
||||
|
||||
if DEBUG:
|
||||
try:
|
||||
assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
|
||||
assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
|
||||
assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
|
||||
for i in range(self.shape[0]):
|
||||
assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
|
||||
except Exception as e:
|
||||
print('Debugging information:')
|
||||
print(f"- Shape: {self.shape}")
|
||||
print(f"- Layout: {self.layout}")
|
||||
print(f"- Scale: {self._scale}")
|
||||
print(f"- Coords: {self.coords}")
|
||||
raise e
|
||||
|
||||
def __cal_shape(self, feats, coords):
|
||||
shape = []
|
||||
shape.append(coords[:, 0].max().item() + 1)
|
||||
shape.extend([*feats.shape[1:]])
|
||||
return torch.Size(shape)
|
||||
|
||||
def __cal_layout(self, coords, batch_size):
|
||||
seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
|
||||
offset = torch.cumsum(seq_len, dim=0)
|
||||
layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
|
||||
return layout
|
||||
|
||||
@property
|
||||
def shape(self) -> torch.Size:
|
||||
return self._shape
|
||||
|
||||
def dim(self) -> int:
|
||||
return len(self.shape)
|
||||
|
||||
@property
|
||||
def layout(self) -> List[slice]:
|
||||
return self._layout
|
||||
|
||||
@property
|
||||
def feats(self) -> torch.Tensor:
|
||||
if BACKEND == 'torchsparse':
|
||||
return self.data.F
|
||||
elif BACKEND == 'spconv':
|
||||
return self.data.features
|
||||
|
||||
@feats.setter
|
||||
def feats(self, value: torch.Tensor):
|
||||
if BACKEND == 'torchsparse':
|
||||
self.data.F = value
|
||||
elif BACKEND == 'spconv':
|
||||
self.data.features = value
|
||||
|
||||
@property
|
||||
def coords(self) -> torch.Tensor:
|
||||
if BACKEND == 'torchsparse':
|
||||
return self.data.C
|
||||
elif BACKEND == 'spconv':
|
||||
return self.data.indices
|
||||
|
||||
@coords.setter
|
||||
def coords(self, value: torch.Tensor):
|
||||
if BACKEND == 'torchsparse':
|
||||
self.data.C = value
|
||||
elif BACKEND == 'spconv':
|
||||
self.data.indices = value
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.feats.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.feats.device
|
||||
|
||||
@overload
|
||||
def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
|
||||
|
||||
@overload
|
||||
def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
|
||||
|
||||
def to(self, *args, **kwargs) -> 'SparseTensor':
|
||||
device = None
|
||||
dtype = None
|
||||
if len(args) == 2:
|
||||
device, dtype = args
|
||||
elif len(args) == 1:
|
||||
if isinstance(args[0], torch.dtype):
|
||||
dtype = args[0]
|
||||
else:
|
||||
device = args[0]
|
||||
if 'dtype' in kwargs:
|
||||
assert dtype is None, "to() received multiple values for argument 'dtype'"
|
||||
dtype = kwargs['dtype']
|
||||
if 'device' in kwargs:
|
||||
assert device is None, "to() received multiple values for argument 'device'"
|
||||
device = kwargs['device']
|
||||
|
||||
new_feats = self.feats.to(device=device, dtype=dtype)
|
||||
new_coords = self.coords.to(device=device)
|
||||
return self.replace(new_feats, new_coords)
|
||||
|
||||
def type(self, dtype):
|
||||
new_feats = self.feats.type(dtype)
|
||||
return self.replace(new_feats)
|
||||
|
||||
def cpu(self) -> 'SparseTensor':
|
||||
new_feats = self.feats.cpu()
|
||||
new_coords = self.coords.cpu()
|
||||
return self.replace(new_feats, new_coords)
|
||||
|
||||
def cuda(self) -> 'SparseTensor':
|
||||
new_feats = self.feats.cuda()
|
||||
new_coords = self.coords.cuda()
|
||||
return self.replace(new_feats, new_coords)
|
||||
|
||||
def half(self) -> 'SparseTensor':
|
||||
new_feats = self.feats.half()
|
||||
return self.replace(new_feats)
|
||||
|
||||
def float(self) -> 'SparseTensor':
|
||||
new_feats = self.feats.float()
|
||||
return self.replace(new_feats)
|
||||
|
||||
def detach(self) -> 'SparseTensor':
|
||||
new_coords = self.coords.detach()
|
||||
new_feats = self.feats.detach()
|
||||
return self.replace(new_feats, new_coords)
|
||||
|
||||
def dense(self) -> torch.Tensor:
|
||||
if BACKEND == 'torchsparse':
|
||||
return self.data.dense()
|
||||
elif BACKEND == 'spconv':
|
||||
return self.data.dense()
|
||||
|
||||
def reshape(self, *shape) -> 'SparseTensor':
|
||||
new_feats = self.feats.reshape(self.feats.shape[0], *shape)
|
||||
return self.replace(new_feats)
|
||||
|
||||
def unbind(self, dim: int) -> List['SparseTensor']:
|
||||
return sparse_unbind(self, dim)
|
||||
|
||||
def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
|
||||
new_shape = [self.shape[0]]
|
||||
new_shape.extend(feats.shape[1:])
|
||||
if BACKEND == 'torchsparse':
|
||||
new_data = SparseTensorData(
|
||||
feats=feats,
|
||||
coords=self.data.coords if coords is None else coords,
|
||||
stride=self.data.stride,
|
||||
spatial_range=self.data.spatial_range,
|
||||
)
|
||||
new_data._caches = self.data._caches
|
||||
elif BACKEND == 'spconv':
|
||||
new_data = SparseTensorData(
|
||||
self.data.features.reshape(self.data.features.shape[0], -1),
|
||||
self.data.indices,
|
||||
self.data.spatial_shape,
|
||||
self.data.batch_size,
|
||||
self.data.grid,
|
||||
self.data.voxel_num,
|
||||
self.data.indice_dict
|
||||
)
|
||||
new_data._features = feats
|
||||
new_data.benchmark = self.data.benchmark
|
||||
new_data.benchmark_record = self.data.benchmark_record
|
||||
new_data.thrust_allocator = self.data.thrust_allocator
|
||||
new_data._timer = self.data._timer
|
||||
new_data.force_algo = self.data.force_algo
|
||||
new_data.int8_scale = self.data.int8_scale
|
||||
if coords is not None:
|
||||
new_data.indices = coords
|
||||
new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
|
||||
return new_tensor
|
||||
|
||||
@staticmethod
|
||||
def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
|
||||
N, C = dim
|
||||
x = torch.arange(aabb[0], aabb[3] + 1)
|
||||
y = torch.arange(aabb[1], aabb[4] + 1)
|
||||
z = torch.arange(aabb[2], aabb[5] + 1)
|
||||
coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
|
||||
coords = torch.cat([
|
||||
torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
|
||||
coords.repeat(N, 1),
|
||||
], dim=1).to(dtype=torch.int32, device=device)
|
||||
feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
|
||||
return SparseTensor(feats=feats, coords=coords)
|
||||
|
||||
def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
|
||||
new_cache = {}
|
||||
for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
|
||||
if k in self._spatial_cache:
|
||||
new_cache[k] = self._spatial_cache[k]
|
||||
if k in other._spatial_cache:
|
||||
if k not in new_cache:
|
||||
new_cache[k] = other._spatial_cache[k]
|
||||
else:
|
||||
new_cache[k].update(other._spatial_cache[k])
|
||||
return new_cache
|
||||
|
||||
def __neg__(self) -> 'SparseTensor':
|
||||
return self.replace(-self.feats)
|
||||
|
||||
def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
|
||||
if isinstance(other, torch.Tensor):
|
||||
try:
|
||||
other = torch.broadcast_to(other, self.shape)
|
||||
other = sparse_batch_broadcast(self, other)
|
||||
except:
|
||||
pass
|
||||
if isinstance(other, SparseTensor):
|
||||
other = other.feats
|
||||
new_feats = op(self.feats, other)
|
||||
new_tensor = self.replace(new_feats)
|
||||
if isinstance(other, SparseTensor):
|
||||
new_tensor._spatial_cache = self.__merge_sparse_cache(other)
|
||||
return new_tensor
|
||||
|
||||
def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.add)
|
||||
|
||||
def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.add)
|
||||
|
||||
def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.sub)
|
||||
|
||||
def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
|
||||
|
||||
def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.mul)
|
||||
|
||||
def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.mul)
|
||||
|
||||
def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.div)
|
||||
|
||||
def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, lambda x, y: torch.div(y, x))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
idx = [idx]
|
||||
elif isinstance(idx, slice):
|
||||
idx = range(*idx.indices(self.shape[0]))
|
||||
elif isinstance(idx, torch.Tensor):
|
||||
if idx.dtype == torch.bool:
|
||||
assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
|
||||
idx = idx.nonzero().squeeze(1)
|
||||
elif idx.dtype in [torch.int32, torch.int64]:
|
||||
assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
|
||||
else:
|
||||
raise ValueError(f"Unknown index type: {idx.dtype}")
|
||||
else:
|
||||
raise ValueError(f"Unknown index type: {type(idx)}")
|
||||
|
||||
coords = []
|
||||
feats = []
|
||||
for new_idx, old_idx in enumerate(idx):
|
||||
coords.append(self.coords[self.layout[old_idx]].clone())
|
||||
coords[-1][:, 0] = new_idx
|
||||
feats.append(self.feats[self.layout[old_idx]])
|
||||
coords = torch.cat(coords, dim=0).contiguous()
|
||||
feats = torch.cat(feats, dim=0).contiguous()
|
||||
return SparseTensor(feats=feats, coords=coords)
|
||||
|
||||
def register_spatial_cache(self, key, value) -> None:
|
||||
"""
|
||||
Register a spatial cache.
|
||||
The spatial cache can be any thing you want to cache.
|
||||
The registery and retrieval of the cache is based on current scale.
|
||||
"""
|
||||
scale_key = str(self._scale)
|
||||
if scale_key not in self._spatial_cache:
|
||||
self._spatial_cache[scale_key] = {}
|
||||
self._spatial_cache[scale_key][key] = value
|
||||
|
||||
def get_spatial_cache(self, key=None):
|
||||
"""
|
||||
Get a spatial cache.
|
||||
"""
|
||||
scale_key = str(self._scale)
|
||||
cur_scale_cache = self._spatial_cache.get(scale_key, {})
|
||||
if key is None:
|
||||
return cur_scale_cache
|
||||
return cur_scale_cache.get(key, None)
|
||||
|
||||
|
||||
def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): 1D tensor to broadcast.
|
||||
target (SparseTensor): Sparse tensor to broadcast to.
|
||||
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
|
||||
"""
|
||||
coords, feats = input.coords, input.feats
|
||||
broadcasted = torch.zeros_like(feats)
|
||||
for k in range(input.shape[0]):
|
||||
broadcasted[input.layout[k]] = other[k]
|
||||
return broadcasted
|
||||
|
||||
|
||||
def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
|
||||
"""
|
||||
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): 1D tensor to broadcast.
|
||||
target (SparseTensor): Sparse tensor to broadcast to.
|
||||
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
|
||||
"""
|
||||
return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
|
||||
|
||||
|
||||
def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
|
||||
"""
|
||||
Concatenate a list of sparse tensors.
|
||||
|
||||
Args:
|
||||
inputs (List[SparseTensor]): List of sparse tensors to concatenate.
|
||||
"""
|
||||
if dim == 0:
|
||||
start = 0
|
||||
coords = []
|
||||
for input in inputs:
|
||||
coords.append(input.coords.clone())
|
||||
coords[-1][:, 0] += start
|
||||
start += input.shape[0]
|
||||
coords = torch.cat(coords, dim=0)
|
||||
feats = torch.cat([input.feats for input in inputs], dim=0)
|
||||
output = SparseTensor(
|
||||
coords=coords,
|
||||
feats=feats,
|
||||
)
|
||||
else:
|
||||
feats = torch.cat([input.feats for input in inputs], dim=dim)
|
||||
output = inputs[0].replace(feats)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
|
||||
"""
|
||||
Unbind a sparse tensor along a dimension.
|
||||
|
||||
Args:
|
||||
input (SparseTensor): Sparse tensor to unbind.
|
||||
dim (int): Dimension to unbind.
|
||||
"""
|
||||
if dim == 0:
|
||||
return [input[i] for i in range(input.shape[0])]
|
||||
else:
|
||||
feats = input.feats.unbind(dim)
|
||||
return [input.replace(f) for f in feats]
|
||||
21
trellis/modules/sparse/conv/__init__.py
Executable file
21
trellis/modules/sparse/conv/__init__.py
Executable file
@@ -0,0 +1,21 @@
|
||||
from .. import BACKEND
|
||||
|
||||
|
||||
SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
|
||||
|
||||
def __from_env():
|
||||
import os
|
||||
|
||||
global SPCONV_ALGO
|
||||
env_spconv_algo = os.environ.get('SPCONV_ALGO')
|
||||
if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
|
||||
SPCONV_ALGO = env_spconv_algo
|
||||
print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
|
||||
|
||||
|
||||
__from_env()
|
||||
|
||||
if BACKEND == 'torchsparse':
|
||||
from .conv_torchsparse import *
|
||||
elif BACKEND == 'spconv':
|
||||
from .conv_spconv import *
|
||||
2
trellis/modules/sparse/transformer/__init__.py
Normal file
2
trellis/modules/sparse/transformer/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .blocks import *
|
||||
from .modulated import *
|
||||
151
trellis/modules/sparse/transformer/blocks.py
Normal file
151
trellis/modules/sparse/transformer/blocks.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..basic import SparseTensor
|
||||
from ..linear import SparseLinear
|
||||
from ..nonlinearity import SparseGELU
|
||||
from ..attention import SparseMultiHeadAttention, SerializeMode
|
||||
from ...norm import LayerNorm32
|
||||
|
||||
|
||||
class SparseFeedForwardNet(nn.Module):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
SparseLinear(channels, int(channels * mlp_ratio)),
|
||||
SparseGELU(approximate="tanh"),
|
||||
SparseLinear(int(channels * mlp_ratio), channels),
|
||||
)
|
||||
|
||||
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class SparseTransformerBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer block (MSA + FFN).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_sequence: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
serialize_mode: Optional[SerializeMode] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
ln_affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_sequence=shift_sequence,
|
||||
shift_window=shift_window,
|
||||
serialize_mode=serialize_mode,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
|
||||
def _forward(self, x: SparseTensor) -> SparseTensor:
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = self.attn(h)
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = self.mlp(h)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class SparseTransformerCrossBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer cross-attention block (MSA + MCA + FFN).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
ctx_channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_sequence: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
serialize_mode: Optional[SerializeMode] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
ln_affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.self_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
type="self",
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_sequence=shift_sequence,
|
||||
shift_window=shift_window,
|
||||
serialize_mode=serialize_mode,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.cross_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
type="cross",
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = self.self_attn(h)
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = x.replace(self.norm3(x.feats))
|
||||
h = self.mlp(h)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, context: torch.Tensor):
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, context)
|
||||
Reference in New Issue
Block a user