1
This commit is contained in:
96
trellis/models/__init__.py
Normal file
96
trellis/models/__init__.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import importlib
|
||||
|
||||
__attributes = {
|
||||
'SparseStructureEncoder': 'sparse_structure_vae',
|
||||
'SparseStructureDecoder': 'sparse_structure_vae',
|
||||
|
||||
'SparseStructureFlowModel': 'sparse_structure_flow',
|
||||
|
||||
'SLatEncoder': 'structured_latent_vae',
|
||||
'SLatGaussianDecoder': 'structured_latent_vae',
|
||||
'SLatRadianceFieldDecoder': 'structured_latent_vae',
|
||||
'SLatMeshDecoder': 'structured_latent_vae',
|
||||
'ElasticSLatEncoder': 'structured_latent_vae',
|
||||
'ElasticSLatGaussianDecoder': 'structured_latent_vae',
|
||||
'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
|
||||
'ElasticSLatMeshDecoder': 'structured_latent_vae',
|
||||
|
||||
'SLatFlowModel': 'structured_latent_flow',
|
||||
'ElasticSLatFlowModel': 'structured_latent_flow',
|
||||
}
|
||||
|
||||
__submodules = []
|
||||
|
||||
__all__ = list(__attributes.keys()) + __submodules
|
||||
|
||||
def __getattr__(name):
|
||||
if name not in globals():
|
||||
if name in __attributes:
|
||||
module_name = __attributes[name]
|
||||
module = importlib.import_module(f".{module_name}", __name__)
|
||||
globals()[name] = getattr(module, name)
|
||||
elif name in __submodules:
|
||||
module = importlib.import_module(f".{name}", __name__)
|
||||
globals()[name] = module
|
||||
else:
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
return globals()[name]
|
||||
|
||||
|
||||
def from_pretrained(path: str, **kwargs):
|
||||
"""
|
||||
Load a model from a pretrained checkpoint.
|
||||
|
||||
Args:
|
||||
path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
|
||||
NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
|
||||
**kwargs: Additional arguments for the model constructor.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from safetensors.torch import load_file
|
||||
is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
|
||||
|
||||
if is_local:
|
||||
config_file = f"{path}.json"
|
||||
model_file = f"{path}.safetensors"
|
||||
else:
|
||||
from huggingface_hub import hf_hub_download
|
||||
path_parts = path.split('/')
|
||||
repo_id = f'{path_parts[0]}/{path_parts[1]}'
|
||||
model_name = '/'.join(path_parts[2:])
|
||||
config_file = hf_hub_download(repo_id, f"{model_name}.json")
|
||||
model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
|
||||
|
||||
with open(config_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
model = __getattr__(config['name'])(**config['args'], **kwargs)
|
||||
model.load_state_dict(load_file(model_file))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# For Pylance
|
||||
if __name__ == '__main__':
|
||||
from .sparse_structure_vae import (
|
||||
SparseStructureEncoder,
|
||||
SparseStructureDecoder,
|
||||
)
|
||||
|
||||
from .sparse_structure_flow import SparseStructureFlowModel
|
||||
|
||||
from .structured_latent_vae import (
|
||||
SLatEncoder,
|
||||
SLatGaussianDecoder,
|
||||
SLatRadianceFieldDecoder,
|
||||
SLatMeshDecoder,
|
||||
ElasticSLatEncoder,
|
||||
ElasticSLatGaussianDecoder,
|
||||
ElasticSLatRadianceFieldDecoder,
|
||||
ElasticSLatMeshDecoder,
|
||||
)
|
||||
|
||||
from .structured_latent_flow import (
|
||||
SLatFlowModel,
|
||||
ElasticSLatFlowModel,
|
||||
)
|
||||
4
trellis/models/structured_latent_vae/__init__.py
Normal file
4
trellis/models/structured_latent_vae/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .encoder import SLatEncoder, ElasticSLatEncoder
|
||||
from .decoder_gs import SLatGaussianDecoder, ElasticSLatGaussianDecoder
|
||||
from .decoder_rf import SLatRadianceFieldDecoder, ElasticSLatRadianceFieldDecoder
|
||||
from .decoder_mesh import SLatMeshDecoder, ElasticSLatMeshDecoder
|
||||
117
trellis/models/structured_latent_vae/base.py
Normal file
117
trellis/models/structured_latent_vae/base.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ...modules.utils import convert_module_to_f16, convert_module_to_f32
|
||||
from ...modules import sparse as sp
|
||||
from ...modules.transformer import AbsolutePositionEmbedder
|
||||
from ...modules.sparse.transformer import SparseTransformerBlock
|
||||
|
||||
|
||||
def block_attn_config(self):
|
||||
"""
|
||||
Return the attention configuration of the model.
|
||||
"""
|
||||
for i in range(self.num_blocks):
|
||||
if self.attn_mode == "shift_window":
|
||||
yield "serialized", self.window_size, 0, (16 * (i % 2),) * 3, sp.SerializeMode.Z_ORDER
|
||||
elif self.attn_mode == "shift_sequence":
|
||||
yield "serialized", self.window_size, self.window_size // 2 * (i % 2), (0, 0, 0), sp.SerializeMode.Z_ORDER
|
||||
elif self.attn_mode == "shift_order":
|
||||
yield "serialized", self.window_size, 0, (0, 0, 0), sp.SerializeModes[i % 4]
|
||||
elif self.attn_mode == "full":
|
||||
yield "full", None, None, None, None
|
||||
elif self.attn_mode == "swin":
|
||||
yield "windowed", self.window_size, None, self.window_size // 2 * (i % 2), None
|
||||
|
||||
|
||||
class SparseTransformerBase(nn.Module):
|
||||
"""
|
||||
Sparse Transformer without output layers.
|
||||
Serve as the base class for encoder and decoder.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
pe_mode: Literal["ape", "rope"] = "ape",
|
||||
use_fp16: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.window_size = window_size
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.attn_mode = attn_mode
|
||||
self.pe_mode = pe_mode
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
|
||||
if pe_mode == "ape":
|
||||
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
||||
|
||||
self.input_layer = sp.SparseLinear(in_channels, model_channels)
|
||||
self.blocks = nn.ModuleList([
|
||||
SparseTransformerBlock(
|
||||
model_channels,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_sequence=shift_sequence,
|
||||
shift_window=shift_window,
|
||||
serialize_mode=serialize_mode,
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
use_rope=(pe_mode == "rope"),
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
)
|
||||
for attn_mode, window_size, shift_sequence, shift_window, serialize_mode in block_attn_config(self)
|
||||
])
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
Return the device of the model.
|
||||
"""
|
||||
return next(self.parameters()).device
|
||||
|
||||
def convert_to_fp16(self) -> None:
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.blocks.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self) -> None:
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.blocks.apply(convert_module_to_f32)
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
self.apply(_basic_init)
|
||||
|
||||
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
||||
h = self.input_layer(x)
|
||||
if self.pe_mode == "ape":
|
||||
h = h + self.pos_embedder(x.coords[:, 1:])
|
||||
h = h.type(self.dtype)
|
||||
for block in self.blocks:
|
||||
h = block(h)
|
||||
return h
|
||||
Reference in New Issue
Block a user