1
This commit is contained in:
131
trellis/models/structured_latent_vae/decoder_gs.py
Normal file
131
trellis/models/structured_latent_vae/decoder_gs.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ...modules import sparse as sp
|
||||
from ...utils.random_utils import hammersley_sequence
|
||||
from .base import SparseTransformerBase
|
||||
from ...representations import Gaussian
|
||||
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
||||
|
||||
|
||||
class SLatGaussianDecoder(SparseTransformerBase):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
model_channels: int,
|
||||
latent_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
|
||||
window_size: int = 8,
|
||||
pe_mode: Literal["ape", "rope"] = "ape",
|
||||
use_fp16: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
representation_config: dict = None,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=latent_channels,
|
||||
model_channels=model_channels,
|
||||
num_blocks=num_blocks,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
pe_mode=pe_mode,
|
||||
use_fp16=use_fp16,
|
||||
use_checkpoint=use_checkpoint,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.resolution = resolution
|
||||
self.rep_config = representation_config
|
||||
self._calc_layout()
|
||||
self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
|
||||
self._build_perturbation()
|
||||
|
||||
self.initialize_weights()
|
||||
if use_fp16:
|
||||
self.convert_to_fp16()
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
super().initialize_weights()
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.out_layer.weight, 0)
|
||||
nn.init.constant_(self.out_layer.bias, 0)
|
||||
|
||||
def _build_perturbation(self) -> None:
|
||||
perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
|
||||
perturbation = torch.tensor(perturbation).float() * 2 - 1
|
||||
perturbation = perturbation / self.rep_config['voxel_size']
|
||||
perturbation = torch.atanh(perturbation).to(self.device)
|
||||
self.register_buffer('offset_perturbation', perturbation)
|
||||
|
||||
def _calc_layout(self) -> None:
|
||||
self.layout = {
|
||||
'_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
|
||||
'_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
|
||||
'_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
|
||||
'_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
|
||||
'_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
|
||||
}
|
||||
start = 0
|
||||
for k, v in self.layout.items():
|
||||
v['range'] = (start, start + v['size'])
|
||||
start += v['size']
|
||||
self.out_channels = start
|
||||
|
||||
def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
|
||||
"""
|
||||
Convert a batch of network outputs to 3D representations.
|
||||
|
||||
Args:
|
||||
x: The [N x * x C] sparse tensor output by the network.
|
||||
|
||||
Returns:
|
||||
list of representations
|
||||
"""
|
||||
ret = []
|
||||
for i in range(x.shape[0]):
|
||||
representation = Gaussian(
|
||||
sh_degree=0,
|
||||
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
|
||||
mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
|
||||
scaling_bias = self.rep_config['scaling_bias'],
|
||||
opacity_bias = self.rep_config['opacity_bias'],
|
||||
scaling_activation = self.rep_config['scaling_activation']
|
||||
)
|
||||
xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
|
||||
for k, v in self.layout.items():
|
||||
if k == '_xyz':
|
||||
offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
|
||||
offset = offset * self.rep_config['lr'][k]
|
||||
if self.rep_config['perturb_offset']:
|
||||
offset = offset + self.offset_perturbation
|
||||
offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
|
||||
_xyz = xyz.unsqueeze(1) + offset
|
||||
setattr(representation, k, _xyz.flatten(0, 1))
|
||||
else:
|
||||
feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
|
||||
feats = feats * self.rep_config['lr'][k]
|
||||
setattr(representation, k, feats)
|
||||
ret.append(representation)
|
||||
return ret
|
||||
|
||||
def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
|
||||
h = super().forward(x)
|
||||
h = h.type(x.dtype)
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.out_layer(h)
|
||||
return self.to_representation(h)
|
||||
|
||||
|
||||
class ElasticSLatGaussianDecoder(SparseTransformerElasticMixin, SLatGaussianDecoder):
|
||||
"""
|
||||
Slat VAE Gaussian decoder with elastic memory management.
|
||||
Used for training with low VRAM.
|
||||
"""
|
||||
pass
|
||||
176
trellis/models/structured_latent_vae/decoder_mesh.py
Normal file
176
trellis/models/structured_latent_vae/decoder_mesh.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
|
||||
from ...modules import sparse as sp
|
||||
from .base import SparseTransformerBase
|
||||
from ...representations import MeshExtractResult
|
||||
from ...representations.mesh import SparseFeatures2Mesh
|
||||
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
||||
|
||||
|
||||
class SparseSubdivideBlock3d(nn.Module):
|
||||
"""
|
||||
A 3D subdivide block that can subdivide the sparse tensor.
|
||||
|
||||
Args:
|
||||
channels: channels in the inputs and outputs.
|
||||
out_channels: if specified, the number of output channels.
|
||||
num_groups: the number of groups for the group norm.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
resolution: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_groups: int = 32
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.resolution = resolution
|
||||
self.out_resolution = resolution * 2
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.act_layers = nn.Sequential(
|
||||
sp.SparseGroupNorm32(num_groups, channels),
|
||||
sp.SparseSiLU()
|
||||
)
|
||||
|
||||
self.sub = sp.SparseSubdivide()
|
||||
|
||||
self.out_layers = nn.Sequential(
|
||||
sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
|
||||
sp.SparseGroupNorm32(num_groups, self.out_channels),
|
||||
sp.SparseSiLU(),
|
||||
zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
|
||||
|
||||
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
|
||||
Args:
|
||||
x: an [N x C x ...] Tensor of features.
|
||||
Returns:
|
||||
an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
h = self.act_layers(x)
|
||||
h = self.sub(h)
|
||||
x = self.sub(x)
|
||||
h = self.out_layers(h)
|
||||
h = h + self.skip_connection(x)
|
||||
return h
|
||||
|
||||
|
||||
class SLatMeshDecoder(SparseTransformerBase):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
model_channels: int,
|
||||
latent_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
|
||||
window_size: int = 8,
|
||||
pe_mode: Literal["ape", "rope"] = "ape",
|
||||
use_fp16: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
representation_config: dict = None,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=latent_channels,
|
||||
model_channels=model_channels,
|
||||
num_blocks=num_blocks,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
pe_mode=pe_mode,
|
||||
use_fp16=use_fp16,
|
||||
use_checkpoint=use_checkpoint,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.resolution = resolution
|
||||
self.rep_config = representation_config
|
||||
self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
|
||||
self.out_channels = self.mesh_extractor.feats_channels
|
||||
self.upsample = nn.ModuleList([
|
||||
SparseSubdivideBlock3d(
|
||||
channels=model_channels,
|
||||
resolution=resolution,
|
||||
out_channels=model_channels // 4
|
||||
),
|
||||
SparseSubdivideBlock3d(
|
||||
channels=model_channels // 4,
|
||||
resolution=resolution * 2,
|
||||
out_channels=model_channels // 8
|
||||
)
|
||||
])
|
||||
self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
|
||||
|
||||
self.initialize_weights()
|
||||
if use_fp16:
|
||||
self.convert_to_fp16()
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
super().initialize_weights()
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.out_layer.weight, 0)
|
||||
nn.init.constant_(self.out_layer.bias, 0)
|
||||
|
||||
def convert_to_fp16(self) -> None:
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
super().convert_to_fp16()
|
||||
self.upsample.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self) -> None:
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
super().convert_to_fp32()
|
||||
self.upsample.apply(convert_module_to_f32)
|
||||
|
||||
def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
|
||||
"""
|
||||
Convert a batch of network outputs to 3D representations.
|
||||
|
||||
Args:
|
||||
x: The [N x * x C] sparse tensor output by the network.
|
||||
|
||||
Returns:
|
||||
list of representations
|
||||
"""
|
||||
ret = []
|
||||
for i in range(x.shape[0]):
|
||||
mesh = self.mesh_extractor(x[i], training=self.training)
|
||||
ret.append(mesh)
|
||||
return ret
|
||||
|
||||
def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
|
||||
h = super().forward(x)
|
||||
for block in self.upsample:
|
||||
h = block(h)
|
||||
h = h.type(x.dtype)
|
||||
h = self.out_layer(h)
|
||||
return self.to_representation(h)
|
||||
|
||||
|
||||
class ElasticSLatMeshDecoder(SparseTransformerElasticMixin, SLatMeshDecoder):
|
||||
"""
|
||||
Slat VAE Mesh decoder with elastic memory management.
|
||||
Used for training with low VRAM.
|
||||
"""
|
||||
pass
|
||||
113
trellis/models/structured_latent_vae/decoder_rf.py
Normal file
113
trellis/models/structured_latent_vae/decoder_rf.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from ...modules import sparse as sp
|
||||
from .base import SparseTransformerBase
|
||||
from ...representations import Strivec
|
||||
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
||||
|
||||
|
||||
class SLatRadianceFieldDecoder(SparseTransformerBase):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
model_channels: int,
|
||||
latent_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
|
||||
window_size: int = 8,
|
||||
pe_mode: Literal["ape", "rope"] = "ape",
|
||||
use_fp16: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
representation_config: dict = None,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=latent_channels,
|
||||
model_channels=model_channels,
|
||||
num_blocks=num_blocks,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
pe_mode=pe_mode,
|
||||
use_fp16=use_fp16,
|
||||
use_checkpoint=use_checkpoint,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.resolution = resolution
|
||||
self.rep_config = representation_config
|
||||
self._calc_layout()
|
||||
self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
|
||||
|
||||
self.initialize_weights()
|
||||
if use_fp16:
|
||||
self.convert_to_fp16()
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
super().initialize_weights()
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.out_layer.weight, 0)
|
||||
nn.init.constant_(self.out_layer.bias, 0)
|
||||
|
||||
def _calc_layout(self) -> None:
|
||||
self.layout = {
|
||||
'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
|
||||
'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
|
||||
'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
|
||||
}
|
||||
start = 0
|
||||
for k, v in self.layout.items():
|
||||
v['range'] = (start, start + v['size'])
|
||||
start += v['size']
|
||||
self.out_channels = start
|
||||
|
||||
def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
|
||||
"""
|
||||
Convert a batch of network outputs to 3D representations.
|
||||
|
||||
Args:
|
||||
x: The [N x * x C] sparse tensor output by the network.
|
||||
|
||||
Returns:
|
||||
list of representations
|
||||
"""
|
||||
ret = []
|
||||
for i in range(x.shape[0]):
|
||||
representation = Strivec(
|
||||
sh_degree=0,
|
||||
resolution=self.resolution,
|
||||
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
|
||||
rank=self.rep_config['rank'],
|
||||
dim=self.rep_config['dim'],
|
||||
device='cuda',
|
||||
)
|
||||
representation.density_shift = 0.0
|
||||
representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
|
||||
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
|
||||
for k, v in self.layout.items():
|
||||
setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
|
||||
representation.trivec = representation.trivec + 1
|
||||
ret.append(representation)
|
||||
return ret
|
||||
|
||||
def forward(self, x: sp.SparseTensor) -> List[Strivec]:
|
||||
h = super().forward(x)
|
||||
h = h.type(x.dtype)
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.out_layer(h)
|
||||
return self.to_representation(h)
|
||||
|
||||
|
||||
class ElasticSLatRadianceFieldDecoder(SparseTransformerElasticMixin, SLatRadianceFieldDecoder):
|
||||
"""
|
||||
Slat VAE Radiance Field Decoder with elastic memory management.
|
||||
Used for training with low VRAM.
|
||||
"""
|
||||
pass
|
||||
80
trellis/models/structured_latent_vae/encoder.py
Normal file
80
trellis/models/structured_latent_vae/encoder.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ...modules import sparse as sp
|
||||
from .base import SparseTransformerBase
|
||||
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
|
||||
|
||||
|
||||
class SLatEncoder(SparseTransformerBase):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
latent_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4,
|
||||
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
|
||||
window_size: int = 8,
|
||||
pe_mode: Literal["ape", "rope"] = "ape",
|
||||
use_fp16: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
qk_rms_norm: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
model_channels=model_channels,
|
||||
num_blocks=num_blocks,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
pe_mode=pe_mode,
|
||||
use_fp16=use_fp16,
|
||||
use_checkpoint=use_checkpoint,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.resolution = resolution
|
||||
self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
|
||||
|
||||
self.initialize_weights()
|
||||
if use_fp16:
|
||||
self.convert_to_fp16()
|
||||
|
||||
def initialize_weights(self) -> None:
|
||||
super().initialize_weights()
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.out_layer.weight, 0)
|
||||
nn.init.constant_(self.out_layer.bias, 0)
|
||||
|
||||
def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
|
||||
h = super().forward(x)
|
||||
h = h.type(x.dtype)
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.out_layer(h)
|
||||
|
||||
# Sample from the posterior distribution
|
||||
mean, logvar = h.feats.chunk(2, dim=-1)
|
||||
if sample_posterior:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
z = mean + std * torch.randn_like(std)
|
||||
else:
|
||||
z = mean
|
||||
z = h.replace(z)
|
||||
|
||||
if return_raw:
|
||||
return z, mean, logvar
|
||||
else:
|
||||
return z
|
||||
|
||||
|
||||
class ElasticSLatEncoder(SparseTransformerElasticMixin, SLatEncoder):
|
||||
"""
|
||||
SLat VAE encoder with elastic memory management.
|
||||
Used for training with low VRAM.
|
||||
"""
|
||||
140
trellis/modules/attention/full_attn.py
Executable file
140
trellis/modules/attention/full_attn.py
Executable file
@@ -0,0 +1,140 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import math
|
||||
from . import DEBUG, BACKEND
|
||||
|
||||
if BACKEND == 'xformers':
|
||||
import xformers.ops as xops
|
||||
elif BACKEND == 'flash_attn':
|
||||
import flash_attn
|
||||
elif BACKEND == 'sdpa':
|
||||
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
||||
elif BACKEND == 'naive':
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown attention backend: {BACKEND}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
'scaled_dot_product_attention',
|
||||
]
|
||||
|
||||
|
||||
def _naive_sdpa(q, k, v):
|
||||
"""
|
||||
Naive implementation of scaled dot product attention.
|
||||
"""
|
||||
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
scale_factor = 1 / math.sqrt(q.size(-1))
|
||||
attn_weight = q @ k.transpose(-2, -1) * scale_factor
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
out = attn_weight @ v
|
||||
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
||||
return out
|
||||
|
||||
|
||||
@overload
|
||||
def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply scaled dot product attention.
|
||||
|
||||
Args:
|
||||
qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply scaled dot product attention.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
|
||||
kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply scaled dot product attention.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
|
||||
k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
|
||||
v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
|
||||
|
||||
Note:
|
||||
k and v are assumed to have the same coordinate map.
|
||||
"""
|
||||
...
|
||||
|
||||
def scaled_dot_product_attention(*args, **kwargs):
|
||||
arg_names_dict = {
|
||||
1: ['qkv'],
|
||||
2: ['q', 'kv'],
|
||||
3: ['q', 'k', 'v']
|
||||
}
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
|
||||
for key in arg_names_dict[num_all_args][len(args):]:
|
||||
assert key in kwargs, f"Missing argument {key}"
|
||||
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||
assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
|
||||
device = qkv.device
|
||||
|
||||
elif num_all_args == 2:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
kv = args[1] if len(args) > 1 else kwargs['kv']
|
||||
assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
|
||||
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
|
||||
assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
|
||||
device = q.device
|
||||
|
||||
elif num_all_args == 3:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
k = args[1] if len(args) > 1 else kwargs['k']
|
||||
v = args[2] if len(args) > 2 else kwargs['v']
|
||||
assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
|
||||
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
|
||||
assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
|
||||
assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
|
||||
device = q.device
|
||||
|
||||
if BACKEND == 'xformers':
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
out = xops.memory_efficient_attention(q, k, v)
|
||||
elif BACKEND == 'flash_attn':
|
||||
if num_all_args == 1:
|
||||
out = flash_attn.flash_attn_qkvpacked_func(qkv)
|
||||
elif num_all_args == 2:
|
||||
out = flash_attn.flash_attn_kvpacked_func(q, kv)
|
||||
elif num_all_args == 3:
|
||||
out = flash_attn.flash_attn_func(q, k, v)
|
||||
elif BACKEND == 'sdpa':
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
out = sdpa(q, k, v) # [N, H, L, C]
|
||||
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
||||
elif BACKEND == 'naive':
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
out = _naive_sdpa(q, k, v)
|
||||
else:
|
||||
raise ValueError(f"Unknown attention module: {BACKEND}")
|
||||
|
||||
return out
|
||||
215
trellis/modules/sparse/attention/full_attn.py
Executable file
215
trellis/modules/sparse/attention/full_attn.py
Executable file
@@ -0,0 +1,215 @@
|
||||
from typing import *
|
||||
import torch
|
||||
from .. import SparseTensor
|
||||
from .. import DEBUG, ATTN
|
||||
|
||||
if ATTN == 'xformers':
|
||||
import xformers.ops as xops
|
||||
elif ATTN == 'flash_attn':
|
||||
import flash_attn
|
||||
else:
|
||||
raise ValueError(f"Unknown attention module: {ATTN}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
'sparse_scaled_dot_product_attention',
|
||||
]
|
||||
|
||||
|
||||
@overload
|
||||
def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
|
||||
"""
|
||||
Apply scaled dot product attention to a sparse tensor.
|
||||
|
||||
Args:
|
||||
qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
|
||||
"""
|
||||
Apply scaled dot product attention to a sparse tensor.
|
||||
|
||||
Args:
|
||||
q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
|
||||
kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply scaled dot product attention to a sparse tensor.
|
||||
|
||||
Args:
|
||||
q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
|
||||
kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
|
||||
"""
|
||||
Apply scaled dot product attention to a sparse tensor.
|
||||
|
||||
Args:
|
||||
q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
|
||||
k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
|
||||
v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
|
||||
|
||||
Note:
|
||||
k and v are assumed to have the same coordinate map.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
|
||||
"""
|
||||
Apply scaled dot product attention to a sparse tensor.
|
||||
|
||||
Args:
|
||||
q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
|
||||
k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
|
||||
v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply scaled dot product attention to a sparse tensor.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
|
||||
k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
|
||||
v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
|
||||
"""
|
||||
...
|
||||
|
||||
def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||
arg_names_dict = {
|
||||
1: ['qkv'],
|
||||
2: ['q', 'kv'],
|
||||
3: ['q', 'k', 'v']
|
||||
}
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
|
||||
for key in arg_names_dict[num_all_args][len(args):]:
|
||||
assert key in kwargs, f"Missing argument {key}"
|
||||
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||
assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
|
||||
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
|
||||
device = qkv.device
|
||||
|
||||
s = qkv
|
||||
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
|
||||
kv_seqlen = q_seqlen
|
||||
qkv = qkv.feats # [T, 3, H, C]
|
||||
|
||||
elif num_all_args == 2:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
kv = args[1] if len(args) > 1 else kwargs['kv']
|
||||
assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
|
||||
isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
|
||||
f"Invalid types, got {type(q)} and {type(kv)}"
|
||||
assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
|
||||
device = q.device
|
||||
|
||||
if isinstance(q, SparseTensor):
|
||||
assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
|
||||
s = q
|
||||
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||
q = q.feats # [T_Q, H, C]
|
||||
else:
|
||||
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
|
||||
s = None
|
||||
N, L, H, C = q.shape
|
||||
q_seqlen = [L] * N
|
||||
q = q.reshape(N * L, H, C) # [T_Q, H, C]
|
||||
|
||||
if isinstance(kv, SparseTensor):
|
||||
assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
|
||||
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
|
||||
kv = kv.feats # [T_KV, 2, H, C]
|
||||
else:
|
||||
assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
|
||||
N, L, _, H, C = kv.shape
|
||||
kv_seqlen = [L] * N
|
||||
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
|
||||
|
||||
elif num_all_args == 3:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
k = args[1] if len(args) > 1 else kwargs['k']
|
||||
v = args[2] if len(args) > 2 else kwargs['v']
|
||||
assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
|
||||
isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
|
||||
f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
|
||||
assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
|
||||
device = q.device
|
||||
|
||||
if isinstance(q, SparseTensor):
|
||||
assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
|
||||
s = q
|
||||
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||
q = q.feats # [T_Q, H, Ci]
|
||||
else:
|
||||
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
|
||||
s = None
|
||||
N, L, H, CI = q.shape
|
||||
q_seqlen = [L] * N
|
||||
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
|
||||
|
||||
if isinstance(k, SparseTensor):
|
||||
assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
|
||||
assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
|
||||
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
|
||||
k = k.feats # [T_KV, H, Ci]
|
||||
v = v.feats # [T_KV, H, Co]
|
||||
else:
|
||||
assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
|
||||
assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
|
||||
N, L, H, CI, CO = *k.shape, v.shape[-1]
|
||||
kv_seqlen = [L] * N
|
||||
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
|
||||
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
|
||||
|
||||
if DEBUG:
|
||||
if s is not None:
|
||||
for i in range(s.shape[0]):
|
||||
assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
|
||||
if num_all_args in [2, 3]:
|
||||
assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
|
||||
if num_all_args == 3:
|
||||
assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
|
||||
assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
|
||||
|
||||
if ATTN == 'xformers':
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=1)
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
|
||||
out = xops.memory_efficient_attention(q, k, v, mask)[0]
|
||||
elif ATTN == 'flash_attn':
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args in [2, 3]:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args == 1:
|
||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
|
||||
elif num_all_args == 2:
|
||||
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
elif num_all_args == 3:
|
||||
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
else:
|
||||
raise ValueError(f"Unknown attention module: {ATTN}")
|
||||
|
||||
if s is not None:
|
||||
return s.replace(out)
|
||||
else:
|
||||
return out.reshape(N, L, H, -1)
|
||||
201
trellis/pipelines/samplers/flow_euler.py
Normal file
201
trellis/pipelines/samplers/flow_euler.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from easydict import EasyDict as edict
|
||||
from .base import Sampler
|
||||
from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin
|
||||
from .guidance_interval_mixin import GuidanceIntervalSamplerMixin
|
||||
|
||||
|
||||
class FlowEulerSampler(Sampler):
|
||||
"""
|
||||
Generate samples from a flow-matching model using Euler sampling.
|
||||
|
||||
Args:
|
||||
sigma_min: The minimum scale of noise in flow.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
sigma_min: float,
|
||||
):
|
||||
self.sigma_min = sigma_min
|
||||
|
||||
def _eps_to_xstart(self, x_t, t, eps):
|
||||
assert x_t.shape == eps.shape
|
||||
return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t)
|
||||
|
||||
def _xstart_to_eps(self, x_t, t, x_0):
|
||||
assert x_t.shape == x_0.shape
|
||||
return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
|
||||
|
||||
def _v_to_xstart_eps(self, x_t, t, v):
|
||||
assert x_t.shape == v.shape
|
||||
eps = (1 - t) * v + x_t
|
||||
x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
|
||||
return x_0, eps
|
||||
|
||||
def _inference_model(self, model, x_t, t, cond=None, **kwargs):
|
||||
t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
|
||||
if cond is not None and cond.shape[0] == 1 and x_t.shape[0] > 1:
|
||||
cond = cond.repeat(x_t.shape[0], *([1] * (len(cond.shape) - 1)))
|
||||
return model(x_t, t, cond, **kwargs)
|
||||
|
||||
def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
|
||||
pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
|
||||
pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
|
||||
return pred_x_0, pred_eps, pred_v
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_once(
|
||||
self,
|
||||
model,
|
||||
x_t,
|
||||
t: float,
|
||||
t_prev: float,
|
||||
cond: Optional[Any] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Sample x_{t-1} from the model using Euler method.
|
||||
|
||||
Args:
|
||||
model: The model to sample from.
|
||||
x_t: The [N x C x ...] tensor of noisy inputs at time t.
|
||||
t: The current timestep.
|
||||
t_prev: The previous timestep.
|
||||
cond: conditional information.
|
||||
**kwargs: Additional arguments for model inference.
|
||||
|
||||
Returns:
|
||||
a dict containing the following
|
||||
- 'pred_x_prev': x_{t-1}.
|
||||
- 'pred_x_0': a prediction of x_0.
|
||||
"""
|
||||
pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
|
||||
pred_x_prev = x_t - (t - t_prev) * pred_v
|
||||
return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
model,
|
||||
noise,
|
||||
cond: Optional[Any] = None,
|
||||
steps: int = 50,
|
||||
rescale_t: float = 1.0,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Generate samples from the model using Euler method.
|
||||
|
||||
Args:
|
||||
model: The model to sample from.
|
||||
noise: The initial noise tensor.
|
||||
cond: conditional information.
|
||||
steps: The number of steps to sample.
|
||||
rescale_t: The rescale factor for t.
|
||||
verbose: If True, show a progress bar.
|
||||
**kwargs: Additional arguments for model_inference.
|
||||
|
||||
Returns:
|
||||
a dict containing the following
|
||||
- 'samples': the model samples.
|
||||
- 'pred_x_t': a list of prediction of x_t.
|
||||
- 'pred_x_0': a list of prediction of x_0.
|
||||
"""
|
||||
sample = noise
|
||||
t_seq = np.linspace(1, 0, steps + 1)
|
||||
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
|
||||
t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
|
||||
ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []})
|
||||
for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose):
|
||||
out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
|
||||
sample = out.pred_x_prev
|
||||
ret.pred_x_t.append(out.pred_x_prev)
|
||||
ret.pred_x_0.append(out.pred_x_0)
|
||||
ret.samples = sample
|
||||
return ret
|
||||
|
||||
|
||||
class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
|
||||
"""
|
||||
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
|
||||
"""
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
model,
|
||||
noise,
|
||||
cond,
|
||||
neg_cond,
|
||||
steps: int = 50,
|
||||
rescale_t: float = 1.0,
|
||||
cfg_strength: float = 3.0,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Generate samples from the model using Euler method.
|
||||
|
||||
Args:
|
||||
model: The model to sample from.
|
||||
noise: The initial noise tensor.
|
||||
cond: conditional information.
|
||||
neg_cond: negative conditional information.
|
||||
steps: The number of steps to sample.
|
||||
rescale_t: The rescale factor for t.
|
||||
cfg_strength: The strength of classifier-free guidance.
|
||||
verbose: If True, show a progress bar.
|
||||
**kwargs: Additional arguments for model_inference.
|
||||
|
||||
Returns:
|
||||
a dict containing the following
|
||||
- 'samples': the model samples.
|
||||
- 'pred_x_t': a list of prediction of x_t.
|
||||
- 'pred_x_0': a list of prediction of x_0.
|
||||
"""
|
||||
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs)
|
||||
|
||||
|
||||
class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler):
|
||||
"""
|
||||
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
|
||||
"""
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
model,
|
||||
noise,
|
||||
cond,
|
||||
neg_cond,
|
||||
steps: int = 50,
|
||||
rescale_t: float = 1.0,
|
||||
cfg_strength: float = 3.0,
|
||||
cfg_interval: Tuple[float, float] = (0.0, 1.0),
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Generate samples from the model using Euler method.
|
||||
|
||||
Args:
|
||||
model: The model to sample from.
|
||||
noise: The initial noise tensor.
|
||||
cond: conditional information.
|
||||
neg_cond: negative conditional information.
|
||||
steps: The number of steps to sample.
|
||||
rescale_t: The rescale factor for t.
|
||||
cfg_strength: The strength of classifier-free guidance.
|
||||
cfg_interval: The interval for classifier-free guidance.
|
||||
verbose: If True, show a progress bar.
|
||||
**kwargs: Additional arguments for model_inference.
|
||||
|
||||
Returns:
|
||||
a dict containing the following
|
||||
- 'samples': the model samples.
|
||||
- 'pred_x_t': a list of prediction of x_t.
|
||||
- 'pred_x_0': a list of prediction of x_0.
|
||||
"""
|
||||
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)
|
||||
209
trellis/representations/gaussian/gaussian_model.py
Executable file
209
trellis/representations/gaussian/gaussian_model.py
Executable file
@@ -0,0 +1,209 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from plyfile import PlyData, PlyElement
|
||||
from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation
|
||||
import utils3d
|
||||
|
||||
|
||||
class Gaussian:
|
||||
def __init__(
|
||||
self,
|
||||
aabb : list,
|
||||
sh_degree : int = 0,
|
||||
mininum_kernel_size : float = 0.0,
|
||||
scaling_bias : float = 0.01,
|
||||
opacity_bias : float = 0.1,
|
||||
scaling_activation : str = "exp",
|
||||
device='cuda'
|
||||
):
|
||||
self.init_params = {
|
||||
'aabb': aabb,
|
||||
'sh_degree': sh_degree,
|
||||
'mininum_kernel_size': mininum_kernel_size,
|
||||
'scaling_bias': scaling_bias,
|
||||
'opacity_bias': opacity_bias,
|
||||
'scaling_activation': scaling_activation,
|
||||
}
|
||||
|
||||
self.sh_degree = sh_degree
|
||||
self.active_sh_degree = sh_degree
|
||||
self.mininum_kernel_size = mininum_kernel_size
|
||||
self.scaling_bias = scaling_bias
|
||||
self.opacity_bias = opacity_bias
|
||||
self.scaling_activation_type = scaling_activation
|
||||
self.device = device
|
||||
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||
self.setup_functions()
|
||||
|
||||
self._xyz = None
|
||||
self._features_dc = None
|
||||
self._features_rest = None
|
||||
self._scaling = None
|
||||
self._rotation = None
|
||||
self._opacity = None
|
||||
|
||||
def setup_functions(self):
|
||||
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
||||
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
||||
actual_covariance = L @ L.transpose(1, 2)
|
||||
symm = strip_symmetric(actual_covariance)
|
||||
return symm
|
||||
|
||||
if self.scaling_activation_type == "exp":
|
||||
self.scaling_activation = torch.exp
|
||||
self.inverse_scaling_activation = torch.log
|
||||
elif self.scaling_activation_type == "softplus":
|
||||
self.scaling_activation = torch.nn.functional.softplus
|
||||
self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
|
||||
|
||||
self.covariance_activation = build_covariance_from_scaling_rotation
|
||||
|
||||
self.opacity_activation = torch.sigmoid
|
||||
self.inverse_opacity_activation = inverse_sigmoid
|
||||
|
||||
self.rotation_activation = torch.nn.functional.normalize
|
||||
|
||||
self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda()
|
||||
self.rots_bias = torch.zeros((4)).cuda()
|
||||
self.rots_bias[0] = 1
|
||||
self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda()
|
||||
|
||||
@property
|
||||
def get_scaling(self):
|
||||
scales = self.scaling_activation(self._scaling + self.scale_bias)
|
||||
scales = torch.square(scales) + self.mininum_kernel_size ** 2
|
||||
scales = torch.sqrt(scales)
|
||||
return scales
|
||||
|
||||
@property
|
||||
def get_rotation(self):
|
||||
return self.rotation_activation(self._rotation + self.rots_bias[None, :])
|
||||
|
||||
@property
|
||||
def get_xyz(self):
|
||||
return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3]
|
||||
|
||||
@property
|
||||
def get_features(self):
|
||||
return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc
|
||||
|
||||
@property
|
||||
def get_opacity(self):
|
||||
return self.opacity_activation(self._opacity + self.opacity_bias)
|
||||
|
||||
def get_covariance(self, scaling_modifier = 1):
|
||||
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :])
|
||||
|
||||
def from_scaling(self, scales):
|
||||
scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)
|
||||
self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias
|
||||
|
||||
def from_rotation(self, rots):
|
||||
self._rotation = rots - self.rots_bias[None, :]
|
||||
|
||||
def from_xyz(self, xyz):
|
||||
self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
|
||||
|
||||
def from_features(self, features):
|
||||
self._features_dc = features
|
||||
|
||||
def from_opacity(self, opacities):
|
||||
self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
|
||||
|
||||
def construct_list_of_attributes(self):
|
||||
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
||||
# All channels except the 3 DC
|
||||
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
|
||||
l.append('f_dc_{}'.format(i))
|
||||
l.append('opacity')
|
||||
for i in range(self._scaling.shape[1]):
|
||||
l.append('scale_{}'.format(i))
|
||||
for i in range(self._rotation.shape[1]):
|
||||
l.append('rot_{}'.format(i))
|
||||
return l
|
||||
|
||||
def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
|
||||
xyz = self.get_xyz.detach().cpu().numpy()
|
||||
normals = np.zeros_like(xyz)
|
||||
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
||||
opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
|
||||
scale = torch.log(self.get_scaling).detach().cpu().numpy()
|
||||
rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
|
||||
|
||||
if transform is not None:
|
||||
transform = np.array(transform)
|
||||
xyz = np.matmul(xyz, transform.T)
|
||||
rotation = utils3d.numpy.quaternion_to_matrix(rotation)
|
||||
rotation = np.matmul(transform, rotation)
|
||||
rotation = utils3d.numpy.matrix_to_quaternion(rotation)
|
||||
|
||||
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
|
||||
|
||||
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
||||
attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)
|
||||
elements[:] = list(map(tuple, attributes))
|
||||
el = PlyElement.describe(elements, 'vertex')
|
||||
PlyData([el]).write(path)
|
||||
|
||||
def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
|
||||
plydata = PlyData.read(path)
|
||||
|
||||
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
||||
np.asarray(plydata.elements[0]["y"]),
|
||||
np.asarray(plydata.elements[0]["z"])), axis=1)
|
||||
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
||||
|
||||
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
||||
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
||||
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
|
||||
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
|
||||
|
||||
if self.sh_degree > 0:
|
||||
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
|
||||
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
|
||||
assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3
|
||||
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
||||
for idx, attr_name in enumerate(extra_f_names):
|
||||
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
||||
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
|
||||
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
|
||||
|
||||
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
||||
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
|
||||
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
||||
for idx, attr_name in enumerate(scale_names):
|
||||
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
||||
|
||||
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
|
||||
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
|
||||
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
||||
for idx, attr_name in enumerate(rot_names):
|
||||
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
||||
|
||||
if transform is not None:
|
||||
transform = np.array(transform)
|
||||
xyz = np.matmul(xyz, transform)
|
||||
rotation = utils3d.numpy.quaternion_to_matrix(rotation)
|
||||
rotation = np.matmul(rotation, transform)
|
||||
rotation = utils3d.numpy.matrix_to_quaternion(rotation)
|
||||
|
||||
# convert to actual gaussian attributes
|
||||
xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
|
||||
features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
|
||||
if self.sh_degree > 0:
|
||||
features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
|
||||
opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device))
|
||||
scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device))
|
||||
rots = torch.tensor(rots, dtype=torch.float, device=self.device)
|
||||
|
||||
# convert to _hidden attributes
|
||||
self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
|
||||
self._features_dc = features_dc
|
||||
if self.sh_degree > 0:
|
||||
self._features_rest = features_extra
|
||||
else:
|
||||
self._features_rest = None
|
||||
self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
|
||||
self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias
|
||||
self._rotation = rots - self.rots_bias[None, :]
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import requests
|
||||
from zipfile import ZipFile
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
def download_file(url, output_path):
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 #1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
|
||||
with open(output_path, 'wb') as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
||||
raise Exception("ERROR, something went wrong")
|
||||
|
||||
|
||||
url = "https://vcg.isti.cnr.it/Publications/2014/MPZ14/inputmodels.zip"
|
||||
zip_file_path = './data/inputmodels.zip'
|
||||
|
||||
os.makedirs('./data', exist_ok=True)
|
||||
|
||||
download_file(url, zip_file_path)
|
||||
|
||||
with ZipFile(zip_file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall('./data')
|
||||
|
||||
os.remove(zip_file_path)
|
||||
|
||||
print("Download and extraction complete.")
|
||||
157
trellis/representations/mesh/flexicubes/examples/optimize.py
Normal file
157
trellis/representations/mesh/flexicubes/examples/optimize.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
import trimesh
|
||||
import os
|
||||
from util import *
|
||||
import render
|
||||
import loss
|
||||
import imageio
|
||||
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
from flexicubes import FlexiCubes
|
||||
|
||||
###############################################################################
|
||||
# Functions adapted from https://github.com/NVlabs/nvdiffrec
|
||||
###############################################################################
|
||||
|
||||
def lr_schedule(iter):
|
||||
return max(0.0, 10**(-(iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='flexicubes optimization')
|
||||
parser.add_argument('-o', '--out_dir', type=str, default=None)
|
||||
parser.add_argument('-rm', '--ref_mesh', type=str)
|
||||
|
||||
parser.add_argument('-i', '--iter', type=int, default=1000)
|
||||
parser.add_argument('-b', '--batch', type=int, default=8)
|
||||
parser.add_argument('-r', '--train_res', nargs=2, type=int, default=[2048, 2048])
|
||||
parser.add_argument('-lr', '--learning_rate', type=float, default=0.01)
|
||||
parser.add_argument('--voxel_grid_res', type=int, default=64)
|
||||
|
||||
parser.add_argument('--sdf_loss', type=bool, default=True)
|
||||
parser.add_argument('--develop_reg', type=bool, default=False)
|
||||
parser.add_argument('--sdf_regularizer', type=float, default=0.2)
|
||||
|
||||
parser.add_argument('-dr', '--display_res', nargs=2, type=int, default=[512, 512])
|
||||
parser.add_argument('-si', '--save_interval', type=int, default=20)
|
||||
FLAGS = parser.parse_args()
|
||||
device = 'cuda'
|
||||
|
||||
os.makedirs(FLAGS.out_dir, exist_ok=True)
|
||||
glctx = dr.RasterizeGLContext()
|
||||
|
||||
# Load GT mesh
|
||||
gt_mesh = load_mesh(FLAGS.ref_mesh, device)
|
||||
gt_mesh.auto_normals() # compute face normals for visualization
|
||||
|
||||
# ==============================================================================================
|
||||
# Create and initialize FlexiCubes
|
||||
# ==============================================================================================
|
||||
fc = FlexiCubes(device)
|
||||
x_nx3, cube_fx8 = fc.construct_voxel_grid(FLAGS.voxel_grid_res)
|
||||
x_nx3 *= 2 # scale up the grid so that it's larger than the target object
|
||||
|
||||
sdf = torch.rand_like(x_nx3[:,0]) - 0.1 # randomly init SDF
|
||||
sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
|
||||
# set per-cube learnable weights to zeros
|
||||
weight = torch.zeros((cube_fx8.shape[0], 21), dtype=torch.float, device='cuda')
|
||||
weight = torch.nn.Parameter(weight.clone().detach(), requires_grad=True)
|
||||
deform = torch.nn.Parameter(torch.zeros_like(x_nx3), requires_grad=True)
|
||||
|
||||
# Retrieve all the edges of the voxel grid; these edges will be utilized to
|
||||
# compute the regularization loss in subsequent steps of the process.
|
||||
all_edges = cube_fx8[:, fc.cube_edges].reshape(-1, 2)
|
||||
grid_edges = torch.unique(all_edges, dim=0)
|
||||
|
||||
# ==============================================================================================
|
||||
# Setup optimizer
|
||||
# ==============================================================================================
|
||||
optimizer = torch.optim.Adam([sdf, weight,deform], lr=FLAGS.learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x))
|
||||
|
||||
# ==============================================================================================
|
||||
# Train loop
|
||||
# ==============================================================================================
|
||||
for it in range(FLAGS.iter):
|
||||
optimizer.zero_grad()
|
||||
# sample random camera poses
|
||||
mv, mvp = render.get_random_camera_batch(FLAGS.batch, iter_res=FLAGS.train_res, device=device, use_kaolin=False)
|
||||
# render gt mesh
|
||||
target = render.render_mesh_paper(gt_mesh, mv, mvp, FLAGS.train_res)
|
||||
# extract and render FlexiCubes mesh
|
||||
grid_verts = x_nx3 + (2-1e-8) / (FLAGS.voxel_grid_res * 2) * torch.tanh(deform)
|
||||
vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
|
||||
gamma_f=weight[:,20], training=True)
|
||||
flexicubes_mesh = Mesh(vertices, faces)
|
||||
buffers = render.render_mesh_paper(flexicubes_mesh, mv, mvp, FLAGS.train_res)
|
||||
|
||||
# evaluate reconstruction loss
|
||||
mask_loss = (buffers['mask'] - target['mask']).abs().mean()
|
||||
depth_loss = (((((buffers['depth'] - (target['depth']))* target['mask'])**2).sum(-1)+1e-8)).sqrt().mean() * 10
|
||||
|
||||
t_iter = it / FLAGS.iter
|
||||
sdf_weight = FLAGS.sdf_regularizer - (FLAGS.sdf_regularizer - FLAGS.sdf_regularizer/20)*min(1.0, 4.0 * t_iter)
|
||||
reg_loss = loss.sdf_reg_loss(sdf, grid_edges).mean() * sdf_weight # Loss to eliminate internal floaters that are not visible
|
||||
reg_loss += L_dev.mean() * 0.5
|
||||
reg_loss += (weight[:,:20]).abs().mean() * 0.1
|
||||
total_loss = mask_loss + depth_loss + reg_loss
|
||||
|
||||
if FLAGS.sdf_loss: # optionally add SDF loss to eliminate internal structures
|
||||
with torch.no_grad():
|
||||
pts = sample_random_points(1000, gt_mesh)
|
||||
gt_sdf = compute_sdf(pts, gt_mesh.vertices, gt_mesh.faces)
|
||||
pred_sdf = compute_sdf(pts, flexicubes_mesh.vertices, flexicubes_mesh.faces)
|
||||
total_loss += torch.nn.functional.mse_loss(pred_sdf, gt_sdf) * 2e3
|
||||
|
||||
# optionally add developability regularizer, as described in paper section 5.2
|
||||
if FLAGS.develop_reg:
|
||||
reg_weight = max(0, t_iter - 0.8) * 5
|
||||
if reg_weight > 0: # only applied after shape converges
|
||||
reg_loss = loss.mesh_developable_reg(flexicubes_mesh).mean() * 10
|
||||
reg_loss += (deform).abs().mean()
|
||||
reg_loss += (weight[:,:20]).abs().mean()
|
||||
total_loss = mask_loss + depth_loss + reg_loss
|
||||
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
if (it % FLAGS.save_interval == 0 or it == (FLAGS.iter-1)): # save normal image for visualization
|
||||
with torch.no_grad():
|
||||
# extract mesh with training=False
|
||||
vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
|
||||
gamma_f=weight[:,20], training=False)
|
||||
flexicubes_mesh = Mesh(vertices, faces)
|
||||
|
||||
flexicubes_mesh.auto_normals() # compute face normals for visualization
|
||||
mv, mvp = render.get_rotate_camera(it//FLAGS.save_interval, iter_res=FLAGS.display_res, device=device,use_kaolin=False)
|
||||
val_buffers = render.render_mesh_paper(flexicubes_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
|
||||
val_image = ((val_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
|
||||
|
||||
gt_buffers = render.render_mesh_paper(gt_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
|
||||
gt_image = ((gt_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
|
||||
imageio.imwrite(os.path.join(FLAGS.out_dir, '{:04d}.png'.format(it)), np.concatenate([val_image, gt_image], 1))
|
||||
print(f"Optimization Step [{it}/{FLAGS.iter}], Loss: {total_loss.item():.4f}")
|
||||
|
||||
# ==============================================================================================
|
||||
# Save ouput
|
||||
# ==============================================================================================
|
||||
mesh_np = trimesh.Trimesh(vertices = vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy(), process=False)
|
||||
mesh_np.export(os.path.join(FLAGS.out_dir, 'output_mesh.obj'))
|
||||
390
trellis/representations/mesh/flexicubes/flexicubes.py
Normal file
390
trellis/representations/mesh/flexicubes/flexicubes.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from .tables import *
|
||||
from kaolin.utils.testing import check_tensor
|
||||
|
||||
__all__ = [
|
||||
'FlexiCubes'
|
||||
]
|
||||
|
||||
|
||||
class FlexiCubes:
|
||||
def __init__(self, device="cuda"):
|
||||
|
||||
self.device = device
|
||||
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
|
||||
self.num_vd_table = torch.tensor(num_vd_table,
|
||||
dtype=torch.long, device=device, requires_grad=False)
|
||||
self.check_table = torch.tensor(
|
||||
check_table,
|
||||
dtype=torch.long, device=device, requires_grad=False)
|
||||
|
||||
self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
|
||||
self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
||||
self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
||||
self.quad_split_train = torch.tensor(
|
||||
[0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
|
||||
|
||||
self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
|
||||
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
|
||||
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
|
||||
self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
|
||||
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
|
||||
|
||||
self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
|
||||
dtype=torch.long, device=device)
|
||||
self.dir_faces_table = torch.tensor([
|
||||
[[5, 4], [3, 2], [4, 5], [2, 3]],
|
||||
[[5, 4], [1, 0], [4, 5], [0, 1]],
|
||||
[[3, 2], [1, 0], [2, 3], [0, 1]]
|
||||
], dtype=torch.long, device=device)
|
||||
self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
|
||||
|
||||
def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3,
|
||||
weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False):
|
||||
assert torch.is_tensor(voxelgrid_vertices) and \
|
||||
check_tensor(voxelgrid_vertices, (None, 3), throw=False), \
|
||||
"'voxelgrid_vertices' should be a tensor of shape (num_vertices, 3)"
|
||||
num_vertices = voxelgrid_vertices.shape[0]
|
||||
assert torch.is_tensor(scalar_field) and \
|
||||
check_tensor(scalar_field, (num_vertices,), throw=False), \
|
||||
"'scalar_field' should be a tensor of shape (num_vertices,)"
|
||||
assert torch.is_tensor(cube_idx) and \
|
||||
check_tensor(cube_idx, (None, 8), throw=False), \
|
||||
"'cube_idx' should be a tensor of shape (num_cubes, 8)"
|
||||
num_cubes = cube_idx.shape[0]
|
||||
assert beta is None or (
|
||||
torch.is_tensor(beta) and
|
||||
check_tensor(beta, (num_cubes, 12), throw=False)
|
||||
), "'beta' should be a tensor of shape (num_cubes, 12)"
|
||||
assert alpha is None or (
|
||||
torch.is_tensor(alpha) and
|
||||
check_tensor(alpha, (num_cubes, 8), throw=False)
|
||||
), "'alpha' should be a tensor of shape (num_cubes, 8)"
|
||||
assert gamma_f is None or (
|
||||
torch.is_tensor(gamma_f) and
|
||||
check_tensor(gamma_f, (num_cubes,), throw=False)
|
||||
), "'gamma_f' should be a tensor of shape (num_cubes,)"
|
||||
|
||||
surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx)
|
||||
if surf_cubes.sum() == 0:
|
||||
return (
|
||||
torch.zeros((0, 3), device=self.device),
|
||||
torch.zeros((0, 3), dtype=torch.long, device=self.device),
|
||||
torch.zeros((0), device=self.device),
|
||||
torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None
|
||||
)
|
||||
beta, alpha, gamma_f = self._normalize_weights(
|
||||
beta, alpha, gamma_f, surf_cubes, weight_scale)
|
||||
|
||||
if voxelgrid_colors is not None:
|
||||
voxelgrid_colors = torch.sigmoid(voxelgrid_colors)
|
||||
|
||||
case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution)
|
||||
|
||||
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(
|
||||
scalar_field, cube_idx, surf_cubes
|
||||
)
|
||||
|
||||
vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd(
|
||||
voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field,
|
||||
case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors)
|
||||
vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate(
|
||||
scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map,
|
||||
vd_idx_map, surf_edges_mask, training, vd_color)
|
||||
return vertices, faces, L_dev, vertices_color
|
||||
|
||||
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
|
||||
"""
|
||||
Regularizer L_dev as in Equation 8
|
||||
"""
|
||||
dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
|
||||
mean_l2 = torch.zeros_like(vd[:, 0])
|
||||
mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
|
||||
mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
|
||||
return mad
|
||||
|
||||
def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale):
|
||||
"""
|
||||
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
|
||||
"""
|
||||
n_cubes = surf_cubes.shape[0]
|
||||
|
||||
if beta is not None:
|
||||
beta = (torch.tanh(beta) * weight_scale + 1)
|
||||
else:
|
||||
beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
|
||||
|
||||
if alpha is not None:
|
||||
alpha = (torch.tanh(alpha) * weight_scale + 1)
|
||||
else:
|
||||
alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
|
||||
|
||||
if gamma_f is not None:
|
||||
gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2
|
||||
else:
|
||||
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
|
||||
|
||||
return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes]
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_case_id(self, occ_fx8, surf_cubes, res):
|
||||
"""
|
||||
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
|
||||
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
|
||||
supplementary material. It should be noted that this function assumes a regular grid.
|
||||
"""
|
||||
case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
|
||||
|
||||
problem_config = self.check_table.to(self.device)[case_ids]
|
||||
to_check = problem_config[..., 0] == 1
|
||||
problem_config = problem_config[to_check]
|
||||
if not isinstance(res, (list, tuple)):
|
||||
res = [res, res, res]
|
||||
|
||||
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
|
||||
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
|
||||
# This allows efficient checking on adjacent cubes.
|
||||
problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
|
||||
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
|
||||
vol_idx_problem = vol_idx[surf_cubes][to_check]
|
||||
problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
|
||||
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
|
||||
|
||||
within_range = (
|
||||
vol_idx_problem_adj[..., 0] >= 0) & (
|
||||
vol_idx_problem_adj[..., 0] < res[0]) & (
|
||||
vol_idx_problem_adj[..., 1] >= 0) & (
|
||||
vol_idx_problem_adj[..., 1] < res[1]) & (
|
||||
vol_idx_problem_adj[..., 2] >= 0) & (
|
||||
vol_idx_problem_adj[..., 2] < res[2])
|
||||
|
||||
vol_idx_problem = vol_idx_problem[within_range]
|
||||
vol_idx_problem_adj = vol_idx_problem_adj[within_range]
|
||||
problem_config = problem_config[within_range]
|
||||
problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
|
||||
vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
|
||||
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
|
||||
to_invert = (problem_config_adj[..., 0] == 1)
|
||||
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
|
||||
case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
|
||||
return case_ids
|
||||
|
||||
@torch.no_grad()
|
||||
def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes):
|
||||
"""
|
||||
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
|
||||
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
|
||||
and marks the cube edges with this index.
|
||||
"""
|
||||
occ_n = scalar_field < 0
|
||||
all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2)
|
||||
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
|
||||
|
||||
unique_edges = unique_edges.long()
|
||||
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
|
||||
|
||||
surf_edges_mask = mask_edges[_idx_map]
|
||||
counts = counts[_idx_map]
|
||||
|
||||
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1
|
||||
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device)
|
||||
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
|
||||
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
|
||||
idx_map = mapping[_idx_map]
|
||||
surf_edges = unique_edges[mask_edges]
|
||||
return surf_edges, idx_map, counts, surf_edges_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def _identify_surf_cubes(self, scalar_field, cube_idx):
|
||||
"""
|
||||
Identifies grid cubes that intersect with the underlying surface by checking if the signs at
|
||||
all corners are not identical.
|
||||
"""
|
||||
occ_n = scalar_field < 0
|
||||
occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8)
|
||||
_occ_sum = torch.sum(occ_fx8, -1)
|
||||
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
|
||||
return surf_cubes, occ_fx8
|
||||
|
||||
def _linear_interp(self, edges_weight, edges_x):
|
||||
"""
|
||||
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
|
||||
"""
|
||||
edge_dim = edges_weight.dim() - 2
|
||||
assert edges_weight.shape[edge_dim] == 2
|
||||
edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
|
||||
torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)]
|
||||
, edge_dim)
|
||||
denominator = edges_weight.sum(edge_dim)
|
||||
ue = (edges_x * edges_weight).sum(edge_dim) / denominator
|
||||
return ue
|
||||
|
||||
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale):
|
||||
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
|
||||
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
|
||||
c_bx3 = c_bx3.reshape(-1, 3)
|
||||
A = norm_bxnx3
|
||||
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
|
||||
|
||||
A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
|
||||
B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1)
|
||||
A = torch.cat([A, A_reg], 1)
|
||||
B = torch.cat([B, B_reg], 1)
|
||||
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
|
||||
return dual_verts
|
||||
|
||||
def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field,
|
||||
case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors):
|
||||
"""
|
||||
Computes the location of dual vertices as described in Section 4.2
|
||||
"""
|
||||
alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
|
||||
surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
|
||||
surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
|
||||
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
|
||||
|
||||
if voxelgrid_colors is not None:
|
||||
C = voxelgrid_colors.shape[-1]
|
||||
surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C)
|
||||
|
||||
idx_map = idx_map.reshape(-1, 12)
|
||||
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
|
||||
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
|
||||
|
||||
# if color is not None:
|
||||
# vd_color = []
|
||||
|
||||
total_num_vd = 0
|
||||
vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
|
||||
|
||||
for num in torch.unique(num_vd):
|
||||
cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
|
||||
curr_num_vd = cur_cubes.sum() * num
|
||||
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
|
||||
curr_edge_group_to_vd = torch.arange(
|
||||
curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
|
||||
total_num_vd += curr_num_vd
|
||||
curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
|
||||
cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
|
||||
|
||||
curr_mask = (curr_edge_group != -1)
|
||||
edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
|
||||
edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
|
||||
edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
|
||||
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
|
||||
vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
|
||||
# if color is not None:
|
||||
# vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3))
|
||||
|
||||
edge_group = torch.cat(edge_group)
|
||||
edge_group_to_vd = torch.cat(edge_group_to_vd)
|
||||
edge_group_to_cube = torch.cat(edge_group_to_cube)
|
||||
vd_num_edges = torch.cat(vd_num_edges)
|
||||
vd_gamma = torch.cat(vd_gamma)
|
||||
# if color is not None:
|
||||
# vd_color = torch.cat(vd_color)
|
||||
# else:
|
||||
# vd_color = None
|
||||
|
||||
vd = torch.zeros((total_num_vd, 3), device=self.device)
|
||||
beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
|
||||
|
||||
idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
|
||||
|
||||
x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
|
||||
s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
|
||||
|
||||
|
||||
zero_crossing_group = torch.index_select(
|
||||
input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
|
||||
|
||||
alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
|
||||
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
|
||||
ue_group = self._linear_interp(s_group * alpha_group, x_group)
|
||||
|
||||
beta_group = torch.gather(input=beta.reshape(-1), dim=0,
|
||||
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
|
||||
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
|
||||
vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
|
||||
|
||||
'''
|
||||
interpolate colors use the same method as dual vertices
|
||||
'''
|
||||
if voxelgrid_colors is not None:
|
||||
vd_color = torch.zeros((total_num_vd, C), device=self.device)
|
||||
c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C)
|
||||
uc_group = self._linear_interp(s_group * alpha_group, c_group)
|
||||
vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum
|
||||
else:
|
||||
vd_color = None
|
||||
|
||||
L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
|
||||
|
||||
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
|
||||
|
||||
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
|
||||
12 + edge_group, src=v_idx[edge_group_to_vd])
|
||||
|
||||
return vd, L_dev, vd_gamma, vd_idx_map, vd_color
|
||||
|
||||
def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color):
|
||||
"""
|
||||
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
|
||||
triangles based on the gamma parameter, as described in Section 4.3.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
|
||||
group = idx_map.reshape(-1)[group_mask]
|
||||
vd_idx = vd_idx_map[group_mask]
|
||||
edge_indices, indices = torch.sort(group, stable=True)
|
||||
quad_vd_idx = vd_idx[indices].reshape(-1, 4)
|
||||
|
||||
# Ensure all face directions point towards the positive SDF to maintain consistent winding.
|
||||
s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
|
||||
flip_mask = s_edges[:, 0] > 0
|
||||
quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
|
||||
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
|
||||
|
||||
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
|
||||
gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2]
|
||||
gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3]
|
||||
if not training:
|
||||
mask = (gamma_02 > gamma_13)
|
||||
faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
|
||||
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
|
||||
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
|
||||
faces = faces.reshape(-1, 3)
|
||||
else:
|
||||
vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
|
||||
vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2
|
||||
vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2
|
||||
weight_sum = (gamma_02 + gamma_13) + 1e-8
|
||||
vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
|
||||
|
||||
if vd_color is not None:
|
||||
color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1])
|
||||
color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2
|
||||
color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2
|
||||
color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
|
||||
vd_color = torch.cat([vd_color, color_center])
|
||||
|
||||
|
||||
vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
|
||||
vd = torch.cat([vd, vd_center])
|
||||
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
|
||||
faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
|
||||
return vd, faces, s_edges, edge_indices, vd_color
|
||||
353
trellis/trainers/flow_matching/flow_matching.py
Normal file
353
trellis/trainers/flow_matching/flow_matching.py
Normal file
@@ -0,0 +1,353 @@
|
||||
from typing import *
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
from ..basic import BasicTrainer
|
||||
from ...pipelines import samplers
|
||||
from ...utils.general_utils import dict_reduce
|
||||
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
|
||||
from .mixins.text_conditioned import TextConditionedMixin
|
||||
from .mixins.image_conditioned import ImageConditionedMixin
|
||||
|
||||
|
||||
class FlowMatchingTrainer(BasicTrainer):
|
||||
"""
|
||||
Trainer for diffusion model with flow matching objective.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
|
||||
t_schedule (dict): Time schedule for flow matching.
|
||||
sigma_min (float): Minimum noise level.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
t_schedule: dict = {
|
||||
'name': 'logitNormal',
|
||||
'args': {
|
||||
'mean': 0.0,
|
||||
'std': 1.0,
|
||||
}
|
||||
},
|
||||
sigma_min: float = 1e-5,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.t_schedule = t_schedule
|
||||
self.sigma_min = sigma_min
|
||||
|
||||
def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Diffuse the data for a given number of diffusion steps.
|
||||
In other words, sample from q(x_t | x_0).
|
||||
|
||||
Args:
|
||||
x_0: The [N x C x ...] tensor of noiseless inputs.
|
||||
t: The [N] tensor of diffusion steps [0-1].
|
||||
noise: If specified, use this noise instead of generating new noise.
|
||||
|
||||
Returns:
|
||||
x_t, the noisy version of x_0 under timestep t.
|
||||
"""
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x_0)
|
||||
assert noise.shape == x_0.shape, "noise must have same shape as x_0"
|
||||
|
||||
t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
|
||||
x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise
|
||||
|
||||
return x_t
|
||||
|
||||
def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get original image from noisy version under timestep t.
|
||||
"""
|
||||
assert noise.shape == x_t.shape, "noise must have same shape as x_t"
|
||||
t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)])
|
||||
x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t)
|
||||
return x_0
|
||||
|
||||
def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the velocity of the diffusion process at time t.
|
||||
"""
|
||||
return (1 - self.sigma_min) * noise - x_0
|
||||
|
||||
def get_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Get the conditioning data.
|
||||
"""
|
||||
return cond
|
||||
|
||||
def get_inference_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Get the conditioning data for inference.
|
||||
"""
|
||||
return {'cond': cond, **kwargs}
|
||||
|
||||
def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler:
|
||||
"""
|
||||
Get the sampler for the diffusion process.
|
||||
"""
|
||||
return samplers.FlowEulerSampler(self.sigma_min)
|
||||
|
||||
def vis_cond(self, **kwargs):
|
||||
"""
|
||||
Visualize the conditioning data.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def sample_t(self, batch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Sample timesteps.
|
||||
"""
|
||||
if self.t_schedule['name'] == 'uniform':
|
||||
t = torch.rand(batch_size)
|
||||
elif self.t_schedule['name'] == 'logitNormal':
|
||||
mean = self.t_schedule['args']['mean']
|
||||
std = self.t_schedule['args']['std']
|
||||
t = torch.sigmoid(torch.randn(batch_size) * std + mean)
|
||||
else:
|
||||
raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}")
|
||||
return t
|
||||
|
||||
def training_losses(
|
||||
self,
|
||||
x_0: torch.Tensor,
|
||||
cond=None,
|
||||
**kwargs
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
|
||||
Args:
|
||||
x_0: The [N x C x ...] tensor of noiseless inputs.
|
||||
cond: The [N x ...] tensor of additional conditions.
|
||||
kwargs: Additional arguments to pass to the backbone.
|
||||
|
||||
Returns:
|
||||
a dict with the key "loss" containing a tensor of shape [N].
|
||||
may also contain other keys for different terms.
|
||||
"""
|
||||
noise = torch.randn_like(x_0)
|
||||
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
|
||||
x_t = self.diffuse(x_0, t, noise=noise)
|
||||
cond = self.get_cond(cond, **kwargs)
|
||||
|
||||
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
|
||||
assert pred.shape == noise.shape == x_0.shape
|
||||
target = self.get_v(x_0, noise, t)
|
||||
terms = edict()
|
||||
terms["mse"] = F.mse_loss(pred, target)
|
||||
terms["loss"] = terms["mse"]
|
||||
|
||||
# log loss with time bins
|
||||
mse_per_instance = np.array([
|
||||
F.mse_loss(pred[i], target[i]).item()
|
||||
for i in range(x_0.shape[0])
|
||||
])
|
||||
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
|
||||
for i in range(10):
|
||||
if (time_bin == i).sum() != 0:
|
||||
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
|
||||
|
||||
return terms, {}
|
||||
|
||||
@torch.no_grad()
|
||||
def run_snapshot(
|
||||
self,
|
||||
num_samples: int,
|
||||
batch_size: int,
|
||||
verbose: bool = False,
|
||||
) -> Dict:
|
||||
dataloader = DataLoader(
|
||||
copy.deepcopy(self.dataset),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||
)
|
||||
|
||||
# inference
|
||||
sampler = self.get_sampler()
|
||||
sample_gt = []
|
||||
sample = []
|
||||
cond_vis = []
|
||||
for i in range(0, num_samples, batch_size):
|
||||
batch = min(batch_size, num_samples - i)
|
||||
data = next(iter(dataloader))
|
||||
data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
|
||||
noise = torch.randn_like(data['x_0'])
|
||||
sample_gt.append(data['x_0'])
|
||||
cond_vis.append(self.vis_cond(**data))
|
||||
del data['x_0']
|
||||
args = self.get_inference_cond(**data)
|
||||
res = sampler.sample(
|
||||
self.models['denoiser'],
|
||||
noise=noise,
|
||||
**args,
|
||||
steps=50, cfg_strength=3.0, verbose=verbose,
|
||||
)
|
||||
sample.append(res.samples)
|
||||
|
||||
sample_gt = torch.cat(sample_gt, dim=0)
|
||||
sample = torch.cat(sample, dim=0)
|
||||
sample_dict = {
|
||||
'sample_gt': {'value': sample_gt, 'type': 'sample'},
|
||||
'sample': {'value': sample, 'type': 'sample'},
|
||||
}
|
||||
sample_dict.update(dict_reduce(cond_vis, None, {
|
||||
'value': lambda x: torch.cat(x, dim=0),
|
||||
'type': lambda x: x[0],
|
||||
}))
|
||||
|
||||
return sample_dict
|
||||
|
||||
|
||||
class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer):
|
||||
"""
|
||||
Trainer for diffusion model with flow matching objective and classifier-free guidance.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
|
||||
t_schedule (dict): Time schedule for flow matching.
|
||||
sigma_min (float): Minimum noise level.
|
||||
p_uncond (float): Probability of dropping conditions.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer):
|
||||
"""
|
||||
Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
|
||||
t_schedule (dict): Time schedule for flow matching.
|
||||
sigma_min (float): Minimum noise level.
|
||||
p_uncond (float): Probability of dropping conditions.
|
||||
text_cond_model(str): Text conditioning model.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer):
|
||||
"""
|
||||
Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
|
||||
|
||||
Args:
|
||||
models (dict[str, nn.Module]): Models to train.
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
output_dir (str): Output directory.
|
||||
load_dir (str): Load directory.
|
||||
step (int): Step to load.
|
||||
batch_size (int): Batch size.
|
||||
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||
batch_split (int): Split batch with gradient accumulation.
|
||||
max_steps (int): Max steps.
|
||||
optimizer (dict): Optimizer config.
|
||||
lr_scheduler (dict): Learning rate scheduler config.
|
||||
elastic (dict): Elastic memory management config.
|
||||
grad_clip (float or dict): Gradient clip config.
|
||||
ema_rate (float or list): Exponential moving average rates.
|
||||
fp16_mode (str): FP16 mode.
|
||||
- None: No FP16.
|
||||
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||
- 'amp': Automatic mixed precision.
|
||||
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||
finetune_ckpt (dict): Finetune checkpoint.
|
||||
log_param_stats (bool): Log parameter stats.
|
||||
i_print (int): Print interval.
|
||||
i_log (int): Log interval.
|
||||
i_sample (int): Sample interval.
|
||||
i_save (int): Save interval.
|
||||
i_ddpcheck (int): DDP check interval.
|
||||
|
||||
t_schedule (dict): Time schedule for flow matching.
|
||||
sigma_min (float): Minimum noise level.
|
||||
p_uncond (float): Probability of dropping conditions.
|
||||
image_cond_model (str): Image conditioning model.
|
||||
"""
|
||||
pass
|
||||
93
trellis/utils/dist_utils.py
Normal file
93
trellis/utils/dist_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
|
||||
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
|
||||
os.environ['MASTER_ADDR'] = master_addr
|
||||
os.environ['MASTER_PORT'] = master_port
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(local_rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def read_file_dist(path):
|
||||
"""
|
||||
Read the binary file distributedly.
|
||||
File is only read once by the rank 0 process and broadcasted to other processes.
|
||||
|
||||
Returns:
|
||||
data (io.BytesIO): The binary data read from the file.
|
||||
"""
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
# read file
|
||||
size = torch.LongTensor(1).cuda()
|
||||
if dist.get_rank() == 0:
|
||||
with open(path, 'rb') as f:
|
||||
data = f.read()
|
||||
data = torch.ByteTensor(
|
||||
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
|
||||
).cuda()
|
||||
size[0] = data.shape[0]
|
||||
# broadcast size
|
||||
dist.broadcast(size, src=0)
|
||||
if dist.get_rank() != 0:
|
||||
data = torch.ByteTensor(size[0].item()).cuda()
|
||||
# broadcast data
|
||||
dist.broadcast(data, src=0)
|
||||
# convert to io.BytesIO
|
||||
data = data.cpu().numpy().tobytes()
|
||||
data = io.BytesIO(data)
|
||||
return data
|
||||
else:
|
||||
with open(path, 'rb') as f:
|
||||
data = f.read()
|
||||
data = io.BytesIO(data)
|
||||
return data
|
||||
|
||||
|
||||
def unwrap_dist(model):
|
||||
"""
|
||||
Unwrap the model from distributed training.
|
||||
"""
|
||||
if isinstance(model, DDP):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
|
||||
@contextmanager
|
||||
def master_first():
|
||||
"""
|
||||
A context manager that ensures master process executes first.
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
yield
|
||||
else:
|
||||
if dist.get_rank() == 0:
|
||||
yield
|
||||
dist.barrier()
|
||||
else:
|
||||
dist.barrier()
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def local_master_first():
|
||||
"""
|
||||
A context manager that ensures local master process executes first.
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
yield
|
||||
else:
|
||||
if dist.get_rank() % torch.cuda.device_count() == 0:
|
||||
yield
|
||||
dist.barrier()
|
||||
else:
|
||||
dist.barrier()
|
||||
yield
|
||||
|
||||
228
trellis/utils/elastic_utils.py
Normal file
228
trellis/utils/elastic_utils.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MemoryController:
|
||||
"""
|
||||
Base class for memory management during training.
|
||||
"""
|
||||
|
||||
_last_input_size = None
|
||||
_last_mem_ratio = []
|
||||
|
||||
@contextmanager
|
||||
def record(self):
|
||||
pass
|
||||
|
||||
def update_run_states(self, input_size=None, mem_ratio=None):
|
||||
if self._last_input_size is None:
|
||||
self._last_input_size = input_size
|
||||
elif self._last_input_size!= input_size:
|
||||
raise ValueError(f'Input size should not change for different ElasticModules.')
|
||||
self._last_mem_ratio.append(mem_ratio)
|
||||
|
||||
@abstractmethod
|
||||
def get_mem_ratio(self, input_size):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def log(self):
|
||||
pass
|
||||
|
||||
|
||||
class LinearMemoryController(MemoryController):
|
||||
"""
|
||||
A simple controller for memory management during training.
|
||||
The memory usage is modeled as a linear function of:
|
||||
- the number of input parameters
|
||||
- the ratio of memory the model use compared to the maximum usage (with no checkpointing)
|
||||
memory_usage = k * input_size * mem_ratio + b
|
||||
The controller keeps track of the memory usage and gives the
|
||||
expected memory ratio to keep the memory usage under a target
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size=1000,
|
||||
update_every=500,
|
||||
target_ratio=0.8,
|
||||
available_memory=None,
|
||||
max_mem_ratio_start=0.1,
|
||||
params=None,
|
||||
device=None
|
||||
):
|
||||
self.buffer_size = buffer_size
|
||||
self.update_every = update_every
|
||||
self.target_ratio = target_ratio
|
||||
self.device = device or torch.cuda.current_device()
|
||||
self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
|
||||
|
||||
self._memory = np.zeros(buffer_size, dtype=np.float32)
|
||||
self._input_size = np.zeros(buffer_size, dtype=np.float32)
|
||||
self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
|
||||
self._buffer_ptr = 0
|
||||
self._buffer_length = 0
|
||||
self._params = tuple(params) if params is not None else (0.0, 0.0)
|
||||
self._max_mem_ratio = max_mem_ratio_start
|
||||
self.step = 0
|
||||
|
||||
def __repr__(self):
|
||||
return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
|
||||
|
||||
def _add_sample(self, memory, input_size, mem_ratio):
|
||||
self._memory[self._buffer_ptr] = memory
|
||||
self._input_size[self._buffer_ptr] = input_size
|
||||
self._mem_ratio[self._buffer_ptr] = mem_ratio
|
||||
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
||||
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
||||
|
||||
@contextmanager
|
||||
def record(self):
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
self._last_input_size = None
|
||||
self._last_mem_ratio = []
|
||||
yield
|
||||
self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
|
||||
self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
|
||||
self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
|
||||
self.step += 1
|
||||
if self.step % self.update_every == 0:
|
||||
self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
|
||||
self._fit_params()
|
||||
|
||||
def _fit_params(self):
|
||||
memory_usage = self._memory[:self._buffer_length]
|
||||
input_size = self._input_size[:self._buffer_length]
|
||||
mem_ratio = self._mem_ratio[:self._buffer_length]
|
||||
|
||||
x = input_size * mem_ratio
|
||||
y = memory_usage
|
||||
k, b = np.polyfit(x, y, 1)
|
||||
self._params = (k, b)
|
||||
# self._visualize()
|
||||
|
||||
def _visualize(self):
|
||||
import matplotlib.pyplot as plt
|
||||
memory_usage = self._memory[:self._buffer_length]
|
||||
input_size = self._input_size[:self._buffer_length]
|
||||
mem_ratio = self._mem_ratio[:self._buffer_length]
|
||||
k, b = self._params
|
||||
|
||||
plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
|
||||
x = np.array([0.0, 20000.0])
|
||||
plt.plot(x, k * x + b, c='r')
|
||||
plt.savefig(f'linear_memory_controller_{self.step}.png')
|
||||
plt.cla()
|
||||
|
||||
def get_mem_ratio(self, input_size):
|
||||
k, b = self._params
|
||||
if k == 0: return np.random.rand() * self._max_mem_ratio
|
||||
pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
|
||||
return min(self._max_mem_ratio, max(0.0, pred))
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'params': self._params,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._params = tuple(state_dict['params'])
|
||||
|
||||
def log(self):
|
||||
return {
|
||||
'params/k': self._params[0],
|
||||
'params/b': self._params[1],
|
||||
'memory': self._last_memory,
|
||||
'input_size': self._last_input_size,
|
||||
'mem_ratio': self._last_mem_ratio,
|
||||
}
|
||||
|
||||
|
||||
class ElasticModule(nn.Module):
|
||||
"""
|
||||
Module for training with elastic memory management.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._memory_controller: MemoryController = None
|
||||
|
||||
@abstractmethod
|
||||
def _get_input_size(self, *args, **kwargs) -> int:
|
||||
"""
|
||||
Get the size of the input data.
|
||||
|
||||
Returns:
|
||||
int: The size of the input data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
|
||||
"""
|
||||
Forward with a given memory ratio.
|
||||
"""
|
||||
pass
|
||||
|
||||
def register_memory_controller(self, memory_controller: MemoryController):
|
||||
self._memory_controller = memory_controller
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
||||
_, ret = self._forward_with_mem_ratio(*args, **kwargs)
|
||||
else:
|
||||
input_size = self._get_input_size(*args, **kwargs)
|
||||
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
||||
mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
|
||||
self._memory_controller.update_run_states(input_size, mem_ratio)
|
||||
return ret
|
||||
|
||||
|
||||
class ElasticModuleMixin:
|
||||
"""
|
||||
Mixin for training with elastic memory management.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._memory_controller: MemoryController = None
|
||||
|
||||
@abstractmethod
|
||||
def _get_input_size(self, *args, **kwargs) -> int:
|
||||
"""
|
||||
Get the size of the input data.
|
||||
|
||||
Returns:
|
||||
int: The size of the input data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def with_mem_ratio(self, mem_ratio=1.0) -> float:
|
||||
"""
|
||||
Context manager for training with a reduced memory ratio compared to the full memory usage.
|
||||
|
||||
Returns:
|
||||
float: The exact memory ratio used during the forward pass.
|
||||
"""
|
||||
pass
|
||||
|
||||
def register_memory_controller(self, memory_controller: MemoryController):
|
||||
self._memory_controller = memory_controller
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
||||
ret = super().forward(*args, **kwargs)
|
||||
else:
|
||||
input_size = self._get_input_size(*args, **kwargs)
|
||||
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
||||
with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
|
||||
ret = super().forward(*args, **kwargs)
|
||||
self._memory_controller.update_run_states(input_size, exact_mem_ratio)
|
||||
return ret
|
||||
587
trellis/utils/postprocessing_utils.py
Normal file
587
trellis/utils/postprocessing_utils.py
Normal file
@@ -0,0 +1,587 @@
|
||||
from typing import *
|
||||
import numpy as np
|
||||
import torch
|
||||
import utils3d
|
||||
import nvdiffrast.torch as dr
|
||||
from tqdm import tqdm
|
||||
import trimesh
|
||||
import trimesh.visual
|
||||
import xatlas
|
||||
import pyvista as pv
|
||||
from pymeshfix import _meshfix
|
||||
import igraph
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from .random_utils import sphere_hammersley_sequence
|
||||
from .render_utils import render_multiview
|
||||
from ..renderers import GaussianRenderer
|
||||
from ..representations import Strivec, Gaussian, MeshExtractResult
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _fill_holes(
|
||||
verts,
|
||||
faces,
|
||||
max_hole_size=0.04,
|
||||
max_hole_nbe=32,
|
||||
resolution=128,
|
||||
num_views=500,
|
||||
debug=False,
|
||||
verbose=False
|
||||
):
|
||||
"""
|
||||
Rasterize a mesh from multiple views and remove invisible faces.
|
||||
Also includes postprocessing to:
|
||||
1. Remove connected components that are have low visibility.
|
||||
2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole.
|
||||
|
||||
Args:
|
||||
verts (torch.Tensor): Vertices of the mesh. Shape (V, 3).
|
||||
faces (torch.Tensor): Faces of the mesh. Shape (F, 3).
|
||||
max_hole_size (float): Maximum area of a hole to fill.
|
||||
resolution (int): Resolution of the rasterization.
|
||||
num_views (int): Number of views to rasterize the mesh.
|
||||
verbose (bool): Whether to print progress.
|
||||
"""
|
||||
# Construct cameras
|
||||
yaws = []
|
||||
pitchs = []
|
||||
for i in range(num_views):
|
||||
y, p = sphere_hammersley_sequence(i, num_views)
|
||||
yaws.append(y)
|
||||
pitchs.append(p)
|
||||
yaws = torch.tensor(yaws).cuda()
|
||||
pitchs = torch.tensor(pitchs).cuda()
|
||||
radius = 2.0
|
||||
fov = torch.deg2rad(torch.tensor(40)).cuda()
|
||||
projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
|
||||
views = []
|
||||
for (yaw, pitch) in zip(yaws, pitchs):
|
||||
orig = torch.tensor([
|
||||
torch.sin(yaw) * torch.cos(pitch),
|
||||
torch.cos(yaw) * torch.cos(pitch),
|
||||
torch.sin(pitch),
|
||||
]).cuda().float() * radius
|
||||
view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
||||
views.append(view)
|
||||
views = torch.stack(views, dim=0)
|
||||
|
||||
# Rasterize
|
||||
visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device)
|
||||
rastctx = utils3d.torch.RastContext(backend='cuda')
|
||||
for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'):
|
||||
view = views[i]
|
||||
buffers = utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection
|
||||
)
|
||||
face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1
|
||||
face_id = torch.unique(face_id).long()
|
||||
visblity[face_id] += 1
|
||||
visblity = visblity.float() / num_views
|
||||
|
||||
# Mincut
|
||||
## construct outer faces
|
||||
edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
|
||||
boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
|
||||
connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge)
|
||||
outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device)
|
||||
for i in range(len(connected_components)):
|
||||
outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5)
|
||||
outer_face_indices = outer_face_indices.nonzero().reshape(-1)
|
||||
|
||||
## construct inner faces
|
||||
inner_face_indices = torch.nonzero(visblity == 0).reshape(-1)
|
||||
if verbose:
|
||||
tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces')
|
||||
if inner_face_indices.shape[0] == 0:
|
||||
return verts, faces
|
||||
|
||||
## Construct dual graph (faces as nodes, edges as edges)
|
||||
dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge)
|
||||
dual_edge2edge = edges[dual_edge2edge]
|
||||
dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1)
|
||||
if verbose:
|
||||
tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges')
|
||||
|
||||
## solve mincut problem
|
||||
### construct main graph
|
||||
g = igraph.Graph()
|
||||
g.add_vertices(faces.shape[0])
|
||||
g.add_edges(dual_edges.cpu().numpy())
|
||||
g.es['weight'] = dual_edges_weights.cpu().numpy()
|
||||
|
||||
### source and target
|
||||
g.add_vertex('s')
|
||||
g.add_vertex('t')
|
||||
|
||||
### connect invisible faces to source
|
||||
g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
|
||||
|
||||
### connect outer faces to target
|
||||
g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
|
||||
|
||||
### solve mincut
|
||||
cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist())
|
||||
remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device)
|
||||
if verbose:
|
||||
tqdm.write(f'Mincut solved, start checking the cut')
|
||||
|
||||
### check if the cut is valid with each connected component
|
||||
to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices])
|
||||
if debug:
|
||||
tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}')
|
||||
valid_remove_cc = []
|
||||
cutting_edges = []
|
||||
for cc in to_remove_cc:
|
||||
#### check if the connected component has low visibility
|
||||
visblity_median = visblity[remove_face_indices[cc]].median()
|
||||
if debug:
|
||||
tqdm.write(f'visblity_median: {visblity_median}')
|
||||
if visblity_median > 0.25:
|
||||
continue
|
||||
|
||||
#### check if the cuting loop is small enough
|
||||
cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True)
|
||||
cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
|
||||
cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)]
|
||||
if len(cc_new_boundary_edge_indices) > 0:
|
||||
cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices])
|
||||
cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc]
|
||||
cc_new_boundary_edges_cc_area = []
|
||||
for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
|
||||
_e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i]
|
||||
_e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i]
|
||||
cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5)
|
||||
if debug:
|
||||
cutting_edges.append(cc_new_boundary_edge_indices)
|
||||
tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}')
|
||||
if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]):
|
||||
continue
|
||||
|
||||
valid_remove_cc.append(cc)
|
||||
|
||||
if debug:
|
||||
face_v = verts[faces].mean(dim=1).cpu().numpy()
|
||||
vis_dual_edges = dual_edges.cpu().numpy()
|
||||
vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8)
|
||||
vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255]
|
||||
vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0]
|
||||
vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255]
|
||||
if len(valid_remove_cc) > 0:
|
||||
vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0]
|
||||
utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors)
|
||||
|
||||
vis_verts = verts.cpu().numpy()
|
||||
vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy()
|
||||
utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges)
|
||||
|
||||
|
||||
if len(valid_remove_cc) > 0:
|
||||
remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)]
|
||||
mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device)
|
||||
mask[remove_face_indices] = 0
|
||||
faces = faces[mask]
|
||||
faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts)
|
||||
if verbose:
|
||||
tqdm.write(f'Removed {(~mask).sum()} faces by mincut')
|
||||
else:
|
||||
if verbose:
|
||||
tqdm.write(f'Removed 0 faces by mincut')
|
||||
|
||||
mesh = _meshfix.PyTMesh()
|
||||
mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy())
|
||||
mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
|
||||
verts, faces = mesh.return_arrays()
|
||||
verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32)
|
||||
|
||||
return verts, faces
|
||||
|
||||
|
||||
def postprocess_mesh(
|
||||
vertices: np.array,
|
||||
faces: np.array,
|
||||
simplify: bool = True,
|
||||
simplify_ratio: float = 0.9,
|
||||
fill_holes: bool = True,
|
||||
fill_holes_max_hole_size: float = 0.04,
|
||||
fill_holes_max_hole_nbe: int = 32,
|
||||
fill_holes_resolution: int = 1024,
|
||||
fill_holes_num_views: int = 1000,
|
||||
debug: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces.
|
||||
|
||||
Args:
|
||||
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
||||
faces (np.array): Faces of the mesh. Shape (F, 3).
|
||||
simplify (bool): Whether to simplify the mesh, using quadric edge collapse.
|
||||
simplify_ratio (float): Ratio of faces to keep after simplification.
|
||||
fill_holes (bool): Whether to fill holes in the mesh.
|
||||
fill_holes_max_hole_size (float): Maximum area of a hole to fill.
|
||||
fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill.
|
||||
fill_holes_resolution (int): Resolution of the rasterization.
|
||||
fill_holes_num_views (int): Number of views to rasterize the mesh.
|
||||
verbose (bool): Whether to print progress.
|
||||
"""
|
||||
|
||||
if verbose:
|
||||
tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
||||
|
||||
# Simplify
|
||||
if simplify and simplify_ratio > 0:
|
||||
mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1))
|
||||
mesh = mesh.decimate(simplify_ratio, progress_bar=verbose)
|
||||
vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:]
|
||||
if verbose:
|
||||
tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
||||
|
||||
# Remove invisible faces
|
||||
if fill_holes:
|
||||
vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda()
|
||||
vertices, faces = _fill_holes(
|
||||
vertices, faces,
|
||||
max_hole_size=fill_holes_max_hole_size,
|
||||
max_hole_nbe=fill_holes_max_hole_nbe,
|
||||
resolution=fill_holes_resolution,
|
||||
num_views=fill_holes_num_views,
|
||||
debug=debug,
|
||||
verbose=verbose,
|
||||
)
|
||||
vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
|
||||
if verbose:
|
||||
tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
||||
|
||||
return vertices, faces
|
||||
|
||||
|
||||
def parametrize_mesh(vertices: np.array, faces: np.array):
|
||||
"""
|
||||
Parametrize a mesh to a texture space, using xatlas.
|
||||
|
||||
Args:
|
||||
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
||||
faces (np.array): Faces of the mesh. Shape (F, 3).
|
||||
"""
|
||||
|
||||
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
||||
|
||||
vertices = vertices[vmapping]
|
||||
faces = indices
|
||||
|
||||
return vertices, faces, uvs
|
||||
|
||||
|
||||
def bake_texture(
|
||||
vertices: np.array,
|
||||
faces: np.array,
|
||||
uvs: np.array,
|
||||
observations: List[np.array],
|
||||
masks: List[np.array],
|
||||
extrinsics: List[np.array],
|
||||
intrinsics: List[np.array],
|
||||
texture_size: int = 2048,
|
||||
near: float = 0.1,
|
||||
far: float = 10.0,
|
||||
mode: Literal['fast', 'opt'] = 'opt',
|
||||
lambda_tv: float = 1e-2,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Bake texture to a mesh from multiple observations.
|
||||
|
||||
Args:
|
||||
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
||||
faces (np.array): Faces of the mesh. Shape (F, 3).
|
||||
uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
|
||||
observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
|
||||
masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
|
||||
extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
|
||||
intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
|
||||
texture_size (int): Size of the texture.
|
||||
near (float): Near plane of the camera.
|
||||
far (float): Far plane of the camera.
|
||||
mode (Literal['fast', 'opt']): Mode of texture baking.
|
||||
lambda_tv (float): Weight of total variation loss in optimization.
|
||||
verbose (bool): Whether to print progress.
|
||||
"""
|
||||
vertices = torch.tensor(vertices).cuda()
|
||||
faces = torch.tensor(faces.astype(np.int32)).cuda()
|
||||
uvs = torch.tensor(uvs).cuda()
|
||||
observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations]
|
||||
masks = [torch.tensor(m>0).bool().cuda() for m in masks]
|
||||
views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics]
|
||||
projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics]
|
||||
|
||||
if mode == 'fast':
|
||||
texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
|
||||
texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
|
||||
rastctx = utils3d.torch.RastContext(backend='cuda')
|
||||
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
|
||||
with torch.no_grad():
|
||||
rast = utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
||||
)
|
||||
uv_map = rast['uv'][0].detach().flip(0)
|
||||
mask = rast['mask'][0].detach().bool() & masks[0]
|
||||
|
||||
# nearest neighbor interpolation
|
||||
uv_map = (uv_map * texture_size).floor().long()
|
||||
obs = observation[mask]
|
||||
uv_map = uv_map[mask]
|
||||
idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
|
||||
texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
|
||||
texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
|
||||
|
||||
mask = texture_weights > 0
|
||||
texture[mask] /= texture_weights[mask][:, None]
|
||||
texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
# inpaint
|
||||
mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
|
||||
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
||||
|
||||
elif mode == 'opt':
|
||||
rastctx = utils3d.torch.RastContext(backend='cuda')
|
||||
observations = [observations.flip(0) for observations in observations]
|
||||
masks = [m.flip(0) for m in masks]
|
||||
_uv = []
|
||||
_uv_dr = []
|
||||
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
|
||||
with torch.no_grad():
|
||||
rast = utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
||||
)
|
||||
_uv.append(rast['uv'].detach())
|
||||
_uv_dr.append(rast['uv_dr'].detach())
|
||||
|
||||
texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
|
||||
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
|
||||
|
||||
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
||||
return start_lr * (end_lr / start_lr) ** (step / total_steps)
|
||||
|
||||
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
||||
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
|
||||
|
||||
def tv_loss(texture):
|
||||
return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
|
||||
torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
|
||||
|
||||
total_steps = 2500
|
||||
with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
|
||||
for step in range(total_steps):
|
||||
optimizer.zero_grad()
|
||||
selected = np.random.randint(0, len(views))
|
||||
uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
|
||||
render = dr.texture(texture, uv, uv_dr)[0]
|
||||
loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
|
||||
if lambda_tv > 0:
|
||||
loss += lambda_tv * tv_loss(texture)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# annealing
|
||||
optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
pbar.update()
|
||||
texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
mask = 1 - utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
|
||||
)['mask'][0].detach().cpu().numpy().astype(np.uint8)
|
||||
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
||||
else:
|
||||
raise ValueError(f'Unknown mode: {mode}')
|
||||
|
||||
return texture
|
||||
|
||||
|
||||
def to_glb(
|
||||
app_rep: Union[Strivec, Gaussian],
|
||||
mesh: MeshExtractResult,
|
||||
simplify: float = 0.95,
|
||||
fill_holes: bool = True,
|
||||
fill_holes_max_size: float = 0.04,
|
||||
texture_size: int = 1024,
|
||||
debug: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> trimesh.Trimesh:
|
||||
"""
|
||||
Convert a generated asset to a glb file.
|
||||
|
||||
Args:
|
||||
app_rep (Union[Strivec, Gaussian]): Appearance representation.
|
||||
mesh (MeshExtractResult): Extracted mesh.
|
||||
simplify (float): Ratio of faces to remove in simplification.
|
||||
fill_holes (bool): Whether to fill holes in the mesh.
|
||||
fill_holes_max_size (float): Maximum area of a hole to fill.
|
||||
texture_size (int): Size of the texture.
|
||||
debug (bool): Whether to print debug information.
|
||||
verbose (bool): Whether to print progress.
|
||||
"""
|
||||
vertices = mesh.vertices.cpu().numpy()
|
||||
faces = mesh.faces.cpu().numpy()
|
||||
|
||||
# mesh postprocess
|
||||
vertices, faces = postprocess_mesh(
|
||||
vertices, faces,
|
||||
simplify=simplify > 0,
|
||||
simplify_ratio=simplify,
|
||||
fill_holes=fill_holes,
|
||||
fill_holes_max_hole_size=fill_holes_max_size,
|
||||
fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)),
|
||||
fill_holes_resolution=1024,
|
||||
fill_holes_num_views=1000,
|
||||
debug=debug,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# parametrize mesh
|
||||
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
||||
|
||||
# bake texture
|
||||
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
|
||||
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
||||
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
||||
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|
||||
texture = bake_texture(
|
||||
vertices, faces, uvs,
|
||||
observations, masks, extrinsics, intrinsics,
|
||||
texture_size=texture_size, mode='opt',
|
||||
lambda_tv=0.01,
|
||||
verbose=verbose
|
||||
)
|
||||
texture = Image.fromarray(texture)
|
||||
|
||||
# rotate mesh (from z-up to y-up)
|
||||
vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
|
||||
material = trimesh.visual.material.PBRMaterial(
|
||||
roughnessFactor=1.0,
|
||||
baseColorTexture=texture,
|
||||
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8)
|
||||
)
|
||||
mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material))
|
||||
return mesh
|
||||
|
||||
|
||||
def simplify_gs(
|
||||
gs: Gaussian,
|
||||
simplify: float = 0.95,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Simplify 3D Gaussians
|
||||
NOTE: this function is not used in the current implementation for the unsatisfactory performance.
|
||||
|
||||
Args:
|
||||
gs (Gaussian): 3D Gaussian.
|
||||
simplify (float): Ratio of Gaussians to remove in simplification.
|
||||
"""
|
||||
if simplify <= 0:
|
||||
return gs
|
||||
|
||||
# simplify
|
||||
observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100)
|
||||
observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations]
|
||||
|
||||
# Following https://arxiv.org/pdf/2411.06019
|
||||
renderer = GaussianRenderer({
|
||||
"resolution": 1024,
|
||||
"near": 0.8,
|
||||
"far": 1.6,
|
||||
"ssaa": 1,
|
||||
"bg_color": (0,0,0),
|
||||
})
|
||||
new_gs = Gaussian(**gs.init_params)
|
||||
new_gs._features_dc = gs._features_dc.clone()
|
||||
new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None
|
||||
new_gs._opacity = torch.nn.Parameter(gs._opacity.clone())
|
||||
new_gs._rotation = torch.nn.Parameter(gs._rotation.clone())
|
||||
new_gs._scaling = torch.nn.Parameter(gs._scaling.clone())
|
||||
new_gs._xyz = torch.nn.Parameter(gs._xyz.clone())
|
||||
|
||||
start_lr = [1e-4, 1e-3, 5e-3, 0.025]
|
||||
end_lr = [1e-6, 1e-5, 5e-5, 0.00025]
|
||||
optimizer = torch.optim.Adam([
|
||||
{"params": new_gs._xyz, "lr": start_lr[0]},
|
||||
{"params": new_gs._rotation, "lr": start_lr[1]},
|
||||
{"params": new_gs._scaling, "lr": start_lr[2]},
|
||||
{"params": new_gs._opacity, "lr": start_lr[3]},
|
||||
], lr=start_lr[0])
|
||||
|
||||
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
||||
return start_lr * (end_lr / start_lr) ** (step / total_steps)
|
||||
|
||||
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
||||
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
|
||||
|
||||
_zeta = new_gs.get_opacity.clone().detach().squeeze()
|
||||
_lambda = torch.zeros_like(_zeta)
|
||||
_delta = 1e-7
|
||||
_interval = 10
|
||||
num_target = int((1 - simplify) * _zeta.shape[0])
|
||||
|
||||
with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar:
|
||||
for i in range(2500):
|
||||
# prune
|
||||
if i % 100 == 0:
|
||||
mask = new_gs.get_opacity.squeeze() > 0.05
|
||||
mask = torch.nonzero(mask).squeeze()
|
||||
new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask])
|
||||
new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask])
|
||||
new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask])
|
||||
new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask])
|
||||
new_gs._features_dc = new_gs._features_dc[mask]
|
||||
new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None
|
||||
_zeta = _zeta[mask]
|
||||
_lambda = _lambda[mask]
|
||||
# update optimizer state
|
||||
for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]):
|
||||
stored_state = optimizer.state[param_group['params'][0]]
|
||||
if 'exp_avg' in stored_state:
|
||||
stored_state['exp_avg'] = stored_state['exp_avg'][mask]
|
||||
stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask]
|
||||
del optimizer.state[param_group['params'][0]]
|
||||
param_group['params'][0] = new_param
|
||||
optimizer.state[param_group['params'][0]] = stored_state
|
||||
|
||||
opacity = new_gs.get_opacity.squeeze()
|
||||
|
||||
# sparisfy
|
||||
if i % _interval == 0:
|
||||
_zeta = _lambda + opacity.detach()
|
||||
if opacity.shape[0] > num_target:
|
||||
index = _zeta.topk(num_target)[1]
|
||||
_m = torch.ones_like(_zeta, dtype=torch.bool)
|
||||
_m[index] = 0
|
||||
_zeta[_m] = 0
|
||||
_lambda = _lambda + opacity.detach() - _zeta
|
||||
|
||||
# sample a random view
|
||||
view_idx = np.random.randint(len(observations))
|
||||
observation = observations[view_idx]
|
||||
extrinsic = extrinsics[view_idx]
|
||||
intrinsic = intrinsics[view_idx]
|
||||
|
||||
color = renderer.render(new_gs, extrinsic, intrinsic)['color']
|
||||
rgb_loss = torch.nn.functional.l1_loss(color, observation)
|
||||
loss = rgb_loss + \
|
||||
_delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2))
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# update lr
|
||||
for j in range(len(optimizer.param_groups)):
|
||||
optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j])
|
||||
|
||||
pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()})
|
||||
pbar.update()
|
||||
|
||||
new_gs._xyz = new_gs._xyz.data
|
||||
new_gs._rotation = new_gs._rotation.data
|
||||
new_gs._scaling = new_gs._scaling.data
|
||||
new_gs._opacity = new_gs._opacity.data
|
||||
|
||||
return new_gs
|
||||
Reference in New Issue
Block a user