This commit is contained in:
zcr
2026-03-17 11:38:16 +08:00
parent 0571f65793
commit 7531afd162
16 changed files with 2736 additions and 0 deletions

View File

@@ -0,0 +1,217 @@
import json
import os
from typing import *
import numpy as np
import torch
import utils3d.torch
from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
from ..modules.sparse.basic import SparseTensor
from .. import models
from ..utils.render_utils import get_renderer
from ..utils.data_utils import load_balanced_group_indices
class SLatVisMixin:
def __init__(
self,
*args,
pretrained_slat_dec: str = 'microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
slat_dec_path: Optional[str] = None,
slat_dec_ckpt: Optional[str] = None,
**kwargs
):
super().__init__(*args, **kwargs)
self.slat_dec = None
self.pretrained_slat_dec = pretrained_slat_dec
self.slat_dec_path = slat_dec_path
self.slat_dec_ckpt = slat_dec_ckpt
def _loading_slat_dec(self):
if self.slat_dec is not None:
return
if self.slat_dec_path is not None:
cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
else:
decoder = models.from_pretrained(self.pretrained_slat_dec)
self.slat_dec = decoder.cuda().eval()
def _delete_slat_dec(self):
del self.slat_dec
self.slat_dec = None
@torch.no_grad()
def decode_latent(self, z, batch_size=4):
self._loading_slat_dec()
reps = []
if self.normalization is not None:
z = z * self.std.to(z.device) + self.mean.to(z.device)
for i in range(0, z.shape[0], batch_size):
reps.append(self.slat_dec(z[i:i+batch_size]))
reps = sum(reps, [])
self._delete_slat_dec()
return reps
@torch.no_grad()
def visualize_sample(self, x_0: Union[SparseTensor, dict]):
x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
reps = self.decode_latent(x_0.cuda())
# Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
exts = []
ints = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(40)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
exts.append(extrinsics)
ints.append(intrinsics)
renderer = get_renderer(reps[0])
images = []
for representation in reps:
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
res = renderer.render(representation, ext, intr)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images.append(image)
images = torch.stack(images)
return images
class SLat(SLatVisMixin, StandardDatasetBase):
"""
structured latent dataset
Args:
roots (str): path to the dataset
latent_model (str): name of the latent model
min_aesthetic_score (float): minimum aesthetic score
max_num_voxels (int): maximum number of voxels
normalization (dict): normalization stats
pretrained_slat_dec (str): name of the pretrained slat decoder
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
slat_dec_ckpt (str): name of the slat decoder checkpoint
"""
def __init__(self,
roots: str,
*,
latent_model: str,
min_aesthetic_score: float = 5.0,
max_num_voxels: int = 32768,
normalization: Optional[dict] = None,
pretrained_slat_dec: str = 'microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
slat_dec_path: Optional[str] = None,
slat_dec_ckpt: Optional[str] = None,
):
self.normalization = normalization
self.latent_model = latent_model
self.min_aesthetic_score = min_aesthetic_score
self.max_num_voxels = max_num_voxels
self.value_range = (0, 1)
super().__init__(
roots,
pretrained_slat_dec=pretrained_slat_dec,
slat_dec_path=slat_dec_path,
slat_dec_ckpt=slat_dec_ckpt,
)
self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances]
if self.normalization is not None:
self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata[f'latent_{self.latent_model}']]
stats['With latent'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
coords = torch.tensor(data['coords']).int()
feats = torch.tensor(data['feats']).float()
if self.normalization is not None:
feats = (feats - self.mean) / self.std
return {
'coords': coords,
'feats': feats,
}
@staticmethod
def collate_fn(batch, split_size=None):
if split_size is None:
group_idx = [list(range(len(batch)))]
else:
group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
packs = []
for group in group_idx:
sub_batch = [batch[i] for i in group]
pack = {}
coords = []
feats = []
layout = []
start = 0
for i, b in enumerate(sub_batch):
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
feats.append(b['feats'])
layout.append(slice(start, start + b['coords'].shape[0]))
start += b['coords'].shape[0]
coords = torch.cat(coords)
feats = torch.cat(feats)
pack['x_0'] = SparseTensor(
coords=coords,
feats=feats,
)
pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
pack['x_0'].register_spatial_cache('layout', layout)
# collate other data
keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
for k in keys:
if isinstance(sub_batch[0][k], torch.Tensor):
pack[k] = torch.stack([b[k] for b in sub_batch])
elif isinstance(sub_batch[0][k], list):
pack[k] = sum([b[k] for b in sub_batch], [])
else:
pack[k] = [b[k] for b in sub_batch]
packs.append(pack)
if split_size is None:
return packs[0]
return packs
class TextConditionedSLat(TextConditionedMixin, SLat):
"""
Text conditioned structured latent dataset
"""
pass
class ImageConditionedSLat(ImageConditionedMixin, SLat):
"""
Image conditioned structured latent dataset
"""
pass

View File

@@ -0,0 +1,160 @@
import os
from PIL import Image
import json
import numpy as np
import torch
import utils3d.torch
from ..modules.sparse.basic import SparseTensor
from .components import StandardDatasetBase
class SLat2Render(StandardDatasetBase):
"""
Dataset for Structured Latent and rendered images.
Args:
roots (str): paths to the dataset
image_size (int): size of the image
latent_model (str): latent model name
min_aesthetic_score (float): minimum aesthetic score
max_num_voxels (int): maximum number of voxels
"""
def __init__(
self,
roots: str,
image_size: int,
latent_model: str,
min_aesthetic_score: float = 5.0,
max_num_voxels: int = 32768,
):
self.image_size = image_size
self.latent_model = latent_model
self.min_aesthetic_score = min_aesthetic_score
self.max_num_voxels = max_num_voxels
self.value_range = (0, 1)
super().__init__(roots)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata[f'latent_{self.latent_model}']]
stats['With latent'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
return metadata, stats
def _get_image(self, root, instance):
with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
metadata = json.load(f)
n_views = len(metadata['frames'])
view = np.random.randint(n_views)
metadata = metadata['frames'][view]
fov = metadata['camera_angle_x']
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
c2w = torch.tensor(metadata['transform_matrix'])
c2w[:3, 1:3] *= -1
extrinsics = torch.inverse(c2w)
image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
image = Image.open(image_path)
alpha = image.getchannel(3)
image = image.convert('RGB')
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
alpha = torch.tensor(np.array(alpha)).float() / 255.0
return {
'image': image,
'alpha': alpha,
'extrinsics': extrinsics,
'intrinsics': intrinsics,
}
def _get_latent(self, root, instance):
data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
coords = torch.tensor(data['coords']).int()
feats = torch.tensor(data['feats']).float()
return {
'coords': coords,
'feats': feats,
}
@torch.no_grad()
def visualize_sample(self, sample: dict):
return sample['image']
@staticmethod
def collate_fn(batch):
pack = {}
coords = []
for i, b in enumerate(batch):
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
coords = torch.cat(coords)
feats = torch.cat([b['feats'] for b in batch])
pack['latents'] = SparseTensor(
coords=coords,
feats=feats,
)
# collate other data
keys = [k for k in batch[0].keys() if k not in ['coords', 'feats']]
for k in keys:
if isinstance(batch[0][k], torch.Tensor):
pack[k] = torch.stack([b[k] for b in batch])
elif isinstance(batch[0][k], list):
pack[k] = sum([b[k] for b in batch], [])
else:
pack[k] = [b[k] for b in batch]
return pack
def get_instance(self, root, instance):
image = self._get_image(root, instance)
latent = self._get_latent(root, instance)
return {
**image,
**latent,
}
class Slat2RenderGeo(SLat2Render):
def __init__(
self,
roots: str,
image_size: int,
latent_model: str,
min_aesthetic_score: float = 5.0,
max_num_voxels: int = 32768,
):
super().__init__(
roots,
image_size,
latent_model,
min_aesthetic_score,
max_num_voxels,
)
def _get_geo(self, root, instance):
verts, face = utils3d.io.read_ply(os.path.join(root, 'renders', instance, 'mesh.ply'))
mesh = {
"vertices" : torch.from_numpy(verts),
"faces" : torch.from_numpy(face),
}
return {
"mesh" : mesh,
}
def get_instance(self, root, instance):
image = self._get_image(root, instance)
latent = self._get_latent(root, instance)
geo = self._get_geo(root, instance)
return {
**image,
**latent,
**geo,
}

View File

@@ -0,0 +1,276 @@
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
from ..modules.transformer import AbsolutePositionEmbedder
from ..modules.norm import LayerNorm32
from ..modules import sparse as sp
from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
from .sparse_structure_flow import TimestepEmbedder
from .sparse_elastic_mixin import SparseTransformerElasticMixin
class SparseResBlock3d(nn.Module):
def __init__(
self,
channels: int,
emb_channels: int,
out_channels: Optional[int] = None,
downsample: bool = False,
upsample: bool = False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.out_channels = out_channels or channels
self.downsample = downsample
self.upsample = upsample
assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
)
self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
self.updown = None
if self.downsample:
self.updown = sp.SparseDownsample(2)
elif self.upsample:
self.updown = sp.SparseUpsample(2)
def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
if self.updown is not None:
x = self.updown(x)
return x
def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
emb_out = self.emb_layers(emb).type(x.dtype)
scale, shift = torch.chunk(emb_out, 2, dim=1)
x = self._updown(x)
h = x.replace(self.norm1(x.feats))
h = h.replace(F.silu(h.feats))
h = self.conv1(h)
h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
h = h.replace(F.silu(h.feats))
h = self.conv2(h)
h = h + self.skip_connection(x)
return h
class SLatFlowModel(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
model_channels: int,
cond_channels: int,
out_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
patch_size: int = 2,
num_io_res_blocks: int = 2,
io_block_channels: List[int] = None,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
use_skip_connection: bool = True,
share_mod: bool = False,
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
):
super().__init__()
self.resolution = resolution
self.in_channels = in_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.patch_size = patch_size
self.num_io_res_blocks = num_io_res_blocks
self.io_block_channels = io_block_channels
self.pe_mode = pe_mode
self.use_fp16 = use_fp16
self.use_checkpoint = use_checkpoint
self.use_skip_connection = use_skip_connection
self.share_mod = share_mod
self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross
self.dtype = torch.float16 if use_fp16 else torch.float32
if self.io_block_channels is not None:
assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
self.t_embedder = TimestepEmbedder(model_channels)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(model_channels, 6 * model_channels, bias=True)
)
if pe_mode == "ape":
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0])
self.input_blocks = nn.ModuleList([])
if io_block_channels is not None:
for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
self.input_blocks.extend([
SparseResBlock3d(
chs,
model_channels,
out_channels=chs,
)
for _ in range(num_io_res_blocks-1)
])
self.input_blocks.append(
SparseResBlock3d(
chs,
model_channels,
out_channels=next_chs,
downsample=True,
)
)
self.blocks = nn.ModuleList([
ModulatedSparseTransformerCrossBlock(
model_channels,
cond_channels,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attn_mode='full',
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
share_mod=self.share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
)
for _ in range(num_blocks)
])
self.out_blocks = nn.ModuleList([])
if io_block_channels is not None:
for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
self.out_blocks.append(
SparseResBlock3d(
prev_chs * 2 if self.use_skip_connection else prev_chs,
model_channels,
out_channels=chs,
upsample=True,
)
)
self.out_blocks.extend([
SparseResBlock3d(
chs * 2 if self.use_skip_connection else chs,
model_channels,
out_channels=chs,
)
for _ in range(num_io_res_blocks-1)
])
self.out_layer = sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], out_channels)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
@property
def device(self) -> torch.device:
"""
Return the device of the model.
"""
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.blocks.apply(convert_module_to_f16)
self.out_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.blocks.apply(convert_module_to_f32)
self.out_blocks.apply(convert_module_to_f32)
def initialize_weights(self) -> None:
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
if self.share_mod:
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
else:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
h = self.input_layer(x).type(self.dtype)
t_emb = self.t_embedder(t)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = t_emb.type(self.dtype)
cond = cond.type(self.dtype)
skips = []
# pack with input blocks
for block in self.input_blocks:
h = block(h, t_emb)
skips.append(h.feats)
if self.pe_mode == "ape":
h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
for block in self.blocks:
h = block(h, t_emb, cond)
# unpack with output blocks
for block, skip in zip(self.out_blocks, reversed(skips)):
if self.use_skip_connection:
h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
else:
h = block(h, t_emb)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h.type(x.dtype))
return h
class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
"""
SLat Flow Model with elastic memory management.
Used for training with low VRAM.
"""
pass

25
trellis/modules/norm.py Normal file
View File

@@ -0,0 +1,25 @@
import torch
import torch.nn as nn
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.float()).type(x.dtype)
class GroupNorm32(nn.GroupNorm):
"""
A GroupNorm layer that converts to float32 before the forward pass.
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.float()).type(x.dtype)
class ChannelLayerNorm32(LayerNorm32):
def forward(self, x: torch.Tensor) -> torch.Tensor:
DIM = x.dim()
x = x.permute(0, *range(2, DIM), 1).contiguous()
x = super().forward(x)
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
return x

View File

@@ -0,0 +1,35 @@
import torch
import torch.nn as nn
from . import SparseTensor
__all__ = [
'SparseReLU',
'SparseSiLU',
'SparseGELU',
'SparseActivation'
]
class SparseReLU(nn.ReLU):
def forward(self, input: SparseTensor) -> SparseTensor:
return input.replace(super().forward(input.feats))
class SparseSiLU(nn.SiLU):
def forward(self, input: SparseTensor) -> SparseTensor:
return input.replace(super().forward(input.feats))
class SparseGELU(nn.GELU):
def forward(self, input: SparseTensor) -> SparseTensor:
return input.replace(super().forward(input.feats))
class SparseActivation(nn.Module):
def __init__(self, activation: nn.Module):
super().__init__()
self.activation = activation
def forward(self, input: SparseTensor) -> SparseTensor:
return input.replace(self.activation(input.feats))

View File

@@ -0,0 +1,58 @@
import torch
import torch.nn as nn
from . import SparseTensor
from . import DEBUG
__all__ = [
'SparseGroupNorm',
'SparseLayerNorm',
'SparseGroupNorm32',
'SparseLayerNorm32',
]
class SparseGroupNorm(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
def forward(self, input: SparseTensor) -> SparseTensor:
nfeats = torch.zeros_like(input.feats)
for k in range(input.shape[0]):
if DEBUG:
assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
bfeats = input.feats[input.layout[k]]
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
bfeats = super().forward(bfeats)
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
nfeats[input.layout[k]] = bfeats
return input.replace(nfeats)
class SparseLayerNorm(nn.LayerNorm):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input: SparseTensor) -> SparseTensor:
nfeats = torch.zeros_like(input.feats)
for k in range(input.shape[0]):
bfeats = input.feats[input.layout[k]]
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
bfeats = super().forward(bfeats)
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
nfeats[input.layout[k]] = bfeats
return input.replace(nfeats)
class SparseGroupNorm32(SparseGroupNorm):
"""
A GroupNorm layer that converts to float32 before the forward pass.
"""
def forward(self, x: SparseTensor) -> SparseTensor:
return super().forward(x.float()).type(x.dtype)
class SparseLayerNorm32(SparseLayerNorm):
"""
A LayerNorm layer that converts to float32 before the forward pass.
"""
def forward(self, x: SparseTensor) -> SparseTensor:
return super().forward(x.float()).type(x.dtype)

View File

@@ -0,0 +1,300 @@
import numpy as np
import torch
import torch.nn.functional as F
import math
import cv2
from scipy.stats import qmc
from easydict import EasyDict as edict
from ..representations.octree import DfsOctree
def intrinsics_to_projection(
intrinsics: torch.Tensor,
near: float,
far: float,
) -> torch.Tensor:
"""
OpenCV intrinsics to OpenGL perspective matrix
Args:
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
near (float): near plane to clip
far (float): far plane to clip
Returns:
(torch.Tensor): [4, 4] OpenGL perspective matrix
"""
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
ret[0, 0] = 2 * fx
ret[1, 1] = 2 * fy
ret[0, 2] = 2 * cx - 1
ret[1, 2] = - 2 * cy + 1
ret[2, 2] = far / (far - near)
ret[2, 3] = near * far / (near - far)
ret[3, 2] = 1.
return ret
def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None):
"""
Render the scene.
Background tensor (bg_color) must be on GPU!
"""
# lazy import
if 'OctreeTrivecRasterizer' not in globals():
from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
raster_settings = edict(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=octree.active_sh_degree,
campos=viewpoint_camera.camera_center,
with_distloss=pipe.with_distloss,
jitter=pipe.jitter,
debug=pipe.debug,
)
positions = octree.get_xyz
if octree.primitive == "voxel":
densities = octree.get_density
elif octree.primitive == "gaussian":
opacities = octree.get_opacity
elif octree.primitive == "trivec":
trivecs = octree.get_trivec
densities = octree.get_density
raster_settings.density_shift = octree.density_shift
elif octree.primitive == "decoupoly":
decoupolys_V, decoupolys_g = octree.get_decoupoly
densities = octree.get_density
raster_settings.density_shift = octree.density_shift
else:
raise ValueError(f"Unknown primitive {octree.primitive}")
depths = octree.get_depth
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
colors_precomp = None
shs = octree.get_features
if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None:
colors_precomp = colors_overwrite
shs = None
ret = edict()
if octree.primitive == "voxel":
renderer = OctreeVoxelRasterizer(raster_settings=raster_settings)
rgb, depth, alpha, distloss = renderer(
positions = positions,
densities = densities,
shs = shs,
colors_precomp = colors_precomp,
depths = depths,
aabb = octree.aabb,
aux = aux,
)
ret['rgb'] = rgb
ret['depth'] = depth
ret['alpha'] = alpha
ret['distloss'] = distloss
elif octree.primitive == "gaussian":
renderer = OctreeGaussianRasterizer(raster_settings=raster_settings)
rgb, depth, alpha = renderer(
positions = positions,
opacities = opacities,
shs = shs,
colors_precomp = colors_precomp,
depths = depths,
aabb = octree.aabb,
aux = aux,
)
ret['rgb'] = rgb
ret['depth'] = depth
ret['alpha'] = alpha
elif octree.primitive == "trivec":
raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1]
renderer = OctreeTrivecRasterizer(raster_settings=raster_settings)
rgb, depth, alpha, percent_depth = renderer(
positions = positions,
trivecs = trivecs,
densities = densities,
shs = shs,
colors_precomp = colors_precomp,
colors_overwrite = colors_overwrite,
depths = depths,
aabb = octree.aabb,
aux = aux,
halton_sampler = halton_sampler,
)
ret['percent_depth'] = percent_depth
ret['rgb'] = rgb
ret['depth'] = depth
ret['alpha'] = alpha
elif octree.primitive == "decoupoly":
raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1]
renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings)
rgb, depth, alpha = renderer(
positions = positions,
decoupolys_V = decoupolys_V,
decoupolys_g = decoupolys_g,
densities = densities,
shs = shs,
colors_precomp = colors_precomp,
depths = depths,
aabb = octree.aabb,
aux = aux,
)
ret['rgb'] = rgb
ret['depth'] = depth
ret['alpha'] = alpha
return ret
class OctreeRenderer:
"""
Renderer for the Voxel representation.
Args:
rendering_options (dict): Rendering options.
"""
def __init__(self, rendering_options={}) -> None:
try:
import diffoctreerast
except ImportError:
print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m")
self.unsupported = True
else:
self.unsupported = False
self.pipe = edict({
"with_distloss": False,
"with_aux": False,
"scale_modifier": 1.0,
"used_rank": None,
"jitter": False,
"debug": False,
})
self.rendering_options = edict({
"resolution": None,
"near": None,
"far": None,
"ssaa": 1,
"bg_color": 'random',
})
self.halton_sampler = qmc.Halton(2, scramble=False)
self.rendering_options.update(rendering_options)
self.bg_color = None
def render(
self,
octree: DfsOctree,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
colors_overwrite: torch.Tensor = None,
) -> edict:
"""
Render the octree.
Args:
octree (Octree): octree
extrinsics (torch.Tensor): (4, 4) camera extrinsics
intrinsics (torch.Tensor): (3, 3) camera intrinsics
colors_overwrite (torch.Tensor): (N, 3) override color
Returns:
edict containing:
color (torch.Tensor): (3, H, W) rendered color
depth (torch.Tensor): (H, W) rendered depth
alpha (torch.Tensor): (H, W) rendered alpha
distloss (Optional[torch.Tensor]): (H, W) rendered distance loss
percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth
aux (Optional[edict]): auxiliary tensors
"""
resolution = self.rendering_options["resolution"]
near = self.rendering_options["near"]
far = self.rendering_options["far"]
ssaa = self.rendering_options["ssaa"]
if self.unsupported:
image = np.zeros((512, 512, 3), dtype=np.uint8)
text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0]
origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2
image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA)
return {
'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255,
}
if self.rendering_options["bg_color"] == 'random':
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
if np.random.rand() < 0.5:
self.bg_color += 1
else:
self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
if self.pipe["with_aux"]:
aux = {
'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
}
for k in aux.keys():
aux[k].requires_grad_()
aux[k].retain_grad()
else:
aux = None
view = extrinsics
perspective = intrinsics_to_projection(intrinsics, near, far)
camera = torch.inverse(view)[:3, 3]
focalx = intrinsics[0, 0]
focaly = intrinsics[1, 1]
fovx = 2 * torch.atan(0.5 / focalx)
fovy = 2 * torch.atan(0.5 / focaly)
camera_dict = edict({
"image_height": resolution * ssaa,
"image_width": resolution * ssaa,
"FoVx": fovx,
"FoVy": fovy,
"znear": near,
"zfar": far,
"world_view_transform": view.T.contiguous(),
"projection_matrix": perspective.T.contiguous(),
"full_proj_transform": (perspective @ view).T.contiguous(),
"camera_center": camera
})
# Render
render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler)
if ssaa > 1:
render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
if hasattr(render_ret, 'percent_depth'):
render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
ret = edict({
'color': render_ret.rgb,
'depth': render_ret.depth,
'alpha': render_ret.alpha,
})
if self.pipe["with_distloss"] and 'distloss' in render_ret:
ret['distloss'] = render_ret.distloss
if self.pipe["with_aux"]:
ret['aux'] = aux
if hasattr(render_ret, 'percent_depth'):
ret['percent_depth'] = render_ret.percent_depth
return ret

View File

@@ -0,0 +1,347 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class DfsOctree:
"""
Sparse Voxel Octree (SVO) implementation for PyTorch.
Using Depth-First Search (DFS) order to store the octree.
DFS order suits rendering and ray tracing.
The structure and data are separatedly stored.
Structure is stored as a continuous array, each element is a 3*32 bits descriptor.
|-----------------------------------------|
| 0:3 bits | 4:31 bits |
| leaf num | unused |
|-----------------------------------------|
| 0:31 bits |
| child ptr |
|-----------------------------------------|
| 0:31 bits |
| data ptr |
|-----------------------------------------|
Each element represents a non-leaf node in the octree.
The valid mask is used to indicate whether the children are valid.
The leaf mask is used to indicate whether the children are leaf nodes.
The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr.
The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr.
There are also auxiliary arrays to store the additional structural information to facilitate parallel processing.
- Position: the position of the octree nodes.
- Depth: the depth of the octree nodes.
Args:
depth (int): the depth of the octree.
"""
def __init__(
self,
depth,
aabb=[0,0,0,1,1,1],
sh_degree=2,
primitive='voxel',
primitive_config={},
device='cuda',
):
self.max_depth = depth
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
self.device = device
self.sh_degree = sh_degree
self.active_sh_degree = sh_degree
self.primitive = primitive
self.primitive_config = primitive_config
self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device)
self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device)
self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device)
self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device)
self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device)
self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device)
self.depth[:, 0] = 1
self.data = ['position', 'depth']
self.param_names = []
if primitive == 'voxel':
self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device)
self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
self.data += ['features_dc', 'features_ac']
self.param_names += ['features_dc', 'features_ac']
if not primitive_config.get('solid', False):
self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device)
self.data.append('density')
self.param_names.append('density')
elif primitive == 'gaussian':
self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device)
self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device)
self.data += ['features_dc', 'features_ac', 'opacity']
self.param_names += ['features_dc', 'features_ac', 'opacity']
elif primitive == 'trivec':
self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device)
self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device)
self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device)
self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
self.density_shift = 0
self.data += ['trivec', 'density', 'features_dc', 'features_ac']
self.param_names += ['trivec', 'density', 'features_dc', 'features_ac']
elif primitive == 'decoupoly':
self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device)
self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device)
self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device)
self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device)
self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
self.density_shift = 0
self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac']
self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac']
self.setup_functions()
def setup_functions(self):
self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x)
self.opacity_activation = lambda x: torch.sigmoid(x - 6)
self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6
self.color_activation = lambda x: torch.sigmoid(x)
@property
def num_non_leaf_nodes(self):
return self.structure.shape[0]
@property
def num_leaf_nodes(self):
return self.depth.shape[0]
@property
def cur_depth(self):
return self.depth.max().item()
@property
def occupancy(self):
return self.num_leaf_nodes / 8 ** self.cur_depth
@property
def get_xyz(self):
return self.position
@property
def get_depth(self):
return self.depth
@property
def get_density(self):
if self.primitive == 'voxel' and self.primitive_config.get('solid', False):
return torch.full((self.position.shape[0], 1), torch.finfo(torch.float32).max, dtype=torch.float32, device=self.device)
return self.density_activation(self.density)
@property
def get_opacity(self):
return self.opacity_activation(self.density)
@property
def get_trivec(self):
return self.trivec
@property
def get_decoupoly(self):
return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g
@property
def get_color(self):
return self.color_activation(self.colors)
@property
def get_features(self):
if self.sh_degree == 0:
return self.features_dc
return torch.cat([self.features_dc, self.features_ac], dim=-2)
def state_dict(self):
ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'primitive_config': self.primitive_config, 'primitive': self.primitive}
if hasattr(self, 'density_shift'):
ret['density_shift'] = self.density_shift
for data in set(self.data + self.param_names):
if not isinstance(getattr(self, data), nn.Module):
ret[data] = getattr(self, data)
else:
ret[data] = getattr(self, data).state_dict()
return ret
def load_state_dict(self, state_dict):
keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth']))
for key in keys:
if key not in state_dict:
print(f"Warning: key {key} not found in the state_dict.")
continue
try:
if not isinstance(getattr(self, key), nn.Module):
setattr(self, key, state_dict[key])
else:
getattr(self, key).load_state_dict(state_dict[key])
except Exception as e:
print(e)
raise ValueError(f"Error loading key {key}.")
def gather_from_leaf_children(self, data):
"""
Gather the data from the leaf children.
Args:
data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes.
"""
leaf_cnt = self.structure[:, 0]
leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)]
ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device)
for i in range(8):
if leaf_cnt_masks[i].sum() == 0:
continue
start = self.structure[leaf_cnt_masks[i], 2]
for j in range(i+1):
ret[leaf_cnt_masks[i]] += data[start + j]
return ret
def gather_from_non_leaf_children(self, data):
"""
Gather the data from the non-leaf children.
Args:
data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes.
"""
non_leaf_cnt = 8 - self.structure[:, 0]
non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)]
ret = torch.zeros_like(data, device=self.device)
for i in range(8):
if non_leaf_cnt_masks[i].sum() == 0:
continue
start = self.structure[non_leaf_cnt_masks[i], 1]
for j in range(i+1):
ret[non_leaf_cnt_masks[i]] += data[start + j]
return ret
def structure_control(self, mask):
"""
Control the structure of the octree.
Args:
mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep.
"""
# Dont subdivide when the depth is the maximum.
mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0)
# Dont merge when the depth is the minimum.
mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0)
# Gather control mask
structre_ctrl = self.gather_from_leaf_children(mask)
structre_ctrl[structre_ctrl==-8] = -1
new_leaf_num = self.structure[:, 0].clone()
# Modify the leaf num.
structre_valid = structre_ctrl >= 0
new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes.
structre_delete = structre_ctrl < 0
merged_nodes = self.gather_from_non_leaf_children(structre_delete.int())
new_leaf_num += merged_nodes # Delete the merged nodes.
# Update the structure array to allocate new nodes.
mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device)
mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes.
mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes.
new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0)
new_structure_length = new_structre_idx[-1].item()
new_structre_idx = new_structre_idx[:-1]
new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device)
new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid]
# Initialize the new nodes.
new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device)
new_node_mask[new_structre_idx[structre_valid]] = False
new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes.
new_node_num = new_node_mask.sum().item()
# Rebuild child ptr.
non_leaf_cnt = 8 - new_structure[:, 0]
new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]])
new_structure[:, 1] = new_child_ptr + 1
# Rebuild data ptr with old data.
leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device)
leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0])
old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]])
# Update the data array
subdivide_mask = mask == 1
merge_mask = mask == -1
data_valid = ~(subdivide_mask | merge_mask)
mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device)
mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes
mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes
mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes
mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes
new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0)
new_data_length = new_data_idx[-1].item()
new_data_idx = new_data_idx[:-1]
new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data}
for data in self.data:
new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid]
# Rebuild data ptr
leaf_cnt = new_structure[:, 0]
new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]])
new_structure[:, 2] = new_data_ptr
# Initialize the new data array
## For subdivide nodes
if subdivide_mask.sum() > 0:
subdivide_data_ptr = new_structure[new_node_mask, 2]
for data in self.data:
for i in range(8):
if data == 'position':
offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5
scale = 2 ** (-1.0 - self.depth[subdivide_mask])
new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale
elif data == 'depth':
new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1
elif data == 'opacity':
new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask])))
elif data == 'trivec':
offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5
coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1)
axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1)
coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1
new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True)
else:
new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask]
## For merge nodes
if merge_mask.sum() > 0:
merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device)
merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]])
for i in range(8):
merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i
old_merge_data_ptr = self.structure[structre_delete, 2]
for data in self.data:
if data == 'position':
scale = 2 ** (1.0 - self.depth[old_merge_data_ptr])
new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5
elif data == 'depth':
new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1
elif data == 'opacity':
new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2)
elif data == 'trivec':
new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr]
else:
new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr]
# Update the structure and data array
self.structure = new_structure
for data in self.data:
setattr(self, data, new_data[data])
# Save data array control temp variables
self.data_rearrange_buffer = {
'subdivide_mask': subdivide_mask,
'merge_mask': merge_mask,
'data_valid': data_valid,
'new_data_idx': new_data_idx,
'new_data_length': new_data_length,
'new_data': new_data
}

View File

@@ -0,0 +1,28 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..octree import DfsOctree as Octree
class Strivec(Octree):
def __init__(
self,
resolution: int,
aabb: list,
sh_degree: int = 0,
rank: int = 8,
dim: int = 8,
device: str = "cuda",
):
assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2"
self.resolution = resolution
depth = int(np.round(np.log2(resolution)))
super().__init__(
depth=depth,
aabb=aabb,
sh_degree=sh_degree,
primitive="trivec",
primitive_config={"rank": rank, "dim": dim},
device=device,
)

View File

@@ -0,0 +1,275 @@
from typing import *
import copy
import torch
from torch.utils.data import DataLoader
import numpy as np
from easydict import EasyDict as edict
import utils3d.torch
from ..basic import BasicTrainer
from ...representations import Gaussian
from ...renderers import GaussianRenderer
from ...modules.sparse import SparseTensor
from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
class SLatVaeGaussianTrainer(BasicTrainer):
"""
Trainer for structured latent VAE.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
loss_type (str): Loss type. Can be 'l1', 'l2'
lambda_ssim (float): SSIM loss weight.
lambda_lpips (float): LPIPS loss weight.
lambda_kl (float): KL loss weight.
regularizations (dict): Regularization config.
"""
def __init__(
self,
*args,
loss_type: str = 'l1',
lambda_ssim: float = 0.2,
lambda_lpips: float = 0.2,
lambda_kl: float = 1e-6,
regularizations: Dict = {},
**kwargs
):
super().__init__(*args, **kwargs)
self.loss_type = loss_type
self.lambda_ssim = lambda_ssim
self.lambda_lpips = lambda_lpips
self.lambda_kl = lambda_kl
self.regularizations = regularizations
self._init_renderer()
def _init_renderer(self):
rendering_options = {"near" : 0.8,
"far" : 1.6,
"bg_color" : 'random'}
self.renderer = GaussianRenderer(rendering_options)
self.renderer.pipe.kernel_size = self.models['decoder'].rep_config['2d_filter_kernel_size']
def _render_batch(self, reps: List[Gaussian], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
"""
Render a batch of representations.
Args:
reps: The dictionary of lists of representations.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
"""
ret = None
for i, representation in enumerate(reps):
render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
if ret is None:
ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
for k, v in render_pack.items():
ret[k].append(v)
ret['bg_color'].append(self.renderer.bg_color)
for k, v in ret.items():
ret[k] = torch.stack(v, dim=0)
return ret
@torch.no_grad()
def _get_status(self, z: SparseTensor, reps: List[Gaussian]) -> Dict:
xyz = torch.cat([g.get_xyz for g in reps], dim=0)
xyz_base = (z.coords[:, 1:].float() + 0.5) / self.models['decoder'].resolution - 0.5
offset = xyz - xyz_base.unsqueeze(1).expand(-1, self.models['decoder'].rep_config['num_gaussians'], -1).reshape(-1, 3)
status = {
'xyz': xyz,
'offset': offset,
'scale': torch.cat([g.get_scaling for g in reps], dim=0),
'opacity': torch.cat([g.get_opacity for g in reps], dim=0),
}
for k in list(status.keys()):
status[k] = {
'mean': status[k].mean().item(),
'max': status[k].max().item(),
'min': status[k].min().item(),
}
return status
def _get_regularization_loss(self, reps: List[Gaussian]) -> Tuple[torch.Tensor, Dict]:
loss = 0.0
terms = {}
if 'lambda_vol' in self.regularizations:
scales = torch.cat([g.get_scaling for g in reps], dim=0) # [N x 3]
volume = torch.prod(scales, dim=1) # [N]
terms[f'reg_vol'] = volume.mean()
loss = loss + self.regularizations['lambda_vol'] * terms[f'reg_vol']
if 'lambda_opacity' in self.regularizations:
opacity = torch.cat([g.get_opacity for g in reps], dim=0)
terms[f'reg_opacity'] = (opacity - 1).pow(2).mean()
loss = loss + self.regularizations['lambda_opacity'] * terms[f'reg_opacity']
return loss, terms
def training_losses(
self,
feats: SparseTensor,
image: torch.Tensor,
alpha: torch.Tensor,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
return_aux: bool = False,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses.
Args:
feats: The [N x * x C] sparse tensor of features.
image: The [N x 3 x H x W] tensor of images.
alpha: The [N x H x W] tensor of alpha channels.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
return_aux: Whether to return auxiliary information.
Returns:
a dict with the key "loss" containing a scalar tensor.
may also contain other keys for different terms.
"""
z, mean, logvar = self.training_models['encoder'](feats, sample_posterior=True, return_raw=True)
reps = self.training_models['decoder'](z)
self.renderer.rendering_options.resolution = image.shape[-1]
render_results = self._render_batch(reps, extrinsics, intrinsics)
terms = edict(loss = 0.0, rec = 0.0)
rec_image = render_results['color']
gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
if self.loss_type == 'l1':
terms["l1"] = l1_loss(rec_image, gt_image)
terms["rec"] = terms["rec"] + terms["l1"]
elif self.loss_type == 'l2':
terms["l2"] = l2_loss(rec_image, gt_image)
terms["rec"] = terms["rec"] + terms["l2"]
else:
raise ValueError(f"Invalid loss type: {self.loss_type}")
if self.lambda_ssim > 0:
terms["ssim"] = 1 - ssim(rec_image, gt_image)
terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
if self.lambda_lpips > 0:
terms["lpips"] = lpips(rec_image, gt_image)
terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
terms["loss"] = terms["loss"] + terms["rec"]
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
reg_loss, reg_terms = self._get_regularization_loss(reps)
terms.update(reg_terms)
terms["loss"] = terms["loss"] + reg_loss
status = self._get_status(z, reps)
if return_aux:
return terms, status, {'rec_image': rec_image, 'gt_image': gt_image}
return terms, status
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
ret_dict = {}
gt_images = []
exts = []
ints = []
reps = []
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
args = {k: v[:batch].cuda() for k, v in data.items()}
gt_images.append(args['image'] * args['alpha'][:, None])
exts.append(args['extrinsics'])
ints.append(args['intrinsics'])
z = self.models['encoder'](args['feats'], sample_posterior=True, return_raw=False)
reps.extend(self.models['decoder'](z))
gt_images = torch.cat(gt_images, dim=0)
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
# render single view
exts = torch.cat(exts, dim=0)
ints = torch.cat(ints, dim=0)
self.renderer.rendering_options.bg_color = (0, 0, 0)
self.renderer.rendering_options.resolution = gt_images.shape[-1]
render_results = self._render_batch(reps, exts, ints)
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
# render multiview
self.renderer.rendering_options.resolution = 512
## Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
## render each view
miltiview_images = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(30)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
render_results = self._render_batch(reps, extrinsics, intrinsics)
miltiview_images.append(render_results['color'])
## Concatenate views
miltiview_images = torch.cat([
torch.cat(miltiview_images[:2], dim=-2),
torch.cat(miltiview_images[2:], dim=-2),
], dim=-1)
ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
self.renderer.rendering_options.bg_color = 'random'
return ret_dict

View File

@@ -0,0 +1,382 @@
from typing import *
import copy
import torch
from torch.utils.data import DataLoader
import numpy as np
from easydict import EasyDict as edict
import utils3d.torch
from ..basic import BasicTrainer
from ...representations import MeshExtractResult
from ...renderers import MeshRenderer
from ...modules.sparse import SparseTensor
from ...utils.loss_utils import l1_loss, smooth_l1_loss, ssim, lpips
from ...utils.data_utils import recursive_to_device
class SLatVaeMeshDecoderTrainer(BasicTrainer):
"""
Trainer for structured latent VAE Mesh Decoder.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
loss_type (str): Loss type. Can be 'l1', 'l2'
lambda_ssim (float): SSIM loss weight.
lambda_lpips (float): LPIPS loss weight.
"""
def __init__(
self,
*args,
depth_loss_type: str = 'l1',
lambda_depth: int = 1,
lambda_ssim: float = 0.2,
lambda_lpips: float = 0.2,
lambda_tsdf: float = 0.01,
lambda_color: float = 0.1,
**kwargs
):
super().__init__(*args, **kwargs)
self.depth_loss_type = depth_loss_type
self.lambda_depth = lambda_depth
self.lambda_ssim = lambda_ssim
self.lambda_lpips = lambda_lpips
self.lambda_tsdf = lambda_tsdf
self.lambda_color = lambda_color
self.use_color = self.lambda_color > 0
self._init_renderer()
def _init_renderer(self):
rendering_options = {"near" : 1,
"far" : 3}
self.renderer = MeshRenderer(rendering_options, device=self.device)
def _render_batch(self, reps: List[MeshExtractResult], extrinsics: torch.Tensor, intrinsics: torch.Tensor,
return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]:
"""
Render a batch of representations.
Args:
reps: The dictionary of lists of representations.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color']
Returns:
a dict with
reg_loss : [N] tensor of regularization losses
mask : [N x 1 x H x W] tensor of rendered masks
normal : [N x 3 x H x W] tensor of rendered normals
depth : [N x 1 x H x W] tensor of rendered depths
"""
ret = {k : [] for k in return_types}
for i, rep in enumerate(reps):
out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types)
for k in out_dict:
ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k])
for k in ret:
ret[k] = torch.stack(ret[k])
return ret
@staticmethod
def _tsdf_reg_loss(rep: MeshExtractResult, depth_map: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
# Calculate tsdf
with torch.no_grad():
# Project points to camera and calculate pseudo-sdf as difference between gt depth and projected depth
projected_pts, pts_depth = utils3d.torch.project_cv(extrinsics=extrinsics, intrinsics=intrinsics, points=rep.tsdf_v)
projected_pts = (projected_pts - 0.5) * 2.0
depth_map_res = depth_map.shape[1]
gt_depth = torch.nn.functional.grid_sample(depth_map.reshape(1, 1, depth_map_res, depth_map_res),
projected_pts.reshape(1, 1, -1, 2), mode='bilinear', padding_mode='border', align_corners=True)
pseudo_sdf = gt_depth.flatten() - pts_depth.flatten()
# Truncate pseudo-sdf
delta = 1 / rep.res * 3.0
trunc_mask = pseudo_sdf > -delta
# Loss
gt_tsdf = pseudo_sdf[trunc_mask]
tsdf = rep.tsdf_s.flatten()[trunc_mask]
gt_tsdf = torch.clamp(gt_tsdf, -delta, delta)
return torch.mean((tsdf - gt_tsdf) ** 2)
def _calc_tsdf_loss(self, reps : list[MeshExtractResult], depth_maps, extrinsics, intrinsics) -> torch.Tensor:
tsdf_loss = 0.0
for i, rep in enumerate(reps):
tsdf_loss += self._tsdf_reg_loss(rep, depth_maps[i], extrinsics[i], intrinsics[i])
return tsdf_loss / len(reps)
@torch.no_grad()
def _flip_normal(self, normal: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
"""
Flip normal to align with camera.
"""
normal = normal * 2.0 - 1.0
R = torch.zeros_like(extrinsics)
R[:, :3, :3] = extrinsics[:, :3, :3]
R[:, 3, 3] = 1.0
view_dir = utils3d.torch.unproject_cv(
utils3d.torch.image_uv(*normal.shape[-2:], device=self.device).reshape(1, -1, 2),
torch.ones(*normal.shape[-2:], device=self.device).reshape(1, -1),
R, intrinsics
).reshape(-1, *normal.shape[-2:], 3).permute(0, 3, 1, 2)
unflip = (normal * view_dir).sum(1, keepdim=True) < 0
normal *= unflip * 2.0 - 1.0
return (normal + 1.0) / 2.0
def _perceptual_loss(self, gt: torch.Tensor, pred: torch.Tensor, name: str) -> Dict[str, torch.Tensor]:
"""
Combination of L1, SSIM, and LPIPS loss.
"""
if gt.shape[1] != 3:
assert gt.shape[-1] == 3
gt = gt.permute(0, 3, 1, 2)
if pred.shape[1] != 3:
assert pred.shape[-1] == 3
pred = pred.permute(0, 3, 1, 2)
terms = {
f"{name}_loss" : l1_loss(gt, pred),
f"{name}_loss_ssim" : 1 - ssim(gt, pred),
f"{name}_loss_lpips" : lpips(gt, pred)
}
terms[f"{name}_loss_perceptual"] = terms[f"{name}_loss"] + terms[f"{name}_loss_ssim"] * self.lambda_ssim + terms[f"{name}_loss_lpips"] * self.lambda_lpips
return terms
def geometry_losses(
self,
reps: List[MeshExtractResult],
mesh: List[Dict],
normal_map: torch.Tensor,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
):
with torch.no_grad():
gt_meshes = []
for i in range(len(reps)):
gt_mesh = MeshExtractResult(mesh[i]['vertices'].to(self.device), mesh[i]['faces'].to(self.device))
gt_meshes.append(gt_mesh)
target = self._render_batch(gt_meshes, extrinsics, intrinsics, return_types=['mask', 'depth', 'normal'])
target['normal'] = self._flip_normal(target['normal'], extrinsics, intrinsics)
terms = edict(geo_loss = 0.0)
if self.lambda_tsdf > 0:
tsdf_loss = self._calc_tsdf_loss(reps, target['depth'], extrinsics, intrinsics)
terms['tsdf_loss'] = tsdf_loss
terms['geo_loss'] += tsdf_loss * self.lambda_tsdf
return_types = ['mask', 'depth', 'normal', 'normal_map'] if self.use_color else ['mask', 'depth', 'normal']
buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
if success_mask.sum() != 0:
for k, v in buffer.items():
buffer[k] = v[success_mask]
for k, v in target.items():
target[k] = v[success_mask]
terms['mask_loss'] = l1_loss(buffer['mask'], target['mask'])
if self.depth_loss_type == 'l1':
terms['depth_loss'] = l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'])
elif self.depth_loss_type == 'smooth_l1':
terms['depth_loss'] = smooth_l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'], beta=1.0 / (2 * reps[0].res))
else:
raise ValueError(f"Unsupported depth loss type: {self.depth_loss_type}")
terms.update(self._perceptual_loss(buffer['normal'] * target['mask'], target['normal'] * target['mask'], 'normal'))
terms['geo_loss'] = terms['geo_loss'] + terms['mask_loss'] + terms['depth_loss'] * self.lambda_depth + terms['normal_loss_perceptual']
if self.use_color and normal_map is not None:
terms.update(self._perceptual_loss(normal_map[success_mask], buffer['normal_map'], 'normal_map'))
terms['geo_loss'] = terms['geo_loss'] + terms['normal_map_loss_perceptual'] * self.lambda_color
return terms
def color_losses(self, reps, image, alpha, extrinsics, intrinsics):
terms = edict(color_loss = torch.tensor(0.0, device=self.device))
buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=['color'])
success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
if success_mask.sum() != 0:
terms.update(self._perceptual_loss((image * alpha[:, None])[success_mask], buffer['color'][success_mask], 'color'))
terms['color_loss'] = terms['color_loss'] + terms['color_loss_perceptual'] * self.lambda_color
return terms
def training_losses(
self,
latents: SparseTensor,
image: torch.Tensor,
alpha: torch.Tensor,
mesh: List[Dict],
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
normal_map: torch.Tensor = None,
) -> Tuple[Dict, Dict]:
"""
Compute training losses.
Args:
latents: The [N x * x C] sparse latents
image: The [N x 3 x H x W] tensor of images.
alpha: The [N x H x W] tensor of alpha channels.
mesh: The list of dictionaries of meshes.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
Returns:
a dict with the key "loss" containing a scalar tensor.
may also contain other keys for different terms.
"""
reps = self.training_models['decoder'](latents)
self.renderer.rendering_options.resolution = image.shape[-1]
terms = edict(loss = 0.0, rec = 0.0)
terms['reg_loss'] = sum([rep.reg_loss for rep in reps]) / len(reps)
terms['loss'] = terms['loss'] + terms['reg_loss']
geo_terms = self.geometry_losses(reps, mesh, normal_map, extrinsics, intrinsics)
terms.update(geo_terms)
terms['loss'] = terms['loss'] + terms['geo_loss']
if self.use_color:
color_terms = self.color_losses(reps, image, alpha, extrinsics, intrinsics)
terms.update(color_terms)
terms['loss'] = terms['loss'] + terms['color_loss']
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
ret_dict = {}
gt_images = []
gt_normal_maps = []
gt_meshes = []
exts = []
ints = []
reps = []
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
args = recursive_to_device(data, 'cuda')
gt_images.append(args['image'] * args['alpha'][:, None])
if self.use_color and 'normal_map' in data:
gt_normal_maps.append(args['normal_map'])
gt_meshes.extend(args['mesh'])
exts.append(args['extrinsics'])
ints.append(args['intrinsics'])
reps.extend(self.models['decoder'](args['latents']))
gt_images = torch.cat(gt_images, dim=0)
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
if self.use_color and gt_normal_maps:
gt_normal_maps = torch.cat(gt_normal_maps, dim=0)
ret_dict.update({f'gt_normal_map': {'value': gt_normal_maps, 'type': 'image'}})
# render single view
exts = torch.cat(exts, dim=0)
ints = torch.cat(ints, dim=0)
self.renderer.rendering_options.bg_color = (0, 0, 0)
self.renderer.rendering_options.resolution = gt_images.shape[-1]
gt_render_results = self._render_batch([
MeshExtractResult(vertices=mesh['vertices'].to(self.device), faces=mesh['faces'].to(self.device))
for mesh in gt_meshes
], exts, ints, return_types=['normal'])
ret_dict.update({f'gt_normal': {'value': self._flip_normal(gt_render_results['normal'], exts, ints), 'type': 'image'}})
return_types = ['normal']
if self.use_color:
return_types.append('color')
if 'normal_map' in data:
return_types.append('normal_map')
render_results = self._render_batch(reps, exts, ints, return_types=return_types)
ret_dict.update({f'rec_normal': {'value': render_results['normal'], 'type': 'image'}})
if 'color' in return_types:
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
if 'normal_map' in return_types:
ret_dict.update({f'rec_normal_map': {'value': render_results['normal_map'], 'type': 'image'}})
# render multiview
self.renderer.rendering_options.resolution = 512
## Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
## render each view
multiview_normals = []
multiview_normal_maps = []
miltiview_images = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(30)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
render_results = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
multiview_normals.append(render_results['normal'])
if 'color' in return_types:
miltiview_images.append(render_results['color'])
if 'normal_map' in return_types:
multiview_normal_maps.append(render_results['normal_map'])
## Concatenate views
multiview_normals = torch.cat([
torch.cat(multiview_normals[:2], dim=-2),
torch.cat(multiview_normals[2:], dim=-2),
], dim=-1)
ret_dict.update({f'multiview_normal': {'value': multiview_normals, 'type': 'image'}})
if 'color' in return_types:
miltiview_images = torch.cat([
torch.cat(miltiview_images[:2], dim=-2),
torch.cat(miltiview_images[2:], dim=-2),
], dim=-1)
ret_dict.update({f'multiview_image': {'value': miltiview_images, 'type': 'image'}})
if 'normal_map' in return_types:
multiview_normal_maps = torch.cat([
torch.cat(multiview_normal_maps[:2], dim=-2),
torch.cat(multiview_normal_maps[2:], dim=-2),
], dim=-1)
ret_dict.update({f'multiview_normal_map': {'value': multiview_normal_maps, 'type': 'image'}})
return ret_dict

View File

@@ -0,0 +1,223 @@
from typing import *
import copy
import torch
from torch.utils.data import DataLoader
import numpy as np
from easydict import EasyDict as edict
import utils3d.torch
from ..basic import BasicTrainer
from ...representations import Strivec
from ...renderers import OctreeRenderer
from ...modules.sparse import SparseTensor
from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
class SLatVaeRadianceFieldDecoderTrainer(BasicTrainer):
"""
Trainer for structured latent VAE Radiance Field Decoder.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
loss_type (str): Loss type. Can be 'l1', 'l2'
lambda_ssim (float): SSIM loss weight.
lambda_lpips (float): LPIPS loss weight.
"""
def __init__(
self,
*args,
loss_type: str = 'l1',
lambda_ssim: float = 0.2,
lambda_lpips: float = 0.2,
**kwargs
):
super().__init__(*args, **kwargs)
self.loss_type = loss_type
self.lambda_ssim = lambda_ssim
self.lambda_lpips = lambda_lpips
self._init_renderer()
def _init_renderer(self):
rendering_options = {"near" : 0.8,
"far" : 1.6,
"bg_color" : 'random'}
self.renderer = OctreeRenderer(rendering_options)
self.renderer.pipe.primitive = 'trivec'
def _render_batch(self, reps: List[Strivec], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
"""
Render a batch of representations.
Args:
reps: The dictionary of lists of representations.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
"""
ret = None
for i, representation in enumerate(reps):
render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
if ret is None:
ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
for k, v in render_pack.items():
ret[k].append(v)
ret['bg_color'].append(self.renderer.bg_color)
for k, v in ret.items():
ret[k] = torch.stack(v, dim=0)
return ret
def training_losses(
self,
latents: SparseTensor,
image: torch.Tensor,
alpha: torch.Tensor,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
return_aux: bool = False,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses.
Args:
latents: The [N x * x C] sparse latents
image: The [N x 3 x H x W] tensor of images.
alpha: The [N x H x W] tensor of alpha channels.
extrinsics: The [N x 4 x 4] tensor of extrinsics.
intrinsics: The [N x 3 x 3] tensor of intrinsics.
return_aux: Whether to return auxiliary information.
Returns:
a dict with the key "loss" containing a scalar tensor.
may also contain other keys for different terms.
"""
reps = self.training_models['decoder'](latents)
self.renderer.rendering_options.resolution = image.shape[-1]
render_results = self._render_batch(reps, extrinsics, intrinsics)
terms = edict(loss = 0.0, rec = 0.0)
rec_image = render_results['color']
gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
if self.loss_type == 'l1':
terms["l1"] = l1_loss(rec_image, gt_image)
terms["rec"] = terms["rec"] + terms["l1"]
elif self.loss_type == 'l2':
terms["l2"] = l2_loss(rec_image, gt_image)
terms["rec"] = terms["rec"] + terms["l2"]
else:
raise ValueError(f"Invalid loss type: {self.loss_type}")
if self.lambda_ssim > 0:
terms["ssim"] = 1 - ssim(rec_image, gt_image)
terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
if self.lambda_lpips > 0:
terms["lpips"] = lpips(rec_image, gt_image)
terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
terms["loss"] = terms["loss"] + terms["rec"]
if return_aux:
return terms, {}, {'rec_image': rec_image, 'gt_image': gt_image}
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
ret_dict = {}
gt_images = []
exts = []
ints = []
reps = []
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
args = {k: v[:batch].cuda() for k, v in data.items()}
gt_images.append(args['image'] * args['alpha'][:, None])
exts.append(args['extrinsics'])
ints.append(args['intrinsics'])
reps.extend(self.models['decoder'](args['latents']))
gt_images = torch.cat(gt_images, dim=0)
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
# render single view
exts = torch.cat(exts, dim=0)
ints = torch.cat(ints, dim=0)
self.renderer.rendering_options.bg_color = (0, 0, 0)
self.renderer.rendering_options.resolution = gt_images.shape[-1]
render_results = self._render_batch(reps, exts, ints)
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
# render multiview
self.renderer.rendering_options.resolution = 512
## Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
## render each view
miltiview_images = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(30)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
render_results = self._render_batch(reps, extrinsics, intrinsics)
miltiview_images.append(render_results['color'])
## Concatenate views
miltiview_images = torch.cat([
torch.cat(miltiview_images[:2], dim=-2),
torch.cat(miltiview_images[2:], dim=-2),
], dim=-1)
ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
self.renderer.rendering_options.bg_color = 'random'
return ret_dict