1
This commit is contained in:
6
trellis/__init__.py
Executable file
6
trellis/__init__.py
Executable file
@@ -0,0 +1,6 @@
|
||||
from . import models
|
||||
from . import modules
|
||||
from . import pipelines
|
||||
from . import renderers
|
||||
from . import representations
|
||||
from . import utils
|
||||
58
trellis/datasets/__init__.py
Normal file
58
trellis/datasets/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import importlib
|
||||
|
||||
__attributes = {
|
||||
'SparseStructure': 'sparse_structure',
|
||||
|
||||
'SparseFeat2Render': 'sparse_feat2render',
|
||||
'SLat2Render':'structured_latent2render',
|
||||
'Slat2RenderGeo':'structured_latent2render',
|
||||
|
||||
'SparseStructureLatent': 'sparse_structure_latent',
|
||||
'TextConditionedSparseStructureLatent': 'sparse_structure_latent',
|
||||
'ImageConditionedSparseStructureLatent': 'sparse_structure_latent',
|
||||
|
||||
'SLat': 'structured_latent',
|
||||
'TextConditionedSLat': 'structured_latent',
|
||||
'ImageConditionedSLat': 'structured_latent',
|
||||
}
|
||||
|
||||
__submodules = []
|
||||
|
||||
__all__ = list(__attributes.keys()) + __submodules
|
||||
|
||||
def __getattr__(name):
|
||||
if name not in globals():
|
||||
if name in __attributes:
|
||||
module_name = __attributes[name]
|
||||
module = importlib.import_module(f".{module_name}", __name__)
|
||||
globals()[name] = getattr(module, name)
|
||||
elif name in __submodules:
|
||||
module = importlib.import_module(f".{name}", __name__)
|
||||
globals()[name] = module
|
||||
else:
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
return globals()[name]
|
||||
|
||||
|
||||
# For Pylance
|
||||
if __name__ == '__main__':
|
||||
from .sparse_structure import SparseStructure
|
||||
|
||||
from .sparse_feat2render import SparseFeat2Render
|
||||
from .structured_latent2render import (
|
||||
SLat2Render,
|
||||
Slat2RenderGeo,
|
||||
)
|
||||
|
||||
from .sparse_structure_latent import (
|
||||
SparseStructureLatent,
|
||||
TextConditionedSparseStructureLatent,
|
||||
ImageConditionedSparseStructureLatent,
|
||||
)
|
||||
|
||||
from .structured_latent import (
|
||||
SLat,
|
||||
TextConditionedSLat,
|
||||
ImageConditionedSLat,
|
||||
)
|
||||
|
||||
96
trellis/models/__init__.py
Normal file
96
trellis/models/__init__.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import importlib
|
||||
|
||||
__attributes = {
|
||||
'SparseStructureEncoder': 'sparse_structure_vae',
|
||||
'SparseStructureDecoder': 'sparse_structure_vae',
|
||||
|
||||
'SparseStructureFlowModel': 'sparse_structure_flow',
|
||||
|
||||
'SLatEncoder': 'structured_latent_vae',
|
||||
'SLatGaussianDecoder': 'structured_latent_vae',
|
||||
'SLatRadianceFieldDecoder': 'structured_latent_vae',
|
||||
'SLatMeshDecoder': 'structured_latent_vae',
|
||||
'ElasticSLatEncoder': 'structured_latent_vae',
|
||||
'ElasticSLatGaussianDecoder': 'structured_latent_vae',
|
||||
'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
|
||||
'ElasticSLatMeshDecoder': 'structured_latent_vae',
|
||||
|
||||
'SLatFlowModel': 'structured_latent_flow',
|
||||
'ElasticSLatFlowModel': 'structured_latent_flow',
|
||||
}
|
||||
|
||||
__submodules = []
|
||||
|
||||
__all__ = list(__attributes.keys()) + __submodules
|
||||
|
||||
def __getattr__(name):
|
||||
if name not in globals():
|
||||
if name in __attributes:
|
||||
module_name = __attributes[name]
|
||||
module = importlib.import_module(f".{module_name}", __name__)
|
||||
globals()[name] = getattr(module, name)
|
||||
elif name in __submodules:
|
||||
module = importlib.import_module(f".{name}", __name__)
|
||||
globals()[name] = module
|
||||
else:
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
return globals()[name]
|
||||
|
||||
|
||||
def from_pretrained(path: str, **kwargs):
|
||||
"""
|
||||
Load a model from a pretrained checkpoint.
|
||||
|
||||
Args:
|
||||
path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
|
||||
NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
|
||||
**kwargs: Additional arguments for the model constructor.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from safetensors.torch import load_file
|
||||
is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
|
||||
|
||||
if is_local:
|
||||
config_file = f"{path}.json"
|
||||
model_file = f"{path}.safetensors"
|
||||
else:
|
||||
from huggingface_hub import hf_hub_download
|
||||
path_parts = path.split('/')
|
||||
repo_id = f'{path_parts[0]}/{path_parts[1]}'
|
||||
model_name = '/'.join(path_parts[2:])
|
||||
config_file = hf_hub_download(repo_id, f"{model_name}.json")
|
||||
model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
model = __getattr__(config['name'])(**config['args'], **kwargs)
|
||||
model.load_state_dict(load_file(model_file))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# For Pylance
|
||||
if __name__ == '__main__':
|
||||
from .sparse_structure_vae import (
|
||||
SparseStructureEncoder,
|
||||
SparseStructureDecoder,
|
||||
)
|
||||
|
||||
from .sparse_structure_flow import SparseStructureFlowModel
|
||||
|
||||
from .structured_latent_vae import (
|
||||
SLatEncoder,
|
||||
SLatGaussianDecoder,
|
||||
SLatRadianceFieldDecoder,
|
||||
SLatMeshDecoder,
|
||||
ElasticSLatEncoder,
|
||||
ElasticSLatGaussianDecoder,
|
||||
ElasticSLatRadianceFieldDecoder,
|
||||
ElasticSLatMeshDecoder,
|
||||
)
|
||||
|
||||
from .structured_latent_flow import (
|
||||
SLatFlowModel,
|
||||
ElasticSLatFlowModel,
|
||||
)
|
||||
4
trellis/models/structured_latent_vae/__init__.py
Normal file
4
trellis/models/structured_latent_vae/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .encoder import SLatEncoder, ElasticSLatEncoder
|
||||
from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder
|
||||
from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder
|
||||
from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder
|
||||
117
trellis/models/structured_latent_vae/base.py
Normal file
117
trellis/models/structured_latent_vae/base.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
|
||||
from ...modules import sparse as sp
|
||||
from ...modules.transformer import AbsolutePositionEmbedder
|
||||
from ...modules.sparse.transformer import SparseTransformerBlock
|
||||
|
||||
|
||||
def block_attn_config(self):
|
||||
"""
|
||||
Return the attention configuration of the model.
|
||||
"""
|
||||
for i in range(self.num_blocks):
|
||||
if self.attn_mode == "shift_window":
|
||||
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
|
||||
elif self.attn_mode == "shift_sequence":
|
||||
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
|
||||
elif self.attn_mode == "shift_order":
|
||||
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
|
||||
elif self.attn_mode == "full":
|
||||
yield "full", None, None, None, None
|
||||
elif self.attn_mode == "swin":
|
||||
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
|
||||
|
||||
|
||||
class SparseTransformerBase(nn.Module):
|
||||
"""
|
||||
Sparse Transformer without output layers.
|
||||
Serve as the base class for encoder and decoder.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
pe_mode: Literal["ape", "rope"] = "ape",
|
||||
use_fp16: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.window_size = window_size
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.attn_mode = attn_mode
|
||||
self.pe_mode = pe_mode
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
|
||||
if pe_mode == "ape":
|
||||
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
||||
|
||||
self.input_layer = sp.SparseLinear(in_channels, model_channels)
|
||||
self.blocks = nn.ModuleList([
|
||||
SparseTransformerBlock(
|
||||
model_channels,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_sequence=shift_sequence,
|
||||
shift_window=shift_window,
|
||||
serialize_mode=serialize_mode,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
use_rope=(pe_mode == "rope"),
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
)
|
||||
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
|
||||
])
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
Return the device of the model.
|
||||
"""
|
||||
return next(self.parameters()).device
|
||||
|
||||
def convert_to_fp16(self) -> None:
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.blocks.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self) -> None:
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.blocks.apply(convert_module_to_f32)
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
self.apply(_basic_init)
|
||||
|
||||
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
||||
h = self.input_layer(x)
|
||||
if self.pe_mode == "ape":
|
||||
h = h + self.pos_embedder(x.coords[:, 1:])
|
||||
h = h.type(self.dtype)
|
||||
for block in self.blocks:
|
||||
h = block(h)
|
||||
return h
|
||||
36
trellis/modules/attention/__init__.py
Executable file
36
trellis/modules/attention/__init__.py
Executable file
@@ -0,0 +1,36 @@
|
||||
from typing import *
|
||||
|
||||
BACKEND = 'flash_attn'
|
||||
DEBUG = False
|
||||
|
||||
def __from_env():
|
||||
import os
|
||||
|
||||
global BACKEND
|
||||
global DEBUG
|
||||
|
||||
env_attn_backend = os.environ.get('ATTN_BACKEND')
|
||||
env_sttn_debug = os.environ.get('ATTN_DEBUG')
|
||||
|
||||
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
|
||||
BACKEND = env_attn_backend
|
||||
if env_sttn_debug is not None:
|
||||
DEBUG = env_sttn_debug == '1'
|
||||
|
||||
print(f"[ATTENTION] Using backend: {BACKEND}")
|
||||
|
||||
|
||||
__from_env()
|
||||
|
||||
|
||||
def set_backend(backend: Literal['xformers', 'flash_attn']):
|
||||
global BACKEND
|
||||
BACKEND = backend
|
||||
|
||||
def set_debug(debug: bool):
|
||||
global DEBUG
|
||||
DEBUG = debug
|
||||
|
||||
|
||||
from .full_attn import *
|
||||
from .modules import *
|
||||
102
trellis/modules/sparse/__init__.py
Executable file
102
trellis/modules/sparse/__init__.py
Executable file
@@ -0,0 +1,102 @@
|
||||
from typing import *
|
||||
|
||||
BACKEND = 'spconv'
|
||||
DEBUG = False
|
||||
ATTN = 'flash_attn'
|
||||
|
||||
def __from_env():
|
||||
import os
|
||||
|
||||
global BACKEND
|
||||
global DEBUG
|
||||
global ATTN
|
||||
|
||||
env_sparse_backend = os.environ.get('SPARSE_BACKEND')
|
||||
env_sparse_debug = os.environ.get('SPARSE_DEBUG')
|
||||
env_sparse_attn = os.environ.get('SPARSE_ATTN_BACKEND')
|
||||
if env_sparse_attn is None:
|
||||
env_sparse_attn = os.environ.get('ATTN_BACKEND')
|
||||
|
||||
if env_sparse_backend is not None and env_sparse_backend in ['spconv', 'torchsparse']:
|
||||
BACKEND = env_sparse_backend
|
||||
if env_sparse_debug is not None:
|
||||
DEBUG = env_sparse_debug == '1'
|
||||
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
|
||||
ATTN = env_sparse_attn
|
||||
|
||||
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
||||
|
||||
|
||||
__from_env()
|
||||
|
||||
|
||||
def set_backend(backend: Literal['spconv', 'torchsparse']):
|
||||
global BACKEND
|
||||
BACKEND = backend
|
||||
|
||||
def set_debug(debug: bool):
|
||||
global DEBUG
|
||||
DEBUG = debug
|
||||
|
||||
def set_attn(attn: Literal['xformers', 'flash_attn']):
|
||||
global ATTN
|
||||
ATTN = attn
|
||||
|
||||
|
||||
import importlib
|
||||
|
||||
__attributes = {
|
||||
'SparseTensor': 'basic',
|
||||
'sparse_batch_broadcast': 'basic',
|
||||
'sparse_batch_op': 'basic',
|
||||
'sparse_cat': 'basic',
|
||||
'sparse_unbind': 'basic',
|
||||
'SparseGroupNorm': 'norm',
|
||||
'SparseLayerNorm': 'norm',
|
||||
'SparseGroupNorm32': 'norm',
|
||||
'SparseLayerNorm32': 'norm',
|
||||
'SparseReLU': 'nonlinearity',
|
||||
'SparseSiLU': 'nonlinearity',
|
||||
'SparseGELU': 'nonlinearity',
|
||||
'SparseActivation': 'nonlinearity',
|
||||
'SparseLinear': 'linear',
|
||||
'sparse_scaled_dot_product_attention': 'attention',
|
||||
'SerializeMode': 'attention',
|
||||
'sparse_serialized_scaled_dot_product_self_attention': 'attention',
|
||||
'sparse_windowed_scaled_dot_product_self_attention': 'attention',
|
||||
'SparseMultiHeadAttention': 'attention',
|
||||
'SparseConv3d': 'conv',
|
||||
'SparseInverseConv3d': 'conv',
|
||||
'SparseDownsample': 'spatial',
|
||||
'SparseUpsample': 'spatial',
|
||||
'SparseSubdivide' : 'spatial'
|
||||
}
|
||||
|
||||
__submodules = ['transformer']
|
||||
|
||||
__all__ = list(__attributes.keys()) + __submodules
|
||||
|
||||
def __getattr__(name):
|
||||
if name not in globals():
|
||||
if name in __attributes:
|
||||
module_name = __attributes[name]
|
||||
module = importlib.import_module(f".{module_name}", __name__)
|
||||
globals()[name] = getattr(module, name)
|
||||
elif name in __submodules:
|
||||
module = importlib.import_module(f".{name}", __name__)
|
||||
globals()[name] = module
|
||||
else:
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
return globals()[name]
|
||||
|
||||
|
||||
# For Pylance
|
||||
if __name__ == '__main__':
|
||||
from .basic import *
|
||||
from .norm import *
|
||||
from .nonlinearity import *
|
||||
from .linear import *
|
||||
from .attention import *
|
||||
from .conv import *
|
||||
from .spatial import *
|
||||
import transformer
|
||||
4
trellis/modules/sparse/attention/__init__.py
Executable file
4
trellis/modules/sparse/attention/__init__.py
Executable file
@@ -0,0 +1,4 @@
|
||||
from .full_attn import *
|
||||
from .serialized_attn import *
|
||||
from .windowed_attn import *
|
||||
from .modules import *
|
||||
459
trellis/modules/sparse/basic.py
Executable file
459
trellis/modules/sparse/basic.py
Executable file
@@ -0,0 +1,459 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from . import BACKEND, DEBUG
|
||||
SparseTensorData = None # Lazy import
|
||||
|
||||
|
||||
__all__ = [
|
||||
'SparseTensor',
|
||||
'sparse_batch_broadcast',
|
||||
'sparse_batch_op',
|
||||
'sparse_cat',
|
||||
'sparse_unbind',
|
||||
]
|
||||
|
||||
|
||||
class SparseTensor:
|
||||
"""
|
||||
Sparse tensor with support for both torchsparse and spconv backends.
|
||||
|
||||
Parameters:
|
||||
- feats (torch.Tensor): Features of the sparse tensor.
|
||||
- coords (torch.Tensor): Coordinates of the sparse tensor.
|
||||
- shape (torch.Size): Shape of the sparse tensor.
|
||||
- layout (List[slice]): Layout of the sparse tensor for each batch
|
||||
- data (SparseTensorData): Sparse tensor data used for convolusion
|
||||
|
||||
NOTE:
|
||||
- Data corresponding to a same batch should be contiguous.
|
||||
- Coords should be in [0, 1023]
|
||||
"""
|
||||
@overload
|
||||
def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
|
||||
|
||||
@overload
|
||||
def __init__(self, data, shape: Optional[torch.Size] = None, layout: Optional[List[slice]] = None, **kwargs): ...
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Lazy import of sparse tensor backend
|
||||
global SparseTensorData
|
||||
if SparseTensorData is None:
|
||||
import importlib
|
||||
if BACKEND == 'torchsparse':
|
||||
SparseTensorData = importlib.import_module('torchsparse').SparseTensor
|
||||
elif BACKEND == 'spconv':
|
||||
SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
|
||||
|
||||
method_id = 0
|
||||
if len(args) != 0:
|
||||
method_id = 0 if isinstance(args[0], torch.Tensor) else 1
|
||||
else:
|
||||
method_id = 1 if 'data' in kwargs else 0
|
||||
|
||||
if method_id == 0:
|
||||
feats, coords, shape, layout = args + (None,) * (4 - len(args))
|
||||
if 'feats' in kwargs:
|
||||
feats = kwargs['feats']
|
||||
del kwargs['feats']
|
||||
if 'coords' in kwargs:
|
||||
coords = kwargs['coords']
|
||||
del kwargs['coords']
|
||||
if 'shape' in kwargs:
|
||||
shape = kwargs['shape']
|
||||
del kwargs['shape']
|
||||
if 'layout' in kwargs:
|
||||
layout = kwargs['layout']
|
||||
del kwargs['layout']
|
||||
|
||||
if shape is None:
|
||||
shape = self.__cal_shape(feats, coords)
|
||||
if layout is None:
|
||||
layout = self.__cal_layout(coords, shape[0])
|
||||
if BACKEND == 'torchsparse':
|
||||
self.data = SparseTensorData(feats, coords, **kwargs)
|
||||
elif BACKEND == 'spconv':
|
||||
spatial_shape = list(coords.max(0)[0] + 1)[1:]
|
||||
self.data = SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape, shape[0], **kwargs)
|
||||
self.data._features = feats
|
||||
elif method_id == 1:
|
||||
data, shape, layout = args + (None,) * (3 - len(args))
|
||||
if 'data' in kwargs:
|
||||
data = kwargs['data']
|
||||
del kwargs['data']
|
||||
if 'shape' in kwargs:
|
||||
shape = kwargs['shape']
|
||||
del kwargs['shape']
|
||||
if 'layout' in kwargs:
|
||||
layout = kwargs['layout']
|
||||
del kwargs['layout']
|
||||
|
||||
self.data = data
|
||||
if shape is None:
|
||||
shape = self.__cal_shape(self.feats, self.coords)
|
||||
if layout is None:
|
||||
layout = self.__cal_layout(self.coords, shape[0])
|
||||
|
||||
self._shape = shape
|
||||
self._layout = layout
|
||||
self._scale = kwargs.get('scale', (1, 1, 1))
|
||||
self._spatial_cache = kwargs.get('spatial_cache', {})
|
||||
|
||||
if DEBUG:
|
||||
try:
|
||||
assert self.feats.shape[0] == self.coords.shape[0], f"Invalid feats shape: {self.feats.shape}, coords shape: {self.coords.shape}"
|
||||
assert self.shape == self.__cal_shape(self.feats, self.coords), f"Invalid shape: {self.shape}"
|
||||
assert self.layout == self.__cal_layout(self.coords, self.shape[0]), f"Invalid layout: {self.layout}"
|
||||
for i in range(self.shape[0]):
|
||||
assert torch.all(self.coords[self.layout[i], 0] == i), f"The data of batch {i} is not contiguous"
|
||||
except Exception as e:
|
||||
print('Debugging information:')
|
||||
print(f"- Shape: {self.shape}")
|
||||
print(f"- Layout: {self.layout}")
|
||||
print(f"- Scale: {self._scale}")
|
||||
print(f"- Coords: {self.coords}")
|
||||
raise e
|
||||
|
||||
def __cal_shape(self, feats, coords):
|
||||
shape = []
|
||||
shape.append(coords[:, 0].max().item() + 1)
|
||||
shape.extend([*feats.shape[1:]])
|
||||
return torch.Size(shape)
|
||||
|
||||
def __cal_layout(self, coords, batch_size):
|
||||
seq_len = torch.bincount(coords[:, 0], minlength=batch_size)
|
||||
offset = torch.cumsum(seq_len, dim=0)
|
||||
layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)]
|
||||
return layout
|
||||
|
||||
@property
|
||||
def shape(self) -> torch.Size:
|
||||
return self._shape
|
||||
|
||||
def dim(self) -> int:
|
||||
return len(self.shape)
|
||||
|
||||
@property
|
||||
def layout(self) -> List[slice]:
|
||||
return self._layout
|
||||
|
||||
@property
|
||||
def feats(self) -> torch.Tensor:
|
||||
if BACKEND == 'torchsparse':
|
||||
return self.data.F
|
||||
elif BACKEND == 'spconv':
|
||||
return self.data.features
|
||||
|
||||
@feats.setter
|
||||
def feats(self, value: torch.Tensor):
|
||||
if BACKEND == 'torchsparse':
|
||||
self.data.F = value
|
||||
elif BACKEND == 'spconv':
|
||||
self.data.features = value
|
||||
|
||||
@property
|
||||
def coords(self) -> torch.Tensor:
|
||||
if BACKEND == 'torchsparse':
|
||||
return self.data.C
|
||||
elif BACKEND == 'spconv':
|
||||
return self.data.indices
|
||||
|
||||
@coords.setter
|
||||
def coords(self, value: torch.Tensor):
|
||||
if BACKEND == 'torchsparse':
|
||||
self.data.C = value
|
||||
elif BACKEND == 'spconv':
|
||||
self.data.indices = value
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.feats.dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.feats.device
|
||||
|
||||
@overload
|
||||
def to(self, dtype: torch.dtype) -> 'SparseTensor': ...
|
||||
|
||||
@overload
|
||||
def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None) -> 'SparseTensor': ...
|
||||
|
||||
def to(self, *args, **kwargs) -> 'SparseTensor':
|
||||
device = None
|
||||
dtype = None
|
||||
if len(args) == 2:
|
||||
device, dtype = args
|
||||
elif len(args) == 1:
|
||||
if isinstance(args[0], torch.dtype):
|
||||
dtype = args[0]
|
||||
else:
|
||||
device = args[0]
|
||||
if 'dtype' in kwargs:
|
||||
assert dtype is None, "to() received multiple values for argument 'dtype'"
|
||||
dtype = kwargs['dtype']
|
||||
if 'device' in kwargs:
|
||||
assert device is None, "to() received multiple values for argument 'device'"
|
||||
device = kwargs['device']
|
||||
|
||||
new_feats = self.feats.to(device=device, dtype=dtype)
|
||||
new_coords = self.coords.to(device=device)
|
||||
return self.replace(new_feats, new_coords)
|
||||
|
||||
def type(self, dtype):
|
||||
new_feats = self.feats.type(dtype)
|
||||
return self.replace(new_feats)
|
||||
|
||||
def cpu(self) -> 'SparseTensor':
|
||||
new_feats = self.feats.cpu()
|
||||
new_coords = self.coords.cpu()
|
||||
return self.replace(new_feats, new_coords)
|
||||
|
||||
def cuda(self) -> 'SparseTensor':
|
||||
new_feats = self.feats.cuda()
|
||||
new_coords = self.coords.cuda()
|
||||
return self.replace(new_feats, new_coords)
|
||||
|
||||
def half(self) -> 'SparseTensor':
|
||||
new_feats = self.feats.half()
|
||||
return self.replace(new_feats)
|
||||
|
||||
def float(self) -> 'SparseTensor':
|
||||
new_feats = self.feats.float()
|
||||
return self.replace(new_feats)
|
||||
|
||||
def detach(self) -> 'SparseTensor':
|
||||
new_coords = self.coords.detach()
|
||||
new_feats = self.feats.detach()
|
||||
return self.replace(new_feats, new_coords)
|
||||
|
||||
def dense(self) -> torch.Tensor:
|
||||
if BACKEND == 'torchsparse':
|
||||
return self.data.dense()
|
||||
elif BACKEND == 'spconv':
|
||||
return self.data.dense()
|
||||
|
||||
def reshape(self, *shape) -> 'SparseTensor':
|
||||
new_feats = self.feats.reshape(self.feats.shape[0], *shape)
|
||||
return self.replace(new_feats)
|
||||
|
||||
def unbind(self, dim: int) -> List['SparseTensor']:
|
||||
return sparse_unbind(self, dim)
|
||||
|
||||
def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor':
|
||||
new_shape = [self.shape[0]]
|
||||
new_shape.extend(feats.shape[1:])
|
||||
if BACKEND == 'torchsparse':
|
||||
new_data = SparseTensorData(
|
||||
feats=feats,
|
||||
coords=self.data.coords if coords is None else coords,
|
||||
stride=self.data.stride,
|
||||
spatial_range=self.data.spatial_range,
|
||||
)
|
||||
new_data._caches = self.data._caches
|
||||
elif BACKEND == 'spconv':
|
||||
new_data = SparseTensorData(
|
||||
self.data.features.reshape(self.data.features.shape[0], -1),
|
||||
self.data.indices,
|
||||
self.data.spatial_shape,
|
||||
self.data.batch_size,
|
||||
self.data.grid,
|
||||
self.data.voxel_num,
|
||||
self.data.indice_dict
|
||||
)
|
||||
new_data._features = feats
|
||||
new_data.benchmark = self.data.benchmark
|
||||
new_data.benchmark_record = self.data.benchmark_record
|
||||
new_data.thrust_allocator = self.data.thrust_allocator
|
||||
new_data._timer = self.data._timer
|
||||
new_data.force_algo = self.data.force_algo
|
||||
new_data.int8_scale = self.data.int8_scale
|
||||
if coords is not None:
|
||||
new_data.indices = coords
|
||||
new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache)
|
||||
return new_tensor
|
||||
|
||||
@staticmethod
|
||||
def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor':
|
||||
N, C = dim
|
||||
x = torch.arange(aabb[0], aabb[3] + 1)
|
||||
y = torch.arange(aabb[1], aabb[4] + 1)
|
||||
z = torch.arange(aabb[2], aabb[5] + 1)
|
||||
coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3)
|
||||
coords = torch.cat([
|
||||
torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1),
|
||||
coords.repeat(N, 1),
|
||||
], dim=1).to(dtype=torch.int32, device=device)
|
||||
feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device)
|
||||
return SparseTensor(feats=feats, coords=coords)
|
||||
|
||||
def __merge_sparse_cache(self, other: 'SparseTensor') -> dict:
|
||||
new_cache = {}
|
||||
for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())):
|
||||
if k in self._spatial_cache:
|
||||
new_cache[k] = self._spatial_cache[k]
|
||||
if k in other._spatial_cache:
|
||||
if k not in new_cache:
|
||||
new_cache[k] = other._spatial_cache[k]
|
||||
else:
|
||||
new_cache[k].update(other._spatial_cache[k])
|
||||
return new_cache
|
||||
|
||||
def __neg__(self) -> 'SparseTensor':
|
||||
return self.replace(-self.feats)
|
||||
|
||||
def __elemwise__(self, other: Union[torch.Tensor, 'SparseTensor'], op: callable) -> 'SparseTensor':
|
||||
if isinstance(other, torch.Tensor):
|
||||
try:
|
||||
other = torch.broadcast_to(other, self.shape)
|
||||
other = sparse_batch_broadcast(self, other)
|
||||
except:
|
||||
pass
|
||||
if isinstance(other, SparseTensor):
|
||||
other = other.feats
|
||||
new_feats = op(self.feats, other)
|
||||
new_tensor = self.replace(new_feats)
|
||||
if isinstance(other, SparseTensor):
|
||||
new_tensor._spatial_cache = self.__merge_sparse_cache(other)
|
||||
return new_tensor
|
||||
|
||||
def __add__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.add)
|
||||
|
||||
def __radd__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.add)
|
||||
|
||||
def __sub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.sub)
|
||||
|
||||
def __rsub__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, lambda x, y: torch.sub(y, x))
|
||||
|
||||
def __mul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.mul)
|
||||
|
||||
def __rmul__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.mul)
|
||||
|
||||
def __truediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, torch.div)
|
||||
|
||||
def __rtruediv__(self, other: Union[torch.Tensor, 'SparseTensor', float]) -> 'SparseTensor':
|
||||
return self.__elemwise__(other, lambda x, y: torch.div(y, x))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
idx = [idx]
|
||||
elif isinstance(idx, slice):
|
||||
idx = range(*idx.indices(self.shape[0]))
|
||||
elif isinstance(idx, torch.Tensor):
|
||||
if idx.dtype == torch.bool:
|
||||
assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}"
|
||||
idx = idx.nonzero().squeeze(1)
|
||||
elif idx.dtype in [torch.int32, torch.int64]:
|
||||
assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}"
|
||||
else:
|
||||
raise ValueError(f"Unknown index type: {idx.dtype}")
|
||||
else:
|
||||
raise ValueError(f"Unknown index type: {type(idx)}")
|
||||
|
||||
coords = []
|
||||
feats = []
|
||||
for new_idx, old_idx in enumerate(idx):
|
||||
coords.append(self.coords[self.layout[old_idx]].clone())
|
||||
coords[-1][:, 0] = new_idx
|
||||
feats.append(self.feats[self.layout[old_idx]])
|
||||
coords = torch.cat(coords, dim=0).contiguous()
|
||||
feats = torch.cat(feats, dim=0).contiguous()
|
||||
return SparseTensor(feats=feats, coords=coords)
|
||||
|
||||
def register_spatial_cache(self, key, value) -> None:
|
||||
"""
|
||||
Register a spatial cache.
|
||||
The spatial cache can be any thing you want to cache.
|
||||
The registery and retrieval of the cache is based on current scale.
|
||||
"""
|
||||
scale_key = str(self._scale)
|
||||
if scale_key not in self._spatial_cache:
|
||||
self._spatial_cache[scale_key] = {}
|
||||
self._spatial_cache[scale_key][key] = value
|
||||
|
||||
def get_spatial_cache(self, key=None):
|
||||
"""
|
||||
Get a spatial cache.
|
||||
"""
|
||||
scale_key = str(self._scale)
|
||||
cur_scale_cache = self._spatial_cache.get(scale_key, {})
|
||||
if key is None:
|
||||
return cur_scale_cache
|
||||
return cur_scale_cache.get(key, None)
|
||||
|
||||
|
||||
def sparse_batch_broadcast(input: SparseTensor, other: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): 1D tensor to broadcast.
|
||||
target (SparseTensor): Sparse tensor to broadcast to.
|
||||
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
|
||||
"""
|
||||
coords, feats = input.coords, input.feats
|
||||
broadcasted = torch.zeros_like(feats)
|
||||
for k in range(input.shape[0]):
|
||||
broadcasted[input.layout[k]] = other[k]
|
||||
return broadcasted
|
||||
|
||||
|
||||
def sparse_batch_op(input: SparseTensor, other: torch.Tensor, op: callable = torch.add) -> SparseTensor:
|
||||
"""
|
||||
Broadcast a 1D tensor to a sparse tensor along the batch dimension then perform an operation.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): 1D tensor to broadcast.
|
||||
target (SparseTensor): Sparse tensor to broadcast to.
|
||||
op (callable): Operation to perform after broadcasting. Defaults to torch.add.
|
||||
"""
|
||||
return input.replace(op(input.feats, sparse_batch_broadcast(input, other)))
|
||||
|
||||
|
||||
def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor:
|
||||
"""
|
||||
Concatenate a list of sparse tensors.
|
||||
|
||||
Args:
|
||||
inputs (List[SparseTensor]): List of sparse tensors to concatenate.
|
||||
"""
|
||||
if dim == 0:
|
||||
start = 0
|
||||
coords = []
|
||||
for input in inputs:
|
||||
coords.append(input.coords.clone())
|
||||
coords[-1][:, 0] += start
|
||||
start += input.shape[0]
|
||||
coords = torch.cat(coords, dim=0)
|
||||
feats = torch.cat([input.feats for input in inputs], dim=0)
|
||||
output = SparseTensor(
|
||||
coords=coords,
|
||||
feats=feats,
|
||||
)
|
||||
else:
|
||||
feats = torch.cat([input.feats for input in inputs], dim=dim)
|
||||
output = inputs[0].replace(feats)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]:
|
||||
"""
|
||||
Unbind a sparse tensor along a dimension.
|
||||
|
||||
Args:
|
||||
input (SparseTensor): Sparse tensor to unbind.
|
||||
dim (int): Dimension to unbind.
|
||||
"""
|
||||
if dim == 0:
|
||||
return [input[i] for i in range(input.shape[0])]
|
||||
else:
|
||||
feats = input.feats.unbind(dim)
|
||||
return [input.replace(f) for f in feats]
|
||||
21
trellis/modules/sparse/conv/__init__.py
Executable file
21
trellis/modules/sparse/conv/__init__.py
Executable file
@@ -0,0 +1,21 @@
|
||||
from .. import BACKEND
|
||||
|
||||
|
||||
SPCONV_ALGO = 'auto' # 'auto', 'implicit_gemm', 'native'
|
||||
|
||||
def __from_env():
|
||||
import os
|
||||
|
||||
global SPCONV_ALGO
|
||||
env_spconv_algo = os.environ.get('SPCONV_ALGO')
|
||||
if env_spconv_algo is not None and env_spconv_algo in ['auto', 'implicit_gemm', 'native']:
|
||||
SPCONV_ALGO = env_spconv_algo
|
||||
print(f"[SPARSE][CONV] spconv algo: {SPCONV_ALGO}")
|
||||
|
||||
|
||||
__from_env()
|
||||
|
||||
if BACKEND == 'torchsparse':
|
||||
from .conv_torchsparse import *
|
||||
elif BACKEND == 'spconv':
|
||||
from .conv_spconv import *
|
||||
2
trellis/modules/sparse/transformer/__init__.py
Normal file
2
trellis/modules/sparse/transformer/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .blocks import *
|
||||
from .modulated import *
|
||||
151
trellis/modules/sparse/transformer/blocks.py
Normal file
151
trellis/modules/sparse/transformer/blocks.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..basic import SparseTensor
|
||||
from ..linear import SparseLinear
|
||||
from ..nonlinearity import SparseGELU
|
||||
from ..attention import SparseMultiHeadAttention, SerializeMode
|
||||
from ...norm import LayerNorm32
|
||||
|
||||
|
||||
class SparseFeedForwardNet(nn.Module):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
SparseLinear(channels, int(channels * mlp_ratio)),
|
||||
SparseGELU(approximate="tanh"),
|
||||
SparseLinear(int(channels * mlp_ratio), channels),
|
||||
)
|
||||
|
||||
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class SparseTransformerBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer block (MSA + FFN).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_sequence: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
serialize_mode: Optional[SerializeMode] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
ln_affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_sequence=shift_sequence,
|
||||
shift_window=shift_window,
|
||||
serialize_mode=serialize_mode,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
|
||||
def _forward(self, x: SparseTensor) -> SparseTensor:
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = self.attn(h)
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = self.mlp(h)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class SparseTransformerCrossBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer cross-attention block (MSA + MCA + FFN).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
ctx_channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_sequence: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
serialize_mode: Optional[SerializeMode] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
ln_affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.self_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
type="self",
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_sequence=shift_sequence,
|
||||
shift_window=shift_window,
|
||||
serialize_mode=serialize_mode,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.cross_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
type="cross",
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: torch.Tensor):
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = self.self_attn(h)
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = x.replace(self.norm3(x.feats))
|
||||
h = self.mlp(h)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, context: torch.Tensor):
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, context)
|
||||
2
trellis/modules/transformer/__init__.py
Normal file
2
trellis/modules/transformer/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .blocks import *
|
||||
from .modulated import *
|
||||
182
trellis/modules/transformer/blocks.py
Normal file
182
trellis/modules/transformer/blocks.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..attention import MultiHeadAttention
|
||||
from ..norm import LayerNorm32
|
||||
|
||||
|
||||
class AbsolutePositionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds spatial positions into vector representations.
|
||||
"""
|
||||
def __init__(self, channels: int, in_channels: int = 3):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.in_channels = in_channels
|
||||
self.freq_dim = channels // in_channels // 2
|
||||
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
||||
self.freqs = 1.0 / (10000 ** self.freqs)
|
||||
|
||||
def _sin_cos_embedding(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Create sinusoidal position embeddings.
|
||||
|
||||
Args:
|
||||
x: a 1-D Tensor of N indices
|
||||
|
||||
Returns:
|
||||
an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
self.freqs = self.freqs.to(x.device)
|
||||
out = torch.outer(x, self.freqs)
|
||||
out = torch.cat([torch.sin(out), torch.cos(out)], dim=-1)
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (N, D) tensor of spatial positions
|
||||
"""
|
||||
N, D = x.shape
|
||||
assert D == self.in_channels, "Input dimension must match number of input channels"
|
||||
embed = self._sin_cos_embedding(x.reshape(-1))
|
||||
embed = embed.reshape(N, -1)
|
||||
if embed.shape[1] < self.channels:
|
||||
embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1], device=embed.device)], dim=-1)
|
||||
return embed
|
||||
|
||||
|
||||
class FeedForwardNet(nn.Module):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(channels, int(channels * mlp_ratio)),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(int(channels * mlp_ratio), channels),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""
|
||||
Transformer block (MSA + FFN).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[int] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
ln_affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.attn = MultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_window=shift_window,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.mlp = FeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
h = self.attn(h)
|
||||
x = x + h
|
||||
h = self.norm2(x)
|
||||
h = self.mlp(h)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class TransformerCrossBlock(nn.Module):
|
||||
"""
|
||||
Transformer cross-attention block (MSA + MCA + FFN).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
ctx_channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
ln_affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=ln_affine, eps=1e-6)
|
||||
self.self_attn = MultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
type="self",
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_window=shift_window,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.cross_attn = MultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
type="cross",
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
)
|
||||
self.mlp = FeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
|
||||
def _forward(self, x: torch.Tensor, context: torch.Tensor):
|
||||
h = self.norm1(x)
|
||||
h = self.self_attn(h)
|
||||
x = x + h
|
||||
h = self.norm2(x)
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = self.norm3(x)
|
||||
h = self.mlp(h)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, context: torch.Tensor):
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, context, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, context)
|
||||
|
||||
25
trellis/pipelines/__init__.py
Normal file
25
trellis/pipelines/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from . import samplers
|
||||
from .trellis_image_to_3d import TrellisImageTo3DPipeline
|
||||
from .trellis_text_to_3d import TrellisTextTo3DPipeline
|
||||
|
||||
|
||||
def from_pretrained(path: str):
|
||||
"""
|
||||
Load a pipeline from a model folder or a Hugging Face model hub.
|
||||
|
||||
Args:
|
||||
path: The path to the model. Can be either local path or a Hugging Face model name.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
is_local = os.path.exists(f"{path}/pipeline.json")
|
||||
|
||||
if is_local:
|
||||
config_file = f"{path}/pipeline.json"
|
||||
else:
|
||||
from huggingface_hub import hf_hub_download
|
||||
config_file = hf_hub_download(path, "pipeline.json")
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
return globals()[config['name']].from_pretrained(path)
|
||||
68
trellis/pipelines/base.py
Normal file
68
trellis/pipelines/base.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .. import models
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
A base class for pipelines.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
models: dict[str, nn.Module] = None,
|
||||
):
|
||||
if models is None:
|
||||
return
|
||||
self.models = models
|
||||
for model in self.models.values():
|
||||
model.eval()
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path: str) -> "Pipeline":
|
||||
"""
|
||||
Load a pretrained model.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
is_local = os.path.exists(f"{path}/pipeline.json")
|
||||
|
||||
if is_local:
|
||||
config_file = f"{path}/pipeline.json"
|
||||
else:
|
||||
from huggingface_hub import hf_hub_download
|
||||
config_file = hf_hub_download(path, "pipeline.json")
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
args = json.load(f)['args']
|
||||
|
||||
_models = {}
|
||||
for k, v in args['models'].items():
|
||||
try:
|
||||
_models[k] = models.from_pretrained(f"{path}/{v}")
|
||||
except:
|
||||
_models[k] = models.from_pretrained(v)
|
||||
|
||||
new_pipeline = Pipeline(_models)
|
||||
new_pipeline._pretrained_args = args
|
||||
return new_pipeline
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
for model in self.models.values():
|
||||
if hasattr(model, 'device'):
|
||||
return model.device
|
||||
for model in self.models.values():
|
||||
if hasattr(model, 'parameters'):
|
||||
return next(model.parameters()).device
|
||||
raise RuntimeError("No device found.")
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
for model in self.models.values():
|
||||
model.to(device)
|
||||
|
||||
def cuda(self) -> None:
|
||||
self.to(torch.device("cuda"))
|
||||
|
||||
def cpu(self) -> None:
|
||||
self.to(torch.device("cpu"))
|
||||
2
trellis/pipelines/samplers/__init__.py
Executable file
2
trellis/pipelines/samplers/__init__.py
Executable file
@@ -0,0 +1,2 @@
|
||||
from .base import Sampler
|
||||
from .flow_euler import FlowEulerSampler, FlowEulerCfgSampler, FlowEulerGuidanceIntervalSampler
|
||||
20
trellis/pipelines/samplers/base.py
Normal file
20
trellis/pipelines/samplers/base.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import *
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Sampler(ABC):
|
||||
"""
|
||||
A base class for samplers.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def sample(
|
||||
self,
|
||||
model,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Sample from a model.
|
||||
"""
|
||||
pass
|
||||
|
||||
12
trellis/pipelines/samplers/classifier_free_guidance_mixin.py
Normal file
12
trellis/pipelines/samplers/classifier_free_guidance_mixin.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import *
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceSamplerMixin:
|
||||
"""
|
||||
A mixin class for samplers that apply classifier-free guidance.
|
||||
"""
|
||||
|
||||
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
|
||||
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
|
||||
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
|
||||
31
trellis/renderers/__init__.py
Executable file
31
trellis/renderers/__init__.py
Executable file
@@ -0,0 +1,31 @@
|
||||
import importlib
|
||||
|
||||
__attributes = {
|
||||
'OctreeRenderer': 'octree_renderer',
|
||||
'GaussianRenderer': 'gaussian_render',
|
||||
'MeshRenderer': 'mesh_renderer',
|
||||
}
|
||||
|
||||
__submodules = []
|
||||
|
||||
__all__ = list(__attributes.keys()) + __submodules
|
||||
|
||||
def __getattr__(name):
|
||||
if name not in globals():
|
||||
if name in __attributes:
|
||||
module_name = __attributes[name]
|
||||
module = importlib.import_module(f".{module_name}", __name__)
|
||||
globals()[name] = getattr(module, name)
|
||||
elif name in __submodules:
|
||||
module = importlib.import_module(f".{name}", __name__)
|
||||
globals()[name] = module
|
||||
else:
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
return globals()[name]
|
||||
|
||||
|
||||
# For Pylance
|
||||
if __name__ == '__main__':
|
||||
from .octree_renderer import OctreeRenderer
|
||||
from .gaussian_render import GaussianRenderer
|
||||
from .mesh_renderer import MeshRenderer
|
||||
4
trellis/representations/__init__.py
Executable file
4
trellis/representations/__init__.py
Executable file
@@ -0,0 +1,4 @@
|
||||
from .radiance_field import Strivec
|
||||
from .octree import DfsOctree as Octree
|
||||
from .gaussian import Gaussian
|
||||
from .mesh import MeshExtractResult
|
||||
1
trellis/representations/gaussian/__init__.py
Executable file
1
trellis/representations/gaussian/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
from .gaussian_model import Gaussian
|
||||
1
trellis/representations/mesh/__init__.py
Normal file
1
trellis/representations/mesh/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .cube2mesh import SparseFeatures2Mesh, MeshExtractResult
|
||||
1
trellis/representations/octree/__init__.py
Executable file
1
trellis/representations/octree/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
from .octree_dfs import DfsOctree
|
||||
1
trellis/representations/radiance_field/__init__.py
Executable file
1
trellis/representations/radiance_field/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
from .strivec import Strivec
|
||||
63
trellis/trainers/__init__.py
Normal file
63
trellis/trainers/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import importlib
|
||||
|
||||
__attributes = {
|
||||
'BasicTrainer': 'basic',
|
||||
|
||||
'SparseStructureVaeTrainer': 'vae.sparse_structure_vae',
|
||||
|
||||
'SLatVaeGaussianTrainer': 'vae.structured_latent_vae_gaussian',
|
||||
'SLatVaeRadianceFieldDecoderTrainer': 'vae.structured_latent_vae_rf_dec',
|
||||
'SLatVaeMeshDecoderTrainer': 'vae.structured_latent_vae_mesh_dec',
|
||||
|
||||
'FlowMatchingTrainer': 'flow_matching.flow_matching',
|
||||
'FlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
||||
'TextConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
||||
'ImageConditionedFlowMatchingCFGTrainer': 'flow_matching.flow_matching',
|
||||
|
||||
'SparseFlowMatchingTrainer': 'flow_matching.sparse_flow_matching',
|
||||
'SparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
||||
'TextConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
||||
'ImageConditionedSparseFlowMatchingCFGTrainer': 'flow_matching.sparse_flow_matching',
|
||||
}
|
||||
|
||||
__submodules = []
|
||||
|
||||
__all__ = list(__attributes.keys()) + __submodules
|
||||
|
||||
def __getattr__(name):
|
||||
if name not in globals():
|
||||
if name in __attributes:
|
||||
module_name = __attributes[name]
|
||||
module = importlib.import_module(f".{module_name}", __name__)
|
||||
globals()[name] = getattr(module, name)
|
||||
elif name in __submodules:
|
||||
module = importlib.import_module(f".{name}", __name__)
|
||||
globals()[name] = module
|
||||
else:
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
return globals()[name]
|
||||
|
||||
|
||||
# For Pylance
|
||||
if __name__ == '__main__':
|
||||
from .basic import BasicTrainer
|
||||
|
||||
from .vae.sparse_structure_vae import SparseStructureVaeTrainer
|
||||
|
||||
from .vae.structured_latent_vae_gaussian import SLatVaeGaussianTrainer
|
||||
from .vae.structured_latent_vae_rf_dec import SLatVaeRadianceFieldDecoderTrainer
|
||||
from .vae.structured_latent_vae_mesh_dec import SLatVaeMeshDecoderTrainer
|
||||
|
||||
from .flow_matching.flow_matching import (
|
||||
FlowMatchingTrainer,
|
||||
FlowMatchingCFGTrainer,
|
||||
TextConditionedFlowMatchingCFGTrainer,
|
||||
ImageConditionedFlowMatchingCFGTrainer,
|
||||
)
|
||||
|
||||
from .flow_matching.sparse_flow_matching import (
|
||||
SparseFlowMatchingTrainer,
|
||||
SparseFlowMatchingCFGTrainer,
|
||||
TextConditionedSparseFlowMatchingCFGTrainer,
|
||||
ImageConditionedSparseFlowMatchingCFGTrainer,
|
||||
)
|
||||
451
trellis/trainers/base.py
Normal file
451
trellis/trainers/base.py
Normal file
@@ -0,0 +1,451 @@
|
||||
from abc import abstractmethod
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
|
||||
from torchvision import utils
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .utils import *
|
||||
from ..utils.general_utils import *
|
||||
from ..utils.data_utils import recursive_to_device, cycle, ResumableSampler
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Base class for training.
|
||||
"""
|
||||
def __init__(self,
|
||||
models,
|
||||
dataset,
|
||||
*,
|
||||
output_dir,
|
||||
load_dir,
|
||||
step,
|
||||
max_steps,
|
||||
batch_size=None,
|
||||
batch_size_per_gpu=None,
|
||||
batch_split=None,
|
||||
optimizer={},
|
||||
lr_scheduler=None,
|
||||
elastic=None,
|
||||
grad_clip=None,
|
||||
ema_rate=0.9999,
|
||||
fp16_mode='inflat_all',
|
||||
fp16_scale_growth=1e-3,
|
||||
finetune_ckpt=None,
|
||||
log_param_stats=False,
|
||||
prefetch_data=True,
|
||||
i_print=1000,
|
||||
i_log=500,
|
||||
i_sample=10000,
|
||||
i_save=10000,
|
||||
i_ddpcheck=10000,
|
||||
**kwargs
|
||||
):
|
||||
assert batch_size is not None or batch_size_per_gpu is not None, 'Either batch_size or batch_size_per_gpu must be specified.'
|
||||
|
||||
self.models = models
|
||||
self.dataset = dataset
|
||||
self.batch_split = batch_split if batch_split is not None else 1
|
||||
self.max_steps = max_steps
|
||||
self.optimizer_config = optimizer
|
||||
self.lr_scheduler_config = lr_scheduler
|
||||
self.elastic_controller_config = elastic
|
||||
self.grad_clip = grad_clip
|
||||
self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else ema_rate
|
||||
self.fp16_mode = fp16_mode
|
||||
self.fp16_scale_growth = fp16_scale_growth
|
||||
self.log_param_stats = log_param_stats
|
||||
self.prefetch_data = prefetch_data
|
||||
if self.prefetch_data:
|
||||
self._data_prefetched = None
|
||||
|
||||
self.output_dir = output_dir
|
||||
self.i_print = i_print
|
||||
self.i_log = i_log
|
||||
self.i_sample = i_sample
|
||||
self.i_save = i_save
|
||||
self.i_ddpcheck = i_ddpcheck
|
||||
|
||||
if dist.is_initialized():
|
||||
# Multi-GPU params
|
||||
self.world_size = dist.get_world_size()
|
||||
self.rank = dist.get_rank()
|
||||
self.local_rank = dist.get_rank() % torch.cuda.device_count()
|
||||
self.is_master = self.rank == 0
|
||||
else:
|
||||
# Single-GPU params
|
||||
self.world_size = 1
|
||||
self.rank = 0
|
||||
self.local_rank = 0
|
||||
self.is_master = True
|
||||
|
||||
self.batch_size = batch_size if batch_size_per_gpu is None else batch_size_per_gpu * self.world_size
|
||||
self.batch_size_per_gpu = batch_size_per_gpu if batch_size_per_gpu is not None else batch_size // self.world_size
|
||||
assert self.batch_size % self.world_size == 0, 'Batch size must be divisible by the number of GPUs.'
|
||||
assert self.batch_size_per_gpu % self.batch_split == 0, 'Batch size per GPU must be divisible by batch split.'
|
||||
|
||||
self.init_models_and_more(**kwargs)
|
||||
self.prepare_dataloader(**kwargs)
|
||||
|
||||
# Load checkpoint
|
||||
self.step = 0
|
||||
if load_dir is not None and step is not None:
|
||||
self.load(load_dir, step)
|
||||
elif finetune_ckpt is not None:
|
||||
self.finetune_from(finetune_ckpt)
|
||||
|
||||
if self.is_master:
|
||||
os.makedirs(os.path.join(self.output_dir, 'ckpts'), exist_ok=True)
|
||||
os.makedirs(os.path.join(self.output_dir, 'samples'), exist_ok=True)
|
||||
self.writer = SummaryWriter(os.path.join(self.output_dir, 'tb_logs'))
|
||||
|
||||
if self.world_size > 1:
|
||||
self.check_ddp()
|
||||
|
||||
if self.is_master:
|
||||
print('\n\nTrainer initialized.')
|
||||
print(self)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
for _, model in self.models.items():
|
||||
if hasattr(model, 'device'):
|
||||
return model.device
|
||||
return next(list(self.models.values())[0].parameters()).device
|
||||
|
||||
@abstractmethod
|
||||
def init_models_and_more(self, **kwargs):
|
||||
"""
|
||||
Initialize models and more.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_dataloader(self, **kwargs):
|
||||
"""
|
||||
Prepare dataloader.
|
||||
"""
|
||||
self.data_sampler = ResumableSampler(
|
||||
self.dataset,
|
||||
shuffle=True,
|
||||
)
|
||||
self.dataloader = DataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.batch_size_per_gpu,
|
||||
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
persistent_workers=True,
|
||||
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||
sampler=self.data_sampler,
|
||||
)
|
||||
self.data_iterator = cycle(self.dataloader)
|
||||
|
||||
@abstractmethod
|
||||
def load(self, load_dir, step=0):
|
||||
"""
|
||||
Load a checkpoint.
|
||||
Should be called by all processes.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self):
|
||||
"""
|
||||
Save a checkpoint.
|
||||
Should be called only by the rank 0 process.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def finetune_from(self, finetune_ckpt):
|
||||
"""
|
||||
Finetune from a checkpoint.
|
||||
Should be called by all processes.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_snapshot(self, num_samples, batch_size=4, verbose=False, **kwargs):
|
||||
"""
|
||||
Run a snapshot of the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def visualize_sample(self, sample):
|
||||
"""
|
||||
Convert a sample to an image.
|
||||
"""
|
||||
if hasattr(self.dataset, 'visualize_sample'):
|
||||
return self.dataset.visualize_sample(sample)
|
||||
else:
|
||||
return sample
|
||||
|
||||
@torch.no_grad()
|
||||
def snapshot_dataset(self, num_samples=100):
|
||||
"""
|
||||
Sample images from the dataset.
|
||||
"""
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
batch_size=num_samples,
|
||||
num_workers=0,
|
||||
shuffle=True,
|
||||
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||
)
|
||||
data = next(iter(dataloader))
|
||||
data = recursive_to_device(data, self.device)
|
||||
vis = self.visualize_sample(data)
|
||||
if isinstance(vis, dict):
|
||||
save_cfg = [(f'dataset_{k}', v) for k, v in vis.items()]
|
||||
else:
|
||||
save_cfg = [('dataset', vis)]
|
||||
for name, image in save_cfg:
|
||||
utils.save_image(
|
||||
image,
|
||||
os.path.join(self.output_dir, 'samples', f'{name}.jpg'),
|
||||
nrow=int(np.sqrt(num_samples)),
|
||||
normalize=True,
|
||||
value_range=self.dataset.value_range,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def snapshot(self, suffix=None, num_samples=64, batch_size=4, verbose=False):
|
||||
"""
|
||||
Sample images from the model.
|
||||
NOTE: This function should be called by all processes.
|
||||
"""
|
||||
if self.is_master:
|
||||
print(f'\nSampling {num_samples} images...', end='')
|
||||
|
||||
if suffix is None:
|
||||
suffix = f'step{self.step:07d}'
|
||||
|
||||
# Assign tasks
|
||||
num_samples_per_process = int(np.ceil(num_samples / self.world_size))
|
||||
samples = self.run_snapshot(num_samples_per_process, batch_size=batch_size, verbose=verbose)
|
||||
|
||||
# Preprocess images
|
||||
for key in list(samples.keys()):
|
||||
if samples[key]['type'] == 'sample':
|
||||
vis = self.visualize_sample(samples[key]['value'])
|
||||
if isinstance(vis, dict):
|
||||
for k, v in vis.items():
|
||||
samples[f'{key}_{k}'] = {'value': v, 'type': 'image'}
|
||||
del samples[key]
|
||||
else:
|
||||
samples[key] = {'value': vis, 'type': 'image'}
|
||||
|
||||
# Gather results
|
||||
if self.world_size > 1:
|
||||
for key in samples.keys():
|
||||
samples[key]['value'] = samples[key]['value'].contiguous()
|
||||
if self.is_master:
|
||||
all_images = [torch.empty_like(samples[key]['value']) for _ in range(self.world_size)]
|
||||
else:
|
||||
all_images = []
|
||||
dist.gather(samples[key]['value'], all_images, dst=0)
|
||||
if self.is_master:
|
||||
samples[key]['value'] = torch.cat(all_images, dim=0)[:num_samples]
|
||||
|
||||
# Save images
|
||||
if self.is_master:
|
||||
os.makedirs(os.path.join(self.output_dir, 'samples', suffix), exist_ok=True)
|
||||
for key in samples.keys():
|
||||
if samples[key]['type'] == 'image':
|
||||
utils.save_image(
|
||||
samples[key]['value'],
|
||||
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
||||
nrow=int(np.sqrt(num_samples)),
|
||||
normalize=True,
|
||||
value_range=self.dataset.value_range,
|
||||
)
|
||||
elif samples[key]['type'] == 'number':
|
||||
min = samples[key]['value'].min()
|
||||
max = samples[key]['value'].max()
|
||||
images = (samples[key]['value'] - min) / (max - min)
|
||||
images = utils.make_grid(
|
||||
images,
|
||||
nrow=int(np.sqrt(num_samples)),
|
||||
normalize=False,
|
||||
)
|
||||
save_image_with_notes(
|
||||
images,
|
||||
os.path.join(self.output_dir, 'samples', suffix, f'{key}_{suffix}.jpg'),
|
||||
notes=f'{key} min: {min}, max: {max}',
|
||||
)
|
||||
|
||||
if self.is_master:
|
||||
print(' Done.')
|
||||
|
||||
@abstractmethod
|
||||
def update_ema(self):
|
||||
"""
|
||||
Update exponential moving average.
|
||||
Should only be called by the rank 0 process.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def check_ddp(self):
|
||||
"""
|
||||
Check if DDP is working properly.
|
||||
Should be called by all process.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_losses(**mb_data):
|
||||
"""
|
||||
Compute training losses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_data(self):
|
||||
"""
|
||||
Load data.
|
||||
"""
|
||||
if self.prefetch_data:
|
||||
if self._data_prefetched is None:
|
||||
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
||||
data = self._data_prefetched
|
||||
self._data_prefetched = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
||||
else:
|
||||
data = recursive_to_device(next(self.data_iterator), self.device, non_blocking=True)
|
||||
|
||||
# if the data is a dict, we need to split it into multiple dicts with batch_size_per_gpu
|
||||
if isinstance(data, dict):
|
||||
if self.batch_split == 1:
|
||||
data_list = [data]
|
||||
else:
|
||||
batch_size = list(data.values())[0].shape[0]
|
||||
data_list = [
|
||||
{k: v[i * batch_size // self.batch_split:(i + 1) * batch_size // self.batch_split] for k, v in data.items()}
|
||||
for i in range(self.batch_split)
|
||||
]
|
||||
elif isinstance(data, list):
|
||||
data_list = data
|
||||
else:
|
||||
raise ValueError('Data must be a dict or a list of dicts.')
|
||||
|
||||
return data_list
|
||||
|
||||
@abstractmethod
|
||||
def run_step(self, data_list):
|
||||
"""
|
||||
Run a training step.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run training.
|
||||
"""
|
||||
if self.is_master:
|
||||
print('\nStarting training...')
|
||||
self.snapshot_dataset()
|
||||
if self.step == 0:
|
||||
self.snapshot(suffix='init')
|
||||
else: # resume
|
||||
self.snapshot(suffix=f'resume_step{self.step:07d}')
|
||||
|
||||
log = []
|
||||
time_last_print = 0.0
|
||||
time_elapsed = 0.0
|
||||
while self.step < self.max_steps:
|
||||
time_start = time.time()
|
||||
|
||||
data_list = self.load_data()
|
||||
step_log = self.run_step(data_list)
|
||||
|
||||
time_end = time.time()
|
||||
time_elapsed += time_end - time_start
|
||||
|
||||
self.step += 1
|
||||
|
||||
# Print progress
|
||||
if self.is_master and self.step % self.i_print == 0:
|
||||
speed = self.i_print / (time_elapsed - time_last_print) * 3600
|
||||
columns = [
|
||||
f'Step: {self.step}/{self.max_steps} ({self.step / self.max_steps * 100:.2f}%)',
|
||||
f'Elapsed: {time_elapsed / 3600:.2f} h',
|
||||
f'Speed: {speed:.2f} steps/h',
|
||||
f'ETA: {(self.max_steps - self.step) / speed:.2f} h',
|
||||
]
|
||||
print(' | '.join([c.ljust(25) for c in columns]), flush=True)
|
||||
time_last_print = time_elapsed
|
||||
|
||||
# Check ddp
|
||||
if self.world_size > 1 and self.i_ddpcheck is not None and self.step % self.i_ddpcheck == 0:
|
||||
self.check_ddp()
|
||||
|
||||
# Sample images
|
||||
if self.step % self.i_sample == 0:
|
||||
self.snapshot()
|
||||
|
||||
if self.is_master:
|
||||
log.append((self.step, {}))
|
||||
|
||||
# Log time
|
||||
log[-1][1]['time'] = {
|
||||
'step': time_end - time_start,
|
||||
'elapsed': time_elapsed,
|
||||
}
|
||||
|
||||
# Log losses
|
||||
if step_log is not None:
|
||||
log[-1][1].update(step_log)
|
||||
|
||||
# Log scale
|
||||
if self.fp16_mode == 'amp':
|
||||
log[-1][1]['scale'] = self.scaler.get_scale()
|
||||
elif self.fp16_mode == 'inflat_all':
|
||||
log[-1][1]['log_scale'] = self.log_scale
|
||||
|
||||
# Save log
|
||||
if self.step % self.i_log == 0:
|
||||
## save to log file
|
||||
log_str = '\n'.join([
|
||||
f'{step}: {json.dumps(log)}' for step, log in log
|
||||
])
|
||||
with open(os.path.join(self.output_dir, 'log.txt'), 'a') as log_file:
|
||||
log_file.write(log_str + '\n')
|
||||
|
||||
# show with mlflow
|
||||
log_show = [l for _, l in log if not dict_any(l, lambda x: np.isnan(x))]
|
||||
log_show = dict_reduce(log_show, lambda x: np.mean(x))
|
||||
log_show = dict_flatten(log_show, sep='/')
|
||||
for key, value in log_show.items():
|
||||
self.writer.add_scalar(key, value, self.step)
|
||||
log = []
|
||||
|
||||
# Save checkpoint
|
||||
if self.step % self.i_save == 0:
|
||||
self.save()
|
||||
|
||||
if self.is_master:
|
||||
self.snapshot(suffix='final')
|
||||
self.writer.close()
|
||||
print('Training finished.')
|
||||
|
||||
def profile(self, wait=2, warmup=3, active=5):
|
||||
"""
|
||||
Profile the training loop.
|
||||
"""
|
||||
with torch.profiler.profile(
|
||||
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(self.output_dir, 'profile')),
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
for _ in range(wait + warmup + active):
|
||||
self.run_step()
|
||||
prof.step()
|
||||
|
||||
438
trellis/trainers/basic.py
Normal file
438
trellis/trainers/basic.py
Normal file
@@ -0,0 +1,438 @@
|
||||
import os
|
||||
import copy
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import numpy as np
|
||||
|
||||
from .utils import *
|
||||
from .base import Trainer
|
||||
from ..utils.general_utils import *
|
||||
from ..utils.dist_utils import *
|
||||
from ..utils import grad_clip_utils, elastic_utils
|
||||
|
||||
|
||||
class BasicTrainer(Trainer):
|
||||
"""
|
||||
Trainer for basic training loop.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
lines = []
|
||||
lines.append(self.__class__.__name__)
|
||||
lines.append(f' - Models:')
|
||||
for name, model in self.models.items():
|
||||
lines.append(f' - {name}: {model.__class__.__name__}')
|
||||
lines.append(f' - Dataset: {indent(str(self.dataset), 2)}')
|
||||
lines.append(f' - Dataloader:')
|
||||
lines.append(f' - Sampler: {self.dataloader.sampler.__class__.__name__}')
|
||||
lines.append(f' - Num workers: {self.dataloader.num_workers}')
|
||||
lines.append(f' - Number of steps: {self.max_steps}')
|
||||
lines.append(f' - Number of GPUs: {self.world_size}')
|
||||
lines.append(f' - Batch size: {self.batch_size}')
|
||||
lines.append(f' - Batch size per GPU: {self.batch_size_per_gpu}')
|
||||
lines.append(f' - Batch split: {self.batch_split}')
|
||||
lines.append(f' - Optimizer: {self.optimizer.__class__.__name__}')
|
||||
lines.append(f' - Learning rate: {self.optimizer.param_groups[0]["lr"]}')
|
||||
if self.lr_scheduler_config is not None:
|
||||
lines.append(f' - LR scheduler: {self.lr_scheduler.__class__.__name__}')
|
||||
if self.elastic_controller_config is not None:
|
||||
lines.append(f' - Elastic memory: {indent(str(self.elastic_controller), 2)}')
|
||||
if self.grad_clip is not None:
|
||||
lines.append(f' - Gradient clip: {indent(str(self.grad_clip), 2)}')
|
||||
lines.append(f' - EMA rate: {self.ema_rate}')
|
||||
lines.append(f' - FP16 mode: {self.fp16_mode}')
|
||||
return '\n'.join(lines)
|
||||
|
||||
def init_models_and_more(self, **kwargs):
|
||||
"""
|
||||
Initialize models and more.
|
||||
"""
|
||||
if self.world_size > 1:
|
||||
# Prepare distributed data parallel
|
||||
self.training_models = {
|
||||
name: DDP(
|
||||
model,
|
||||
device_ids=[self.local_rank],
|
||||
output_device=self.local_rank,
|
||||
bucket_cap_mb=128,
|
||||
find_unused_parameters=False
|
||||
)
|
||||
for name, model in self.models.items()
|
||||
}
|
||||
else:
|
||||
self.training_models = self.models
|
||||
|
||||
# Build master params
|
||||
self.model_params = sum(
|
||||
[[p for p in model.parameters() if p.requires_grad] for model in self.models.values()]
|
||||
, [])
|
||||
if self.fp16_mode == 'amp':
|
||||
self.master_params = self.model_params
|
||||
self.scaler = torch.GradScaler() if self.fp16_mode == 'amp' else None
|
||||
elif self.fp16_mode == 'inflat_all':
|
||||
self.master_params = make_master_params(self.model_params)
|
||||
self.fp16_scale_growth = self.fp16_scale_growth
|
||||
self.log_scale = 20.0
|
||||
elif self.fp16_mode is None:
|
||||
self.master_params = self.model_params
|
||||
else:
|
||||
raise NotImplementedError(f'FP16 mode {self.fp16_mode} is not implemented.')
|
||||
|
||||
# Build EMA params
|
||||
if self.is_master:
|
||||
self.ema_params = [copy.deepcopy(self.master_params) for _ in self.ema_rate]
|
||||
|
||||
# Initialize optimizer
|
||||
if hasattr(torch.optim, self.optimizer_config['name']):
|
||||
self.optimizer = getattr(torch.optim, self.optimizer_config['name'])(self.master_params, **self.optimizer_config['args'])
|
||||
else:
|
||||
self.optimizer = globals()[self.optimizer_config['name']](self.master_params, **self.optimizer_config['args'])
|
||||
|
||||
# Initalize learning rate scheduler
|
||||
if self.lr_scheduler_config is not None:
|
||||
if hasattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name']):
|
||||
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_config['name'])(self.optimizer, **self.lr_scheduler_config['args'])
|
||||
else:
|
||||
self.lr_scheduler = globals()[self.lr_scheduler_config['name']](self.optimizer, **self.lr_scheduler_config['args'])
|
||||
|
||||
# Initialize elastic memory controller
|
||||
if self.elastic_controller_config is not None:
|
||||
assert any([isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)) for model in self.models.values()]), \
|
||||
'No elastic module found in models, please inherit from ElasticModule or ElasticModuleMixin'
|
||||
self.elastic_controller = getattr(elastic_utils, self.elastic_controller_config['name'])(**self.elastic_controller_config['args'])
|
||||
for model in self.models.values():
|
||||
if isinstance(model, (elastic_utils.ElasticModule, elastic_utils.ElasticModuleMixin)):
|
||||
model.register_memory_controller(self.elastic_controller)
|
||||
|
||||
# Initialize gradient clipper
|
||||
if self.grad_clip is not None:
|
||||
if isinstance(self.grad_clip, (float, int)):
|
||||
self.grad_clip = float(self.grad_clip)
|
||||
else:
|
||||
self.grad_clip = getattr(grad_clip_utils, self.grad_clip['name'])(**self.grad_clip['args'])
|
||||
|
||||
def _master_params_to_state_dicts(self, master_params):
|
||||
"""
|
||||
Convert master params to dict of state_dicts.
|
||||
"""
|
||||
if self.fp16_mode == 'inflat_all':
|
||||
master_params = unflatten_master_params(self.model_params, master_params)
|
||||
state_dicts = {name: model.state_dict() for name, model in self.models.items()}
|
||||
master_params_names = sum(
|
||||
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
||||
, [])
|
||||
for i, (model_name, param_name) in enumerate(master_params_names):
|
||||
state_dicts[model_name][param_name] = master_params[i]
|
||||
return state_dicts
|
||||
|
||||
def _state_dicts_to_master_params(self, master_params, state_dicts):
|
||||
"""
|
||||
Convert a state_dict to master params.
|
||||
"""
|
||||
master_params_names = sum(
|
||||
[[(name, n) for n, p in model.named_parameters() if p.requires_grad] for name, model in self.models.items()]
|
||||
, [])
|
||||
params = [state_dicts[name][param_name] for name, param_name in master_params_names]
|
||||
if self.fp16_mode == 'inflat_all':
|
||||
model_params_to_master_params(params, master_params)
|
||||
else:
|
||||
for i, param in enumerate(params):
|
||||
master_params[i].data.copy_(param.data)
|
||||
|
||||
def load(self, load_dir, step=0):
|
||||
"""
|
||||
Load a checkpoint.
|
||||
Should be called by all processes.
|
||||
"""
|
||||
if self.is_master:
|
||||
print(f'\nLoading checkpoint from step {step}...', end='')
|
||||
|
||||
model_ckpts = {}
|
||||
for name, model in self.models.items():
|
||||
model_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'{name}_step{step:07d}.pt')), map_location=self.device, weights_only=True)
|
||||
model_ckpts[name] = model_ckpt
|
||||
model.load_state_dict(model_ckpt)
|
||||
if self.fp16_mode == 'inflat_all':
|
||||
model.convert_to_fp16()
|
||||
self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
||||
del model_ckpts
|
||||
|
||||
if self.is_master:
|
||||
for i, ema_rate in enumerate(self.ema_rate):
|
||||
ema_ckpts = {}
|
||||
for name, model in self.models.items():
|
||||
ema_ckpt = torch.load(os.path.join(load_dir, 'ckpts', f'{name}_ema{ema_rate}_step{step:07d}.pt'), map_location=self.device, weights_only=True)
|
||||
ema_ckpts[name] = ema_ckpt
|
||||
self._state_dicts_to_master_params(self.ema_params[i], ema_ckpts)
|
||||
del ema_ckpts
|
||||
|
||||
misc_ckpt = torch.load(read_file_dist(os.path.join(load_dir, 'ckpts', f'misc_step{step:07d}.pt')), map_location=torch.device('cpu'), weights_only=False)
|
||||
self.optimizer.load_state_dict(misc_ckpt['optimizer'])
|
||||
self.step = misc_ckpt['step']
|
||||
self.data_sampler.load_state_dict(misc_ckpt['data_sampler'])
|
||||
if self.fp16_mode == 'amp':
|
||||
self.scaler.load_state_dict(misc_ckpt['scaler'])
|
||||
elif self.fp16_mode == 'inflat_all':
|
||||
self.log_scale = misc_ckpt['log_scale']
|
||||
if self.lr_scheduler_config is not None:
|
||||
self.lr_scheduler.load_state_dict(misc_ckpt['lr_scheduler'])
|
||||
if self.elastic_controller_config is not None:
|
||||
self.elastic_controller.load_state_dict(misc_ckpt['elastic_controller'])
|
||||
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
||||
self.grad_clip.load_state_dict(misc_ckpt['grad_clip'])
|
||||
del misc_ckpt
|
||||
|
||||
if self.world_size > 1:
|
||||
dist.barrier()
|
||||
if self.is_master:
|
||||
print(' Done.')
|
||||
|
||||
if self.world_size > 1:
|
||||
self.check_ddp()
|
||||
|
||||
def save(self):
|
||||
"""
|
||||
Save a checkpoint.
|
||||
Should be called only by the rank 0 process.
|
||||
"""
|
||||
assert self.is_master, 'save() should be called only by the rank 0 process.'
|
||||
print(f'\nSaving checkpoint at step {self.step}...', end='')
|
||||
|
||||
model_ckpts = self._master_params_to_state_dicts(self.master_params)
|
||||
for name, model_ckpt in model_ckpts.items():
|
||||
torch.save(model_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_step{self.step:07d}.pt'))
|
||||
|
||||
for i, ema_rate in enumerate(self.ema_rate):
|
||||
ema_ckpts = self._master_params_to_state_dicts(self.ema_params[i])
|
||||
for name, ema_ckpt in ema_ckpts.items():
|
||||
torch.save(ema_ckpt, os.path.join(self.output_dir, 'ckpts', f'{name}_ema{ema_rate}_step{self.step:07d}.pt'))
|
||||
|
||||
misc_ckpt = {
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'step': self.step,
|
||||
'data_sampler': self.data_sampler.state_dict(),
|
||||
}
|
||||
if self.fp16_mode == 'amp':
|
||||
misc_ckpt['scaler'] = self.scaler.state_dict()
|
||||
elif self.fp16_mode == 'inflat_all':
|
||||
misc_ckpt['log_scale'] = self.log_scale
|
||||
if self.lr_scheduler_config is not None:
|
||||
misc_ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
|
||||
if self.elastic_controller_config is not None:
|
||||
misc_ckpt['elastic_controller'] = self.elastic_controller.state_dict()
|
||||
if self.grad_clip is not None and not isinstance(self.grad_clip, float):
|
||||
misc_ckpt['grad_clip'] = self.grad_clip.state_dict()
|
||||
torch.save(misc_ckpt, os.path.join(self.output_dir, 'ckpts', f'misc_step{self.step:07d}.pt'))
|
||||
print(' Done.')
|
||||
|
||||
def finetune_from(self, finetune_ckpt):
|
||||
"""
|
||||
Finetune from a checkpoint.
|
||||
Should be called by all processes.
|
||||
"""
|
||||
if self.is_master:
|
||||
print('\nFinetuning from:')
|
||||
for name, path in finetune_ckpt.items():
|
||||
print(f' - {name}: {path}')
|
||||
|
||||
model_ckpts = {}
|
||||
for name, model in self.models.items():
|
||||
model_state_dict = model.state_dict()
|
||||
if name in finetune_ckpt:
|
||||
model_ckpt = torch.load(read_file_dist(finetune_ckpt[name]), map_location=self.device, weights_only=True)
|
||||
for k, v in model_ckpt.items():
|
||||
if model_ckpt[k].shape != model_state_dict[k].shape:
|
||||
if self.is_master:
|
||||
print(f'Warning: {k} shape mismatch, {model_ckpt[k].shape} vs {model_state_dict[k].shape}, skipped.')
|
||||
model_ckpt[k] = model_state_dict[k]
|
||||
model_ckpts[name] = model_ckpt
|
||||
model.load_state_dict(model_ckpt)
|
||||
if self.fp16_mode == 'inflat_all':
|
||||
model.convert_to_fp16()
|
||||
else:
|
||||
if self.is_master:
|
||||
print(f'Warning: {name} not found in finetune_ckpt, skipped.')
|
||||
model_ckpts[name] = model_state_dict
|
||||
self._state_dicts_to_master_params(self.master_params, model_ckpts)
|
||||
del model_ckpts
|
||||
|
||||
if self.world_size > 1:
|
||||
dist.barrier()
|
||||
if self.is_master:
|
||||
print('Done.')
|
||||
|
||||
if self.world_size > 1:
|
||||
self.check_ddp()
|
||||
|
||||
def update_ema(self):
|
||||
"""
|
||||
Update exponential moving average.
|
||||
Should only be called by the rank 0 process.
|
||||
"""
|
||||
assert self.is_master, 'update_ema() should be called only by the rank 0 process.'
|
||||
for i, ema_rate in enumerate(self.ema_rate):
|
||||
for master_param, ema_param in zip(self.master_params, self.ema_params[i]):
|
||||
ema_param.detach().mul_(ema_rate).add_(master_param, alpha=1.0 - ema_rate)
|
||||
|
||||
def check_ddp(self):
|
||||
"""
|
||||
Check if DDP is working properly.
|
||||
Should be called by all process.
|
||||
"""
|
||||
if self.is_master:
|
||||
print('\nPerforming DDP check...')
|
||||
|
||||
if self.is_master:
|
||||
print('Checking if parameters are consistent across processes...')
|
||||
dist.barrier()
|
||||
try:
|
||||
for p in self.master_params:
|
||||
# split to avoid OOM
|
||||
for i in range(0, p.numel(), 10000000):
|
||||
sub_size = min(10000000, p.numel() - i)
|
||||
sub_p = p.detach().view(-1)[i:i+sub_size]
|
||||
# gather from all processes
|
||||
sub_p_gather = [torch.empty_like(sub_p) for _ in range(self.world_size)]
|
||||
dist.all_gather(sub_p_gather, sub_p)
|
||||
# check if equal
|
||||
assert all([torch.equal(sub_p, sub_p_gather[i]) for i in range(self.world_size)]), 'parameters are not consistent across processes'
|
||||
except AssertionError as e:
|
||||
if self.is_master:
|
||||
print(f'\n\033[91mError: {e}\033[0m')
|
||||
print('DDP check failed.')
|
||||
raise e
|
||||
|
||||
dist.barrier()
|
||||
if self.is_master:
|
||||
print('Done.')
|
||||
|
||||
def run_step(self, data_list):
|
||||
"""
|
||||
Run a training step.
|
||||
"""
|
||||
step_log = {'loss': {}, 'status': {}}
|
||||
amp_context = partial(torch.autocast, device_type='cuda') if self.fp16_mode == 'amp' else nullcontext
|
||||
elastic_controller_context = self.elastic_controller.record if self.elastic_controller_config is not None else nullcontext
|
||||
|
||||
# Train
|
||||
losses = []
|
||||
statuses = []
|
||||
elastic_controller_logs = []
|
||||
zero_grad(self.model_params)
|
||||
for i, mb_data in enumerate(data_list):
|
||||
## sync at the end of each batch split
|
||||
sync_contexts = [self.training_models[name].no_sync for name in self.training_models] if i != len(data_list) - 1 and self.world_size > 1 else [nullcontext]
|
||||
with nested_contexts(*sync_contexts), elastic_controller_context():
|
||||
with amp_context():
|
||||
loss, status = self.training_losses(**mb_data)
|
||||
l = loss['loss'] / len(data_list)
|
||||
## backward
|
||||
if self.fp16_mode == 'amp':
|
||||
self.scaler.scale(l).backward()
|
||||
elif self.fp16_mode == 'inflat_all':
|
||||
scaled_l = l * (2 ** self.log_scale)
|
||||
scaled_l.backward()
|
||||
else:
|
||||
l.backward()
|
||||
## log
|
||||
losses.append(dict_foreach(loss, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
||||
statuses.append(dict_foreach(status, lambda x: x.item() if isinstance(x, torch.Tensor) else x))
|
||||
if self.elastic_controller_config is not None:
|
||||
elastic_controller_logs.append(self.elastic_controller.log())
|
||||
## gradient clip
|
||||
if self.grad_clip is not None:
|
||||
if self.fp16_mode == 'amp':
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
elif self.fp16_mode == 'inflat_all':
|
||||
model_grads_to_master_grads(self.model_params, self.master_params)
|
||||
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
||||
if isinstance(self.grad_clip, float):
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params, self.grad_clip)
|
||||
else:
|
||||
grad_norm = self.grad_clip(self.master_params)
|
||||
if torch.isfinite(grad_norm):
|
||||
statuses[-1]['grad_norm'] = grad_norm.item()
|
||||
## step
|
||||
if self.fp16_mode == 'amp':
|
||||
prev_scale = self.scaler.get_scale()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
elif self.fp16_mode == 'inflat_all':
|
||||
prev_scale = 2 ** self.log_scale
|
||||
if not any(not p.grad.isfinite().all() for p in self.model_params):
|
||||
if self.grad_clip is None:
|
||||
model_grads_to_master_grads(self.model_params, self.master_params)
|
||||
self.master_params[0].grad.mul_(1.0 / (2 ** self.log_scale))
|
||||
self.optimizer.step()
|
||||
master_params_to_model_params(self.model_params, self.master_params)
|
||||
self.log_scale += self.fp16_scale_growth
|
||||
else:
|
||||
self.log_scale -= 1
|
||||
else:
|
||||
prev_scale = 1.0
|
||||
if not any(not p.grad.isfinite().all() for p in self.model_params):
|
||||
self.optimizer.step()
|
||||
else:
|
||||
print('\n\033[93mWarning: NaN detected in gradients. Skipping update.\033[0m')
|
||||
## adjust learning rate
|
||||
if self.lr_scheduler_config is not None:
|
||||
statuses[-1]['lr'] = self.lr_scheduler.get_last_lr()[0]
|
||||
self.lr_scheduler.step()
|
||||
|
||||
# Logs
|
||||
step_log['loss'] = dict_reduce(losses, lambda x: np.mean(x))
|
||||
step_log['status'] = dict_reduce(statuses, lambda x: np.mean(x), special_func={'min': lambda x: np.min(x), 'max': lambda x: np.max(x)})
|
||||
if self.elastic_controller_config is not None:
|
||||
step_log['elastic'] = dict_reduce(elastic_controller_logs, lambda x: np.mean(x))
|
||||
if self.grad_clip is not None:
|
||||
step_log['grad_clip'] = self.grad_clip if isinstance(self.grad_clip, float) else self.grad_clip.log()
|
||||
|
||||
# Check grad and norm of each param
|
||||
if self.log_param_stats:
|
||||
param_norms = {}
|
||||
param_grads = {}
|
||||
for name, param in self.backbone.named_parameters():
|
||||
if param.requires_grad:
|
||||
param_norms[name] = param.norm().item()
|
||||
if param.grad is not None and torch.isfinite(param.grad).all():
|
||||
param_grads[name] = param.grad.norm().item() / prev_scale
|
||||
step_log['param_norms'] = param_norms
|
||||
step_log['param_grads'] = param_grads
|
||||
|
||||
# Update exponential moving average
|
||||
if self.is_master:
|
||||
self.update_ema()
|
||||
|
||||
return step_log
|
||||
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from ....utils.general_utils import dict_foreach
|
||||
from ....pipelines import samplers
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceMixin:
|
||||
def __init__(self, *args, p_uncond: float = 0.1, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.p_uncond = p_uncond
|
||||
|
||||
def get_cond(self, cond, neg_cond=None, **kwargs):
|
||||
"""
|
||||
Get the conditioning data.
|
||||
"""
|
||||
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
||||
|
||||
if self.p_uncond > 0:
|
||||
# randomly drop the class label
|
||||
def get_batch_size(cond):
|
||||
if isinstance(cond, torch.Tensor):
|
||||
return cond.shape[0]
|
||||
elif isinstance(cond, list):
|
||||
return len(cond)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
||||
|
||||
ref_cond = cond if not isinstance(cond, dict) else cond[list(cond.keys())[0]]
|
||||
B = get_batch_size(ref_cond)
|
||||
|
||||
def select(cond, neg_cond, mask):
|
||||
if isinstance(cond, torch.Tensor):
|
||||
mask = torch.tensor(mask, device=cond.device).reshape(-1, *[1] * (cond.ndim - 1))
|
||||
return torch.where(mask, neg_cond, cond)
|
||||
elif isinstance(cond, list):
|
||||
return [nc if m else c for c, nc, m in zip(cond, neg_cond, mask)]
|
||||
else:
|
||||
raise ValueError(f"Unsupported type of cond: {type(cond)}")
|
||||
|
||||
mask = list(np.random.rand(B) < self.p_uncond)
|
||||
if not isinstance(cond, dict):
|
||||
cond = select(cond, neg_cond, mask)
|
||||
else:
|
||||
cond = dict_foreach([cond, neg_cond], lambda x: select(x[0], x[1], mask))
|
||||
|
||||
return cond
|
||||
|
||||
def get_inference_cond(self, cond, neg_cond=None, **kwargs):
|
||||
"""
|
||||
Get the conditioning data for inference.
|
||||
"""
|
||||
assert neg_cond is not None, "neg_cond must be provided for classifier-free guidance"
|
||||
return {'cond': cond, 'neg_cond': neg_cond, **kwargs}
|
||||
|
||||
def get_sampler(self, **kwargs) -> samplers.FlowEulerCfgSampler:
|
||||
"""
|
||||
Get the sampler for the diffusion process.
|
||||
"""
|
||||
return samplers.FlowEulerCfgSampler(self.sigma_min)
|
||||
0
trellis/utils/__init__.py
Executable file
0
trellis/utils/__init__.py
Executable file
Reference in New Issue
Block a user