This commit is contained in:
zcr
2026-03-17 11:28:52 +08:00
commit 59570f8812
45 changed files with 5308 additions and 0 deletions

6
trellis/__init__.py Executable file
View File

@@ -0,0 +1,6 @@
from . import models
from . import modules
from . import pipelines
from . import renderers
from . import representations
from . import utils

View File

@@ -0,0 +1,58 @@
import importlib
__attributes = {
'SparseStructure': 'sparse_structure',
'SparseFeat2Render': 'sparse_feat2render',
'SLat2Render':'structured_latent2render',
'Slat2RenderGeo':'structured_latent2render',
'SparseStructureLatent': 'sparse_structure_latent',
'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
'SLat': 'structured_latent',
'TextConditionedSLat': 'structured_latent',
'ImageConditionedSLat': 'structured_latent',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .sparse_structure import SparseStructure
from .sparse_feat2render import SparseFeat2Render
from .structured_latent2render import (
SLat2Render,
Slat2RenderGeo,
)
from .sparse_structure_latent import (
SparseStructureLatent,
TextConditionedSparseStructureLatent,
ImageConditionedSparseStructureLatent,
)
from .structured_latent import (
SLat,
TextConditionedSLat,
ImageConditionedSLat,
)

View File

@@ -0,0 +1,96 @@
import importlib
__attributes = {
'SparseStructureEncoder': 'sparse_structure_vae',
'SparseStructureDecoder': 'sparse_structure_vae',
'SparseStructureFlowModel': 'sparse_structure_flow',
'SLatEncoder': 'structured_latent_vae',
'SLatGaussianDecoder': 'structured_latent_vae',
'SLatRadianceFieldDecoder': 'structured_latent_vae',
'SLatMeshDecoder': 'structured_latent_vae',
'ElasticSLatEncoder': 'structured_latent_vae',
'ElasticSLatGaussianDecoder': 'structured_latent_vae',
'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
'ElasticSLatMeshDecoder': 'structured_latent_vae',
'SLatFlowModel': 'structured_latent_flow',
'ElasticSLatFlowModel': 'structured_latent_flow',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
def from_pretrained(path: str, **kwargs):
"""
Load a model from a pretrained checkpoint.
Args:
path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
**kwargs: Additional arguments for the model constructor.
"""
import os
import json
from safetensors.torch import load_file
is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
if is_local:
config_file = f"{path}.json"
model_file = f"{path}.safetensors"
else:
from huggingface_hub import hf_hub_download
path_parts = path.split('/')
repo_id = f'{path_parts[0]}/{path_parts[1]}'
model_name = '/'.join(path_parts[2:])
config_file = hf_hub_download(repo_id, f"{model_name}.json")
model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
with open(config_file, 'r') as f:
config = json.load(f)
model = __getattr__(config['name'])(**config['args'], **kwargs)
model.load_state_dict(load_file(model_file))
return model
# For Pylance
if __name__ == '__main__':
from .sparse_structure_vae import (
SparseStructureEncoder,
SparseStructureDecoder,
)
from .sparse_structure_flow import SparseStructureFlowModel
from .structured_latent_vae import (
SLatEncoder,
SLatGaussianDecoder,
SLatRadianceFieldDecoder,
SLatMeshDecoder,
ElasticSLatEncoder,
ElasticSLatGaussianDecoder,
ElasticSLatRadianceFieldDecoder,
ElasticSLatMeshDecoder,
)
from .structured_latent_flow import (
SLatFlowModel,
ElasticSLatFlowModel,
)

View File

@@ -0,0 +1,4 @@
from .encoder import SLatEncoder, ElasticSLatEncoder
from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder
from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder
from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder

View File

@@ -0,0 +1,117 @@
from typing import *
import torch
import torch.nn as nn
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
from ...modules import sparse as sp
from ...modules.transformer import AbsolutePositionEmbedder
from ...modules.sparse.transformer import SparseTransformerBlock
def block_attn_config(self):
"""
Return the attention configuration of the model.
"""
for i in range(self.num_blocks):
if self.attn_mode == "shift_window":
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
elif self.attn_mode == "shift_sequence":
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
elif self.attn_mode == "shift_order":
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
elif self.attn_mode == "full":
yield "full", None, None, None, None
elif self.attn_mode == "swin":
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
class SparseTransformerBase(nn.Module):
"""
Sparse Transformer without output layers.
Serve as the base class for encoder and decoder.
"""
def __init__(
self,
in_channels: int,
model_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
window_size: Optional[int] = None,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.num_blocks = num_blocks
self.window_size = window_size
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.attn_mode = attn_mode
self.pe_mode = pe_mode
self.use_fp16 = use_fp16
self.use_checkpoint = use_checkpoint
self.qk_rms_norm = qk_rms_norm
self.dtype = torch.float16 if use_fp16 else torch.float32
if pe_mode == "ape":
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
self.input_layer = sp.SparseLinear(in_channels, model_channels)
self.blocks = nn.ModuleList([
SparseTransformerBlock(
model_channels,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
shift_sequence=shift_sequence,
shift_window=shift_window,
serialize_mode=serialize_mode,
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
qk_rms_norm=self.qk_rms_norm,
)
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
])
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
self.blocks.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.blocks.apply(convert_module_to_f32)
def initialize_weights(self) -> None:
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
h = self.input_layer(x)
if self.pe_mode == "ape":
h = h + self.pos_embedder(x.coords[:, 1:])
h = h.type(self.dtype)
for block in self.blocks:
h = block(h)
return h

View File

@@ -0,0 +1,36 @@
from typing import *
BACKEND = 'flash_attn'
DEBUG = False
def __from_env():
import os
global BACKEND
global DEBUG
env_attn_backend = os.environ.get('ATTN_BACKEND')
env_sttn_debug = os.environ.get('ATTN_DEBUG')
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
BACKEND = env_attn_backend
if env_sttn_debug is not None:
DEBUG = env_sttn_debug == '1'
print(f"[ATTENTION] Using backend: {BACKEND}")
__from_env()
def set_backend(backend: Literal['xformers', 'flash_attn']):
global BACKEND
BACKEND = backend
def set_debug(debug: bool):
global DEBUG
DEBUG = debug
from .full_attn import *
from .modules import *

View File

@@ -0,0 +1,102 @@
from typing import *
BACKEND = 'spconv'
DEBUG = False
ATTN = 'flash_attn'
def __from_env():
import os
global BACKEND
global DEBUG
global ATTN
env_sparse_backend = os.environ.get('SPARSE_BACKEND')
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
if env_sparse_attn is None:
env_sparse_attn = os.environ.get('ATTN_BACKEND')
if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
BACKEND = env_sparse_backend
if env_sparse_debug is not None:
DEBUG = env_sparse_debug == '1'
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
ATTN = env_sparse_attn
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
__from_env()
def set_backend(backend: Literal['spconv', 'torchsparse']):
global BACKEND
BACKEND = backend
def set_debug(debug: bool):
global DEBUG
DEBUG = debug
def set_attn(attn: Literal['xformers', 'flash_attn']):
global ATTN
ATTN = attn
import importlib
__attributes = {
'SparseTensor': 'basic',
'sparse_batch_broadcast': 'basic',
'sparse_batch_op': 'basic',
'sparse_cat': 'basic',
'sparse_unbind': 'basic',
'SparseGroupNorm': 'norm',
'SparseLayerNorm': 'norm',
'SparseGroupNorm32': 'norm',
'SparseLayerNorm32': 'norm',
'SparseReLU': 'nonlinearity',
'SparseSiLU': 'nonlinearity',
'SparseGELU': 'nonlinearity',
'SparseActivation': 'nonlinearity',
'SparseLinear': 'linear',
'sparse_scaled_dot_product_attention': 'attention',
'SerializeMode': 'attention',
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
'SparseMultiHeadAttention': 'attention',
'SparseConv3d': 'conv',
'SparseInverseConv3d': 'conv',
'SparseDownsample': 'spatial',
'SparseUpsample': 'spatial',
'SparseSubdivide' : 'spatial'
}
__submodules = ['transformer']
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .basic import *
from .norm import *
from .nonlinearity import *
from .linear import *
from .attention import *
from .conv import *
from .spatial import *
import transformer

View File

@@ -0,0 +1,4 @@
from .full_attn import *
from .serialized_attn import *
from .windowed_attn import *
from .modules import *

459
trellis/modules/sparse/basic.py Executable file
View File

@@ -0,0 +1,459 @@
from typing import *
import torch
import torch.nn as nn
from . import BACKEND, DEBUG
SparseTensorData = None # Lazy import
__all__ = [
'SparseTensor',
'sparse_batch_broadcast',
'sparse_batch_op',
'sparse_cat',
'sparse_unbind',
]
class SparseTensor:
"""
Sparse tensor with support for both torchsparse and spconv backends.
Parameters:
- feats (torch.Tensor): Features of the sparse tensor.
- coords (torch.Tensor): Coordinates of the sparse tensor.
- shape (torch.Size): Shape of the sparse tensor.
- layout (List[slice]): Layout of the sparse tensor for each batch
- data (SparseTensorData): Sparse tensor data used for convolusion
NOTE:
- Data corresponding to a same batch should be contiguous.
- Coords should be in [0, 1023]
"""
@overload
def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
@overload
def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
def __init__(self, *args, **kwargs):
# Lazy import of sparse tensor backend
global SparseTensorData
if SparseTensorData is None:
import importlib
if BACKEND == 'torchsparse':
SparseTensorData = importlib.import_module('torchsparse').SparseTensor
elif BACKEND == 'spconv':
SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
method_id = 0
if len(args) != 0:
method_id = 0 if isinstance(args[0], torch.Tensor) else 1
else:
method_id = 1 if 'data' in kwargs else 0
if method_id == 0:
feats, coords, shape, layout = args + (None,) * (4 - len(args))
if 'feats' in kwargs:
feats = kwargs['feats']
del kwargs['feats']
if 'coords' in kwargs:
coords = kwargs['coords']
del kwargs['coords']
if 'shape' in kwargs:
shape = kwargs['shape']
del kwargs['shape']
if 'layout' in kwargs:
layout = kwargs['layout']
del kwargs['layout']
if shape is None:
shape = self.__cal_shape(feats, coords)
if layout is None:
layout = self.__cal_layout(coords, shape[0])
if BACKEND == 'torchsparse':
self.data = SparseTensorData(feats, coords, **kwargs)
elif BACKEND == 'spconv':
spatial_shape = list(coords.max(0)[0] + 1)[1:]
self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
self.data._features = feats
elif method_id == 1:
data, shape, layout = args + (None,) * (3 - len(args))
if 'data' in kwargs:
data = kwargs['data']
del kwargs['data']
if 'shape' in kwargs:
shape = kwargs['shape']
del kwargs['shape']
if 'layout' in kwargs:
layout = kwargs['layout']
del kwargs['layout']
self.data = data
if shape is None:
shape = self.__cal_shape(self.feats, self.coords)
if layout is None:
layout = self.__cal_layout(self.coords, shape[0])
self._shape = shape
self._layout = layout
self._scale = kwargs.get('scale', (1, 1, 1))
self._spatial_cache = kwargs.get('spatial_cache', {})
if DEBUG:
try:
assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
for i in range(self.shape[0]):
assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
except Exception as e:
print('Debugging information:')
print(f"- Shape: {self.shape}")
print(f"- Layout: {self.layout}")
print(f"- Scale: {self._scale}")
print(f"- Coords: {self.coords}")
raise e
def __cal_shape(self, feats, coords):
shape = []
shape.append(coords[:, 0].max().item() + 1)
shape.extend([*feats.shape[1:]])
return torch.Size(shape)
def __cal_layout(self, coords, batch_size):
seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
offset = torch.cumsum(seq_len, dim=0)
layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
return layout
@property
def shape(self) -> torch.Size:
return self._shape
def dim(self) -> int:
return len(self.shape)
@property
def layout(self) -> List[slice]:
return self._layout
@property
def feats(self) -> torch.Tensor:
if BACKEND == 'torchsparse':
return self.data.F
elif BACKEND == 'spconv':
return self.data.features
@feats.setter
def feats(self, value: torch.Tensor):
if BACKEND == 'torchsparse':
self.data.F = value
elif BACKEND == 'spconv':
self.data.features = value
@property
def coords(self) -> torch.Tensor:
if BACKEND == 'torchsparse':
return self.data.C
elif BACKEND == 'spconv':
return self.data.indices
@coords.setter
def coords(self, value: torch.Tensor):
if BACKEND == 'torchsparse':
self.data.C = value
elif BACKEND == 'spconv':
self.data.indices = value
@property
def dtype(self):
return self.feats.dtype
@property
def device(self):
return self.feats.device
@overload
def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
@overload
def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
def to(self, *args, **kwargs) -> 'SparseTensor':
device = None
dtype = None
if len(args) == 2:
device, dtype = args
elif len(args) == 1:
if isinstance(args[0], torch.dtype):
dtype = args[0]
else:
device = args[0]
if 'dtype' in kwargs:
assert dtype is None, "to() received multiple values for argument 'dtype'"
dtype = kwargs['dtype']
if 'device' in kwargs:
assert device is None, "to() received multiple values for argument 'device'"
device = kwargs['device']
new_feats = self.feats.to(device=device, dtype=dtype)
new_coords = self.coords.to(device=device)
return self.replace(new_feats, new_coords)
def type(self, dtype):
new_feats = self.feats.type(dtype)
return self.replace(new_feats)
def cpu(self) -> 'SparseTensor':
new_feats = self.feats.cpu()
new_coords = self.coords.cpu()
return self.replace(new_feats, new_coords)
def cuda(self) -> 'SparseTensor':
new_feats = self.feats.cuda()
new_coords = self.coords.cuda()
return self.replace(new_feats, new_coords)
def half(self) -> 'SparseTensor':
new_feats = self.feats.half()
return self.replace(new_feats)
def float(self) -> 'SparseTensor':
new_feats = self.feats.float()
return self.replace(new_feats)
def detach(self) -> 'SparseTensor':
new_coords = self.coords.detach()
new_feats = self.feats.detach()
return self.replace(new_feats, new_coords)
def dense(self) -> torch.Tensor:
if BACKEND == 'torchsparse':
return self.data.dense()
elif BACKEND == 'spconv':
return self.data.dense()
def reshape(self, *shape) -> 'SparseTensor':
new_feats = self.feats.reshape(self.feats.shape[0], *shape)
return self.replace(new_feats)
def unbind(self, dim: int) -> List['SparseTensor']:
return sparse_unbind(self, dim)
def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
new_shape = [self.shape[0]]
new_shape.extend(feats.shape[1:])
if BACKEND == 'torchsparse':
new_data = SparseTensorData(
feats=feats,
coords=self.data.coords if coords is None else coords,
stride=self.data.stride,
spatial_range=self.data.spatial_range,
)
new_data._caches = self.data._caches
elif BACKEND == 'spconv':
new_data = SparseTensorData(
self.data.features.reshape(self.data.features.shape[0], -1),
self.data.indices,
self.data.spatial_shape,
self.data.batch_size,
self.data.grid,
self.data.voxel_num,
self.data.indice_dict
)
new_data._features = feats
new_data.benchmark = self.data.benchmark
new_data.benchmark_record = self.data.benchmark_record
new_data.thrust_allocator = self.data.thrust_allocator
new_data._timer = self.data._timer
new_data.force_algo = self.data.force_algo
new_data.int8_scale = self.data.int8_scale
if coords is not None:
new_data.indices = coords
new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
return new_tensor
@staticmethod
def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
N, C = dim
x = torch.arange(aabb[0], aabb[3] + 1)
y = torch.arange(aabb[1], aabb[4] + 1)
z = torch.arange(aabb[2], aabb[5] + 1)
coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
coords = torch.cat([
torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
coords.repeat(N, 1),
], dim=1).to(dtype=torch.int32, device=device)
feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
return SparseTensor(feats=feats, coords=coords)
def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
new_cache = {}
for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
if k in self._spatial_cache:
new_cache[k] = self._spatial_cache[k]
if k in other._spatial_cache:
if k not in new_cache:
new_cache[k] = other._spatial_cache[k]
else:
new_cache[k].update(other._spatial_cache[k])
return new_cache
def __neg__(self) -> 'SparseTensor':
return self.replace(-self.feats)
def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
if isinstance(other, torch.Tensor):
try:
other = torch.broadcast_to(other, self.shape)
other = sparse_batch_broadcast(self, other)
except:
pass
if isinstance(other, SparseTensor):
other = other.feats
new_feats = op(self.feats, other)
new_tensor = self.replace(new_feats)
if isinstance(other, SparseTensor):
new_tensor._spatial_cache = self.__merge_sparse_cache(other)
return new_tensor
def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.add)
def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.add)
def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.sub)
def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.mul)
def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.mul)
def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, torch.div)
def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
return self.__elemwise__(other, lambda x, y: torch.div(y, x))
def __getitem__(self, idx):
if isinstance(idx, int):
idx = [idx]
elif isinstance(idx, slice):
idx = range(*idx.indices(self.shape[0]))
elif isinstance(idx, torch.Tensor):
if idx.dtype == torch.bool:
assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
idx = idx.nonzero().squeeze(1)
elif idx.dtype in [torch.int32, torch.int64]:
assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
else:
raise ValueError(f"Unknown index type: {idx.dtype}")
else:
raise ValueError(f"Unknown index type: {type(idx)}")
coords = []
feats = []
for new_idx, old_idx in enumerate(idx):
coords.append(self.coords[self.layout[old_idx]].clone())
coords[-1][:, 0] = new_idx
feats.append(self.feats[self.layout[old_idx]])
coords = torch.cat(coords, dim=0).contiguous()
feats = torch.cat(feats, dim=0).contiguous()
return SparseTensor(feats=feats, coords=coords)
def register_spatial_cache(self, key, value) -> None:
"""
Register a spatial cache.
The spatial cache can be any thing you want to cache.
The registery and retrieval of the cache is based on current scale.
"""
scale_key = str(self._scale)
if scale_key not in self._spatial_cache:
self._spatial_cache[scale_key] = {}
self._spatial_cache[scale_key][key] = value
def get_spatial_cache(self, key=None):
"""
Get a spatial cache.
"""
scale_key = str(self._scale)
cur_scale_cache = self._spatial_cache.get(scale_key, {})
if key is None:
return cur_scale_cache
return cur_scale_cache.get(key, None)
def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
"""
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
Args:
input (torch.Tensor): 1D tensor to broadcast.
target (SparseTensor): Sparse tensor to broadcast to.
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
"""
coords, feats = input.coords, input.feats
broadcasted = torch.zeros_like(feats)
for k in range(input.shape[0]):
broadcasted[input.layout[k]] = other[k]
return broadcasted
def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
"""
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
Args:
input (torch.Tensor): 1D tensor to broadcast.
target (SparseTensor): Sparse tensor to broadcast to.
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
"""
return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
"""
Concatenate a list of sparse tensors.
Args:
inputs (List[SparseTensor]): List of sparse tensors to concatenate.
"""
if dim == 0:
start = 0
coords = []
for input in inputs:
coords.append(input.coords.clone())
coords[-1][:, 0] += start
start += input.shape[0]
coords = torch.cat(coords, dim=0)
feats = torch.cat([input.feats for input in inputs], dim=0)
output = SparseTensor(
coords=coords,
feats=feats,
)
else:
feats = torch.cat([input.feats for input in inputs], dim=dim)
output = inputs[0].replace(feats)
return output
def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
"""
Unbind a sparse tensor along a dimension.
Args:
input (SparseTensor): Sparse tensor to unbind.
dim (int): Dimension to unbind.
"""
if dim == 0:
return [input[i] for i in range(input.shape[0])]
else:
feats = input.feats.unbind(dim)
return [input.replace(f) for f in feats]

View File

@@ -0,0 +1,21 @@
from .. import BACKEND
SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
def __from_env():
import os
global SPCONV_ALGO
env_spconv_algo = os.environ.get('SPCONV_ALGO')
if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
SPCONV_ALGO = env_spconv_algo
print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
__from_env()
if BACKEND == 'torchsparse':
from .conv_torchsparse import *
elif BACKEND == 'spconv':
from .conv_spconv import *

View File

@@ -0,0 +1,2 @@
from .blocks import *
from .modulated import *

View File

@@ -0,0 +1,151 @@
from typing import *
import torch
import torch.nn as nn
from ..basic import SparseTensor
from ..linear import SparseLinear
from ..nonlinearity import SparseGELU
from ..attention import SparseMultiHeadAttention, SerializeMode
from ...norm import LayerNorm32
class SparseFeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp = nn.Sequential(
SparseLinear(channels, int(channels * mlp_ratio)),
SparseGELU(approximate="tanh"),
SparseLinear(int(channels * mlp_ratio), channels),
)
def forward(self, x: SparseTensor) -> SparseTensor:
return self.mlp(x)
class SparseTransformerBlock(nn.Module):
"""
Sparse Transformer block (MSA + FFN).
"""
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
window_size: Optional[int] = None,
shift_sequence: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
serialize_mode: Optional[SerializeMode] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
qk_rms_norm: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_sequence=shift_sequence,
shift_window=shift_window,
serialize_mode=serialize_mode,
qkv_bias=qkv_bias,
use_rope=use_rope,
qk_rms_norm=qk_rms_norm,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: SparseTensor) -> SparseTensor:
h = x.replace(self.norm1(x.feats))
h = self.attn(h)
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.mlp(h)
x = x + h
return x
def forward(self, x: SparseTensor) -> SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class SparseTransformerCrossBlock(nn.Module):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN).
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
window_size: Optional[int] = None,
shift_sequence: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
serialize_mode: Optional[SerializeMode] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.self_attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_sequence=shift_sequence,
shift_window=shift_window,
serialize_mode=serialize_mode,
qkv_bias=qkv_bias,
use_rope=use_rope,
qk_rms_norm=qk_rms_norm,
)
self.cross_attn = SparseMultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
h = x.replace(self.norm1(x.feats))
h = self.self_attn(h)
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.cross_attn(h, context)
x = x + h
h = x.replace(self.norm3(x.feats))
h = self.mlp(h)
x = x + h
return x
def forward(self, x: SparseTensor, context: torch.Tensor):
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
else:
return self._forward(x, context)

View File

@@ -0,0 +1,2 @@
from .blocks import *
from .modulated import *

View File

@@ -0,0 +1,182 @@
from typing import *
import torch
import torch.nn as nn
from ..attention import MultiHeadAttention
from ..norm import LayerNorm32
class AbsolutePositionEmbedder(nn.Module):
"""
Embeds spatial positions into vector representations.
"""
def __init__(self, channels: int, in_channels: int = 3):
super().__init__()
self.channels = channels
self.in_channels = in_channels
self.freq_dim = channels // in_channels // 2
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
self.freqs = 1.0 / (10000 ** self.freqs)
def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
"""
Create sinusoidal position embeddings.
Args:
x: a 1-D Tensor of N indices
Returns:
an (N, D) Tensor of positional embeddings.
"""
self.freqs = self.freqs.to(x.device)
out = torch.outer(x, self.freqs)
out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
return out
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): (N, D) tensor of spatial positions
"""
N, D = x.shape
assert D == self.in_channels, "Input dimension must match number of input channels"
embed = self._sin_cos_embedding(x.reshape(-1))
embed = embed.reshape(N, -1)
if embed.shape[1] < self.channels:
embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
return embed
class FeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(channels, int(channels * mlp_ratio)),
nn.GELU(approximate="tanh"),
nn.Linear(int(channels * mlp_ratio), channels),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class TransformerBlock(nn.Module):
"""
Transformer block (MSA + FFN).
"""
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[int] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
qk_rms_norm: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.attn = MultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
qk_rms_norm=qk_rms_norm,
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm1(x)
h = self.attn(h)
x = x + h
h = self.norm2(x)
h = self.mlp(h)
x = x + h
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
class TransformerCrossBlock(nn.Module):
"""
Transformer cross-attention block (MSA + MCA + FFN).
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
ln_affine: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
self.self_attn = MultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
qk_rms_norm=qk_rms_norm,
)
self.cross_attn = MultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
)
self.mlp = FeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
def _forward(self, x: torch.Tensor, context: torch.Tensor):
h = self.norm1(x)
h = self.self_attn(h)
x = x + h
h = self.norm2(x)
h = self.cross_attn(h, context)
x = x + h
h = self.norm3(x)
h = self.mlp(h)
x = x + h
return x
def forward(self, x: torch.Tensor, context: torch.Tensor):
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
else:
return self._forward(x, context)

View File

@@ -0,0 +1,25 @@
from . import samplers
from .trellis_image_to_3d import TrellisImageTo3DPipeline
from .trellis_text_to_3d import TrellisTextTo3DPipeline
def from_pretrained(path: str):
"""
Load a pipeline from a model folder or a Hugging Face model hub.
Args:
path: The path to the model. Can be either local path or a Hugging Face model name.
"""
import os
import json
is_local = os.path.exists(f"{path}/pipeline.json")
if is_local:
config_file = f"{path}/pipeline.json"
else:
from huggingface_hub import hf_hub_download
config_file = hf_hub_download(path, "pipeline.json")
with open(config_file, 'r') as f:
config = json.load(f)
return globals()[config['name']].from_pretrained(path)

68
trellis/pipelines/base.py Normal file
View File

@@ -0,0 +1,68 @@
from typing import *
import torch
import torch.nn as nn
from .. import models
class Pipeline:
"""
A base class for pipelines.
"""
def __init__(
self,
models: dict[str, nn.Module] = None,
):
if models is None:
return
self.models = models
for model in self.models.values():
model.eval()
@staticmethod
def from_pretrained(path: str) -> "Pipeline":
"""
Load a pretrained model.
"""
import os
import json
is_local = os.path.exists(f"{path}/pipeline.json")
if is_local:
config_file = f"{path}/pipeline.json"
else:
from huggingface_hub import hf_hub_download
config_file = hf_hub_download(path, "pipeline.json")
with open(config_file, 'r') as f:
args = json.load(f)['args']
_models = {}
for k, v in args['models'].items():
try:
_models[k] = models.from_pretrained(f"{path}/{v}")
except:
_models[k] = models.from_pretrained(v)
new_pipeline = Pipeline(_models)
new_pipeline._pretrained_args = args
return new_pipeline
@property
def device(self) -> torch.device:
for model in self.models.values():
if hasattr(model, 'device'):
return model.device
for model in self.models.values():
if hasattr(model, 'parameters'):
return next(model.parameters()).device
raise RuntimeError("No device found.")
def to(self, device: torch.device) -> None:
for model in self.models.values():
model.to(device)
def cuda(self) -> None:
self.to(torch.device("cuda"))
def cpu(self) -> None:
self.to(torch.device("cpu"))

View File

@@ -0,0 +1,2 @@
from .base import Sampler
from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler

View File

@@ -0,0 +1,20 @@
from typing import *
from abc import ABC, abstractmethod
class Sampler(ABC):
"""
A base class for samplers.
"""
@abstractmethod
def sample(
self,
model,
**kwargs
):
"""
Sample from a model.
"""
pass

View File

@@ -0,0 +1,12 @@
from typing import *
class ClassifierFreeGuidanceSamplerMixin:
"""
A mixin class for samplers that apply classifier-free guidance.
"""
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
return (1 + cfg_strength) * pred - cfg_strength * neg_pred

31
trellis/renderers/__init__.py Executable file
View File

@@ -0,0 +1,31 @@
import importlib
__attributes = {
'OctreeRenderer': 'octree_renderer',
'GaussianRenderer': 'gaussian_render',
'MeshRenderer': 'mesh_renderer',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .octree_renderer import OctreeRenderer
from .gaussian_render import GaussianRenderer
from .mesh_renderer import MeshRenderer

View File

@@ -0,0 +1,4 @@
from .radiance_field import Strivec
from .octree import DfsOctree as Octree
from .gaussian import Gaussian
from .mesh import MeshExtractResult

View File

@@ -0,0 +1 @@
from .gaussian_model import Gaussian

View File

@@ -0,0 +1 @@
from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult

View File

@@ -0,0 +1 @@
from .octree_dfs import DfsOctree

View File

@@ -0,0 +1 @@
from .strivec import Strivec

View File

@@ -0,0 +1,63 @@
import importlib
__attributes = {
'BasicTrainer': 'basic',
'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian',
'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec',
'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec',
'FlowMatchingTrainer': 'flow_matching.flow_matching',
'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
# For Pylance
if __name__ == '__main__':
from .basic import BasicTrainer
from .vae.sparse_structure_vae import SparseStructureVaeTrainer
from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer
from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer
from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer
from .flow_matching.flow_matching import (
FlowMatchingTrainer,
FlowMatchingCFGTrainer,
TextConditionedFlowMatchingCFGTrainer,
ImageConditionedFlowMatchingCFGTrainer,
)
from .flow_matching.sparse_flow_matching import (
SparseFlowMatchingTrainer,
SparseFlowMatchingCFGTrainer,
TextConditionedSparseFlowMatchingCFGTrainer,
ImageConditionedSparseFlowMatchingCFGTrainer,
)

451
trellis/trainers/base.py Normal file
View File

@@ -0,0 +1,451 @@
from abc import abstractmethod
import os
import time
import json
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
import numpy as np
from torchvision import utils
from torch.utils.tensorboard import SummaryWriter
from .utils import *
from ..utils.general_utils import *
from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
class Trainer:
"""
Base class for training.
"""
def __init__(self,
models,
dataset,
*,
output_dir,
load_dir,
step,
max_steps,
batch_size=None,
batch_size_per_gpu=None,
batch_split=None,
optimizer={},
lr_scheduler=None,
elastic=None,
grad_clip=None,
ema_rate=0.9999,
fp16_mode='inflat_all',
fp16_scale_growth=1e-3,
finetune_ckpt=None,
log_param_stats=False,
prefetch_data=True,
i_print=1000,
i_log=500,
i_sample=10000,
i_save=10000,
i_ddpcheck=10000,
**kwargs
):
assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
self.models = models
self.dataset = dataset
self.batch_split = batch_split if batch_split is not None else 1
self.max_steps = max_steps
self.optimizer_config = optimizer
self.lr_scheduler_config = lr_scheduler
self.elastic_controller_config = elastic
self.grad_clip = grad_clip
self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
self.fp16_mode = fp16_mode
self.fp16_scale_growth = fp16_scale_growth
self.log_param_stats = log_param_stats
self.prefetch_data = prefetch_data
if self.prefetch_data:
self._data_prefetched = None
self.output_dir = output_dir
self.i_print = i_print
self.i_log = i_log
self.i_sample = i_sample
self.i_save = i_save
self.i_ddpcheck = i_ddpcheck
if dist.is_initialized():
# Multi-GPU params
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
self.local_rank = dist.get_rank() % torch.cuda.device_count()
self.is_master = self.rank == 0
else:
# Single-GPU params
self.world_size = 1
self.rank = 0
self.local_rank = 0
self.is_master = True
self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
self.init_models_and_more(**kwargs)
self.prepare_dataloader(**kwargs)
# Load checkpoint
self.step = 0
if load_dir is not None and step is not None:
self.load(load_dir, step)
elif finetune_ckpt is not None:
self.finetune_from(finetune_ckpt)
if self.is_master:
os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
if self.world_size > 1:
self.check_ddp()
if self.is_master:
print('\n\nTrainer initialized.')
print(self)
@property
def device(self):
for _, model in self.models.items():
if hasattr(model, 'device'):
return model.device
return next(list(self.models.values())[0].parameters()).device
@abstractmethod
def init_models_and_more(self, **kwargs):
"""
Initialize models and more.
"""
pass
def prepare_dataloader(self, **kwargs):
"""
Prepare dataloader.
"""
self.data_sampler = ResumableSampler(
self.dataset,
shuffle=True,
)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size_per_gpu,
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
pin_memory=True,
drop_last=True,
persistent_workers=True,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
sampler=self.data_sampler,
)
self.data_iterator = cycle(self.dataloader)
@abstractmethod
def load(self, load_dir, step=0):
"""
Load a checkpoint.
Should be called by all processes.
"""
pass
@abstractmethod
def save(self):
"""
Save a checkpoint.
Should be called only by the rank 0 process.
"""
pass
@abstractmethod
def finetune_from(self, finetune_ckpt):
"""
Finetune from a checkpoint.
Should be called by all processes.
"""
pass
@abstractmethod
def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
"""
Run a snapshot of the model.
"""
pass
@torch.no_grad()
def visualize_sample(self, sample):
"""
Convert a sample to an image.
"""
if hasattr(self.dataset, 'visualize_sample'):
return self.dataset.visualize_sample(sample)
else:
return sample
@torch.no_grad()
def snapshot_dataset(self, num_samples=100):
"""
Sample images from the dataset.
"""
dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=num_samples,
num_workers=0,
shuffle=True,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
data = next(iter(dataloader))
data = recursive_to_device(data, self.device)
vis = self.visualize_sample(data)
if isinstance(vis, dict):
save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()]
else:
save_cfg = [('dataset', vis)]
for name, image in save_cfg:
utils.save_image(
image,
os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
nrow=int(np.sqrt(num_samples)),
normalize=True,
value_range=self.dataset.value_range,
)
@torch.no_grad()
def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
"""
Sample images from the model.
NOTE: This function should be called by all processes.
"""
if self.is_master:
print(f'\nSampling {num_samples} images...', end='')
if suffix is None:
suffix = f'step{self.step:07d}'
# Assign tasks
num_samples_per_process = int(np.ceil(num_samples / self.world_size))
samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
# Preprocess images
for key in list(samples.keys()):
if samples[key]['type'] == 'sample':
vis = self.visualize_sample(samples[key]['value'])
if isinstance(vis, dict):
for k, v in vis.items():
samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
del samples[key]
else:
samples[key] = {'value': vis, 'type': 'image'}
# Gather results
if self.world_size > 1:
for key in samples.keys():
samples[key]['value'] = samples[key]['value'].contiguous()
if self.is_master:
all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
else:
all_images = []
dist.gather(samples[key]['value'], all_images, dst=0)
if self.is_master:
samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
# Save images
if self.is_master:
os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
for key in samples.keys():
if samples[key]['type'] == 'image':
utils.save_image(
samples[key]['value'],
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
nrow=int(np.sqrt(num_samples)),
normalize=True,
value_range=self.dataset.value_range,
)
elif samples[key]['type'] == 'number':
min = samples[key]['value'].min()
max = samples[key]['value'].max()
images = (samples[key]['value'] - min) / (max - min)
images = utils.make_grid(
images,
nrow=int(np.sqrt(num_samples)),
normalize=False,
)
save_image_with_notes(
images,
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
notes=f'{key} min: {min}, max: {max}',
)
if self.is_master:
print(' Done.')
@abstractmethod
def update_ema(self):
"""
Update exponential moving average.
Should only be called by the rank 0 process.
"""
pass
@abstractmethod
def check_ddp(self):
"""
Check if DDP is working properly.
Should be called by all process.
"""
pass
@abstractmethod
def training_losses(**mb_data):
"""
Compute training losses.
"""
pass
def load_data(self):
"""
Load data.
"""
if self.prefetch_data:
if self._data_prefetched is None:
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
data = self._data_prefetched
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
else:
data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
# if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
if isinstance(data, dict):
if self.batch_split == 1:
data_list = [data]
else:
batch_size = list(data.values())[0].shape[0]
data_list = [
{k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
for i in range(self.batch_split)
]
elif isinstance(data, list):
data_list = data
else:
raise ValueError('Data must be a dict or a list of dicts.')
return data_list
@abstractmethod
def run_step(self, data_list):
"""
Run a training step.
"""
pass
def run(self):
"""
Run training.
"""
if self.is_master:
print('\nStarting training...')
self.snapshot_dataset()
if self.step == 0:
self.snapshot(suffix='init')
else: # resume
self.snapshot(suffix=f'resume_step{self.step:07d}')
log = []
time_last_print = 0.0
time_elapsed = 0.0
while self.step < self.max_steps:
time_start = time.time()
data_list = self.load_data()
step_log = self.run_step(data_list)
time_end = time.time()
time_elapsed += time_end - time_start
self.step += 1
# Print progress
if self.is_master and self.step % self.i_print == 0:
speed = self.i_print / (time_elapsed - time_last_print) * 3600
columns = [
f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
f'Elapsed: {time_elapsed / 3600:.2f} h',
f'Speed: {speed:.2f} steps/h',
f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
]
print(' | '.join([c.ljust(25) for c in columns]), flush=True)
time_last_print = time_elapsed
# Check ddp
if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
self.check_ddp()
# Sample images
if self.step % self.i_sample == 0:
self.snapshot()
if self.is_master:
log.append((self.step, {}))
# Log time
log[-1][1]['time'] = {
'step': time_end - time_start,
'elapsed': time_elapsed,
}
# Log losses
if step_log is not None:
log[-1][1].update(step_log)
# Log scale
if self.fp16_mode == 'amp':
log[-1][1]['scale'] = self.scaler.get_scale()
elif self.fp16_mode == 'inflat_all':
log[-1][1]['log_scale'] = self.log_scale
# Save log
if self.step % self.i_log == 0:
## save to log file
log_str = '\n'.join([
f'{step}: {json.dumps(log)}' for step, log in log
])
with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
log_file.write(log_str + '\n')
# show with mlflow
log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))]
log_show = dict_reduce(log_show, lambda x: np.mean(x))
log_show = dict_flatten(log_show, sep='/')
for key, value in log_show.items():
self.writer.add_scalar(key, value, self.step)
log = []
# Save checkpoint
if self.step % self.i_save == 0:
self.save()
if self.is_master:
self.snapshot(suffix='final')
self.writer.close()
print('Training finished.')
def profile(self, wait=2, warmup=3, active=5):
"""
Profile the training loop.
"""
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
profile_memory=True,
with_stack=True,
) as prof:
for _ in range(wait + warmup + active):
self.run_step()
prof.step()

438
trellis/trainers/basic.py Normal file
View File

@@ -0,0 +1,438 @@
import os
import copy
from functools import partial
from contextlib import nullcontext
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
from .utils import *
from .base import Trainer
from ..utils.general_utils import *
from ..utils.dist_utils import *
from ..utils import grad_clip_utils, elastic_utils
class BasicTrainer(Trainer):
"""
Trainer for basic training loop.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
"""
def __str__(self):
lines = []
lines.append(self.__class__.__name__)
lines.append(f' - Models:')
for name, model in self.models.items():
lines.append(f' - {name}: {model.__class__.__name__}')
lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
lines.append(f' - Dataloader:')
lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
lines.append(f' - Num workers: {self.dataloader.num_workers}')
lines.append(f' - Number of steps: {self.max_steps}')
lines.append(f' - Number of GPUs: {self.world_size}')
lines.append(f' - Batch size: {self.batch_size}')
lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
lines.append(f' - Batch split: {self.batch_split}')
lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
if self.lr_scheduler_config is not None:
lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
if self.elastic_controller_config is not None:
lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
if self.grad_clip is not None:
lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
lines.append(f' - EMA rate: {self.ema_rate}')
lines.append(f' - FP16 mode: {self.fp16_mode}')
return '\n'.join(lines)
def init_models_and_more(self, **kwargs):
"""
Initialize models and more.
"""
if self.world_size > 1:
# Prepare distributed data parallel
self.training_models = {
name: DDP(
model,
device_ids=[self.local_rank],
output_device=self.local_rank,
bucket_cap_mb=128,
find_unused_parameters=False
)
for name, model in self.models.items()
}
else:
self.training_models = self.models
# Build master params
self.model_params = sum(
[[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
, [])
if self.fp16_mode == 'amp':
self.master_params = self.model_params
self.scaler = torch.GradScaler() if self.fp16_mode == 'amp' else None
elif self.fp16_mode == 'inflat_all':
self.master_params = make_master_params(self.model_params)
self.fp16_scale_growth = self.fp16_scale_growth
self.log_scale = 20.0
elif self.fp16_mode is None:
self.master_params = self.model_params
else:
raise NotImplementedError(f'FP16 mode {self.fp16_mode} is not implemented.')
# Build EMA params
if self.is_master:
self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
# Initialize optimizer
if hasattr(torch.optim, self.optimizer_config['name']):
self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
else:
self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
# Initalize learning rate scheduler
if self.lr_scheduler_config is not None:
if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
else:
self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
# Initialize elastic memory controller
if self.elastic_controller_config is not None:
assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
for model in self.models.values():
if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
model.register_memory_controller(self.elastic_controller)
# Initialize gradient clipper
if self.grad_clip is not None:
if isinstance(self.grad_clip, (float, int)):
self.grad_clip = float(self.grad_clip)
else:
self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
def _master_params_to_state_dicts(self, master_params):
"""
Convert master params to dict of state_dicts.
"""
if self.fp16_mode == 'inflat_all':
master_params = unflatten_master_params(self.model_params, master_params)
state_dicts = {name: model.state_dict() for name, model in self.models.items()}
master_params_names = sum(
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
, [])
for i, (model_name, param_name) in enumerate(master_params_names):
state_dicts[model_name][param_name] = master_params[i]
return state_dicts
def _state_dicts_to_master_params(self, master_params, state_dicts):
"""
Convert a state_dict to master params.
"""
master_params_names = sum(
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
, [])
params = [state_dicts[name][param_name] for name, param_name in master_params_names]
if self.fp16_mode == 'inflat_all':
model_params_to_master_params(params, master_params)
else:
for i, param in enumerate(params):
master_params[i].data.copy_(param.data)
def load(self, load_dir, step=0):
"""
Load a checkpoint.
Should be called by all processes.
"""
if self.is_master:
print(f'\nLoading checkpoint from step {step}...', end='')
model_ckpts = {}
for name, model in self.models.items():
model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
model_ckpts[name] = model_ckpt
model.load_state_dict(model_ckpt)
if self.fp16_mode == 'inflat_all':
model.convert_to_fp16()
self._state_dicts_to_master_params(self.master_params, model_ckpts)
del model_ckpts
if self.is_master:
for i, ema_rate in enumerate(self.ema_rate):
ema_ckpts = {}
for name, model in self.models.items():
ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
ema_ckpts[name] = ema_ckpt
self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
del ema_ckpts
misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
self.optimizer.load_state_dict(misc_ckpt['optimizer'])
self.step = misc_ckpt['step']
self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
if self.fp16_mode == 'amp':
self.scaler.load_state_dict(misc_ckpt['scaler'])
elif self.fp16_mode == 'inflat_all':
self.log_scale = misc_ckpt['log_scale']
if self.lr_scheduler_config is not None:
self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
if self.elastic_controller_config is not None:
self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
del misc_ckpt
if self.world_size > 1:
dist.barrier()
if self.is_master:
print(' Done.')
if self.world_size > 1:
self.check_ddp()
def save(self):
"""
Save a checkpoint.
Should be called only by the rank 0 process.
"""
assert self.is_master, 'save() should be called only by the rank 0 process.'
print(f'\nSaving checkpoint at step {self.step}...', end='')
model_ckpts = self._master_params_to_state_dicts(self.master_params)
for name, model_ckpt in model_ckpts.items():
torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
for i, ema_rate in enumerate(self.ema_rate):
ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
for name, ema_ckpt in ema_ckpts.items():
torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
misc_ckpt = {
'optimizer': self.optimizer.state_dict(),
'step': self.step,
'data_sampler': self.data_sampler.state_dict(),
}
if self.fp16_mode == 'amp':
misc_ckpt['scaler'] = self.scaler.state_dict()
elif self.fp16_mode == 'inflat_all':
misc_ckpt['log_scale'] = self.log_scale
if self.lr_scheduler_config is not None:
misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
if self.elastic_controller_config is not None:
misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
print(' Done.')
def finetune_from(self, finetune_ckpt):
"""
Finetune from a checkpoint.
Should be called by all processes.
"""
if self.is_master:
print('\nFinetuning from:')
for name, path in finetune_ckpt.items():
print(f' - {name}: {path}')
model_ckpts = {}
for name, model in self.models.items():
model_state_dict = model.state_dict()
if name in finetune_ckpt:
model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
for k, v in model_ckpt.items():
if model_ckpt[k].shape != model_state_dict[k].shape:
if self.is_master:
print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
model_ckpt[k] = model_state_dict[k]
model_ckpts[name] = model_ckpt
model.load_state_dict(model_ckpt)
if self.fp16_mode == 'inflat_all':
model.convert_to_fp16()
else:
if self.is_master:
print(f'Warning: {name} not found in finetune_ckpt, skipped.')
model_ckpts[name] = model_state_dict
self._state_dicts_to_master_params(self.master_params, model_ckpts)
del model_ckpts
if self.world_size > 1:
dist.barrier()
if self.is_master:
print('Done.')
if self.world_size > 1:
self.check_ddp()
def update_ema(self):
"""
Update exponential moving average.
Should only be called by the rank 0 process.
"""
assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
for i, ema_rate in enumerate(self.ema_rate):
for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
def check_ddp(self):
"""
Check if DDP is working properly.
Should be called by all process.
"""
if self.is_master:
print('\nPerforming DDP check...')
if self.is_master:
print('Checking if parameters are consistent across processes...')
dist.barrier()
try:
for p in self.master_params:
# split to avoid OOM
for i in range(0, p.numel(), 10000000):
sub_size = min(10000000, p.numel() - i)
sub_p = p.detach().view(-1)[i:i+sub_size]
# gather from all processes
sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
dist.all_gather(sub_p_gather, sub_p)
# check if equal
assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
except AssertionError as e:
if self.is_master:
print(f'\n\033[91mError: {e}\033[0m')
print('DDP check failed.')
raise e
dist.barrier()
if self.is_master:
print('Done.')
def run_step(self, data_list):
"""
Run a training step.
"""
step_log = {'loss': {}, 'status': {}}
amp_context = partial(torch.autocast, device_type='cuda') if self.fp16_mode == 'amp' else nullcontext
elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
# Train
losses = []
statuses = []
elastic_controller_logs = []
zero_grad(self.model_params)
for i, mb_data in enumerate(data_list):
## sync at the end of each batch split
sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
with nested_contexts(*sync_contexts), elastic_controller_context():
with amp_context():
loss, status = self.training_losses(**mb_data)
l = loss['loss'] / len(data_list)
## backward
if self.fp16_mode == 'amp':
self.scaler.scale(l).backward()
elif self.fp16_mode == 'inflat_all':
scaled_l = l * (2 ** self.log_scale)
scaled_l.backward()
else:
l.backward()
## log
losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
if self.elastic_controller_config is not None:
elastic_controller_logs.append(self.elastic_controller.log())
## gradient clip
if self.grad_clip is not None:
if self.fp16_mode == 'amp':
self.scaler.unscale_(self.optimizer)
elif self.fp16_mode == 'inflat_all':
model_grads_to_master_grads(self.model_params, self.master_params)
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
if isinstance(self.grad_clip, float):
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
else:
grad_norm = self.grad_clip(self.master_params)
if torch.isfinite(grad_norm):
statuses[-1]['grad_norm'] = grad_norm.item()
## step
if self.fp16_mode == 'amp':
prev_scale = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
elif self.fp16_mode == 'inflat_all':
prev_scale = 2 ** self.log_scale
if not any(not p.grad.isfinite().all() for p in self.model_params):
if self.grad_clip is None:
model_grads_to_master_grads(self.model_params, self.master_params)
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
self.optimizer.step()
master_params_to_model_params(self.model_params, self.master_params)
self.log_scale += self.fp16_scale_growth
else:
self.log_scale -= 1
else:
prev_scale = 1.0
if not any(not p.grad.isfinite().all() for p in self.model_params):
self.optimizer.step()
else:
print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
## adjust learning rate
if self.lr_scheduler_config is not None:
statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step()
# Logs
step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
if self.elastic_controller_config is not None:
step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
if self.grad_clip is not None:
step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
# Check grad and norm of each param
if self.log_param_stats:
param_norms = {}
param_grads = {}
for name, param in self.backbone.named_parameters():
if param.requires_grad:
param_norms[name] = param.norm().item()
if param.grad is not None and torch.isfinite(param.grad).all():
param_grads[name] = param.grad.norm().item() / prev_scale
step_log['param_norms'] = param_norms
step_log['param_grads'] = param_grads
# Update exponential moving average
if self.is_master:
self.update_ema()
return step_log

View File

@@ -0,0 +1,59 @@
import torch
import numpy as np
from ....utils.general_utils import dict_foreach
from ....pipelines import samplers
class ClassifierFreeGuidanceMixin:
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
super().__init__(*args, **kwargs)
self.p_uncond = p_uncond
def get_cond(self, cond, neg_cond=None, **kwargs):
"""
Get the conditioning data.
"""
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
if self.p_uncond > 0:
# randomly drop the class label
def get_batch_size(cond):
if isinstance(cond, torch.Tensor):
return cond.shape[0]
elif isinstance(cond, list):
return len(cond)
else:
raise ValueError(f"Unsupported type of cond: {type(cond)}")
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
B = get_batch_size(ref_cond)
def select(cond, neg_cond, mask):
if isinstance(cond, torch.Tensor):
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
return torch.where(mask, neg_cond, cond)
elif isinstance(cond, list):
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
else:
raise ValueError(f"Unsupported type of cond: {type(cond)}")
mask = list(np.random.rand(B) < self.p_uncond)
if not isinstance(cond, dict):
cond = select(cond, neg_cond, mask)
else:
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
return cond
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
"""
Get the conditioning data for inference.
"""
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
"""
Get the sampler for the diffusion process.
"""
return samplers.FlowEulerCfgSampler(self.sigma_min)

0
trellis/utils/__init__.py Executable file
View File