1
This commit is contained in:
93
trellis/utils/dist_utils.py
Normal file
93
trellis/utils/dist_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
|
||||
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
|
||||
os.environ['MASTER_ADDR'] = master_addr
|
||||
os.environ['MASTER_PORT'] = master_port
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(local_rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def read_file_dist(path):
|
||||
"""
|
||||
Read the binary file distributedly.
|
||||
File is only read once by the rank 0 process and broadcasted to other processes.
|
||||
|
||||
Returns:
|
||||
data (io.BytesIO): The binary data read from the file.
|
||||
"""
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
# read file
|
||||
size = torch.LongTensor(1).cuda()
|
||||
if dist.get_rank() == 0:
|
||||
with open(path, 'rb') as f:
|
||||
data = f.read()
|
||||
data = torch.ByteTensor(
|
||||
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
|
||||
).cuda()
|
||||
size[0] = data.shape[0]
|
||||
# broadcast size
|
||||
dist.broadcast(size, src=0)
|
||||
if dist.get_rank() != 0:
|
||||
data = torch.ByteTensor(size[0].item()).cuda()
|
||||
# broadcast data
|
||||
dist.broadcast(data, src=0)
|
||||
# convert to io.BytesIO
|
||||
data = data.cpu().numpy().tobytes()
|
||||
data = io.BytesIO(data)
|
||||
return data
|
||||
else:
|
||||
with open(path, 'rb') as f:
|
||||
data = f.read()
|
||||
data = io.BytesIO(data)
|
||||
return data
|
||||
|
||||
|
||||
def unwrap_dist(model):
|
||||
"""
|
||||
Unwrap the model from distributed training.
|
||||
"""
|
||||
if isinstance(model, DDP):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
|
||||
@contextmanager
|
||||
def master_first():
|
||||
"""
|
||||
A context manager that ensures master process executes first.
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
yield
|
||||
else:
|
||||
if dist.get_rank() == 0:
|
||||
yield
|
||||
dist.barrier()
|
||||
else:
|
||||
dist.barrier()
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def local_master_first():
|
||||
"""
|
||||
A context manager that ensures local master process executes first.
|
||||
"""
|
||||
if not dist.is_initialized():
|
||||
yield
|
||||
else:
|
||||
if dist.get_rank() % torch.cuda.device_count() == 0:
|
||||
yield
|
||||
dist.barrier()
|
||||
else:
|
||||
dist.barrier()
|
||||
yield
|
||||
|
||||
228
trellis/utils/elastic_utils.py
Normal file
228
trellis/utils/elastic_utils.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MemoryController:
|
||||
"""
|
||||
Base class for memory management during training.
|
||||
"""
|
||||
|
||||
_last_input_size = None
|
||||
_last_mem_ratio = []
|
||||
|
||||
@contextmanager
|
||||
def record(self):
|
||||
pass
|
||||
|
||||
def update_run_states(self, input_size=None, mem_ratio=None):
|
||||
if self._last_input_size is None:
|
||||
self._last_input_size = input_size
|
||||
elif self._last_input_size!= input_size:
|
||||
raise ValueError(f'Input size should not change for different ElasticModules.')
|
||||
self._last_mem_ratio.append(mem_ratio)
|
||||
|
||||
@abstractmethod
|
||||
def get_mem_ratio(self, input_size):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def state_dict(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def log(self):
|
||||
pass
|
||||
|
||||
|
||||
class LinearMemoryController(MemoryController):
|
||||
"""
|
||||
A simple controller for memory management during training.
|
||||
The memory usage is modeled as a linear function of:
|
||||
- the number of input parameters
|
||||
- the ratio of memory the model use compared to the maximum usage (with no checkpointing)
|
||||
memory_usage = k * input_size * mem_ratio + b
|
||||
The controller keeps track of the memory usage and gives the
|
||||
expected memory ratio to keep the memory usage under a target
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
buffer_size=1000,
|
||||
update_every=500,
|
||||
target_ratio=0.8,
|
||||
available_memory=None,
|
||||
max_mem_ratio_start=0.1,
|
||||
params=None,
|
||||
device=None
|
||||
):
|
||||
self.buffer_size = buffer_size
|
||||
self.update_every = update_every
|
||||
self.target_ratio = target_ratio
|
||||
self.device = device or torch.cuda.current_device()
|
||||
self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
|
||||
|
||||
self._memory = np.zeros(buffer_size, dtype=np.float32)
|
||||
self._input_size = np.zeros(buffer_size, dtype=np.float32)
|
||||
self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
|
||||
self._buffer_ptr = 0
|
||||
self._buffer_length = 0
|
||||
self._params = tuple(params) if params is not None else (0.0, 0.0)
|
||||
self._max_mem_ratio = max_mem_ratio_start
|
||||
self.step = 0
|
||||
|
||||
def __repr__(self):
|
||||
return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
|
||||
|
||||
def _add_sample(self, memory, input_size, mem_ratio):
|
||||
self._memory[self._buffer_ptr] = memory
|
||||
self._input_size[self._buffer_ptr] = input_size
|
||||
self._mem_ratio[self._buffer_ptr] = mem_ratio
|
||||
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
||||
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
||||
|
||||
@contextmanager
|
||||
def record(self):
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
self._last_input_size = None
|
||||
self._last_mem_ratio = []
|
||||
yield
|
||||
self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
|
||||
self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
|
||||
self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
|
||||
self.step += 1
|
||||
if self.step % self.update_every == 0:
|
||||
self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
|
||||
self._fit_params()
|
||||
|
||||
def _fit_params(self):
|
||||
memory_usage = self._memory[:self._buffer_length]
|
||||
input_size = self._input_size[:self._buffer_length]
|
||||
mem_ratio = self._mem_ratio[:self._buffer_length]
|
||||
|
||||
x = input_size * mem_ratio
|
||||
y = memory_usage
|
||||
k, b = np.polyfit(x, y, 1)
|
||||
self._params = (k, b)
|
||||
# self._visualize()
|
||||
|
||||
def _visualize(self):
|
||||
import matplotlib.pyplot as plt
|
||||
memory_usage = self._memory[:self._buffer_length]
|
||||
input_size = self._input_size[:self._buffer_length]
|
||||
mem_ratio = self._mem_ratio[:self._buffer_length]
|
||||
k, b = self._params
|
||||
|
||||
plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
|
||||
x = np.array([0.0, 20000.0])
|
||||
plt.plot(x, k * x + b, c='r')
|
||||
plt.savefig(f'linear_memory_controller_{self.step}.png')
|
||||
plt.cla()
|
||||
|
||||
def get_mem_ratio(self, input_size):
|
||||
k, b = self._params
|
||||
if k == 0: return np.random.rand() * self._max_mem_ratio
|
||||
pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
|
||||
return min(self._max_mem_ratio, max(0.0, pred))
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'params': self._params,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._params = tuple(state_dict['params'])
|
||||
|
||||
def log(self):
|
||||
return {
|
||||
'params/k': self._params[0],
|
||||
'params/b': self._params[1],
|
||||
'memory': self._last_memory,
|
||||
'input_size': self._last_input_size,
|
||||
'mem_ratio': self._last_mem_ratio,
|
||||
}
|
||||
|
||||
|
||||
class ElasticModule(nn.Module):
|
||||
"""
|
||||
Module for training with elastic memory management.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._memory_controller: MemoryController = None
|
||||
|
||||
@abstractmethod
|
||||
def _get_input_size(self, *args, **kwargs) -> int:
|
||||
"""
|
||||
Get the size of the input data.
|
||||
|
||||
Returns:
|
||||
int: The size of the input data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
|
||||
"""
|
||||
Forward with a given memory ratio.
|
||||
"""
|
||||
pass
|
||||
|
||||
def register_memory_controller(self, memory_controller: MemoryController):
|
||||
self._memory_controller = memory_controller
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
||||
_, ret = self._forward_with_mem_ratio(*args, **kwargs)
|
||||
else:
|
||||
input_size = self._get_input_size(*args, **kwargs)
|
||||
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
||||
mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
|
||||
self._memory_controller.update_run_states(input_size, mem_ratio)
|
||||
return ret
|
||||
|
||||
|
||||
class ElasticModuleMixin:
|
||||
"""
|
||||
Mixin for training with elastic memory management.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._memory_controller: MemoryController = None
|
||||
|
||||
@abstractmethod
|
||||
def _get_input_size(self, *args, **kwargs) -> int:
|
||||
"""
|
||||
Get the size of the input data.
|
||||
|
||||
Returns:
|
||||
int: The size of the input data.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def with_mem_ratio(self, mem_ratio=1.0) -> float:
|
||||
"""
|
||||
Context manager for training with a reduced memory ratio compared to the full memory usage.
|
||||
|
||||
Returns:
|
||||
float: The exact memory ratio used during the forward pass.
|
||||
"""
|
||||
pass
|
||||
|
||||
def register_memory_controller(self, memory_controller: MemoryController):
|
||||
self._memory_controller = memory_controller
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
|
||||
ret = super().forward(*args, **kwargs)
|
||||
else:
|
||||
input_size = self._get_input_size(*args, **kwargs)
|
||||
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
|
||||
with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
|
||||
ret = super().forward(*args, **kwargs)
|
||||
self._memory_controller.update_run_states(input_size, exact_mem_ratio)
|
||||
return ret
|
||||
587
trellis/utils/postprocessing_utils.py
Normal file
587
trellis/utils/postprocessing_utils.py
Normal file
@@ -0,0 +1,587 @@
|
||||
from typing import *
|
||||
import numpy as np
|
||||
import torch
|
||||
import utils3d
|
||||
import nvdiffrast.torch as dr
|
||||
from tqdm import tqdm
|
||||
import trimesh
|
||||
import trimesh.visual
|
||||
import xatlas
|
||||
import pyvista as pv
|
||||
from pymeshfix import _meshfix
|
||||
import igraph
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from .random_utils import sphere_hammersley_sequence
|
||||
from .render_utils import render_multiview
|
||||
from ..renderers import GaussianRenderer
|
||||
from ..representations import Strivec, Gaussian, MeshExtractResult
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _fill_holes(
|
||||
verts,
|
||||
faces,
|
||||
max_hole_size=0.04,
|
||||
max_hole_nbe=32,
|
||||
resolution=128,
|
||||
num_views=500,
|
||||
debug=False,
|
||||
verbose=False
|
||||
):
|
||||
"""
|
||||
Rasterize a mesh from multiple views and remove invisible faces.
|
||||
Also includes postprocessing to:
|
||||
1. Remove connected components that are have low visibility.
|
||||
2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole.
|
||||
|
||||
Args:
|
||||
verts (torch.Tensor): Vertices of the mesh. Shape (V, 3).
|
||||
faces (torch.Tensor): Faces of the mesh. Shape (F, 3).
|
||||
max_hole_size (float): Maximum area of a hole to fill.
|
||||
resolution (int): Resolution of the rasterization.
|
||||
num_views (int): Number of views to rasterize the mesh.
|
||||
verbose (bool): Whether to print progress.
|
||||
"""
|
||||
# Construct cameras
|
||||
yaws = []
|
||||
pitchs = []
|
||||
for i in range(num_views):
|
||||
y, p = sphere_hammersley_sequence(i, num_views)
|
||||
yaws.append(y)
|
||||
pitchs.append(p)
|
||||
yaws = torch.tensor(yaws).cuda()
|
||||
pitchs = torch.tensor(pitchs).cuda()
|
||||
radius = 2.0
|
||||
fov = torch.deg2rad(torch.tensor(40)).cuda()
|
||||
projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
|
||||
views = []
|
||||
for (yaw, pitch) in zip(yaws, pitchs):
|
||||
orig = torch.tensor([
|
||||
torch.sin(yaw) * torch.cos(pitch),
|
||||
torch.cos(yaw) * torch.cos(pitch),
|
||||
torch.sin(pitch),
|
||||
]).cuda().float() * radius
|
||||
view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
||||
views.append(view)
|
||||
views = torch.stack(views, dim=0)
|
||||
|
||||
# Rasterize
|
||||
visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device)
|
||||
rastctx = utils3d.torch.RastContext(backend='cuda')
|
||||
for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'):
|
||||
view = views[i]
|
||||
buffers = utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection
|
||||
)
|
||||
face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1
|
||||
face_id = torch.unique(face_id).long()
|
||||
visblity[face_id] += 1
|
||||
visblity = visblity.float() / num_views
|
||||
|
||||
# Mincut
|
||||
## construct outer faces
|
||||
edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
|
||||
boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
|
||||
connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge)
|
||||
outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device)
|
||||
for i in range(len(connected_components)):
|
||||
outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5)
|
||||
outer_face_indices = outer_face_indices.nonzero().reshape(-1)
|
||||
|
||||
## construct inner faces
|
||||
inner_face_indices = torch.nonzero(visblity == 0).reshape(-1)
|
||||
if verbose:
|
||||
tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces')
|
||||
if inner_face_indices.shape[0] == 0:
|
||||
return verts, faces
|
||||
|
||||
## Construct dual graph (faces as nodes, edges as edges)
|
||||
dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge)
|
||||
dual_edge2edge = edges[dual_edge2edge]
|
||||
dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1)
|
||||
if verbose:
|
||||
tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges')
|
||||
|
||||
## solve mincut problem
|
||||
### construct main graph
|
||||
g = igraph.Graph()
|
||||
g.add_vertices(faces.shape[0])
|
||||
g.add_edges(dual_edges.cpu().numpy())
|
||||
g.es['weight'] = dual_edges_weights.cpu().numpy()
|
||||
|
||||
### source and target
|
||||
g.add_vertex('s')
|
||||
g.add_vertex('t')
|
||||
|
||||
### connect invisible faces to source
|
||||
g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
|
||||
|
||||
### connect outer faces to target
|
||||
g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
|
||||
|
||||
### solve mincut
|
||||
cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist())
|
||||
remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device)
|
||||
if verbose:
|
||||
tqdm.write(f'Mincut solved, start checking the cut')
|
||||
|
||||
### check if the cut is valid with each connected component
|
||||
to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices])
|
||||
if debug:
|
||||
tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}')
|
||||
valid_remove_cc = []
|
||||
cutting_edges = []
|
||||
for cc in to_remove_cc:
|
||||
#### check if the connected component has low visibility
|
||||
visblity_median = visblity[remove_face_indices[cc]].median()
|
||||
if debug:
|
||||
tqdm.write(f'visblity_median: {visblity_median}')
|
||||
if visblity_median > 0.25:
|
||||
continue
|
||||
|
||||
#### check if the cuting loop is small enough
|
||||
cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True)
|
||||
cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
|
||||
cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)]
|
||||
if len(cc_new_boundary_edge_indices) > 0:
|
||||
cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices])
|
||||
cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc]
|
||||
cc_new_boundary_edges_cc_area = []
|
||||
for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
|
||||
_e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i]
|
||||
_e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i]
|
||||
cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5)
|
||||
if debug:
|
||||
cutting_edges.append(cc_new_boundary_edge_indices)
|
||||
tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}')
|
||||
if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]):
|
||||
continue
|
||||
|
||||
valid_remove_cc.append(cc)
|
||||
|
||||
if debug:
|
||||
face_v = verts[faces].mean(dim=1).cpu().numpy()
|
||||
vis_dual_edges = dual_edges.cpu().numpy()
|
||||
vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8)
|
||||
vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255]
|
||||
vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0]
|
||||
vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255]
|
||||
if len(valid_remove_cc) > 0:
|
||||
vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0]
|
||||
utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors)
|
||||
|
||||
vis_verts = verts.cpu().numpy()
|
||||
vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy()
|
||||
utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges)
|
||||
|
||||
|
||||
if len(valid_remove_cc) > 0:
|
||||
remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)]
|
||||
mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device)
|
||||
mask[remove_face_indices] = 0
|
||||
faces = faces[mask]
|
||||
faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts)
|
||||
if verbose:
|
||||
tqdm.write(f'Removed {(~mask).sum()} faces by mincut')
|
||||
else:
|
||||
if verbose:
|
||||
tqdm.write(f'Removed 0 faces by mincut')
|
||||
|
||||
mesh = _meshfix.PyTMesh()
|
||||
mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy())
|
||||
mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
|
||||
verts, faces = mesh.return_arrays()
|
||||
verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32)
|
||||
|
||||
return verts, faces
|
||||
|
||||
|
||||
def postprocess_mesh(
|
||||
vertices: np.array,
|
||||
faces: np.array,
|
||||
simplify: bool = True,
|
||||
simplify_ratio: float = 0.9,
|
||||
fill_holes: bool = True,
|
||||
fill_holes_max_hole_size: float = 0.04,
|
||||
fill_holes_max_hole_nbe: int = 32,
|
||||
fill_holes_resolution: int = 1024,
|
||||
fill_holes_num_views: int = 1000,
|
||||
debug: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces.
|
||||
|
||||
Args:
|
||||
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
||||
faces (np.array): Faces of the mesh. Shape (F, 3).
|
||||
simplify (bool): Whether to simplify the mesh, using quadric edge collapse.
|
||||
simplify_ratio (float): Ratio of faces to keep after simplification.
|
||||
fill_holes (bool): Whether to fill holes in the mesh.
|
||||
fill_holes_max_hole_size (float): Maximum area of a hole to fill.
|
||||
fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill.
|
||||
fill_holes_resolution (int): Resolution of the rasterization.
|
||||
fill_holes_num_views (int): Number of views to rasterize the mesh.
|
||||
verbose (bool): Whether to print progress.
|
||||
"""
|
||||
|
||||
if verbose:
|
||||
tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
||||
|
||||
# Simplify
|
||||
if simplify and simplify_ratio > 0:
|
||||
mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1))
|
||||
mesh = mesh.decimate(simplify_ratio, progress_bar=verbose)
|
||||
vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:]
|
||||
if verbose:
|
||||
tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
||||
|
||||
# Remove invisible faces
|
||||
if fill_holes:
|
||||
vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda()
|
||||
vertices, faces = _fill_holes(
|
||||
vertices, faces,
|
||||
max_hole_size=fill_holes_max_hole_size,
|
||||
max_hole_nbe=fill_holes_max_hole_nbe,
|
||||
resolution=fill_holes_resolution,
|
||||
num_views=fill_holes_num_views,
|
||||
debug=debug,
|
||||
verbose=verbose,
|
||||
)
|
||||
vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
|
||||
if verbose:
|
||||
tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
|
||||
|
||||
return vertices, faces
|
||||
|
||||
|
||||
def parametrize_mesh(vertices: np.array, faces: np.array):
|
||||
"""
|
||||
Parametrize a mesh to a texture space, using xatlas.
|
||||
|
||||
Args:
|
||||
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
||||
faces (np.array): Faces of the mesh. Shape (F, 3).
|
||||
"""
|
||||
|
||||
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
|
||||
|
||||
vertices = vertices[vmapping]
|
||||
faces = indices
|
||||
|
||||
return vertices, faces, uvs
|
||||
|
||||
|
||||
def bake_texture(
|
||||
vertices: np.array,
|
||||
faces: np.array,
|
||||
uvs: np.array,
|
||||
observations: List[np.array],
|
||||
masks: List[np.array],
|
||||
extrinsics: List[np.array],
|
||||
intrinsics: List[np.array],
|
||||
texture_size: int = 2048,
|
||||
near: float = 0.1,
|
||||
far: float = 10.0,
|
||||
mode: Literal['fast', 'opt'] = 'opt',
|
||||
lambda_tv: float = 1e-2,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Bake texture to a mesh from multiple observations.
|
||||
|
||||
Args:
|
||||
vertices (np.array): Vertices of the mesh. Shape (V, 3).
|
||||
faces (np.array): Faces of the mesh. Shape (F, 3).
|
||||
uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
|
||||
observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
|
||||
masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
|
||||
extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
|
||||
intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
|
||||
texture_size (int): Size of the texture.
|
||||
near (float): Near plane of the camera.
|
||||
far (float): Far plane of the camera.
|
||||
mode (Literal['fast', 'opt']): Mode of texture baking.
|
||||
lambda_tv (float): Weight of total variation loss in optimization.
|
||||
verbose (bool): Whether to print progress.
|
||||
"""
|
||||
vertices = torch.tensor(vertices).cuda()
|
||||
faces = torch.tensor(faces.astype(np.int32)).cuda()
|
||||
uvs = torch.tensor(uvs).cuda()
|
||||
observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations]
|
||||
masks = [torch.tensor(m>0).bool().cuda() for m in masks]
|
||||
views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics]
|
||||
projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics]
|
||||
|
||||
if mode == 'fast':
|
||||
texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
|
||||
texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
|
||||
rastctx = utils3d.torch.RastContext(backend='cuda')
|
||||
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
|
||||
with torch.no_grad():
|
||||
rast = utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
||||
)
|
||||
uv_map = rast['uv'][0].detach().flip(0)
|
||||
mask = rast['mask'][0].detach().bool() & masks[0]
|
||||
|
||||
# nearest neighbor interpolation
|
||||
uv_map = (uv_map * texture_size).floor().long()
|
||||
obs = observation[mask]
|
||||
uv_map = uv_map[mask]
|
||||
idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
|
||||
texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
|
||||
texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
|
||||
|
||||
mask = texture_weights > 0
|
||||
texture[mask] /= texture_weights[mask][:, None]
|
||||
texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
# inpaint
|
||||
mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
|
||||
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
||||
|
||||
elif mode == 'opt':
|
||||
rastctx = utils3d.torch.RastContext(backend='cuda')
|
||||
observations = [observations.flip(0) for observations in observations]
|
||||
masks = [m.flip(0) for m in masks]
|
||||
_uv = []
|
||||
_uv_dr = []
|
||||
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
|
||||
with torch.no_grad():
|
||||
rast = utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
|
||||
)
|
||||
_uv.append(rast['uv'].detach())
|
||||
_uv_dr.append(rast['uv_dr'].detach())
|
||||
|
||||
texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
|
||||
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
|
||||
|
||||
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
||||
return start_lr * (end_lr / start_lr) ** (step / total_steps)
|
||||
|
||||
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
||||
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
|
||||
|
||||
def tv_loss(texture):
|
||||
return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
|
||||
torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
|
||||
|
||||
total_steps = 2500
|
||||
with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
|
||||
for step in range(total_steps):
|
||||
optimizer.zero_grad()
|
||||
selected = np.random.randint(0, len(views))
|
||||
uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
|
||||
render = dr.texture(texture, uv, uv_dr)[0]
|
||||
loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
|
||||
if lambda_tv > 0:
|
||||
loss += lambda_tv * tv_loss(texture)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# annealing
|
||||
optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
|
||||
pbar.set_postfix({'loss': loss.item()})
|
||||
pbar.update()
|
||||
texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
mask = 1 - utils3d.torch.rasterize_triangle_faces(
|
||||
rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
|
||||
)['mask'][0].detach().cpu().numpy().astype(np.uint8)
|
||||
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
|
||||
else:
|
||||
raise ValueError(f'Unknown mode: {mode}')
|
||||
|
||||
return texture
|
||||
|
||||
|
||||
def to_glb(
|
||||
app_rep: Union[Strivec, Gaussian],
|
||||
mesh: MeshExtractResult,
|
||||
simplify: float = 0.95,
|
||||
fill_holes: bool = True,
|
||||
fill_holes_max_size: float = 0.04,
|
||||
texture_size: int = 1024,
|
||||
debug: bool = False,
|
||||
verbose: bool = True,
|
||||
) -> trimesh.Trimesh:
|
||||
"""
|
||||
Convert a generated asset to a glb file.
|
||||
|
||||
Args:
|
||||
app_rep (Union[Strivec, Gaussian]): Appearance representation.
|
||||
mesh (MeshExtractResult): Extracted mesh.
|
||||
simplify (float): Ratio of faces to remove in simplification.
|
||||
fill_holes (bool): Whether to fill holes in the mesh.
|
||||
fill_holes_max_size (float): Maximum area of a hole to fill.
|
||||
texture_size (int): Size of the texture.
|
||||
debug (bool): Whether to print debug information.
|
||||
verbose (bool): Whether to print progress.
|
||||
"""
|
||||
vertices = mesh.vertices.cpu().numpy()
|
||||
faces = mesh.faces.cpu().numpy()
|
||||
|
||||
# mesh postprocess
|
||||
vertices, faces = postprocess_mesh(
|
||||
vertices, faces,
|
||||
simplify=simplify > 0,
|
||||
simplify_ratio=simplify,
|
||||
fill_holes=fill_holes,
|
||||
fill_holes_max_hole_size=fill_holes_max_size,
|
||||
fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)),
|
||||
fill_holes_resolution=1024,
|
||||
fill_holes_num_views=1000,
|
||||
debug=debug,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# parametrize mesh
|
||||
vertices, faces, uvs = parametrize_mesh(vertices, faces)
|
||||
|
||||
# bake texture
|
||||
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
|
||||
masks = [np.any(observation > 0, axis=-1) for observation in observations]
|
||||
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
|
||||
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
|
||||
texture = bake_texture(
|
||||
vertices, faces, uvs,
|
||||
observations, masks, extrinsics, intrinsics,
|
||||
texture_size=texture_size, mode='opt',
|
||||
lambda_tv=0.01,
|
||||
verbose=verbose
|
||||
)
|
||||
texture = Image.fromarray(texture)
|
||||
|
||||
# rotate mesh (from z-up to y-up)
|
||||
vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
|
||||
material = trimesh.visual.material.PBRMaterial(
|
||||
roughnessFactor=1.0,
|
||||
baseColorTexture=texture,
|
||||
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8)
|
||||
)
|
||||
mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material))
|
||||
return mesh
|
||||
|
||||
|
||||
def simplify_gs(
|
||||
gs: Gaussian,
|
||||
simplify: float = 0.95,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Simplify 3D Gaussians
|
||||
NOTE: this function is not used in the current implementation for the unsatisfactory performance.
|
||||
|
||||
Args:
|
||||
gs (Gaussian): 3D Gaussian.
|
||||
simplify (float): Ratio of Gaussians to remove in simplification.
|
||||
"""
|
||||
if simplify <= 0:
|
||||
return gs
|
||||
|
||||
# simplify
|
||||
observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100)
|
||||
observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations]
|
||||
|
||||
# Following https://arxiv.org/pdf/2411.06019
|
||||
renderer = GaussianRenderer({
|
||||
"resolution": 1024,
|
||||
"near": 0.8,
|
||||
"far": 1.6,
|
||||
"ssaa": 1,
|
||||
"bg_color": (0,0,0),
|
||||
})
|
||||
new_gs = Gaussian(**gs.init_params)
|
||||
new_gs._features_dc = gs._features_dc.clone()
|
||||
new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None
|
||||
new_gs._opacity = torch.nn.Parameter(gs._opacity.clone())
|
||||
new_gs._rotation = torch.nn.Parameter(gs._rotation.clone())
|
||||
new_gs._scaling = torch.nn.Parameter(gs._scaling.clone())
|
||||
new_gs._xyz = torch.nn.Parameter(gs._xyz.clone())
|
||||
|
||||
start_lr = [1e-4, 1e-3, 5e-3, 0.025]
|
||||
end_lr = [1e-6, 1e-5, 5e-5, 0.00025]
|
||||
optimizer = torch.optim.Adam([
|
||||
{"params": new_gs._xyz, "lr": start_lr[0]},
|
||||
{"params": new_gs._rotation, "lr": start_lr[1]},
|
||||
{"params": new_gs._scaling, "lr": start_lr[2]},
|
||||
{"params": new_gs._opacity, "lr": start_lr[3]},
|
||||
], lr=start_lr[0])
|
||||
|
||||
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
||||
return start_lr * (end_lr / start_lr) ** (step / total_steps)
|
||||
|
||||
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
|
||||
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
|
||||
|
||||
_zeta = new_gs.get_opacity.clone().detach().squeeze()
|
||||
_lambda = torch.zeros_like(_zeta)
|
||||
_delta = 1e-7
|
||||
_interval = 10
|
||||
num_target = int((1 - simplify) * _zeta.shape[0])
|
||||
|
||||
with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar:
|
||||
for i in range(2500):
|
||||
# prune
|
||||
if i % 100 == 0:
|
||||
mask = new_gs.get_opacity.squeeze() > 0.05
|
||||
mask = torch.nonzero(mask).squeeze()
|
||||
new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask])
|
||||
new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask])
|
||||
new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask])
|
||||
new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask])
|
||||
new_gs._features_dc = new_gs._features_dc[mask]
|
||||
new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None
|
||||
_zeta = _zeta[mask]
|
||||
_lambda = _lambda[mask]
|
||||
# update optimizer state
|
||||
for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]):
|
||||
stored_state = optimizer.state[param_group['params'][0]]
|
||||
if 'exp_avg' in stored_state:
|
||||
stored_state['exp_avg'] = stored_state['exp_avg'][mask]
|
||||
stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask]
|
||||
del optimizer.state[param_group['params'][0]]
|
||||
param_group['params'][0] = new_param
|
||||
optimizer.state[param_group['params'][0]] = stored_state
|
||||
|
||||
opacity = new_gs.get_opacity.squeeze()
|
||||
|
||||
# sparisfy
|
||||
if i % _interval == 0:
|
||||
_zeta = _lambda + opacity.detach()
|
||||
if opacity.shape[0] > num_target:
|
||||
index = _zeta.topk(num_target)[1]
|
||||
_m = torch.ones_like(_zeta, dtype=torch.bool)
|
||||
_m[index] = 0
|
||||
_zeta[_m] = 0
|
||||
_lambda = _lambda + opacity.detach() - _zeta
|
||||
|
||||
# sample a random view
|
||||
view_idx = np.random.randint(len(observations))
|
||||
observation = observations[view_idx]
|
||||
extrinsic = extrinsics[view_idx]
|
||||
intrinsic = intrinsics[view_idx]
|
||||
|
||||
color = renderer.render(new_gs, extrinsic, intrinsic)['color']
|
||||
rgb_loss = torch.nn.functional.l1_loss(color, observation)
|
||||
loss = rgb_loss + \
|
||||
_delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2))
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# update lr
|
||||
for j in range(len(optimizer.param_groups)):
|
||||
optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j])
|
||||
|
||||
pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()})
|
||||
pbar.update()
|
||||
|
||||
new_gs._xyz = new_gs._xyz.data
|
||||
new_gs._rotation = new_gs._rotation.data
|
||||
new_gs._scaling = new_gs._scaling.data
|
||||
new_gs._opacity = new_gs._opacity.data
|
||||
|
||||
return new_gs
|
||||
Reference in New Issue
Block a user