This commit is contained in:
zcr
2026-03-17 11:29:31 +08:00
parent 24e4c120be
commit a6d9bac6d0
11 changed files with 2159 additions and 0 deletions

View File

@@ -0,0 +1,528 @@
import argparse, sys, os, math, re, glob
from typing import *
import bpy
from mathutils import Vector, Matrix
import numpy as np
import json
import glob
"""=============== BLENDER ==============="""
IMPORT_FUNCTIONS: Dict[str, Callable] = {
"obj": bpy.ops.import_scene.obj,
"glb": bpy.ops.import_scene.gltf,
"gltf": bpy.ops.import_scene.gltf,
"usd": bpy.ops.import_scene.usd,
"fbx": bpy.ops.import_scene.fbx,
"stl": bpy.ops.import_mesh.stl,
"usda": bpy.ops.import_scene.usda,
"dae": bpy.ops.wm.collada_import,
"ply": bpy.ops.import_mesh.ply,
"abc": bpy.ops.wm.alembic_import,
"blend": bpy.ops.wm.append,
}
EXT = {
'PNG': 'png',
'JPEG': 'jpg',
'OPEN_EXR': 'exr',
'TIFF': 'tiff',
'BMP': 'bmp',
'HDR': 'hdr',
'TARGA': 'tga'
}
def init_render(engine='CYCLES', resolution=512, geo_mode=False):
bpy.context.scene.render.engine = engine
bpy.context.scene.render.resolution_x = resolution
bpy.context.scene.render.resolution_y = resolution
bpy.context.scene.render.resolution_percentage = 100
bpy.context.scene.render.image_settings.file_format = 'PNG'
bpy.context.scene.render.image_settings.color_mode = 'RGBA'
bpy.context.scene.render.film_transparent = True
bpy.context.scene.cycles.device = 'GPU'
bpy.context.scene.cycles.samples = 128 if not geo_mode else 1
bpy.context.scene.cycles.filter_type = 'BOX'
bpy.context.scene.cycles.filter_width = 1
bpy.context.scene.cycles.diffuse_bounces = 1
bpy.context.scene.cycles.glossy_bounces = 1
bpy.context.scene.cycles.transparent_max_bounces = 3 if not geo_mode else 0
bpy.context.scene.cycles.transmission_bounces = 3 if not geo_mode else 1
bpy.context.scene.cycles.use_denoising = True
bpy.context.preferences.addons['cycles'].preferences.get_devices()
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
def init_nodes(save_depth=False, save_normal=False, save_albedo=False, save_mist=False):
if not any([save_depth, save_normal, save_albedo, save_mist]):
return {}, {}
outputs = {}
spec_nodes = {}
bpy.context.scene.use_nodes = True
bpy.context.scene.view_layers['View Layer'].use_pass_z = save_depth
bpy.context.scene.view_layers['View Layer'].use_pass_normal = save_normal
bpy.context.scene.view_layers['View Layer'].use_pass_diffuse_color = save_albedo
bpy.context.scene.view_layers['View Layer'].use_pass_mist = save_mist
nodes = bpy.context.scene.node_tree.nodes
links = bpy.context.scene.node_tree.links
for n in nodes:
nodes.remove(n)
render_layers = nodes.new('CompositorNodeRLayers')
if save_depth:
depth_file_output = nodes.new('CompositorNodeOutputFile')
depth_file_output.base_path = ''
depth_file_output.file_slots[0].use_node_format = True
depth_file_output.format.file_format = 'PNG'
depth_file_output.format.color_depth = '16'
depth_file_output.format.color_mode = 'BW'
# Remap to 0-1
map = nodes.new(type="CompositorNodeMapRange")
map.inputs[1].default_value = 0 # (min value you will be getting)
map.inputs[2].default_value = 10 # (max value you will be getting)
map.inputs[3].default_value = 0 # (min value you will map to)
map.inputs[4].default_value = 1 # (max value you will map to)
links.new(render_layers.outputs['Depth'], map.inputs[0])
links.new(map.outputs[0], depth_file_output.inputs[0])
outputs['depth'] = depth_file_output
spec_nodes['depth_map'] = map
if save_normal:
normal_file_output = nodes.new('CompositorNodeOutputFile')
normal_file_output.base_path = ''
normal_file_output.file_slots[0].use_node_format = True
normal_file_output.format.file_format = 'OPEN_EXR'
normal_file_output.format.color_mode = 'RGB'
normal_file_output.format.color_depth = '16'
links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0])
outputs['normal'] = normal_file_output
if save_albedo:
albedo_file_output = nodes.new('CompositorNodeOutputFile')
albedo_file_output.base_path = ''
albedo_file_output.file_slots[0].use_node_format = True
albedo_file_output.format.file_format = 'PNG'
albedo_file_output.format.color_mode = 'RGBA'
albedo_file_output.format.color_depth = '8'
alpha_albedo = nodes.new('CompositorNodeSetAlpha')
links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image'])
links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha'])
links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0])
outputs['albedo'] = albedo_file_output
if save_mist:
bpy.data.worlds['World'].mist_settings.start = 0
bpy.data.worlds['World'].mist_settings.depth = 10
mist_file_output = nodes.new('CompositorNodeOutputFile')
mist_file_output.base_path = ''
mist_file_output.file_slots[0].use_node_format = True
mist_file_output.format.file_format = 'PNG'
mist_file_output.format.color_mode = 'BW'
mist_file_output.format.color_depth = '16'
links.new(render_layers.outputs['Mist'], mist_file_output.inputs[0])
outputs['mist'] = mist_file_output
return outputs, spec_nodes
def init_scene() -> None:
"""Resets the scene to a clean state.
Returns:
None
"""
# delete everything
for obj in bpy.data.objects:
bpy.data.objects.remove(obj, do_unlink=True)
# delete all the materials
for material in bpy.data.materials:
bpy.data.materials.remove(material, do_unlink=True)
# delete all the textures
for texture in bpy.data.textures:
bpy.data.textures.remove(texture, do_unlink=True)
# delete all the images
for image in bpy.data.images:
bpy.data.images.remove(image, do_unlink=True)
def init_camera():
cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera'))
bpy.context.collection.objects.link(cam)
bpy.context.scene.camera = cam
cam.data.sensor_height = cam.data.sensor_width = 32
cam_constraint = cam.constraints.new(type='TRACK_TO')
cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
cam_constraint.up_axis = 'UP_Y'
cam_empty = bpy.data.objects.new("Empty", None)
cam_empty.location = (0, 0, 0)
bpy.context.scene.collection.objects.link(cam_empty)
cam_constraint.target = cam_empty
return cam
def init_lighting():
# Clear existing lights
bpy.ops.object.select_all(action="DESELECT")
bpy.ops.object.select_by_type(type="LIGHT")
bpy.ops.object.delete()
# Create key light
default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT"))
bpy.context.collection.objects.link(default_light)
default_light.data.energy = 1000
default_light.location = (4, 1, 6)
default_light.rotation_euler = (0, 0, 0)
# create top light
top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA"))
bpy.context.collection.objects.link(top_light)
top_light.data.energy = 10000
top_light.location = (0, 0, 10)
top_light.scale = (100, 100, 100)
# create bottom light
bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA"))
bpy.context.collection.objects.link(bottom_light)
bottom_light.data.energy = 1000
bottom_light.location = (0, 0, -10)
bottom_light.rotation_euler = (0, 0, 0)
return {
"default_light": default_light,
"top_light": top_light,
"bottom_light": bottom_light
}
def load_object(object_path: str) -> None:
"""Loads a model with a supported file extension into the scene.
Args:
object_path (str): Path to the model file.
Raises:
ValueError: If the file extension is not supported.
Returns:
None
"""
file_extension = object_path.split(".")[-1].lower()
if file_extension is None:
raise ValueError(f"Unsupported file type: {object_path}")
if file_extension == "usdz":
# install usdz io package
dirname = os.path.dirname(os.path.realpath(__file__))
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
bpy.ops.preferences.addon_install(filepath=usdz_package)
# enable it
addon_name = "io_scene_usdz"
bpy.ops.preferences.addon_enable(module=addon_name)
# import the usdz
from io_scene_usdz.import_usdz import import_usdz
import_usdz(context, filepath=object_path, materials=True, animations=True)
return None
# load from existing import functions
import_function = IMPORT_FUNCTIONS[file_extension]
print(f"Loading object from {object_path}")
if file_extension == "blend":
import_function(directory=object_path, link=False)
elif file_extension in {"glb", "gltf"}:
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS')
else:
import_function(filepath=object_path)
def delete_invisible_objects() -> None:
"""Deletes all invisible objects in the scene.
Returns:
None
"""
# bpy.ops.object.mode_set(mode="OBJECT")
bpy.ops.object.select_all(action="DESELECT")
for obj in bpy.context.scene.objects:
if obj.hide_viewport or obj.hide_render:
obj.hide_viewport = False
obj.hide_render = False
obj.hide_select = False
obj.select_set(True)
bpy.ops.object.delete()
# Delete invisible collections
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
for col in invisible_collections:
bpy.data.collections.remove(col)
def split_mesh_normal():
bpy.ops.object.select_all(action="DESELECT")
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
bpy.context.view_layer.objects.active = objs[0]
for obj in objs:
obj.select_set(True)
bpy.ops.object.mode_set(mode="EDIT")
bpy.ops.mesh.select_all(action='SELECT')
bpy.ops.mesh.split_normals()
bpy.ops.object.mode_set(mode='OBJECT')
bpy.ops.object.select_all(action="DESELECT")
def delete_custom_normals():
for this_obj in bpy.data.objects:
if this_obj.type == "MESH":
bpy.context.view_layer.objects.active = this_obj
bpy.ops.mesh.customdata_custom_splitnormals_clear()
def override_material():
new_mat = bpy.data.materials.new(name="Override0123456789")
new_mat.use_nodes = True
new_mat.node_tree.nodes.clear()
bsdf = new_mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1)
bsdf.inputs[1].default_value = 1
output = new_mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
new_mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
bpy.context.scene.view_layers['View Layer'].material_override = new_mat
def unhide_all_objects() -> None:
"""Unhides all objects in the scene.
Returns:
None
"""
for obj in bpy.context.scene.objects:
obj.hide_set(False)
def convert_to_meshes() -> None:
"""Converts all objects in the scene to meshes.
Returns:
None
"""
bpy.ops.object.select_all(action="DESELECT")
bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
for obj in bpy.context.scene.objects:
obj.select_set(True)
bpy.ops.object.convert(target="MESH")
def triangulate_meshes() -> None:
"""Triangulates all meshes in the scene.
Returns:
None
"""
bpy.ops.object.select_all(action="DESELECT")
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
bpy.context.view_layer.objects.active = objs[0]
for obj in objs:
obj.select_set(True)
bpy.ops.object.mode_set(mode="EDIT")
bpy.ops.mesh.reveal()
bpy.ops.mesh.select_all(action="SELECT")
bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
bpy.ops.object.mode_set(mode="OBJECT")
bpy.ops.object.select_all(action="DESELECT")
def scene_bbox() -> Tuple[Vector, Vector]:
"""Returns the bounding box of the scene.
Taken from Shap-E rendering script
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
Returns:
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
"""
bbox_min = (math.inf,) * 3
bbox_max = (-math.inf,) * 3
found = False
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
for obj in scene_meshes:
found = True
for coord in obj.bound_box:
coord = Vector(coord)
coord = obj.matrix_world @ coord
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
if not found:
raise RuntimeError("no objects in scene to compute bounding box for")
return Vector(bbox_min), Vector(bbox_max)
def normalize_scene() -> Tuple[float, Vector]:
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
at the origin.
Mostly taken from the Point-E / Shap-E rendering script
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
but fix for multiple root objects: (see bug report here:
https://github.com/openai/shap-e/pull/60).
Returns:
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
"""
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
if len(scene_root_objects) > 1:
# create an empty object to be used as a parent for all root objects
scene = bpy.data.objects.new("ParentEmpty", None)
bpy.context.scene.collection.objects.link(scene)
# parent all root objects to the empty object
for obj in scene_root_objects:
obj.parent = scene
else:
scene = scene_root_objects[0]
bbox_min, bbox_max = scene_bbox()
scale = 1 / max(bbox_max - bbox_min)
scene.scale = scene.scale * scale
# Apply scale to matrix_world.
bpy.context.view_layer.update()
bbox_min, bbox_max = scene_bbox()
offset = -(bbox_min + bbox_max) / 2
scene.matrix_world.translation += offset
bpy.ops.object.select_all(action="DESELECT")
return scale, offset
def get_transform_matrix(obj: bpy.types.Object) -> list:
pos, rt, _ = obj.matrix_world.decompose()
rt = rt.to_matrix()
matrix = []
for ii in range(3):
a = []
for jj in range(3):
a.append(rt[ii][jj])
a.append(pos[ii])
matrix.append(a)
matrix.append([0, 0, 0, 1])
return matrix
def main(arg):
os.makedirs(arg.output_folder, exist_ok=True)
# Initialize context
init_render(engine=arg.engine, resolution=arg.resolution, geo_mode=arg.geo_mode)
outputs, spec_nodes = init_nodes(
save_depth=arg.save_depth,
save_normal=arg.save_normal,
save_albedo=arg.save_albedo,
save_mist=arg.save_mist
)
if arg.object.endswith(".blend"):
delete_invisible_objects()
else:
init_scene()
load_object(arg.object)
if arg.split_normal:
split_mesh_normal()
# delete_custom_normals()
print('[INFO] Scene initialized.')
# normalize scene
scale, offset = normalize_scene()
print('[INFO] Scene normalized.')
# Initialize camera and lighting
cam = init_camera()
init_lighting()
print('[INFO] Camera and lighting initialized.')
# Override material
if arg.geo_mode:
override_material()
# Create a list of views
to_export = {
"aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
"scale": scale,
"offset": [offset.x, offset.y, offset.z],
"frames": []
}
views = json.loads(arg.views)
for i, view in enumerate(views):
cam.location = (
view['radius'] * np.cos(view['yaw']) * np.cos(view['pitch']),
view['radius'] * np.sin(view['yaw']) * np.cos(view['pitch']),
view['radius'] * np.sin(view['pitch'])
)
cam.data.lens = 16 / np.tan(view['fov'] / 2)
if arg.save_depth:
spec_nodes['depth_map'].inputs[1].default_value = view['radius'] - 0.5 * np.sqrt(3)
spec_nodes['depth_map'].inputs[2].default_value = view['radius'] + 0.5 * np.sqrt(3)
bpy.context.scene.render.filepath = os.path.join(arg.output_folder, f'{i:03d}.png')
for name, output in outputs.items():
output.file_slots[0].path = os.path.join(arg.output_folder, f'{i:03d}_{name}')
# Render the scene
bpy.ops.render.render(write_still=True)
bpy.context.view_layer.update()
for name, output in outputs.items():
ext = EXT[output.format.file_format]
path = glob.glob(f'{output.file_slots[0].path}*.{ext}')[0]
os.rename(path, f'{output.file_slots[0].path}.{ext}')
# Save camera parameters
metadata = {
"file_path": f'{i:03d}.png',
"camera_angle_x": view['fov'],
"transform_matrix": get_transform_matrix(cam)
}
if arg.save_depth:
metadata['depth'] = {
'min': view['radius'] - 0.5 * np.sqrt(3),
'max': view['radius'] + 0.5 * np.sqrt(3)
}
to_export["frames"].append(metadata)
# Save the camera parameters
with open(os.path.join(arg.output_folder, 'transforms.json'), 'w') as f:
json.dump(to_export, f, indent=4)
if arg.save_mesh:
# triangulate meshes
unhide_all_objects()
convert_to_meshes()
triangulate_meshes()
print('[INFO] Meshes triangulated.')
# export ply mesh
bpy.ops.export_mesh.ply(filepath=os.path.join(arg.output_folder, 'mesh.ply'))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
parser.add_argument('--views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.')
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
parser.add_argument('--output_folder', type=str, default='/tmp', help='The path the output will be dumped to.')
parser.add_argument('--resolution', type=int, default=512, help='Resolution of the images.')
parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
parser.add_argument('--geo_mode', action='store_true', help='Geometry mode for rendering.')
parser.add_argument('--save_depth', action='store_true', help='Save the depth maps.')
parser.add_argument('--save_normal', action='store_true', help='Save the normal maps.')
parser.add_argument('--save_albedo', action='store_true', help='Save the albedo maps.')
parser.add_argument('--save_mist', action='store_true', help='Save the mist distance maps.')
parser.add_argument('--split_normal', action='store_true', help='Split the normals of the mesh.')
parser.add_argument('--save_mesh', action='store_true', help='Save the mesh as a .ply file.')
argv = sys.argv[sys.argv.index("--") + 1:]
args = parser.parse_args(argv)
main(args)

121
dataset_toolkits/render.py Normal file
View File

@@ -0,0 +1,121 @@
import os
import json
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
from functools import partial
from subprocess import DEVNULL, call
import numpy as np
from utils import sphere_hammersley_sequence
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
BLENDER_INSTALLATION_PATH = '/tmp'
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
def _install_blender():
if not os.path.exists(BLENDER_PATH):
os.system('sudo apt-get update')
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
def _render(file_path, sha256, output_dir, num_views):
output_folder = os.path.join(output_dir, 'renders', sha256)
# Build camera {yaw, pitch, radius, fov}
yaws = []
pitchs = []
offset = (np.random.rand(), np.random.rand())
for i in range(num_views):
y, p = sphere_hammersley_sequence(i, num_views, offset)
yaws.append(y)
pitchs.append(p)
radius = [2] * num_views
fov = [40 / 180 * np.pi] * num_views
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
args = [
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
'--',
'--views', json.dumps(views),
'--object', os.path.expanduser(file_path),
'--resolution', '512',
'--output_folder', output_folder,
'--engine', 'CYCLES',
'--save_mesh',
]
if file_path.endswith('.blend'):
args.insert(1, file_path)
call(args, stdout=DEVNULL, stderr=DEVNULL)
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
return {'sha256': sha256, 'rendered': True}
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--num_views', type=int, default=150,
help='Number of views to render')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=8)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True)
# install blender
print('Checking blender...', flush=True)
_install_blender()
# get file list
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
if opt.instances is None:
metadata = metadata[metadata['local_path'].notna()]
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'rendered' in metadata.columns:
metadata = metadata[metadata['rendered'] == False]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
for sha256 in copy.copy(metadata['sha256'].values):
if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
records.append({'sha256': sha256, 'rendered': True})
metadata = metadata[metadata['sha256'] != sha256]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views)
rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
rendered = pd.concat([rendered, pd.DataFrame.from_records(records)])
rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False)

View File

@@ -0,0 +1,125 @@
import os
import json
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
from functools import partial
from subprocess import DEVNULL, call
import numpy as np
from utils import sphere_hammersley_sequence
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
BLENDER_INSTALLATION_PATH = '/tmp'
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
def _install_blender():
if not os.path.exists(BLENDER_PATH):
os.system('sudo apt-get update')
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
def _render_cond(file_path, sha256, output_dir, num_views):
output_folder = os.path.join(output_dir, 'renders_cond', sha256)
# Build camera {yaw, pitch, radius, fov}
yaws = []
pitchs = []
offset = (np.random.rand(), np.random.rand())
for i in range(num_views):
y, p = sphere_hammersley_sequence(i, num_views, offset)
yaws.append(y)
pitchs.append(p)
fov_min, fov_max = 10, 70
radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi)
radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi)
k_min = 1 / radius_max**2
k_max = 1 / radius_min**2
ks = np.random.uniform(k_min, k_max, (1000000,))
radius = [1 / np.sqrt(k) for k in ks]
fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius]
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
args = [
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
'--',
'--views', json.dumps(views),
'--object', os.path.expanduser(file_path),
'--output_folder', os.path.expanduser(output_folder),
'--resolution', '1024',
]
if file_path.endswith('.blend'):
args.insert(1, file_path)
call(args, stdout=DEVNULL)
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
return {'sha256': sha256, 'cond_rendered': True}
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--num_views', type=int, default=24,
help='Number of views to render')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=8)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True)
# install blender
print('Checking blender...', flush=True)
_install_blender()
# get file list
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
if opt.instances is None:
metadata = metadata[metadata['local_path'].notna()]
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'cond_rendered' in metadata.columns:
metadata = metadata[metadata['cond_rendered'] == False]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
for sha256 in copy.copy(metadata['sha256'].values):
if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
records.append({'sha256': sha256, 'cond_rendered': True})
metadata = metadata[metadata['sha256'] != sha256]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views)
cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)])
cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False)

116
render_model.py Normal file
View File

@@ -0,0 +1,116 @@
import subprocess
import bpy
import os
import sys
from mathutils import Vector
def clear_scene():
bpy.ops.object.select_all(action='SELECT')
bpy.ops.object.delete(use_global=False)
def import_model(model_path):
ext = os.path.splitext(model_path)[1].lower()
if ext == ".obj":
bpy.ops.wm.obj_import(filepath=model_path)
elif ext in [".glb", ".gltf"]:
bpy.ops.import_scene.gltf(filepath=model_path)
else:
raise ValueError(f"Unsupported format: {ext}")
def get_scene_bbox():
objs = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH']
if not objs:
raise RuntimeError("No mesh objects found")
min_corner = Vector((float("inf"), float("inf"), float("inf")))
max_corner = Vector((float("-inf"), float("-inf"), float("-inf")))
for obj in objs:
for v in obj.bound_box:
world_v = obj.matrix_world @ Vector(v)
min_corner.x = min(min_corner.x, world_v.x)
min_corner.y = min(min_corner.y, world_v.y)
min_corner.z = min(min_corner.z, world_v.z)
max_corner.x = max(max_corner.x, world_v.x)
max_corner.y = max(max_corner.y, world_v.y)
max_corner.z = max(max_corner.z, world_v.z)
return min_corner, max_corner
def normalize_model():
min_corner, max_corner = get_scene_bbox()
center = (min_corner + max_corner) / 2
size = max(max_corner.x - min_corner.x,
max(max_corner.y - min_corner.y, max_corner.z - min_corner.z))
objs = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH']
for obj in objs:
obj.location -= center
if size > 0:
scale = 2.0 / size
for obj in objs:
obj.scale *= scale
bpy.context.view_layer.update()
def add_camera():
bpy.ops.object.camera_add(location=(2.5, -2.5, 2.0))
cam = bpy.context.active_object
direction = Vector((0, 0, 0)) - cam.location
cam.rotation_euler = direction.to_track_quat('-Z', 'Y').to_euler()
bpy.context.scene.camera = cam
def add_light():
bpy.ops.object.light_add(type='SUN', location=(5, -5, 8))
bpy.context.active_object.data.energy = 3.0
bpy.ops.object.light_add(type='AREA', location=(3, -3, 3))
area = bpy.context.active_object
area.data.energy = 3000
area.data.size = 5
def setup_render(output_path, resolution=1024):
scene = bpy.context.scene
scene.render.engine = 'CYCLES'
scene.cycles.samples = 128
scene.render.filepath = output_path
scene.render.image_settings.file_format = 'PNG'
scene.render.resolution_x = resolution
scene.render.resolution_y = resolution
scene.render.film_transparent = True
def parse_args():
argv = sys.argv
argv = argv[argv.index("--") + 1:] if "--" in argv else []
if len(argv) < 2:
raise ValueError("Usage: blender -b -P render_model.py -- model_path output_path")
return argv[0], argv[1]
def get_static_model_image():
model_path, output_path = parse_args()
clear_scene()
import_model(model_path)
normalize_model()
add_camera()
add_light()
setup_render(output_path)
bpy.ops.render.render(write_still=True)
print(f"Saved to {output_path}")
if __name__ == "__main__":
get_static_model_image()

439
server.py Normal file
View File

@@ -0,0 +1,439 @@
import mimetypes
import os
import secrets
import subprocess
import tempfile
import uuid
from datetime import datetime
from io import BytesIO
import imageio
import numpy as np
import trimesh
import litserve as ls
from minio import Minio
from glb2svg import glb_to_obj, obj_to_step, step_to_svg
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import render_utils, postprocessing_utils
from utils.new_oss_client import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, minio_get_image, MINIO_BUCKET, upload_bytes, upload_local_file, download_from_minio
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def generate_unique_name(original_name: str) -> str:
stem, ext = os.path.splitext(original_name)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
random_part = secrets.token_hex(4) # 8位随机十六进制
return f"{stem}_{timestamp}_{random_part}{ext}"
def load_mesh(file_path):
"""
加载.obj或.glb文件返回顶点数据
Args:
file_path: 模型文件路径(支持.obj和.glb格式
Returns:
numpy.ndarray: 顶点数组shape为(N, 3)
"""
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.obj':
# 使用trimesh加载obj文件
mesh = trimesh.load(file_path, file_type='obj')
elif file_ext == '.glb' or file_ext == '.gltf':
# 使用trimesh加载glb/gltf文件
mesh = trimesh.load(file_path, file_type='glb')
else:
raise ValueError(f"不支持的文件格式: {file_ext},仅支持.obj和.glb/.gltf")
# 获取顶点数据
if isinstance(mesh, trimesh.Scene):
# 如果是场景,合并所有几何体的顶点
vertices = []
for geom in mesh.geometry.values():
vertices.append(geom.vertices)
vertices = np.vstack(vertices)
else:
# 如果是单个几何体
vertices = mesh.vertices
if len(vertices) == 0:
raise ValueError("文件中未找到顶点数据")
return vertices
def analyze_mesh(file_path):
"""
分析3D模型文件计算质心、边界框、尺寸等信息
Args:
file_path: 模型文件路径(支持.obj和.glb格式
Returns:
dict: 包含模型分析信息的字典
"""
# 加载模型并获取顶点
vertices = load_mesh(file_path)
# 边界框(每个轴的最小/最大值)
min_coords = vertices.min(axis=0)
max_coords = vertices.max(axis=0)
# 质心
centroid = vertices.mean(axis=0)
# 尺寸 = 边界框维度
size = max_coords - min_coords
# 计算尺寸比例(每个轴占总尺寸的比例)
total_size = np.sum(size)
size_ratio = size / total_size if total_size != 0 else [0, 0, 0]
info = {
# "file_path": file_path,
"file_format": os.path.splitext(file_path)[1].lower(),
"vertex_count": len(vertices),
"centroid": centroid.tolist(),
"bounding_box_min": min_coords.tolist(),
"bounding_box_max": max_coords.tolist(),
"size": size.tolist(),
"size_ratio": size_ratio.tolist(),
"size_ratio_percentage": (size_ratio * 100).tolist()
}
return info
def render_glb_preview(glb_path, output_path):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cmd = [
"blender",
"--background",
"--python",
"render_model.py",
"--",
glb_path,
output_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(
f"Blender render failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
)
return output_path
class TrellisAPI(ls.LitAPI):
def setup(self, device):
os.environ.setdefault("SPCONV_ALGO", "native")
self.pipeline = TrellisImageTo3DPipeline.from_pretrained(
"microsoft/TRELLIS-image-large"
)
self.pipeline.to(device)
def decode_request(self, request):
image_paths = request["image_paths"]
images = []
for path in image_paths:
bucket = path.split('/')[0]
object_name = path[path.find('/') + 1:]
image = minio_get_image(minio_client, bucket, object_name)
images.append(image)
params = {
"file_name": uuid.uuid4().hex,
"model": request.get("model", "single"),
"seed": request.get("seed", 1),
"steps_sparse": request.get("steps_sparse", 12),
"cfg_sparse": request.get("cfg_sparse", 7.5),
"steps_slat": request.get("steps_slat", 12),
"cfg_slat": request.get("cfg_slat", 3.0),
"simplify": request.get("simplify", 0.95),
"texture_size": request.get("texture_size", 1024),
"fps": request.get("fps", 30),
}
return images, params
def predict(self, inputs):
images, params = inputs
if params["model"] == "single":
outputs = self.pipeline.run(
images[0],
seed=params["seed"],
sparse_structure_sampler_params={
"steps": params["steps_sparse"],
"cfg_strength": params["cfg_sparse"],
},
slat_sampler_params={
"steps": params["steps_slat"],
"cfg_strength": params["cfg_slat"],
},
)
else:
outputs = self.pipeline.run_multi_image(
images,
seed=params['seed'],
sparse_structure_sampler_params={
"steps": params['steps_sparse'],
"cfg_strength": params['cfg_sparse'],
},
slat_sampler_params={
"steps": params['steps_slat'],
"cfg_strength": params['cfg_slat'],
},
)
# video_path = self.upload_video(outputs, params)
minio_glb_path, local_glb_path = self.upload_glb(outputs, params)
glb_info = analyze_mesh(local_glb_path)
local_static_model_image_path = os.path.join("glb_output", generate_unique_name("static_model_image.png"))
static_model_image = self.get_static_model_image(model_path=local_glb_path, output_path=local_static_model_image_path)
return {
"glb_path": minio_glb_path,
"glb_static_img_path": static_model_image,
"glb_info": glb_info,
}
def encode_response(self, output):
return output
def upload_video(self, outputs, params):
gaussian_name = f"3d_result/video/{params['file_name']}-gaussian.mp4"
radiance_field_name = f"3d_result/video/{params['file_name']}-radiance_field.mp4"
mesh_name = f"3d_result/video/{params['file_name']}-mesh.mp4"
# gaussian video
video = render_utils.render_video(outputs["gaussian"][0])["color"]
buffer = BytesIO()
imageio.mimsave(buffer, video, format="mp4", fps=params['fps'])
gaussian_video_path = upload_bytes(
buffer.getvalue(),
gaussian_name,
"video/mp4",
)
# radiance field video
video = render_utils.render_video(outputs["radiance_field"][0])["color"]
buffer = BytesIO()
imageio.mimsave(buffer, video, format="mp4", fps=params['fps'])
radiance_field_video_path = upload_bytes(
buffer.getvalue(),
radiance_field_name,
"video/mp4",
)
# mesh video
video = render_utils.render_video(outputs["mesh"][0])["normal"]
buffer = BytesIO()
imageio.mimsave(buffer, video, format="mp4", fps=params['fps'])
mesh_path = upload_bytes(
buffer.getvalue(),
mesh_name,
"video/mp4",
)
return {
"gaussian": gaussian_video_path,
"radiance_field": radiance_field_video_path,
"mesh": mesh_path
}
def upload_glb(self, outputs, params):
file_name = f"3d_result/glb/{params['file_name']}.glb"
local_glb_path = os.path.join("glb_output", generate_unique_name("sample.glb"))
out_dir = os.path.dirname(local_glb_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
glb = postprocessing_utils.to_glb(
outputs["gaussian"][0],
outputs["mesh"][0],
simplify=params['simplify'],
texture_size=params['texture_size'],
)
glb.export(
file_obj=local_glb_path,
file_type="glb"
)
glb_path = upload_local_file(
local_glb_path,
file_name,
"application/octet-stream",
)
return glb_path, local_glb_path
def upload_ply(self, outputs, params):
file_name = f"3d_result/ply/{params['file_name']}.ply"
with tempfile.NamedTemporaryFile(suffix=".ply") as tmp:
outputs["gaussian"][0].save_ply(tmp.name)
tmp.seek(0)
ply_path = upload_bytes(
tmp.read(),
file_name,
"application/octet-stream",
)
return {"ply": ply_path}
def get_static_model_image(self, model_path, output_path):
local_static_model_image_path = os.path.join(
"glb_output",
generate_unique_name("static_model_image.png")
)
print(f"model_path : {model_path}")
print(f"local_static_model_image_path :{local_static_model_image_path}")
output_path = render_glb_preview(model_path, local_static_model_image_path)
static_model_image = self.upload_local_file(
output_path,
"png"
)
print(f"Saved to {static_model_image}")
return static_model_image
def upload_local_file(self, local_path, type):
"""
通用上传函数:支持 SVG, PNG, OBJ 等
"""
object_name = f"3d_result/{type}/{uuid.uuid4().hex}.{type}"
if not os.path.exists(local_path):
print(f"错误: 文件 {local_path} 不存在")
return None
# 自动根据后缀名识别 Content-Type
# 例如: .svg -> image/svg+xml, .png -> image/png
content_type, _ = mimetypes.guess_type(local_path)
if content_type is None:
content_type = "application/octet-stream"
try:
minio_client.fput_object(
bucket_name=MINIO_BUCKET,
object_name=object_name,
file_path=local_path,
content_type=content_type
)
print(f"成功上传 [{content_type}]: {object_name}")
return f"{MINIO_BUCKET}/{object_name}"
except Exception as e:
print(f"上传失败: {e}")
return None
class ModelToThreeViews(ls.LitAPI):
def setup(self, device):
pass
def upload_local_file(self, local_path, type):
"""
通用上传函数:支持 SVG, PNG, OBJ 等
"""
object_name = f"3d_result/{type}/{uuid.uuid4().hex}.{type}"
if not os.path.exists(local_path):
print(f"错误: 文件 {local_path} 不存在")
return None
# 自动根据后缀名识别 Content-Type
# 例如: .svg -> image/svg+xml, .png -> image/png
content_type, _ = mimetypes.guess_type(local_path)
if content_type is None:
content_type = "application/octet-stream"
try:
minio_client.fput_object(
bucket_name=MINIO_BUCKET,
object_name=object_name,
file_path=local_path,
content_type=content_type
)
print(f"成功上传 [{content_type}]: {object_name}")
return f"{MINIO_BUCKET}/{object_name}"
except Exception as e:
print(f"上传失败: {e}")
return None
def predict(self, request):
minio_glb_path = request['minio_glb_path']
work_dir = f"glb_to_obj"
os.makedirs(work_dir, exist_ok=True)
glb_path = os.path.join(work_dir, f"model{uuid.uuid4().hex}.glb")
step_dir = os.path.join(work_dir, "step")
svg_dir = os.path.join(work_dir, "svg")
os.makedirs(step_dir, exist_ok=True)
os.makedirs(svg_dir, exist_ok=True)
print(f"""
入参阶段:
input glb-obj minio-path:{minio_glb_path},\n
work_dir : {work_dir},glb_path : {glb_path},step_dir : {step_dir},svg_dir : {svg_dir}\n
""")
print("=" * 10)
print(f" 第一阶段 下载glb文件: ")
# 1 下载
glb_result = download_from_minio(object_path=minio_glb_path, local_path=glb_path)
print(f" 下载结果 : {glb_result} \n")
print("=" * 10)
print(f" 第二阶段 glb -> obj: ")
# 2 glb -> obj
obj_result = glb_to_obj(glb_result)
print(f" glb -> obj 结果 : {obj_result} \n")
print("=" * 10)
print(f" 第三阶段 obj -> step: ")
# 3 obj -> step
step_result = obj_to_step(
input_obj=obj_result,
output_dir=step_dir,
script_path="1_obj_to_step.py"
)
print(f" obj -> step 结果 : {step_result} \n")
print("=" * 10)
print(f" 第四阶段 step -> svg: ")
# 4 step -> svg
combined_svg, combined_png = step_to_svg(
step_path=step_result,
out_dir=svg_dir
)
print(f" step -> svg 结果 : {combined_svg} \n")
print("=" * 10)
# 5 上传
minio_svg_path = self.upload_local_file(combined_png, "svg")
return {"minio_svg_path": minio_svg_path}
if __name__ == "__main__":
trellis_api = TrellisAPI(api_path="/canvas/img_to_3D")
model_to_three_api = ModelToThreeViews(api_path="/canvas/3d_to_3views")
server = ls.LitServer([
trellis_api,
model_to_three_api])
server.run(port=8120)

95
single_image_to_3D.py Normal file
View File

@@ -0,0 +1,95 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import argparse
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import render_utils, postprocessing_utils
def build_parser():
p = argparse.ArgumentParser("TRELLIS CLI: single image -> 3D")
p.add_argument("-i", "--image", required=True, help="Input image path")
p.add_argument("-o", "--out_dir", default="trellis_out", help="Output directory")
p.add_argument("--seed", type=int, default=1)
p.add_argument("--steps_sparse", type=int, default=12)
p.add_argument("--cfg_sparse", type=float, default=7.5)
p.add_argument("--steps_slat", type=int, default=12)
p.add_argument("--cfg_slat", type=float, default=3.0)
p.add_argument("--simplify", type=float, default=0.95)
p.add_argument("--texture_size", type=int, default=1024)
# Export GLB (default True)
p.add_argument("--export_glb", dest="export_glb", action="store_true", default=True)
p.add_argument("--no-export_glb", dest="export_glb", action="store_false")
# Save PLY (default True)
p.add_argument("--save_ply", dest="save_ply", action="store_true", default=True)
p.add_argument("--no-save_ply", dest="save_ply", action="store_false")
# Save videos (default False, plus explicit toggle)
p.add_argument("--save_video", dest="save_video", action="store_true", default=True)
p.add_argument("--no-save_video", dest="save_video", action="store_false")
p.add_argument("--fps", type=int, default=30)
p.add_argument("--video_gs_name", type=str, default="sample_gs.mp4")
p.add_argument("--video_rf_name", type=str, default="sample_rf.mp4")
p.add_argument("--video_mesh_name", type=str, default="sample_mesh.mp4")
return p
def main():
args = build_parser().parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# Optional env
os.environ.setdefault("SPCONV_ALGO", "native")
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
pipeline.cuda()
image = Image.open(args.image)
outputs = pipeline.run(
image,
seed=args.seed,
sparse_structure_sampler_params={
"steps": args.steps_sparse,
"cfg_strength": args.cfg_sparse,
},
slat_sampler_params={
"steps": args.steps_slat,
"cfg_strength": args.cfg_slat,
},
)
if args.save_video:
import imageio
video = render_utils.render_video(outputs["gaussian"][0])["color"]
imageio.mimsave(os.path.join(args.out_dir, args.video_gs_name), video, fps=args.fps)
video = render_utils.render_video(outputs["radiance_field"][0])["color"]
imageio.mimsave(os.path.join(args.out_dir, args.video_rf_name), video, fps=args.fps)
video = render_utils.render_video(outputs["mesh"][0])["normal"]
imageio.mimsave(os.path.join(args.out_dir, args.video_mesh_name), video, fps=args.fps)
if args.export_glb:
glb = postprocessing_utils.to_glb(
outputs["gaussian"][0],
outputs["mesh"][0],
simplify=args.simplify,
texture_size=args.texture_size,
)
glb.export(os.path.join(args.out_dir, "sample.glb"))
if args.save_ply:
outputs["gaussian"][0].save_ply(os.path.join(args.out_dir, "sample.ply"))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,193 @@
from typing import *
from enum import Enum
import torch
import math
from .. import SparseTensor
from .. import DEBUG, ATTN
if ATTN == 'xformers':
import xformers.ops as xops
elif ATTN == 'flash_attn':
import flash_attn
else:
raise ValueError(f"Unknown attention module: {ATTN}")
__all__ = [
'sparse_serialized_scaled_dot_product_self_attention',
]
class SerializeMode(Enum):
Z_ORDER = 0
Z_ORDER_TRANSPOSED = 1
HILBERT = 2
HILBERT_TRANSPOSED = 3
SerializeModes = [
SerializeMode.Z_ORDER,
SerializeMode.Z_ORDER_TRANSPOSED,
SerializeMode.HILBERT,
SerializeMode.HILBERT_TRANSPOSED
]
def calc_serialization(
tensor: SparseTensor,
window_size: int,
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
shift_sequence: int = 0,
shift_window: Tuple[int, int, int] = (0, 0, 0)
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
"""
Calculate serialization and partitioning for a set of coordinates.
Args:
tensor (SparseTensor): The input tensor.
window_size (int): The window size to use.
serialize_mode (SerializeMode): The serialization mode to use.
shift_sequence (int): The shift of serialized sequence.
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
Returns:
(torch.Tensor, torch.Tensor): Forwards and backwards indices.
"""
fwd_indices = []
bwd_indices = []
seq_lens = []
seq_batch_indices = []
offsets = [0]
if 'vox2seq' not in globals():
import vox2seq
# Serialize the input
serialize_coords = tensor.coords[:, 1:].clone()
serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
if serialize_mode == SerializeMode.Z_ORDER:
code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
elif serialize_mode == SerializeMode.HILBERT:
code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
else:
raise ValueError(f"Unknown serialize mode: {serialize_mode}")
for bi, s in enumerate(tensor.layout):
num_points = s.stop - s.start
num_windows = (num_points + window_size - 1) // window_size
valid_window_size = num_points / num_windows
to_ordered = torch.argsort(code[s.start:s.stop])
if num_windows == 1:
fwd_indices.append(to_ordered)
bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
fwd_indices[-1] += s.start
bwd_indices[-1] += offsets[-1]
seq_lens.append(num_points)
seq_batch_indices.append(bi)
offsets.append(offsets[-1] + seq_lens[-1])
else:
# Partition the input
offset = 0
mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
for i in range(num_windows):
mid = mids[i]
valid_start = split[i]
valid_end = split[i + 1]
padded_start = math.floor(mid - 0.5 * window_size)
padded_end = padded_start + window_size
fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
offset += valid_start - padded_start
bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
offset += padded_end - valid_start
fwd_indices[-1] += s.start
seq_lens.extend([window_size] * num_windows)
seq_batch_indices.extend([bi] * num_windows)
bwd_indices.append(bwd_index + offsets[-1])
offsets.append(offsets[-1] + num_windows * window_size)
fwd_indices = torch.cat(fwd_indices)
bwd_indices = torch.cat(bwd_indices)
return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
def sparse_serialized_scaled_dot_product_self_attention(
qkv: SparseTensor,
window_size: int,
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
shift_sequence: int = 0,
shift_window: Tuple[int, int, int] = (0, 0, 0)
) -> SparseTensor:
"""
Apply serialized scaled dot product self attention to a sparse tensor.
Args:
qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
window_size (int): The window size to use.
serialize_mode (SerializeMode): The serialization mode to use.
shift_sequence (int): The shift of serialized sequence.
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
shift (int): The shift to use.
"""
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
if serialization_spatial_cache is None:
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
else:
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
M = fwd_indices.shape[0]
T = qkv.feats.shape[0]
H = qkv.feats.shape[2]
C = qkv.feats.shape[3]
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
if DEBUG:
start = 0
qkv_coords = qkv.coords[fwd_indices]
for i in range(len(seq_lens)):
assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
start += seq_lens[i]
if all([seq_len == window_size for seq_len in seq_lens]):
B = len(seq_lens)
N = window_size
qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
if ATTN == 'xformers':
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
elif ATTN == 'flash_attn':
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
else:
raise ValueError(f"Unknown attention module: {ATTN}")
out = out.reshape(B * N, H, C) # [M, H, C]
else:
if ATTN == 'xformers':
q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
q = q.unsqueeze(0) # [1, M, H, C]
k = k.unsqueeze(0) # [1, M, H, C]
v = v.unsqueeze(0) # [1, M, H, C]
mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
elif ATTN == 'flash_attn':
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
.to(qkv.device).int()
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
out = out[bwd_indices] # [T, H, C]
if DEBUG:
qkv_coords = qkv_coords[bwd_indices]
assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
return qkv.replace(out)

118
trellis/renderers/sh_utils.py Executable file
View File

@@ -0,0 +1,118 @@
# Copyright 2021 The PlenOctree Authors.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import torch
C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
1.0925484305920792,
-1.0925484305920792,
0.31539156525252005,
-1.0925484305920792,
0.5462742152960396
]
C3 = [
-0.5900435899266435,
2.890611442640554,
-0.4570457994644658,
0.3731763325901154,
-0.4570457994644658,
1.445305721320277,
-0.5900435899266435
]
C4 = [
2.5033429417967046,
-1.7701307697799304,
0.9461746957575601,
-0.6690465435572892,
0.10578554691520431,
-0.6690465435572892,
0.47308734787878004,
-1.7701307697799304,
0.6258357354491761,
]
def eval_sh(deg, sh, dirs):
"""
Evaluate spherical harmonics at unit directions
using hardcoded SH polynomials.
Works with torch/np/jnp.
... Can be 0 or more batch dimensions.
Args:
deg: int SH deg. Currently, 0-3 supported
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
dirs: jnp.ndarray unit directions [..., 3]
Returns:
[..., C]
"""
assert deg <= 4 and deg >= 0
coeff = (deg + 1) ** 2
assert sh.shape[-1] >= coeff
result = C0 * sh[..., 0]
if deg > 0:
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
result = (result -
C1 * y * sh[..., 1] +
C1 * z * sh[..., 2] -
C1 * x * sh[..., 3])
if deg > 1:
xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z
result = (result +
C2[0] * xy * sh[..., 4] +
C2[1] * yz * sh[..., 5] +
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
C2[3] * xz * sh[..., 7] +
C2[4] * (xx - yy) * sh[..., 8])
if deg > 2:
result = (result +
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
C3[1] * xy * z * sh[..., 10] +
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
C3[5] * z * (xx - yy) * sh[..., 14] +
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
if deg > 3:
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
return result
def RGB2SH(rgb):
return (rgb - 0.5) / C0
def SH2RGB(sh):
return sh * C0 + 0.5

View File

@@ -0,0 +1,274 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import copy
import math
from ipywidgets import interactive, HBox, VBox, FloatLogSlider, IntSlider
import torch
import nvdiffrast.torch as dr
import kaolin as kal
import util
###############################################################################
# Functions adapted from https://github.com/NVlabs/nvdiffrec
###############################################################################
def get_random_camera_batch(batch_size, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
if use_kaolin:
camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(
*kal.ops.random.sample_spherical_coords((batch_size,), azimuth_low=0., azimuth_high=math.pi * 2,
elevation_low=-math.pi / 2., elevation_high=math.pi / 2., device='cuda'),
cam_radius
), dim=-1)
return kal.render.camera.Camera.from_args(
eye=camera_pos + torch.rand((batch_size, 1), device='cuda') * 0.5 - 0.25,
at=torch.zeros(batch_size, 3),
up=torch.tensor([[0., 1., 0.]]),
fov=fovy,
near=cam_near_far[0], far=cam_near_far[1],
height=iter_res[0], width=iter_res[1],
device='cuda'
)
else:
def get_random_camera():
proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
mv = util.translate(0, 0, -cam_radius) @ util.random_rotation_translation(0.25)
mvp = proj_mtx @ mv
return mv, mvp
mv_batch = []
mvp_batch = []
for i in range(batch_size):
mv, mvp = get_random_camera()
mv_batch.append(mv)
mvp_batch.append(mvp)
return torch.stack(mv_batch).to(device), torch.stack(mvp_batch).to(device)
def get_rotate_camera(itr, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
if use_kaolin:
ang = (itr / 10) * np.pi * 2
camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(torch.tensor(ang), torch.tensor(0.4), -torch.tensor(cam_radius)))
return kal.render.camera.Camera.from_args(
eye=camera_pos,
at=torch.zeros(3),
up=torch.tensor([0., 1., 0.]),
fov=fovy,
near=cam_near_far[0], far=cam_near_far[1],
height=iter_res[0], width=iter_res[1],
device='cuda'
)
else:
proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
# Smooth rotation for display.
ang = (itr / 10) * np.pi * 2
mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
mvp = proj_mtx @ mv
return mv.to(device), mvp.to(device)
glctx = dr.RasterizeGLContext()
def render_mesh(mesh, camera, iter_res, return_types = ["mask", "depth"], white_bg=False, wireframe_thickness=0.4):
vertices_camera = camera.extrinsics.transform(mesh.vertices)
face_vertices_camera = kal.ops.mesh.index_vertices_by_faces(
vertices_camera, mesh.faces
)
# Projection: nvdiffrast take clip coordinates as input to apply barycentric perspective correction.
# Using `camera.intrinsics.transform(vertices_camera) would return the normalized device coordinates.
proj = camera.projection_matrix().unsqueeze(1)
proj[:, :, 1, 1] = -proj[:, :, 1, 1]
homogeneous_vecs = kal.render.camera.up_to_homogeneous(
vertices_camera
)
vertices_clip = (proj @ homogeneous_vecs.unsqueeze(-1)).squeeze(-1)
faces_int = mesh.faces.int()
rast, _ = dr.rasterize(
glctx, vertices_clip, faces_int, iter_res)
out_dict = {}
for type in return_types:
if type == "mask" :
img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
elif type == "depth":
img = dr.interpolate(homogeneous_vecs, rast, faces_int)[0]
elif type == "wireframe":
img = torch.logical_or(
torch.logical_or(rast[..., 0] < wireframe_thickness, rast[..., 1] < wireframe_thickness),
(rast[..., 0] + rast[..., 1]) > (1. - wireframe_thickness)
).unsqueeze(-1)
elif type == "normals" :
img = dr.interpolate(
mesh.face_normals.reshape(len(mesh), -1, 3), rast,
torch.arange(mesh.faces.shape[0] * 3, device='cuda', dtype=torch.int).reshape(-1, 3)
)[0]
if white_bg:
bg = torch.ones_like(img)
alpha = (rast[..., -1:] > 0).float()
img = torch.lerp(bg, img, alpha)
out_dict[type] = img
return out_dict
def render_mesh_paper(mesh, mv, mvp, iter_res, return_types = ["mask", "depth"], white_bg=False):
'''
The rendering function used to produce the results in the paper.
'''
v_pos_clip = util.xfm_points(mesh.vertices.unsqueeze(0), mvp) # Rotate it to camera coordinates
rast, db = dr.rasterize(
dr.RasterizeGLContext(), v_pos_clip, mesh.faces.int(), iter_res)
out_dict = {}
for type in return_types:
if type == "mask" :
img = dr.antialias((rast[..., -1:] > 0).float(), rast, v_pos_clip, mesh.faces.int())
elif type == "depth":
v_pos_cam = util.xfm_points(mesh.vertices.unsqueeze(0), mv)
img, _ = util.interpolate(v_pos_cam, rast, mesh.faces.int())
elif type == "normal" :
normal_indices = (torch.arange(0, mesh.nrm.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
img, _ = util.interpolate(mesh.nrm.unsqueeze(0).contiguous(), rast, normal_indices.int())
elif type == "vertex_normal":
img, _ = util.interpolate(mesh.v_nrm.unsqueeze(0).contiguous(), rast, mesh.faces.int())
img = dr.antialias((img + 1) * 0.5, rast, v_pos_clip, mesh.faces.int())
if white_bg:
bg = torch.ones_like(img)
alpha = (rast[..., -1:] > 0).float()
img = torch.lerp(bg, img, alpha)
out_dict[type] = img
return out_dict
class SplitVisualizer():
def __init__(self, lh_mesh, rh_mesh, height, width):
self.lh_mesh = lh_mesh
self.rh_mesh = rh_mesh
self.height = height
self.width = width
self.wireframe_thickness = 0.4
def render(self, camera):
lh_outputs = render_mesh(
self.lh_mesh, camera, (self.height, self.width),
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
)
rh_outputs = render_mesh(
self.rh_mesh, camera, (self.height, self.width),
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
)
outputs = {
k: torch.cat(
[lh_outputs[k][0].permute(1, 0, 2), rh_outputs[k][0].permute(1, 0, 2)],
dim=0
).permute(1, 0, 2) for k in ["normals", "wireframe"]
}
return {
'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8),
'normals': outputs['normals']
}
def show(self, init_camera):
visualizer = kal.visualize.IpyTurntableVisualizer(
self.height, self.width * 2, copy.deepcopy(init_camera), self.render,
max_fps=24, world_up_axis=1)
def slider_callback(new_wireframe_thickness):
"""ipywidgets sliders callback"""
with visualizer.out: # This is in case of bug
self.wireframe_thickness = new_wireframe_thickness
# this is how we request a new update
visualizer.render_update()
wireframe_thickness_slider = FloatLogSlider(
value=self.wireframe_thickness,
base=10,
min=-3,
max=-0.4,
step=0.1,
description='wireframe_thickness',
continuous_update=True,
readout=True,
readout_format='.3f',
)
interactive_slider = interactive(
slider_callback,
new_wireframe_thickness=wireframe_thickness_slider,
)
full_output = VBox([visualizer.canvas, interactive_slider])
display(full_output, visualizer.out)
class TimelineVisualizer():
def __init__(self, meshes, height, width):
self.meshes = meshes
self.height = height
self.width = width
self.wireframe_thickness = 0.4
self.idx = len(meshes) - 1
def render(self, camera):
outputs = render_mesh(
self.meshes[self.idx], camera, (self.height, self.width),
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
)
return {
'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8)[0],
'normals': outputs['normals'][0]
}
def show(self, init_camera):
visualizer = kal.visualize.IpyTurntableVisualizer(
self.height, self.width, copy.deepcopy(init_camera), self.render,
max_fps=24, world_up_axis=1)
def slider_callback(new_wireframe_thickness, new_idx):
"""ipywidgets sliders callback"""
with visualizer.out: # This is in case of bug
self.wireframe_thickness = new_wireframe_thickness
self.idx = new_idx
# this is how we request a new update
visualizer.render_update()
wireframe_thickness_slider = FloatLogSlider(
value=self.wireframe_thickness,
base=10,
min=-3,
max=-0.4,
step=0.1,
description='wireframe_thickness',
continuous_update=True,
readout=True,
readout_format='.3f',
)
idx_slider = IntSlider(
value=self.idx,
min=0,
max=len(self.meshes) - 1,
description='idx',
continuous_update=True,
readout=True
)
interactive_slider = interactive(
slider_callback,
new_wireframe_thickness=wireframe_thickness_slider,
new_idx=idx_slider
)
full_output = HBox([visualizer.canvas, interactive_slider])
display(full_output, visualizer.out)

View File

@@ -0,0 +1,30 @@
import numpy as np
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
u, v = hammersley_sequence(2, n, num_samples)
u += offset[0] / num_samples
v += offset[1]
if remap:
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
theta = np.arccos(1 - 2 * u) - np.pi / 2
phi = v * 2 * np.pi
return [phi, theta]

View File

@@ -0,0 +1,120 @@
import torch
import numpy as np
from tqdm import tqdm
import utils3d
from PIL import Image
from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer
from ..representations import Octree, Gaussian, MeshExtractResult
from ..modules import sparse as sp
from .random_utils import sphere_hammersley_sequence
def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
is_list = isinstance(yaws, list)
if not is_list:
yaws = [yaws]
pitchs = [pitchs]
if not isinstance(rs, list):
rs = [rs] * len(yaws)
if not isinstance(fovs, list):
fovs = [fovs] * len(yaws)
extrinsics = []
intrinsics = []
for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
yaw = torch.tensor(float(yaw)).cuda()
pitch = torch.tensor(float(pitch)).cuda()
orig = torch.tensor([
torch.sin(yaw) * torch.cos(pitch),
torch.cos(yaw) * torch.cos(pitch),
torch.sin(pitch),
]).cuda() * r
extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
extrinsics.append(extr)
intrinsics.append(intr)
if not is_list:
extrinsics = extrinsics[0]
intrinsics = intrinsics[0]
return extrinsics, intrinsics
def get_renderer(sample, **kwargs):
if isinstance(sample, Octree):
renderer = OctreeRenderer()
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
renderer.rendering_options.near = kwargs.get('near', 0.8)
renderer.rendering_options.far = kwargs.get('far', 1.6)
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
renderer.pipe.primitive = sample.primitive
elif isinstance(sample, Gaussian):
renderer = GaussianRenderer()
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
renderer.rendering_options.near = kwargs.get('near', 0.8)
renderer.rendering_options.far = kwargs.get('far', 1.6)
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
renderer.rendering_options.ssaa = kwargs.get('ssaa', 1)
renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
renderer.pipe.use_mip_gaussian = True
elif isinstance(sample, MeshExtractResult):
renderer = MeshRenderer()
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
renderer.rendering_options.near = kwargs.get('near', 1)
renderer.rendering_options.far = kwargs.get('far', 100)
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
else:
raise ValueError(f'Unsupported sample type: {type(sample)}')
return renderer
def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
renderer = get_renderer(sample, **options)
rets = {}
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
if isinstance(sample, MeshExtractResult):
res = renderer.render(sample, extr, intr)
if 'normal' not in rets: rets['normal'] = []
rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
else:
res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
if 'color' not in rets: rets['color'] = []
if 'depth' not in rets: rets['depth'] = []
rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
if 'percent_depth' in res:
rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
elif 'depth' in res:
rets['depth'].append(res['depth'].detach().cpu().numpy())
else:
rets['depth'].append(None)
return rets
def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
yaws = yaws.tolist()
pitch = pitch.tolist()
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
def render_multiview(sample, resolution=512, nviews=30):
r = 2
fov = 40
cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
yaws = [cam[0] for cam in cams]
pitchs = [cam[1] for cam in cams]
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
return res['color'], extrinsics, intrinsics
def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs):
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
yaw_offset = offset[0]
yaw = [y + yaw_offset for y in yaw]
pitch = [offset[1] for _ in range(4)]
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)