This commit is contained in:
zcr
2026-03-17 11:29:17 +08:00
parent 6c79bdb20f
commit 24e4c120be
25 changed files with 3895 additions and 0 deletions

62
Dockerfile Normal file
View File

@@ -0,0 +1,62 @@
# 使用 devel 镜像,确保有 nvcc + CUDA Toolkit
FROM nvidia/cuda:12.8.0-devel-ubuntu22.04
# 安装基本系统依赖apt 清理缓存节省空间)
RUN apt-get update && apt-get install -y --no-install-recommends \
wget \
git \
build-essential \
cmake \
ninja-build \
libglib2.0-0 \
libgl1 \
python3.10 \
python3.10-dev \
python3-pip \
&& rm -rf /var/lib/apt/lists/*
# 安装 miniconda推荐用官方最新版
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \
bash /tmp/miniconda.sh -b -p /opt/conda && \
rm /tmp/miniconda.sh && \
/opt/conda/bin/conda clean --all -y
# 把 conda 加到 PATH
ENV PATH="/opt/conda/bin:${PATH}"
# 接受 ToS关键修复点
RUN conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \
conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r
#
## 创建环境
#RUN conda create -n trellis python=3.10 -y && \
# echo "conda activate trellis" >> ~/.bashrc
#
#SHELL ["/bin/bash", "-c"]
#
## 激活环境后安装 PyTorch 2.8 nightly + CUDA 12.8
#RUN conda activate trellis && \
# conda install pytorch torchvision torchaudio pytorch-cuda=12.8 -c pytorch-nightly -c nvidia -y
#
## 安装 pytorch3d优先用 fvcore 频道,如果没有则从 git 源码编译)
#RUN conda activate trellis && \
# conda install pytorch3d -c fvcore -c pytorch -c nvidia -y || \
# pip install --no-cache-dir "git+https://github.com/facebookresearch/pytorch3d.git@stable"
#
## 如果你有 trellis.tar.gz 打包的环境,可以继续 COPY 并 unpack可选
## COPY trellis.tar.gz /opt/
## RUN mkdir /opt/env && tar -xzf /opt/trellis.tar.gz -C /opt/env && /opt/env/bin/conda-unpack
#
## 安装 kaolin你的原 Dockerfile 有这个)
#RUN conda activate trellis && \
# pip install kaolin==0.18.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.8.0_cu128.html
# 设置 PATH 和工作目录
ENV PATH="/opt/conda/envs/trellis/bin:${PATH}"
WORKDIR /workspace
# 复制你的代码(如果需要)
COPY . /workspace
# 默认命令:保持容器运行,或换成你的启动脚本
CMD ["tail", "-f", "/dev/null"]

View File

@@ -0,0 +1,52 @@
import os
import copy
import sys
import importlib
import argparse
import pandas as pd
from easydict import EasyDict as edict
if __name__ == '__main__':
dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}')
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
dataset_utils.add_args(parser)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args(sys.argv[2:])
opt = edict(vars(opt))
os.makedirs(opt.output_dir, exist_ok=True)
# get file list
if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
raise ValueError('metadata.csv not found')
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
if opt.instances is None:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if 'local_path' in metadata.columns:
metadata = metadata[metadata['local_path'].isna()]
else:
if os.path.exists(opt.instances):
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
else:
instances = opt.instances.split(',')
metadata = metadata[metadata['sha256'].isin(instances)]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
print(f'Processing {len(metadata)} objects...')
# process objects
downloaded = dataset_utils.download(metadata, **opt)
downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False)

View File

@@ -0,0 +1,127 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import copy
import json
import argparse
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import trellis.models as models
import trellis.modules.sparse as sp
torch.set_grad_enabled(False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg',
help='Feature model')
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16',
help='Pretrained encoder model')
parser.add_argument('--model_root', type=str, default='results',
help='Root directory of models')
parser.add_argument('--enc_model', type=str, default=None,
help='Encoder model. if specified, use this model instead of pretrained model')
parser.add_argument('--ckpt', type=str, default=None,
help='Checkpoint to load')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
if opt.enc_model is None:
latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}'
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
else:
latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}'
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
encoder.eval()
print(f'Loaded model from {ckpt_path}')
os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True)
# get file list
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
else:
raise ValueError('metadata.csv not found')
if opt.instances is not None:
with open(opt.instances, 'r') as f:
sha256s = [line.strip() for line in f]
metadata = metadata[metadata['sha256'].isin(sha256s)]
else:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True]
if f'latent_{latent_name}' in metadata.columns:
metadata = metadata[metadata[f'latent_{latent_name}'] == False]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256s = list(metadata['sha256'].values)
for sha256 in copy.copy(sha256s):
if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')):
records.append({'sha256': sha256, f'latent_{latent_name}': True})
sha256s.remove(sha256)
# encode latents
load_queue = Queue(maxsize=4)
try:
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
ThreadPoolExecutor(max_workers=32) as saver_executor:
def loader(sha256):
try:
feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz'))
load_queue.put((sha256, feats))
except Exception as e:
print(f"Error loading features for {sha256}: {e}")
loader_executor.map(loader, sha256s)
def saver(sha256, pack):
save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, f'latent_{latent_name}': True})
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
sha256, feats = load_queue.get()
feats = sp.SparseTensor(
feats = torch.from_numpy(feats['patchtokens']).float(),
coords = torch.cat([
torch.zeros(feats['patchtokens'].shape[0], 1).int(),
torch.from_numpy(feats['indices']).int(),
], dim=1),
).cuda()
latent = encoder(feats, sample_posterior=False)
assert torch.isfinite(latent.feats).all(), "Non-finite latent"
pack = {
'feats': latent.feats.cpu().numpy().astype(np.float32),
'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8),
}
saver_executor.submit(saver, sha256, pack)
saver_executor.shutdown(wait=True)
except:
print("Error happened during processing.")
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False)

View File

@@ -0,0 +1,128 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import copy
import json
import argparse
import torch
import numpy as np
import pandas as pd
import utils3d
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import trellis.models as models
torch.set_grad_enabled(False)
def get_voxels(instance):
position = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{instance}.ply'))[0]
coords = ((torch.tensor(position) + 0.5) * opt.resolution).int().contiguous()
ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long)
ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1
return ss
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16',
help='Pretrained encoder model')
parser.add_argument('--model_root', type=str, default='results',
help='Root directory of models')
parser.add_argument('--enc_model', type=str, default=None,
help='Encoder model. if specified, use this model instead of pretrained model')
parser.add_argument('--ckpt', type=str, default=None,
help='Checkpoint to load')
parser.add_argument('--resolution', type=int, default=64,
help='Resolution')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
if opt.enc_model is None:
latent_name = f'{opt.enc_pretrained.split("/")[-1]}'
encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda()
else:
latent_name = f'{opt.enc_model}_{opt.ckpt}'
cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r')))
encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda()
ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt')
encoder.load_state_dict(torch.load(ckpt_path), strict=False)
encoder.eval()
print(f'Loaded model from {ckpt_path}')
os.makedirs(os.path.join(opt.output_dir, 'ss_latents', latent_name), exist_ok=True)
# get file list
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
else:
raise ValueError('metadata.csv not found')
if opt.instances is not None:
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
metadata = metadata[metadata['sha256'].isin(instances)]
else:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
metadata = metadata[metadata['voxelized'] == True]
if f'ss_latent_{latent_name}' in metadata.columns:
metadata = metadata[metadata[f'ss_latent_{latent_name}'] == False]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256s = list(metadata['sha256'].values)
for sha256 in copy.copy(sha256s):
if os.path.exists(os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')):
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
sha256s.remove(sha256)
# encode latents
load_queue = Queue(maxsize=4)
try:
with ThreadPoolExecutor(max_workers=32) as loader_executor, \
ThreadPoolExecutor(max_workers=32) as saver_executor:
def loader(sha256):
try:
ss = get_voxels(sha256)[None].float()
load_queue.put((sha256, ss))
except Exception as e:
print(f"Error loading features for {sha256}: {e}")
loader_executor.map(loader, sha256s)
def saver(sha256, pack):
save_path = os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, f'ss_latent_{latent_name}': True})
for _ in tqdm(range(len(sha256s)), desc="Extracting latents"):
sha256, ss = load_queue.get()
ss = ss.cuda().float()
latent = encoder(ss, sample_posterior=False)
assert torch.isfinite(latent).all(), "Non-finite latent"
pack = {
'mean': latent[0].cpu().numpy(),
}
saver_executor.submit(saver, sha256, pack)
saver_executor.shutdown(wait=True)
except:
print("Error happened during processing.")
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.output_dir, f'ss_latent_{latent_name}_{opt.rank}.csv'), index=False)

View File

@@ -0,0 +1,179 @@
import os
import copy
import sys
import json
import importlib
import argparse
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import utils3d
from tqdm import tqdm
from easydict import EasyDict as edict
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from torchvision import transforms
from PIL import Image
torch.set_grad_enabled(False)
def get_data(frames, sha256):
with ThreadPoolExecutor(max_workers=16) as executor:
def worker(view):
image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path'])
try:
image = Image.open(image_path)
except:
print(f"Error loading image {image_path}")
return None
image = image.resize((518, 518), Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255
image = image[:, :, :3] * image[:, :, 3:]
image = torch.from_numpy(image).permute(2, 0, 1).float()
c2w = torch.tensor(view['transform_matrix'])
c2w[:3, 1:3] *= -1
extrinsics = torch.inverse(c2w)
fov = view['camera_angle_x']
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
return {
'image': image,
'extrinsics': extrinsics,
'intrinsics': intrinsics
}
datas = executor.map(worker, frames)
for data in datas:
if data is not None:
yield data
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str, required=True,
help='Directory to save the metadata')
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
help='Filter objects with aesthetic score lower than this value')
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg',
help='Feature extraction model')
parser.add_argument('--instances', type=str, default=None,
help='Instances to process')
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
opt = parser.parse_args()
opt = edict(vars(opt))
feature_name = opt.model
os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True)
# load model
dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model)
dinov2_model.eval().cuda()
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
n_patch = 518 // 14
# get file list
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
else:
raise ValueError('metadata.csv not found')
if opt.instances is not None:
with open(opt.instances, 'r') as f:
instances = f.read().splitlines()
metadata = metadata[metadata['sha256'].isin(instances)]
else:
if opt.filter_low_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
if f'feature_{feature_name}' in metadata.columns:
metadata = metadata[metadata[f'feature_{feature_name}'] == False]
metadata = metadata[metadata['voxelized'] == True]
metadata = metadata[metadata['rendered'] == True]
start = len(metadata) * opt.rank // opt.world_size
end = len(metadata) * (opt.rank + 1) // opt.world_size
metadata = metadata[start:end]
records = []
# filter out objects that are already processed
sha256s = list(metadata['sha256'].values)
for sha256 in copy.copy(sha256s):
if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')):
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
sha256s.remove(sha256)
# extract features
load_queue = Queue(maxsize=4)
try:
with ThreadPoolExecutor(max_workers=8) as loader_executor, \
ThreadPoolExecutor(max_workers=8) as saver_executor:
def loader(sha256):
try:
with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f:
metadata = json.load(f)
frames = metadata['frames']
data = []
for datum in get_data(frames, sha256):
datum['image'] = transform(datum['image'])
data.append(datum)
positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0]
load_queue.put((sha256, data, positions))
except Exception as e:
print(f"Error loading data for {sha256}: {e}")
loader_executor.map(loader, sha256s)
def saver(sha256, pack, patchtokens, uv):
pack['patchtokens'] = F.grid_sample(
patchtokens,
uv.unsqueeze(1),
mode='bilinear',
align_corners=False,
).squeeze(2).permute(0, 2, 1).cpu().numpy()
pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16)
save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')
np.savez_compressed(save_path, **pack)
records.append({'sha256': sha256, f'feature_{feature_name}' : True})
for _ in tqdm(range(len(sha256s)), desc="Extracting features"):
sha256, data, positions = load_queue.get()
positions = torch.from_numpy(positions).float().cuda()
indices = ((positions + 0.5) * 64).long()
assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds"
n_views = len(data)
N = positions.shape[0]
pack = {
'indices': indices.cpu().numpy().astype(np.uint8),
}
patchtokens_lst = []
uv_lst = []
for i in range(0, n_views, opt.batch_size):
batch_data = data[i:i+opt.batch_size]
bs = len(batch_data)
batch_images = torch.stack([d['image'] for d in batch_data]).cuda()
batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda()
batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda()
features = dinov2_model(batch_images, is_training=True)
uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1
patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch)
patchtokens_lst.append(patchtokens)
uv_lst.append(uv)
patchtokens = torch.cat(patchtokens_lst, dim=0)
uv = torch.cat(uv_lst, dim=0)
# save features
saver_executor.submit(saver, sha256, pack, patchtokens, uv)
saver_executor.shutdown(wait=True)
except:
print("Error happened during processing.")
records = pd.DataFrame.from_records(records)
records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False)

29
docker-compose.yaml Normal file
View File

@@ -0,0 +1,29 @@
services:
trellis:
image: my-trellis-with-blender:20260317 # 你 commit 的镜像
container_name: trellis-dev
restart: unless-stopped
environment:
- NVIDIA_VISIBLE_DEVICES=all
- NVIDIA_DRIVER_CAPABILITIES=compute,utility,video
volumes:
- .:/workspace # 当前目录挂载到 /workspace便于开发
ports:
- "7412:8120"
working_dir: /workspace
tty: true
stdin_open: true
# 最新GPU配置方式替代旧的环境变量声明更规范
runtime: nvidia
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all # 使用所有可用GPU也可指定数量如count: 1
capabilities: [ gpu, compute, video ]
command: >
bash -c "
/opt/conda/envs/trellis/bin/python -c 'import torch; print(torch.__version__, torch.cuda.is_available())' &&
tail -f /dev/null
"

57
example.py Normal file
View File

@@ -0,0 +1,57 @@
import os
# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn'
os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'.
# 'auto' is faster but will do benchmarking at the beginning.
# Recommended to set to 'native' if run only once.
import imageio
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import render_utils, postprocessing_utils
# Load a pipeline from a model folder or a Hugging Face model hub.
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
pipeline.cuda()
# Load an image
image = Image.open("assets/example_image/T.png")
# Run the pipeline
outputs = pipeline.run(
image,
seed=1,
# Optional parameters
# sparse_structure_sampler_params={
# "steps": 12,
# "cfg_strength": 7.5,
# },
# slat_sampler_params={
# "steps": 12,
# "cfg_strength": 3,
# },
)
# outputs is a dictionary containing generated 3D assets in different formats:
# - outputs['gaussian']: a list of 3D Gaussians
# - outputs['radiance_field']: a list of radiance fields
# - outputs['mesh']: a list of meshes
# Render the outputs
video = render_utils.render_video(outputs['gaussian'][0])['color']
imageio.mimsave("sample_gs.mp4", video, fps=30)
video = render_utils.render_video(outputs['radiance_field'][0])['color']
imageio.mimsave("sample_rf.mp4", video, fps=30)
video = render_utils.render_video(outputs['mesh'][0])['normal']
imageio.mimsave("sample_mesh.mp4", video, fps=30)
# GLB files can be extracted from the outputs
glb = postprocessing_utils.to_glb(
outputs['gaussian'][0],
outputs['mesh'][0],
# Optional parameters
simplify=0.95, # Ratio of triangles to remove in the simplification process
texture_size=1024, # Size of the texture used for the GLB
)
glb.export("sample.glb")
# Save Gaussians as PLY files
outputs['gaussian'][0].save_ply("sample.ply")

46
example_multi_image.py Normal file
View File

@@ -0,0 +1,46 @@
import os
# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn'
os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'.
# 'auto' is faster but will do benchmarking at the beginning.
# Recommended to set to 'native' if run only once.
import numpy as np
import imageio
from PIL import Image
from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.utils import render_utils
# Load a pipeline from a model folder or a Hugging Face model hub.
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
pipeline.cuda()
# Load an image
images = [
Image.open("assets/example_multi_image/character_1.png"),
Image.open("assets/example_multi_image/character_2.png"),
Image.open("assets/example_multi_image/character_3.png"),
]
# Run the pipeline
outputs = pipeline.run_multi_image(
images,
seed=1,
# Optional parameters
sparse_structure_sampler_params={
"steps": 12,
"cfg_strength": 7.5,
},
slat_sampler_params={
"steps": 12,
"cfg_strength": 3,
},
)
# outputs is a dictionary containing generated 3D assets in different formats:
# - outputs['gaussian']: a list of 3D Gaussians
# - outputs['radiance_field']: a list of radiance fields
# - outputs['mesh']: a list of meshes
video_gs = render_utils.render_video(outputs['gaussian'][0])['color']
video_mesh = render_utils.render_video(outputs['mesh'][0])['normal']
video = [np.concatenate([frame_gs, frame_mesh], axis=1) for frame_gs, frame_mesh in zip(video_gs, video_mesh)]
imageio.mimsave("sample_multi.mp4", video, fps=30)

53
example_text.py Normal file
View File

@@ -0,0 +1,53 @@
import os
# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn'
os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'.
# 'auto' is faster but will do benchmarking at the beginning.
# Recommended to set to 'native' if run only once.
import imageio
from trellis.pipelines import TrellisTextTo3DPipeline
from trellis.utils import render_utils, postprocessing_utils
# Load a pipeline from a model folder or a Hugging Face model hub.
pipeline = TrellisTextTo3DPipeline.from_pretrained("microsoft/TRELLIS-text-xlarge")
pipeline.cuda()
# Run the pipeline
outputs = pipeline.run(
"A chair looking like a avocado.",
seed=1,
# Optional parameters
# sparse_structure_sampler_params={
# "steps": 12,
# "cfg_strength": 7.5,
# },
# slat_sampler_params={
# "steps": 12,
# "cfg_strength": 7.5,
# },
)
# outputs is a dictionary containing generated 3D assets in different formats:
# - outputs['gaussian']: a list of 3D Gaussians
# - outputs['radiance_field']: a list of radiance fields
# - outputs['mesh']: a list of meshes
# Render the outputs
video = render_utils.render_video(outputs['gaussian'][0])['color']
imageio.mimsave("sample_gs.mp4", video, fps=30)
video = render_utils.render_video(outputs['radiance_field'][0])['color']
imageio.mimsave("sample_rf.mp4", video, fps=30)
video = render_utils.render_video(outputs['mesh'][0])['normal']
imageio.mimsave("sample_mesh.mp4", video, fps=30)
# GLB files can be extracted from the outputs
glb = postprocessing_utils.to_glb(
outputs['gaussian'][0],
outputs['mesh'][0],
# Optional parameters
simplify=0.95, # Ratio of triangles to remove in the simplification process
texture_size=1024, # Size of the texture used for the GLB
)
glb.export("sample.glb")
# Save Gaussians as PLY files
outputs['gaussian'][0].save_ply("sample.ply")

41
example_variant.py Normal file
View File

@@ -0,0 +1,41 @@
import os
# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn'
os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'.
# 'auto' is faster but will do benchmarking at the beginning.
# Recommended to set to 'native' if run only once.
import imageio
import numpy as np
import open3d as o3d
from trellis.pipelines import TrellisTextTo3DPipeline
from trellis.utils import render_utils, postprocessing_utils
# Load a pipeline from a model folder or a Hugging Face model hub.
pipeline = TrellisTextTo3DPipeline.from_pretrained("microsoft/TRELLIS-text-xlarge")
pipeline.cuda()
# Load mesh to make variants
base_mesh = o3d.io.read_triangle_mesh("assets/T.ply")
# Run the pipeline
outputs = pipeline.run_variant(
base_mesh,
"Rugged, metallic texture with orange and white paint finish, suggesting a durable, industrial feel.",
seed=1,
# Optional parameters
# slat_sampler_params={
# "steps": 12,
# "cfg_strength": 7.5,
# },
)
# outputs is a dictionary containing generated 3D assets in different formats:
# - outputs['gaussian']: a list of 3D Gaussians
# - outputs['radiance_field']: a list of radiance fields
# - outputs['mesh']: a list of meshes
# Render the outputs
video_gs = render_utils.render_video(outputs['gaussian'][0])['color']
video_mesh = render_utils.render_video(outputs['mesh'][0])['normal']
video = [np.concatenate([frame_gs, frame_mesh], axis=1) for frame_gs, frame_mesh in zip(video_gs, video_mesh)]
imageio.mimsave("sample_variant.mp4", video, fps=30)

View File

@@ -0,0 +1,131 @@
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import sparse as sp
from ...utils.random_utils import hammersley_sequence
from .base import SparseTransformerBase
from ...representations import Gaussian
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
class SLatGaussianDecoder(SparseTransformerBase):
def __init__(
self,
resolution: int,
model_channels: int,
latent_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
representation_config: dict = None,
):
super().__init__(
in_channels=latent_channels,
model_channels=model_channels,
num_blocks=num_blocks,
num_heads=num_heads,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
)
self.resolution = resolution
self.rep_config = representation_config
self._calc_layout()
self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
self._build_perturbation()
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
def initialize_weights(self) -> None:
super().initialize_weights()
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def _build_perturbation(self) -> None:
perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])]
perturbation = torch.tensor(perturbation).float() * 2 - 1
perturbation = perturbation / self.rep_config['voxel_size']
perturbation = torch.atanh(perturbation).to(self.device)
self.register_buffer('offset_perturbation', perturbation)
def _calc_layout(self) -> None:
self.layout = {
'_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
'_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3},
'_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3},
'_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4},
'_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']},
}
start = 0
for k, v in self.layout.items():
v['range'] = (start, start + v['size'])
start += v['size']
self.out_channels = start
def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]:
"""
Convert a batch of network outputs to 3D representations.
Args:
x: The [N x * x C] sparse tensor output by the network.
Returns:
list of representations
"""
ret = []
for i in range(x.shape[0]):
representation = Gaussian(
sh_degree=0,
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
mininum_kernel_size = self.rep_config['3d_filter_kernel_size'],
scaling_bias = self.rep_config['scaling_bias'],
opacity_bias = self.rep_config['opacity_bias'],
scaling_activation = self.rep_config['scaling_activation']
)
xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
for k, v in self.layout.items():
if k == '_xyz':
offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])
offset = offset * self.rep_config['lr'][k]
if self.rep_config['perturb_offset']:
offset = offset + self.offset_perturbation
offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size']
_xyz = xyz.unsqueeze(1) + offset
setattr(representation, k, _xyz.flatten(0, 1))
else:
feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
feats = feats * self.rep_config['lr'][k]
setattr(representation, k, feats)
ret.append(representation)
return ret
def forward(self, x: sp.SparseTensor) -> List[Gaussian]:
h = super().forward(x)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
return self.to_representation(h)
class ElasticSLatGaussianDecoder(SparseTransformerElasticMixin, SLatGaussianDecoder):
"""
Slat VAE Gaussian decoder with elastic memory management.
Used for training with low VRAM.
"""
pass

View File

@@ -0,0 +1,176 @@
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
from ...modules import sparse as sp
from .base import SparseTransformerBase
from ...representations import MeshExtractResult
from ...representations.mesh import SparseFeatures2Mesh
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
class SparseSubdivideBlock3d(nn.Module):
"""
A 3D subdivide block that can subdivide the sparse tensor.
Args:
channels: channels in the inputs and outputs.
out_channels: if specified, the number of output channels.
num_groups: the number of groups for the group norm.
"""
def __init__(
self,
channels: int,
resolution: int,
out_channels: Optional[int] = None,
num_groups: int = 32
):
super().__init__()
self.channels = channels
self.resolution = resolution
self.out_resolution = resolution * 2
self.out_channels = out_channels or channels
self.act_layers = nn.Sequential(
sp.SparseGroupNorm32(num_groups, channels),
sp.SparseSiLU()
)
self.sub = sp.SparseSubdivide()
self.out_layers = nn.Sequential(
sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"),
sp.SparseGroupNorm32(num_groups, self.out_channels),
sp.SparseSiLU(),
zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}")
def forward(self, x: sp.SparseTensor) -> sp.SparseTensor:
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
Args:
x: an [N x C x ...] Tensor of features.
Returns:
an [N x C x ...] Tensor of outputs.
"""
h = self.act_layers(x)
h = self.sub(h)
x = self.sub(x)
h = self.out_layers(h)
h = h + self.skip_connection(x)
return h
class SLatMeshDecoder(SparseTransformerBase):
def __init__(
self,
resolution: int,
model_channels: int,
latent_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
representation_config: dict = None,
):
super().__init__(
in_channels=latent_channels,
model_channels=model_channels,
num_blocks=num_blocks,
num_heads=num_heads,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
)
self.resolution = resolution
self.rep_config = representation_config
self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False))
self.out_channels = self.mesh_extractor.feats_channels
self.upsample = nn.ModuleList([
SparseSubdivideBlock3d(
channels=model_channels,
resolution=resolution,
out_channels=model_channels // 4
),
SparseSubdivideBlock3d(
channels=model_channels // 4,
resolution=resolution * 2,
out_channels=model_channels // 8
)
])
self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
def initialize_weights(self) -> None:
super().initialize_weights()
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def convert_to_fp16(self) -> None:
"""
Convert the torso of the model to float16.
"""
super().convert_to_fp16()
self.upsample.apply(convert_module_to_f16)
def convert_to_fp32(self) -> None:
"""
Convert the torso of the model to float32.
"""
super().convert_to_fp32()
self.upsample.apply(convert_module_to_f32)
def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
"""
Convert a batch of network outputs to 3D representations.
Args:
x: The [N x * x C] sparse tensor output by the network.
Returns:
list of representations
"""
ret = []
for i in range(x.shape[0]):
mesh = self.mesh_extractor(x[i], training=self.training)
ret.append(mesh)
return ret
def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]:
h = super().forward(x)
for block in self.upsample:
h = block(h)
h = h.type(x.dtype)
h = self.out_layer(h)
return self.to_representation(h)
class ElasticSLatMeshDecoder(SparseTransformerElasticMixin, SLatMeshDecoder):
"""
Slat VAE Mesh decoder with elastic memory management.
Used for training with low VRAM.
"""
pass

View File

@@ -0,0 +1,113 @@
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ...modules import sparse as sp
from .base import SparseTransformerBase
from ...representations import Strivec
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
class SLatRadianceFieldDecoder(SparseTransformerBase):
def __init__(
self,
resolution: int,
model_channels: int,
latent_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
representation_config: dict = None,
):
super().__init__(
in_channels=latent_channels,
model_channels=model_channels,
num_blocks=num_blocks,
num_heads=num_heads,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
)
self.resolution = resolution
self.rep_config = representation_config
self._calc_layout()
self.out_layer = sp.SparseLinear(model_channels, self.out_channels)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
def initialize_weights(self) -> None:
super().initialize_weights()
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def _calc_layout(self) -> None:
self.layout = {
'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']},
'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']},
'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3},
}
start = 0
for k, v in self.layout.items():
v['range'] = (start, start + v['size'])
start += v['size']
self.out_channels = start
def to_representation(self, x: sp.SparseTensor) -> List[Strivec]:
"""
Convert a batch of network outputs to 3D representations.
Args:
x: The [N x * x C] sparse tensor output by the network.
Returns:
list of representations
"""
ret = []
for i in range(x.shape[0]):
representation = Strivec(
sh_degree=0,
resolution=self.resolution,
aabb=[-0.5, -0.5, -0.5, 1, 1, 1],
rank=self.rep_config['rank'],
dim=self.rep_config['dim'],
device='cuda',
)
representation.density_shift = 0.0
representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution
representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda')
for k, v in self.layout.items():
setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']))
representation.trivec = representation.trivec + 1
ret.append(representation)
return ret
def forward(self, x: sp.SparseTensor) -> List[Strivec]:
h = super().forward(x)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
return self.to_representation(h)
class ElasticSLatRadianceFieldDecoder(SparseTransformerElasticMixin, SLatRadianceFieldDecoder):
"""
Slat VAE Radiance Field Decoder with elastic memory management.
Used for training with low VRAM.
"""
pass

View File

@@ -0,0 +1,80 @@
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modules import sparse as sp
from .base import SparseTransformerBase
from ..sparse_elastic_mixin import SparseTransformerElasticMixin
class SLatEncoder(SparseTransformerBase):
def __init__(
self,
resolution: int,
in_channels: int,
model_channels: int,
latent_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin",
window_size: int = 8,
pe_mode: Literal["ape", "rope"] = "ape",
use_fp16: bool = False,
use_checkpoint: bool = False,
qk_rms_norm: bool = False,
):
super().__init__(
in_channels=in_channels,
model_channels=model_channels,
num_blocks=num_blocks,
num_heads=num_heads,
num_head_channels=num_head_channels,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
window_size=window_size,
pe_mode=pe_mode,
use_fp16=use_fp16,
use_checkpoint=use_checkpoint,
qk_rms_norm=qk_rms_norm,
)
self.resolution = resolution
self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels)
self.initialize_weights()
if use_fp16:
self.convert_to_fp16()
def initialize_weights(self) -> None:
super().initialize_weights()
# Zero-out output layers:
nn.init.constant_(self.out_layer.weight, 0)
nn.init.constant_(self.out_layer.bias, 0)
def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False):
h = super().forward(x)
h = h.type(x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
# Sample from the posterior distribution
mean, logvar = h.feats.chunk(2, dim=-1)
if sample_posterior:
std = torch.exp(0.5 * logvar)
z = mean + std * torch.randn_like(std)
else:
z = mean
z = h.replace(z)
if return_raw:
return z, mean, logvar
else:
return z
class ElasticSLatEncoder(SparseTransformerElasticMixin, SLatEncoder):
"""
SLat VAE encoder with elastic memory management.
Used for training with low VRAM.
"""

View File

@@ -0,0 +1,140 @@
from typing import *
import torch
import math
from . import DEBUG, BACKEND
if BACKEND == 'xformers':
import xformers.ops as xops
elif BACKEND == 'flash_attn':
import flash_attn
elif BACKEND == 'sdpa':
from torch.nn.functional import scaled_dot_product_attention as sdpa
elif BACKEND == 'naive':
pass
else:
raise ValueError(f"Unknown attention backend: {BACKEND}")
__all__ = [
'scaled_dot_product_attention',
]
def _naive_sdpa(q, k, v):
"""
Naive implementation of scaled dot product attention.
"""
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
scale_factor = 1 / math.sqrt(q.size(-1))
attn_weight = q @ k.transpose(-2, -1) * scale_factor
attn_weight = torch.softmax(attn_weight, dim=-1)
out = attn_weight @ v
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
return out
@overload
def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor:
"""
Apply scaled dot product attention.
Args:
qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs.
"""
...
@overload
def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor:
"""
Apply scaled dot product attention.
Args:
q (torch.Tensor): A [N, L, H, C] tensor containing Qs.
kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs.
"""
...
@overload
def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Apply scaled dot product attention.
Args:
q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs.
k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks.
v (torch.Tensor): A [N, L, H, Co] tensor containing Vs.
Note:
k and v are assumed to have the same coordinate map.
"""
...
def scaled_dot_product_attention(*args, **kwargs):
arg_names_dict = {
1: ['qkv'],
2: ['q', 'kv'],
3: ['q', 'k', 'v']
}
num_all_args = len(args) + len(kwargs)
assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
for key in arg_names_dict[num_all_args][len(args):]:
assert key in kwargs, f"Missing argument {key}"
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs['qkv']
assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]"
device = qkv.device
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs['q']
kv = args[1] if len(args) > 1 else kwargs['kv']
assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
device = q.device
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs['q']
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
device = q.device
if BACKEND == 'xformers':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
out = xops.memory_efficient_attention(q, k, v)
elif BACKEND == 'flash_attn':
if num_all_args == 1:
out = flash_attn.flash_attn_qkvpacked_func(qkv)
elif num_all_args == 2:
out = flash_attn.flash_attn_kvpacked_func(q, kv)
elif num_all_args == 3:
out = flash_attn.flash_attn_func(q, k, v)
elif BACKEND == 'sdpa':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
out = sdpa(q, k, v) # [N, H, L, C]
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
elif BACKEND == 'naive':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
out = _naive_sdpa(q, k, v)
else:
raise ValueError(f"Unknown attention module: {BACKEND}")
return out

View File

@@ -0,0 +1,215 @@
from typing import *
import torch
from .. import SparseTensor
from .. import DEBUG, ATTN
if ATTN == 'xformers':
import xformers.ops as xops
elif ATTN == 'flash_attn':
import flash_attn
else:
raise ValueError(f"Unknown attention module: {ATTN}")
__all__ = [
'sparse_scaled_dot_product_attention',
]
@overload
def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs.
kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (SparseTensor): A [N, L, H, C] dense tensor containing Qs.
kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
Note:
k and v are assumed to have the same coordinate map.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs.
k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks.
v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs.
"""
...
@overload
def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor:
"""
Apply scaled dot product attention to a sparse tensor.
Args:
q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs.
k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks.
v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs.
"""
...
def sparse_scaled_dot_product_attention(*args, **kwargs):
arg_names_dict = {
1: ['qkv'],
2: ['q', 'kv'],
3: ['q', 'k', 'v']
}
num_all_args = len(args) + len(kwargs)
assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3"
for key in arg_names_dict[num_all_args][len(args):]:
assert key in kwargs, f"Missing argument {key}"
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs['qkv']
assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}"
assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]"
device = qkv.device
s = qkv
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
kv_seqlen = q_seqlen
qkv = qkv.feats # [T, 3, H, C]
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs['q']
kv = args[1] if len(args) > 1 else kwargs['kv']
assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \
isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \
f"Invalid types, got {type(q)} and {type(kv)}"
assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}"
device = q.device
if isinstance(q, SparseTensor):
assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]"
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, C]
else:
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]"
s = None
N, L, H, C = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, C) # [T_Q, H, C]
if isinstance(kv, SparseTensor):
assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]"
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
kv = kv.feats # [T_KV, 2, H, C]
else:
assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]"
N, L, _, H, C = kv.shape
kv_seqlen = [L] * N
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs['q']
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \
isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \
f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}"
assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}"
device = q.device
if isinstance(q, SparseTensor):
assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]"
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, Ci]
else:
assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]"
s = None
N, L, H, CI = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
if isinstance(k, SparseTensor):
assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]"
assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]"
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
k = k.feats # [T_KV, H, Ci]
v = v.feats # [T_KV, H, Co]
else:
assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]"
assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]"
N, L, H, CI, CO = *k.shape, v.shape[-1]
kv_seqlen = [L] * N
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
if DEBUG:
if s is not None:
for i in range(s.shape[0]):
assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch"
if num_all_args in [2, 3]:
assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch"
if num_all_args == 3:
assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch"
assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch"
if ATTN == 'xformers':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
out = xops.memory_efficient_attention(q, k, v, mask)[0]
elif ATTN == 'flash_attn':
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args in [2, 3]:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
if num_all_args == 1:
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
elif num_all_args == 2:
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif num_all_args == 3:
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
else:
raise ValueError(f"Unknown attention module: {ATTN}")
if s is not None:
return s.replace(out)
else:
return out.reshape(N, L, H, -1)

View File

@@ -0,0 +1,201 @@
from typing import *
import torch
import numpy as np
from tqdm import tqdm
from easydict import EasyDict as edict
from .base import Sampler
from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin
from .guidance_interval_mixin import GuidanceIntervalSamplerMixin
class FlowEulerSampler(Sampler):
"""
Generate samples from a flow-matching model using Euler sampling.
Args:
sigma_min: The minimum scale of noise in flow.
"""
def __init__(
self,
sigma_min: float,
):
self.sigma_min = sigma_min
def _eps_to_xstart(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t)
def _xstart_to_eps(self, x_t, t, x_0):
assert x_t.shape == x_0.shape
return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t)
def _v_to_xstart_eps(self, x_t, t, v):
assert x_t.shape == v.shape
eps = (1 - t) * v + x_t
x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v
return x_0, eps
def _inference_model(self, model, x_t, t, cond=None, **kwargs):
t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32)
if cond is not None and cond.shape[0] == 1 and x_t.shape[0] > 1:
cond = cond.repeat(x_t.shape[0], *([1] * (len(cond.shape) - 1)))
return model(x_t, t, cond, **kwargs)
def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs):
pred_v = self._inference_model(model, x_t, t, cond, **kwargs)
pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v)
return pred_x_0, pred_eps, pred_v
@torch.no_grad()
def sample_once(
self,
model,
x_t,
t: float,
t_prev: float,
cond: Optional[Any] = None,
**kwargs
):
"""
Sample x_{t-1} from the model using Euler method.
Args:
model: The model to sample from.
x_t: The [N x C x ...] tensor of noisy inputs at time t.
t: The current timestep.
t_prev: The previous timestep.
cond: conditional information.
**kwargs: Additional arguments for model inference.
Returns:
a dict containing the following
- 'pred_x_prev': x_{t-1}.
- 'pred_x_0': a prediction of x_0.
"""
pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs)
pred_x_prev = x_t - (t - t_prev) * pred_v
return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0})
@torch.no_grad()
def sample(
self,
model,
noise,
cond: Optional[Any] = None,
steps: int = 50,
rescale_t: float = 1.0,
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
sample = noise
t_seq = np.linspace(1, 0, steps + 1)
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps))
ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []})
for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose):
out = self.sample_once(model, sample, t, t_prev, cond, **kwargs)
sample = out.pred_x_prev
ret.pred_x_t.append(out.pred_x_prev)
ret.pred_x_0.append(out.pred_x_0)
ret.samples = sample
return ret
class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance.
"""
@torch.no_grad()
def sample(
self,
model,
noise,
cond,
neg_cond,
steps: int = 50,
rescale_t: float = 1.0,
cfg_strength: float = 3.0,
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
cfg_strength: The strength of classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs)
class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler):
"""
Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval.
"""
@torch.no_grad()
def sample(
self,
model,
noise,
cond,
neg_cond,
steps: int = 50,
rescale_t: float = 1.0,
cfg_strength: float = 3.0,
cfg_interval: Tuple[float, float] = (0.0, 1.0),
verbose: bool = True,
**kwargs
):
"""
Generate samples from the model using Euler method.
Args:
model: The model to sample from.
noise: The initial noise tensor.
cond: conditional information.
neg_cond: negative conditional information.
steps: The number of steps to sample.
rescale_t: The rescale factor for t.
cfg_strength: The strength of classifier-free guidance.
cfg_interval: The interval for classifier-free guidance.
verbose: If True, show a progress bar.
**kwargs: Additional arguments for model_inference.
Returns:
a dict containing the following
- 'samples': the model samples.
- 'pred_x_t': a list of prediction of x_t.
- 'pred_x_0': a list of prediction of x_0.
"""
return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs)

View File

@@ -0,0 +1,209 @@
import torch
import numpy as np
from plyfile import PlyData, PlyElement
from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation
import utils3d
class Gaussian:
def __init__(
self,
aabb : list,
sh_degree : int = 0,
mininum_kernel_size : float = 0.0,
scaling_bias : float = 0.01,
opacity_bias : float = 0.1,
scaling_activation : str = "exp",
device='cuda'
):
self.init_params = {
'aabb': aabb,
'sh_degree': sh_degree,
'mininum_kernel_size': mininum_kernel_size,
'scaling_bias': scaling_bias,
'opacity_bias': opacity_bias,
'scaling_activation': scaling_activation,
}
self.sh_degree = sh_degree
self.active_sh_degree = sh_degree
self.mininum_kernel_size = mininum_kernel_size
self.scaling_bias = scaling_bias
self.opacity_bias = opacity_bias
self.scaling_activation_type = scaling_activation
self.device = device
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
self.setup_functions()
self._xyz = None
self._features_dc = None
self._features_rest = None
self._scaling = None
self._rotation = None
self._opacity = None
def setup_functions(self):
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
actual_covariance = L @ L.transpose(1, 2)
symm = strip_symmetric(actual_covariance)
return symm
if self.scaling_activation_type == "exp":
self.scaling_activation = torch.exp
self.inverse_scaling_activation = torch.log
elif self.scaling_activation_type == "softplus":
self.scaling_activation = torch.nn.functional.softplus
self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
self.covariance_activation = build_covariance_from_scaling_rotation
self.opacity_activation = torch.sigmoid
self.inverse_opacity_activation = inverse_sigmoid
self.rotation_activation = torch.nn.functional.normalize
self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda()
self.rots_bias = torch.zeros((4)).cuda()
self.rots_bias[0] = 1
self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda()
@property
def get_scaling(self):
scales = self.scaling_activation(self._scaling + self.scale_bias)
scales = torch.square(scales) + self.mininum_kernel_size ** 2
scales = torch.sqrt(scales)
return scales
@property
def get_rotation(self):
return self.rotation_activation(self._rotation + self.rots_bias[None, :])
@property
def get_xyz(self):
return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3]
@property
def get_features(self):
return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc
@property
def get_opacity(self):
return self.opacity_activation(self._opacity + self.opacity_bias)
def get_covariance(self, scaling_modifier = 1):
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :])
def from_scaling(self, scales):
scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)
self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias
def from_rotation(self, rots):
self._rotation = rots - self.rots_bias[None, :]
def from_xyz(self, xyz):
self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
def from_features(self, features):
self._features_dc = features
def from_opacity(self, opacities):
self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
def construct_list_of_attributes(self):
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
# All channels except the 3 DC
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
l.append('f_dc_{}'.format(i))
l.append('opacity')
for i in range(self._scaling.shape[1]):
l.append('scale_{}'.format(i))
for i in range(self._rotation.shape[1]):
l.append('rot_{}'.format(i))
return l
def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
xyz = self.get_xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
scale = torch.log(self.get_scaling).detach().cpu().numpy()
rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
if transform is not None:
transform = np.array(transform)
xyz = np.matmul(xyz, transform.T)
rotation = utils3d.numpy.quaternion_to_matrix(rotation)
rotation = np.matmul(transform, rotation)
rotation = utils3d.numpy.matrix_to_quaternion(rotation)
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)
def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
plydata = PlyData.read(path)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"])), axis=1)
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
features_dc = np.zeros((xyz.shape[0], 3, 1))
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
if self.sh_degree > 0:
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
rots = np.zeros((xyz.shape[0], len(rot_names)))
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
if transform is not None:
transform = np.array(transform)
xyz = np.matmul(xyz, transform)
rotation = utils3d.numpy.quaternion_to_matrix(rotation)
rotation = np.matmul(rotation, transform)
rotation = utils3d.numpy.matrix_to_quaternion(rotation)
# convert to actual gaussian attributes
xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
if self.sh_degree > 0:
features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device))
scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device))
rots = torch.tensor(rots, dtype=torch.float, device=self.device)
# convert to _hidden attributes
self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
self._features_dc = features_dc
if self.sh_degree > 0:
self._features_rest = features_extra
else:
self._features_rest = None
self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias
self._rotation = rots - self.rots_bias[None, :]

View File

@@ -0,0 +1,48 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
from zipfile import ZipFile
from tqdm import tqdm
import os
def download_file(url, output_path):
response = requests.get(url, stream=True)
response.raise_for_status()
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 #1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(output_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
raise Exception("ERROR, something went wrong")
url = "https://vcg.isti.cnr.it/Publications/2014/MPZ14/inputmodels.zip"
zip_file_path = './data/inputmodels.zip'
os.makedirs('./data', exist_ok=True)
download_file(url, zip_file_path)
with ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall('./data')
os.remove(zip_file_path)
print("Download and extraction complete.")

View File

@@ -0,0 +1,157 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import numpy as np
import torch
import nvdiffrast.torch as dr
import trimesh
import os
from util import *
import render
import loss
import imageio
import sys
sys.path.append('..')
from flexicubes import FlexiCubes
###############################################################################
# Functions adapted from https://github.com/NVlabs/nvdiffrec
###############################################################################
def lr_schedule(iter):
return max(0.0, 10**(-(iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='flexicubes optimization')
parser.add_argument('-o', '--out_dir', type=str, default=None)
parser.add_argument('-rm', '--ref_mesh', type=str)
parser.add_argument('-i', '--iter', type=int, default=1000)
parser.add_argument('-b', '--batch', type=int, default=8)
parser.add_argument('-r', '--train_res', nargs=2, type=int, default=[2048, 2048])
parser.add_argument('-lr', '--learning_rate', type=float, default=0.01)
parser.add_argument('--voxel_grid_res', type=int, default=64)
parser.add_argument('--sdf_loss', type=bool, default=True)
parser.add_argument('--develop_reg', type=bool, default=False)
parser.add_argument('--sdf_regularizer', type=float, default=0.2)
parser.add_argument('-dr', '--display_res', nargs=2, type=int, default=[512, 512])
parser.add_argument('-si', '--save_interval', type=int, default=20)
FLAGS = parser.parse_args()
device = 'cuda'
os.makedirs(FLAGS.out_dir, exist_ok=True)
glctx = dr.RasterizeGLContext()
# Load GT mesh
gt_mesh = load_mesh(FLAGS.ref_mesh, device)
gt_mesh.auto_normals() # compute face normals for visualization
# ==============================================================================================
# Create and initialize FlexiCubes
# ==============================================================================================
fc = FlexiCubes(device)
x_nx3, cube_fx8 = fc.construct_voxel_grid(FLAGS.voxel_grid_res)
x_nx3 *= 2 # scale up the grid so that it's larger than the target object
sdf = torch.rand_like(x_nx3[:,0]) - 0.1 # randomly init SDF
sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
# set per-cube learnable weights to zeros
weight = torch.zeros((cube_fx8.shape[0], 21), dtype=torch.float, device='cuda')
weight = torch.nn.Parameter(weight.clone().detach(), requires_grad=True)
deform = torch.nn.Parameter(torch.zeros_like(x_nx3), requires_grad=True)
# Retrieve all the edges of the voxel grid; these edges will be utilized to
# compute the regularization loss in subsequent steps of the process.
all_edges = cube_fx8[:, fc.cube_edges].reshape(-1, 2)
grid_edges = torch.unique(all_edges, dim=0)
# ==============================================================================================
# Setup optimizer
# ==============================================================================================
optimizer = torch.optim.Adam([sdf, weight,deform], lr=FLAGS.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x))
# ==============================================================================================
# Train loop
# ==============================================================================================
for it in range(FLAGS.iter):
optimizer.zero_grad()
# sample random camera poses
mv, mvp = render.get_random_camera_batch(FLAGS.batch, iter_res=FLAGS.train_res, device=device, use_kaolin=False)
# render gt mesh
target = render.render_mesh_paper(gt_mesh, mv, mvp, FLAGS.train_res)
# extract and render FlexiCubes mesh
grid_verts = x_nx3 + (2-1e-8) / (FLAGS.voxel_grid_res * 2) * torch.tanh(deform)
vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
gamma_f=weight[:,20], training=True)
flexicubes_mesh = Mesh(vertices, faces)
buffers = render.render_mesh_paper(flexicubes_mesh, mv, mvp, FLAGS.train_res)
# evaluate reconstruction loss
mask_loss = (buffers['mask'] - target['mask']).abs().mean()
depth_loss = (((((buffers['depth'] - (target['depth']))* target['mask'])**2).sum(-1)+1e-8)).sqrt().mean() * 10
t_iter = it / FLAGS.iter
sdf_weight = FLAGS.sdf_regularizer - (FLAGS.sdf_regularizer - FLAGS.sdf_regularizer/20)*min(1.0, 4.0 * t_iter)
reg_loss = loss.sdf_reg_loss(sdf, grid_edges).mean() * sdf_weight # Loss to eliminate internal floaters that are not visible
reg_loss += L_dev.mean() * 0.5
reg_loss += (weight[:,:20]).abs().mean() * 0.1
total_loss = mask_loss + depth_loss + reg_loss
if FLAGS.sdf_loss: # optionally add SDF loss to eliminate internal structures
with torch.no_grad():
pts = sample_random_points(1000, gt_mesh)
gt_sdf = compute_sdf(pts, gt_mesh.vertices, gt_mesh.faces)
pred_sdf = compute_sdf(pts, flexicubes_mesh.vertices, flexicubes_mesh.faces)
total_loss += torch.nn.functional.mse_loss(pred_sdf, gt_sdf) * 2e3
# optionally add developability regularizer, as described in paper section 5.2
if FLAGS.develop_reg:
reg_weight = max(0, t_iter - 0.8) * 5
if reg_weight > 0: # only applied after shape converges
reg_loss = loss.mesh_developable_reg(flexicubes_mesh).mean() * 10
reg_loss += (deform).abs().mean()
reg_loss += (weight[:,:20]).abs().mean()
total_loss = mask_loss + depth_loss + reg_loss
total_loss.backward()
optimizer.step()
scheduler.step()
if (it % FLAGS.save_interval == 0 or it == (FLAGS.iter-1)): # save normal image for visualization
with torch.no_grad():
# extract mesh with training=False
vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20],
gamma_f=weight[:,20], training=False)
flexicubes_mesh = Mesh(vertices, faces)
flexicubes_mesh.auto_normals() # compute face normals for visualization
mv, mvp = render.get_rotate_camera(it//FLAGS.save_interval, iter_res=FLAGS.display_res, device=device,use_kaolin=False)
val_buffers = render.render_mesh_paper(flexicubes_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
val_image = ((val_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
gt_buffers = render.render_mesh_paper(gt_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True)
gt_image = ((gt_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8)
imageio.imwrite(os.path.join(FLAGS.out_dir, '{:04d}.png'.format(it)), np.concatenate([val_image, gt_image], 1))
print(f"Optimization Step [{it}/{FLAGS.iter}], Loss: {total_loss.item():.4f}")
# ==============================================================================================
# Save ouput
# ==============================================================================================
mesh_np = trimesh.Trimesh(vertices = vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy(), process=False)
mesh_np.export(os.path.join(FLAGS.out_dir, 'output_mesh.obj'))

View File

@@ -0,0 +1,390 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .tables import *
from kaolin.utils.testing import check_tensor
__all__ = [
'FlexiCubes'
]
class FlexiCubes:
def __init__(self, device="cuda"):
self.device = device
self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
self.num_vd_table = torch.tensor(num_vd_table,
dtype=torch.long, device=device, requires_grad=False)
self.check_table = torch.tensor(
check_table,
dtype=torch.long, device=device, requires_grad=False)
self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
self.quad_split_train = torch.tensor(
[0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
dtype=torch.long, device=device)
self.dir_faces_table = torch.tensor([
[[5, 4], [3, 2], [4, 5], [2, 3]],
[[5, 4], [1, 0], [4, 5], [0, 1]],
[[3, 2], [1, 0], [2, 3], [0, 1]]
], dtype=torch.long, device=device)
self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3,
weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False):
assert torch.is_tensor(voxelgrid_vertices) and \
check_tensor(voxelgrid_vertices, (None, 3), throw=False), \
"'voxelgrid_vertices' should be a tensor of shape (num_vertices, 3)"
num_vertices = voxelgrid_vertices.shape[0]
assert torch.is_tensor(scalar_field) and \
check_tensor(scalar_field, (num_vertices,), throw=False), \
"'scalar_field' should be a tensor of shape (num_vertices,)"
assert torch.is_tensor(cube_idx) and \
check_tensor(cube_idx, (None, 8), throw=False), \
"'cube_idx' should be a tensor of shape (num_cubes, 8)"
num_cubes = cube_idx.shape[0]
assert beta is None or (
torch.is_tensor(beta) and
check_tensor(beta, (num_cubes, 12), throw=False)
), "'beta' should be a tensor of shape (num_cubes, 12)"
assert alpha is None or (
torch.is_tensor(alpha) and
check_tensor(alpha, (num_cubes, 8), throw=False)
), "'alpha' should be a tensor of shape (num_cubes, 8)"
assert gamma_f is None or (
torch.is_tensor(gamma_f) and
check_tensor(gamma_f, (num_cubes,), throw=False)
), "'gamma_f' should be a tensor of shape (num_cubes,)"
surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx)
if surf_cubes.sum() == 0:
return (
torch.zeros((0, 3), device=self.device),
torch.zeros((0, 3), dtype=torch.long, device=self.device),
torch.zeros((0), device=self.device),
torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None
)
beta, alpha, gamma_f = self._normalize_weights(
beta, alpha, gamma_f, surf_cubes, weight_scale)
if voxelgrid_colors is not None:
voxelgrid_colors = torch.sigmoid(voxelgrid_colors)
case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution)
surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(
scalar_field, cube_idx, surf_cubes
)
vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd(
voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field,
case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors)
vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate(
scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map,
vd_idx_map, surf_edges_mask, training, vd_color)
return vertices, faces, L_dev, vertices_color
def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
"""
Regularizer L_dev as in Equation 8
"""
dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
mean_l2 = torch.zeros_like(vd[:, 0])
mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
return mad
def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale):
"""
Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
"""
n_cubes = surf_cubes.shape[0]
if beta is not None:
beta = (torch.tanh(beta) * weight_scale + 1)
else:
beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
if alpha is not None:
alpha = (torch.tanh(alpha) * weight_scale + 1)
else:
alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
if gamma_f is not None:
gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2
else:
gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes]
@torch.no_grad()
def _get_case_id(self, occ_fx8, surf_cubes, res):
"""
Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
supplementary material. It should be noted that this function assumes a regular grid.
"""
case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
problem_config = self.check_table.to(self.device)[case_ids]
to_check = problem_config[..., 0] == 1
problem_config = problem_config[to_check]
if not isinstance(res, (list, tuple)):
res = [res, res, res]
# The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
# 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
# This allows efficient checking on adjacent cubes.
problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
vol_idx_problem = vol_idx[surf_cubes][to_check]
problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
within_range = (
vol_idx_problem_adj[..., 0] >= 0) & (
vol_idx_problem_adj[..., 0] < res[0]) & (
vol_idx_problem_adj[..., 1] >= 0) & (
vol_idx_problem_adj[..., 1] < res[1]) & (
vol_idx_problem_adj[..., 2] >= 0) & (
vol_idx_problem_adj[..., 2] < res[2])
vol_idx_problem = vol_idx_problem[within_range]
vol_idx_problem_adj = vol_idx_problem_adj[within_range]
problem_config = problem_config[within_range]
problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
# If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
to_invert = (problem_config_adj[..., 0] == 1)
idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
return case_ids
@torch.no_grad()
def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes):
"""
Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
and marks the cube edges with this index.
"""
occ_n = scalar_field < 0
all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2)
unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
unique_edges = unique_edges.long()
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
surf_edges_mask = mask_edges[_idx_map]
counts = counts[_idx_map]
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1
mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device)
# Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
# for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
idx_map = mapping[_idx_map]
surf_edges = unique_edges[mask_edges]
return surf_edges, idx_map, counts, surf_edges_mask
@torch.no_grad()
def _identify_surf_cubes(self, scalar_field, cube_idx):
"""
Identifies grid cubes that intersect with the underlying surface by checking if the signs at
all corners are not identical.
"""
occ_n = scalar_field < 0
occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8)
_occ_sum = torch.sum(occ_fx8, -1)
surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
return surf_cubes, occ_fx8
def _linear_interp(self, edges_weight, edges_x):
"""
Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
"""
edge_dim = edges_weight.dim() - 2
assert edges_weight.shape[edge_dim] == 2
edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)]
, edge_dim)
denominator = edges_weight.sum(edge_dim)
ue = (edges_x * edges_weight).sum(edge_dim) / denominator
return ue
def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale):
p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
c_bx3 = c_bx3.reshape(-1, 3)
A = norm_bxnx3
B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1)
A = torch.cat([A, A_reg], 1)
B = torch.cat([B, B_reg], 1)
dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
return dual_verts
def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field,
case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors):
"""
Computes the location of dual vertices as described in Section 4.2
"""
alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
if voxelgrid_colors is not None:
C = voxelgrid_colors.shape[-1]
surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C)
idx_map = idx_map.reshape(-1, 12)
num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
# if color is not None:
# vd_color = []
total_num_vd = 0
vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
for num in torch.unique(num_vd):
cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
curr_num_vd = cur_cubes.sum() * num
curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
curr_edge_group_to_vd = torch.arange(
curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
total_num_vd += curr_num_vd
curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
curr_mask = (curr_edge_group != -1)
edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
# if color is not None:
# vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3))
edge_group = torch.cat(edge_group)
edge_group_to_vd = torch.cat(edge_group_to_vd)
edge_group_to_cube = torch.cat(edge_group_to_cube)
vd_num_edges = torch.cat(vd_num_edges)
vd_gamma = torch.cat(vd_gamma)
# if color is not None:
# vd_color = torch.cat(vd_color)
# else:
# vd_color = None
vd = torch.zeros((total_num_vd, 3), device=self.device)
beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
zero_crossing_group = torch.index_select(
input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
ue_group = self._linear_interp(s_group * alpha_group, x_group)
beta_group = torch.gather(input=beta.reshape(-1), dim=0,
index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
'''
interpolate colors use the same method as dual vertices
'''
if voxelgrid_colors is not None:
vd_color = torch.zeros((total_num_vd, C), device=self.device)
c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C)
uc_group = self._linear_interp(s_group * alpha_group, c_group)
vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum
else:
vd_color = None
L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
12 + edge_group, src=v_idx[edge_group_to_vd])
return vd, L_dev, vd_gamma, vd_idx_map, vd_color
def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color):
"""
Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
triangles based on the gamma parameter, as described in Section 4.3.
"""
with torch.no_grad():
group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
group = idx_map.reshape(-1)[group_mask]
vd_idx = vd_idx_map[group_mask]
edge_indices, indices = torch.sort(group, stable=True)
quad_vd_idx = vd_idx[indices].reshape(-1, 4)
# Ensure all face directions point towards the positive SDF to maintain consistent winding.
s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
flip_mask = s_edges[:, 0] > 0
quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2]
gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3]
if not training:
mask = (gamma_02 > gamma_13)
faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
faces = faces.reshape(-1, 3)
else:
vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2
vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2
weight_sum = (gamma_02 + gamma_13) + 1e-8
vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
if vd_color is not None:
color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1])
color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2
color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2
color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
vd_color = torch.cat([vd_color, color_center])
vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
vd = torch.cat([vd, vd_center])
faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
return vd, faces, s_edges, edge_indices, vd_color

View File

@@ -0,0 +1,353 @@
from typing import *
import copy
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from easydict import EasyDict as edict
from ..basic import BasicTrainer
from ...pipelines import samplers
from ...utils.general_utils import dict_reduce
from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin
from .mixins.text_conditioned import TextConditionedMixin
from .mixins.image_conditioned import ImageConditionedMixin
class FlowMatchingTrainer(BasicTrainer):
"""
Trainer for diffusion model with flow matching objective.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
"""
def __init__(
self,
*args,
t_schedule: dict = {
'name': 'logitNormal',
'args': {
'mean': 0.0,
'std': 1.0,
}
},
sigma_min: float = 1e-5,
**kwargs
):
super().__init__(*args, **kwargs)
self.t_schedule = t_schedule
self.sigma_min = sigma_min
def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
Args:
x_0: The [N x C x ...] tensor of noiseless inputs.
t: The [N] tensor of diffusion steps [0-1].
noise: If specified, use this noise instead of generating new noise.
Returns:
x_t, the noisy version of x_0 under timestep t.
"""
if noise is None:
noise = torch.randn_like(x_0)
assert noise.shape == x_0.shape, "noise must have same shape as x_0"
t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)])
x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise
return x_t
def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""
Get original image from noisy version under timestep t.
"""
assert noise.shape == x_t.shape, "noise must have same shape as x_t"
t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)])
x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t)
return x_0
def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
Compute the velocity of the diffusion process at time t.
"""
return (1 - self.sigma_min) * noise - x_0
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
return {'cond': cond, **kwargs}
def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler:
"""
Get the sampler for the diffusion process.
"""
return samplers.FlowEulerSampler(self.sigma_min)
def vis_cond(self, **kwargs):
"""
Visualize the conditioning data.
"""
return {}
def sample_t(self, batch_size: int) -> torch.Tensor:
"""
Sample timesteps.
"""
if self.t_schedule['name'] == 'uniform':
t = torch.rand(batch_size)
elif self.t_schedule['name'] == 'logitNormal':
mean = self.t_schedule['args']['mean']
std = self.t_schedule['args']['std']
t = torch.sigmoid(torch.randn(batch_size) * std + mean)
else:
raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}")
return t
def training_losses(
self,
x_0: torch.Tensor,
cond=None,
**kwargs
) -> Tuple[Dict, Dict]:
"""
Compute training losses for a single timestep.
Args:
x_0: The [N x C x ...] tensor of noiseless inputs.
cond: The [N x ...] tensor of additional conditions.
kwargs: Additional arguments to pass to the backbone.
Returns:
a dict with the key "loss" containing a tensor of shape [N].
may also contain other keys for different terms.
"""
noise = torch.randn_like(x_0)
t = self.sample_t(x_0.shape[0]).to(x_0.device).float()
x_t = self.diffuse(x_0, t, noise=noise)
cond = self.get_cond(cond, **kwargs)
pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs)
assert pred.shape == noise.shape == x_0.shape
target = self.get_v(x_0, noise, t)
terms = edict()
terms["mse"] = F.mse_loss(pred, target)
terms["loss"] = terms["mse"]
# log loss with time bins
mse_per_instance = np.array([
F.mse_loss(pred[i], target[i]).item()
for i in range(x_0.shape[0])
])
time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1
for i in range(10):
if (time_bin == i).sum() != 0:
terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()}
return terms, {}
@torch.no_grad()
def run_snapshot(
self,
num_samples: int,
batch_size: int,
verbose: bool = False,
) -> Dict:
dataloader = DataLoader(
copy.deepcopy(self.dataset),
batch_size=batch_size,
shuffle=True,
num_workers=0,
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
)
# inference
sampler = self.get_sampler()
sample_gt = []
sample = []
cond_vis = []
for i in range(0, num_samples, batch_size):
batch = min(batch_size, num_samples - i)
data = next(iter(dataloader))
data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()}
noise = torch.randn_like(data['x_0'])
sample_gt.append(data['x_0'])
cond_vis.append(self.vis_cond(**data))
del data['x_0']
args = self.get_inference_cond(**data)
res = sampler.sample(
self.models['denoiser'],
noise=noise,
**args,
steps=50, cfg_strength=3.0, verbose=verbose,
)
sample.append(res.samples)
sample_gt = torch.cat(sample_gt, dim=0)
sample = torch.cat(sample, dim=0)
sample_dict = {
'sample_gt': {'value': sample_gt, 'type': 'sample'},
'sample': {'value': sample, 'type': 'sample'},
}
sample_dict.update(dict_reduce(cond_vis, None, {
'value': lambda x: torch.cat(x, dim=0),
'type': lambda x: x[0],
}))
return sample_dict
class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer):
"""
Trainer for diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
"""
pass
class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer):
"""
Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
text_cond_model(str): Text conditioning model.
"""
pass
class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer):
"""
Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance.
Args:
models (dict[str, nn.Module]): Models to train.
dataset (torch.utils.data.Dataset): Dataset.
output_dir (str): Output directory.
load_dir (str): Load directory.
step (int): Step to load.
batch_size (int): Batch size.
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
batch_split (int): Split batch with gradient accumulation.
max_steps (int): Max steps.
optimizer (dict): Optimizer config.
lr_scheduler (dict): Learning rate scheduler config.
elastic (dict): Elastic memory management config.
grad_clip (float or dict): Gradient clip config.
ema_rate (float or list): Exponential moving average rates.
fp16_mode (str): FP16 mode.
- None: No FP16.
- 'inflat_all': Hold a inflated fp32 master param for all params.
- 'amp': Automatic mixed precision.
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
finetune_ckpt (dict): Finetune checkpoint.
log_param_stats (bool): Log parameter stats.
i_print (int): Print interval.
i_log (int): Log interval.
i_sample (int): Sample interval.
i_save (int): Save interval.
i_ddpcheck (int): DDP check interval.
t_schedule (dict): Time schedule for flow matching.
sigma_min (float): Minimum noise level.
p_uncond (float): Probability of dropping conditions.
image_cond_model (str): Image conditioning model.
"""
pass

View File

@@ -0,0 +1,93 @@
import os
import io
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(local_rank)
torch.cuda.set_device(local_rank)
dist.init_process_group('nccl', rank=rank, world_size=world_size)
def read_file_dist(path):
"""
Read the binary file distributedly.
File is only read once by the rank 0 process and broadcasted to other processes.
Returns:
data (io.BytesIO): The binary data read from the file.
"""
if dist.is_initialized() and dist.get_world_size() > 1:
# read file
size = torch.LongTensor(1).cuda()
if dist.get_rank() == 0:
with open(path, 'rb') as f:
data = f.read()
data = torch.ByteTensor(
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
).cuda()
size[0] = data.shape[0]
# broadcast size
dist.broadcast(size, src=0)
if dist.get_rank() != 0:
data = torch.ByteTensor(size[0].item()).cuda()
# broadcast data
dist.broadcast(data, src=0)
# convert to io.BytesIO
data = data.cpu().numpy().tobytes()
data = io.BytesIO(data)
return data
else:
with open(path, 'rb') as f:
data = f.read()
data = io.BytesIO(data)
return data
def unwrap_dist(model):
"""
Unwrap the model from distributed training.
"""
if isinstance(model, DDP):
return model.module
return model
@contextmanager
def master_first():
"""
A context manager that ensures master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield
@contextmanager
def local_master_first():
"""
A context manager that ensures local master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() % torch.cuda.device_count() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield

View File

@@ -0,0 +1,228 @@
from abc import abstractmethod
from contextlib import contextmanager
from typing import Tuple
import torch
import torch.nn as nn
import numpy as np
class MemoryController:
"""
Base class for memory management during training.
"""
_last_input_size = None
_last_mem_ratio = []
@contextmanager
def record(self):
pass
def update_run_states(self, input_size=None, mem_ratio=None):
if self._last_input_size is None:
self._last_input_size = input_size
elif self._last_input_size!= input_size:
raise ValueError(f'Input size should not change for different ElasticModules.')
self._last_mem_ratio.append(mem_ratio)
@abstractmethod
def get_mem_ratio(self, input_size):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def log(self):
pass
class LinearMemoryController(MemoryController):
"""
A simple controller for memory management during training.
The memory usage is modeled as a linear function of:
- the number of input parameters
- the ratio of memory the model use compared to the maximum usage (with no checkpointing)
memory_usage = k * input_size * mem_ratio + b
The controller keeps track of the memory usage and gives the
expected memory ratio to keep the memory usage under a target
"""
def __init__(
self,
buffer_size=1000,
update_every=500,
target_ratio=0.8,
available_memory=None,
max_mem_ratio_start=0.1,
params=None,
device=None
):
self.buffer_size = buffer_size
self.update_every = update_every
self.target_ratio = target_ratio
self.device = device or torch.cuda.current_device()
self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3
self._memory = np.zeros(buffer_size, dtype=np.float32)
self._input_size = np.zeros(buffer_size, dtype=np.float32)
self._mem_ratio = np.zeros(buffer_size, dtype=np.float32)
self._buffer_ptr = 0
self._buffer_length = 0
self._params = tuple(params) if params is not None else (0.0, 0.0)
self._max_mem_ratio = max_mem_ratio_start
self.step = 0
def __repr__(self):
return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})'
def _add_sample(self, memory, input_size, mem_ratio):
self._memory[self._buffer_ptr] = memory
self._input_size[self._buffer_ptr] = input_size
self._mem_ratio[self._buffer_ptr] = mem_ratio
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
@contextmanager
def record(self):
torch.cuda.reset_peak_memory_stats(self.device)
self._last_input_size = None
self._last_mem_ratio = []
yield
self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3
self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio)
self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio)
self.step += 1
if self.step % self.update_every == 0:
self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1)
self._fit_params()
def _fit_params(self):
memory_usage = self._memory[:self._buffer_length]
input_size = self._input_size[:self._buffer_length]
mem_ratio = self._mem_ratio[:self._buffer_length]
x = input_size * mem_ratio
y = memory_usage
k, b = np.polyfit(x, y, 1)
self._params = (k, b)
# self._visualize()
def _visualize(self):
import matplotlib.pyplot as plt
memory_usage = self._memory[:self._buffer_length]
input_size = self._input_size[:self._buffer_length]
mem_ratio = self._mem_ratio[:self._buffer_length]
k, b = self._params
plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis')
x = np.array([0.0, 20000.0])
plt.plot(x, k * x + b, c='r')
plt.savefig(f'linear_memory_controller_{self.step}.png')
plt.cla()
def get_mem_ratio(self, input_size):
k, b = self._params
if k == 0: return np.random.rand() * self._max_mem_ratio
pred = (self.available_memory * self.target_ratio - b) / (k * input_size)
return min(self._max_mem_ratio, max(0.0, pred))
def state_dict(self):
return {
'params': self._params,
}
def load_state_dict(self, state_dict):
self._params = tuple(state_dict['params'])
def log(self):
return {
'params/k': self._params[0],
'params/b': self._params[1],
'memory': self._last_memory,
'input_size': self._last_input_size,
'mem_ratio': self._last_mem_ratio,
}
class ElasticModule(nn.Module):
"""
Module for training with elastic memory management.
"""
def __init__(self):
super().__init__()
self._memory_controller: MemoryController = None
@abstractmethod
def _get_input_size(self, *args, **kwargs) -> int:
"""
Get the size of the input data.
Returns:
int: The size of the input data.
"""
pass
@abstractmethod
def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]:
"""
Forward with a given memory ratio.
"""
pass
def register_memory_controller(self, memory_controller: MemoryController):
self._memory_controller = memory_controller
def forward(self, *args, **kwargs):
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
_, ret = self._forward_with_mem_ratio(*args, **kwargs)
else:
input_size = self._get_input_size(*args, **kwargs)
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs)
self._memory_controller.update_run_states(input_size, mem_ratio)
return ret
class ElasticModuleMixin:
"""
Mixin for training with elastic memory management.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._memory_controller: MemoryController = None
@abstractmethod
def _get_input_size(self, *args, **kwargs) -> int:
"""
Get the size of the input data.
Returns:
int: The size of the input data.
"""
pass
@abstractmethod
@contextmanager
def with_mem_ratio(self, mem_ratio=1.0) -> float:
"""
Context manager for training with a reduced memory ratio compared to the full memory usage.
Returns:
float: The exact memory ratio used during the forward pass.
"""
pass
def register_memory_controller(self, memory_controller: MemoryController):
self._memory_controller = memory_controller
def forward(self, *args, **kwargs):
if self._memory_controller is None or not torch.is_grad_enabled() or not self.training:
ret = super().forward(*args, **kwargs)
else:
input_size = self._get_input_size(*args, **kwargs)
mem_ratio = self._memory_controller.get_mem_ratio(input_size)
with self.with_mem_ratio(mem_ratio) as exact_mem_ratio:
ret = super().forward(*args, **kwargs)
self._memory_controller.update_run_states(input_size, exact_mem_ratio)
return ret

View File

@@ -0,0 +1,587 @@
from typing import *
import numpy as np
import torch
import utils3d
import nvdiffrast.torch as dr
from tqdm import tqdm
import trimesh
import trimesh.visual
import xatlas
import pyvista as pv
from pymeshfix import _meshfix
import igraph
import cv2
from PIL import Image
from .random_utils import sphere_hammersley_sequence
from .render_utils import render_multiview
from ..renderers import GaussianRenderer
from ..representations import Strivec, Gaussian, MeshExtractResult
@torch.no_grad()
def _fill_holes(
verts,
faces,
max_hole_size=0.04,
max_hole_nbe=32,
resolution=128,
num_views=500,
debug=False,
verbose=False
):
"""
Rasterize a mesh from multiple views and remove invisible faces.
Also includes postprocessing to:
1. Remove connected components that are have low visibility.
2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole.
Args:
verts (torch.Tensor): Vertices of the mesh. Shape (V, 3).
faces (torch.Tensor): Faces of the mesh. Shape (F, 3).
max_hole_size (float): Maximum area of a hole to fill.
resolution (int): Resolution of the rasterization.
num_views (int): Number of views to rasterize the mesh.
verbose (bool): Whether to print progress.
"""
# Construct cameras
yaws = []
pitchs = []
for i in range(num_views):
y, p = sphere_hammersley_sequence(i, num_views)
yaws.append(y)
pitchs.append(p)
yaws = torch.tensor(yaws).cuda()
pitchs = torch.tensor(pitchs).cuda()
radius = 2.0
fov = torch.deg2rad(torch.tensor(40)).cuda()
projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
views = []
for (yaw, pitch) in zip(yaws, pitchs):
orig = torch.tensor([
torch.sin(yaw) * torch.cos(pitch),
torch.cos(yaw) * torch.cos(pitch),
torch.sin(pitch),
]).cuda().float() * radius
view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
views.append(view)
views = torch.stack(views, dim=0)
# Rasterize
visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device)
rastctx = utils3d.torch.RastContext(backend='cuda')
for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'):
view = views[i]
buffers = utils3d.torch.rasterize_triangle_faces(
rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection
)
face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1
face_id = torch.unique(face_id).long()
visblity[face_id] += 1
visblity = visblity.float() / num_views
# Mincut
## construct outer faces
edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge)
outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device)
for i in range(len(connected_components)):
outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5)
outer_face_indices = outer_face_indices.nonzero().reshape(-1)
## construct inner faces
inner_face_indices = torch.nonzero(visblity == 0).reshape(-1)
if verbose:
tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces')
if inner_face_indices.shape[0] == 0:
return verts, faces
## Construct dual graph (faces as nodes, edges as edges)
dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge)
dual_edge2edge = edges[dual_edge2edge]
dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1)
if verbose:
tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges')
## solve mincut problem
### construct main graph
g = igraph.Graph()
g.add_vertices(faces.shape[0])
g.add_edges(dual_edges.cpu().numpy())
g.es['weight'] = dual_edges_weights.cpu().numpy()
### source and target
g.add_vertex('s')
g.add_vertex('t')
### connect invisible faces to source
g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
### connect outer faces to target
g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()})
### solve mincut
cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist())
remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device)
if verbose:
tqdm.write(f'Mincut solved, start checking the cut')
### check if the cut is valid with each connected component
to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices])
if debug:
tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}')
valid_remove_cc = []
cutting_edges = []
for cc in to_remove_cc:
#### check if the connected component has low visibility
visblity_median = visblity[remove_face_indices[cc]].median()
if debug:
tqdm.write(f'visblity_median: {visblity_median}')
if visblity_median > 0.25:
continue
#### check if the cuting loop is small enough
cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True)
cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)]
if len(cc_new_boundary_edge_indices) > 0:
cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices])
cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc]
cc_new_boundary_edges_cc_area = []
for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
_e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i]
_e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i]
cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5)
if debug:
cutting_edges.append(cc_new_boundary_edge_indices)
tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}')
if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]):
continue
valid_remove_cc.append(cc)
if debug:
face_v = verts[faces].mean(dim=1).cpu().numpy()
vis_dual_edges = dual_edges.cpu().numpy()
vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8)
vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255]
vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0]
vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255]
if len(valid_remove_cc) > 0:
vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0]
utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors)
vis_verts = verts.cpu().numpy()
vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy()
utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges)
if len(valid_remove_cc) > 0:
remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)]
mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device)
mask[remove_face_indices] = 0
faces = faces[mask]
faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts)
if verbose:
tqdm.write(f'Removed {(~mask).sum()} faces by mincut')
else:
if verbose:
tqdm.write(f'Removed 0 faces by mincut')
mesh = _meshfix.PyTMesh()
mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy())
mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
verts, faces = mesh.return_arrays()
verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32)
return verts, faces
def postprocess_mesh(
vertices: np.array,
faces: np.array,
simplify: bool = True,
simplify_ratio: float = 0.9,
fill_holes: bool = True,
fill_holes_max_hole_size: float = 0.04,
fill_holes_max_hole_nbe: int = 32,
fill_holes_resolution: int = 1024,
fill_holes_num_views: int = 1000,
debug: bool = False,
verbose: bool = False,
):
"""
Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces.
Args:
vertices (np.array): Vertices of the mesh. Shape (V, 3).
faces (np.array): Faces of the mesh. Shape (F, 3).
simplify (bool): Whether to simplify the mesh, using quadric edge collapse.
simplify_ratio (float): Ratio of faces to keep after simplification.
fill_holes (bool): Whether to fill holes in the mesh.
fill_holes_max_hole_size (float): Maximum area of a hole to fill.
fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill.
fill_holes_resolution (int): Resolution of the rasterization.
fill_holes_num_views (int): Number of views to rasterize the mesh.
verbose (bool): Whether to print progress.
"""
if verbose:
tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
# Simplify
if simplify and simplify_ratio > 0:
mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1))
mesh = mesh.decimate(simplify_ratio, progress_bar=verbose)
vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:]
if verbose:
tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
# Remove invisible faces
if fill_holes:
vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda()
vertices, faces = _fill_holes(
vertices, faces,
max_hole_size=fill_holes_max_hole_size,
max_hole_nbe=fill_holes_max_hole_nbe,
resolution=fill_holes_resolution,
num_views=fill_holes_num_views,
debug=debug,
verbose=verbose,
)
vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy()
if verbose:
tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces')
return vertices, faces
def parametrize_mesh(vertices: np.array, faces: np.array):
"""
Parametrize a mesh to a texture space, using xatlas.
Args:
vertices (np.array): Vertices of the mesh. Shape (V, 3).
faces (np.array): Faces of the mesh. Shape (F, 3).
"""
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
vertices = vertices[vmapping]
faces = indices
return vertices, faces, uvs
def bake_texture(
vertices: np.array,
faces: np.array,
uvs: np.array,
observations: List[np.array],
masks: List[np.array],
extrinsics: List[np.array],
intrinsics: List[np.array],
texture_size: int = 2048,
near: float = 0.1,
far: float = 10.0,
mode: Literal['fast', 'opt'] = 'opt',
lambda_tv: float = 1e-2,
verbose: bool = False,
):
"""
Bake texture to a mesh from multiple observations.
Args:
vertices (np.array): Vertices of the mesh. Shape (V, 3).
faces (np.array): Faces of the mesh. Shape (F, 3).
uvs (np.array): UV coordinates of the mesh. Shape (V, 2).
observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3).
masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W).
extrinsics (List[np.array]): List of extrinsics. Shape (4, 4).
intrinsics (List[np.array]): List of intrinsics. Shape (3, 3).
texture_size (int): Size of the texture.
near (float): Near plane of the camera.
far (float): Far plane of the camera.
mode (Literal['fast', 'opt']): Mode of texture baking.
lambda_tv (float): Weight of total variation loss in optimization.
verbose (bool): Whether to print progress.
"""
vertices = torch.tensor(vertices).cuda()
faces = torch.tensor(faces.astype(np.int32)).cuda()
uvs = torch.tensor(uvs).cuda()
observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations]
masks = [torch.tensor(m>0).bool().cuda() for m in masks]
views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics]
projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics]
if mode == 'fast':
texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda()
texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda()
rastctx = utils3d.torch.RastContext(backend='cuda')
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'):
with torch.no_grad():
rast = utils3d.torch.rasterize_triangle_faces(
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
)
uv_map = rast['uv'][0].detach().flip(0)
mask = rast['mask'][0].detach().bool() & masks[0]
# nearest neighbor interpolation
uv_map = (uv_map * texture_size).floor().long()
obs = observation[mask]
uv_map = uv_map[mask]
idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs)
texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device))
mask = texture_weights > 0
texture[mask] /= texture_weights[mask][:, None]
texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8)
# inpaint
mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size)
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
elif mode == 'opt':
rastctx = utils3d.torch.RastContext(backend='cuda')
observations = [observations.flip(0) for observations in observations]
masks = [m.flip(0) for m in masks]
_uv = []
_uv_dr = []
for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'):
with torch.no_grad():
rast = utils3d.torch.rasterize_triangle_faces(
rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection
)
_uv.append(rast['uv'].detach())
_uv_dr.append(rast['uv_dr'].detach())
texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda())
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
return start_lr * (end_lr / start_lr) ** (step / total_steps)
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
def tv_loss(texture):
return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \
torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :])
total_steps = 2500
with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar:
for step in range(total_steps):
optimizer.zero_grad()
selected = np.random.randint(0, len(views))
uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected]
render = dr.texture(texture, uv, uv_dr)[0]
loss = torch.nn.functional.l1_loss(render[mask], observation[mask])
if lambda_tv > 0:
loss += lambda_tv * tv_loss(texture)
loss.backward()
optimizer.step()
# annealing
optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5)
pbar.set_postfix({'loss': loss.item()})
pbar.update()
texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
mask = 1 - utils3d.torch.rasterize_triangle_faces(
rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size
)['mask'][0].detach().cpu().numpy().astype(np.uint8)
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
else:
raise ValueError(f'Unknown mode: {mode}')
return texture
def to_glb(
app_rep: Union[Strivec, Gaussian],
mesh: MeshExtractResult,
simplify: float = 0.95,
fill_holes: bool = True,
fill_holes_max_size: float = 0.04,
texture_size: int = 1024,
debug: bool = False,
verbose: bool = True,
) -> trimesh.Trimesh:
"""
Convert a generated asset to a glb file.
Args:
app_rep (Union[Strivec, Gaussian]): Appearance representation.
mesh (MeshExtractResult): Extracted mesh.
simplify (float): Ratio of faces to remove in simplification.
fill_holes (bool): Whether to fill holes in the mesh.
fill_holes_max_size (float): Maximum area of a hole to fill.
texture_size (int): Size of the texture.
debug (bool): Whether to print debug information.
verbose (bool): Whether to print progress.
"""
vertices = mesh.vertices.cpu().numpy()
faces = mesh.faces.cpu().numpy()
# mesh postprocess
vertices, faces = postprocess_mesh(
vertices, faces,
simplify=simplify > 0,
simplify_ratio=simplify,
fill_holes=fill_holes,
fill_holes_max_hole_size=fill_holes_max_size,
fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)),
fill_holes_resolution=1024,
fill_holes_num_views=1000,
debug=debug,
verbose=verbose,
)
# parametrize mesh
vertices, faces, uvs = parametrize_mesh(vertices, faces)
# bake texture
observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100)
masks = [np.any(observation > 0, axis=-1) for observation in observations]
extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))]
intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))]
texture = bake_texture(
vertices, faces, uvs,
observations, masks, extrinsics, intrinsics,
texture_size=texture_size, mode='opt',
lambda_tv=0.01,
verbose=verbose
)
texture = Image.fromarray(texture)
# rotate mesh (from z-up to y-up)
vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
material = trimesh.visual.material.PBRMaterial(
roughnessFactor=1.0,
baseColorTexture=texture,
baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8)
)
mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material))
return mesh
def simplify_gs(
gs: Gaussian,
simplify: float = 0.95,
verbose: bool = True,
):
"""
Simplify 3D Gaussians
NOTE: this function is not used in the current implementation for the unsatisfactory performance.
Args:
gs (Gaussian): 3D Gaussian.
simplify (float): Ratio of Gaussians to remove in simplification.
"""
if simplify <= 0:
return gs
# simplify
observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100)
observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations]
# Following https://arxiv.org/pdf/2411.06019
renderer = GaussianRenderer({
"resolution": 1024,
"near": 0.8,
"far": 1.6,
"ssaa": 1,
"bg_color": (0,0,0),
})
new_gs = Gaussian(**gs.init_params)
new_gs._features_dc = gs._features_dc.clone()
new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None
new_gs._opacity = torch.nn.Parameter(gs._opacity.clone())
new_gs._rotation = torch.nn.Parameter(gs._rotation.clone())
new_gs._scaling = torch.nn.Parameter(gs._scaling.clone())
new_gs._xyz = torch.nn.Parameter(gs._xyz.clone())
start_lr = [1e-4, 1e-3, 5e-3, 0.025]
end_lr = [1e-6, 1e-5, 5e-5, 0.00025]
optimizer = torch.optim.Adam([
{"params": new_gs._xyz, "lr": start_lr[0]},
{"params": new_gs._rotation, "lr": start_lr[1]},
{"params": new_gs._scaling, "lr": start_lr[2]},
{"params": new_gs._opacity, "lr": start_lr[3]},
], lr=start_lr[0])
def exp_anealing(optimizer, step, total_steps, start_lr, end_lr):
return start_lr * (end_lr / start_lr) ** (step / total_steps)
def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr):
return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
_zeta = new_gs.get_opacity.clone().detach().squeeze()
_lambda = torch.zeros_like(_zeta)
_delta = 1e-7
_interval = 10
num_target = int((1 - simplify) * _zeta.shape[0])
with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar:
for i in range(2500):
# prune
if i % 100 == 0:
mask = new_gs.get_opacity.squeeze() > 0.05
mask = torch.nonzero(mask).squeeze()
new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask])
new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask])
new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask])
new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask])
new_gs._features_dc = new_gs._features_dc[mask]
new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None
_zeta = _zeta[mask]
_lambda = _lambda[mask]
# update optimizer state
for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]):
stored_state = optimizer.state[param_group['params'][0]]
if 'exp_avg' in stored_state:
stored_state['exp_avg'] = stored_state['exp_avg'][mask]
stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask]
del optimizer.state[param_group['params'][0]]
param_group['params'][0] = new_param
optimizer.state[param_group['params'][0]] = stored_state
opacity = new_gs.get_opacity.squeeze()
# sparisfy
if i % _interval == 0:
_zeta = _lambda + opacity.detach()
if opacity.shape[0] > num_target:
index = _zeta.topk(num_target)[1]
_m = torch.ones_like(_zeta, dtype=torch.bool)
_m[index] = 0
_zeta[_m] = 0
_lambda = _lambda + opacity.detach() - _zeta
# sample a random view
view_idx = np.random.randint(len(observations))
observation = observations[view_idx]
extrinsic = extrinsics[view_idx]
intrinsic = intrinsics[view_idx]
color = renderer.render(new_gs, extrinsic, intrinsic)['color']
rgb_loss = torch.nn.functional.l1_loss(color, observation)
loss = rgb_loss + \
_delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update lr
for j in range(len(optimizer.param_groups)):
optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j])
pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()})
pbar.update()
new_gs._xyz = new_gs._xyz.data
new_gs._rotation = new_gs._rotation.data
new_gs._scaling = new_gs._scaling.data
new_gs._opacity = new_gs._opacity.data
return new_gs