From 6c79bdb20f93b43c58ac29d06177ef8bbb3ff5a2 Mon Sep 17 00:00:00 2001 From: zcr Date: Tue, 17 Mar 2026 11:29:03 +0800 Subject: [PATCH] 1 --- trellis/datasets/components.py | 137 +++++++++++ trellis/modules/sparse/conv/conv_spconv.py | 80 +++++++ .../modules/sparse/conv/conv_torchsparse.py | 38 +++ trellis/representations/mesh/cube2mesh.py | 143 +++++++++++ trellis/utils/data_utils.py | 226 ++++++++++++++++++ 5 files changed, 624 insertions(+) create mode 100644 trellis/datasets/components.py create mode 100755 trellis/modules/sparse/conv/conv_spconv.py create mode 100755 trellis/modules/sparse/conv/conv_torchsparse.py create mode 100644 trellis/representations/mesh/cube2mesh.py create mode 100644 trellis/utils/data_utils.py diff --git a/trellis/datasets/components.py b/trellis/datasets/components.py new file mode 100644 index 0000000..9c863d6 --- /dev/null +++ b/trellis/datasets/components.py @@ -0,0 +1,137 @@ +from typing import * +from abc import abstractmethod +import os +import json +import torch +import numpy as np +import pandas as pd +from PIL import Image +from torch.utils.data import Dataset + + +class StandardDatasetBase(Dataset): + """ + Base class for standard datasets. + + Args: + roots (str): paths to the dataset + """ + + def __init__(self, + roots: str, + ): + super().__init__() + self.roots = roots.split(',') + self.instances = [] + self.metadata = pd.DataFrame() + + self._stats = {} + for root in self.roots: + key = os.path.basename(root) + self._stats[key] = {} + metadata = pd.read_csv(os.path.join(root, 'metadata.csv')) + self._stats[key]['Total'] = len(metadata) + metadata, stats = self.filter_metadata(metadata) + self._stats[key].update(stats) + self.instances.extend([(root, sha256) for sha256 in metadata['sha256'].values]) + metadata.set_index('sha256', inplace=True) + self.metadata = pd.concat([self.metadata, metadata]) + + @abstractmethod + def filter_metadata(self, metadata: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]: + pass + + @abstractmethod + def get_instance(self, root: str, instance: str) -> Dict[str, Any]: + pass + + def __len__(self): + return len(self.instances) + + def __getitem__(self, index) -> Dict[str, Any]: + try: + root, instance = self.instances[index] + return self.get_instance(root, instance) + except Exception as e: + print(e) + return self.__getitem__(np.random.randint(0, len(self))) + + def __str__(self): + lines = [] + lines.append(self.__class__.__name__) + lines.append(f' - Total instances: {len(self)}') + lines.append(f' - Sources:') + for key, stats in self._stats.items(): + lines.append(f' - {key}:') + for k, v in stats.items(): + lines.append(f' - {k}: {v}') + return '\n'.join(lines) + + +class TextConditionedMixin: + def __init__(self, roots, **kwargs): + super().__init__(roots, **kwargs) + self.captions = {} + for instance in self.instances: + sha256 = instance[1] + self.captions[sha256] = json.loads(self.metadata.loc[sha256]['captions']) + + def filter_metadata(self, metadata): + metadata, stats = super().filter_metadata(metadata) + metadata = metadata[metadata['captions'].notna()] + stats['With captions'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + pack = super().get_instance(root, instance) + text = np.random.choice(self.captions[instance]) + pack['cond'] = text + return pack + + +class ImageConditionedMixin: + def __init__(self, roots, *, image_size=518, **kwargs): + self.image_size = image_size + super().__init__(roots, **kwargs) + + def filter_metadata(self, metadata): + metadata, stats = super().filter_metadata(metadata) + metadata = metadata[metadata[f'cond_rendered']] + stats['Cond rendered'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + pack = super().get_instance(root, instance) + + image_root = os.path.join(root, 'renders_cond', instance) + with open(os.path.join(image_root, 'transforms.json')) as f: + metadata = json.load(f) + n_views = len(metadata['frames']) + view = np.random.randint(n_views) + metadata = metadata['frames'][view] + + image_path = os.path.join(image_root, metadata['file_path']) + image = Image.open(image_path) + + alpha = np.array(image.getchannel(3)) + bbox = np.array(alpha).nonzero() + bbox = [bbox[1].min(), bbox[0].min(), bbox[1].max(), bbox[0].max()] + center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] + hsize = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 + aug_size_ratio = 1.2 + aug_hsize = hsize * aug_size_ratio + aug_center_offset = [0, 0] + aug_center = [center[0] + aug_center_offset[0], center[1] + aug_center_offset[1]] + aug_bbox = [int(aug_center[0] - aug_hsize), int(aug_center[1] - aug_hsize), int(aug_center[0] + aug_hsize), int(aug_center[1] + aug_hsize)] + image = image.crop(aug_bbox) + + image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + alpha = image.getchannel(3) + image = image.convert('RGB') + image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 + alpha = torch.tensor(np.array(alpha)).float() / 255.0 + image = image * alpha.unsqueeze(0) + pack['cond'] = image + + return pack + \ No newline at end of file diff --git a/trellis/modules/sparse/conv/conv_spconv.py b/trellis/modules/sparse/conv/conv_spconv.py new file mode 100755 index 0000000..524bcd4 --- /dev/null +++ b/trellis/modules/sparse/conv/conv_spconv.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +from .. import SparseTensor +from .. import DEBUG +from . import SPCONV_ALGO + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + algo = None + if SPCONV_ALGO == 'native': + algo = spconv.ConvAlgo.Native + elif SPCONV_ALGO == 'implicit_gemm': + algo = spconv.ConvAlgo.MaskImplicitGemm + if stride == 1 and (padding is None): + self.conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, dilation=dilation, bias=bias, indice_key=indice_key, algo=algo) + else: + self.conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias, indice_key=indice_key, algo=algo) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + self.padding = padding + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) or (self.padding is not None) + new_data = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + + if spatial_changed and (x.shape[0] != 1): + # spconv was non-1 stride will break the contiguous of the output tensor, sort by the coords + fwd = new_data.indices[:, 0].argsort() + bwd = torch.zeros_like(fwd).scatter_(0, fwd, torch.arange(fwd.shape[0], device=fwd.device)) + sorted_feats = new_data.features[fwd] + sorted_coords = new_data.indices[fwd] + unsorted_data = new_data + new_data = spconv.SparseConvTensor(sorted_feats, sorted_coords, unsorted_data.spatial_shape, unsorted_data.batch_size) # type: ignore + + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s * stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + + if spatial_changed and (x.shape[0] != 1): + out.register_spatial_cache(f'conv_{self.stride}_unsorted_data', unsorted_data) + out.register_spatial_cache(f'conv_{self.stride}_sort_bwd', bwd) + + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'spconv' not in globals(): + import spconv.pytorch as spconv + self.conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, bias=bias, indice_key=indice_key) + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, stride, stride) + + def forward(self, x: SparseTensor) -> SparseTensor: + spatial_changed = any(s != 1 for s in self.stride) + if spatial_changed: + # recover the original spconv order + data = x.get_spatial_cache(f'conv_{self.stride}_unsorted_data') + bwd = x.get_spatial_cache(f'conv_{self.stride}_sort_bwd') + data = data.replace_feature(x.feats[bwd]) + if DEBUG: + assert torch.equal(data.indices, x.coords[bwd]), 'Recover the original order failed' + else: + data = x.data + + new_data = self.conv(data) + new_shape = [x.shape[0], self.conv.out_channels] + new_layout = None if spatial_changed else x.layout + out = SparseTensor( + new_data, shape=torch.Size(new_shape), layout=new_layout, + scale=tuple([s // stride for s, stride in zip(x._scale, self.stride)]), + spatial_cache=x._spatial_cache, + ) + return out diff --git a/trellis/modules/sparse/conv/conv_torchsparse.py b/trellis/modules/sparse/conv/conv_torchsparse.py new file mode 100755 index 0000000..1d61258 --- /dev/null +++ b/trellis/modules/sparse/conv/conv_torchsparse.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +from .. import SparseTensor + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s * stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + +class SparseInverseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, bias=True, indice_key=None): + super(SparseInverseConv3d, self).__init__() + if 'torchsparse' not in globals(): + import torchsparse + self.conv = torchsparse.nn.Conv3d(in_channels, out_channels, kernel_size, stride, 0, dilation, bias, transposed=True) + + def forward(self, x: SparseTensor) -> SparseTensor: + out = self.conv(x.data) + new_shape = [x.shape[0], self.conv.out_channels] + out = SparseTensor(out, shape=torch.Size(new_shape), layout=x.layout if all(s == 1 for s in self.conv.stride) else None) + out._spatial_cache = x._spatial_cache + out._scale = tuple([s // stride for s, stride in zip(x._scale, self.conv.stride)]) + return out + + + diff --git a/trellis/representations/mesh/cube2mesh.py b/trellis/representations/mesh/cube2mesh.py new file mode 100644 index 0000000..44e8776 --- /dev/null +++ b/trellis/representations/mesh/cube2mesh.py @@ -0,0 +1,143 @@ +import torch +from ...modules.sparse import SparseTensor +from easydict import EasyDict as edict +from .utils_cube import * +from .flexicubes.flexicubes import FlexiCubes + + +class MeshExtractResult: + def __init__(self, + vertices, + faces, + vertex_attrs=None, + res=64 + ): + self.vertices = vertices + self.faces = faces.long() + self.vertex_attrs = vertex_attrs + self.face_normal = self.comput_face_normals(vertices, faces) + self.res = res + self.success = (vertices.shape[0] != 0 and faces.shape[0] != 0) + + # training only + self.tsdf_v = None + self.tsdf_s = None + self.reg_loss = None + + def comput_face_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + face_normals = torch.nn.functional.normalize(face_normals, dim=1) + # print(face_normals.min(), face_normals.max(), face_normals.shape) + return face_normals[:, None, :].repeat(1, 3, 1) + + def comput_v_normals(self, verts, faces): + i0 = faces[..., 0].long() + i1 = faces[..., 1].long() + i2 = faces[..., 2].long() + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + v_normals = torch.zeros_like(verts) + v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals) + v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals) + + v_normals = torch.nn.functional.normalize(v_normals, dim=1) + return v_normals + + +class SparseFeatures2Mesh: + def __init__(self, device="cuda", res=64, use_color=True): + ''' + a model to generate a mesh from sparse features structures using flexicube + ''' + super().__init__() + self.device=device + self.res = res + self.mesh_extractor = FlexiCubes(device=device) + self.sdf_bias = -1.0 / res + verts, cube = construct_dense_grid(self.res, self.device) + self.reg_c = cube.to(self.device) + self.reg_v = verts.to(self.device) + self.use_color = use_color + self._calc_layout() + + def _calc_layout(self): + LAYOUTS = { + 'sdf': {'shape': (8, 1), 'size': 8}, + 'deform': {'shape': (8, 3), 'size': 8 * 3}, + 'weights': {'shape': (21,), 'size': 21} + } + if self.use_color: + ''' + 6 channel color including normal map + ''' + LAYOUTS['color'] = {'shape': (8, 6,), 'size': 8 * 6} + self.layouts = edict(LAYOUTS) + start = 0 + for k, v in self.layouts.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.feats_channels = start + + def get_layout(self, feats : torch.Tensor, name : str): + if name not in self.layouts: + return None + return feats[:, self.layouts[name]['range'][0]:self.layouts[name]['range'][1]].reshape(-1, *self.layouts[name]['shape']) + + def __call__(self, cubefeats : SparseTensor, training=False): + """ + Generates a mesh based on the specified sparse voxel structures. + Args: + cube_attrs [Nx21] : Sparse Tensor attrs about cube weights + verts_attrs [Nx10] : [0:1] SDF [1:4] deform [4:7] color [7:10] normal + Returns: + return the success tag and ni you loss, + """ + # add sdf bias to verts_attrs + coords = cubefeats.coords[:, 1:] + feats = cubefeats.feats + + sdf, deform, color, weights = [self.get_layout(feats, name) for name in ['sdf', 'deform', 'color', 'weights']] + sdf += self.sdf_bias + v_attrs = [sdf, deform, color] if self.use_color else [sdf, deform] + v_pos, v_attrs, reg_loss = sparse_cube2verts(coords, torch.cat(v_attrs, dim=-1), training=training) + v_attrs_d = get_dense_attrs(v_pos, v_attrs, res=self.res+1, sdf_init=True) + weights_d = get_dense_attrs(coords, weights, res=self.res, sdf_init=False) + if self.use_color: + sdf_d, deform_d, colors_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4], v_attrs_d[..., 4:] + else: + sdf_d, deform_d = v_attrs_d[..., 0], v_attrs_d[..., 1:4] + colors_d = None + + x_nx3 = get_defomed_verts(self.reg_v, deform_d, self.res) + + vertices, faces, L_dev, colors = self.mesh_extractor( + voxelgrid_vertices=x_nx3, + scalar_field=sdf_d, + cube_idx=self.reg_c, + resolution=self.res, + beta=weights_d[:, :12], + alpha=weights_d[:, 12:20], + gamma_f=weights_d[:, 20], + voxelgrid_colors=colors_d, + training=training) + + mesh = MeshExtractResult(vertices=vertices, faces=faces, vertex_attrs=colors, res=self.res) + if training: + if mesh.success: + reg_loss += L_dev.mean() * 0.5 + reg_loss += (weights[:,:20]).abs().mean() * 0.2 + mesh.reg_loss = reg_loss + mesh.tsdf_v = get_defomed_verts(v_pos, v_attrs[:, 1:4], self.res) + mesh.tsdf_s = v_attrs[:, 0] + return mesh diff --git a/trellis/utils/data_utils.py b/trellis/utils/data_utils.py new file mode 100644 index 0000000..805b6cc --- /dev/null +++ b/trellis/utils/data_utils.py @@ -0,0 +1,226 @@ +from typing import * +import math +import torch +import numpy as np +from torch.utils.data import Sampler, Dataset, DataLoader, DistributedSampler +import torch.distributed as dist + + +def recursive_to_device( + data: Any, + device: torch.device, + non_blocking: bool = False, +) -> Any: + """ + Recursively move all tensors in a data structure to a device. + """ + if hasattr(data, "to"): + return data.to(device, non_blocking=non_blocking) + elif isinstance(data, (list, tuple)): + return type(data)(recursive_to_device(d, device, non_blocking) for d in data) + elif isinstance(data, dict): + return {k: recursive_to_device(v, device, non_blocking) for k, v in data.items()} + else: + return data + + +def load_balanced_group_indices( + load: List[int], + num_groups: int, + equal_size: bool = False, +) -> List[List[int]]: + """ + Split indices into groups with balanced load. + """ + if equal_size: + group_size = len(load) // num_groups + indices = np.argsort(load)[::-1] + groups = [[] for _ in range(num_groups)] + group_load = np.zeros(num_groups) + for idx in indices: + min_group_idx = np.argmin(group_load) + groups[min_group_idx].append(idx) + if equal_size and len(groups[min_group_idx]) == group_size: + group_load[min_group_idx] = float('inf') + else: + group_load[min_group_idx] += load[idx] + return groups + + +def cycle(data_loader: DataLoader) -> Iterator: + while True: + for data in data_loader: + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.idx += data_loader.batch_size # type: ignore[attr-defined] + yield data + if isinstance(data_loader.sampler, DistributedSampler): + data_loader.sampler.epoch += 1 + if isinstance(data_loader.sampler, ResumableSampler): + data_loader.sampler.epoch += 1 + data_loader.sampler.idx = 0 + + +class ResumableSampler(Sampler): + """ + Distributed sampler that is resumable. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + self.dataset = dataset + self.epoch = 0 + self.idx = 0 + self.drop_last = drop_last + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 + self.rank = dist.get_rank() if dist.is_initialized() else 0 + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.world_size != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.world_size) / self.world_size # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.world_size) # type: ignore[arg-type] + self.total_size = self.num_samples * self.world_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.world_size] + + # resume from previous state + indices = indices[self.idx:] + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def state_dict(self) -> dict[str, int]: + return { + 'epoch': self.epoch, + 'idx': self.idx, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict['epoch'] + self.idx = state_dict['idx'] + + +class BalancedResumableSampler(ResumableSampler): + """ + Distributed sampler that is resumable and balances the load among the processes. + + Args: + dataset: Dataset used for sampling. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True`` (default), sampler will shuffle the + indices. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``False``. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + batch_size: int = 1, + ) -> None: + assert hasattr(dataset, 'loads'), 'Dataset must have "loads" attribute to use BalancedResumableSampler' + super().__init__(dataset, shuffle, seed, drop_last) + self.batch_size = batch_size + self.loads = dataset.loads + + def __iter__(self) -> Iterator: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # balance load among processes + num_batches = len(indices) // (self.batch_size * self.world_size) + balanced_indices = [] + for i in range(num_batches): + start_idx = i * self.batch_size * self.world_size + end_idx = (i + 1) * self.batch_size * self.world_size + batch_indices = indices[start_idx:end_idx] + batch_loads = [self.loads[idx] for idx in batch_indices] + groups = load_balanced_group_indices(batch_loads, self.world_size, equal_size=True) + balanced_indices.extend([batch_indices[j] for j in groups[self.rank]]) + + # resume from previous state + indices = balanced_indices[self.idx:] + + return iter(indices)