This commit is contained in:
zcr
2026-03-17 11:29:47 +08:00
parent a6d9bac6d0
commit 06659057c3
26 changed files with 3742 additions and 0 deletions

View File

@@ -0,0 +1,102 @@
{
"models": {
"denoiser": {
"name": "ElasticSLatFlowModel",
"args": {
"resolution": 64,
"in_channels": 8,
"out_channels": 8,
"model_channels": 1024,
"cond_channels": 1024,
"num_blocks": 24,
"num_heads": 16,
"mlp_ratio": 4,
"patch_size": 2,
"num_io_res_blocks": 2,
"io_block_channels": [128],
"pe_mode": "ape",
"qk_rms_norm": true,
"use_fp16": true
}
}
},
"dataset": {
"name": "ImageConditionedSLat",
"args": {
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
"min_aesthetic_score": 4.5,
"max_num_voxels": 32768,
"image_size": 518,
"normalization": {
"mean": [
-2.1687545776367188,
-0.004347046371549368,
-0.13352349400520325,
-0.08418072760105133,
-0.5271206498146057,
0.7238689064979553,
-1.1414450407028198,
1.2039363384246826
],
"std": [
2.377650737762451,
2.386378288269043,
2.124418020248413,
2.1748552322387695,
2.663944721221924,
2.371192216873169,
2.6217446327209473,
2.684523105621338
]
},
"pretrained_slat_dec": "microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
}
},
"trainer": {
"name": "ImageConditionedSparseFlowMatchingCFGTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 8,
"batch_split": 4,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 0.0001,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.5
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"p_uncond": 0.1,
"t_schedule": {
"name": "logitNormal",
"args": {
"mean": 1.0,
"std": 1.0
}
},
"sigma_min": 1e-5,
"image_cond_model": "dinov2_vitl14_reg"
}
}
}

View File

@@ -0,0 +1,101 @@
{
"models": {
"denoiser": {
"name": "ElasticSLatFlowModel",
"args": {
"resolution": 64,
"in_channels": 8,
"out_channels": 8,
"model_channels": 768,
"cond_channels": 768,
"num_blocks": 12,
"num_heads": 12,
"mlp_ratio": 4,
"patch_size": 2,
"num_io_res_blocks": 2,
"io_block_channels": [128],
"pe_mode": "ape",
"qk_rms_norm": true,
"use_fp16": true
}
}
},
"dataset": {
"name": "TextConditionedSLat",
"args": {
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
"min_aesthetic_score": 4.5,
"max_num_voxels": 32768,
"normalization": {
"mean": [
-2.1687545776367188,
-0.004347046371549368,
-0.13352349400520325,
-0.08418072760105133,
-0.5271206498146057,
0.7238689064979553,
-1.1414450407028198,
1.2039363384246826
],
"std": [
2.377650737762451,
2.386378288269043,
2.124418020248413,
2.1748552322387695,
2.663944721221924,
2.371192216873169,
2.6217446327209473,
2.684523105621338
]
},
"pretrained_slat_dec": "microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
}
},
"trainer": {
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 16,
"batch_split": 4,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 0.0001,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.5
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"p_uncond": 0.1,
"t_schedule": {
"name": "logitNormal",
"args": {
"mean": 1.0,
"std": 1.0
}
},
"sigma_min": 1e-5,
"text_cond_model": "openai/clip-vit-large-patch14"
}
}
}

View File

@@ -0,0 +1,101 @@
{
"models": {
"denoiser": {
"name": "ElasticSLatFlowModel",
"args": {
"resolution": 64,
"in_channels": 8,
"out_channels": 8,
"model_channels": 1024,
"cond_channels": 768,
"num_blocks": 24,
"num_heads": 16,
"mlp_ratio": 4,
"patch_size": 2,
"num_io_res_blocks": 2,
"io_block_channels": [128],
"pe_mode": "ape",
"qk_rms_norm": true,
"use_fp16": true
}
}
},
"dataset": {
"name": "TextConditionedSLat",
"args": {
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
"min_aesthetic_score": 4.5,
"max_num_voxels": 32768,
"normalization": {
"mean": [
-2.1687545776367188,
-0.004347046371549368,
-0.13352349400520325,
-0.08418072760105133,
-0.5271206498146057,
0.7238689064979553,
-1.1414450407028198,
1.2039363384246826
],
"std": [
2.377650737762451,
2.386378288269043,
2.124418020248413,
2.1748552322387695,
2.663944721221924,
2.371192216873169,
2.6217446327209473,
2.684523105621338
]
},
"pretrained_slat_dec": "microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
}
},
"trainer": {
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 8,
"batch_split": 4,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 0.0001,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.5
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"p_uncond": 0.1,
"t_schedule": {
"name": "logitNormal",
"args": {
"mean": 1.0,
"std": 1.0
}
},
"sigma_min": 1e-5,
"text_cond_model": "openai/clip-vit-large-patch14"
}
}
}

View File

@@ -0,0 +1,101 @@
{
"models": {
"denoiser": {
"name": "ElasticSLatFlowModel",
"args": {
"resolution": 64,
"in_channels": 8,
"out_channels": 8,
"model_channels": 1280,
"cond_channels": 768,
"num_blocks": 28,
"num_heads": 16,
"mlp_ratio": 4,
"patch_size": 2,
"num_io_res_blocks": 3,
"io_block_channels": [256],
"pe_mode": "ape",
"qk_rms_norm": true,
"use_fp16": true
}
}
},
"dataset": {
"name": "TextConditionedSLat",
"args": {
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
"min_aesthetic_score": 4.5,
"max_num_voxels": 32768,
"normalization": {
"mean": [
-2.1687545776367188,
-0.004347046371549368,
-0.13352349400520325,
-0.08418072760105133,
-0.5271206498146057,
0.7238689064979553,
-1.1414450407028198,
1.2039363384246826
],
"std": [
2.377650737762451,
2.386378288269043,
2.124418020248413,
2.1748552322387695,
2.663944721221924,
2.371192216873169,
2.6217446327209473,
2.684523105621338
]
},
"pretrained_slat_dec": "microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16"
}
},
"trainer": {
"name": "TextConditionedSparseFlowMatchingCFGTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 4,
"batch_split": 4,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 0.0001,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.5
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"p_uncond": 0.1,
"t_schedule": {
"name": "logitNormal",
"args": {
"mean": 1.0,
"std": 1.0
}
},
"sigma_min": 1e-5,
"text_cond_model": "openai/clip-vit-large-patch14"
}
}
}

View File

@@ -0,0 +1,71 @@
{
"models": {
"decoder": {
"name": "ElasticSLatRadianceFieldDecoder",
"args": {
"resolution": 64,
"model_channels": 768,
"latent_channels": 8,
"num_blocks": 12,
"num_heads": 12,
"mlp_ratio": 4,
"attn_mode": "swin",
"window_size": 8,
"use_fp16": true,
"representation_config": {
"rank": 16,
"dim": 8
}
}
}
},
"dataset": {
"name": "SLat2Render",
"args": {
"image_size": 512,
"latent_model": "dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16",
"min_aesthetic_score": 4.5,
"max_num_voxels": 32768
}
},
"trainer": {
"name": "SLatVaeRadianceFieldDecoderTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 4,
"batch_split": 2,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 1e-4,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.5
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"loss_type": "l1",
"lambda_ssim": 0.2,
"lambda_lpips": 0.2
}
}
}

View File

@@ -0,0 +1,105 @@
{
"models": {
"encoder": {
"name": "ElasticSLatEncoder",
"args": {
"resolution": 64,
"in_channels": 1024,
"model_channels": 768,
"latent_channels": 8,
"num_blocks": 12,
"num_heads": 12,
"mlp_ratio": 4,
"attn_mode": "swin",
"window_size": 8,
"use_fp16": true
}
},
"decoder": {
"name": "ElasticSLatGaussianDecoder",
"args": {
"resolution": 64,
"model_channels": 768,
"latent_channels": 8,
"num_blocks": 12,
"num_heads": 12,
"mlp_ratio": 4,
"attn_mode": "swin",
"window_size": 8,
"use_fp16": true,
"representation_config": {
"lr": {
"_xyz": 1.0,
"_features_dc": 1.0,
"_opacity": 1.0,
"_scaling": 1.0,
"_rotation": 0.1
},
"perturb_offset": true,
"voxel_size": 1.5,
"num_gaussians": 32,
"2d_filter_kernel_size": 0.1,
"3d_filter_kernel_size": 9e-4,
"scaling_bias": 4e-3,
"opacity_bias": 0.1,
"scaling_activation": "softplus"
}
}
}
},
"dataset": {
"name": "SparseFeat2Render",
"args": {
"image_size": 512,
"model": "dinov2_vitl14_reg",
"resolution": 64,
"min_aesthetic_score": 4.5,
"max_num_voxels": 32768
}
},
"trainer": {
"name": "SLatVaeGaussianTrainer",
"args": {
"max_steps": 1000000,
"batch_size_per_gpu": 4,
"batch_split": 2,
"optimizer": {
"name": "AdamW",
"args": {
"lr": 1e-4,
"weight_decay": 0.0
}
},
"ema_rate": [
0.9999
],
"fp16_mode": "inflat_all",
"fp16_scale_growth": 0.001,
"elastic": {
"name": "LinearMemoryController",
"args": {
"target_ratio": 0.75,
"max_mem_ratio_start": 0.5
}
},
"grad_clip": {
"name": "AdaptiveGradClipper",
"args": {
"max_norm": 1.0,
"clip_percentile": 95
}
},
"i_log": 500,
"i_sample": 10000,
"i_save": 10000,
"loss_type": "l1",
"lambda_ssim": 0.2,
"lambda_lpips": 0.2,
"lambda_kl": 1e-06,
"regularizations": {
"lambda_vol": 10000.0,
"lambda_opacity": 0.001
}
}
}
}

View File

@@ -0,0 +1,92 @@
import os
import re
import argparse
import zipfile
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import pandas as pd
from utils import get_file_hash
def add_args(parser: argparse.ArgumentParser):
pass
def get_metadata(**kwargs):
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/Toys4k.csv")
return metadata
def download(metadata, output_dir, **kwargs):
os.makedirs(output_dir, exist_ok=True)
if not os.path.exists(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')):
print("\033[93m")
print("Toys4k have to be downloaded manually")
print(f"Please download the toys4k_blend_files.zip file and place it in the {output_dir}/raw directory")
print("Visit https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k for more information")
print("\033[0m")
raise FileNotFoundError("toys4k_blend_files.zip not found")
downloaded = {}
metadata = metadata.set_index("file_identifier")
with zipfile.ZipFile(os.path.join(output_dir, 'raw', 'toys4k_blend_files.zip')) as zip_ref:
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor, \
tqdm(total=len(metadata), desc="Extracting") as pbar:
def worker(instance: str) -> str:
try:
zip_ref.extract(os.path.join('toys4k_blend_files', instance), os.path.join(output_dir, 'raw'))
sha256 = get_file_hash(os.path.join(output_dir, 'raw/toys4k_blend_files', instance))
pbar.update()
return sha256
except Exception as e:
pbar.update()
print(f"Error extracting for {instance}: {e}")
return None
sha256s = executor.map(worker, metadata.index)
executor.shutdown(wait=True)
for k, sha256 in zip(metadata.index, sha256s):
if sha256 is not None:
if sha256 == metadata.loc[k, "sha256"]:
downloaded[sha256] = os.path.join("raw/toys4k_blend_files", k)
else:
print(f"Error downloading {k}: sha256s do not match")
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
import os
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
# load metadata
metadata = metadata.to_dict('records')
# processing objects
records = []
max_workers = max_workers or os.cpu_count()
try:
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
tqdm(total=len(metadata), desc=desc) as pbar:
def worker(metadatum):
try:
local_path = metadatum['local_path']
sha256 = metadatum['sha256']
file = os.path.join(output_dir, local_path)
record = func(file, sha256)
if record is not None:
records.append(record)
pbar.update()
except Exception as e:
print(f"Error processing object {sha256}: {e}")
pbar.update()
executor.map(worker, metadata)
executor.shutdown(wait=True)
except:
print("Error happened during processing.")
return pd.DataFrame.from_records(records)

View File

@@ -0,0 +1 @@
pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless pandas open3d objaverse huggingface_hub open_clip_torch

43
dataset_toolkits/utils.py Normal file
View File

@@ -0,0 +1,43 @@
from typing import *
import hashlib
import numpy as np
def get_file_hash(file: str) -> str:
sha256 = hashlib.sha256()
# Read the file from the path
with open(file, "rb") as f:
# Update the hash with the file content
for byte_block in iter(lambda: f.read(4096), b""):
sha256.update(byte_block)
return sha256.hexdigest()
# ===============LOW DISCREPANCY SEQUENCES================
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0)):
u, v = hammersley_sequence(2, n, num_samples)
u += offset[0] / num_samples
v += offset[1]
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
theta = np.arccos(1 - 2 * u) - np.pi / 2
phi = v * 2 * np.pi
return [phi, theta]

View File

@@ -0,0 +1,86 @@
import os
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
from functools import partial
import numpy as np
import open3d as o3d
import utils3d
def _voxelize(file, sha256, output_dir):
mesh = o3d.io.read_triangle_mesh(os.path.join(output_dir, 'renders', sha256, 'mesh.ply'))
# clamp vertices to the range [-0.5, 0.5]
vertices = np.clip(np.asarray(mesh.vertices), -0.5 + 1e-6, 0.5 - 1e-6)
mesh.vertices = o3d.utility.Vector3dVector(vertices)
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
assert np.all(vertices >= 0) and np.all(vertices < 64), "Some vertices are out of bounds"
vertices = (vertices + 0.5) / 64 - 0.5
utils3d.io.write_ply(os.path.join(output_dir, 'voxels', f'{sha256}.ply'), vertices)
return {'sha256': sha256, 'voxelized': True, 'num_voxels': len(vertices)}
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--num_views', type=int, default=150,
help='Number of views to render')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--max_workers', type=int, default=None)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
os.makedirs(os.path.join(opt.output_dir, 'voxels'), exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'rendered' not in metadata.columns:
raise ValueError('metadata.csv does not have "rendered" column, please run "build_metadata.py" first')
metadata = metadata[metadata['rendered'] == True]
if 'voxelized' in metadata.columns:
metadata = metadata[metadata['voxelized'] == False]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
for sha256 in copy.copy(metadata['sha256'].values):
if os.path.exists(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply')):
pts = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
records.append({'sha256': sha256, 'voxelized': True, 'num_voxels': len(pts)})
metadata = metadata[metadata['sha256'] != sha256]
print(f'Processing {len(metadata)} objects...')
# process objects
func = partial(_voxelize, output_dir=opt.output_dir)
voxelized = dataset_utils.foreach_instance(metadata, opt.output_dir, func, max_workers=opt.max_workers, desc='Voxelizing')
voxelized = pd.concat([voxelized, pd.DataFrame.from_records(records)])
voxelized.to_csv(os.path.join(opt.output_dir, f'voxelized_{opt.rank}.csv'), index=False)

260
setup.sh Executable file
View File

@@ -0,0 +1,260 @@
# Read Arguments
TEMP=`getopt -o h --long help,new-env,basic,train,xformers,flash-attn,diffoctreerast,vox2seq,spconv,mipgaussian,kaolin,nvdiffrast,demo -n 'setup.sh' -- "$@"`
eval set -- "$TEMP"
HELP=false
NEW_ENV=false
BASIC=false
TRAIN=false
XFORMERS=false
FLASHATTN=false
DIFFOCTREERAST=false
VOX2SEQ=false
LINEAR_ASSIGNMENT=false
SPCONV=false
ERROR=false
MIPGAUSSIAN=false
KAOLIN=false
NVDIFFRAST=false
DEMO=false
if [ "$#" -eq 1 ] ; then
HELP=true
fi
while true ; do
case "$1" in
-h|--help) HELP=true ; shift ;;
--new-env) NEW_ENV=true ; shift ;;
--basic) BASIC=true ; shift ;;
--train) TRAIN=true ; shift ;;
--xformers) XFORMERS=true ; shift ;;
--flash-attn) FLASHATTN=true ; shift ;;
--diffoctreerast) DIFFOCTREERAST=true ; shift ;;
--vox2seq) VOX2SEQ=true ; shift ;;
--spconv) SPCONV=true ; shift ;;
--mipgaussian) MIPGAUSSIAN=true ; shift ;;
--kaolin) KAOLIN=true ; shift ;;
--nvdiffrast) NVDIFFRAST=true ; shift ;;
--demo) DEMO=true ; shift ;;
--) shift ; break ;;
*) ERROR=true ; break ;;
esac
done
if [ "$ERROR" = true ] ; then
echo "Error: Invalid argument"
HELP=true
fi
if [ "$HELP" = true ] ; then
echo "Usage: setup.sh [OPTIONS]"
echo "Options:"
echo " -h, --help Display this help message"
echo " --new-env Create a new conda environment"
echo " --basic Install basic dependencies"
echo " --train Install training dependencies"
echo " --xformers Install xformers"
echo " --flash-attn Install flash-attn"
echo " --diffoctreerast Install diffoctreerast"
echo " --vox2seq Install vox2seq"
echo " --spconv Install spconv"
echo " --mipgaussian Install mip-splatting"
echo " --kaolin Install kaolin"
echo " --nvdiffrast Install nvdiffrast"
echo " --demo Install all dependencies for demo"
return
fi
if [ "$NEW_ENV" = true ] ; then
conda create -n trellis python=3.10
conda activate trellis
conda install pytorch==2.4.0 torchvision==0.19.0 pytorch-cuda=11.8 -c pytorch -c nvidia
fi
# Get system information
WORKDIR=$(pwd)
PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__)")
PLATFORM=$(python -c "import torch; print(('cuda' if torch.version.cuda else ('hip' if torch.version.hip else 'unknown')) if torch.cuda.is_available() else 'cpu')")
case $PLATFORM in
cuda)
CUDA_VERSION=$(python -c "import torch; print(torch.version.cuda)")
CUDA_MAJOR_VERSION=$(echo $CUDA_VERSION | cut -d'.' -f1)
CUDA_MINOR_VERSION=$(echo $CUDA_VERSION | cut -d'.' -f2)
echo "[SYSTEM] PyTorch Version: $PYTORCH_VERSION, CUDA Version: $CUDA_VERSION"
;;
hip)
HIP_VERSION=$(python -c "import torch; print(torch.version.hip)")
HIP_MAJOR_VERSION=$(echo $HIP_VERSION | cut -d'.' -f1)
HIP_MINOR_VERSION=$(echo $HIP_VERSION | cut -d'.' -f2)
# Install pytorch 2.4.1 for hip
if [ "$PYTORCH_VERSION" != "2.4.1+rocm6.1" ] ; then
echo "[SYSTEM] Installing PyTorch 2.4.1 for HIP ($PYTORCH_VERSION -> 2.4.1+rocm6.1)"
pip install torch==2.4.1 torchvision==0.19.1 --index-url https://download.pytorch.org/whl/rocm6.1 --user
mkdir -p /tmp/extensions
sudo cp /opt/rocm/share/amd_smi /tmp/extensions/amd_smi -r
cd /tmp/extensions/amd_smi
sudo chmod -R 777 .
pip install .
cd $WORKDIR
PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__)")
fi
echo "[SYSTEM] PyTorch Version: $PYTORCH_VERSION, HIP Version: $HIP_VERSION"
;;
*)
;;
esac
if [ "$BASIC" = true ] ; then
pip install pillow imageio imageio-ffmpeg tqdm easydict opencv-python-headless scipy ninja rembg onnxruntime trimesh open3d xatlas pyvista pymeshfix igraph transformers
pip install git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
fi
if [ "$TRAIN" = true ] ; then
pip install tensorboard pandas lpips
pip uninstall -y pillow
sudo apt install -y libjpeg-dev
pip install pillow-simd
fi
if [ "$XFORMERS" = true ] ; then
# install xformers
if [ "$PLATFORM" = "cuda" ] ; then
if [ "$CUDA_VERSION" = "11.8" ] ; then
case $PYTORCH_VERSION in
2.0.1) pip install https://files.pythonhosted.org/packages/52/ca/82aeee5dcc24a3429ff5de65cc58ae9695f90f49fbba71755e7fab69a706/xformers-0.0.22-cp310-cp310-manylinux2014_x86_64.whl ;;
2.1.0) pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu118 ;;
2.1.1) pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu118 ;;
2.1.2) pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 ;;
2.2.0) pip install xformers==0.0.24 --index-url https://download.pytorch.org/whl/cu118 ;;
2.2.1) pip install xformers==0.0.25 --index-url https://download.pytorch.org/whl/cu118 ;;
2.2.2) pip install xformers==0.0.25.post1 --index-url https://download.pytorch.org/whl/cu118 ;;
2.3.0) pip install xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu118 ;;
2.4.0) pip install xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu118 ;;
2.4.1) pip install xformers==0.0.28 --index-url https://download.pytorch.org/whl/cu118 ;;
2.5.0) pip install xformers==0.0.28.post2 --index-url https://download.pytorch.org/whl/cu118 ;;
*) echo "[XFORMERS] Unsupported PyTorch & CUDA version: $PYTORCH_VERSION & $CUDA_VERSION" ;;
esac
elif [ "$CUDA_VERSION" = "12.1" ] ; then
case $PYTORCH_VERSION in
2.1.0) pip install xformers==0.0.22.post7 --index-url https://download.pytorch.org/whl/cu121 ;;
2.1.1) pip install xformers==0.0.23 --index-url https://download.pytorch.org/whl/cu121 ;;
2.1.2) pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121 ;;
2.2.0) pip install xformers==0.0.24 --index-url https://download.pytorch.org/whl/cu121 ;;
2.2.1) pip install xformers==0.0.25 --index-url https://download.pytorch.org/whl/cu121 ;;
2.2.2) pip install xformers==0.0.25.post1 --index-url https://download.pytorch.org/whl/cu121 ;;
2.3.0) pip install xformers==0.0.26.post1 --index-url https://download.pytorch.org/whl/cu121 ;;
2.4.0) pip install xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu121 ;;
2.4.1) pip install xformers==0.0.28 --index-url https://download.pytorch.org/whl/cu121 ;;
2.5.0) pip install xformers==0.0.28.post2 --index-url https://download.pytorch.org/whl/cu121 ;;
*) echo "[XFORMERS] Unsupported PyTorch & CUDA version: $PYTORCH_VERSION & $CUDA_VERSION" ;;
esac
elif [ "$CUDA_VERSION" = "12.4" ] ; then
case $PYTORCH_VERSION in
2.5.0) pip install xformers==0.0.28.post2 --index-url https://download.pytorch.org/whl/cu124 ;;
*) echo "[XFORMERS] Unsupported PyTorch & CUDA version: $PYTORCH_VERSION & $CUDA_VERSION" ;;
esac
else
echo "[XFORMERS] Unsupported CUDA version: $CUDA_MAJOR_VERSION"
fi
elif [ "$PLATFORM" = "hip" ] ; then
case $PYTORCH_VERSION in
2.4.1\+rocm6.1) pip install xformers==0.0.28 --index-url https://download.pytorch.org/whl/rocm6.1 ;;
*) echo "[XFORMERS] Unsupported PyTorch version: $PYTORCH_VERSION" ;;
esac
else
echo "[XFORMERS] Unsupported platform: $PLATFORM"
fi
fi
if [ "$FLASHATTN" = true ] ; then
if [ "$PLATFORM" = "cuda" ] ; then
pip install flash-attn
elif [ "$PLATFORM" = "hip" ] ; then
echo "[FLASHATTN] Prebuilt binaries not found. Building from source..."
mkdir -p /tmp/extensions
git clone --recursive https://github.com/ROCm/flash-attention.git /tmp/extensions/flash-attention
cd /tmp/extensions/flash-attention
git checkout tags/v2.6.3-cktile
GPU_ARCHS=gfx942 python setup.py install #MI300 series
cd $WORKDIR
else
echo "[FLASHATTN] Unsupported platform: $PLATFORM"
fi
fi
if [ "$KAOLIN" = true ] ; then
# install kaolin
if [ "$PLATFORM" = "cuda" ] ; then
case $PYTORCH_VERSION in
2.0.1) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.0.1_cu118.html;;
2.1.0) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.1.0_cu118.html;;
2.1.1) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.1.1_cu118.html;;
2.2.0) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.2.0_cu118.html;;
2.2.1) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.2.1_cu118.html;;
2.2.2) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.2.2_cu118.html;;
2.4.0) pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.4.0_cu121.html;;
*) echo "[KAOLIN] Unsupported PyTorch version: $PYTORCH_VERSION" ;;
esac
else
echo "[KAOLIN] Unsupported platform: $PLATFORM"
fi
fi
if [ "$NVDIFFRAST" = true ] ; then
if [ "$PLATFORM" = "cuda" ] ; then
mkdir -p /tmp/extensions
git clone https://github.com/NVlabs/nvdiffrast.git /tmp/extensions/nvdiffrast
pip install /tmp/extensions/nvdiffrast
else
echo "[NVDIFFRAST] Unsupported platform: $PLATFORM"
fi
fi
if [ "$DIFFOCTREERAST" = true ] ; then
if [ "$PLATFORM" = "cuda" ] ; then
mkdir -p /tmp/extensions
git clone --recurse-submodules https://github.com/JeffreyXiang/diffoctreerast.git /tmp/extensions/diffoctreerast
pip install /tmp/extensions/diffoctreerast
else
echo "[DIFFOCTREERAST] Unsupported platform: $PLATFORM"
fi
fi
if [ "$MIPGAUSSIAN" = true ] ; then
if [ "$PLATFORM" = "cuda" ] ; then
mkdir -p /tmp/extensions
git clone https://github.com/autonomousvision/mip-splatting.git /tmp/extensions/mip-splatting
pip install /tmp/extensions/mip-splatting/submodules/diff-gaussian-rasterization/
else
echo "[MIPGAUSSIAN] Unsupported platform: $PLATFORM"
fi
fi
if [ "$VOX2SEQ" = true ] ; then
if [ "$PLATFORM" = "cuda" ] ; then
mkdir -p /tmp/extensions
cp -r extensions/vox2seq /tmp/extensions/vox2seq
pip install /tmp/extensions/vox2seq
else
echo "[VOX2SEQ] Unsupported platform: $PLATFORM"
fi
fi
if [ "$SPCONV" = true ] ; then
# install spconv
if [ "$PLATFORM" = "cuda" ] ; then
case $CUDA_MAJOR_VERSION in
11) pip install spconv-cu118 ;;
12) pip install spconv-cu120 ;;
*) echo "[SPCONV] Unsupported PyTorch CUDA version: $CUDA_MAJOR_VERSION" ;;
esac
else
echo "[SPCONV] Unsupported platform: $PLATFORM"
fi
fi
if [ "$DEMO" = true ] ; then
pip install gradio==4.44.1 gradio_litmodel3d==0.0.1
fi

7
test.py Normal file
View File

@@ -0,0 +1,7 @@
object_path = "test/3d_result/glb/543570111d344552b080ff6f875e4e83.glb"
path_parts = object_path.split("/", 1)
bucket_name = path_parts[0]
object_name = path_parts[1]
print(object_name)
print(bucket_name)

View File

@@ -0,0 +1,32 @@
import argparse
import os
import subprocess
def render_glb_preview(glb_path, output_path):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cmd = [
"blender",
"--background",
"--python",
"render_model.py",
"--",
glb_path,
output_path
]
result = subprocess.run(cmd, capture_output=True, text=True)
print(result)
if result.returncode != 0:
raise RuntimeError(
f"Blender render failed\nstdout:\n{result.stdout}\nstderr:\n{result.stderr}"
)
return output_path
if __name__ == '__main__':
x = render_glb_preview(glb_path='glb_output/sample_20260316_113848_a956a189.glb',
output_path='glb_output/static_model_image_20260316_113856_57ad56d2.png')
print(x)

158
train.py Normal file
View File

@@ -0,0 +1,158 @@
import os
import sys
import json
import glob
import argparse
from easydict import EasyDict as edict
import torch
import torch.multiprocessing as mp
import numpy as np
import random
from trellis import models, datasets, trainers
from trellis.utils.dist_utils import setup_dist
def find_ckpt(cfg):
# Load checkpoint
cfg['load_ckpt'] = None
if cfg.load_dir != '':
if cfg.ckpt == 'latest':
files = glob.glob(os.path.join(cfg.load_dir, 'ckpts', 'misc_*.pt'))
if len(files) != 0:
cfg.load_ckpt = max([
int(os.path.basename(f).split('step')[-1].split('.')[0])
for f in files
])
elif cfg.ckpt == 'none':
cfg.load_ckpt = None
else:
cfg.load_ckpt = int(cfg.ckpt)
return cfg
def setup_rng(rank):
torch.manual_seed(rank)
torch.cuda.manual_seed_all(rank)
np.random.seed(rank)
random.seed(rank)
def get_model_summary(model):
model_summary = 'Parameters:\n'
model_summary += '=' * 128 + '\n'
model_summary += f'{"Name":<{72}}{"Shape":<{32}}{"Type":<{16}}{"Grad"}\n'
num_params = 0
num_trainable_params = 0
for name, param in model.named_parameters():
model_summary += f'{name:<{72}}{str(param.shape):<{32}}{str(param.dtype):<{16}}{param.requires_grad}\n'
num_params += param.numel()
if param.requires_grad:
num_trainable_params += param.numel()
model_summary += '\n'
model_summary += f'Number of parameters: {num_params}\n'
model_summary += f'Number of trainable parameters: {num_trainable_params}\n'
return model_summary
def main(local_rank, cfg):
# Set up distributed training
rank = cfg.node_rank * cfg.num_gpus + local_rank
world_size = cfg.num_nodes * cfg.num_gpus
if world_size > 1:
setup_dist(rank, local_rank, world_size, cfg.master_addr, cfg.master_port)
# Seed rngs
setup_rng(rank)
# Load data
dataset = getattr(datasets, cfg.dataset.name)(cfg.data_dir, **cfg.dataset.args)
# Build model
model_dict = {
name: getattr(models, model.name)(**model.args).cuda()
for name, model in cfg.models.items()
}
# Model summary
if rank == 0:
for name, backbone in model_dict.items():
model_summary = get_model_summary(backbone)
print(f'\n\nBackbone: {name}\n' + model_summary)
with open(os.path.join(cfg.output_dir, f'{name}_model_summary.txt'), 'w') as fp:
print(model_summary, file=fp)
# Build trainer
trainer = getattr(trainers, cfg.trainer.name)(model_dict, dataset, **cfg.trainer.args, output_dir=cfg.output_dir, load_dir=cfg.load_dir, step=cfg.load_ckpt)
# Train
if not cfg.tryrun:
if cfg.profile:
trainer.profile()
else:
trainer.run()
if __name__ == '__main__':
# Arguments and config
parser = argparse.ArgumentParser()
## config
parser.add_argument('--config', type=str, required=True, help='Experiment config file')
## io and resume
parser.add_argument('--output_dir', type=str, required=True, help='Output directory')
parser.add_argument('--load_dir', type=str, default='', help='Load directory, default to output_dir')
parser.add_argument('--ckpt', type=str, default='latest', help='Checkpoint step to resume training, default to latest')
parser.add_argument('--data_dir', type=str, default='./data/', help='Data directory')
parser.add_argument('--auto_retry', type=int, default=3, help='Number of retries on error')
## dubug
parser.add_argument('--tryrun', action='store_true', help='Try run without training')
parser.add_argument('--profile', action='store_true', help='Profile training')
## multi-node and multi-gpu
parser.add_argument('--num_nodes', type=int, default=1, help='Number of nodes')
parser.add_argument('--node_rank', type=int, default=0, help='Node rank')
parser.add_argument('--num_gpus', type=int, default=-1, help='Number of GPUs per node, default to all')
parser.add_argument('--master_addr', type=str, default='localhost', help='Master address for distributed training')
parser.add_argument('--master_port', type=str, default='12345', help='Port for distributed training')
opt = parser.parse_args()
opt.load_dir = opt.load_dir if opt.load_dir != '' else opt.output_dir
opt.num_gpus = torch.cuda.device_count() if opt.num_gpus == -1 else opt.num_gpus
## Load config
config = json.load(open(opt.config, 'r'))
## Combine arguments and config
cfg = edict()
cfg.update(opt.__dict__)
cfg.update(config)
print('\n\nConfig:')
print('=' * 80)
print(json.dumps(cfg.__dict__, indent=4))
# Prepare output directory
if cfg.node_rank == 0:
os.makedirs(cfg.output_dir, exist_ok=True)
## Save command and config
with open(os.path.join(cfg.output_dir, 'command.txt'), 'w') as fp:
print(' '.join(['python'] + sys.argv), file=fp)
with open(os.path.join(cfg.output_dir, 'config.json'), 'w') as fp:
json.dump(config, fp, indent=4)
# Run
if cfg.auto_retry == 0:
cfg = find_ckpt(cfg)
if cfg.num_gpus > 1:
mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True)
else:
main(0, cfg)
else:
for rty in range(cfg.auto_retry):
try:
cfg = find_ckpt(cfg)
if cfg.num_gpus > 1:
mp.spawn(main, args=(cfg,), nprocs=cfg.num_gpus, join=True)
else:
main(0, cfg)
break
except Exception as e:
print(f'Error: {e}')
print(f'Retrying ({rty + 1}/{cfg.auto_retry})...')

View File

@@ -0,0 +1,107 @@
import os
import json
from typing import Union
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import utils3d
from .components import StandardDatasetBase
from ..representations.octree import DfsOctree as Octree
from ..renderers import OctreeRenderer
class SparseStructure(StandardDatasetBase):
"""
Sparse structure dataset
Args:
roots (str): path to the dataset
resolution (int): resolution of the voxel grid
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
"""
def __init__(self,
roots,
resolution: int = 64,
min_aesthetic_score: float = 5.0,
):
self.resolution = resolution
self.min_aesthetic_score = min_aesthetic_score
self.value_range = (0, 1)
super().__init__(roots)
def filter_metadata(self, metadata):
stats = {}
metadata = metadata[metadata[f'voxelized']]
stats['Voxelized'] = len(metadata)
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
return metadata, stats
def get_instance(self, root, instance):
position = utils3d.io.read_ply(os.path.join(root, 'voxels', f'{instance}.ply'))[0]
coords = ((torch.tensor(position) + 0.5) * self.resolution).int().contiguous()
ss = torch.zeros(1, self.resolution, self.resolution, self.resolution, dtype=torch.long)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
return {'ss': ss}
@torch.no_grad()
def visualize_sample(self, ss: Union[torch.Tensor, dict]):
ss = ss if isinstance(ss, torch.Tensor) else ss['ss']
renderer = OctreeRenderer()
renderer.rendering_options.resolution = 512
renderer.rendering_options.near = 0.8
renderer.rendering_options.far = 1.6
renderer.rendering_options.bg_color = (0, 0, 0)
renderer.rendering_options.ssaa = 4
renderer.pipe.primitive = 'voxel'
# Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
exts = []
ints = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(30)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
exts.append(extrinsics)
ints.append(intrinsics)
images = []
# Build each representation
ss = ss.cuda()
for i in range(ss.shape[0]):
representation = Octree(
depth=10,
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
device='cuda',
primitive='voxel',
sh_degree=0,
primitive_config={'solid': True},
)
coords = torch.nonzero(ss[i, 0], as_tuple=False)
representation.position = coords.float() / self.resolution
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
res = renderer.render(representation, ext, intr, colors_overwrite=representation.position)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images.append(image)
return torch.stack(images)

View File

@@ -0,0 +1,135 @@
from typing import *
import torch
import math
from .. import SparseTensor
from .. import DEBUG, ATTN
if ATTN == 'xformers':
import xformers.ops as xops
elif ATTN == 'flash_attn':
import flash_attn
else:
raise ValueError(f"Unknown attention module: {ATTN}")
__all__ = [
'sparse_windowed_scaled_dot_product_self_attention',
]
def calc_window_partition(
tensor: SparseTensor,
window_size: Union[int, Tuple[int, ...]],
shift_window: Union[int, Tuple[int, ...]] = 0
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
"""
Calculate serialization and partitioning for a set of coordinates.
Args:
tensor (SparseTensor): The input tensor.
window_size (int): The window size to use.
shift_window (Tuple[int, ...]): The shift of serialized coordinates.
Returns:
(torch.Tensor): Forwards indices.
(torch.Tensor): Backwards indices.
(List[int]): Sequence lengths.
(List[int]): Sequence batch indices.
"""
DIM = tensor.coords.shape[1] - 1
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
shifted_coords = tensor.coords.clone().detach()
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
MAX_COORDS = shifted_coords[:, 1:].max(dim=0).values.tolist()
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
fwd_indices = torch.argsort(shifted_indices)
bwd_indices = torch.empty_like(fwd_indices)
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
seq_lens = torch.bincount(shifted_indices)
seq_batch_indices = torch.arange(seq_lens.shape[0], device=tensor.device, dtype=torch.int32) // OFFSET[0]
mask = seq_lens != 0
seq_lens = seq_lens[mask].tolist()
seq_batch_indices = seq_batch_indices[mask].tolist()
return fwd_indices, bwd_indices, seq_lens, seq_batch_indices
def sparse_windowed_scaled_dot_product_self_attention(
qkv: SparseTensor,
window_size: int,
shift_window: Tuple[int, int, int] = (0, 0, 0)
) -> SparseTensor:
"""
Apply windowed scaled dot product self attention to a sparse tensor.
Args:
qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
window_size (int): The window size to use.
shift_window (Tuple[int, int, int]): The shift of serialized coordinates.
shift (int): The shift to use.
"""
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
serialization_spatial_cache_name = f'window_partition_{window_size}_{shift_window}'
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
if serialization_spatial_cache is None:
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_window_partition(qkv, window_size, shift_window)
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices))
else:
fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache
M = fwd_indices.shape[0]
T = qkv.feats.shape[0]
H = qkv.feats.shape[2]
C = qkv.feats.shape[3]
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
if DEBUG:
start = 0
qkv_coords = qkv.coords[fwd_indices]
for i in range(len(seq_lens)):
seq_coords = qkv_coords[start:start+seq_lens[i]]
assert (seq_coords[:, 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch"
assert (seq_coords[:, 1:].max(dim=0).values - seq_coords[:, 1:].min(dim=0).values < window_size).all(), \
f"SparseWindowedScaledDotProductSelfAttention: window size exceeded"
start += seq_lens[i]
if all([seq_len == window_size for seq_len in seq_lens]):
B = len(seq_lens)
N = window_size
qkv_feats = qkv_feats.reshape(B, N, 3, H, C)
if ATTN == 'xformers':
q, k, v = qkv_feats.unbind(dim=2) # [B, N, H, C]
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
elif ATTN == 'flash_attn':
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
else:
raise ValueError(f"Unknown attention module: {ATTN}")
out = out.reshape(B * N, H, C) # [M, H, C]
else:
if ATTN == 'xformers':
q, k, v = qkv_feats.unbind(dim=1) # [M, H, C]
q = q.unsqueeze(0) # [1, M, H, C]
k = k.unsqueeze(0) # [1, M, H, C]
v = v.unsqueeze(0) # [1, M, H, C]
mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C]
elif ATTN == 'flash_attn':
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
.to(qkv.device).int()
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
out = out[bwd_indices] # [T, H, C]
if DEBUG:
qkv_coords = qkv_coords[bwd_indices]
assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch"
return qkv.replace(out)

54
trellis/modules/utils.py Executable file
View File

@@ -0,0 +1,54 @@
import torch.nn as nn
from ..modules import sparse as sp
FP16_MODULES = (
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
sp.SparseConv3d,
sp.SparseInverseConv3d,
sp.SparseLinear,
)
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, FP16_MODULES):
for p in l.parameters():
p.data = p.data.half()
def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, FP16_MODULES):
for p in l.parameters():
p.data = p.data.float()
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

View File

@@ -0,0 +1,375 @@
from typing import *
from contextlib import contextmanager
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from PIL import Image
import rembg
from .base import Pipeline
from . import samplers
from ..modules import sparse as sp
class TrellisImageTo3DPipeline(Pipeline):
"""
Pipeline for inferring Trellis image-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
slat_sampler (samplers.Sampler): The sampler for the structured latent.
slat_normalization (dict): The normalization parameters for the structured latent.
image_cond_model (str): The name of the image conditioning model.
"""
def __init__(
self,
models: dict[str, nn.Module] = None,
sparse_structure_sampler: samplers.Sampler = None,
slat_sampler: samplers.Sampler = None,
slat_normalization: dict = None,
image_cond_model: str = None,
):
if models is None:
return
super().__init__(models)
self.sparse_structure_sampler = sparse_structure_sampler
self.slat_sampler = slat_sampler
self.sparse_structure_sampler_params = {}
self.slat_sampler_params = {}
self.slat_normalization = slat_normalization
self.rembg_session = None
self._init_image_cond_model(image_cond_model)
@staticmethod
def from_pretrained(path: str) -> "TrellisImageTo3DPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path)
new_pipeline = TrellisImageTo3DPipeline()
new_pipeline.__dict__ = pipeline.__dict__
args = pipeline._pretrained_args
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
new_pipeline.slat_sampler_params = args['slat_sampler']['params']
new_pipeline.slat_normalization = args['slat_normalization']
new_pipeline._init_image_cond_model(args['image_cond_model'])
return new_pipeline
def _init_image_cond_model(self, name: str):
"""
Initialize the image conditioning model.
"""
dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True)
dinov2_model.eval()
self.models['image_cond_model'] = dinov2_model
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.image_cond_model_transform = transform
def preprocess_image(self, input: Image.Image) -> Image.Image:
"""
Preprocess the input image.
"""
# if has alpha channel, use it directly; otherwise, remove background
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
if has_alpha:
output = input
else:
input = input.convert('RGB')
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
if getattr(self, 'rembg_session', None) is None:
self.rembg_session = rembg.new_session('u2net')
output = rembg.remove(input, session=self.rembg_session)
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1.2)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = output.crop(bbox) # type: ignore
output = output.resize((518, 518), Image.Resampling.LANCZOS)
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
return output
@torch.no_grad()
def encode_image(self, image: Union[torch.Tensor, list[Image.Image]]) -> torch.Tensor:
"""
Encode the image.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image to encode
Returns:
torch.Tensor: The encoded features.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).to(self.device)
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
image = self.image_cond_model_transform(image).to(self.device)
features = self.models['image_cond_model'](image, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens
def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: The conditioning information
"""
cond = self.encode_image(image)
neg_cond = torch.zeros_like(cond)
return {
'cond': cond,
'neg_cond': neg_cond,
}
def sample_sparse_structure(
self,
cond: dict,
num_samples: int = 1,
sampler_params: dict = {},
) -> torch.Tensor:
"""
Sample sparse structures with the given conditioning.
Args:
cond (dict): The conditioning information.
num_samples (int): The number of samples to generate.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample occupancy latent
flow_model = self.models['sparse_structure_flow_model']
reso = flow_model.resolution
noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
z_s = self.sparse_structure_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True
).samples
# Decode occupancy latent
decoder = self.models['sparse_structure_decoder']
coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
return coords
def decode_slat(
self,
slat: sp.SparseTensor,
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
) -> dict:
"""
Decode the structured latent.
Args:
slat (sp.SparseTensor): The structured latent.
formats (List[str]): The formats to decode the structured latent to.
Returns:
dict: The decoded structured latent.
"""
ret = {}
if 'mesh' in formats:
ret['mesh'] = self.models['slat_decoder_mesh'](slat)
if 'gaussian' in formats:
ret['gaussian'] = self.models['slat_decoder_gs'](slat)
if 'radiance_field' in formats:
ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
return ret
def sample_slat(
self,
cond: dict,
coords: torch.Tensor,
sampler_params: dict = {},
) -> sp.SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
coords (torch.Tensor): The coordinates of the sparse structure.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
flow_model = self.models['slat_flow_model']
noise = sp.SparseTensor(
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
coords=coords,
)
sampler_params = {**self.slat_sampler_params, **sampler_params}
slat = self.slat_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True
).samples
std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat
@torch.no_grad()
def run(
self,
image: Image.Image,
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
slat_sampler_params: dict = {},
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
preprocess_image: bool = True,
) -> dict:
"""
Run the pipeline.
Args:
image (Image.Image): The image prompt.
num_samples (int): The number of samples to generate.
seed (int): The random seed.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
formats (List[str]): The formats to decode the structured latent to.
preprocess_image (bool): Whether to preprocess the image.
"""
if preprocess_image:
image = self.preprocess_image(image)
cond = self.get_cond([image])
torch.manual_seed(seed)
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
slat = self.sample_slat(cond, coords, slat_sampler_params)
return self.decode_slat(slat, formats)
@contextmanager
def inject_sampler_multi_image(
self,
sampler_name: str,
num_images: int,
num_steps: int,
mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
):
"""
Inject a sampler with multiple images as condition.
Args:
sampler_name (str): The name of the sampler to inject.
num_images (int): The number of images to condition on.
num_steps (int): The number of steps to run the sampler for.
"""
sampler = getattr(self, sampler_name)
setattr(sampler, f'_old_inference_model', sampler._inference_model)
if mode == 'stochastic':
if num_images > num_steps:
print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. "
"This may lead to performance degradation.\033[0m")
cond_indices = (np.arange(num_steps) % num_images).tolist()
def _new_inference_model(self, model, x_t, t, cond, **kwargs):
cond_idx = cond_indices.pop(0)
cond_i = cond[cond_idx:cond_idx+1]
return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs)
elif mode =='multidiffusion':
from .samplers import FlowEulerSampler
def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
if cfg_interval[0] <= t <= cfg_interval[1]:
preds = []
for i in range(len(cond)):
preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
pred = sum(preds) / len(preds)
neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
else:
preds = []
for i in range(len(cond)):
preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
pred = sum(preds) / len(preds)
return pred
else:
raise ValueError(f"Unsupported mode: {mode}")
sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler))
yield
sampler._inference_model = sampler._old_inference_model
delattr(sampler, f'_old_inference_model')
@torch.no_grad()
def run_multi_image(
self,
images: List[Image.Image],
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
slat_sampler_params: dict = {},
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
preprocess_image: bool = True,
mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
) -> dict:
"""
Run the pipeline with multiple images as condition
Args:
images (List[Image.Image]): The multi-view images of the assets
num_samples (int): The number of samples to generate.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
preprocess_image (bool): Whether to preprocess the image.
"""
if preprocess_image:
images = [self.preprocess_image(image) for image in images]
cond = self.get_cond(images)
cond['neg_cond'] = cond['neg_cond'][:1]
torch.manual_seed(seed)
ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps')
with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode):
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps')
with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode):
slat = self.sample_slat(cond, coords, slat_sampler_params)
return self.decode_slat(slat, formats)

View File

@@ -0,0 +1,278 @@
from typing import *
import torch
import torch.nn as nn
import numpy as np
from transformers import CLIPTextModel, AutoTokenizer
import open3d as o3d
from .base import Pipeline
from . import samplers
from ..modules import sparse as sp
class TrellisTextTo3DPipeline(Pipeline):
"""
Pipeline for inferring Trellis text-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
slat_sampler (samplers.Sampler): The sampler for the structured latent.
slat_normalization (dict): The normalization parameters for the structured latent.
text_cond_model (str): The name of the text conditioning model.
"""
def __init__(
self,
models: dict[str, nn.Module] = None,
sparse_structure_sampler: samplers.Sampler = None,
slat_sampler: samplers.Sampler = None,
slat_normalization: dict = None,
text_cond_model: str = None,
):
if models is None:
return
super().__init__(models)
self.sparse_structure_sampler = sparse_structure_sampler
self.slat_sampler = slat_sampler
self.sparse_structure_sampler_params = {}
self.slat_sampler_params = {}
self.slat_normalization = slat_normalization
self._init_text_cond_model(text_cond_model)
@staticmethod
def from_pretrained(path: str) -> "TrellisTextTo3DPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline = super(TrellisTextTo3DPipeline, TrellisTextTo3DPipeline).from_pretrained(path)
new_pipeline = TrellisTextTo3DPipeline()
new_pipeline.__dict__ = pipeline.__dict__
args = pipeline._pretrained_args
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
new_pipeline.slat_sampler_params = args['slat_sampler']['params']
new_pipeline.slat_normalization = args['slat_normalization']
new_pipeline._init_text_cond_model(args['text_cond_model'])
return new_pipeline
def _init_text_cond_model(self, name: str):
"""
Initialize the text conditioning model.
"""
# load model
model = CLIPTextModel.from_pretrained(name)
tokenizer = AutoTokenizer.from_pretrained(name)
model.eval()
model = model.cuda()
self.text_cond_model = {
'model': model,
'tokenizer': tokenizer,
}
self.text_cond_model['null_cond'] = self.encode_text([''])
@torch.no_grad()
def encode_text(self, text: List[str]) -> torch.Tensor:
"""
Encode the text.
"""
assert isinstance(text, list) and all(isinstance(t, str) for t in text), "text must be a list of strings"
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
tokens = encoding['input_ids'].cuda()
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
return embeddings
def get_cond(self, prompt: List[str]) -> dict:
"""
Get the conditioning information for the model.
Args:
prompt (List[str]): The text prompt.
Returns:
dict: The conditioning information
"""
cond = self.encode_text(prompt)
neg_cond = self.text_cond_model['null_cond']
return {
'cond': cond,
'neg_cond': neg_cond,
}
def sample_sparse_structure(
self,
cond: dict,
num_samples: int = 1,
sampler_params: dict = {},
) -> torch.Tensor:
"""
Sample sparse structures with the given conditioning.
Args:
cond (dict): The conditioning information.
num_samples (int): The number of samples to generate.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample occupancy latent
flow_model = self.models['sparse_structure_flow_model']
reso = flow_model.resolution
noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
z_s = self.sparse_structure_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True
).samples
# Decode occupancy latent
decoder = self.models['sparse_structure_decoder']
coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
return coords
def decode_slat(
self,
slat: sp.SparseTensor,
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
) -> dict:
"""
Decode the structured latent.
Args:
slat (sp.SparseTensor): The structured latent.
formats (List[str]): The formats to decode the structured latent to.
Returns:
dict: The decoded structured latent.
"""
ret = {}
if 'mesh' in formats:
ret['mesh'] = self.models['slat_decoder_mesh'](slat)
if 'gaussian' in formats:
ret['gaussian'] = self.models['slat_decoder_gs'](slat)
if 'radiance_field' in formats:
ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
return ret
def sample_slat(
self,
cond: dict,
coords: torch.Tensor,
sampler_params: dict = {},
) -> sp.SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
coords (torch.Tensor): The coordinates of the sparse structure.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
flow_model = self.models['slat_flow_model']
noise = sp.SparseTensor(
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
coords=coords,
)
sampler_params = {**self.slat_sampler_params, **sampler_params}
slat = self.slat_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True
).samples
std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat
@torch.no_grad()
def run(
self,
prompt: str,
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
slat_sampler_params: dict = {},
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
) -> dict:
"""
Run the pipeline.
Args:
prompt (str): The text prompt.
num_samples (int): The number of samples to generate.
seed (int): The random seed.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
formats (List[str]): The formats to decode the structured latent to.
"""
cond = self.get_cond([prompt])
torch.manual_seed(seed)
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params)
slat = self.sample_slat(cond, coords, slat_sampler_params)
return self.decode_slat(slat, formats)
def voxelize(self, mesh: o3d.geometry.TriangleMesh) -> torch.Tensor:
"""
Voxelize a mesh.
Args:
mesh (o3d.geometry.TriangleMesh): The mesh to voxelize.
sha256 (str): The SHA256 hash of the mesh.
output_dir (str): The output directory.
"""
vertices = np.asarray(mesh.vertices)
aabb = np.stack([vertices.min(0), vertices.max(0)])
center = (aabb[0] + aabb[1]) / 2
scale = (aabb[1] - aabb[0]).max()
vertices = (vertices - center) / scale
vertices = np.clip(vertices, -0.5 + 1e-6, 0.5 - 1e-6)
mesh.vertices = o3d.utility.Vector3dVector(vertices)
voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh_within_bounds(mesh, voxel_size=1/64, min_bound=(-0.5, -0.5, -0.5), max_bound=(0.5, 0.5, 0.5))
vertices = np.array([voxel.grid_index for voxel in voxel_grid.get_voxels()])
return torch.tensor(vertices).int().cuda()
@torch.no_grad()
def run_variant(
self,
mesh: o3d.geometry.TriangleMesh,
prompt: str,
num_samples: int = 1,
seed: int = 42,
slat_sampler_params: dict = {},
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
) -> dict:
"""
Run the pipeline for making variants of an asset.
Args:
mesh (o3d.geometry.TriangleMesh): The base mesh.
prompt (str): The text prompt.
num_samples (int): The number of samples to generate.
seed (int): The random seed
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
formats (List[str]): The formats to decode the structured latent to.
"""
cond = self.get_cond([prompt])
coords = self.voxelize(mesh)
coords = torch.cat([
torch.arange(num_samples).repeat_interleave(coords.shape[0], 0)[:, None].int().cuda(),
coords.repeat(num_samples, 1)
], 1)
torch.manual_seed(seed)
slat = self.sample_slat(cond, coords, slat_sampler_params)
return self.decode_slat(slat, formats)

View File

@@ -0,0 +1,114 @@
## Flexible Isosurface Extraction for Gradient-Based Mesh Optimization (FlexiCubes)<br><sub>Official PyTorch implementation </sub>
![Teaser image](<images/teaser_top.png>)
FlexiCubes is a high-quality isosurface representation specifically designed for gradient-based mesh optimization with respect to geometric, visual, or even physical objectives. For more details, please refer to our [paper](https://arxiv.org/abs/2308.05371) and [project page](https://research.nvidia.com/labs/toronto-ai/flexicubes/).
## Highlights
* [Getting started](https://github.com/nv-tlabs/FlexiCubes#getting-started)
* [Basic workflow](https://github.com/nv-tlabs/FlexiCubes#example-usage)
* [nvdiffrec: image-based reconstruction example](https://github.com/NVlabs/nvdiffrec#news)
* [GET3D: generative AI example](https://github.com/nv-tlabs/GET3D#employing-flexicubes)
* [Bibtex](https://github.com/nv-tlabs/FlexiCubes#citation)
## Getting Started
The core functions of FlexiCubes are now in [Kaolin](https://github.com/NVIDIAGameWorks/kaolin/) starting from v0.15.0. See installation instructions [here](https://kaolin.readthedocs.io/en/latest/notes/installation.html) and API documentations [here](https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.conversions.html?highlight=flexicubes#kaolin.ops.conversions.FlexiCubes)
The original code of the paper is still visible in `flexicube.py`.
## Example Usage
### Gradient-Based Mesh Optimization
We provide examples demonstrating how to use FlexiCubes for reconstructing unknown meshes through gradient-based optimization. Specifically, starting from randomly initialized SDF, we optimize the shape towards the reference mesh by minimizing their geometric difference, measured by multiview mask and depth losses. This workflow is a simplified version of `nvdiffrec` with code largely borrowed from the [nvdiffrec GitHub](https://github.com/NVlabs/nvdiffrec). We use the same pipeline to conduct the analysis in Section 3 and the main experiments described in Section 5 of our paper. We provide a detailed tutorial in `examples/optimization.ipynb`, along with an optimization script in `examples/optimize.py` which accepts command-line arguments.
To run the examples, it is suggested to install the Conda environment as detailed below:
```sh
conda create -n flexicubes python=3.9
conda activate flexicubes
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
pip install imageio trimesh tqdm matplotlib torch_scatter ninja
pip install git+https://github.com/NVlabs/nvdiffrast/
pip install kaolin==0.15.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.12.0_cu113.html
```
Then download the dataset collected by [Myles et al.](https://vcg.isti.cnr.it/Publications/2014/MPZ14/) as follows. We include one shape in 'examples/data/inputmodels/block.obj' if you want to test without downloading the full dataset.
```sh
cd examples
python download_data.py
```
After downloading the data, run shape optimization with the following example command:
```sh
python optimize.py --ref_mesh data/inputmodels/block.obj --out_dir out/block
```
You can find visualization and output meshes in the `out/block`. Below, we show the initial and final shapes during optimization, with the reference shape on the right.
<img src="images/block_init.png" alt="block_init" width="80%" height="80%">
<img src="images/block_final.png" alt="block_final" width="80%" height="80%">
To further demonstrate the flexibility of our FlexiCubes representation, which can accommodates both reconstruction objectives and regularizers defined on the extracted mesh, you can add a developability regularizer (proposed by [Stein et al.](https://www.cs.cmu.edu/~kmcrane/Projects/DiscreteDevelopable/)) to the previous reconstruction pipeline to encourage fabricability from panels:
```sh
python optimize.py --ref_mesh data/inputmodels/david.obj --out_dir out/david_dev --develop_reg True --iter=1250
```
### Extract mesh from known signed distance field
While not its designated use case, our function can extract a mesh from a known Signed Distance Field (SDF) without optimization. Please refer to the tutorial found in `examples/extraction.ipynb` for details.
## Tips for using FlexiCubes
### Regularization losses:
We commonly use three regularizers in our mesh optimization pipelines, referenced in lines `L104-L106` in `examples/optimize.py`. The weights of these regularizers should be scaled according to the your application objectives. Initially, it is suggested to employ low weights because strong regularization can hinder convergence. You can incrementally increase the weights if you notice artifacts appearing in the optimized meshes. Specifically:
* The loss function at `L104` helps to remove floaters in areas of the shape that are not supervised by the application objective, such as internal faces when using image supervision only.
* The L_dev loss at `L105` can be increased if you observe artifacts in flat areas, as illustrated in the image below.
* Generally, the L1 regularizer on flexible weights at `L106` does not have a significant impact during the optimization of a single shape. However, we found it to be effective in stabilizing training in generative pipelines such as GET3D.
<img src="images/ablate_L_dev.jpg" alt="Ablating L_dev" width="80%" height="80%">
### Resolution of voxel grid vs. tetrahedral grid:
If you are switching from our previous work, DMTet, it's important to note the difference in grid resolution when compared to FlexiCubes. In both implementations, the resolution is defined by the edge length: a grid resolution of `n` means the grid edge length is 1/n for both the voxel and tetrahedral grids. However, a tetrahedral grid with a resolution of `n` contains only `(n/2+1)³` grid vertices, in contrast to the `(n+1)³` vertices in a voxel grid. Consequently, if you are switching from DMTet to FlexiCubes while maintaining the same resolution, you will notice not only a denser output mesh but also a substantial increase in computational cost. To align the triangle count in the output meshes more closely, we recommend adopting a 4:5 resolution ratio between the voxel grid and the tetrahedral grid. For instance, in our paper, `64³` FlexiCubes generate approximately the same number of triangles as `80³` DMTet.
## Applications
FlexiCubes is now integrated into NVIDIA applications as a drop-in replacement for DMTet. You can visit their GitHub pages to see how FlexiCubes is used in advanced photogrammetry and 3D generative pipelines.
[Extracting Triangular 3D Models, Materials, and Lighting From Images (nvdiffrec)](https://github.com/NVlabs/nvdiffrec#news)
[GET3D: A Generative Model of High Quality 3D Textured Shapes Learned from Images](https://github.com/nv-tlabs/GET3D#employing-flexicubes)
## License
Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
This work is made available under the [Apache License 2.0](LICENSE.txt).
## Contributing
This project uses the [Developer Certificate of Origin 1.1](https://developercertificate.org/) to manage contributions.
By submitting a pull request, you certify that you agree to the terms of the DCO. See [DCO.txt](./DCO.txt) for details.
## Citation
```bibtex
@article{shen2023flexicubes,
author = {Shen, Tianchang and Munkberg, Jacob and Hasselgren, Jon and Yin, Kangxue and Wang, Zian
and Chen, Wenzheng and Gojcic, Zan and Fidler, Sanja and Sharp, Nicholas and Gao, Jun},
title = {Flexible Isosurface Extraction for Gradient-Based Mesh Optimization},
year = {2023},
issue_date = {August 2023},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
volume = {42},
number = {4},
issn = {0730-0301},
url = {https://doi.org/10.1145/3592430},
doi = {10.1145/3592430},
journal = {ACM Trans. Graph.},
month = {jul},
articleno = {37},
numpages = {16}
}
```

View File

@@ -0,0 +1,129 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import trimesh
import kaolin
import nvdiffrast.torch as dr
###############################################################################
# Functions adapted from https://github.com/NVlabs/nvdiffrec
###############################################################################
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.sum(x*y, -1, keepdim=True)
def length(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor:
return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
def safe_normalize(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor:
return x / length(x, eps)
def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
y = np.tan(fovy / 2)
return torch.tensor([[1/(y*aspect), 0, 0, 0],
[ 0, 1/-y, 0, 0],
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
[ 0, 0, -1, 0]], dtype=torch.float32, device=device)
def translate(x, y, z, device=None):
return torch.tensor([[1, 0, 0, x],
[0, 1, 0, y],
[0, 0, 1, z],
[0, 0, 0, 1]], dtype=torch.float32, device=device)
@torch.no_grad()
def random_rotation_translation(t, device=None):
m = np.random.normal(size=[3, 3])
m[1] = np.cross(m[0], m[2])
m[2] = np.cross(m[0], m[1])
m = m / np.linalg.norm(m, axis=1, keepdims=True)
m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
m[3, 3] = 1.0
m[:3, 3] = np.random.uniform(-t, t, size=[3])
return torch.tensor(m, dtype=torch.float32, device=device)
def rotate_x(a, device=None):
s, c = np.sin(a), np.cos(a)
return torch.tensor([[1, 0, 0, 0],
[0, c, s, 0],
[0, -s, c, 0],
[0, 0, 0, 1]], dtype=torch.float32, device=device)
def rotate_y(a, device=None):
s, c = np.sin(a), np.cos(a)
return torch.tensor([[ c, 0, s, 0],
[ 0, 1, 0, 0],
[-s, 0, c, 0],
[ 0, 0, 0, 1]], dtype=torch.float32, device=device)
class Mesh:
def __init__(self, vertices, faces):
self.vertices = vertices
self.faces = faces
def auto_normals(self):
v0 = self.vertices[self.faces[:, 0], :]
v1 = self.vertices[self.faces[:, 1], :]
v2 = self.vertices[self.faces[:, 2], :]
nrm = safe_normalize(torch.cross(v1 - v0, v2 - v0))
self.nrm = nrm
def load_mesh(path, device):
mesh_np = trimesh.load(path)
vertices = torch.tensor(mesh_np.vertices, device=device, dtype=torch.float)
faces = torch.tensor(mesh_np.faces, device=device, dtype=torch.long)
# Normalize
vmin, vmax = vertices.min(dim=0)[0], vertices.max(dim=0)[0]
scale = 1.8 / torch.max(vmax - vmin).item()
vertices = vertices - (vmax + vmin) / 2 # Center mesh on origin
vertices = vertices * scale # Rescale to [-0.9, 0.9]
return Mesh(vertices, faces)
def compute_sdf(points, vertices, faces):
face_vertices = kaolin.ops.mesh.index_vertices_by_faces(vertices.clone().unsqueeze(0), faces)
distance = kaolin.metrics.trianglemesh.point_to_mesh_distance(points.unsqueeze(0), face_vertices)[0]
with torch.no_grad():
sign = (kaolin.ops.mesh.check_sign(vertices.unsqueeze(0), faces, points.unsqueeze(0))<1).float() * 2 - 1
sdf = (sign*distance).squeeze(0)
return sdf
def sample_random_points(n, mesh):
pts_random = (torch.rand((n//2,3),device='cuda') - 0.5) * 2
pts_surface = kaolin.ops.mesh.sample_points(mesh.vertices.unsqueeze(0), mesh.faces, 500)[0].squeeze(0)
pts_surface += torch.randn_like(pts_surface) * 0.05
pts = torch.cat([pts_random, pts_surface])
return pts
def xfm_points(points, matrix):
'''Transform points.
Args:
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
use_python: Use PyTorch's torch.matmul (for validation)
Returns:
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
'''
out = torch.matmul(
torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
if torch.is_anomaly_enabled():
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
return out
def interpolate(attr, rast, attr_idx, rast_db=None):
return dr.interpolate(
attr, rast, attr_idx, rast_db=rast_db,
diff_attrs=None if rast_db is None else 'all')

View File

@@ -0,0 +1,798 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
dmc_table = [
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
]
num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
check_table = [
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 194],
[1, -1, 0, 0, 193],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 164],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 161],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 152],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 145],
[1, 0, 0, 1, 144],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 137],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 133],
[1, 0, 1, 0, 132],
[1, 1, 0, 0, 131],
[1, 1, 0, 0, 130],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 100],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 98],
[0, 0, 0, 0, 0],
[1, 0, 0, 1, 96],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 88],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 82],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 74],
[0, 0, 0, 0, 0],
[1, 0, 1, 0, 72],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 70],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 67],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 65],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 56],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 52],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 44],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 40],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 38],
[1, 0, -1, 0, 37],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 33],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 28],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 26],
[1, 0, 0, -1, 25],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, -1, 0, 0, 20],
[0, 0, 0, 0, 0],
[1, 0, -1, 0, 18],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 9],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, -1, 6],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
]
tet_table = [
[-1, -1, -1, -1, -1, -1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[4, 4, 4, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[4, 0, 0, 4, 4, -1],
[1, 1, 1, 1, 1, 1],
[4, 4, 4, 4, 4, 4],
[0, 4, 0, 4, 4, -1],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[5, 5, 5, 5, 5, 5],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, -1, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, -1, 2, 4, 4, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 4, 4, 2],
[1, 1, 1, 1, 1, 1],
[2, 4, 2, 4, 4, 2],
[0, 4, 0, 4, 4, 0],
[2, 0, 2, 0, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 5, 2, 5, 5, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 0, 0, 2],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 1, 1, -1, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[4, 1, 1, 4, 4, 1],
[0, 1, 1, 0, 0, 1],
[4, 0, 0, 4, 4, 0],
[2, 2, 2, 2, 2, 2],
[-1, 1, 1, 4, 4, 1],
[0, 1, 1, 4, 4, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[5, 1, 1, 5, 5, 1],
[0, 1, 1, 0, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[8, 8, 8, 8, 8, 8],
[1, 1, 1, 4, 4, 1],
[0, 0, 0, 0, 0, 0],
[4, 0, 0, 4, 4, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 4, 4, 1],
[0, 4, 0, 4, 4, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 5, 5, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[5, 5, 5, 5, 5, 5],
[6, 6, 6, 6, 6, 6],
[6, -1, 0, 6, 0, 6],
[6, 0, 0, 6, 0, 6],
[6, 1, 1, 6, 1, 6],
[4, 4, 4, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[4, 0, 0, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[6, 4, -1, 6, 4, 6],
[6, 4, 0, 6, 4, 6],
[6, 0, 0, 6, 0, 6],
[6, 1, 1, 6, 1, 6],
[5, 5, 5, 5, 5, 5],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 2, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[2, 4, 2, 2, 4, 2],
[0, 4, 0, 4, 4, 0],
[2, 0, 2, 2, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[6, 1, 1, 6, -1, 6],
[6, 1, 1, 6, 0, 6],
[6, 0, 0, 6, 0, 6],
[6, 2, 2, 6, 2, 6],
[4, 1, 1, 4, 4, 1],
[0, 1, 1, 0, 0, 1],
[4, 0, 0, 4, 4, 4],
[2, 2, 2, 2, 2, 2],
[6, 1, 1, 6, 4, 6],
[6, 1, 1, 6, 4, 6],
[6, 0, 0, 6, 0, 6],
[6, 2, 2, 6, 2, 6],
[5, 1, 1, 5, 5, 1],
[0, 1, 1, 0, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[6, 6, 6, 6, 6, 6],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 4, 1],
[0, 4, 0, 4, 4, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 5, 0, 5, 0, 5],
[5, 5, 5, 5, 5, 5],
[5, 5, 5, 5, 5, 5],
[0, 5, 0, 5, 0, 5],
[-1, 5, 0, 5, 0, 5],
[1, 5, 1, 5, 1, 5],
[4, 5, -1, 5, 4, 5],
[0, 5, 0, 5, 0, 5],
[4, 5, 0, 5, 4, 5],
[1, 5, 1, 5, 1, 5],
[4, 4, 4, 4, 4, 4],
[0, 4, 0, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[6, 6, 6, 6, 6, 6],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 5, 2, 5, -1, 5],
[0, 5, 0, 5, 0, 5],
[2, 5, 2, 5, 0, 5],
[1, 5, 1, 5, 1, 5],
[2, 5, 2, 5, 4, 5],
[0, 5, 0, 5, 0, 5],
[2, 5, 2, 5, 4, 5],
[1, 5, 1, 5, 1, 5],
[2, 4, 2, 4, 4, 2],
[0, 4, 0, 4, 4, 4],
[2, 0, 2, 0, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 6, 2, 6, 6, 2],
[0, 0, 0, 0, 0, 0],
[2, 0, 2, 0, 0, 2],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[4, 1, 1, 1, 4, 1],
[0, 1, 1, 1, 0, 1],
[4, 0, 0, 4, 4, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 4, 1],
[0, 0, 0, 0, 0, 0],
[4, 0, 0, 4, 4, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[6, 0, 0, 6, 0, 6],
[0, 0, 0, 0, 0, 0],
[6, 6, 6, 6, 6, 6],
[5, 5, 5, 5, 5, 5],
[5, 5, 0, 5, 0, 5],
[5, 5, 0, 5, 0, 5],
[5, 5, 1, 5, 1, 5],
[4, 4, 4, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[4, 4, 0, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[4, 4, 4, 4, 4, 4],
[4, 4, 0, 4, 4, 4],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[8, 8, 8, 8, 8, 8],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 0, 2],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 1, 1, 4, 4, 1],
[2, 2, 2, 2, 2, 2],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[2, 4, 2, 4, 4, 2],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[2, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[5, 5, 5, 5, 5, 5],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[4, 4, 4, 4, 4, 4],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[12, 12, 12, 12, 12, 12]
]

View File

@@ -0,0 +1,61 @@
import torch
cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.int)
cube_neighbor = torch.tensor([[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]])
cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, requires_grad=False)
def construct_dense_grid(res, device='cuda'):
'''construct a dense grid based on resolution'''
res_v = res + 1
vertsid = torch.arange(res_v ** 3, device=device)
coordsid = vertsid.reshape(res_v, res_v, res_v)[:res, :res, :res].flatten()
cube_corners_bias = (cube_corners[:, 0] * res_v + cube_corners[:, 1]) * res_v + cube_corners[:, 2]
cube_fx8 = (coordsid.unsqueeze(1) + cube_corners_bias.unsqueeze(0).to(device))
verts = torch.stack([vertsid // (res_v ** 2), (vertsid // res_v) % res_v, vertsid % res_v], dim=1)
return verts, cube_fx8
def construct_voxel_grid(coords):
verts = (cube_corners.unsqueeze(0).to(coords) + coords.unsqueeze(1)).reshape(-1, 3)
verts_unique, inverse_indices = torch.unique(verts, dim=0, return_inverse=True)
cubes = inverse_indices.reshape(-1, 8)
return verts_unique, cubes
def cubes_to_verts(num_verts, cubes, value, reduce='mean'):
"""
Args:
cubes [Vx8] verts index for each cube
value [Vx8xM] value to be scattered
Operation:
reduced[cubes[i][j]][k] += value[i][k]
"""
M = value.shape[2] # number of channels
reduced = torch.zeros(num_verts, M, device=cubes.device)
return torch.scatter_reduce(reduced, 0,
cubes.unsqueeze(-1).expand(-1, -1, M).flatten(0, 1),
value.flatten(0, 1), reduce=reduce, include_self=False)
def sparse_cube2verts(coords, feats, training=True):
new_coords, cubes = construct_voxel_grid(coords)
new_feats = cubes_to_verts(new_coords.shape[0], cubes, feats)
if training:
con_loss = torch.mean((feats - new_feats[cubes]) ** 2)
else:
con_loss = 0.0
return new_coords, new_feats, con_loss
def get_dense_attrs(coords : torch.Tensor, feats : torch.Tensor, res : int, sdf_init=True):
F = feats.shape[-1]
dense_attrs = torch.zeros([res] * 3 + [F], device=feats.device)
if sdf_init:
dense_attrs[..., 0] = 1 # initial outside sdf value
dense_attrs[coords[:, 0], coords[:, 1], coords[:, 2], :] = feats
return dense_attrs.reshape(-1, F)
def get_defomed_verts(v_pos : torch.Tensor, deform : torch.Tensor, res):
return v_pos / res - 0.5 + (1 - 1e-8) / (res * 2) * torch.tanh(deform)

View File

@@ -0,0 +1,68 @@
from typing import *
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
import torch
from transformers import AutoTokenizer, CLIPTextModel
from ....utils import dist_utils
class TextConditionedMixin:
"""
Mixin for text-conditioned models.
Args:
text_cond_model: The text conditioning model.
"""
def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs):
super().__init__(*args, **kwargs)
self.text_cond_model_name = text_cond_model
self.text_cond_model = None # the model is init lazily
def _init_text_cond_model(self):
"""
Initialize the text conditioning model.
"""
# load model
with dist_utils.local_master_first():
model = CLIPTextModel.from_pretrained(self.text_cond_model_name)
tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name)
model.eval()
model = model.cuda()
self.text_cond_model = {
'model': model,
'tokenizer': tokenizer,
}
self.text_cond_model['null_cond'] = self.encode_text([''])
@torch.no_grad()
def encode_text(self, text: List[str]) -> torch.Tensor:
"""
Encode the text.
"""
assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond"
if self.text_cond_model is None:
self._init_text_cond_model()
encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt')
tokens = encoding['input_ids'].cuda()
embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state
return embeddings
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
cond = self.encode_text(cond)
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
cond = super().get_cond(cond, **kwargs)
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
cond = self.encode_text(cond)
kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1)
cond = super().get_inference_cond(cond, **kwargs)
return cond

View File

@@ -0,0 +1,286 @@
from typing import *
import os
import copy
import functools
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from easydict import EasyDict as edict
from ...modules import sparse as sp
from ...utils.general_utils import dict_reduce
from ...utils.data_utils import cycle, BalancedResumableSampler
from .flow_matching import FlowMatchingTrainer
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
from .mixins.text_conditioned import TextConditionedMixin
from .mixins.image_conditioned import ImageConditionedMixin
class SparseFlowMatchingTrainer(FlowMatchingTrainer):
"""
Trainer for sparse diffusion model with flow matching objective.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
"""
def prepare_dataloader(self, **kwargs):
"""
Prepare dataloader.
"""
self.data_sampler = BalancedResumableSampler(
self.dataset,
shuffle=True,
batch_size=self.batch_size_per_gpu,
)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.batch_size_per_gpu,
num_workers=int(np.ceil(os.cpu_count() / torch.cuda.device_count())),
pin_memory=True,
drop_last=True,
persistent_workers=True,
collate_fn=functools.partial(self.dataset.collate_fn, split_size=self.batch_split),
sampler=self.data_sampler,
)
self.data_iterator = cycle(self.dataloader)
def training_losses(
self,
x_0: sp.SparseTensor,
cond=None,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses for a single timestep.
Args:
x_0: The [N x ... x C] sparse tensor of the inputs.
cond: The [N x ...] tensor of additional conditions.
kwargs: Additional arguments to pass to the backbone.
Returns:
a dict with the key "loss" containing a tensor of shape [N].
may also contain other keys for different terms.
"""
noise = x_0.replace(torch.randn_like(x_0.feats))
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
x_t = self.diffuse(x_0, t, noise=noise)
cond = self.get_cond(cond, **kwargs)
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
assert pred.shape == noise.shape == x_0.shape
target = self.get_v(x_0, noise, t)
terms = edict()
terms["mse"] = F.mse_loss(pred.feats, target.feats)
terms["loss"] = terms["mse"]
# log loss with time bins
mse_per_instance = np.array([
F.mse_loss(pred.feats[x_0.layout[i]], target.feats[x_0.layout[i]]).item()
for i in range(x_0.shape[0])
])
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
for i in range(10):
if (time_bin == i).sum() != 0:
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
sampler = self.get_sampler()
sample_gt = []
sample = []
cond_vis = []
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
data = {k: v[:batch].cuda() if not isinstance(v, list) else v[:batch] for k, v in data.items()}
noise = data['x_0'].replace(torch.randn_like(data['x_0'].feats))
sample_gt.append(data['x_0'])
cond_vis.append(self.vis_cond(**data))
del data['x_0']
args = self.get_inference_cond(**data)
res = sampler.sample(
self.models['denoiser'],
noise=noise,
**args,
steps=50, cfg_strength=3.0, verbose=verbose,
)
sample.append(res.samples)
sample_gt = sp.sparse_cat(sample_gt)
sample = sp.sparse_cat(sample)
sample_dict = {
'sample_gt': {'value': sample_gt, 'type': 'sample'},
'sample': {'value': sample, 'type': 'sample'},
}
sample_dict.update(dict_reduce(cond_vis, None, {
'value': lambda x: torch.cat(x, dim=0),
'type': lambda x: x[0],
}))
return sample_dict
class SparseFlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, SparseFlowMatchingTrainer):
"""
Trainer for sparse diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
"""
pass
class TextConditionedSparseFlowMatchingCFGTrainer(TextConditionedMixin, SparseFlowMatchingCFGTrainer):
"""
Trainer for sparse text-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
text_cond_model(str): Text conditioning model.
"""
pass
class ImageConditionedSparseFlowMatchingCFGTrainer(ImageConditionedMixin, SparseFlowMatchingCFGTrainer):
"""
Trainer for sparse image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass

77
trellis/trainers/utils.py Normal file
View File

@@ -0,0 +1,77 @@
import torch.nn as nn
# FP16 utils
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
def make_master_params(model_params):
"""
Copy model parameters into a inflated tensor of full-precision parameters.
"""
master_params = _flatten_dense_tensors(
[param.detach().float() for param in model_params]
)
master_params = nn.Parameter(master_params)
master_params.requires_grad = True
return [master_params]
def unflatten_master_params(model_params, master_params):
"""
Unflatten the master parameters to look like model_params.
"""
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
def model_params_to_master_params(model_params, master_params):
"""
Copy the model parameter data into the master parameters.
"""
master_params[0].detach().copy_(
_flatten_dense_tensors([param.detach().float() for param in model_params])
)
def master_params_to_model_params(model_params, master_params):
"""
Copy the master parameter data back into the model parameters.
"""
for param, master_param in zip(
model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
):
param.detach().copy_(master_param)
def model_grads_to_master_grads(model_params, master_params):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
master_params[0].grad = _flatten_dense_tensors(
[param.grad.data.detach().float() for param in model_params]
)
def zero_grad(model_params):
for param in model_params:
if param.grad is not None:
if param.grad.grad_fn is not None:
param.grad.detach_()
else:
param.grad.requires_grad_(False)
param.grad.zero_()
# LR Schedulers
from torch.optim.lr_scheduler import LambdaLR
class LinearWarmupLRScheduler(LambdaLR):
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
self.warmup_steps = warmup_steps
super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
def lr_lambda(self, current_step):
if current_step < self.warmup_steps:
return float(current_step + 1) / self.warmup_steps
return 1.0