diff --git a/dataset_toolkits/blender_script/render.py b/dataset_toolkits/blender_script/render.py new file mode 100644 index 0000000..1fbd586 --- /dev/null +++ b/dataset_toolkits/blender_script/render.py @@ -0,0 +1,528 @@ +import argparse, sys, os, math, re, glob +from typing import * +import bpy +from mathutils import Vector, Matrix +import numpy as np +import json +import glob + + +"""=============== BLENDER ===============""" + +IMPORT_FUNCTIONS: Dict[str, Callable] = { + "obj": bpy.ops.import_scene.obj, + "glb": bpy.ops.import_scene.gltf, + "gltf": bpy.ops.import_scene.gltf, + "usd": bpy.ops.import_scene.usd, + "fbx": bpy.ops.import_scene.fbx, + "stl": bpy.ops.import_mesh.stl, + "usda": bpy.ops.import_scene.usda, + "dae": bpy.ops.wm.collada_import, + "ply": bpy.ops.import_mesh.ply, + "abc": bpy.ops.wm.alembic_import, + "blend": bpy.ops.wm.append, +} + +EXT = { + 'PNG': 'png', + 'JPEG': 'jpg', + 'OPEN_EXR': 'exr', + 'TIFF': 'tiff', + 'BMP': 'bmp', + 'HDR': 'hdr', + 'TARGA': 'tga' +} + +def init_render(engine='CYCLES', resolution=512, geo_mode=False): + bpy.context.scene.render.engine = engine + bpy.context.scene.render.resolution_x = resolution + bpy.context.scene.render.resolution_y = resolution + bpy.context.scene.render.resolution_percentage = 100 + bpy.context.scene.render.image_settings.file_format = 'PNG' + bpy.context.scene.render.image_settings.color_mode = 'RGBA' + bpy.context.scene.render.film_transparent = True + + bpy.context.scene.cycles.device = 'GPU' + bpy.context.scene.cycles.samples = 128 if not geo_mode else 1 + bpy.context.scene.cycles.filter_type = 'BOX' + bpy.context.scene.cycles.filter_width = 1 + bpy.context.scene.cycles.diffuse_bounces = 1 + bpy.context.scene.cycles.glossy_bounces = 1 + bpy.context.scene.cycles.transparent_max_bounces = 3 if not geo_mode else 0 + bpy.context.scene.cycles.transmission_bounces = 3 if not geo_mode else 1 + bpy.context.scene.cycles.use_denoising = True + + bpy.context.preferences.addons['cycles'].preferences.get_devices() + bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA' + +def init_nodes(save_depth=False, save_normal=False, save_albedo=False, save_mist=False): + if not any([save_depth, save_normal, save_albedo, save_mist]): + return {}, {} + outputs = {} + spec_nodes = {} + + bpy.context.scene.use_nodes = True + bpy.context.scene.view_layers['View Layer'].use_pass_z = save_depth + bpy.context.scene.view_layers['View Layer'].use_pass_normal = save_normal + bpy.context.scene.view_layers['View Layer'].use_pass_diffuse_color = save_albedo + bpy.context.scene.view_layers['View Layer'].use_pass_mist = save_mist + + nodes = bpy.context.scene.node_tree.nodes + links = bpy.context.scene.node_tree.links + for n in nodes: + nodes.remove(n) + + render_layers = nodes.new('CompositorNodeRLayers') + + if save_depth: + depth_file_output = nodes.new('CompositorNodeOutputFile') + depth_file_output.base_path = '' + depth_file_output.file_slots[0].use_node_format = True + depth_file_output.format.file_format = 'PNG' + depth_file_output.format.color_depth = '16' + depth_file_output.format.color_mode = 'BW' + # Remap to 0-1 + map = nodes.new(type="CompositorNodeMapRange") + map.inputs[1].default_value = 0 # (min value you will be getting) + map.inputs[2].default_value = 10 # (max value you will be getting) + map.inputs[3].default_value = 0 # (min value you will map to) + map.inputs[4].default_value = 1 # (max value you will map to) + + links.new(render_layers.outputs['Depth'], map.inputs[0]) + links.new(map.outputs[0], depth_file_output.inputs[0]) + + outputs['depth'] = depth_file_output + spec_nodes['depth_map'] = map + + if save_normal: + normal_file_output = nodes.new('CompositorNodeOutputFile') + normal_file_output.base_path = '' + normal_file_output.file_slots[0].use_node_format = True + normal_file_output.format.file_format = 'OPEN_EXR' + normal_file_output.format.color_mode = 'RGB' + normal_file_output.format.color_depth = '16' + + links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0]) + + outputs['normal'] = normal_file_output + + if save_albedo: + albedo_file_output = nodes.new('CompositorNodeOutputFile') + albedo_file_output.base_path = '' + albedo_file_output.file_slots[0].use_node_format = True + albedo_file_output.format.file_format = 'PNG' + albedo_file_output.format.color_mode = 'RGBA' + albedo_file_output.format.color_depth = '8' + + alpha_albedo = nodes.new('CompositorNodeSetAlpha') + + links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image']) + links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha']) + links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0]) + + outputs['albedo'] = albedo_file_output + + if save_mist: + bpy.data.worlds['World'].mist_settings.start = 0 + bpy.data.worlds['World'].mist_settings.depth = 10 + + mist_file_output = nodes.new('CompositorNodeOutputFile') + mist_file_output.base_path = '' + mist_file_output.file_slots[0].use_node_format = True + mist_file_output.format.file_format = 'PNG' + mist_file_output.format.color_mode = 'BW' + mist_file_output.format.color_depth = '16' + + links.new(render_layers.outputs['Mist'], mist_file_output.inputs[0]) + + outputs['mist'] = mist_file_output + + return outputs, spec_nodes + +def init_scene() -> None: + """Resets the scene to a clean state. + + Returns: + None + """ + # delete everything + for obj in bpy.data.objects: + bpy.data.objects.remove(obj, do_unlink=True) + + # delete all the materials + for material in bpy.data.materials: + bpy.data.materials.remove(material, do_unlink=True) + + # delete all the textures + for texture in bpy.data.textures: + bpy.data.textures.remove(texture, do_unlink=True) + + # delete all the images + for image in bpy.data.images: + bpy.data.images.remove(image, do_unlink=True) + +def init_camera(): + cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera')) + bpy.context.collection.objects.link(cam) + bpy.context.scene.camera = cam + cam.data.sensor_height = cam.data.sensor_width = 32 + cam_constraint = cam.constraints.new(type='TRACK_TO') + cam_constraint.track_axis = 'TRACK_NEGATIVE_Z' + cam_constraint.up_axis = 'UP_Y' + cam_empty = bpy.data.objects.new("Empty", None) + cam_empty.location = (0, 0, 0) + bpy.context.scene.collection.objects.link(cam_empty) + cam_constraint.target = cam_empty + return cam + +def init_lighting(): + # Clear existing lights + bpy.ops.object.select_all(action="DESELECT") + bpy.ops.object.select_by_type(type="LIGHT") + bpy.ops.object.delete() + + # Create key light + default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT")) + bpy.context.collection.objects.link(default_light) + default_light.data.energy = 1000 + default_light.location = (4, 1, 6) + default_light.rotation_euler = (0, 0, 0) + + # create top light + top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA")) + bpy.context.collection.objects.link(top_light) + top_light.data.energy = 10000 + top_light.location = (0, 0, 10) + top_light.scale = (100, 100, 100) + + # create bottom light + bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA")) + bpy.context.collection.objects.link(bottom_light) + bottom_light.data.energy = 1000 + bottom_light.location = (0, 0, -10) + bottom_light.rotation_euler = (0, 0, 0) + + return { + "default_light": default_light, + "top_light": top_light, + "bottom_light": bottom_light + } + + +def load_object(object_path: str) -> None: + """Loads a model with a supported file extension into the scene. + + Args: + object_path (str): Path to the model file. + + Raises: + ValueError: If the file extension is not supported. + + Returns: + None + """ + file_extension = object_path.split(".")[-1].lower() + if file_extension is None: + raise ValueError(f"Unsupported file type: {object_path}") + + if file_extension == "usdz": + # install usdz io package + dirname = os.path.dirname(os.path.realpath(__file__)) + usdz_package = os.path.join(dirname, "io_scene_usdz.zip") + bpy.ops.preferences.addon_install(filepath=usdz_package) + # enable it + addon_name = "io_scene_usdz" + bpy.ops.preferences.addon_enable(module=addon_name) + # import the usdz + from io_scene_usdz.import_usdz import import_usdz + + import_usdz(context, filepath=object_path, materials=True, animations=True) + return None + + # load from existing import functions + import_function = IMPORT_FUNCTIONS[file_extension] + + print(f"Loading object from {object_path}") + if file_extension == "blend": + import_function(directory=object_path, link=False) + elif file_extension in {"glb", "gltf"}: + import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS') + else: + import_function(filepath=object_path) + +def delete_invisible_objects() -> None: + """Deletes all invisible objects in the scene. + + Returns: + None + """ + # bpy.ops.object.mode_set(mode="OBJECT") + bpy.ops.object.select_all(action="DESELECT") + for obj in bpy.context.scene.objects: + if obj.hide_viewport or obj.hide_render: + obj.hide_viewport = False + obj.hide_render = False + obj.hide_select = False + obj.select_set(True) + bpy.ops.object.delete() + + # Delete invisible collections + invisible_collections = [col for col in bpy.data.collections if col.hide_viewport] + for col in invisible_collections: + bpy.data.collections.remove(col) + +def split_mesh_normal(): + bpy.ops.object.select_all(action="DESELECT") + objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"] + bpy.context.view_layer.objects.active = objs[0] + for obj in objs: + obj.select_set(True) + bpy.ops.object.mode_set(mode="EDIT") + bpy.ops.mesh.select_all(action='SELECT') + bpy.ops.mesh.split_normals() + bpy.ops.object.mode_set(mode='OBJECT') + bpy.ops.object.select_all(action="DESELECT") + +def delete_custom_normals(): + for this_obj in bpy.data.objects: + if this_obj.type == "MESH": + bpy.context.view_layer.objects.active = this_obj + bpy.ops.mesh.customdata_custom_splitnormals_clear() + +def override_material(): + new_mat = bpy.data.materials.new(name="Override0123456789") + new_mat.use_nodes = True + new_mat.node_tree.nodes.clear() + bsdf = new_mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse') + bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1) + bsdf.inputs[1].default_value = 1 + output = new_mat.node_tree.nodes.new('ShaderNodeOutputMaterial') + new_mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface']) + bpy.context.scene.view_layers['View Layer'].material_override = new_mat + +def unhide_all_objects() -> None: + """Unhides all objects in the scene. + + Returns: + None + """ + for obj in bpy.context.scene.objects: + obj.hide_set(False) + +def convert_to_meshes() -> None: + """Converts all objects in the scene to meshes. + + Returns: + None + """ + bpy.ops.object.select_all(action="DESELECT") + bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0] + for obj in bpy.context.scene.objects: + obj.select_set(True) + bpy.ops.object.convert(target="MESH") + +def triangulate_meshes() -> None: + """Triangulates all meshes in the scene. + + Returns: + None + """ + bpy.ops.object.select_all(action="DESELECT") + objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"] + bpy.context.view_layer.objects.active = objs[0] + for obj in objs: + obj.select_set(True) + bpy.ops.object.mode_set(mode="EDIT") + bpy.ops.mesh.reveal() + bpy.ops.mesh.select_all(action="SELECT") + bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY") + bpy.ops.object.mode_set(mode="OBJECT") + bpy.ops.object.select_all(action="DESELECT") + +def scene_bbox() -> Tuple[Vector, Vector]: + """Returns the bounding box of the scene. + + Taken from Shap-E rendering script + (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82) + + Returns: + Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box. + """ + bbox_min = (math.inf,) * 3 + bbox_max = (-math.inf,) * 3 + found = False + scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)] + for obj in scene_meshes: + found = True + for coord in obj.bound_box: + coord = Vector(coord) + coord = obj.matrix_world @ coord + bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) + bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) + if not found: + raise RuntimeError("no objects in scene to compute bounding box for") + return Vector(bbox_min), Vector(bbox_max) + +def normalize_scene() -> Tuple[float, Vector]: + """Normalizes the scene by scaling and translating it to fit in a unit cube centered + at the origin. + + Mostly taken from the Point-E / Shap-E rendering script + (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112), + but fix for multiple root objects: (see bug report here: + https://github.com/openai/shap-e/pull/60). + + Returns: + Tuple[float, Vector]: The scale factor and the offset applied to the scene. + """ + scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent] + if len(scene_root_objects) > 1: + # create an empty object to be used as a parent for all root objects + scene = bpy.data.objects.new("ParentEmpty", None) + bpy.context.scene.collection.objects.link(scene) + + # parent all root objects to the empty object + for obj in scene_root_objects: + obj.parent = scene + else: + scene = scene_root_objects[0] + + bbox_min, bbox_max = scene_bbox() + scale = 1 / max(bbox_max - bbox_min) + scene.scale = scene.scale * scale + + # Apply scale to matrix_world. + bpy.context.view_layer.update() + bbox_min, bbox_max = scene_bbox() + offset = -(bbox_min + bbox_max) / 2 + scene.matrix_world.translation += offset + bpy.ops.object.select_all(action="DESELECT") + + return scale, offset + +def get_transform_matrix(obj: bpy.types.Object) -> list: + pos, rt, _ = obj.matrix_world.decompose() + rt = rt.to_matrix() + matrix = [] + for ii in range(3): + a = [] + for jj in range(3): + a.append(rt[ii][jj]) + a.append(pos[ii]) + matrix.append(a) + matrix.append([0, 0, 0, 1]) + return matrix + +def main(arg): + os.makedirs(arg.output_folder, exist_ok=True) + + # Initialize context + init_render(engine=arg.engine, resolution=arg.resolution, geo_mode=arg.geo_mode) + outputs, spec_nodes = init_nodes( + save_depth=arg.save_depth, + save_normal=arg.save_normal, + save_albedo=arg.save_albedo, + save_mist=arg.save_mist + ) + if arg.object.endswith(".blend"): + delete_invisible_objects() + else: + init_scene() + load_object(arg.object) + if arg.split_normal: + split_mesh_normal() + # delete_custom_normals() + print('[INFO] Scene initialized.') + + # normalize scene + scale, offset = normalize_scene() + print('[INFO] Scene normalized.') + + # Initialize camera and lighting + cam = init_camera() + init_lighting() + print('[INFO] Camera and lighting initialized.') + + # Override material + if arg.geo_mode: + override_material() + + # Create a list of views + to_export = { + "aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + "scale": scale, + "offset": [offset.x, offset.y, offset.z], + "frames": [] + } + views = json.loads(arg.views) + for i, view in enumerate(views): + cam.location = ( + view['radius'] * np.cos(view['yaw']) * np.cos(view['pitch']), + view['radius'] * np.sin(view['yaw']) * np.cos(view['pitch']), + view['radius'] * np.sin(view['pitch']) + ) + cam.data.lens = 16 / np.tan(view['fov'] / 2) + + if arg.save_depth: + spec_nodes['depth_map'].inputs[1].default_value = view['radius'] - 0.5 * np.sqrt(3) + spec_nodes['depth_map'].inputs[2].default_value = view['radius'] + 0.5 * np.sqrt(3) + + bpy.context.scene.render.filepath = os.path.join(arg.output_folder, f'{i:03d}.png') + for name, output in outputs.items(): + output.file_slots[0].path = os.path.join(arg.output_folder, f'{i:03d}_{name}') + + # Render the scene + bpy.ops.render.render(write_still=True) + bpy.context.view_layer.update() + for name, output in outputs.items(): + ext = EXT[output.format.file_format] + path = glob.glob(f'{output.file_slots[0].path}*.{ext}')[0] + os.rename(path, f'{output.file_slots[0].path}.{ext}') + + # Save camera parameters + metadata = { + "file_path": f'{i:03d}.png', + "camera_angle_x": view['fov'], + "transform_matrix": get_transform_matrix(cam) + } + if arg.save_depth: + metadata['depth'] = { + 'min': view['radius'] - 0.5 * np.sqrt(3), + 'max': view['radius'] + 0.5 * np.sqrt(3) + } + to_export["frames"].append(metadata) + + # Save the camera parameters + with open(os.path.join(arg.output_folder, 'transforms.json'), 'w') as f: + json.dump(to_export, f, indent=4) + + if arg.save_mesh: + # triangulate meshes + unhide_all_objects() + convert_to_meshes() + triangulate_meshes() + print('[INFO] Meshes triangulated.') + + # export ply mesh + bpy.ops.export_mesh.ply(filepath=os.path.join(arg.output_folder, 'mesh.ply')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.') + parser.add_argument('--views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.') + parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.') + parser.add_argument('--output_folder', type=str, default='/tmp', help='The path the output will be dumped to.') + parser.add_argument('--resolution', type=int, default=512, help='Resolution of the images.') + parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...') + parser.add_argument('--geo_mode', action='store_true', help='Geometry mode for rendering.') + parser.add_argument('--save_depth', action='store_true', help='Save the depth maps.') + parser.add_argument('--save_normal', action='store_true', help='Save the normal maps.') + parser.add_argument('--save_albedo', action='store_true', help='Save the albedo maps.') + parser.add_argument('--save_mist', action='store_true', help='Save the mist distance maps.') + parser.add_argument('--split_normal', action='store_true', help='Split the normals of the mesh.') + parser.add_argument('--save_mesh', action='store_true', help='Save the mesh as a .ply file.') + argv = sys.argv[sys.argv.index("--") + 1:] + args = parser.parse_args(argv) + + main(args) + \ No newline at end of file diff --git a/dataset_toolkits/render.py b/dataset_toolkits/render.py new file mode 100644 index 0000000..636f3b3 --- /dev/null +++ b/dataset_toolkits/render.py @@ -0,0 +1,121 @@ +import os +import json +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict +from functools import partial +from subprocess import DEVNULL, call +import numpy as np +from utils import sphere_hammersley_sequence + + +BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz' +BLENDER_INSTALLATION_PATH = '/tmp' +BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender' + +def _install_blender(): + if not os.path.exists(BLENDER_PATH): + os.system('sudo apt-get update') + os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6') + os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}') + os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}') + + +def _render(file_path, sha256, output_dir, num_views): + output_folder = os.path.join(output_dir, 'renders', sha256) + + # Build camera {yaw, pitch, radius, fov} + yaws = [] + pitchs = [] + offset = (np.random.rand(), np.random.rand()) + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views, offset) + yaws.append(y) + pitchs.append(p) + radius = [2] * num_views + fov = [40 / 180 * np.pi] * num_views + views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)] + + args = [ + BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'), + '--', + '--views', json.dumps(views), + '--object', os.path.expanduser(file_path), + '--resolution', '512', + '--output_folder', output_folder, + '--engine', 'CYCLES', + '--save_mesh', + ] + if file_path.endswith('.blend'): + args.insert(1, file_path) + + call(args, stdout=DEVNULL, stderr=DEVNULL) + + if os.path.exists(os.path.join(output_folder, 'transforms.json')): + return {'sha256': sha256, 'rendered': True} + + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--num_views', type=int, default=150, + help='Number of views to render') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + parser.add_argument('--max_workers', type=int, default=8) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True) + + # install blender + print('Checking blender...', flush=True) + _install_blender() + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + metadata = metadata[metadata['local_path'].notna()] + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'rendered' in metadata.columns: + metadata = metadata[metadata['rendered'] == False] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + for sha256 in copy.copy(metadata['sha256'].values): + if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')): + records.append({'sha256': sha256, 'rendered': True}) + metadata = metadata[metadata['sha256'] != sha256] + + print(f'Processing {len(metadata)} objects...') + + # process objects + func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views) + rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects') + rendered = pd.concat([rendered, pd.DataFrame.from_records(records)]) + rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False) diff --git a/dataset_toolkits/render_cond.py b/dataset_toolkits/render_cond.py new file mode 100644 index 0000000..b2a40e6 --- /dev/null +++ b/dataset_toolkits/render_cond.py @@ -0,0 +1,125 @@ +import os +import json +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict +from functools import partial +from subprocess import DEVNULL, call +import numpy as np +from utils import sphere_hammersley_sequence + + +BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz' +BLENDER_INSTALLATION_PATH = '/tmp' +BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender' + +def _install_blender(): + if not os.path.exists(BLENDER_PATH): + os.system('sudo apt-get update') + os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6') + os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}') + os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}') + + +def _render_cond(file_path, sha256, output_dir, num_views): + output_folder = os.path.join(output_dir, 'renders_cond', sha256) + + # Build camera {yaw, pitch, radius, fov} + yaws = [] + pitchs = [] + offset = (np.random.rand(), np.random.rand()) + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views, offset) + yaws.append(y) + pitchs.append(p) + fov_min, fov_max = 10, 70 + radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi) + radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi) + k_min = 1 / radius_max**2 + k_max = 1 / radius_min**2 + ks = np.random.uniform(k_min, k_max, (1000000,)) + radius = [1 / np.sqrt(k) for k in ks] + fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius] + views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)] + + args = [ + BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'), + '--', + '--views', json.dumps(views), + '--object', os.path.expanduser(file_path), + '--output_folder', os.path.expanduser(output_folder), + '--resolution', '1024', + ] + if file_path.endswith('.blend'): + args.insert(1, file_path) + + call(args, stdout=DEVNULL) + + if os.path.exists(os.path.join(output_folder, 'transforms.json')): + return {'sha256': sha256, 'cond_rendered': True} + + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--num_views', type=int, default=24, + help='Number of views to render') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + parser.add_argument('--max_workers', type=int, default=8) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True) + + # install blender + print('Checking blender...', flush=True) + _install_blender() + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + metadata = metadata[metadata['local_path'].notna()] + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'cond_rendered' in metadata.columns: + metadata = metadata[metadata['cond_rendered'] == False] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + for sha256 in copy.copy(metadata['sha256'].values): + if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')): + records.append({'sha256': sha256, 'cond_rendered': True}) + metadata = metadata[metadata['sha256'] != sha256] + + print(f'Processing {len(metadata)} objects...') + + # process objects + func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views) + cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects') + cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)]) + cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False) diff --git a/render_model.py b/render_model.py new file mode 100644 index 0000000..ec1e67f --- /dev/null +++ b/render_model.py @@ -0,0 +1,116 @@ +import subprocess + +import bpy +import os +import sys +from mathutils import Vector + + +def clear_scene(): + bpy.ops.object.select_all(action='SELECT') + bpy.ops.object.delete(use_global=False) + + +def import_model(model_path): + ext = os.path.splitext(model_path)[1].lower() + if ext == ".obj": + bpy.ops.wm.obj_import(filepath=model_path) + elif ext in [".glb", ".gltf"]: + bpy.ops.import_scene.gltf(filepath=model_path) + else: + raise ValueError(f"Unsupported format: {ext}") + + +def get_scene_bbox(): + objs = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH'] + if not objs: + raise RuntimeError("No mesh objects found") + + min_corner = Vector((float("inf"), float("inf"), float("inf"))) + max_corner = Vector((float("-inf"), float("-inf"), float("-inf"))) + + for obj in objs: + for v in obj.bound_box: + world_v = obj.matrix_world @ Vector(v) + min_corner.x = min(min_corner.x, world_v.x) + min_corner.y = min(min_corner.y, world_v.y) + min_corner.z = min(min_corner.z, world_v.z) + max_corner.x = max(max_corner.x, world_v.x) + max_corner.y = max(max_corner.y, world_v.y) + max_corner.z = max(max_corner.z, world_v.z) + + return min_corner, max_corner + + +def normalize_model(): + min_corner, max_corner = get_scene_bbox() + center = (min_corner + max_corner) / 2 + size = max(max_corner.x - min_corner.x, + max(max_corner.y - min_corner.y, max_corner.z - min_corner.z)) + + objs = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH'] + for obj in objs: + obj.location -= center + + if size > 0: + scale = 2.0 / size + for obj in objs: + obj.scale *= scale + + bpy.context.view_layer.update() + + +def add_camera(): + bpy.ops.object.camera_add(location=(2.5, -2.5, 2.0)) + cam = bpy.context.active_object + direction = Vector((0, 0, 0)) - cam.location + cam.rotation_euler = direction.to_track_quat('-Z', 'Y').to_euler() + bpy.context.scene.camera = cam + + +def add_light(): + bpy.ops.object.light_add(type='SUN', location=(5, -5, 8)) + bpy.context.active_object.data.energy = 3.0 + + bpy.ops.object.light_add(type='AREA', location=(3, -3, 3)) + area = bpy.context.active_object + area.data.energy = 3000 + area.data.size = 5 + + +def setup_render(output_path, resolution=1024): + scene = bpy.context.scene + scene.render.engine = 'CYCLES' + scene.cycles.samples = 128 + scene.render.filepath = output_path + scene.render.image_settings.file_format = 'PNG' + scene.render.resolution_x = resolution + scene.render.resolution_y = resolution + scene.render.film_transparent = True + + +def parse_args(): + argv = sys.argv + argv = argv[argv.index("--") + 1:] if "--" in argv else [] + if len(argv) < 2: + raise ValueError("Usage: blender -b -P render_model.py -- model_path output_path") + return argv[0], argv[1] + + +def get_static_model_image(): + model_path, output_path = parse_args() + clear_scene() + import_model(model_path) + normalize_model() + add_camera() + add_light() + setup_render(output_path) + bpy.ops.render.render(write_still=True) + print(f"Saved to {output_path}") + + + + + +if __name__ == "__main__": + get_static_model_image() diff --git a/server.py b/server.py new file mode 100644 index 0000000..03e79aa --- /dev/null +++ b/server.py @@ -0,0 +1,439 @@ +import mimetypes +import os +import secrets +import subprocess +import tempfile +import uuid +from datetime import datetime +from io import BytesIO + +import imageio +import numpy as np +import trimesh + +import litserve as ls +from minio import Minio + +from glb2svg import glb_to_obj, obj_to_step, step_to_svg +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils +from utils.new_oss_client import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, minio_get_image, MINIO_BUCKET, upload_bytes, upload_local_file, download_from_minio + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +def generate_unique_name(original_name: str) -> str: + stem, ext = os.path.splitext(original_name) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + random_part = secrets.token_hex(4) # 8位随机十六进制 + return f"{stem}_{timestamp}_{random_part}{ext}" + + +def load_mesh(file_path): + """ + 加载.obj或.glb文件,返回顶点数据 + + Args: + file_path: 模型文件路径(支持.obj和.glb格式) + + Returns: + numpy.ndarray: 顶点数组,shape为(N, 3) + """ + file_ext = os.path.splitext(file_path)[1].lower() + + if file_ext == '.obj': + # 使用trimesh加载obj文件 + mesh = trimesh.load(file_path, file_type='obj') + elif file_ext == '.glb' or file_ext == '.gltf': + # 使用trimesh加载glb/gltf文件 + mesh = trimesh.load(file_path, file_type='glb') + else: + raise ValueError(f"不支持的文件格式: {file_ext},仅支持.obj和.glb/.gltf") + + # 获取顶点数据 + if isinstance(mesh, trimesh.Scene): + # 如果是场景,合并所有几何体的顶点 + vertices = [] + for geom in mesh.geometry.values(): + vertices.append(geom.vertices) + vertices = np.vstack(vertices) + else: + # 如果是单个几何体 + vertices = mesh.vertices + + if len(vertices) == 0: + raise ValueError("文件中未找到顶点数据") + + return vertices + + +def analyze_mesh(file_path): + """ + 分析3D模型文件,计算质心、边界框、尺寸等信息 + + Args: + file_path: 模型文件路径(支持.obj和.glb格式) + + Returns: + dict: 包含模型分析信息的字典 + """ + # 加载模型并获取顶点 + vertices = load_mesh(file_path) + + # 边界框(每个轴的最小/最大值) + min_coords = vertices.min(axis=0) + max_coords = vertices.max(axis=0) + + # 质心 + centroid = vertices.mean(axis=0) + + # 尺寸 = 边界框维度 + size = max_coords - min_coords + + # 计算尺寸比例(每个轴占总尺寸的比例) + total_size = np.sum(size) + size_ratio = size / total_size if total_size != 0 else [0, 0, 0] + + info = { + # "file_path": file_path, + "file_format": os.path.splitext(file_path)[1].lower(), + "vertex_count": len(vertices), + "centroid": centroid.tolist(), + "bounding_box_min": min_coords.tolist(), + "bounding_box_max": max_coords.tolist(), + "size": size.tolist(), + "size_ratio": size_ratio.tolist(), + "size_ratio_percentage": (size_ratio * 100).tolist() + } + + return info + + +def render_glb_preview(glb_path, output_path): + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + cmd = [ + "blender", + "--background", + "--python", + "render_model.py", + "--", + glb_path, + output_path + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError( + f"Blender render failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}" + ) + + return output_path + + +class TrellisAPI(ls.LitAPI): + + def setup(self, device): + os.environ.setdefault("SPCONV_ALGO", "native") + + self.pipeline = TrellisImageTo3DPipeline.from_pretrained( + "microsoft/TRELLIS-image-large" + ) + self.pipeline.to(device) + + def decode_request(self, request): + image_paths = request["image_paths"] + images = [] + for path in image_paths: + bucket = path.split('/')[0] + object_name = path[path.find('/') + 1:] + + image = minio_get_image(minio_client, bucket, object_name) + images.append(image) + + params = { + "file_name": uuid.uuid4().hex, + "model": request.get("model", "single"), + "seed": request.get("seed", 1), + "steps_sparse": request.get("steps_sparse", 12), + "cfg_sparse": request.get("cfg_sparse", 7.5), + "steps_slat": request.get("steps_slat", 12), + "cfg_slat": request.get("cfg_slat", 3.0), + "simplify": request.get("simplify", 0.95), + "texture_size": request.get("texture_size", 1024), + "fps": request.get("fps", 30), + } + + return images, params + + def predict(self, inputs): + images, params = inputs + if params["model"] == "single": + outputs = self.pipeline.run( + images[0], + seed=params["seed"], + sparse_structure_sampler_params={ + "steps": params["steps_sparse"], + "cfg_strength": params["cfg_sparse"], + }, + slat_sampler_params={ + "steps": params["steps_slat"], + "cfg_strength": params["cfg_slat"], + }, + ) + else: + outputs = self.pipeline.run_multi_image( + images, + seed=params['seed'], + sparse_structure_sampler_params={ + "steps": params['steps_sparse'], + "cfg_strength": params['cfg_sparse'], + }, + slat_sampler_params={ + "steps": params['steps_slat'], + "cfg_strength": params['cfg_slat'], + }, + ) + + # video_path = self.upload_video(outputs, params) + + minio_glb_path, local_glb_path = self.upload_glb(outputs, params) + + glb_info = analyze_mesh(local_glb_path) + + local_static_model_image_path = os.path.join("glb_output", generate_unique_name("static_model_image.png")) + static_model_image = self.get_static_model_image(model_path=local_glb_path, output_path=local_static_model_image_path) + + return { + "glb_path": minio_glb_path, + "glb_static_img_path": static_model_image, + "glb_info": glb_info, + } + + def encode_response(self, output): + return output + + def upload_video(self, outputs, params): + gaussian_name = f"3d_result/video/{params['file_name']}-gaussian.mp4" + radiance_field_name = f"3d_result/video/{params['file_name']}-radiance_field.mp4" + mesh_name = f"3d_result/video/{params['file_name']}-mesh.mp4" + + # gaussian video + video = render_utils.render_video(outputs["gaussian"][0])["color"] + buffer = BytesIO() + imageio.mimsave(buffer, video, format="mp4", fps=params['fps']) + gaussian_video_path = upload_bytes( + buffer.getvalue(), + gaussian_name, + "video/mp4", + ) + + # radiance field video + video = render_utils.render_video(outputs["radiance_field"][0])["color"] + buffer = BytesIO() + imageio.mimsave(buffer, video, format="mp4", fps=params['fps']) + radiance_field_video_path = upload_bytes( + buffer.getvalue(), + radiance_field_name, + "video/mp4", + ) + + # mesh video + video = render_utils.render_video(outputs["mesh"][0])["normal"] + buffer = BytesIO() + imageio.mimsave(buffer, video, format="mp4", fps=params['fps']) + mesh_path = upload_bytes( + buffer.getvalue(), + mesh_name, + "video/mp4", + ) + + return { + "gaussian": gaussian_video_path, + "radiance_field": radiance_field_video_path, + "mesh": mesh_path + } + + def upload_glb(self, outputs, params): + file_name = f"3d_result/glb/{params['file_name']}.glb" + local_glb_path = os.path.join("glb_output", generate_unique_name("sample.glb")) + out_dir = os.path.dirname(local_glb_path) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + + glb = postprocessing_utils.to_glb( + outputs["gaussian"][0], + outputs["mesh"][0], + simplify=params['simplify'], + texture_size=params['texture_size'], + ) + + glb.export( + file_obj=local_glb_path, + file_type="glb" + ) + + glb_path = upload_local_file( + local_glb_path, + file_name, + "application/octet-stream", + ) + return glb_path, local_glb_path + + def upload_ply(self, outputs, params): + file_name = f"3d_result/ply/{params['file_name']}.ply" + + with tempfile.NamedTemporaryFile(suffix=".ply") as tmp: + outputs["gaussian"][0].save_ply(tmp.name) + tmp.seek(0) + + ply_path = upload_bytes( + tmp.read(), + file_name, + "application/octet-stream", + ) + return {"ply": ply_path} + + def get_static_model_image(self, model_path, output_path): + local_static_model_image_path = os.path.join( + "glb_output", + generate_unique_name("static_model_image.png") + ) + + print(f"model_path : {model_path}") + print(f"local_static_model_image_path :{local_static_model_image_path}") + output_path = render_glb_preview(model_path, local_static_model_image_path) + + static_model_image = self.upload_local_file( + output_path, + "png" + ) + + print(f"Saved to {static_model_image}") + return static_model_image + + def upload_local_file(self, local_path, type): + """ + 通用上传函数:支持 SVG, PNG, OBJ 等 + """ + object_name = f"3d_result/{type}/{uuid.uuid4().hex}.{type}" + if not os.path.exists(local_path): + print(f"错误: 文件 {local_path} 不存在") + return None + + # 自动根据后缀名识别 Content-Type + # 例如: .svg -> image/svg+xml, .png -> image/png + content_type, _ = mimetypes.guess_type(local_path) + if content_type is None: + content_type = "application/octet-stream" + + try: + minio_client.fput_object( + bucket_name=MINIO_BUCKET, + object_name=object_name, + file_path=local_path, + content_type=content_type + ) + print(f"成功上传 [{content_type}]: {object_name}") + return f"{MINIO_BUCKET}/{object_name}" + except Exception as e: + print(f"上传失败: {e}") + return None + + +class ModelToThreeViews(ls.LitAPI): + + def setup(self, device): + pass + + def upload_local_file(self, local_path, type): + """ + 通用上传函数:支持 SVG, PNG, OBJ 等 + """ + object_name = f"3d_result/{type}/{uuid.uuid4().hex}.{type}" + if not os.path.exists(local_path): + print(f"错误: 文件 {local_path} 不存在") + return None + + # 自动根据后缀名识别 Content-Type + # 例如: .svg -> image/svg+xml, .png -> image/png + content_type, _ = mimetypes.guess_type(local_path) + if content_type is None: + content_type = "application/octet-stream" + + try: + minio_client.fput_object( + bucket_name=MINIO_BUCKET, + object_name=object_name, + file_path=local_path, + content_type=content_type + ) + print(f"成功上传 [{content_type}]: {object_name}") + return f"{MINIO_BUCKET}/{object_name}" + except Exception as e: + print(f"上传失败: {e}") + return None + + def predict(self, request): + minio_glb_path = request['minio_glb_path'] + + work_dir = f"glb_to_obj" + os.makedirs(work_dir, exist_ok=True) + + glb_path = os.path.join(work_dir, f"model{uuid.uuid4().hex}.glb") + step_dir = os.path.join(work_dir, "step") + svg_dir = os.path.join(work_dir, "svg") + os.makedirs(step_dir, exist_ok=True) + os.makedirs(svg_dir, exist_ok=True) + print(f""" + 入参阶段: + input glb-obj minio-path:{minio_glb_path},\n + work_dir : {work_dir},glb_path : {glb_path},step_dir : {step_dir},svg_dir : {svg_dir}\n + """) + print("=" * 10) + + print(f" 第一阶段 下载glb文件: ") + # 1 下载 + glb_result = download_from_minio(object_path=minio_glb_path, local_path=glb_path) + print(f" 下载结果 : {glb_result} \n") + print("=" * 10) + + print(f" 第二阶段 glb -> obj: ") + # 2 glb -> obj + obj_result = glb_to_obj(glb_result) + print(f" glb -> obj 结果 : {obj_result} \n") + print("=" * 10) + + print(f" 第三阶段 obj -> step: ") + # 3 obj -> step + step_result = obj_to_step( + input_obj=obj_result, + output_dir=step_dir, + script_path="1_obj_to_step.py" + ) + print(f" obj -> step 结果 : {step_result} \n") + print("=" * 10) + + print(f" 第四阶段 step -> svg: ") + # 4 step -> svg + combined_svg, combined_png = step_to_svg( + step_path=step_result, + out_dir=svg_dir + ) + print(f" step -> svg 结果 : {combined_svg} \n") + print("=" * 10) + + # 5 上传 + minio_svg_path = self.upload_local_file(combined_png, "svg") + + return {"minio_svg_path": minio_svg_path} + + +if __name__ == "__main__": + trellis_api = TrellisAPI(api_path="/canvas/img_to_3D") + model_to_three_api = ModelToThreeViews(api_path="/canvas/3d_to_3views") + server = ls.LitServer([ + trellis_api, + model_to_three_api]) + server.run(port=8120) diff --git a/single_image_to_3D.py b/single_image_to_3D.py new file mode 100644 index 0000000..f09483e --- /dev/null +++ b/single_image_to_3D.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import argparse +from PIL import Image + +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils + +def build_parser(): + p = argparse.ArgumentParser("TRELLIS CLI: single image -> 3D") + + p.add_argument("-i", "--image", required=True, help="Input image path") + p.add_argument("-o", "--out_dir", default="trellis_out", help="Output directory") + + p.add_argument("--seed", type=int, default=1) + p.add_argument("--steps_sparse", type=int, default=12) + p.add_argument("--cfg_sparse", type=float, default=7.5) + p.add_argument("--steps_slat", type=int, default=12) + p.add_argument("--cfg_slat", type=float, default=3.0) + + p.add_argument("--simplify", type=float, default=0.95) + p.add_argument("--texture_size", type=int, default=1024) + + # Export GLB (default True) + p.add_argument("--export_glb", dest="export_glb", action="store_true", default=True) + p.add_argument("--no-export_glb", dest="export_glb", action="store_false") + + # Save PLY (default True) + p.add_argument("--save_ply", dest="save_ply", action="store_true", default=True) + p.add_argument("--no-save_ply", dest="save_ply", action="store_false") + + # Save videos (default False, plus explicit toggle) + p.add_argument("--save_video", dest="save_video", action="store_true", default=True) + p.add_argument("--no-save_video", dest="save_video", action="store_false") + + p.add_argument("--fps", type=int, default=30) + p.add_argument("--video_gs_name", type=str, default="sample_gs.mp4") + p.add_argument("--video_rf_name", type=str, default="sample_rf.mp4") + p.add_argument("--video_mesh_name", type=str, default="sample_mesh.mp4") + + return p + +def main(): + args = build_parser().parse_args() + os.makedirs(args.out_dir, exist_ok=True) + + # Optional env + os.environ.setdefault("SPCONV_ALGO", "native") + + pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") + pipeline.cuda() + + image = Image.open(args.image) + + outputs = pipeline.run( + image, + seed=args.seed, + sparse_structure_sampler_params={ + "steps": args.steps_sparse, + "cfg_strength": args.cfg_sparse, + }, + slat_sampler_params={ + "steps": args.steps_slat, + "cfg_strength": args.cfg_slat, + }, + ) + + if args.save_video: + import imageio + + video = render_utils.render_video(outputs["gaussian"][0])["color"] + imageio.mimsave(os.path.join(args.out_dir, args.video_gs_name), video, fps=args.fps) + + video = render_utils.render_video(outputs["radiance_field"][0])["color"] + imageio.mimsave(os.path.join(args.out_dir, args.video_rf_name), video, fps=args.fps) + + video = render_utils.render_video(outputs["mesh"][0])["normal"] + imageio.mimsave(os.path.join(args.out_dir, args.video_mesh_name), video, fps=args.fps) + + if args.export_glb: + glb = postprocessing_utils.to_glb( + outputs["gaussian"][0], + outputs["mesh"][0], + simplify=args.simplify, + texture_size=args.texture_size, + ) + glb.export(os.path.join(args.out_dir, "sample.glb")) + + if args.save_ply: + outputs["gaussian"][0].save_ply(os.path.join(args.out_dir, "sample.ply")) + +if __name__ == "__main__": + main() diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py new file mode 100755 index 0000000..5950b75 --- /dev/null +++ b/trellis/modules/sparse/attention/serialized_attn.py @@ -0,0 +1,193 @@ +from typing import * +from enum import Enum +import torch +import math +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_serialized_scaled_dot_product_self_attention', +] + + +class SerializeMode(Enum): + Z_ORDER = 0 + Z_ORDER_TRANSPOSED = 1 + HILBERT = 2 + HILBERT_TRANSPOSED = 3 + + +SerializeModes = [ + SerializeMode.Z_ORDER, + SerializeMode.Z_ORDER_TRANSPOSED, + SerializeMode.HILBERT, + SerializeMode.HILBERT_TRANSPOSED +] + + +def calc_serialization( + tensor: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + """ + Calculate serialization and partitioning for a set of coordinates. + + Args: + tensor (SparseTensor): The input tensor. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + + Returns: + (torch.Tensor, torch.Tensor): Forwards and backwards indices. + """ + fwd_indices = [] + bwd_indices = [] + seq_lens = [] + seq_batch_indices = [] + offsets = [0] + + if 'vox2seq' not in globals(): + import vox2seq + + # Serialize the input + serialize_coords = tensor.coords[:, 1:].clone() + serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) + if serialize_mode == SerializeMode.Z_ORDER: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) + elif serialize_mode == SerializeMode.HILBERT: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) + elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: + code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) + else: + raise ValueError(f"Unknown serialize mode: {serialize_mode}") + + for bi, s in enumerate(tensor.layout): + num_points = s.stop - s.start + num_windows = (num_points + window_size - 1) // window_size + valid_window_size = num_points / num_windows + to_ordered = torch.argsort(code[s.start:s.stop]) + if num_windows == 1: + fwd_indices.append(to_ordered) + bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) + fwd_indices[-1] += s.start + bwd_indices[-1] += offsets[-1] + seq_lens.append(num_points) + seq_batch_indices.append(bi) + offsets.append(offsets[-1] + seq_lens[-1]) + else: + # Partition the input + offset = 0 + mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] + split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] + bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) + for i in range(num_windows): + mid = mids[i] + valid_start = split[i] + valid_end = split[i + 1] + padded_start = math.floor(mid - 0.5 * window_size) + padded_end = padded_start + window_size + fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) + offset += valid_start - padded_start + bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) + offset += padded_end - valid_start + fwd_indices[-1] += s.start + seq_lens.extend([window_size] * num_windows) + seq_batch_indices.extend([bi] * num_windows) + bwd_indices.append(bwd_index + offsets[-1]) + offsets.append(offsets[-1] + num_windows * window_size) + + fwd_indices = torch.cat(fwd_indices) + bwd_indices = torch.cat(bwd_indices) + + return fwd_indices, bwd_indices, seq_lens, seq_batch_indices + + +def sparse_serialized_scaled_dot_product_self_attention( + qkv: SparseTensor, + window_size: int, + serialize_mode: SerializeMode = SerializeMode.Z_ORDER, + shift_sequence: int = 0, + shift_window: Tuple[int, int, int] = (0, 0, 0) +) -> SparseTensor: + """ + Apply serialized scaled dot product self attention to a sparse tensor. + + Args: + qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + window_size (int): The window size to use. + serialize_mode (SerializeMode): The serialization mode to use. + shift_sequence (int): The shift of serialized sequence. + shift_window (Tuple[int, int, int]): The shift of serialized coordinates. + shift (int): The shift to use. + """ + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + + serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) + else: + fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache + + M = fwd_indices.shape[0] + T = qkv.feats.shape[0] + H = qkv.feats.shape[2] + C = qkv.feats.shape[3] + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if DEBUG: + start = 0 + qkv_coords = qkv.coords[fwd_indices] + for i in range(len(seq_lens)): + assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" + start += seq_lens[i] + + if all([seq_len == window_size for seq_len in seq_lens]): + B = len(seq_lens) + N = window_size + qkv_feats = qkv_feats.reshape(B, N, 3, H, C) + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C] + out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C] + elif ATTN == 'flash_attn': + out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C] + else: + raise ValueError(f"Unknown attention module: {ATTN}") + out = out.reshape(B * N, H, C) # [M, H, C] + else: + if ATTN == 'xformers': + q, k, v = qkv_feats.unbind(dim=1) # [M, H, C] + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] + elif ATTN == 'flash_attn': + cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ + .to(qkv.device).int() + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + if DEBUG: + qkv_coords = qkv_coords[bwd_indices] + assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" + + return qkv.replace(out) diff --git a/trellis/renderers/sh_utils.py b/trellis/renderers/sh_utils.py new file mode 100755 index 0000000..bbca7d1 --- /dev/null +++ b/trellis/renderers/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/trellis/representations/mesh/flexicubes/examples/render.py b/trellis/representations/mesh/flexicubes/examples/render.py new file mode 100644 index 0000000..6aecbda --- /dev/null +++ b/trellis/representations/mesh/flexicubes/examples/render.py @@ -0,0 +1,274 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import copy +import math +from ipywidgets import interactive, HBox, VBox, FloatLogSlider, IntSlider + +import torch +import nvdiffrast.torch as dr +import kaolin as kal +import util + +############################################################################### +# Functions adapted from https://github.com/NVlabs/nvdiffrec +############################################################################### + +def get_random_camera_batch(batch_size, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True): + if use_kaolin: + camera_pos = torch.stack(kal.ops.coords.spherical2cartesian( + *kal.ops.random.sample_spherical_coords((batch_size,), azimuth_low=0., azimuth_high=math.pi * 2, + elevation_low=-math.pi / 2., elevation_high=math.pi / 2., device='cuda'), + cam_radius + ), dim=-1) + return kal.render.camera.Camera.from_args( + eye=camera_pos + torch.rand((batch_size, 1), device='cuda') * 0.5 - 0.25, + at=torch.zeros(batch_size, 3), + up=torch.tensor([[0., 1., 0.]]), + fov=fovy, + near=cam_near_far[0], far=cam_near_far[1], + height=iter_res[0], width=iter_res[1], + device='cuda' + ) + else: + def get_random_camera(): + proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1]) + mv = util.translate(0, 0, -cam_radius) @ util.random_rotation_translation(0.25) + mvp = proj_mtx @ mv + return mv, mvp + mv_batch = [] + mvp_batch = [] + for i in range(batch_size): + mv, mvp = get_random_camera() + mv_batch.append(mv) + mvp_batch.append(mvp) + return torch.stack(mv_batch).to(device), torch.stack(mvp_batch).to(device) + +def get_rotate_camera(itr, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True): + if use_kaolin: + ang = (itr / 10) * np.pi * 2 + camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(torch.tensor(ang), torch.tensor(0.4), -torch.tensor(cam_radius))) + return kal.render.camera.Camera.from_args( + eye=camera_pos, + at=torch.zeros(3), + up=torch.tensor([0., 1., 0.]), + fov=fovy, + near=cam_near_far[0], far=cam_near_far[1], + height=iter_res[0], width=iter_res[1], + device='cuda' + ) + else: + proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1]) + + # Smooth rotation for display. + ang = (itr / 10) * np.pi * 2 + mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang)) + mvp = proj_mtx @ mv + return mv.to(device), mvp.to(device) + +glctx = dr.RasterizeGLContext() +def render_mesh(mesh, camera, iter_res, return_types = ["mask", "depth"], white_bg=False, wireframe_thickness=0.4): + vertices_camera = camera.extrinsics.transform(mesh.vertices) + face_vertices_camera = kal.ops.mesh.index_vertices_by_faces( + vertices_camera, mesh.faces + ) + + # Projection: nvdiffrast take clip coordinates as input to apply barycentric perspective correction. + # Using `camera.intrinsics.transform(vertices_camera) would return the normalized device coordinates. + proj = camera.projection_matrix().unsqueeze(1) + proj[:, :, 1, 1] = -proj[:, :, 1, 1] + homogeneous_vecs = kal.render.camera.up_to_homogeneous( + vertices_camera + ) + vertices_clip = (proj @ homogeneous_vecs.unsqueeze(-1)).squeeze(-1) + faces_int = mesh.faces.int() + + rast, _ = dr.rasterize( + glctx, vertices_clip, faces_int, iter_res) + + out_dict = {} + for type in return_types: + if type == "mask" : + img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int) + elif type == "depth": + img = dr.interpolate(homogeneous_vecs, rast, faces_int)[0] + elif type == "wireframe": + img = torch.logical_or( + torch.logical_or(rast[..., 0] < wireframe_thickness, rast[..., 1] < wireframe_thickness), + (rast[..., 0] + rast[..., 1]) > (1. - wireframe_thickness) + ).unsqueeze(-1) + elif type == "normals" : + img = dr.interpolate( + mesh.face_normals.reshape(len(mesh), -1, 3), rast, + torch.arange(mesh.faces.shape[0] * 3, device='cuda', dtype=torch.int).reshape(-1, 3) + )[0] + if white_bg: + bg = torch.ones_like(img) + alpha = (rast[..., -1:] > 0).float() + img = torch.lerp(bg, img, alpha) + out_dict[type] = img + + + return out_dict + +def render_mesh_paper(mesh, mv, mvp, iter_res, return_types = ["mask", "depth"], white_bg=False): + ''' + The rendering function used to produce the results in the paper. + ''' + v_pos_clip = util.xfm_points(mesh.vertices.unsqueeze(0), mvp) # Rotate it to camera coordinates + rast, db = dr.rasterize( + dr.RasterizeGLContext(), v_pos_clip, mesh.faces.int(), iter_res) + + out_dict = {} + for type in return_types: + if type == "mask" : + img = dr.antialias((rast[..., -1:] > 0).float(), rast, v_pos_clip, mesh.faces.int()) + elif type == "depth": + v_pos_cam = util.xfm_points(mesh.vertices.unsqueeze(0), mv) + img, _ = util.interpolate(v_pos_cam, rast, mesh.faces.int()) + elif type == "normal" : + normal_indices = (torch.arange(0, mesh.nrm.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) + img, _ = util.interpolate(mesh.nrm.unsqueeze(0).contiguous(), rast, normal_indices.int()) + elif type == "vertex_normal": + img, _ = util.interpolate(mesh.v_nrm.unsqueeze(0).contiguous(), rast, mesh.faces.int()) + img = dr.antialias((img + 1) * 0.5, rast, v_pos_clip, mesh.faces.int()) + if white_bg: + bg = torch.ones_like(img) + alpha = (rast[..., -1:] > 0).float() + img = torch.lerp(bg, img, alpha) + out_dict[type] = img + return out_dict + +class SplitVisualizer(): + def __init__(self, lh_mesh, rh_mesh, height, width): + self.lh_mesh = lh_mesh + self.rh_mesh = rh_mesh + self.height = height + self.width = width + self.wireframe_thickness = 0.4 + + + def render(self, camera): + lh_outputs = render_mesh( + self.lh_mesh, camera, (self.height, self.width), + return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness + ) + rh_outputs = render_mesh( + self.rh_mesh, camera, (self.height, self.width), + return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness + ) + outputs = { + k: torch.cat( + [lh_outputs[k][0].permute(1, 0, 2), rh_outputs[k][0].permute(1, 0, 2)], + dim=0 + ).permute(1, 0, 2) for k in ["normals", "wireframe"] + } + return { + 'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8), + 'normals': outputs['normals'] + } + + def show(self, init_camera): + visualizer = kal.visualize.IpyTurntableVisualizer( + self.height, self.width * 2, copy.deepcopy(init_camera), self.render, + max_fps=24, world_up_axis=1) + + def slider_callback(new_wireframe_thickness): + """ipywidgets sliders callback""" + with visualizer.out: # This is in case of bug + self.wireframe_thickness = new_wireframe_thickness + # this is how we request a new update + visualizer.render_update() + + wireframe_thickness_slider = FloatLogSlider( + value=self.wireframe_thickness, + base=10, + min=-3, + max=-0.4, + step=0.1, + description='wireframe_thickness', + continuous_update=True, + readout=True, + readout_format='.3f', + ) + + interactive_slider = interactive( + slider_callback, + new_wireframe_thickness=wireframe_thickness_slider, + ) + + full_output = VBox([visualizer.canvas, interactive_slider]) + display(full_output, visualizer.out) + +class TimelineVisualizer(): + def __init__(self, meshes, height, width): + self.meshes = meshes + self.height = height + self.width = width + self.wireframe_thickness = 0.4 + self.idx = len(meshes) - 1 + + def render(self, camera): + outputs = render_mesh( + self.meshes[self.idx], camera, (self.height, self.width), + return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness + ) + + return { + 'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8)[0], + 'normals': outputs['normals'][0] + } + + def show(self, init_camera): + visualizer = kal.visualize.IpyTurntableVisualizer( + self.height, self.width, copy.deepcopy(init_camera), self.render, + max_fps=24, world_up_axis=1) + + def slider_callback(new_wireframe_thickness, new_idx): + """ipywidgets sliders callback""" + with visualizer.out: # This is in case of bug + self.wireframe_thickness = new_wireframe_thickness + self.idx = new_idx + # this is how we request a new update + visualizer.render_update() + + wireframe_thickness_slider = FloatLogSlider( + value=self.wireframe_thickness, + base=10, + min=-3, + max=-0.4, + step=0.1, + description='wireframe_thickness', + continuous_update=True, + readout=True, + readout_format='.3f', + ) + + idx_slider = IntSlider( + value=self.idx, + min=0, + max=len(self.meshes) - 1, + description='idx', + continuous_update=True, + readout=True + ) + + interactive_slider = interactive( + slider_callback, + new_wireframe_thickness=wireframe_thickness_slider, + new_idx=idx_slider + ) + full_output = HBox([visualizer.canvas, interactive_slider]) + display(full_output, visualizer.out) diff --git a/trellis/utils/random_utils.py b/trellis/utils/random_utils.py new file mode 100644 index 0000000..5b668c2 --- /dev/null +++ b/trellis/utils/random_utils.py @@ -0,0 +1,30 @@ +import numpy as np + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[dim], n) for dim in range(dim)] + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + +def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False): + u, v = hammersley_sequence(2, n, num_samples) + u += offset[0] / num_samples + v += offset[1] + if remap: + u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3 + theta = np.arccos(1 - 2 * u) - np.pi / 2 + phi = v * 2 * np.pi + return [phi, theta] \ No newline at end of file diff --git a/trellis/utils/render_utils.py b/trellis/utils/render_utils.py new file mode 100644 index 0000000..c13d902 --- /dev/null +++ b/trellis/utils/render_utils.py @@ -0,0 +1,120 @@ +import torch +import numpy as np +from tqdm import tqdm +import utils3d +from PIL import Image + +from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer +from ..representations import Octree, Gaussian, MeshExtractResult +from ..modules import sparse as sp +from .random_utils import sphere_hammersley_sequence + + +def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs): + is_list = isinstance(yaws, list) + if not is_list: + yaws = [yaws] + pitchs = [pitchs] + if not isinstance(rs, list): + rs = [rs] * len(yaws) + if not isinstance(fovs, list): + fovs = [fovs] * len(yaws) + extrinsics = [] + intrinsics = [] + for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs): + fov = torch.deg2rad(torch.tensor(float(fov))).cuda() + yaw = torch.tensor(float(yaw)).cuda() + pitch = torch.tensor(float(pitch)).cuda() + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda() * r + extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics.append(extr) + intrinsics.append(intr) + if not is_list: + extrinsics = extrinsics[0] + intrinsics = intrinsics[0] + return extrinsics, intrinsics + + +def get_renderer(sample, **kwargs): + if isinstance(sample, Octree): + renderer = OctreeRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 0.8) + renderer.rendering_options.far = kwargs.get('far', 1.6) + renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 4) + renderer.pipe.primitive = sample.primitive + elif isinstance(sample, Gaussian): + renderer = GaussianRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 0.8) + renderer.rendering_options.far = kwargs.get('far', 1.6) + renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0)) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 1) + renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1) + renderer.pipe.use_mip_gaussian = True + elif isinstance(sample, MeshExtractResult): + renderer = MeshRenderer() + renderer.rendering_options.resolution = kwargs.get('resolution', 512) + renderer.rendering_options.near = kwargs.get('near', 1) + renderer.rendering_options.far = kwargs.get('far', 100) + renderer.rendering_options.ssaa = kwargs.get('ssaa', 4) + else: + raise ValueError(f'Unsupported sample type: {type(sample)}') + return renderer + + +def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs): + renderer = get_renderer(sample, **options) + rets = {} + for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose): + if isinstance(sample, MeshExtractResult): + res = renderer.render(sample, extr, intr) + if 'normal' not in rets: rets['normal'] = [] + rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + else: + res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite) + if 'color' not in rets: rets['color'] = [] + if 'depth' not in rets: rets['depth'] = [] + rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)) + if 'percent_depth' in res: + rets['depth'].append(res['percent_depth'].detach().cpu().numpy()) + elif 'depth' in res: + rets['depth'].append(res['depth'].detach().cpu().numpy()) + else: + rets['depth'].append(None) + return rets + + +def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs): + yaws = torch.linspace(0, 2 * 3.1415, num_frames) + pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames)) + yaws = yaws.tolist() + pitch = pitch.tolist() + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov) + return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs) + + +def render_multiview(sample, resolution=512, nviews=30): + r = 2 + fov = 40 + cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)] + yaws = [cam[0] for cam in cams] + pitchs = [cam[1] for cam in cams] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov) + res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)}) + return res['color'], extrinsics, intrinsics + + +def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs): + yaw = [0, np.pi/2, np.pi, 3*np.pi/2] + yaw_offset = offset[0] + yaw = [y + yaw_offset for y in yaw] + pitch = [offset[1] for _ in range(4)] + extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov) + return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)