1
This commit is contained in:
528
dataset_toolkits/blender_script/render.py
Normal file
528
dataset_toolkits/blender_script/render.py
Normal file
@@ -0,0 +1,528 @@
|
|||||||
|
import argparse, sys, os, math, re, glob
|
||||||
|
from typing import *
|
||||||
|
import bpy
|
||||||
|
from mathutils import Vector, Matrix
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
"""=============== BLENDER ==============="""
|
||||||
|
|
||||||
|
IMPORT_FUNCTIONS: Dict[str, Callable] = {
|
||||||
|
"obj": bpy.ops.import_scene.obj,
|
||||||
|
"glb": bpy.ops.import_scene.gltf,
|
||||||
|
"gltf": bpy.ops.import_scene.gltf,
|
||||||
|
"usd": bpy.ops.import_scene.usd,
|
||||||
|
"fbx": bpy.ops.import_scene.fbx,
|
||||||
|
"stl": bpy.ops.import_mesh.stl,
|
||||||
|
"usda": bpy.ops.import_scene.usda,
|
||||||
|
"dae": bpy.ops.wm.collada_import,
|
||||||
|
"ply": bpy.ops.import_mesh.ply,
|
||||||
|
"abc": bpy.ops.wm.alembic_import,
|
||||||
|
"blend": bpy.ops.wm.append,
|
||||||
|
}
|
||||||
|
|
||||||
|
EXT = {
|
||||||
|
'PNG': 'png',
|
||||||
|
'JPEG': 'jpg',
|
||||||
|
'OPEN_EXR': 'exr',
|
||||||
|
'TIFF': 'tiff',
|
||||||
|
'BMP': 'bmp',
|
||||||
|
'HDR': 'hdr',
|
||||||
|
'TARGA': 'tga'
|
||||||
|
}
|
||||||
|
|
||||||
|
def init_render(engine='CYCLES', resolution=512, geo_mode=False):
|
||||||
|
bpy.context.scene.render.engine = engine
|
||||||
|
bpy.context.scene.render.resolution_x = resolution
|
||||||
|
bpy.context.scene.render.resolution_y = resolution
|
||||||
|
bpy.context.scene.render.resolution_percentage = 100
|
||||||
|
bpy.context.scene.render.image_settings.file_format = 'PNG'
|
||||||
|
bpy.context.scene.render.image_settings.color_mode = 'RGBA'
|
||||||
|
bpy.context.scene.render.film_transparent = True
|
||||||
|
|
||||||
|
bpy.context.scene.cycles.device = 'GPU'
|
||||||
|
bpy.context.scene.cycles.samples = 128 if not geo_mode else 1
|
||||||
|
bpy.context.scene.cycles.filter_type = 'BOX'
|
||||||
|
bpy.context.scene.cycles.filter_width = 1
|
||||||
|
bpy.context.scene.cycles.diffuse_bounces = 1
|
||||||
|
bpy.context.scene.cycles.glossy_bounces = 1
|
||||||
|
bpy.context.scene.cycles.transparent_max_bounces = 3 if not geo_mode else 0
|
||||||
|
bpy.context.scene.cycles.transmission_bounces = 3 if not geo_mode else 1
|
||||||
|
bpy.context.scene.cycles.use_denoising = True
|
||||||
|
|
||||||
|
bpy.context.preferences.addons['cycles'].preferences.get_devices()
|
||||||
|
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
|
||||||
|
|
||||||
|
def init_nodes(save_depth=False, save_normal=False, save_albedo=False, save_mist=False):
|
||||||
|
if not any([save_depth, save_normal, save_albedo, save_mist]):
|
||||||
|
return {}, {}
|
||||||
|
outputs = {}
|
||||||
|
spec_nodes = {}
|
||||||
|
|
||||||
|
bpy.context.scene.use_nodes = True
|
||||||
|
bpy.context.scene.view_layers['View Layer'].use_pass_z = save_depth
|
||||||
|
bpy.context.scene.view_layers['View Layer'].use_pass_normal = save_normal
|
||||||
|
bpy.context.scene.view_layers['View Layer'].use_pass_diffuse_color = save_albedo
|
||||||
|
bpy.context.scene.view_layers['View Layer'].use_pass_mist = save_mist
|
||||||
|
|
||||||
|
nodes = bpy.context.scene.node_tree.nodes
|
||||||
|
links = bpy.context.scene.node_tree.links
|
||||||
|
for n in nodes:
|
||||||
|
nodes.remove(n)
|
||||||
|
|
||||||
|
render_layers = nodes.new('CompositorNodeRLayers')
|
||||||
|
|
||||||
|
if save_depth:
|
||||||
|
depth_file_output = nodes.new('CompositorNodeOutputFile')
|
||||||
|
depth_file_output.base_path = ''
|
||||||
|
depth_file_output.file_slots[0].use_node_format = True
|
||||||
|
depth_file_output.format.file_format = 'PNG'
|
||||||
|
depth_file_output.format.color_depth = '16'
|
||||||
|
depth_file_output.format.color_mode = 'BW'
|
||||||
|
# Remap to 0-1
|
||||||
|
map = nodes.new(type="CompositorNodeMapRange")
|
||||||
|
map.inputs[1].default_value = 0 # (min value you will be getting)
|
||||||
|
map.inputs[2].default_value = 10 # (max value you will be getting)
|
||||||
|
map.inputs[3].default_value = 0 # (min value you will map to)
|
||||||
|
map.inputs[4].default_value = 1 # (max value you will map to)
|
||||||
|
|
||||||
|
links.new(render_layers.outputs['Depth'], map.inputs[0])
|
||||||
|
links.new(map.outputs[0], depth_file_output.inputs[0])
|
||||||
|
|
||||||
|
outputs['depth'] = depth_file_output
|
||||||
|
spec_nodes['depth_map'] = map
|
||||||
|
|
||||||
|
if save_normal:
|
||||||
|
normal_file_output = nodes.new('CompositorNodeOutputFile')
|
||||||
|
normal_file_output.base_path = ''
|
||||||
|
normal_file_output.file_slots[0].use_node_format = True
|
||||||
|
normal_file_output.format.file_format = 'OPEN_EXR'
|
||||||
|
normal_file_output.format.color_mode = 'RGB'
|
||||||
|
normal_file_output.format.color_depth = '16'
|
||||||
|
|
||||||
|
links.new(render_layers.outputs['Normal'], normal_file_output.inputs[0])
|
||||||
|
|
||||||
|
outputs['normal'] = normal_file_output
|
||||||
|
|
||||||
|
if save_albedo:
|
||||||
|
albedo_file_output = nodes.new('CompositorNodeOutputFile')
|
||||||
|
albedo_file_output.base_path = ''
|
||||||
|
albedo_file_output.file_slots[0].use_node_format = True
|
||||||
|
albedo_file_output.format.file_format = 'PNG'
|
||||||
|
albedo_file_output.format.color_mode = 'RGBA'
|
||||||
|
albedo_file_output.format.color_depth = '8'
|
||||||
|
|
||||||
|
alpha_albedo = nodes.new('CompositorNodeSetAlpha')
|
||||||
|
|
||||||
|
links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image'])
|
||||||
|
links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha'])
|
||||||
|
links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0])
|
||||||
|
|
||||||
|
outputs['albedo'] = albedo_file_output
|
||||||
|
|
||||||
|
if save_mist:
|
||||||
|
bpy.data.worlds['World'].mist_settings.start = 0
|
||||||
|
bpy.data.worlds['World'].mist_settings.depth = 10
|
||||||
|
|
||||||
|
mist_file_output = nodes.new('CompositorNodeOutputFile')
|
||||||
|
mist_file_output.base_path = ''
|
||||||
|
mist_file_output.file_slots[0].use_node_format = True
|
||||||
|
mist_file_output.format.file_format = 'PNG'
|
||||||
|
mist_file_output.format.color_mode = 'BW'
|
||||||
|
mist_file_output.format.color_depth = '16'
|
||||||
|
|
||||||
|
links.new(render_layers.outputs['Mist'], mist_file_output.inputs[0])
|
||||||
|
|
||||||
|
outputs['mist'] = mist_file_output
|
||||||
|
|
||||||
|
return outputs, spec_nodes
|
||||||
|
|
||||||
|
def init_scene() -> None:
|
||||||
|
"""Resets the scene to a clean state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# delete everything
|
||||||
|
for obj in bpy.data.objects:
|
||||||
|
bpy.data.objects.remove(obj, do_unlink=True)
|
||||||
|
|
||||||
|
# delete all the materials
|
||||||
|
for material in bpy.data.materials:
|
||||||
|
bpy.data.materials.remove(material, do_unlink=True)
|
||||||
|
|
||||||
|
# delete all the textures
|
||||||
|
for texture in bpy.data.textures:
|
||||||
|
bpy.data.textures.remove(texture, do_unlink=True)
|
||||||
|
|
||||||
|
# delete all the images
|
||||||
|
for image in bpy.data.images:
|
||||||
|
bpy.data.images.remove(image, do_unlink=True)
|
||||||
|
|
||||||
|
def init_camera():
|
||||||
|
cam = bpy.data.objects.new('Camera', bpy.data.cameras.new('Camera'))
|
||||||
|
bpy.context.collection.objects.link(cam)
|
||||||
|
bpy.context.scene.camera = cam
|
||||||
|
cam.data.sensor_height = cam.data.sensor_width = 32
|
||||||
|
cam_constraint = cam.constraints.new(type='TRACK_TO')
|
||||||
|
cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
|
||||||
|
cam_constraint.up_axis = 'UP_Y'
|
||||||
|
cam_empty = bpy.data.objects.new("Empty", None)
|
||||||
|
cam_empty.location = (0, 0, 0)
|
||||||
|
bpy.context.scene.collection.objects.link(cam_empty)
|
||||||
|
cam_constraint.target = cam_empty
|
||||||
|
return cam
|
||||||
|
|
||||||
|
def init_lighting():
|
||||||
|
# Clear existing lights
|
||||||
|
bpy.ops.object.select_all(action="DESELECT")
|
||||||
|
bpy.ops.object.select_by_type(type="LIGHT")
|
||||||
|
bpy.ops.object.delete()
|
||||||
|
|
||||||
|
# Create key light
|
||||||
|
default_light = bpy.data.objects.new("Default_Light", bpy.data.lights.new("Default_Light", type="POINT"))
|
||||||
|
bpy.context.collection.objects.link(default_light)
|
||||||
|
default_light.data.energy = 1000
|
||||||
|
default_light.location = (4, 1, 6)
|
||||||
|
default_light.rotation_euler = (0, 0, 0)
|
||||||
|
|
||||||
|
# create top light
|
||||||
|
top_light = bpy.data.objects.new("Top_Light", bpy.data.lights.new("Top_Light", type="AREA"))
|
||||||
|
bpy.context.collection.objects.link(top_light)
|
||||||
|
top_light.data.energy = 10000
|
||||||
|
top_light.location = (0, 0, 10)
|
||||||
|
top_light.scale = (100, 100, 100)
|
||||||
|
|
||||||
|
# create bottom light
|
||||||
|
bottom_light = bpy.data.objects.new("Bottom_Light", bpy.data.lights.new("Bottom_Light", type="AREA"))
|
||||||
|
bpy.context.collection.objects.link(bottom_light)
|
||||||
|
bottom_light.data.energy = 1000
|
||||||
|
bottom_light.location = (0, 0, -10)
|
||||||
|
bottom_light.rotation_euler = (0, 0, 0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"default_light": default_light,
|
||||||
|
"top_light": top_light,
|
||||||
|
"bottom_light": bottom_light
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_object(object_path: str) -> None:
|
||||||
|
"""Loads a model with a supported file extension into the scene.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_path (str): Path to the model file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the file extension is not supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
file_extension = object_path.split(".")[-1].lower()
|
||||||
|
if file_extension is None:
|
||||||
|
raise ValueError(f"Unsupported file type: {object_path}")
|
||||||
|
|
||||||
|
if file_extension == "usdz":
|
||||||
|
# install usdz io package
|
||||||
|
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
usdz_package = os.path.join(dirname, "io_scene_usdz.zip")
|
||||||
|
bpy.ops.preferences.addon_install(filepath=usdz_package)
|
||||||
|
# enable it
|
||||||
|
addon_name = "io_scene_usdz"
|
||||||
|
bpy.ops.preferences.addon_enable(module=addon_name)
|
||||||
|
# import the usdz
|
||||||
|
from io_scene_usdz.import_usdz import import_usdz
|
||||||
|
|
||||||
|
import_usdz(context, filepath=object_path, materials=True, animations=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# load from existing import functions
|
||||||
|
import_function = IMPORT_FUNCTIONS[file_extension]
|
||||||
|
|
||||||
|
print(f"Loading object from {object_path}")
|
||||||
|
if file_extension == "blend":
|
||||||
|
import_function(directory=object_path, link=False)
|
||||||
|
elif file_extension in {"glb", "gltf"}:
|
||||||
|
import_function(filepath=object_path, merge_vertices=True, import_shading='NORMALS')
|
||||||
|
else:
|
||||||
|
import_function(filepath=object_path)
|
||||||
|
|
||||||
|
def delete_invisible_objects() -> None:
|
||||||
|
"""Deletes all invisible objects in the scene.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
# bpy.ops.object.mode_set(mode="OBJECT")
|
||||||
|
bpy.ops.object.select_all(action="DESELECT")
|
||||||
|
for obj in bpy.context.scene.objects:
|
||||||
|
if obj.hide_viewport or obj.hide_render:
|
||||||
|
obj.hide_viewport = False
|
||||||
|
obj.hide_render = False
|
||||||
|
obj.hide_select = False
|
||||||
|
obj.select_set(True)
|
||||||
|
bpy.ops.object.delete()
|
||||||
|
|
||||||
|
# Delete invisible collections
|
||||||
|
invisible_collections = [col for col in bpy.data.collections if col.hide_viewport]
|
||||||
|
for col in invisible_collections:
|
||||||
|
bpy.data.collections.remove(col)
|
||||||
|
|
||||||
|
def split_mesh_normal():
|
||||||
|
bpy.ops.object.select_all(action="DESELECT")
|
||||||
|
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
||||||
|
bpy.context.view_layer.objects.active = objs[0]
|
||||||
|
for obj in objs:
|
||||||
|
obj.select_set(True)
|
||||||
|
bpy.ops.object.mode_set(mode="EDIT")
|
||||||
|
bpy.ops.mesh.select_all(action='SELECT')
|
||||||
|
bpy.ops.mesh.split_normals()
|
||||||
|
bpy.ops.object.mode_set(mode='OBJECT')
|
||||||
|
bpy.ops.object.select_all(action="DESELECT")
|
||||||
|
|
||||||
|
def delete_custom_normals():
|
||||||
|
for this_obj in bpy.data.objects:
|
||||||
|
if this_obj.type == "MESH":
|
||||||
|
bpy.context.view_layer.objects.active = this_obj
|
||||||
|
bpy.ops.mesh.customdata_custom_splitnormals_clear()
|
||||||
|
|
||||||
|
def override_material():
|
||||||
|
new_mat = bpy.data.materials.new(name="Override0123456789")
|
||||||
|
new_mat.use_nodes = True
|
||||||
|
new_mat.node_tree.nodes.clear()
|
||||||
|
bsdf = new_mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
|
||||||
|
bsdf.inputs[0].default_value = (0.5, 0.5, 0.5, 1)
|
||||||
|
bsdf.inputs[1].default_value = 1
|
||||||
|
output = new_mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
||||||
|
new_mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
||||||
|
bpy.context.scene.view_layers['View Layer'].material_override = new_mat
|
||||||
|
|
||||||
|
def unhide_all_objects() -> None:
|
||||||
|
"""Unhides all objects in the scene.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
for obj in bpy.context.scene.objects:
|
||||||
|
obj.hide_set(False)
|
||||||
|
|
||||||
|
def convert_to_meshes() -> None:
|
||||||
|
"""Converts all objects in the scene to meshes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
bpy.ops.object.select_all(action="DESELECT")
|
||||||
|
bpy.context.view_layer.objects.active = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"][0]
|
||||||
|
for obj in bpy.context.scene.objects:
|
||||||
|
obj.select_set(True)
|
||||||
|
bpy.ops.object.convert(target="MESH")
|
||||||
|
|
||||||
|
def triangulate_meshes() -> None:
|
||||||
|
"""Triangulates all meshes in the scene.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
bpy.ops.object.select_all(action="DESELECT")
|
||||||
|
objs = [obj for obj in bpy.context.scene.objects if obj.type == "MESH"]
|
||||||
|
bpy.context.view_layer.objects.active = objs[0]
|
||||||
|
for obj in objs:
|
||||||
|
obj.select_set(True)
|
||||||
|
bpy.ops.object.mode_set(mode="EDIT")
|
||||||
|
bpy.ops.mesh.reveal()
|
||||||
|
bpy.ops.mesh.select_all(action="SELECT")
|
||||||
|
bpy.ops.mesh.quads_convert_to_tris(quad_method="BEAUTY", ngon_method="BEAUTY")
|
||||||
|
bpy.ops.object.mode_set(mode="OBJECT")
|
||||||
|
bpy.ops.object.select_all(action="DESELECT")
|
||||||
|
|
||||||
|
def scene_bbox() -> Tuple[Vector, Vector]:
|
||||||
|
"""Returns the bounding box of the scene.
|
||||||
|
|
||||||
|
Taken from Shap-E rendering script
|
||||||
|
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
|
||||||
|
"""
|
||||||
|
bbox_min = (math.inf,) * 3
|
||||||
|
bbox_max = (-math.inf,) * 3
|
||||||
|
found = False
|
||||||
|
scene_meshes = [obj for obj in bpy.context.scene.objects.values() if isinstance(obj.data, bpy.types.Mesh)]
|
||||||
|
for obj in scene_meshes:
|
||||||
|
found = True
|
||||||
|
for coord in obj.bound_box:
|
||||||
|
coord = Vector(coord)
|
||||||
|
coord = obj.matrix_world @ coord
|
||||||
|
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
||||||
|
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
||||||
|
if not found:
|
||||||
|
raise RuntimeError("no objects in scene to compute bounding box for")
|
||||||
|
return Vector(bbox_min), Vector(bbox_max)
|
||||||
|
|
||||||
|
def normalize_scene() -> Tuple[float, Vector]:
|
||||||
|
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
|
||||||
|
at the origin.
|
||||||
|
|
||||||
|
Mostly taken from the Point-E / Shap-E rendering script
|
||||||
|
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
|
||||||
|
but fix for multiple root objects: (see bug report here:
|
||||||
|
https://github.com/openai/shap-e/pull/60).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[float, Vector]: The scale factor and the offset applied to the scene.
|
||||||
|
"""
|
||||||
|
scene_root_objects = [obj for obj in bpy.context.scene.objects.values() if not obj.parent]
|
||||||
|
if len(scene_root_objects) > 1:
|
||||||
|
# create an empty object to be used as a parent for all root objects
|
||||||
|
scene = bpy.data.objects.new("ParentEmpty", None)
|
||||||
|
bpy.context.scene.collection.objects.link(scene)
|
||||||
|
|
||||||
|
# parent all root objects to the empty object
|
||||||
|
for obj in scene_root_objects:
|
||||||
|
obj.parent = scene
|
||||||
|
else:
|
||||||
|
scene = scene_root_objects[0]
|
||||||
|
|
||||||
|
bbox_min, bbox_max = scene_bbox()
|
||||||
|
scale = 1 / max(bbox_max - bbox_min)
|
||||||
|
scene.scale = scene.scale * scale
|
||||||
|
|
||||||
|
# Apply scale to matrix_world.
|
||||||
|
bpy.context.view_layer.update()
|
||||||
|
bbox_min, bbox_max = scene_bbox()
|
||||||
|
offset = -(bbox_min + bbox_max) / 2
|
||||||
|
scene.matrix_world.translation += offset
|
||||||
|
bpy.ops.object.select_all(action="DESELECT")
|
||||||
|
|
||||||
|
return scale, offset
|
||||||
|
|
||||||
|
def get_transform_matrix(obj: bpy.types.Object) -> list:
|
||||||
|
pos, rt, _ = obj.matrix_world.decompose()
|
||||||
|
rt = rt.to_matrix()
|
||||||
|
matrix = []
|
||||||
|
for ii in range(3):
|
||||||
|
a = []
|
||||||
|
for jj in range(3):
|
||||||
|
a.append(rt[ii][jj])
|
||||||
|
a.append(pos[ii])
|
||||||
|
matrix.append(a)
|
||||||
|
matrix.append([0, 0, 0, 1])
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
def main(arg):
|
||||||
|
os.makedirs(arg.output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize context
|
||||||
|
init_render(engine=arg.engine, resolution=arg.resolution, geo_mode=arg.geo_mode)
|
||||||
|
outputs, spec_nodes = init_nodes(
|
||||||
|
save_depth=arg.save_depth,
|
||||||
|
save_normal=arg.save_normal,
|
||||||
|
save_albedo=arg.save_albedo,
|
||||||
|
save_mist=arg.save_mist
|
||||||
|
)
|
||||||
|
if arg.object.endswith(".blend"):
|
||||||
|
delete_invisible_objects()
|
||||||
|
else:
|
||||||
|
init_scene()
|
||||||
|
load_object(arg.object)
|
||||||
|
if arg.split_normal:
|
||||||
|
split_mesh_normal()
|
||||||
|
# delete_custom_normals()
|
||||||
|
print('[INFO] Scene initialized.')
|
||||||
|
|
||||||
|
# normalize scene
|
||||||
|
scale, offset = normalize_scene()
|
||||||
|
print('[INFO] Scene normalized.')
|
||||||
|
|
||||||
|
# Initialize camera and lighting
|
||||||
|
cam = init_camera()
|
||||||
|
init_lighting()
|
||||||
|
print('[INFO] Camera and lighting initialized.')
|
||||||
|
|
||||||
|
# Override material
|
||||||
|
if arg.geo_mode:
|
||||||
|
override_material()
|
||||||
|
|
||||||
|
# Create a list of views
|
||||||
|
to_export = {
|
||||||
|
"aabb": [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
|
||||||
|
"scale": scale,
|
||||||
|
"offset": [offset.x, offset.y, offset.z],
|
||||||
|
"frames": []
|
||||||
|
}
|
||||||
|
views = json.loads(arg.views)
|
||||||
|
for i, view in enumerate(views):
|
||||||
|
cam.location = (
|
||||||
|
view['radius'] * np.cos(view['yaw']) * np.cos(view['pitch']),
|
||||||
|
view['radius'] * np.sin(view['yaw']) * np.cos(view['pitch']),
|
||||||
|
view['radius'] * np.sin(view['pitch'])
|
||||||
|
)
|
||||||
|
cam.data.lens = 16 / np.tan(view['fov'] / 2)
|
||||||
|
|
||||||
|
if arg.save_depth:
|
||||||
|
spec_nodes['depth_map'].inputs[1].default_value = view['radius'] - 0.5 * np.sqrt(3)
|
||||||
|
spec_nodes['depth_map'].inputs[2].default_value = view['radius'] + 0.5 * np.sqrt(3)
|
||||||
|
|
||||||
|
bpy.context.scene.render.filepath = os.path.join(arg.output_folder, f'{i:03d}.png')
|
||||||
|
for name, output in outputs.items():
|
||||||
|
output.file_slots[0].path = os.path.join(arg.output_folder, f'{i:03d}_{name}')
|
||||||
|
|
||||||
|
# Render the scene
|
||||||
|
bpy.ops.render.render(write_still=True)
|
||||||
|
bpy.context.view_layer.update()
|
||||||
|
for name, output in outputs.items():
|
||||||
|
ext = EXT[output.format.file_format]
|
||||||
|
path = glob.glob(f'{output.file_slots[0].path}*.{ext}')[0]
|
||||||
|
os.rename(path, f'{output.file_slots[0].path}.{ext}')
|
||||||
|
|
||||||
|
# Save camera parameters
|
||||||
|
metadata = {
|
||||||
|
"file_path": f'{i:03d}.png',
|
||||||
|
"camera_angle_x": view['fov'],
|
||||||
|
"transform_matrix": get_transform_matrix(cam)
|
||||||
|
}
|
||||||
|
if arg.save_depth:
|
||||||
|
metadata['depth'] = {
|
||||||
|
'min': view['radius'] - 0.5 * np.sqrt(3),
|
||||||
|
'max': view['radius'] + 0.5 * np.sqrt(3)
|
||||||
|
}
|
||||||
|
to_export["frames"].append(metadata)
|
||||||
|
|
||||||
|
# Save the camera parameters
|
||||||
|
with open(os.path.join(arg.output_folder, 'transforms.json'), 'w') as f:
|
||||||
|
json.dump(to_export, f, indent=4)
|
||||||
|
|
||||||
|
if arg.save_mesh:
|
||||||
|
# triangulate meshes
|
||||||
|
unhide_all_objects()
|
||||||
|
convert_to_meshes()
|
||||||
|
triangulate_meshes()
|
||||||
|
print('[INFO] Meshes triangulated.')
|
||||||
|
|
||||||
|
# export ply mesh
|
||||||
|
bpy.ops.export_mesh.ply(filepath=os.path.join(arg.output_folder, 'mesh.ply'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
|
||||||
|
parser.add_argument('--views', type=str, help='JSON string of views. Contains a list of {yaw, pitch, radius, fov} object.')
|
||||||
|
parser.add_argument('--object', type=str, help='Path to the 3D model file to be rendered.')
|
||||||
|
parser.add_argument('--output_folder', type=str, default='/tmp', help='The path the output will be dumped to.')
|
||||||
|
parser.add_argument('--resolution', type=int, default=512, help='Resolution of the images.')
|
||||||
|
parser.add_argument('--engine', type=str, default='CYCLES', help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
|
||||||
|
parser.add_argument('--geo_mode', action='store_true', help='Geometry mode for rendering.')
|
||||||
|
parser.add_argument('--save_depth', action='store_true', help='Save the depth maps.')
|
||||||
|
parser.add_argument('--save_normal', action='store_true', help='Save the normal maps.')
|
||||||
|
parser.add_argument('--save_albedo', action='store_true', help='Save the albedo maps.')
|
||||||
|
parser.add_argument('--save_mist', action='store_true', help='Save the mist distance maps.')
|
||||||
|
parser.add_argument('--split_normal', action='store_true', help='Split the normals of the mesh.')
|
||||||
|
parser.add_argument('--save_mesh', action='store_true', help='Save the mesh as a .ply file.')
|
||||||
|
argv = sys.argv[sys.argv.index("--") + 1:]
|
||||||
|
args = parser.parse_args(argv)
|
||||||
|
|
||||||
|
main(args)
|
||||||
|
|
||||||
121
dataset_toolkits/render.py
Normal file
121
dataset_toolkits/render.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
import sys
|
||||||
|
import importlib
|
||||||
|
import argparse
|
||||||
|
import pandas as pd
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from functools import partial
|
||||||
|
from subprocess import DEVNULL, call
|
||||||
|
import numpy as np
|
||||||
|
from utils import sphere_hammersley_sequence
|
||||||
|
|
||||||
|
|
||||||
|
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
|
||||||
|
BLENDER_INSTALLATION_PATH = '/tmp'
|
||||||
|
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
|
||||||
|
|
||||||
|
def _install_blender():
|
||||||
|
if not os.path.exists(BLENDER_PATH):
|
||||||
|
os.system('sudo apt-get update')
|
||||||
|
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
|
||||||
|
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
||||||
|
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
||||||
|
|
||||||
|
|
||||||
|
def _render(file_path, sha256, output_dir, num_views):
|
||||||
|
output_folder = os.path.join(output_dir, 'renders', sha256)
|
||||||
|
|
||||||
|
# Build camera {yaw, pitch, radius, fov}
|
||||||
|
yaws = []
|
||||||
|
pitchs = []
|
||||||
|
offset = (np.random.rand(), np.random.rand())
|
||||||
|
for i in range(num_views):
|
||||||
|
y, p = sphere_hammersley_sequence(i, num_views, offset)
|
||||||
|
yaws.append(y)
|
||||||
|
pitchs.append(p)
|
||||||
|
radius = [2] * num_views
|
||||||
|
fov = [40 / 180 * np.pi] * num_views
|
||||||
|
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
|
||||||
|
|
||||||
|
args = [
|
||||||
|
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
|
||||||
|
'--',
|
||||||
|
'--views', json.dumps(views),
|
||||||
|
'--object', os.path.expanduser(file_path),
|
||||||
|
'--resolution', '512',
|
||||||
|
'--output_folder', output_folder,
|
||||||
|
'--engine', 'CYCLES',
|
||||||
|
'--save_mesh',
|
||||||
|
]
|
||||||
|
if file_path.endswith('.blend'):
|
||||||
|
args.insert(1, file_path)
|
||||||
|
|
||||||
|
call(args, stdout=DEVNULL, stderr=DEVNULL)
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
|
||||||
|
return {'sha256': sha256, 'rendered': True}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--output_dir', type=str, required=True,
|
||||||
|
help='Directory to save the metadata')
|
||||||
|
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||||
|
help='Filter objects with aesthetic score lower than this value')
|
||||||
|
parser.add_argument('--instances', type=str, default=None,
|
||||||
|
help='Instances to process')
|
||||||
|
parser.add_argument('--num_views', type=int, default=150,
|
||||||
|
help='Number of views to render')
|
||||||
|
dataset_utils.add_args(parser)
|
||||||
|
parser.add_argument('--rank', type=int, default=0)
|
||||||
|
parser.add_argument('--world_size', type=int, default=1)
|
||||||
|
parser.add_argument('--max_workers', type=int, default=8)
|
||||||
|
opt = parser.parse_args(sys.argv[2:])
|
||||||
|
opt = edict(vars(opt))
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(opt.output_dir, 'renders'), exist_ok=True)
|
||||||
|
|
||||||
|
# install blender
|
||||||
|
print('Checking blender...', flush=True)
|
||||||
|
_install_blender()
|
||||||
|
|
||||||
|
# get file list
|
||||||
|
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
||||||
|
raise ValueError('metadata.csv not found')
|
||||||
|
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||||
|
if opt.instances is None:
|
||||||
|
metadata = metadata[metadata['local_path'].notna()]
|
||||||
|
if opt.filter_low_aesthetic_score is not None:
|
||||||
|
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||||
|
if 'rendered' in metadata.columns:
|
||||||
|
metadata = metadata[metadata['rendered'] == False]
|
||||||
|
else:
|
||||||
|
if os.path.exists(opt.instances):
|
||||||
|
with open(opt.instances, 'r') as f:
|
||||||
|
instances = f.read().splitlines()
|
||||||
|
else:
|
||||||
|
instances = opt.instances.split(',')
|
||||||
|
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||||
|
|
||||||
|
start = len(metadata) * opt.rank // opt.world_size
|
||||||
|
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||||
|
metadata = metadata[start:end]
|
||||||
|
records = []
|
||||||
|
|
||||||
|
# filter out objects that are already processed
|
||||||
|
for sha256 in copy.copy(metadata['sha256'].values):
|
||||||
|
if os.path.exists(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json')):
|
||||||
|
records.append({'sha256': sha256, 'rendered': True})
|
||||||
|
metadata = metadata[metadata['sha256'] != sha256]
|
||||||
|
|
||||||
|
print(f'Processing {len(metadata)} objects...')
|
||||||
|
|
||||||
|
# process objects
|
||||||
|
func = partial(_render, output_dir=opt.output_dir, num_views=opt.num_views)
|
||||||
|
rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
|
||||||
|
rendered = pd.concat([rendered, pd.DataFrame.from_records(records)])
|
||||||
|
rendered.to_csv(os.path.join(opt.output_dir, f'rendered_{opt.rank}.csv'), index=False)
|
||||||
125
dataset_toolkits/render_cond.py
Normal file
125
dataset_toolkits/render_cond.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
import sys
|
||||||
|
import importlib
|
||||||
|
import argparse
|
||||||
|
import pandas as pd
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from functools import partial
|
||||||
|
from subprocess import DEVNULL, call
|
||||||
|
import numpy as np
|
||||||
|
from utils import sphere_hammersley_sequence
|
||||||
|
|
||||||
|
|
||||||
|
BLENDER_LINK = 'https://download.blender.org/release/Blender3.0/blender-3.0.1-linux-x64.tar.xz'
|
||||||
|
BLENDER_INSTALLATION_PATH = '/tmp'
|
||||||
|
BLENDER_PATH = f'{BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64/blender'
|
||||||
|
|
||||||
|
def _install_blender():
|
||||||
|
if not os.path.exists(BLENDER_PATH):
|
||||||
|
os.system('sudo apt-get update')
|
||||||
|
os.system('sudo apt-get install -y libxrender1 libxi6 libxkbcommon-x11-0 libsm6')
|
||||||
|
os.system(f'wget {BLENDER_LINK} -P {BLENDER_INSTALLATION_PATH}')
|
||||||
|
os.system(f'tar -xvf {BLENDER_INSTALLATION_PATH}/blender-3.0.1-linux-x64.tar.xz -C {BLENDER_INSTALLATION_PATH}')
|
||||||
|
|
||||||
|
|
||||||
|
def _render_cond(file_path, sha256, output_dir, num_views):
|
||||||
|
output_folder = os.path.join(output_dir, 'renders_cond', sha256)
|
||||||
|
|
||||||
|
# Build camera {yaw, pitch, radius, fov}
|
||||||
|
yaws = []
|
||||||
|
pitchs = []
|
||||||
|
offset = (np.random.rand(), np.random.rand())
|
||||||
|
for i in range(num_views):
|
||||||
|
y, p = sphere_hammersley_sequence(i, num_views, offset)
|
||||||
|
yaws.append(y)
|
||||||
|
pitchs.append(p)
|
||||||
|
fov_min, fov_max = 10, 70
|
||||||
|
radius_min = np.sqrt(3) / 2 / np.sin(fov_max / 360 * np.pi)
|
||||||
|
radius_max = np.sqrt(3) / 2 / np.sin(fov_min / 360 * np.pi)
|
||||||
|
k_min = 1 / radius_max**2
|
||||||
|
k_max = 1 / radius_min**2
|
||||||
|
ks = np.random.uniform(k_min, k_max, (1000000,))
|
||||||
|
radius = [1 / np.sqrt(k) for k in ks]
|
||||||
|
fov = [2 * np.arcsin(np.sqrt(3) / 2 / r) for r in radius]
|
||||||
|
views = [{'yaw': y, 'pitch': p, 'radius': r, 'fov': f} for y, p, r, f in zip(yaws, pitchs, radius, fov)]
|
||||||
|
|
||||||
|
args = [
|
||||||
|
BLENDER_PATH, '-b', '-P', os.path.join(os.path.dirname(__file__), 'blender_script', 'render.py'),
|
||||||
|
'--',
|
||||||
|
'--views', json.dumps(views),
|
||||||
|
'--object', os.path.expanduser(file_path),
|
||||||
|
'--output_folder', os.path.expanduser(output_folder),
|
||||||
|
'--resolution', '1024',
|
||||||
|
]
|
||||||
|
if file_path.endswith('.blend'):
|
||||||
|
args.insert(1, file_path)
|
||||||
|
|
||||||
|
call(args, stdout=DEVNULL)
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(output_folder, 'transforms.json')):
|
||||||
|
return {'sha256': sha256, 'cond_rendered': True}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--output_dir', type=str, required=True,
|
||||||
|
help='Directory to save the metadata')
|
||||||
|
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||||
|
help='Filter objects with aesthetic score lower than this value')
|
||||||
|
parser.add_argument('--instances', type=str, default=None,
|
||||||
|
help='Instances to process')
|
||||||
|
parser.add_argument('--num_views', type=int, default=24,
|
||||||
|
help='Number of views to render')
|
||||||
|
dataset_utils.add_args(parser)
|
||||||
|
parser.add_argument('--rank', type=int, default=0)
|
||||||
|
parser.add_argument('--world_size', type=int, default=1)
|
||||||
|
parser.add_argument('--max_workers', type=int, default=8)
|
||||||
|
opt = parser.parse_args(sys.argv[2:])
|
||||||
|
opt = edict(vars(opt))
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(opt.output_dir, 'renders_cond'), exist_ok=True)
|
||||||
|
|
||||||
|
# install blender
|
||||||
|
print('Checking blender...', flush=True)
|
||||||
|
_install_blender()
|
||||||
|
|
||||||
|
# get file list
|
||||||
|
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
||||||
|
raise ValueError('metadata.csv not found')
|
||||||
|
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||||
|
if opt.instances is None:
|
||||||
|
metadata = metadata[metadata['local_path'].notna()]
|
||||||
|
if opt.filter_low_aesthetic_score is not None:
|
||||||
|
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||||
|
if 'cond_rendered' in metadata.columns:
|
||||||
|
metadata = metadata[metadata['cond_rendered'] == False]
|
||||||
|
else:
|
||||||
|
if os.path.exists(opt.instances):
|
||||||
|
with open(opt.instances, 'r') as f:
|
||||||
|
instances = f.read().splitlines()
|
||||||
|
else:
|
||||||
|
instances = opt.instances.split(',')
|
||||||
|
metadata = metadata[metadata['sha256'].isin(instances)]
|
||||||
|
|
||||||
|
start = len(metadata) * opt.rank // opt.world_size
|
||||||
|
end = len(metadata) * (opt.rank + 1) // opt.world_size
|
||||||
|
metadata = metadata[start:end]
|
||||||
|
records = []
|
||||||
|
|
||||||
|
# filter out objects that are already processed
|
||||||
|
for sha256 in copy.copy(metadata['sha256'].values):
|
||||||
|
if os.path.exists(os.path.join(opt.output_dir, 'renders_cond', sha256, 'transforms.json')):
|
||||||
|
records.append({'sha256': sha256, 'cond_rendered': True})
|
||||||
|
metadata = metadata[metadata['sha256'] != sha256]
|
||||||
|
|
||||||
|
print(f'Processing {len(metadata)} objects...')
|
||||||
|
|
||||||
|
# process objects
|
||||||
|
func = partial(_render_cond, output_dir=opt.output_dir, num_views=opt.num_views)
|
||||||
|
cond_rendered = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Rendering objects')
|
||||||
|
cond_rendered = pd.concat([cond_rendered, pd.DataFrame.from_records(records)])
|
||||||
|
cond_rendered.to_csv(os.path.join(opt.output_dir, f'cond_rendered_{opt.rank}.csv'), index=False)
|
||||||
116
render_model.py
Normal file
116
render_model.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import subprocess
|
||||||
|
|
||||||
|
import bpy
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from mathutils import Vector
|
||||||
|
|
||||||
|
|
||||||
|
def clear_scene():
|
||||||
|
bpy.ops.object.select_all(action='SELECT')
|
||||||
|
bpy.ops.object.delete(use_global=False)
|
||||||
|
|
||||||
|
|
||||||
|
def import_model(model_path):
|
||||||
|
ext = os.path.splitext(model_path)[1].lower()
|
||||||
|
if ext == ".obj":
|
||||||
|
bpy.ops.wm.obj_import(filepath=model_path)
|
||||||
|
elif ext in [".glb", ".gltf"]:
|
||||||
|
bpy.ops.import_scene.gltf(filepath=model_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported format: {ext}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_scene_bbox():
|
||||||
|
objs = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH']
|
||||||
|
if not objs:
|
||||||
|
raise RuntimeError("No mesh objects found")
|
||||||
|
|
||||||
|
min_corner = Vector((float("inf"), float("inf"), float("inf")))
|
||||||
|
max_corner = Vector((float("-inf"), float("-inf"), float("-inf")))
|
||||||
|
|
||||||
|
for obj in objs:
|
||||||
|
for v in obj.bound_box:
|
||||||
|
world_v = obj.matrix_world @ Vector(v)
|
||||||
|
min_corner.x = min(min_corner.x, world_v.x)
|
||||||
|
min_corner.y = min(min_corner.y, world_v.y)
|
||||||
|
min_corner.z = min(min_corner.z, world_v.z)
|
||||||
|
max_corner.x = max(max_corner.x, world_v.x)
|
||||||
|
max_corner.y = max(max_corner.y, world_v.y)
|
||||||
|
max_corner.z = max(max_corner.z, world_v.z)
|
||||||
|
|
||||||
|
return min_corner, max_corner
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_model():
|
||||||
|
min_corner, max_corner = get_scene_bbox()
|
||||||
|
center = (min_corner + max_corner) / 2
|
||||||
|
size = max(max_corner.x - min_corner.x,
|
||||||
|
max(max_corner.y - min_corner.y, max_corner.z - min_corner.z))
|
||||||
|
|
||||||
|
objs = [obj for obj in bpy.context.scene.objects if obj.type == 'MESH']
|
||||||
|
for obj in objs:
|
||||||
|
obj.location -= center
|
||||||
|
|
||||||
|
if size > 0:
|
||||||
|
scale = 2.0 / size
|
||||||
|
for obj in objs:
|
||||||
|
obj.scale *= scale
|
||||||
|
|
||||||
|
bpy.context.view_layer.update()
|
||||||
|
|
||||||
|
|
||||||
|
def add_camera():
|
||||||
|
bpy.ops.object.camera_add(location=(2.5, -2.5, 2.0))
|
||||||
|
cam = bpy.context.active_object
|
||||||
|
direction = Vector((0, 0, 0)) - cam.location
|
||||||
|
cam.rotation_euler = direction.to_track_quat('-Z', 'Y').to_euler()
|
||||||
|
bpy.context.scene.camera = cam
|
||||||
|
|
||||||
|
|
||||||
|
def add_light():
|
||||||
|
bpy.ops.object.light_add(type='SUN', location=(5, -5, 8))
|
||||||
|
bpy.context.active_object.data.energy = 3.0
|
||||||
|
|
||||||
|
bpy.ops.object.light_add(type='AREA', location=(3, -3, 3))
|
||||||
|
area = bpy.context.active_object
|
||||||
|
area.data.energy = 3000
|
||||||
|
area.data.size = 5
|
||||||
|
|
||||||
|
|
||||||
|
def setup_render(output_path, resolution=1024):
|
||||||
|
scene = bpy.context.scene
|
||||||
|
scene.render.engine = 'CYCLES'
|
||||||
|
scene.cycles.samples = 128
|
||||||
|
scene.render.filepath = output_path
|
||||||
|
scene.render.image_settings.file_format = 'PNG'
|
||||||
|
scene.render.resolution_x = resolution
|
||||||
|
scene.render.resolution_y = resolution
|
||||||
|
scene.render.film_transparent = True
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
argv = sys.argv
|
||||||
|
argv = argv[argv.index("--") + 1:] if "--" in argv else []
|
||||||
|
if len(argv) < 2:
|
||||||
|
raise ValueError("Usage: blender -b -P render_model.py -- model_path output_path")
|
||||||
|
return argv[0], argv[1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_static_model_image():
|
||||||
|
model_path, output_path = parse_args()
|
||||||
|
clear_scene()
|
||||||
|
import_model(model_path)
|
||||||
|
normalize_model()
|
||||||
|
add_camera()
|
||||||
|
add_light()
|
||||||
|
setup_render(output_path)
|
||||||
|
bpy.ops.render.render(write_still=True)
|
||||||
|
print(f"Saved to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
get_static_model_image()
|
||||||
439
server.py
Normal file
439
server.py
Normal file
@@ -0,0 +1,439 @@
|
|||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
import trimesh
|
||||||
|
|
||||||
|
import litserve as ls
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
from glb2svg import glb_to_obj, obj_to_step, step_to_svg
|
||||||
|
from trellis.pipelines import TrellisImageTo3DPipeline
|
||||||
|
from trellis.utils import render_utils, postprocessing_utils
|
||||||
|
from utils.new_oss_client import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, minio_get_image, MINIO_BUCKET, upload_bytes, upload_local_file, download_from_minio
|
||||||
|
|
||||||
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_unique_name(original_name: str) -> str:
|
||||||
|
stem, ext = os.path.splitext(original_name)
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
random_part = secrets.token_hex(4) # 8位随机十六进制
|
||||||
|
return f"{stem}_{timestamp}_{random_part}{ext}"
|
||||||
|
|
||||||
|
|
||||||
|
def load_mesh(file_path):
|
||||||
|
"""
|
||||||
|
加载.obj或.glb文件,返回顶点数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 模型文件路径(支持.obj和.glb格式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: 顶点数组,shape为(N, 3)
|
||||||
|
"""
|
||||||
|
file_ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
if file_ext == '.obj':
|
||||||
|
# 使用trimesh加载obj文件
|
||||||
|
mesh = trimesh.load(file_path, file_type='obj')
|
||||||
|
elif file_ext == '.glb' or file_ext == '.gltf':
|
||||||
|
# 使用trimesh加载glb/gltf文件
|
||||||
|
mesh = trimesh.load(file_path, file_type='glb')
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的文件格式: {file_ext},仅支持.obj和.glb/.gltf")
|
||||||
|
|
||||||
|
# 获取顶点数据
|
||||||
|
if isinstance(mesh, trimesh.Scene):
|
||||||
|
# 如果是场景,合并所有几何体的顶点
|
||||||
|
vertices = []
|
||||||
|
for geom in mesh.geometry.values():
|
||||||
|
vertices.append(geom.vertices)
|
||||||
|
vertices = np.vstack(vertices)
|
||||||
|
else:
|
||||||
|
# 如果是单个几何体
|
||||||
|
vertices = mesh.vertices
|
||||||
|
|
||||||
|
if len(vertices) == 0:
|
||||||
|
raise ValueError("文件中未找到顶点数据")
|
||||||
|
|
||||||
|
return vertices
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_mesh(file_path):
|
||||||
|
"""
|
||||||
|
分析3D模型文件,计算质心、边界框、尺寸等信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 模型文件路径(支持.obj和.glb格式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含模型分析信息的字典
|
||||||
|
"""
|
||||||
|
# 加载模型并获取顶点
|
||||||
|
vertices = load_mesh(file_path)
|
||||||
|
|
||||||
|
# 边界框(每个轴的最小/最大值)
|
||||||
|
min_coords = vertices.min(axis=0)
|
||||||
|
max_coords = vertices.max(axis=0)
|
||||||
|
|
||||||
|
# 质心
|
||||||
|
centroid = vertices.mean(axis=0)
|
||||||
|
|
||||||
|
# 尺寸 = 边界框维度
|
||||||
|
size = max_coords - min_coords
|
||||||
|
|
||||||
|
# 计算尺寸比例(每个轴占总尺寸的比例)
|
||||||
|
total_size = np.sum(size)
|
||||||
|
size_ratio = size / total_size if total_size != 0 else [0, 0, 0]
|
||||||
|
|
||||||
|
info = {
|
||||||
|
# "file_path": file_path,
|
||||||
|
"file_format": os.path.splitext(file_path)[1].lower(),
|
||||||
|
"vertex_count": len(vertices),
|
||||||
|
"centroid": centroid.tolist(),
|
||||||
|
"bounding_box_min": min_coords.tolist(),
|
||||||
|
"bounding_box_max": max_coords.tolist(),
|
||||||
|
"size": size.tolist(),
|
||||||
|
"size_ratio": size_ratio.tolist(),
|
||||||
|
"size_ratio_percentage": (size_ratio * 100).tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def render_glb_preview(glb_path, output_path):
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"blender",
|
||||||
|
"--background",
|
||||||
|
"--python",
|
||||||
|
"render_model.py",
|
||||||
|
"--",
|
||||||
|
glb_path,
|
||||||
|
output_path
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Blender render failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
class TrellisAPI(ls.LitAPI):
|
||||||
|
|
||||||
|
def setup(self, device):
|
||||||
|
os.environ.setdefault("SPCONV_ALGO", "native")
|
||||||
|
|
||||||
|
self.pipeline = TrellisImageTo3DPipeline.from_pretrained(
|
||||||
|
"microsoft/TRELLIS-image-large"
|
||||||
|
)
|
||||||
|
self.pipeline.to(device)
|
||||||
|
|
||||||
|
def decode_request(self, request):
|
||||||
|
image_paths = request["image_paths"]
|
||||||
|
images = []
|
||||||
|
for path in image_paths:
|
||||||
|
bucket = path.split('/')[0]
|
||||||
|
object_name = path[path.find('/') + 1:]
|
||||||
|
|
||||||
|
image = minio_get_image(minio_client, bucket, object_name)
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"file_name": uuid.uuid4().hex,
|
||||||
|
"model": request.get("model", "single"),
|
||||||
|
"seed": request.get("seed", 1),
|
||||||
|
"steps_sparse": request.get("steps_sparse", 12),
|
||||||
|
"cfg_sparse": request.get("cfg_sparse", 7.5),
|
||||||
|
"steps_slat": request.get("steps_slat", 12),
|
||||||
|
"cfg_slat": request.get("cfg_slat", 3.0),
|
||||||
|
"simplify": request.get("simplify", 0.95),
|
||||||
|
"texture_size": request.get("texture_size", 1024),
|
||||||
|
"fps": request.get("fps", 30),
|
||||||
|
}
|
||||||
|
|
||||||
|
return images, params
|
||||||
|
|
||||||
|
def predict(self, inputs):
|
||||||
|
images, params = inputs
|
||||||
|
if params["model"] == "single":
|
||||||
|
outputs = self.pipeline.run(
|
||||||
|
images[0],
|
||||||
|
seed=params["seed"],
|
||||||
|
sparse_structure_sampler_params={
|
||||||
|
"steps": params["steps_sparse"],
|
||||||
|
"cfg_strength": params["cfg_sparse"],
|
||||||
|
},
|
||||||
|
slat_sampler_params={
|
||||||
|
"steps": params["steps_slat"],
|
||||||
|
"cfg_strength": params["cfg_slat"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = self.pipeline.run_multi_image(
|
||||||
|
images,
|
||||||
|
seed=params['seed'],
|
||||||
|
sparse_structure_sampler_params={
|
||||||
|
"steps": params['steps_sparse'],
|
||||||
|
"cfg_strength": params['cfg_sparse'],
|
||||||
|
},
|
||||||
|
slat_sampler_params={
|
||||||
|
"steps": params['steps_slat'],
|
||||||
|
"cfg_strength": params['cfg_slat'],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# video_path = self.upload_video(outputs, params)
|
||||||
|
|
||||||
|
minio_glb_path, local_glb_path = self.upload_glb(outputs, params)
|
||||||
|
|
||||||
|
glb_info = analyze_mesh(local_glb_path)
|
||||||
|
|
||||||
|
local_static_model_image_path = os.path.join("glb_output", generate_unique_name("static_model_image.png"))
|
||||||
|
static_model_image = self.get_static_model_image(model_path=local_glb_path, output_path=local_static_model_image_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"glb_path": minio_glb_path,
|
||||||
|
"glb_static_img_path": static_model_image,
|
||||||
|
"glb_info": glb_info,
|
||||||
|
}
|
||||||
|
|
||||||
|
def encode_response(self, output):
|
||||||
|
return output
|
||||||
|
|
||||||
|
def upload_video(self, outputs, params):
|
||||||
|
gaussian_name = f"3d_result/video/{params['file_name']}-gaussian.mp4"
|
||||||
|
radiance_field_name = f"3d_result/video/{params['file_name']}-radiance_field.mp4"
|
||||||
|
mesh_name = f"3d_result/video/{params['file_name']}-mesh.mp4"
|
||||||
|
|
||||||
|
# gaussian video
|
||||||
|
video = render_utils.render_video(outputs["gaussian"][0])["color"]
|
||||||
|
buffer = BytesIO()
|
||||||
|
imageio.mimsave(buffer, video, format="mp4", fps=params['fps'])
|
||||||
|
gaussian_video_path = upload_bytes(
|
||||||
|
buffer.getvalue(),
|
||||||
|
gaussian_name,
|
||||||
|
"video/mp4",
|
||||||
|
)
|
||||||
|
|
||||||
|
# radiance field video
|
||||||
|
video = render_utils.render_video(outputs["radiance_field"][0])["color"]
|
||||||
|
buffer = BytesIO()
|
||||||
|
imageio.mimsave(buffer, video, format="mp4", fps=params['fps'])
|
||||||
|
radiance_field_video_path = upload_bytes(
|
||||||
|
buffer.getvalue(),
|
||||||
|
radiance_field_name,
|
||||||
|
"video/mp4",
|
||||||
|
)
|
||||||
|
|
||||||
|
# mesh video
|
||||||
|
video = render_utils.render_video(outputs["mesh"][0])["normal"]
|
||||||
|
buffer = BytesIO()
|
||||||
|
imageio.mimsave(buffer, video, format="mp4", fps=params['fps'])
|
||||||
|
mesh_path = upload_bytes(
|
||||||
|
buffer.getvalue(),
|
||||||
|
mesh_name,
|
||||||
|
"video/mp4",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"gaussian": gaussian_video_path,
|
||||||
|
"radiance_field": radiance_field_video_path,
|
||||||
|
"mesh": mesh_path
|
||||||
|
}
|
||||||
|
|
||||||
|
def upload_glb(self, outputs, params):
|
||||||
|
file_name = f"3d_result/glb/{params['file_name']}.glb"
|
||||||
|
local_glb_path = os.path.join("glb_output", generate_unique_name("sample.glb"))
|
||||||
|
out_dir = os.path.dirname(local_glb_path)
|
||||||
|
if out_dir:
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
glb = postprocessing_utils.to_glb(
|
||||||
|
outputs["gaussian"][0],
|
||||||
|
outputs["mesh"][0],
|
||||||
|
simplify=params['simplify'],
|
||||||
|
texture_size=params['texture_size'],
|
||||||
|
)
|
||||||
|
|
||||||
|
glb.export(
|
||||||
|
file_obj=local_glb_path,
|
||||||
|
file_type="glb"
|
||||||
|
)
|
||||||
|
|
||||||
|
glb_path = upload_local_file(
|
||||||
|
local_glb_path,
|
||||||
|
file_name,
|
||||||
|
"application/octet-stream",
|
||||||
|
)
|
||||||
|
return glb_path, local_glb_path
|
||||||
|
|
||||||
|
def upload_ply(self, outputs, params):
|
||||||
|
file_name = f"3d_result/ply/{params['file_name']}.ply"
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".ply") as tmp:
|
||||||
|
outputs["gaussian"][0].save_ply(tmp.name)
|
||||||
|
tmp.seek(0)
|
||||||
|
|
||||||
|
ply_path = upload_bytes(
|
||||||
|
tmp.read(),
|
||||||
|
file_name,
|
||||||
|
"application/octet-stream",
|
||||||
|
)
|
||||||
|
return {"ply": ply_path}
|
||||||
|
|
||||||
|
def get_static_model_image(self, model_path, output_path):
|
||||||
|
local_static_model_image_path = os.path.join(
|
||||||
|
"glb_output",
|
||||||
|
generate_unique_name("static_model_image.png")
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"model_path : {model_path}")
|
||||||
|
print(f"local_static_model_image_path :{local_static_model_image_path}")
|
||||||
|
output_path = render_glb_preview(model_path, local_static_model_image_path)
|
||||||
|
|
||||||
|
static_model_image = self.upload_local_file(
|
||||||
|
output_path,
|
||||||
|
"png"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Saved to {static_model_image}")
|
||||||
|
return static_model_image
|
||||||
|
|
||||||
|
def upload_local_file(self, local_path, type):
|
||||||
|
"""
|
||||||
|
通用上传函数:支持 SVG, PNG, OBJ 等
|
||||||
|
"""
|
||||||
|
object_name = f"3d_result/{type}/{uuid.uuid4().hex}.{type}"
|
||||||
|
if not os.path.exists(local_path):
|
||||||
|
print(f"错误: 文件 {local_path} 不存在")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 自动根据后缀名识别 Content-Type
|
||||||
|
# 例如: .svg -> image/svg+xml, .png -> image/png
|
||||||
|
content_type, _ = mimetypes.guess_type(local_path)
|
||||||
|
if content_type is None:
|
||||||
|
content_type = "application/octet-stream"
|
||||||
|
|
||||||
|
try:
|
||||||
|
minio_client.fput_object(
|
||||||
|
bucket_name=MINIO_BUCKET,
|
||||||
|
object_name=object_name,
|
||||||
|
file_path=local_path,
|
||||||
|
content_type=content_type
|
||||||
|
)
|
||||||
|
print(f"成功上传 [{content_type}]: {object_name}")
|
||||||
|
return f"{MINIO_BUCKET}/{object_name}"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"上传失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelToThreeViews(ls.LitAPI):
|
||||||
|
|
||||||
|
def setup(self, device):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def upload_local_file(self, local_path, type):
|
||||||
|
"""
|
||||||
|
通用上传函数:支持 SVG, PNG, OBJ 等
|
||||||
|
"""
|
||||||
|
object_name = f"3d_result/{type}/{uuid.uuid4().hex}.{type}"
|
||||||
|
if not os.path.exists(local_path):
|
||||||
|
print(f"错误: 文件 {local_path} 不存在")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 自动根据后缀名识别 Content-Type
|
||||||
|
# 例如: .svg -> image/svg+xml, .png -> image/png
|
||||||
|
content_type, _ = mimetypes.guess_type(local_path)
|
||||||
|
if content_type is None:
|
||||||
|
content_type = "application/octet-stream"
|
||||||
|
|
||||||
|
try:
|
||||||
|
minio_client.fput_object(
|
||||||
|
bucket_name=MINIO_BUCKET,
|
||||||
|
object_name=object_name,
|
||||||
|
file_path=local_path,
|
||||||
|
content_type=content_type
|
||||||
|
)
|
||||||
|
print(f"成功上传 [{content_type}]: {object_name}")
|
||||||
|
return f"{MINIO_BUCKET}/{object_name}"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"上传失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def predict(self, request):
|
||||||
|
minio_glb_path = request['minio_glb_path']
|
||||||
|
|
||||||
|
work_dir = f"glb_to_obj"
|
||||||
|
os.makedirs(work_dir, exist_ok=True)
|
||||||
|
|
||||||
|
glb_path = os.path.join(work_dir, f"model{uuid.uuid4().hex}.glb")
|
||||||
|
step_dir = os.path.join(work_dir, "step")
|
||||||
|
svg_dir = os.path.join(work_dir, "svg")
|
||||||
|
os.makedirs(step_dir, exist_ok=True)
|
||||||
|
os.makedirs(svg_dir, exist_ok=True)
|
||||||
|
print(f"""
|
||||||
|
入参阶段:
|
||||||
|
input glb-obj minio-path:{minio_glb_path},\n
|
||||||
|
work_dir : {work_dir},glb_path : {glb_path},step_dir : {step_dir},svg_dir : {svg_dir}\n
|
||||||
|
""")
|
||||||
|
print("=" * 10)
|
||||||
|
|
||||||
|
print(f" 第一阶段 下载glb文件: ")
|
||||||
|
# 1 下载
|
||||||
|
glb_result = download_from_minio(object_path=minio_glb_path, local_path=glb_path)
|
||||||
|
print(f" 下载结果 : {glb_result} \n")
|
||||||
|
print("=" * 10)
|
||||||
|
|
||||||
|
print(f" 第二阶段 glb -> obj: ")
|
||||||
|
# 2 glb -> obj
|
||||||
|
obj_result = glb_to_obj(glb_result)
|
||||||
|
print(f" glb -> obj 结果 : {obj_result} \n")
|
||||||
|
print("=" * 10)
|
||||||
|
|
||||||
|
print(f" 第三阶段 obj -> step: ")
|
||||||
|
# 3 obj -> step
|
||||||
|
step_result = obj_to_step(
|
||||||
|
input_obj=obj_result,
|
||||||
|
output_dir=step_dir,
|
||||||
|
script_path="1_obj_to_step.py"
|
||||||
|
)
|
||||||
|
print(f" obj -> step 结果 : {step_result} \n")
|
||||||
|
print("=" * 10)
|
||||||
|
|
||||||
|
print(f" 第四阶段 step -> svg: ")
|
||||||
|
# 4 step -> svg
|
||||||
|
combined_svg, combined_png = step_to_svg(
|
||||||
|
step_path=step_result,
|
||||||
|
out_dir=svg_dir
|
||||||
|
)
|
||||||
|
print(f" step -> svg 结果 : {combined_svg} \n")
|
||||||
|
print("=" * 10)
|
||||||
|
|
||||||
|
# 5 上传
|
||||||
|
minio_svg_path = self.upload_local_file(combined_png, "svg")
|
||||||
|
|
||||||
|
return {"minio_svg_path": minio_svg_path}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
trellis_api = TrellisAPI(api_path="/canvas/img_to_3D")
|
||||||
|
model_to_three_api = ModelToThreeViews(api_path="/canvas/3d_to_3views")
|
||||||
|
server = ls.LitServer([
|
||||||
|
trellis_api,
|
||||||
|
model_to_three_api])
|
||||||
|
server.run(port=8120)
|
||||||
95
single_image_to_3D.py
Normal file
95
single_image_to_3D.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from trellis.pipelines import TrellisImageTo3DPipeline
|
||||||
|
from trellis.utils import render_utils, postprocessing_utils
|
||||||
|
|
||||||
|
def build_parser():
|
||||||
|
p = argparse.ArgumentParser("TRELLIS CLI: single image -> 3D")
|
||||||
|
|
||||||
|
p.add_argument("-i", "--image", required=True, help="Input image path")
|
||||||
|
p.add_argument("-o", "--out_dir", default="trellis_out", help="Output directory")
|
||||||
|
|
||||||
|
p.add_argument("--seed", type=int, default=1)
|
||||||
|
p.add_argument("--steps_sparse", type=int, default=12)
|
||||||
|
p.add_argument("--cfg_sparse", type=float, default=7.5)
|
||||||
|
p.add_argument("--steps_slat", type=int, default=12)
|
||||||
|
p.add_argument("--cfg_slat", type=float, default=3.0)
|
||||||
|
|
||||||
|
p.add_argument("--simplify", type=float, default=0.95)
|
||||||
|
p.add_argument("--texture_size", type=int, default=1024)
|
||||||
|
|
||||||
|
# Export GLB (default True)
|
||||||
|
p.add_argument("--export_glb", dest="export_glb", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-export_glb", dest="export_glb", action="store_false")
|
||||||
|
|
||||||
|
# Save PLY (default True)
|
||||||
|
p.add_argument("--save_ply", dest="save_ply", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-save_ply", dest="save_ply", action="store_false")
|
||||||
|
|
||||||
|
# Save videos (default False, plus explicit toggle)
|
||||||
|
p.add_argument("--save_video", dest="save_video", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-save_video", dest="save_video", action="store_false")
|
||||||
|
|
||||||
|
p.add_argument("--fps", type=int, default=30)
|
||||||
|
p.add_argument("--video_gs_name", type=str, default="sample_gs.mp4")
|
||||||
|
p.add_argument("--video_rf_name", type=str, default="sample_rf.mp4")
|
||||||
|
p.add_argument("--video_mesh_name", type=str, default="sample_mesh.mp4")
|
||||||
|
|
||||||
|
return p
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = build_parser().parse_args()
|
||||||
|
os.makedirs(args.out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Optional env
|
||||||
|
os.environ.setdefault("SPCONV_ALGO", "native")
|
||||||
|
|
||||||
|
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
|
||||||
|
pipeline.cuda()
|
||||||
|
|
||||||
|
image = Image.open(args.image)
|
||||||
|
|
||||||
|
outputs = pipeline.run(
|
||||||
|
image,
|
||||||
|
seed=args.seed,
|
||||||
|
sparse_structure_sampler_params={
|
||||||
|
"steps": args.steps_sparse,
|
||||||
|
"cfg_strength": args.cfg_sparse,
|
||||||
|
},
|
||||||
|
slat_sampler_params={
|
||||||
|
"steps": args.steps_slat,
|
||||||
|
"cfg_strength": args.cfg_slat,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.save_video:
|
||||||
|
import imageio
|
||||||
|
|
||||||
|
video = render_utils.render_video(outputs["gaussian"][0])["color"]
|
||||||
|
imageio.mimsave(os.path.join(args.out_dir, args.video_gs_name), video, fps=args.fps)
|
||||||
|
|
||||||
|
video = render_utils.render_video(outputs["radiance_field"][0])["color"]
|
||||||
|
imageio.mimsave(os.path.join(args.out_dir, args.video_rf_name), video, fps=args.fps)
|
||||||
|
|
||||||
|
video = render_utils.render_video(outputs["mesh"][0])["normal"]
|
||||||
|
imageio.mimsave(os.path.join(args.out_dir, args.video_mesh_name), video, fps=args.fps)
|
||||||
|
|
||||||
|
if args.export_glb:
|
||||||
|
glb = postprocessing_utils.to_glb(
|
||||||
|
outputs["gaussian"][0],
|
||||||
|
outputs["mesh"][0],
|
||||||
|
simplify=args.simplify,
|
||||||
|
texture_size=args.texture_size,
|
||||||
|
)
|
||||||
|
glb.export(os.path.join(args.out_dir, "sample.glb"))
|
||||||
|
|
||||||
|
if args.save_ply:
|
||||||
|
outputs["gaussian"][0].save_ply(os.path.join(args.out_dir, "sample.ply"))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
193
trellis/modules/sparse/attention/serialized_attn.py
Executable file
193
trellis/modules/sparse/attention/serialized_attn.py
Executable file
@@ -0,0 +1,193 @@
|
|||||||
|
from typing import *
|
||||||
|
from enum import Enum
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from .. import SparseTensor
|
||||||
|
from .. import DEBUG, ATTN
|
||||||
|
|
||||||
|
if ATTN == 'xformers':
|
||||||
|
import xformers.ops as xops
|
||||||
|
elif ATTN == 'flash_attn':
|
||||||
|
import flash_attn
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown attention module: {ATTN}")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'sparse_serialized_scaled_dot_product_self_attention',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SerializeMode(Enum):
|
||||||
|
Z_ORDER = 0
|
||||||
|
Z_ORDER_TRANSPOSED = 1
|
||||||
|
HILBERT = 2
|
||||||
|
HILBERT_TRANSPOSED = 3
|
||||||
|
|
||||||
|
|
||||||
|
SerializeModes = [
|
||||||
|
SerializeMode.Z_ORDER,
|
||||||
|
SerializeMode.Z_ORDER_TRANSPOSED,
|
||||||
|
SerializeMode.HILBERT,
|
||||||
|
SerializeMode.HILBERT_TRANSPOSED
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def calc_serialization(
|
||||||
|
tensor: SparseTensor,
|
||||||
|
window_size: int,
|
||||||
|
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
|
||||||
|
shift_sequence: int = 0,
|
||||||
|
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
||||||
|
"""
|
||||||
|
Calculate serialization and partitioning for a set of coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (SparseTensor): The input tensor.
|
||||||
|
window_size (int): The window size to use.
|
||||||
|
serialize_mode (SerializeMode): The serialization mode to use.
|
||||||
|
shift_sequence (int): The shift of serialized sequence.
|
||||||
|
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor, torch.Tensor): Forwards and backwards indices.
|
||||||
|
"""
|
||||||
|
fwd_indices = []
|
||||||
|
bwd_indices = []
|
||||||
|
seq_lens = []
|
||||||
|
seq_batch_indices = []
|
||||||
|
offsets = [0]
|
||||||
|
|
||||||
|
if 'vox2seq' not in globals():
|
||||||
|
import vox2seq
|
||||||
|
|
||||||
|
# Serialize the input
|
||||||
|
serialize_coords = tensor.coords[:, 1:].clone()
|
||||||
|
serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3)
|
||||||
|
if serialize_mode == SerializeMode.Z_ORDER:
|
||||||
|
code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2])
|
||||||
|
elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED:
|
||||||
|
code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2])
|
||||||
|
elif serialize_mode == SerializeMode.HILBERT:
|
||||||
|
code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2])
|
||||||
|
elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED:
|
||||||
|
code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown serialize mode: {serialize_mode}")
|
||||||
|
|
||||||
|
for bi, s in enumerate(tensor.layout):
|
||||||
|
num_points = s.stop - s.start
|
||||||
|
num_windows = (num_points + window_size - 1) // window_size
|
||||||
|
valid_window_size = num_points / num_windows
|
||||||
|
to_ordered = torch.argsort(code[s.start:s.stop])
|
||||||
|
if num_windows == 1:
|
||||||
|
fwd_indices.append(to_ordered)
|
||||||
|
bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device)))
|
||||||
|
fwd_indices[-1] += s.start
|
||||||
|
bwd_indices[-1] += offsets[-1]
|
||||||
|
seq_lens.append(num_points)
|
||||||
|
seq_batch_indices.append(bi)
|
||||||
|
offsets.append(offsets[-1] + seq_lens[-1])
|
||||||
|
else:
|
||||||
|
# Partition the input
|
||||||
|
offset = 0
|
||||||
|
mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)]
|
||||||
|
split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)]
|
||||||
|
bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device)
|
||||||
|
for i in range(num_windows):
|
||||||
|
mid = mids[i]
|
||||||
|
valid_start = split[i]
|
||||||
|
valid_end = split[i + 1]
|
||||||
|
padded_start = math.floor(mid - 0.5 * window_size)
|
||||||
|
padded_end = padded_start + window_size
|
||||||
|
fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points])
|
||||||
|
offset += valid_start - padded_start
|
||||||
|
bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device))
|
||||||
|
offset += padded_end - valid_start
|
||||||
|
fwd_indices[-1] += s.start
|
||||||
|
seq_lens.extend([window_size] * num_windows)
|
||||||
|
seq_batch_indices.extend([bi] * num_windows)
|
||||||
|
bwd_indices.append(bwd_index + offsets[-1])
|
||||||
|
offsets.append(offsets[-1] + num_windows * window_size)
|
||||||
|
|
||||||
|
fwd_indices = torch.cat(fwd_indices)
|
||||||
|
bwd_indices = torch.cat(bwd_indices)
|
||||||
|
|
||||||
|
return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_serialized_scaled_dot_product_self_attention(
|
||||||
|
qkv: SparseTensor,
|
||||||
|
window_size: int,
|
||||||
|
serialize_mode: SerializeMode = SerializeMode.Z_ORDER,
|
||||||
|
shift_sequence: int = 0,
|
||||||
|
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
||||||
|
) -> SparseTensor:
|
||||||
|
"""
|
||||||
|
Apply serialized scaled dot product self attention to a sparse tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
|
||||||
|
window_size (int): The window size to use.
|
||||||
|
serialize_mode (SerializeMode): The serialization mode to use.
|
||||||
|
shift_sequence (int): The shift of serialized sequence.
|
||||||
|
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
|
||||||
|
shift (int): The shift to use.
|
||||||
|
"""
|
||||||
|
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
|
||||||
|
|
||||||
|
serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}'
|
||||||
|
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
|
||||||
|
if serialization_spatial_cache is None:
|
||||||
|
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window)
|
||||||
|
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
|
||||||
|
else:
|
||||||
|
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
|
||||||
|
|
||||||
|
M = fwd_indices.shape[0]
|
||||||
|
T = qkv.feats.shape[0]
|
||||||
|
H = qkv.feats.shape[2]
|
||||||
|
C = qkv.feats.shape[3]
|
||||||
|
|
||||||
|
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
start = 0
|
||||||
|
qkv_coords = qkv.coords[fwd_indices]
|
||||||
|
for i in range(len(seq_lens)):
|
||||||
|
assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
|
||||||
|
start += seq_lens[i]
|
||||||
|
|
||||||
|
if all([seq_len == window_size for seq_len in seq_lens]):
|
||||||
|
B = len(seq_lens)
|
||||||
|
N = window_size
|
||||||
|
qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
|
||||||
|
if ATTN == 'xformers':
|
||||||
|
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
|
||||||
|
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
||||||
|
elif ATTN == 'flash_attn':
|
||||||
|
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown attention module: {ATTN}")
|
||||||
|
out = out.reshape(B * N, H, C) # [M, H, C]
|
||||||
|
else:
|
||||||
|
if ATTN == 'xformers':
|
||||||
|
q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
|
||||||
|
q = q.unsqueeze(0) # [1, M, H, C]
|
||||||
|
k = k.unsqueeze(0) # [1, M, H, C]
|
||||||
|
v = v.unsqueeze(0) # [1, M, H, C]
|
||||||
|
mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
|
||||||
|
out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
|
||||||
|
elif ATTN == 'flash_attn':
|
||||||
|
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
||||||
|
.to(qkv.device).int()
|
||||||
|
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
|
||||||
|
|
||||||
|
out = out[bwd_indices] # [T, H, C]
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
qkv_coords = qkv_coords[bwd_indices]
|
||||||
|
assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
|
||||||
|
|
||||||
|
return qkv.replace(out)
|
||||||
118
trellis/renderers/sh_utils.py
Executable file
118
trellis/renderers/sh_utils.py
Executable file
@@ -0,0 +1,118 @@
|
|||||||
|
# Copyright 2021 The PlenOctree Authors.
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
# POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
C0 = 0.28209479177387814
|
||||||
|
C1 = 0.4886025119029199
|
||||||
|
C2 = [
|
||||||
|
1.0925484305920792,
|
||||||
|
-1.0925484305920792,
|
||||||
|
0.31539156525252005,
|
||||||
|
-1.0925484305920792,
|
||||||
|
0.5462742152960396
|
||||||
|
]
|
||||||
|
C3 = [
|
||||||
|
-0.5900435899266435,
|
||||||
|
2.890611442640554,
|
||||||
|
-0.4570457994644658,
|
||||||
|
0.3731763325901154,
|
||||||
|
-0.4570457994644658,
|
||||||
|
1.445305721320277,
|
||||||
|
-0.5900435899266435
|
||||||
|
]
|
||||||
|
C4 = [
|
||||||
|
2.5033429417967046,
|
||||||
|
-1.7701307697799304,
|
||||||
|
0.9461746957575601,
|
||||||
|
-0.6690465435572892,
|
||||||
|
0.10578554691520431,
|
||||||
|
-0.6690465435572892,
|
||||||
|
0.47308734787878004,
|
||||||
|
-1.7701307697799304,
|
||||||
|
0.6258357354491761,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def eval_sh(deg, sh, dirs):
|
||||||
|
"""
|
||||||
|
Evaluate spherical harmonics at unit directions
|
||||||
|
using hardcoded SH polynomials.
|
||||||
|
Works with torch/np/jnp.
|
||||||
|
... Can be 0 or more batch dimensions.
|
||||||
|
Args:
|
||||||
|
deg: int SH deg. Currently, 0-3 supported
|
||||||
|
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
|
||||||
|
dirs: jnp.ndarray unit directions [..., 3]
|
||||||
|
Returns:
|
||||||
|
[..., C]
|
||||||
|
"""
|
||||||
|
assert deg <= 4 and deg >= 0
|
||||||
|
coeff = (deg + 1) ** 2
|
||||||
|
assert sh.shape[-1] >= coeff
|
||||||
|
|
||||||
|
result = C0 * sh[..., 0]
|
||||||
|
if deg > 0:
|
||||||
|
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
|
||||||
|
result = (result -
|
||||||
|
C1 * y * sh[..., 1] +
|
||||||
|
C1 * z * sh[..., 2] -
|
||||||
|
C1 * x * sh[..., 3])
|
||||||
|
|
||||||
|
if deg > 1:
|
||||||
|
xx, yy, zz = x * x, y * y, z * z
|
||||||
|
xy, yz, xz = x * y, y * z, x * z
|
||||||
|
result = (result +
|
||||||
|
C2[0] * xy * sh[..., 4] +
|
||||||
|
C2[1] * yz * sh[..., 5] +
|
||||||
|
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
|
||||||
|
C2[3] * xz * sh[..., 7] +
|
||||||
|
C2[4] * (xx - yy) * sh[..., 8])
|
||||||
|
|
||||||
|
if deg > 2:
|
||||||
|
result = (result +
|
||||||
|
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
|
||||||
|
C3[1] * xy * z * sh[..., 10] +
|
||||||
|
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
|
||||||
|
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
|
||||||
|
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
|
||||||
|
C3[5] * z * (xx - yy) * sh[..., 14] +
|
||||||
|
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
|
||||||
|
|
||||||
|
if deg > 3:
|
||||||
|
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
|
||||||
|
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
|
||||||
|
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
|
||||||
|
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
|
||||||
|
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
|
||||||
|
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
|
||||||
|
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
|
||||||
|
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
|
||||||
|
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
|
||||||
|
return result
|
||||||
|
|
||||||
|
def RGB2SH(rgb):
|
||||||
|
return (rgb - 0.5) / C0
|
||||||
|
|
||||||
|
def SH2RGB(sh):
|
||||||
|
return sh * C0 + 0.5
|
||||||
274
trellis/representations/mesh/flexicubes/examples/render.py
Normal file
274
trellis/representations/mesh/flexicubes/examples/render.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
from ipywidgets import interactive, HBox, VBox, FloatLogSlider, IntSlider
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import nvdiffrast.torch as dr
|
||||||
|
import kaolin as kal
|
||||||
|
import util
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Functions adapted from https://github.com/NVlabs/nvdiffrec
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
def get_random_camera_batch(batch_size, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
|
||||||
|
if use_kaolin:
|
||||||
|
camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(
|
||||||
|
*kal.ops.random.sample_spherical_coords((batch_size,), azimuth_low=0., azimuth_high=math.pi * 2,
|
||||||
|
elevation_low=-math.pi / 2., elevation_high=math.pi / 2., device='cuda'),
|
||||||
|
cam_radius
|
||||||
|
), dim=-1)
|
||||||
|
return kal.render.camera.Camera.from_args(
|
||||||
|
eye=camera_pos + torch.rand((batch_size, 1), device='cuda') * 0.5 - 0.25,
|
||||||
|
at=torch.zeros(batch_size, 3),
|
||||||
|
up=torch.tensor([[0., 1., 0.]]),
|
||||||
|
fov=fovy,
|
||||||
|
near=cam_near_far[0], far=cam_near_far[1],
|
||||||
|
height=iter_res[0], width=iter_res[1],
|
||||||
|
device='cuda'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
def get_random_camera():
|
||||||
|
proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
|
||||||
|
mv = util.translate(0, 0, -cam_radius) @ util.random_rotation_translation(0.25)
|
||||||
|
mvp = proj_mtx @ mv
|
||||||
|
return mv, mvp
|
||||||
|
mv_batch = []
|
||||||
|
mvp_batch = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
mv, mvp = get_random_camera()
|
||||||
|
mv_batch.append(mv)
|
||||||
|
mvp_batch.append(mvp)
|
||||||
|
return torch.stack(mv_batch).to(device), torch.stack(mvp_batch).to(device)
|
||||||
|
|
||||||
|
def get_rotate_camera(itr, fovy = np.deg2rad(45), iter_res=[512,512], cam_near_far=[0.1, 1000.0], cam_radius=3.0, device="cuda", use_kaolin=True):
|
||||||
|
if use_kaolin:
|
||||||
|
ang = (itr / 10) * np.pi * 2
|
||||||
|
camera_pos = torch.stack(kal.ops.coords.spherical2cartesian(torch.tensor(ang), torch.tensor(0.4), -torch.tensor(cam_radius)))
|
||||||
|
return kal.render.camera.Camera.from_args(
|
||||||
|
eye=camera_pos,
|
||||||
|
at=torch.zeros(3),
|
||||||
|
up=torch.tensor([0., 1., 0.]),
|
||||||
|
fov=fovy,
|
||||||
|
near=cam_near_far[0], far=cam_near_far[1],
|
||||||
|
height=iter_res[0], width=iter_res[1],
|
||||||
|
device='cuda'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
proj_mtx = util.perspective(fovy, iter_res[1] / iter_res[0], cam_near_far[0], cam_near_far[1])
|
||||||
|
|
||||||
|
# Smooth rotation for display.
|
||||||
|
ang = (itr / 10) * np.pi * 2
|
||||||
|
mv = util.translate(0, 0, -cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
|
||||||
|
mvp = proj_mtx @ mv
|
||||||
|
return mv.to(device), mvp.to(device)
|
||||||
|
|
||||||
|
glctx = dr.RasterizeGLContext()
|
||||||
|
def render_mesh(mesh, camera, iter_res, return_types = ["mask", "depth"], white_bg=False, wireframe_thickness=0.4):
|
||||||
|
vertices_camera = camera.extrinsics.transform(mesh.vertices)
|
||||||
|
face_vertices_camera = kal.ops.mesh.index_vertices_by_faces(
|
||||||
|
vertices_camera, mesh.faces
|
||||||
|
)
|
||||||
|
|
||||||
|
# Projection: nvdiffrast take clip coordinates as input to apply barycentric perspective correction.
|
||||||
|
# Using `camera.intrinsics.transform(vertices_camera) would return the normalized device coordinates.
|
||||||
|
proj = camera.projection_matrix().unsqueeze(1)
|
||||||
|
proj[:, :, 1, 1] = -proj[:, :, 1, 1]
|
||||||
|
homogeneous_vecs = kal.render.camera.up_to_homogeneous(
|
||||||
|
vertices_camera
|
||||||
|
)
|
||||||
|
vertices_clip = (proj @ homogeneous_vecs.unsqueeze(-1)).squeeze(-1)
|
||||||
|
faces_int = mesh.faces.int()
|
||||||
|
|
||||||
|
rast, _ = dr.rasterize(
|
||||||
|
glctx, vertices_clip, faces_int, iter_res)
|
||||||
|
|
||||||
|
out_dict = {}
|
||||||
|
for type in return_types:
|
||||||
|
if type == "mask" :
|
||||||
|
img = dr.antialias((rast[..., -1:] > 0).float(), rast, vertices_clip, faces_int)
|
||||||
|
elif type == "depth":
|
||||||
|
img = dr.interpolate(homogeneous_vecs, rast, faces_int)[0]
|
||||||
|
elif type == "wireframe":
|
||||||
|
img = torch.logical_or(
|
||||||
|
torch.logical_or(rast[..., 0] < wireframe_thickness, rast[..., 1] < wireframe_thickness),
|
||||||
|
(rast[..., 0] + rast[..., 1]) > (1. - wireframe_thickness)
|
||||||
|
).unsqueeze(-1)
|
||||||
|
elif type == "normals" :
|
||||||
|
img = dr.interpolate(
|
||||||
|
mesh.face_normals.reshape(len(mesh), -1, 3), rast,
|
||||||
|
torch.arange(mesh.faces.shape[0] * 3, device='cuda', dtype=torch.int).reshape(-1, 3)
|
||||||
|
)[0]
|
||||||
|
if white_bg:
|
||||||
|
bg = torch.ones_like(img)
|
||||||
|
alpha = (rast[..., -1:] > 0).float()
|
||||||
|
img = torch.lerp(bg, img, alpha)
|
||||||
|
out_dict[type] = img
|
||||||
|
|
||||||
|
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
def render_mesh_paper(mesh, mv, mvp, iter_res, return_types = ["mask", "depth"], white_bg=False):
|
||||||
|
'''
|
||||||
|
The rendering function used to produce the results in the paper.
|
||||||
|
'''
|
||||||
|
v_pos_clip = util.xfm_points(mesh.vertices.unsqueeze(0), mvp) # Rotate it to camera coordinates
|
||||||
|
rast, db = dr.rasterize(
|
||||||
|
dr.RasterizeGLContext(), v_pos_clip, mesh.faces.int(), iter_res)
|
||||||
|
|
||||||
|
out_dict = {}
|
||||||
|
for type in return_types:
|
||||||
|
if type == "mask" :
|
||||||
|
img = dr.antialias((rast[..., -1:] > 0).float(), rast, v_pos_clip, mesh.faces.int())
|
||||||
|
elif type == "depth":
|
||||||
|
v_pos_cam = util.xfm_points(mesh.vertices.unsqueeze(0), mv)
|
||||||
|
img, _ = util.interpolate(v_pos_cam, rast, mesh.faces.int())
|
||||||
|
elif type == "normal" :
|
||||||
|
normal_indices = (torch.arange(0, mesh.nrm.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
|
||||||
|
img, _ = util.interpolate(mesh.nrm.unsqueeze(0).contiguous(), rast, normal_indices.int())
|
||||||
|
elif type == "vertex_normal":
|
||||||
|
img, _ = util.interpolate(mesh.v_nrm.unsqueeze(0).contiguous(), rast, mesh.faces.int())
|
||||||
|
img = dr.antialias((img + 1) * 0.5, rast, v_pos_clip, mesh.faces.int())
|
||||||
|
if white_bg:
|
||||||
|
bg = torch.ones_like(img)
|
||||||
|
alpha = (rast[..., -1:] > 0).float()
|
||||||
|
img = torch.lerp(bg, img, alpha)
|
||||||
|
out_dict[type] = img
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
class SplitVisualizer():
|
||||||
|
def __init__(self, lh_mesh, rh_mesh, height, width):
|
||||||
|
self.lh_mesh = lh_mesh
|
||||||
|
self.rh_mesh = rh_mesh
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.wireframe_thickness = 0.4
|
||||||
|
|
||||||
|
|
||||||
|
def render(self, camera):
|
||||||
|
lh_outputs = render_mesh(
|
||||||
|
self.lh_mesh, camera, (self.height, self.width),
|
||||||
|
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
|
||||||
|
)
|
||||||
|
rh_outputs = render_mesh(
|
||||||
|
self.rh_mesh, camera, (self.height, self.width),
|
||||||
|
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
|
||||||
|
)
|
||||||
|
outputs = {
|
||||||
|
k: torch.cat(
|
||||||
|
[lh_outputs[k][0].permute(1, 0, 2), rh_outputs[k][0].permute(1, 0, 2)],
|
||||||
|
dim=0
|
||||||
|
).permute(1, 0, 2) for k in ["normals", "wireframe"]
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8),
|
||||||
|
'normals': outputs['normals']
|
||||||
|
}
|
||||||
|
|
||||||
|
def show(self, init_camera):
|
||||||
|
visualizer = kal.visualize.IpyTurntableVisualizer(
|
||||||
|
self.height, self.width * 2, copy.deepcopy(init_camera), self.render,
|
||||||
|
max_fps=24, world_up_axis=1)
|
||||||
|
|
||||||
|
def slider_callback(new_wireframe_thickness):
|
||||||
|
"""ipywidgets sliders callback"""
|
||||||
|
with visualizer.out: # This is in case of bug
|
||||||
|
self.wireframe_thickness = new_wireframe_thickness
|
||||||
|
# this is how we request a new update
|
||||||
|
visualizer.render_update()
|
||||||
|
|
||||||
|
wireframe_thickness_slider = FloatLogSlider(
|
||||||
|
value=self.wireframe_thickness,
|
||||||
|
base=10,
|
||||||
|
min=-3,
|
||||||
|
max=-0.4,
|
||||||
|
step=0.1,
|
||||||
|
description='wireframe_thickness',
|
||||||
|
continuous_update=True,
|
||||||
|
readout=True,
|
||||||
|
readout_format='.3f',
|
||||||
|
)
|
||||||
|
|
||||||
|
interactive_slider = interactive(
|
||||||
|
slider_callback,
|
||||||
|
new_wireframe_thickness=wireframe_thickness_slider,
|
||||||
|
)
|
||||||
|
|
||||||
|
full_output = VBox([visualizer.canvas, interactive_slider])
|
||||||
|
display(full_output, visualizer.out)
|
||||||
|
|
||||||
|
class TimelineVisualizer():
|
||||||
|
def __init__(self, meshes, height, width):
|
||||||
|
self.meshes = meshes
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.wireframe_thickness = 0.4
|
||||||
|
self.idx = len(meshes) - 1
|
||||||
|
|
||||||
|
def render(self, camera):
|
||||||
|
outputs = render_mesh(
|
||||||
|
self.meshes[self.idx], camera, (self.height, self.width),
|
||||||
|
return_types=["normals", "wireframe"], wireframe_thickness=self.wireframe_thickness
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'img': (outputs['wireframe'] * ((outputs['normals'] + 1.) / 2.) * 255).to(torch.uint8)[0],
|
||||||
|
'normals': outputs['normals'][0]
|
||||||
|
}
|
||||||
|
|
||||||
|
def show(self, init_camera):
|
||||||
|
visualizer = kal.visualize.IpyTurntableVisualizer(
|
||||||
|
self.height, self.width, copy.deepcopy(init_camera), self.render,
|
||||||
|
max_fps=24, world_up_axis=1)
|
||||||
|
|
||||||
|
def slider_callback(new_wireframe_thickness, new_idx):
|
||||||
|
"""ipywidgets sliders callback"""
|
||||||
|
with visualizer.out: # This is in case of bug
|
||||||
|
self.wireframe_thickness = new_wireframe_thickness
|
||||||
|
self.idx = new_idx
|
||||||
|
# this is how we request a new update
|
||||||
|
visualizer.render_update()
|
||||||
|
|
||||||
|
wireframe_thickness_slider = FloatLogSlider(
|
||||||
|
value=self.wireframe_thickness,
|
||||||
|
base=10,
|
||||||
|
min=-3,
|
||||||
|
max=-0.4,
|
||||||
|
step=0.1,
|
||||||
|
description='wireframe_thickness',
|
||||||
|
continuous_update=True,
|
||||||
|
readout=True,
|
||||||
|
readout_format='.3f',
|
||||||
|
)
|
||||||
|
|
||||||
|
idx_slider = IntSlider(
|
||||||
|
value=self.idx,
|
||||||
|
min=0,
|
||||||
|
max=len(self.meshes) - 1,
|
||||||
|
description='idx',
|
||||||
|
continuous_update=True,
|
||||||
|
readout=True
|
||||||
|
)
|
||||||
|
|
||||||
|
interactive_slider = interactive(
|
||||||
|
slider_callback,
|
||||||
|
new_wireframe_thickness=wireframe_thickness_slider,
|
||||||
|
new_idx=idx_slider
|
||||||
|
)
|
||||||
|
full_output = HBox([visualizer.canvas, interactive_slider])
|
||||||
|
display(full_output, visualizer.out)
|
||||||
30
trellis/utils/random_utils.py
Normal file
30
trellis/utils/random_utils.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
||||||
|
|
||||||
|
def radical_inverse(base, n):
|
||||||
|
val = 0
|
||||||
|
inv_base = 1.0 / base
|
||||||
|
inv_base_n = inv_base
|
||||||
|
while n > 0:
|
||||||
|
digit = n % base
|
||||||
|
val += digit * inv_base_n
|
||||||
|
n //= base
|
||||||
|
inv_base_n *= inv_base
|
||||||
|
return val
|
||||||
|
|
||||||
|
def halton_sequence(dim, n):
|
||||||
|
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
|
||||||
|
|
||||||
|
def hammersley_sequence(dim, n, num_samples):
|
||||||
|
return [n / num_samples] + halton_sequence(dim - 1, n)
|
||||||
|
|
||||||
|
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
|
||||||
|
u, v = hammersley_sequence(2, n, num_samples)
|
||||||
|
u += offset[0] / num_samples
|
||||||
|
v += offset[1]
|
||||||
|
if remap:
|
||||||
|
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
|
||||||
|
theta = np.arccos(1 - 2 * u) - np.pi / 2
|
||||||
|
phi = v * 2 * np.pi
|
||||||
|
return [phi, theta]
|
||||||
120
trellis/utils/render_utils.py
Normal file
120
trellis/utils/render_utils.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
import utils3d
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..renderers import OctreeRenderer, GaussianRenderer, MeshRenderer
|
||||||
|
from ..representations import Octree, Gaussian, MeshExtractResult
|
||||||
|
from ..modules import sparse as sp
|
||||||
|
from .random_utils import sphere_hammersley_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, rs, fovs):
|
||||||
|
is_list = isinstance(yaws, list)
|
||||||
|
if not is_list:
|
||||||
|
yaws = [yaws]
|
||||||
|
pitchs = [pitchs]
|
||||||
|
if not isinstance(rs, list):
|
||||||
|
rs = [rs] * len(yaws)
|
||||||
|
if not isinstance(fovs, list):
|
||||||
|
fovs = [fovs] * len(yaws)
|
||||||
|
extrinsics = []
|
||||||
|
intrinsics = []
|
||||||
|
for yaw, pitch, r, fov in zip(yaws, pitchs, rs, fovs):
|
||||||
|
fov = torch.deg2rad(torch.tensor(float(fov))).cuda()
|
||||||
|
yaw = torch.tensor(float(yaw)).cuda()
|
||||||
|
pitch = torch.tensor(float(pitch)).cuda()
|
||||||
|
orig = torch.tensor([
|
||||||
|
torch.sin(yaw) * torch.cos(pitch),
|
||||||
|
torch.cos(yaw) * torch.cos(pitch),
|
||||||
|
torch.sin(pitch),
|
||||||
|
]).cuda() * r
|
||||||
|
extr = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
||||||
|
intr = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
||||||
|
extrinsics.append(extr)
|
||||||
|
intrinsics.append(intr)
|
||||||
|
if not is_list:
|
||||||
|
extrinsics = extrinsics[0]
|
||||||
|
intrinsics = intrinsics[0]
|
||||||
|
return extrinsics, intrinsics
|
||||||
|
|
||||||
|
|
||||||
|
def get_renderer(sample, **kwargs):
|
||||||
|
if isinstance(sample, Octree):
|
||||||
|
renderer = OctreeRenderer()
|
||||||
|
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
||||||
|
renderer.rendering_options.near = kwargs.get('near', 0.8)
|
||||||
|
renderer.rendering_options.far = kwargs.get('far', 1.6)
|
||||||
|
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
|
||||||
|
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
|
||||||
|
renderer.pipe.primitive = sample.primitive
|
||||||
|
elif isinstance(sample, Gaussian):
|
||||||
|
renderer = GaussianRenderer()
|
||||||
|
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
||||||
|
renderer.rendering_options.near = kwargs.get('near', 0.8)
|
||||||
|
renderer.rendering_options.far = kwargs.get('far', 1.6)
|
||||||
|
renderer.rendering_options.bg_color = kwargs.get('bg_color', (0, 0, 0))
|
||||||
|
renderer.rendering_options.ssaa = kwargs.get('ssaa', 1)
|
||||||
|
renderer.pipe.kernel_size = kwargs.get('kernel_size', 0.1)
|
||||||
|
renderer.pipe.use_mip_gaussian = True
|
||||||
|
elif isinstance(sample, MeshExtractResult):
|
||||||
|
renderer = MeshRenderer()
|
||||||
|
renderer.rendering_options.resolution = kwargs.get('resolution', 512)
|
||||||
|
renderer.rendering_options.near = kwargs.get('near', 1)
|
||||||
|
renderer.rendering_options.far = kwargs.get('far', 100)
|
||||||
|
renderer.rendering_options.ssaa = kwargs.get('ssaa', 4)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported sample type: {type(sample)}')
|
||||||
|
return renderer
|
||||||
|
|
||||||
|
|
||||||
|
def render_frames(sample, extrinsics, intrinsics, options={}, colors_overwrite=None, verbose=True, **kwargs):
|
||||||
|
renderer = get_renderer(sample, **options)
|
||||||
|
rets = {}
|
||||||
|
for j, (extr, intr) in tqdm(enumerate(zip(extrinsics, intrinsics)), desc='Rendering', disable=not verbose):
|
||||||
|
if isinstance(sample, MeshExtractResult):
|
||||||
|
res = renderer.render(sample, extr, intr)
|
||||||
|
if 'normal' not in rets: rets['normal'] = []
|
||||||
|
rets['normal'].append(np.clip(res['normal'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
||||||
|
else:
|
||||||
|
res = renderer.render(sample, extr, intr, colors_overwrite=colors_overwrite)
|
||||||
|
if 'color' not in rets: rets['color'] = []
|
||||||
|
if 'depth' not in rets: rets['depth'] = []
|
||||||
|
rets['color'].append(np.clip(res['color'].detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8))
|
||||||
|
if 'percent_depth' in res:
|
||||||
|
rets['depth'].append(res['percent_depth'].detach().cpu().numpy())
|
||||||
|
elif 'depth' in res:
|
||||||
|
rets['depth'].append(res['depth'].detach().cpu().numpy())
|
||||||
|
else:
|
||||||
|
rets['depth'].append(None)
|
||||||
|
return rets
|
||||||
|
|
||||||
|
|
||||||
|
def render_video(sample, resolution=512, bg_color=(0, 0, 0), num_frames=300, r=2, fov=40, **kwargs):
|
||||||
|
yaws = torch.linspace(0, 2 * 3.1415, num_frames)
|
||||||
|
pitch = 0.25 + 0.5 * torch.sin(torch.linspace(0, 2 * 3.1415, num_frames))
|
||||||
|
yaws = yaws.tolist()
|
||||||
|
pitch = pitch.tolist()
|
||||||
|
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitch, r, fov)
|
||||||
|
return render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def render_multiview(sample, resolution=512, nviews=30):
|
||||||
|
r = 2
|
||||||
|
fov = 40
|
||||||
|
cams = [sphere_hammersley_sequence(i, nviews) for i in range(nviews)]
|
||||||
|
yaws = [cam[0] for cam in cams]
|
||||||
|
pitchs = [cam[1] for cam in cams]
|
||||||
|
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaws, pitchs, r, fov)
|
||||||
|
res = render_frames(sample, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': (0, 0, 0)})
|
||||||
|
return res['color'], extrinsics, intrinsics
|
||||||
|
|
||||||
|
|
||||||
|
def render_snapshot(samples, resolution=512, bg_color=(0, 0, 0), offset=(-16 / 180 * np.pi, 20 / 180 * np.pi), r=10, fov=8, **kwargs):
|
||||||
|
yaw = [0, np.pi/2, np.pi, 3*np.pi/2]
|
||||||
|
yaw_offset = offset[0]
|
||||||
|
yaw = [y + yaw_offset for y in yaw]
|
||||||
|
pitch = [offset[1] for _ in range(4)]
|
||||||
|
extrinsics, intrinsics = yaw_pitch_r_fov_to_extrinsics_intrinsics(yaw, pitch, r, fov)
|
||||||
|
return render_frames(samples, extrinsics, intrinsics, {'resolution': resolution, 'bg_color': bg_color}, **kwargs)
|
||||||
Reference in New Issue
Block a user