1
This commit is contained in:
193
trellis/modules/sparse/attention/serialized_attn.py
Executable file
193
trellis/modules/sparse/attention/serialized_attn.py
Executable file
@@ -0,0 +1,193 @@
|
||||
from typing import *
|
||||
from enum import Enum
|
||||
import torch
|
||||
import math
|
||||
from .. import SparseTensor
|
||||
from .. import DEBUG, ATTN
|
||||
|
||||
if ATTN == 'xformers':
|
||||
import xformers.ops as xops
|
||||
elif ATTN == 'flash_attn':
|
||||
import flash_attn
|
||||
else:
|
||||
raise ValueError(f"Unknown attention module: {ATTN}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
'sparse_serialized_scaled_dot_product_self_attention',
|
||||
]
|
||||
|
||||
|
||||
class SerializeMode(Enum):
|
||||
Z_ORDER = 0
|
||||
Z_ORDER_TRANSPOSED = 1
|
||||
HILBERT = 2
|
||||
HILBERT_TRANSPOSED = 3
|
||||
|
||||
|
||||
SerializeModes = [
|
||||
SerializeMode.Z_ORDER,
|
||||
SerializeMode.Z_ORDER_TRANSPOSED,
|
||||
SerializeMode.HILBERT,
|
||||
SerializeMode.HILBERT_TRANSPOSED
|
||||
]
|
||||
|
||||
|
||||
def calc_serialization(
|
||||
tensor: SparseTensor,
|
||||
window_size: int,
|
||||
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
|
||||
shift_sequence: int = 0,
|
||||
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
||||
"""
|
||||
Calculate serialization and partitioning for a set of coordinates.
|
||||
|
||||
Args:
|
||||
tensor (SparseTensor): The input tensor.
|
||||
window_size (int): The window size to use.
|
||||
serialize_mode (SerializeMode): The serialization mode to use.
|
||||
shift_sequence (int): The shift of serialized sequence.
|
||||
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor, torch.Tensor): Forwards and backwards indices.
|
||||
"""
|
||||
fwd_indices = []
|
||||
bwd_indices = []
|
||||
seq_lens = []
|
||||
seq_batch_indices = []
|
||||
offsets = [0]
|
||||
|
||||
if 'vox2seq' not in globals():
|
||||
import vox2seq
|
||||
|
||||
# Serialize the input
|
||||
serialize_coords = tensor.coords[:, 1:].clone()
|
||||
serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
|
||||
if serialize_mode == SerializeMode.Z_ORDER:
|
||||
code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
|
||||
elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
|
||||
code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
|
||||
elif serialize_mode == SerializeMode.HILBERT:
|
||||
code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
|
||||
elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
|
||||
code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
|
||||
else:
|
||||
raise ValueError(f"Unknown serialize mode: {serialize_mode}")
|
||||
|
||||
for bi, s in enumerate(tensor.layout):
|
||||
num_points = s.stop - s.start
|
||||
num_windows = (num_points + window_size - 1) // window_size
|
||||
valid_window_size = num_points / num_windows
|
||||
to_ordered = torch.argsort(code[s.start:s.stop])
|
||||
if num_windows == 1:
|
||||
fwd_indices.append(to_ordered)
|
||||
bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
|
||||
fwd_indices[-1] += s.start
|
||||
bwd_indices[-1] += offsets[-1]
|
||||
seq_lens.append(num_points)
|
||||
seq_batch_indices.append(bi)
|
||||
offsets.append(offsets[-1] + seq_lens[-1])
|
||||
else:
|
||||
# Partition the input
|
||||
offset = 0
|
||||
mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
|
||||
split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
|
||||
bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
|
||||
for i in range(num_windows):
|
||||
mid = mids[i]
|
||||
valid_start = split[i]
|
||||
valid_end = split[i + 1]
|
||||
padded_start = math.floor(mid - 0.5 * window_size)
|
||||
padded_end = padded_start + window_size
|
||||
fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
|
||||
offset += valid_start - padded_start
|
||||
bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
|
||||
offset += padded_end - valid_start
|
||||
fwd_indices[-1] += s.start
|
||||
seq_lens.extend([window_size] * num_windows)
|
||||
seq_batch_indices.extend([bi] * num_windows)
|
||||
bwd_indices.append(bwd_index + offsets[-1])
|
||||
offsets.append(offsets[-1] + num_windows * window_size)
|
||||
|
||||
fwd_indices = torch.cat(fwd_indices)
|
||||
bwd_indices = torch.cat(bwd_indices)
|
||||
|
||||
return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
|
||||
|
||||
|
||||
def sparse_serialized_scaled_dot_product_self_attention(
|
||||
qkv: SparseTensor,
|
||||
window_size: int,
|
||||
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
|
||||
shift_sequence: int = 0,
|
||||
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
||||
) -> SparseTensor:
|
||||
"""
|
||||
Apply serialized scaled dot product self attention to a sparse tensor.
|
||||
|
||||
Args:
|
||||
qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
|
||||
window_size (int): The window size to use.
|
||||
serialize_mode (SerializeMode): The serialization mode to use.
|
||||
shift_sequence (int): The shift of serialized sequence.
|
||||
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
|
||||
shift (int): The shift to use.
|
||||
"""
|
||||
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
|
||||
|
||||
serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
|
||||
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
|
||||
if serialization_spatial_cache is None:
|
||||
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
|
||||
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
|
||||
else:
|
||||
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
|
||||
|
||||
M = fwd_indices.shape[0]
|
||||
T = qkv.feats.shape[0]
|
||||
H = qkv.feats.shape[2]
|
||||
C = qkv.feats.shape[3]
|
||||
|
||||
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
||||
|
||||
if DEBUG:
|
||||
start = 0
|
||||
qkv_coords = qkv.coords[fwd_indices]
|
||||
for i in range(len(seq_lens)):
|
||||
assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
|
||||
start += seq_lens[i]
|
||||
|
||||
if all([seq_len == window_size for seq_len in seq_lens]):
|
||||
B = len(seq_lens)
|
||||
N = window_size
|
||||
qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
|
||||
if ATTN == 'xformers':
|
||||
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
|
||||
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
||||
elif ATTN == 'flash_attn':
|
||||
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
|
||||
else:
|
||||
raise ValueError(f"Unknown attention module: {ATTN}")
|
||||
out = out.reshape(B * N, H, C) # [M, H, C]
|
||||
else:
|
||||
if ATTN == 'xformers':
|
||||
q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
|
||||
q = q.unsqueeze(0) # [1, M, H, C]
|
||||
k = k.unsqueeze(0) # [1, M, H, C]
|
||||
v = v.unsqueeze(0) # [1, M, H, C]
|
||||
mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
|
||||
out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
|
||||
elif ATTN == 'flash_attn':
|
||||
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
||||
.to(qkv.device).int()
|
||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
|
||||
|
||||
out = out[bwd_indices] # [T, H, C]
|
||||
|
||||
if DEBUG:
|
||||
qkv_coords = qkv_coords[bwd_indices]
|
||||
assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
|
||||
|
||||
return qkv.replace(out)
|
||||
118
trellis/renderers/sh_utils.py
Executable file
118
trellis/renderers/sh_utils.py
Executable file
@@ -0,0 +1,118 @@
|
||||
# Copyright 2021 The PlenOctree Authors.
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# 1. Redistributions of source code must retain the above copyright notice,
|
||||
# this list of conditions and the following disclaimer.
|
||||
#
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
# POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import torch
|
||||
|
||||
C0 = 0.28209479177387814
|
||||
C1 = 0.4886025119029199
|
||||
C2 = [
|
||||
1.0925484305920792,
|
||||
-1.0925484305920792,
|
||||
0.31539156525252005,
|
||||
-1.0925484305920792,
|
||||
0.5462742152960396
|
||||
]
|
||||
C3 = [
|
||||
-0.5900435899266435,
|
||||
2.890611442640554,
|
||||
-0.4570457994644658,
|
||||
0.3731763325901154,
|
||||
-0.4570457994644658,
|
||||
1.445305721320277,
|
||||
-0.5900435899266435
|
||||
]
|
||||
C4 = [
|
||||
2.5033429417967046,
|
||||
-1.7701307697799304,
|
||||
0.9461746957575601,
|
||||
-0.6690465435572892,
|
||||
0.10578554691520431,
|
||||
-0.6690465435572892,
|
||||
0.47308734787878004,
|
||||
-1.7701307697799304,
|
||||
0.6258357354491761,
|
||||
]
|
||||
|
||||
|
||||
def eval_sh(deg, sh, dirs):
|
||||
"""
|
||||
Evaluate spherical harmonics at unit directions
|
||||
using hardcoded SH polynomials.
|
||||
Works with torch/np/jnp.
|
||||
... Can be 0 or more batch dimensions.
|
||||
Args:
|
||||
deg: int SH deg. Currently, 0-3 supported
|
||||
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
|
||||
dirs: jnp.ndarray unit directions [..., 3]
|
||||
Returns:
|
||||
[..., C]
|
||||
"""
|
||||
assert deg <= 4 and deg >= 0
|
||||
coeff = (deg + 1) ** 2
|
||||
assert sh.shape[-1] >= coeff
|
||||
|
||||
result = C0 * sh[..., 0]
|
||||
if deg > 0:
|
||||
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
|
||||
result = (result -
|
||||
C1 * y * sh[..., 1] +
|
||||
C1 * z * sh[..., 2] -
|
||||
C1 * x * sh[..., 3])
|
||||
|
||||
if deg > 1:
|
||||
xx, yy, zz = x * x, y * y, z * z
|
||||
xy, yz, xz = x * y, y * z, x * z
|
||||
result = (result +
|
||||
C2[0] * xy * sh[..., 4] +
|
||||
C2[1] * yz * sh[..., 5] +
|
||||
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
|
||||
C2[3] * xz * sh[..., 7] +
|
||||
C2[4] * (xx - yy) * sh[..., 8])
|
||||
|
||||
if deg > 2:
|
||||
result = (result +
|
||||
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
|
||||
C3[1] * xy * z * sh[..., 10] +
|
||||
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
|
||||
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
|
||||
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
|
||||
C3[5] * z * (xx - yy) * sh[..., 14] +
|
||||
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
|
||||
|
||||
if deg > 3:
|
||||
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
|
||||
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
|
||||
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
|
||||
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
|
||||
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
|
||||
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
|
||||
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
|
||||
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
|
||||
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
|
||||
return result
|
||||
|
||||
def RGB2SH(rgb):
|
||||
return (rgb - 0.5) / C0
|
||||
|
||||
def SH2RGB(sh):
|
||||
return sh * C0 + 0.5
|
||||
274
trellis/representations/mesh/flexicubes/examples/render.py
Normal file
274
trellis/representations/mesh/flexicubes/examples/render.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import copy
|
||||
import math
|
||||
from ipywidgets import interactive, HBox, VBox, FloatLogSlider, IntSlider
|
||||
|
||||
import torch
|
||||
import nvdiffrast.torch as dr
|
||||
import kaolin as kal
|
||||
import util
|
||||
|
||||
###############################################################################
|
||||
# Functions adapted from https://github.com/NVlabs/nvdiffrec
|
||||
###############################################################################
|
||||
|
||||
def get_random_camera_batch(batch_size, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
|
||||
if use_kaolin:
|
||||
camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(
|
||||
*kal.ops.random.sample_spherical_coords((batch_size,), azimuth_low=0., azimuth_high=math.pi * 2,
|
||||
elevation_low=-math.pi / 2., elevation_high=math.pi / 2., device='cuda'),
|
||||
cam_radius
|
||||
), dim=-1)
|
||||
return kal.render.camera.Camera.from_args(
|
||||
eye=camera_pos + torch.rand((batch_size, 1), device='cuda') * 0.5 - 0.25,
|
||||
at=torch.zeros(batch_size, 3),
|
||||
up=torch.tensor([[0., 1., 0.]]),
|
||||
fov=fovy,
|
||||
near=cam_near_far[0], far=cam_near_far[1],
|
||||
height=iter_res[0], width=iter_res[1],
|
||||
device='cuda'
|
||||
)
|
||||
else:
|
||||
def get_random_camera():
|
||||
proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
|
||||
mv = util.translate(0, 0, -cam_radius) @ util.random_rotation_translation(0.25)
|
||||
mvp = proj_mtx @ mv
|
||||
return mv, mvp
|
||||
mv_batch = []
|
||||
mvp_batch = []
|
||||
for i in range(batch_size):
|
||||
mv, mvp = get_random_camera()
|
||||
mv_batch.append(mv)
|
||||
mvp_batch.append(mvp)
|
||||
return torch.stack(mv_batch).to(device), torch.stack(mvp_batch).to(device)
|
||||
|
||||
def get_rotate_camera(itr, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
|
||||
if use_kaolin:
|
||||
ang = (itr / 10) * np.pi * 2
|
||||
camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(torch.tensor(ang), torch.tensor(0.4), -torch.tensor(cam_radius)))
|
||||
return kal.render.camera.Camera.from_args(
|
||||
eye=camera_pos,
|
||||
at=torch.zeros(3),
|
||||
up=torch.tensor([0., 1., 0.]),
|
||||
fov=fovy,
|
||||
near=cam_near_far[0], far=cam_near_far[1],
|
||||
height=iter_res[0], width=iter_res[1],
|
||||
device='cuda'
|
||||
)
|
||||
else:
|
||||
proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
|
||||
|
||||
# Smooth rotation for display.
|
||||
ang = (itr / 10) * np.pi * 2
|
||||
mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
|
||||
mvp = proj_mtx @ mv
|
||||
return mv.to(device), mvp.to(device)
|
||||
|
||||
glctx = dr.RasterizeGLContext()
|
||||
def render_mesh(mesh, camera, iter_res, return_types = ["mask", "depth"], white_bg=False, wireframe_thickness=0.4):
|
||||
vertices_camera = camera.extrinsics.transform(mesh.vertices)
|
||||
face_vertices_camera = kal.ops.mesh.index_vertices_by_faces(
|
||||
vertices_camera, mesh.faces
|
||||
)
|
||||
|
||||
# Projection: nvdiffrast take clip coordinates as input to apply barycentric perspective correction.
|
||||
# Using `camera.intrinsics.transform(vertices_camera) would return the normalized device coordinates.
|
||||
proj = camera.projection_matrix().unsqueeze(1)
|
||||
proj[:, :, 1, 1] = -proj[:, :, 1, 1]
|
||||
homogeneous_vecs = kal.render.camera.up_to_homogeneous(
|
||||
vertices_camera
|
||||
)
|
||||
vertices_clip = (proj @ homogeneous_vecs.unsqueeze(-1)).squeeze(-1)
|
||||
faces_int = mesh.faces.int()
|
||||
|
||||
rast, _ = dr.rasterize(
|
||||
glctx, vertices_clip, faces_int, iter_res)
|
||||
|
||||
out_dict = {}
|
||||
for type in return_types:
|
||||
if type == "mask" :
|
||||
img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
|
||||
elif type == "depth":
|
||||
img = dr.interpolate(homogeneous_vecs, rast, faces_int)[0]
|
||||
elif type == "wireframe":
|
||||
img = torch.logical_or(
|
||||
torch.logical_or(rast[..., 0] < wireframe_thickness, rast[..., 1] < wireframe_thickness),
|
||||
(rast[..., 0] + rast[..., 1]) > (1. - wireframe_thickness)
|
||||
).unsqueeze(-1)
|
||||
elif type == "normals" :
|
||||
img = dr.interpolate(
|
||||
mesh.face_normals.reshape(len(mesh), -1, 3), rast,
|
||||
torch.arange(mesh.faces.shape[0] * 3, device='cuda', dtype=torch.int).reshape(-1, 3)
|
||||
)[0]
|
||||
if white_bg:
|
||||
bg = torch.ones_like(img)
|
||||
alpha = (rast[..., -1:] > 0).float()
|
||||
img = torch.lerp(bg, img, alpha)
|
||||
out_dict[type] = img
|
||||
|
||||
|
||||
return out_dict
|
||||
|
||||
def render_mesh_paper(mesh, mv, mvp, iter_res, return_types = ["mask", "depth"], white_bg=False):
|
||||
'''
|
||||
The rendering function used to produce the results in the paper.
|
||||
'''
|
||||
v_pos_clip = util.xfm_points(mesh.vertices.unsqueeze(0), mvp) # Rotate it to camera coordinates
|
||||
rast, db = dr.rasterize(
|
||||
dr.RasterizeGLContext(), v_pos_clip, mesh.faces.int(), iter_res)
|
||||
|
||||
out_dict = {}
|
||||
for type in return_types:
|
||||
if type == "mask" :
|
||||
img = dr.antialias((rast[..., -1:] > 0).float(), rast, v_pos_clip, mesh.faces.int())
|
||||
elif type == "depth":
|
||||
v_pos_cam = util.xfm_points(mesh.vertices.unsqueeze(0), mv)
|
||||
img, _ = util.interpolate(v_pos_cam, rast, mesh.faces.int())
|
||||
elif type == "normal" :
|
||||
normal_indices = (torch.arange(0, mesh.nrm.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
|
||||
img, _ = util.interpolate(mesh.nrm.unsqueeze(0).contiguous(), rast, normal_indices.int())
|
||||
elif type == "vertex_normal":
|
||||
img, _ = util.interpolate(mesh.v_nrm.unsqueeze(0).contiguous(), rast, mesh.faces.int())
|
||||
img = dr.antialias((img + 1) * 0.5, rast, v_pos_clip, mesh.faces.int())
|
||||
if white_bg:
|
||||
bg = torch.ones_like(img)
|
||||
alpha = (rast[..., -1:] > 0).float()
|
||||
img = torch.lerp(bg, img, alpha)
|
||||
out_dict[type] = img
|
||||
return out_dict
|
||||
|
||||
class SplitVisualizer():
|
||||
def __init__(self, lh_mesh, rh_mesh, height, width):
|
||||
self.lh_mesh = lh_mesh
|
||||
self.rh_mesh = rh_mesh
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.wireframe_thickness = 0.4
|
||||
|
||||
|
||||
def render(self, camera):
|
||||
lh_outputs = render_mesh(
|
||||
self.lh_mesh, camera, (self.height, self.width),
|
||||
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
|
||||
)
|
||||
rh_outputs = render_mesh(
|
||||
self.rh_mesh, camera, (self.height, self.width),
|
||||
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
|
||||
)
|
||||
outputs = {
|
||||
k: torch.cat(
|
||||
[lh_outputs[k][0].permute(1, 0, 2), rh_outputs[k][0].permute(1, 0, 2)],
|
||||
dim=0
|
||||
).permute(1, 0, 2) for k in ["normals", "wireframe"]
|
||||
}
|
||||
return {
|
||||
'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8),
|
||||
'normals': outputs['normals']
|
||||
}
|
||||
|
||||
def show(self, init_camera):
|
||||
visualizer = kal.visualize.IpyTurntableVisualizer(
|
||||
self.height, self.width * 2, copy.deepcopy(init_camera), self.render,
|
||||
max_fps=24, world_up_axis=1)
|
||||
|
||||
def slider_callback(new_wireframe_thickness):
|
||||
"""ipywidgets sliders callback"""
|
||||
with visualizer.out: # This is in case of bug
|
||||
self.wireframe_thickness = new_wireframe_thickness
|
||||
# this is how we request a new update
|
||||
visualizer.render_update()
|
||||
|
||||
wireframe_thickness_slider = FloatLogSlider(
|
||||
value=self.wireframe_thickness,
|
||||
base=10,
|
||||
min=-3,
|
||||
max=-0.4,
|
||||
step=0.1,
|
||||
description='wireframe_thickness',
|
||||
continuous_update=True,
|
||||
readout=True,
|
||||
readout_format='.3f',
|
||||
)
|
||||
|
||||
interactive_slider = interactive(
|
||||
slider_callback,
|
||||
new_wireframe_thickness=wireframe_thickness_slider,
|
||||
)
|
||||
|
||||
full_output = VBox([visualizer.canvas, interactive_slider])
|
||||
display(full_output, visualizer.out)
|
||||
|
||||
class TimelineVisualizer():
|
||||
def __init__(self, meshes, height, width):
|
||||
self.meshes = meshes
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.wireframe_thickness = 0.4
|
||||
self.idx = len(meshes) - 1
|
||||
|
||||
def render(self, camera):
|
||||
outputs = render_mesh(
|
||||
self.meshes[self.idx], camera, (self.height, self.width),
|
||||
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
|
||||
)
|
||||
|
||||
return {
|
||||
'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8)[0],
|
||||
'normals': outputs['normals'][0]
|
||||
}
|
||||
|
||||
def show(self, init_camera):
|
||||
visualizer = kal.visualize.IpyTurntableVisualizer(
|
||||
self.height, self.width, copy.deepcopy(init_camera), self.render,
|
||||
max_fps=24, world_up_axis=1)
|
||||
|
||||
def slider_callback(new_wireframe_thickness, new_idx):
|
||||
"""ipywidgets sliders callback"""
|
||||
with visualizer.out: # This is in case of bug
|
||||
self.wireframe_thickness = new_wireframe_thickness
|
||||
self.idx = new_idx
|
||||
# this is how we request a new update
|
||||
visualizer.render_update()
|
||||
|
||||
wireframe_thickness_slider = FloatLogSlider(
|
||||
value=self.wireframe_thickness,
|
||||
base=10,
|
||||
min=-3,
|
||||
max=-0.4,
|
||||
step=0.1,
|
||||
description='wireframe_thickness',
|
||||
continuous_update=True,
|
||||
readout=True,
|
||||
readout_format='.3f',
|
||||
)
|
||||
|
||||
idx_slider = IntSlider(
|
||||
value=self.idx,
|
||||
min=0,
|
||||
max=len(self.meshes) - 1,
|
||||
description='idx',
|
||||
continuous_update=True,
|
||||
readout=True
|
||||
)
|
||||
|
||||
interactive_slider = interactive(
|
||||
slider_callback,
|
||||
new_wireframe_thickness=wireframe_thickness_slider,
|
||||
new_idx=idx_slider
|
||||
)
|
||||
full_output = HBox([visualizer.canvas, interactive_slider])
|
||||
display(full_output, visualizer.out)
|
||||
30
trellis/utils/random_utils.py
Normal file
30
trellis/utils/random_utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import numpy as np
|
||||
|
||||
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
||||
|
||||
def radical_inverse(base, n):
|
||||
val = 0
|
||||
inv_base = 1.0 / base
|
||||
inv_base_n = inv_base
|
||||
while n > 0:
|
||||
digit = n % base
|
||||
val += digit * inv_base_n
|
||||
n //= base
|
||||
inv_base_n *= inv_base
|
||||
return val
|
||||
|
||||
def halton_sequence(dim, n):
|
||||
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
|
||||
|
||||
def hammersley_sequence(dim, n, num_samples):
|
||||
return [n / num_samples] + halton_sequence(dim - 1, n)
|
||||
|
||||
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
|
||||
u, v = hammersley_sequence(2, n, num_samples)
|
||||
u += offset[0] / num_samples
|
||||
v += offset[1]
|
||||
if remap:
|
||||
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
|
||||
theta = np.arccos(1 - 2 * u) - np.pi / 2
|
||||
phi = v * 2 * np.pi
|
||||
return [phi, theta]
|
||||
120
trellis/utils/render_utils.py
Normal file
120
trellis/utils/render_utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import utils3d
|
||||
from PIL import Image
|
||||
|
||||
from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer
|
||||
from ..representations import Octree, Gaussian, MeshExtractResult
|
||||
from ..modules import sparse as sp
|
||||
from .random_utils import sphere_hammersley_sequence
|
||||
|
||||
|
||||
def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
|
||||
is_list = isinstance(yaws, list)
|
||||
if not is_list:
|
||||
yaws = [yaws]
|
||||
pitchs = [pitchs]
|
||||
if not isinstance(rs, list):
|
||||
rs = [rs] * len(yaws)
|
||||
if not isinstance(fovs, list):
|
||||
fovs = [fovs] * len(yaws)
|
||||
extrinsics = []
|
||||
intrinsics = []
|
||||
for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
|
||||
fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
|
||||
yaw = torch.tensor(float(yaw)).cuda()
|
||||
pitch = torch.tensor(float(pitch)).cuda()
|
||||
orig = torch.tensor([
|
||||
torch.sin(yaw) * torch.cos(pitch),
|
||||
torch.cos(yaw) * torch.cos(pitch),
|
||||
torch.sin(pitch),
|
||||
]).cuda() * r
|
||||
extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
||||
intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
||||
extrinsics.append(extr)
|
||||
intrinsics.append(intr)
|
||||
if not is_list:
|
||||
extrinsics = extrinsics[0]
|
||||
intrinsics = intrinsics[0]
|
||||
return extrinsics, intrinsics
|
||||
|
||||
|
||||
def get_renderer(sample, **kwargs):
|
||||
if isinstance(sample, Octree):
|
||||
renderer = OctreeRenderer()
|
||||
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
||||
renderer.rendering_options.near = kwargs.get('near', 0.8)
|
||||
renderer.rendering_options.far = kwargs.get('far', 1.6)
|
||||
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
|
||||
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
|
||||
renderer.pipe.primitive = sample.primitive
|
||||
elif isinstance(sample, Gaussian):
|
||||
renderer = GaussianRenderer()
|
||||
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
||||
renderer.rendering_options.near = kwargs.get('near', 0.8)
|
||||
renderer.rendering_options.far = kwargs.get('far', 1.6)
|
||||
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
|
||||
renderer.rendering_options.ssaa = kwargs.get('ssaa', 1)
|
||||
renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
|
||||
renderer.pipe.use_mip_gaussian = True
|
||||
elif isinstance(sample, MeshExtractResult):
|
||||
renderer = MeshRenderer()
|
||||
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
||||
renderer.rendering_options.near = kwargs.get('near', 1)
|
||||
renderer.rendering_options.far = kwargs.get('far', 100)
|
||||
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
|
||||
else:
|
||||
raise ValueError(f'Unsupported sample type: {type(sample)}')
|
||||
return renderer
|
||||
|
||||
|
||||
def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
|
||||
renderer = get_renderer(sample, **options)
|
||||
rets = {}
|
||||
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
|
||||
if isinstance(sample, MeshExtractResult):
|
||||
res = renderer.render(sample, extr, intr)
|
||||
if 'normal' not in rets: rets['normal'] = []
|
||||
rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
||||
else:
|
||||
res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
|
||||
if 'color' not in rets: rets['color'] = []
|
||||
if 'depth' not in rets: rets['depth'] = []
|
||||
rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
||||
if 'percent_depth' in res:
|
||||
rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
|
||||
elif 'depth' in res:
|
||||
rets['depth'].append(res['depth'].detach().cpu().numpy())
|
||||
else:
|
||||
rets['depth'].append(None)
|
||||
return rets
|
||||
|
||||
|
||||
def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
|
||||
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
||||
pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
|
||||
yaws = yaws.tolist()
|
||||
pitch = pitch.tolist()
|
||||
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
|
||||
return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|
||||
|
||||
|
||||
def render_multiview(sample, resolution=512, nviews=30):
|
||||
r = 2
|
||||
fov = 40
|
||||
cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
|
||||
yaws = [cam[0] for cam in cams]
|
||||
pitchs = [cam[1] for cam in cams]
|
||||
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
|
||||
res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
|
||||
return res['color'], extrinsics, intrinsics
|
||||
|
||||
|
||||
def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs):
|
||||
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
|
||||
yaw_offset = offset[0]
|
||||
yaw = [y + yaw_offset for y in yaw]
|
||||
pitch = [offset[1] for _ in range(4)]
|
||||
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
|
||||
return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|
||||
Reference in New Issue
Block a user