diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..4d3893f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,62 @@ +# 使用 devel 镜像,确保有 nvcc + CUDA Toolkit +FROM nvidia/cuda:12.8.0-devel-ubuntu22.04 + +# 安装基本系统依赖(apt 清理缓存节省空间) +RUN apt-get update && apt-get install -y --no-install-recommends \ + wget \ + git \ + build-essential \ + cmake \ + ninja-build \ + libglib2.0-0 \ + libgl1 \ + python3.10 \ + python3.10-dev \ + python3-pip \ + && rm -rf /var/lib/apt/lists/* + +# 安装 miniconda(推荐用官方最新版) +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \ + bash /tmp/miniconda.sh -b -p /opt/conda && \ + rm /tmp/miniconda.sh && \ + /opt/conda/bin/conda clean --all -y + +# 把 conda 加到 PATH +ENV PATH="/opt/conda/bin:${PATH}" + +# 接受 ToS(关键修复点) +RUN conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main && \ + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r +# +## 创建环境 +#RUN conda create -n trellis python=3.10 -y && \ +# echo "conda activate trellis" >> ~/.bashrc +# +#SHELL ["/bin/bash", "-c"] +# +## 激活环境后安装 PyTorch 2.8 nightly + CUDA 12.8 +#RUN conda activate trellis && \ +# conda install pytorch torchvision torchaudio pytorch-cuda=12.8 -c pytorch-nightly -c nvidia -y +# +## 安装 pytorch3d(优先用 fvcore 频道,如果没有则从 git 源码编译) +#RUN conda activate trellis && \ +# conda install pytorch3d -c fvcore -c pytorch -c nvidia -y || \ +# pip install --no-cache-dir "git+https://github.com/facebookresearch/pytorch3d.git@stable" +# +## 如果你有 trellis.tar.gz 打包的环境,可以继续 COPY 并 unpack(可选) +## COPY trellis.tar.gz /opt/ +## RUN mkdir /opt/env && tar -xzf /opt/trellis.tar.gz -C /opt/env && /opt/env/bin/conda-unpack +# +## 安装 kaolin(你的原 Dockerfile 有这个) +#RUN conda activate trellis && \ +# pip install kaolin==0.18.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.8.0_cu128.html + +# 设置 PATH 和工作目录 +ENV PATH="/opt/conda/envs/trellis/bin:${PATH}" +WORKDIR /workspace + +# 复制你的代码(如果需要) +COPY . /workspace + +# 默认命令:保持容器运行,或换成你的启动脚本 +CMD ["tail", "-f", "/dev/null"] \ No newline at end of file diff --git a/dataset_toolkits/download.py b/dataset_toolkits/download.py new file mode 100644 index 0000000..36e684f --- /dev/null +++ b/dataset_toolkits/download.py @@ -0,0 +1,52 @@ +import os +import copy +import sys +import importlib +import argparse +import pandas as pd +from easydict import EasyDict as edict + +if __name__ == '__main__': + dataset_utils = importlib.import_module(f'datasets.{sys.argv[1]}') + + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + dataset_utils.add_args(parser) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args(sys.argv[2:]) + opt = edict(vars(opt)) + + os.makedirs(opt.output_dir, exist_ok=True) + + # get file list + if not os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + raise ValueError('metadata.csv not found') + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + if opt.instances is None: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if 'local_path' in metadata.columns: + metadata = metadata[metadata['local_path'].isna()] + else: + if os.path.exists(opt.instances): + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + else: + instances = opt.instances.split(',') + metadata = metadata[metadata['sha256'].isin(instances)] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + + print(f'Processing {len(metadata)} objects...') + + # process objects + downloaded = dataset_utils.download(metadata, **opt) + downloaded.to_csv(os.path.join(opt.output_dir, f'downloaded_{opt.rank}.csv'), index=False) diff --git a/dataset_toolkits/encode_latent.py b/dataset_toolkits/encode_latent.py new file mode 100644 index 0000000..a5b078f --- /dev/null +++ b/dataset_toolkits/encode_latent.py @@ -0,0 +1,127 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +import copy +import json +import argparse +import torch +import numpy as np +import pandas as pd +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +from queue import Queue + +import trellis.models as models +import trellis.modules.sparse as sp + + +torch.set_grad_enabled(False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--feat_model', type=str, default='dinov2_vitl14_reg', + help='Feature model') + parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/slat_enc_swin8_B_64l8_fp16', + help='Pretrained encoder model') + parser.add_argument('--model_root', type=str, default='results', + help='Root directory of models') + parser.add_argument('--enc_model', type=str, default=None, + help='Encoder model. if specified, use this model instead of pretrained model') + parser.add_argument('--ckpt', type=str, default=None, + help='Checkpoint to load') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args() + opt = edict(vars(opt)) + + if opt.enc_model is None: + latent_name = f'{opt.feat_model}_{opt.enc_pretrained.split("/")[-1]}' + encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda() + else: + latent_name = f'{opt.feat_model}_{opt.enc_model}_{opt.ckpt}' + cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r'))) + encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda() + ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt') + encoder.load_state_dict(torch.load(ckpt_path), strict=False) + encoder.eval() + print(f'Loaded model from {ckpt_path}') + + os.makedirs(os.path.join(opt.output_dir, 'latents', latent_name), exist_ok=True) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.instances is not None: + with open(opt.instances, 'r') as f: + sha256s = [line.strip() for line in f] + metadata = metadata[metadata['sha256'].isin(sha256s)] + else: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + metadata = metadata[metadata[f'feature_{opt.feat_model}'] == True] + if f'latent_{latent_name}' in metadata.columns: + metadata = metadata[metadata[f'latent_{latent_name}'] == False] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + sha256s = list(metadata['sha256'].values) + for sha256 in copy.copy(sha256s): + if os.path.exists(os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz')): + records.append({'sha256': sha256, f'latent_{latent_name}': True}) + sha256s.remove(sha256) + + # encode latents + load_queue = Queue(maxsize=4) + try: + with ThreadPoolExecutor(max_workers=32) as loader_executor, \ + ThreadPoolExecutor(max_workers=32) as saver_executor: + def loader(sha256): + try: + feats = np.load(os.path.join(opt.output_dir, 'features', opt.feat_model, f'{sha256}.npz')) + load_queue.put((sha256, feats)) + except Exception as e: + print(f"Error loading features for {sha256}: {e}") + loader_executor.map(loader, sha256s) + + def saver(sha256, pack): + save_path = os.path.join(opt.output_dir, 'latents', latent_name, f'{sha256}.npz') + np.savez_compressed(save_path, **pack) + records.append({'sha256': sha256, f'latent_{latent_name}': True}) + + for _ in tqdm(range(len(sha256s)), desc="Extracting latents"): + sha256, feats = load_queue.get() + feats = sp.SparseTensor( + feats = torch.from_numpy(feats['patchtokens']).float(), + coords = torch.cat([ + torch.zeros(feats['patchtokens'].shape[0], 1).int(), + torch.from_numpy(feats['indices']).int(), + ], dim=1), + ).cuda() + latent = encoder(feats, sample_posterior=False) + assert torch.isfinite(latent.feats).all(), "Non-finite latent" + pack = { + 'feats': latent.feats.cpu().numpy().astype(np.float32), + 'coords': latent.coords[:, 1:].cpu().numpy().astype(np.uint8), + } + saver_executor.submit(saver, sha256, pack) + + saver_executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + records = pd.DataFrame.from_records(records) + records.to_csv(os.path.join(opt.output_dir, f'latent_{latent_name}_{opt.rank}.csv'), index=False) diff --git a/dataset_toolkits/encode_ss_latent.py b/dataset_toolkits/encode_ss_latent.py new file mode 100644 index 0000000..a59549e --- /dev/null +++ b/dataset_toolkits/encode_ss_latent.py @@ -0,0 +1,128 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +import copy +import json +import argparse +import torch +import numpy as np +import pandas as pd +import utils3d +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +from queue import Queue + +import trellis.models as models + + +torch.set_grad_enabled(False) + + +def get_voxels(instance): + position = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{instance}.ply'))[0] + coords = ((torch.tensor(position) + 0.5) * opt.resolution).int().contiguous() + ss = torch.zeros(1, opt.resolution, opt.resolution, opt.resolution, dtype=torch.long) + ss[:, coords[:, 0], coords[:, 1], coords[:, 2]] = 1 + return ss + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--enc_pretrained', type=str, default='microsoft/TRELLIS-image-large/ckpts/ss_enc_conv3d_16l8_fp16', + help='Pretrained encoder model') + parser.add_argument('--model_root', type=str, default='results', + help='Root directory of models') + parser.add_argument('--enc_model', type=str, default=None, + help='Encoder model. if specified, use this model instead of pretrained model') + parser.add_argument('--ckpt', type=str, default=None, + help='Checkpoint to load') + parser.add_argument('--resolution', type=int, default=64, + help='Resolution') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args() + opt = edict(vars(opt)) + + if opt.enc_model is None: + latent_name = f'{opt.enc_pretrained.split("/")[-1]}' + encoder = models.from_pretrained(opt.enc_pretrained).eval().cuda() + else: + latent_name = f'{opt.enc_model}_{opt.ckpt}' + cfg = edict(json.load(open(os.path.join(opt.model_root, opt.enc_model, 'config.json'), 'r'))) + encoder = getattr(models, cfg.models.encoder.name)(**cfg.models.encoder.args).cuda() + ckpt_path = os.path.join(opt.model_root, opt.enc_model, 'ckpts', f'encoder_{opt.ckpt}.pt') + encoder.load_state_dict(torch.load(ckpt_path), strict=False) + encoder.eval() + print(f'Loaded model from {ckpt_path}') + + os.makedirs(os.path.join(opt.output_dir, 'ss_latents', latent_name), exist_ok=True) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.instances is not None: + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + metadata = metadata[metadata['sha256'].isin(instances)] + else: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + metadata = metadata[metadata['voxelized'] == True] + if f'ss_latent_{latent_name}' in metadata.columns: + metadata = metadata[metadata[f'ss_latent_{latent_name}'] == False] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + sha256s = list(metadata['sha256'].values) + for sha256 in copy.copy(sha256s): + if os.path.exists(os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz')): + records.append({'sha256': sha256, f'ss_latent_{latent_name}': True}) + sha256s.remove(sha256) + + # encode latents + load_queue = Queue(maxsize=4) + try: + with ThreadPoolExecutor(max_workers=32) as loader_executor, \ + ThreadPoolExecutor(max_workers=32) as saver_executor: + def loader(sha256): + try: + ss = get_voxels(sha256)[None].float() + load_queue.put((sha256, ss)) + except Exception as e: + print(f"Error loading features for {sha256}: {e}") + loader_executor.map(loader, sha256s) + + def saver(sha256, pack): + save_path = os.path.join(opt.output_dir, 'ss_latents', latent_name, f'{sha256}.npz') + np.savez_compressed(save_path, **pack) + records.append({'sha256': sha256, f'ss_latent_{latent_name}': True}) + + for _ in tqdm(range(len(sha256s)), desc="Extracting latents"): + sha256, ss = load_queue.get() + ss = ss.cuda().float() + latent = encoder(ss, sample_posterior=False) + assert torch.isfinite(latent).all(), "Non-finite latent" + pack = { + 'mean': latent[0].cpu().numpy(), + } + saver_executor.submit(saver, sha256, pack) + + saver_executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + records = pd.DataFrame.from_records(records) + records.to_csv(os.path.join(opt.output_dir, f'ss_latent_{latent_name}_{opt.rank}.csv'), index=False) diff --git a/dataset_toolkits/extract_feature.py b/dataset_toolkits/extract_feature.py new file mode 100644 index 0000000..1ded2fe --- /dev/null +++ b/dataset_toolkits/extract_feature.py @@ -0,0 +1,179 @@ +import os +import copy +import sys +import json +import importlib +import argparse +import torch +import torch.nn.functional as F +import numpy as np +import pandas as pd +import utils3d +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor +from queue import Queue +from torchvision import transforms +from PIL import Image + + +torch.set_grad_enabled(False) + + +def get_data(frames, sha256): + with ThreadPoolExecutor(max_workers=16) as executor: + def worker(view): + image_path = os.path.join(opt.output_dir, 'renders', sha256, view['file_path']) + try: + image = Image.open(image_path) + except: + print(f"Error loading image {image_path}") + return None + image = image.resize((518, 518), Image.Resampling.LANCZOS) + image = np.array(image).astype(np.float32) / 255 + image = image[:, :, :3] * image[:, :, 3:] + image = torch.from_numpy(image).permute(2, 0, 1).float() + + c2w = torch.tensor(view['transform_matrix']) + c2w[:3, 1:3] *= -1 + extrinsics = torch.inverse(c2w) + fov = view['camera_angle_x'] + intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov)) + + return { + 'image': image, + 'extrinsics': extrinsics, + 'intrinsics': intrinsics + } + + datas = executor.map(worker, frames) + for data in datas: + if data is not None: + yield data + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--model', type=str, default='dinov2_vitl14_reg', + help='Feature extraction model') + parser.add_argument('--instances', type=str, default=None, + help='Instances to process') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--world_size', type=int, default=1) + opt = parser.parse_args() + opt = edict(vars(opt)) + + feature_name = opt.model + os.makedirs(os.path.join(opt.output_dir, 'features', feature_name), exist_ok=True) + + # load model + dinov2_model = torch.hub.load('facebookresearch/dinov2', opt.model) + dinov2_model.eval().cuda() + transform = transforms.Compose([ + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + n_patch = 518 // 14 + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.instances is not None: + with open(opt.instances, 'r') as f: + instances = f.read().splitlines() + metadata = metadata[metadata['sha256'].isin(instances)] + else: + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + if f'feature_{feature_name}' in metadata.columns: + metadata = metadata[metadata[f'feature_{feature_name}'] == False] + metadata = metadata[metadata['voxelized'] == True] + metadata = metadata[metadata['rendered'] == True] + + start = len(metadata) * opt.rank // opt.world_size + end = len(metadata) * (opt.rank + 1) // opt.world_size + metadata = metadata[start:end] + records = [] + + # filter out objects that are already processed + sha256s = list(metadata['sha256'].values) + for sha256 in copy.copy(sha256s): + if os.path.exists(os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz')): + records.append({'sha256': sha256, f'feature_{feature_name}' : True}) + sha256s.remove(sha256) + + # extract features + load_queue = Queue(maxsize=4) + try: + with ThreadPoolExecutor(max_workers=8) as loader_executor, \ + ThreadPoolExecutor(max_workers=8) as saver_executor: + def loader(sha256): + try: + with open(os.path.join(opt.output_dir, 'renders', sha256, 'transforms.json'), 'r') as f: + metadata = json.load(f) + frames = metadata['frames'] + data = [] + for datum in get_data(frames, sha256): + datum['image'] = transform(datum['image']) + data.append(datum) + positions = utils3d.io.read_ply(os.path.join(opt.output_dir, 'voxels', f'{sha256}.ply'))[0] + load_queue.put((sha256, data, positions)) + except Exception as e: + print(f"Error loading data for {sha256}: {e}") + + loader_executor.map(loader, sha256s) + + def saver(sha256, pack, patchtokens, uv): + pack['patchtokens'] = F.grid_sample( + patchtokens, + uv.unsqueeze(1), + mode='bilinear', + align_corners=False, + ).squeeze(2).permute(0, 2, 1).cpu().numpy() + pack['patchtokens'] = np.mean(pack['patchtokens'], axis=0).astype(np.float16) + save_path = os.path.join(opt.output_dir, 'features', feature_name, f'{sha256}.npz') + np.savez_compressed(save_path, **pack) + records.append({'sha256': sha256, f'feature_{feature_name}' : True}) + + for _ in tqdm(range(len(sha256s)), desc="Extracting features"): + sha256, data, positions = load_queue.get() + positions = torch.from_numpy(positions).float().cuda() + indices = ((positions + 0.5) * 64).long() + assert torch.all(indices >= 0) and torch.all(indices < 64), "Some vertices are out of bounds" + n_views = len(data) + N = positions.shape[0] + pack = { + 'indices': indices.cpu().numpy().astype(np.uint8), + } + patchtokens_lst = [] + uv_lst = [] + for i in range(0, n_views, opt.batch_size): + batch_data = data[i:i+opt.batch_size] + bs = len(batch_data) + batch_images = torch.stack([d['image'] for d in batch_data]).cuda() + batch_extrinsics = torch.stack([d['extrinsics'] for d in batch_data]).cuda() + batch_intrinsics = torch.stack([d['intrinsics'] for d in batch_data]).cuda() + features = dinov2_model(batch_images, is_training=True) + uv = utils3d.torch.project_cv(positions, batch_extrinsics, batch_intrinsics)[0] * 2 - 1 + patchtokens = features['x_prenorm'][:, dinov2_model.num_register_tokens + 1:].permute(0, 2, 1).reshape(bs, 1024, n_patch, n_patch) + patchtokens_lst.append(patchtokens) + uv_lst.append(uv) + patchtokens = torch.cat(patchtokens_lst, dim=0) + uv = torch.cat(uv_lst, dim=0) + + # save features + saver_executor.submit(saver, sha256, pack, patchtokens, uv) + + saver_executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + records = pd.DataFrame.from_records(records) + records.to_csv(os.path.join(opt.output_dir, f'feature_{feature_name}_{opt.rank}.csv'), index=False) + \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..fdd30c0 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,29 @@ +services: + trellis: + image: my-trellis-with-blender:20260317 # 你 commit 的镜像 + container_name: trellis-dev + restart: unless-stopped + environment: + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=compute,utility,video + volumes: + - .:/workspace # 当前目录挂载到 /workspace,便于开发 + ports: + - "7412:8120" + working_dir: /workspace + tty: true + stdin_open: true + # 最新GPU配置方式(替代旧的环境变量声明,更规范) + runtime: nvidia + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all # 使用所有可用GPU,也可指定数量如count: 1 + capabilities: [ gpu, compute, video ] + command: > + bash -c " + /opt/conda/envs/trellis/bin/python -c 'import torch; print(torch.__version__, torch.cuda.is_available())' && + tail -f /dev/null + " \ No newline at end of file diff --git a/example.py b/example.py new file mode 100644 index 0000000..21033c1 --- /dev/null +++ b/example.py @@ -0,0 +1,57 @@ +import os +# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' +os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. + # 'auto' is faster but will do benchmarking at the beginning. + # Recommended to set to 'native' if run only once. + +import imageio +from PIL import Image +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils + +# Load a pipeline from a model folder or a Hugging Face model hub. +pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") +pipeline.cuda() + +# Load an image +image = Image.open("assets/example_image/T.png") + +# Run the pipeline +outputs = pipeline.run( + image, + seed=1, + # Optional parameters + # sparse_structure_sampler_params={ + # "steps": 12, + # "cfg_strength": 7.5, + # }, + # slat_sampler_params={ + # "steps": 12, + # "cfg_strength": 3, + # }, +) +# outputs is a dictionary containing generated 3D assets in different formats: +# - outputs['gaussian']: a list of 3D Gaussians +# - outputs['radiance_field']: a list of radiance fields +# - outputs['mesh']: a list of meshes + +# Render the outputs +video = render_utils.render_video(outputs['gaussian'][0])['color'] +imageio.mimsave("sample_gs.mp4", video, fps=30) +video = render_utils.render_video(outputs['radiance_field'][0])['color'] +imageio.mimsave("sample_rf.mp4", video, fps=30) +video = render_utils.render_video(outputs['mesh'][0])['normal'] +imageio.mimsave("sample_mesh.mp4", video, fps=30) + +# GLB files can be extracted from the outputs +glb = postprocessing_utils.to_glb( + outputs['gaussian'][0], + outputs['mesh'][0], + # Optional parameters + simplify=0.95, # Ratio of triangles to remove in the simplification process + texture_size=1024, # Size of the texture used for the GLB +) +glb.export("sample.glb") + +# Save Gaussians as PLY files +outputs['gaussian'][0].save_ply("sample.ply") diff --git a/example_multi_image.py b/example_multi_image.py new file mode 100644 index 0000000..3b33306 --- /dev/null +++ b/example_multi_image.py @@ -0,0 +1,46 @@ +import os +# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' +os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. + # 'auto' is faster but will do benchmarking at the beginning. + # Recommended to set to 'native' if run only once. + +import numpy as np +import imageio +from PIL import Image +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils + +# Load a pipeline from a model folder or a Hugging Face model hub. +pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") +pipeline.cuda() + +# Load an image +images = [ + Image.open("assets/example_multi_image/character_1.png"), + Image.open("assets/example_multi_image/character_2.png"), + Image.open("assets/example_multi_image/character_3.png"), +] + +# Run the pipeline +outputs = pipeline.run_multi_image( + images, + seed=1, + # Optional parameters + sparse_structure_sampler_params={ + "steps": 12, + "cfg_strength": 7.5, + }, + slat_sampler_params={ + "steps": 12, + "cfg_strength": 3, + }, +) +# outputs is a dictionary containing generated 3D assets in different formats: +# - outputs['gaussian']: a list of 3D Gaussians +# - outputs['radiance_field']: a list of radiance fields +# - outputs['mesh']: a list of meshes + +video_gs = render_utils.render_video(outputs['gaussian'][0])['color'] +video_mesh = render_utils.render_video(outputs['mesh'][0])['normal'] +video = [np.concatenate([frame_gs, frame_mesh], axis=1) for frame_gs, frame_mesh in zip(video_gs, video_mesh)] +imageio.mimsave("sample_multi.mp4", video, fps=30) diff --git a/example_text.py b/example_text.py new file mode 100644 index 0000000..31f57e8 --- /dev/null +++ b/example_text.py @@ -0,0 +1,53 @@ +import os +# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' +os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. + # 'auto' is faster but will do benchmarking at the beginning. + # Recommended to set to 'native' if run only once. + +import imageio +from trellis.pipelines import TrellisTextTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils + +# Load a pipeline from a model folder or a Hugging Face model hub. +pipeline = TrellisTextTo3DPipeline.from_pretrained("microsoft/TRELLIS-text-xlarge") +pipeline.cuda() + +# Run the pipeline +outputs = pipeline.run( + "A chair looking like a avocado.", + seed=1, + # Optional parameters + # sparse_structure_sampler_params={ + # "steps": 12, + # "cfg_strength": 7.5, + # }, + # slat_sampler_params={ + # "steps": 12, + # "cfg_strength": 7.5, + # }, +) +# outputs is a dictionary containing generated 3D assets in different formats: +# - outputs['gaussian']: a list of 3D Gaussians +# - outputs['radiance_field']: a list of radiance fields +# - outputs['mesh']: a list of meshes + +# Render the outputs +video = render_utils.render_video(outputs['gaussian'][0])['color'] +imageio.mimsave("sample_gs.mp4", video, fps=30) +video = render_utils.render_video(outputs['radiance_field'][0])['color'] +imageio.mimsave("sample_rf.mp4", video, fps=30) +video = render_utils.render_video(outputs['mesh'][0])['normal'] +imageio.mimsave("sample_mesh.mp4", video, fps=30) + +# GLB files can be extracted from the outputs +glb = postprocessing_utils.to_glb( + outputs['gaussian'][0], + outputs['mesh'][0], + # Optional parameters + simplify=0.95, # Ratio of triangles to remove in the simplification process + texture_size=1024, # Size of the texture used for the GLB +) +glb.export("sample.glb") + +# Save Gaussians as PLY files +outputs['gaussian'][0].save_ply("sample.ply") diff --git a/example_variant.py b/example_variant.py new file mode 100644 index 0000000..637f319 --- /dev/null +++ b/example_variant.py @@ -0,0 +1,41 @@ +import os +# os.environ['ATTN_BACKEND'] = 'xformers' # Can be 'flash-attn' or 'xformers', default is 'flash-attn' +os.environ['SPCONV_ALGO'] = 'native' # Can be 'native' or 'auto', default is 'auto'. + # 'auto' is faster but will do benchmarking at the beginning. + # Recommended to set to 'native' if run only once. + +import imageio +import numpy as np +import open3d as o3d +from trellis.pipelines import TrellisTextTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils + +# Load a pipeline from a model folder or a Hugging Face model hub. +pipeline = TrellisTextTo3DPipeline.from_pretrained("microsoft/TRELLIS-text-xlarge") +pipeline.cuda() + +# Load mesh to make variants +base_mesh = o3d.io.read_triangle_mesh("assets/T.ply") + +# Run the pipeline +outputs = pipeline.run_variant( + base_mesh, + "Rugged, metallic texture with orange and white paint finish, suggesting a durable, industrial feel.", + seed=1, + # Optional parameters + # slat_sampler_params={ + # "steps": 12, + # "cfg_strength": 7.5, + # }, +) +# outputs is a dictionary containing generated 3D assets in different formats: +# - outputs['gaussian']: a list of 3D Gaussians +# - outputs['radiance_field']: a list of radiance fields +# - outputs['mesh']: a list of meshes + +# Render the outputs +video_gs = render_utils.render_video(outputs['gaussian'][0])['color'] +video_mesh = render_utils.render_video(outputs['mesh'][0])['normal'] +video = [np.concatenate([frame_gs, frame_mesh], axis=1) for frame_gs, frame_mesh in zip(video_gs, video_mesh)] +imageio.mimsave("sample_variant.mp4", video, fps=30) + diff --git a/trellis/models/structured_latent_vae/decoder_gs.py b/trellis/models/structured_latent_vae/decoder_gs.py new file mode 100644 index 0000000..47304f3 --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_gs.py @@ -0,0 +1,131 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from ...utils.random_utils import hammersley_sequence +from .base import SparseTransformerBase +from ...representations import Gaussian +from ..sparse_elastic_mixin import SparseTransformerElasticMixin + + +class SLatGaussianDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + self._build_perturbation() + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _build_perturbation(self) -> None: + perturbation = [hammersley_sequence(3, i, self.rep_config['num_gaussians']) for i in range(self.rep_config['num_gaussians'])] + perturbation = torch.tensor(perturbation).float() * 2 - 1 + perturbation = perturbation / self.rep_config['voxel_size'] + perturbation = torch.atanh(perturbation).to(self.device) + self.register_buffer('offset_perturbation', perturbation) + + def _calc_layout(self) -> None: + self.layout = { + '_xyz' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_features_dc' : {'shape': (self.rep_config['num_gaussians'], 1, 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_scaling' : {'shape': (self.rep_config['num_gaussians'], 3), 'size': self.rep_config['num_gaussians'] * 3}, + '_rotation' : {'shape': (self.rep_config['num_gaussians'], 4), 'size': self.rep_config['num_gaussians'] * 4}, + '_opacity' : {'shape': (self.rep_config['num_gaussians'], 1), 'size': self.rep_config['num_gaussians']}, + } + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Gaussian]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Gaussian( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size = self.rep_config['3d_filter_kernel_size'], + scaling_bias = self.rep_config['scaling_bias'], + opacity_bias = self.rep_config['opacity_bias'], + scaling_activation = self.rep_config['scaling_activation'] + ) + xyz = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + for k, v in self.layout.items(): + if k == '_xyz': + offset = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']) + offset = offset * self.rep_config['lr'][k] + if self.rep_config['perturb_offset']: + offset = offset + self.offset_perturbation + offset = torch.tanh(offset) / self.resolution * 0.5 * self.rep_config['voxel_size'] + _xyz = xyz.unsqueeze(1) + offset + setattr(representation, k, _xyz.flatten(0, 1)) + else: + feats = x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) + feats = feats * self.rep_config['lr'][k] + setattr(representation, k, feats) + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Gaussian]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) + + +class ElasticSLatGaussianDecoder(SparseTransformerElasticMixin, SLatGaussianDecoder): + """ + Slat VAE Gaussian decoder with elastic memory management. + Used for training with low VRAM. + """ + pass diff --git a/trellis/models/structured_latent_vae/decoder_mesh.py b/trellis/models/structured_latent_vae/decoder_mesh.py new file mode 100644 index 0000000..f9f5659 --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_mesh.py @@ -0,0 +1,176 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import MeshExtractResult +from ...representations.mesh import SparseFeatures2Mesh +from ..sparse_elastic_mixin import SparseTransformerElasticMixin + + +class SparseSubdivideBlock3d(nn.Module): + """ + A 3D subdivide block that can subdivide the sparse tensor. + + Args: + channels: channels in the inputs and outputs. + out_channels: if specified, the number of output channels. + num_groups: the number of groups for the group norm. + """ + def __init__( + self, + channels: int, + resolution: int, + out_channels: Optional[int] = None, + num_groups: int = 32 + ): + super().__init__() + self.channels = channels + self.resolution = resolution + self.out_resolution = resolution * 2 + self.out_channels = out_channels or channels + + self.act_layers = nn.Sequential( + sp.SparseGroupNorm32(num_groups, channels), + sp.SparseSiLU() + ) + + self.sub = sp.SparseSubdivide() + + self.out_layers = nn.Sequential( + sp.SparseConv3d(channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}"), + sp.SparseGroupNorm32(num_groups, self.out_channels), + sp.SparseSiLU(), + zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3, indice_key=f"res_{self.out_resolution}")), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = sp.SparseConv3d(channels, self.out_channels, 1, indice_key=f"res_{self.out_resolution}") + + def forward(self, x: sp.SparseTensor) -> sp.SparseTensor: + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Args: + x: an [N x C x ...] Tensor of features. + Returns: + an [N x C x ...] Tensor of outputs. + """ + h = self.act_layers(x) + h = self.sub(h) + x = self.sub(x) + h = self.out_layers(h) + h = h + self.skip_connection(x) + return h + + +class SLatMeshDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self.mesh_extractor = SparseFeatures2Mesh(res=self.resolution*4, use_color=self.rep_config.get('use_color', False)) + self.out_channels = self.mesh_extractor.feats_channels + self.upsample = nn.ModuleList([ + SparseSubdivideBlock3d( + channels=model_channels, + resolution=resolution, + out_channels=model_channels // 4 + ), + SparseSubdivideBlock3d( + channels=model_channels // 4, + resolution=resolution * 2, + out_channels=model_channels // 8 + ) + ]) + self.out_layer = sp.SparseLinear(model_channels // 8, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + super().convert_to_fp16() + self.upsample.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + super().convert_to_fp32() + self.upsample.apply(convert_module_to_f32) + + def to_representation(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + mesh = self.mesh_extractor(x[i], training=self.training) + ret.append(mesh) + return ret + + def forward(self, x: sp.SparseTensor) -> List[MeshExtractResult]: + h = super().forward(x) + for block in self.upsample: + h = block(h) + h = h.type(x.dtype) + h = self.out_layer(h) + return self.to_representation(h) + + +class ElasticSLatMeshDecoder(SparseTransformerElasticMixin, SLatMeshDecoder): + """ + Slat VAE Mesh decoder with elastic memory management. + Used for training with low VRAM. + """ + pass diff --git a/trellis/models/structured_latent_vae/decoder_rf.py b/trellis/models/structured_latent_vae/decoder_rf.py new file mode 100644 index 0000000..bb18094 --- /dev/null +++ b/trellis/models/structured_latent_vae/decoder_rf.py @@ -0,0 +1,113 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ...representations import Strivec +from ..sparse_elastic_mixin import SparseTransformerElasticMixin + + +class SLatRadianceFieldDecoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + representation_config: dict = None, + ): + super().__init__( + in_channels=latent_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.rep_config = representation_config + self._calc_layout() + self.out_layer = sp.SparseLinear(model_channels, self.out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def _calc_layout(self) -> None: + self.layout = { + 'trivec': {'shape': (self.rep_config['rank'], 3, self.rep_config['dim']), 'size': self.rep_config['rank'] * 3 * self.rep_config['dim']}, + 'density': {'shape': (self.rep_config['rank'],), 'size': self.rep_config['rank']}, + 'features_dc': {'shape': (self.rep_config['rank'], 1, 3), 'size': self.rep_config['rank'] * 3}, + } + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + self.out_channels = start + + def to_representation(self, x: sp.SparseTensor) -> List[Strivec]: + """ + Convert a batch of network outputs to 3D representations. + + Args: + x: The [N x * x C] sparse tensor output by the network. + + Returns: + list of representations + """ + ret = [] + for i in range(x.shape[0]): + representation = Strivec( + sh_degree=0, + resolution=self.resolution, + aabb=[-0.5, -0.5, -0.5, 1, 1, 1], + rank=self.rep_config['rank'], + dim=self.rep_config['dim'], + device='cuda', + ) + representation.density_shift = 0.0 + representation.position = (x.coords[x.layout[i]][:, 1:].float() + 0.5) / self.resolution + representation.depth = torch.full((representation.position.shape[0], 1), int(np.log2(self.resolution)), dtype=torch.uint8, device='cuda') + for k, v in self.layout.items(): + setattr(representation, k, x.feats[x.layout[i]][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape'])) + representation.trivec = representation.trivec + 1 + ret.append(representation) + return ret + + def forward(self, x: sp.SparseTensor) -> List[Strivec]: + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return self.to_representation(h) + + +class ElasticSLatRadianceFieldDecoder(SparseTransformerElasticMixin, SLatRadianceFieldDecoder): + """ + Slat VAE Radiance Field Decoder with elastic memory management. + Used for training with low VRAM. + """ + pass diff --git a/trellis/models/structured_latent_vae/encoder.py b/trellis/models/structured_latent_vae/encoder.py new file mode 100644 index 0000000..77f6fb1 --- /dev/null +++ b/trellis/models/structured_latent_vae/encoder.py @@ -0,0 +1,80 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...modules import sparse as sp +from .base import SparseTransformerBase +from ..sparse_elastic_mixin import SparseTransformerElasticMixin + + +class SLatEncoder(SparseTransformerBase): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + latent_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + attn_mode: Literal["full", "shift_window", "shift_sequence", "shift_order", "swin"] = "swin", + window_size: int = 8, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + qk_rms_norm: bool = False, + ): + super().__init__( + in_channels=in_channels, + model_channels=model_channels, + num_blocks=num_blocks, + num_heads=num_heads, + num_head_channels=num_head_channels, + mlp_ratio=mlp_ratio, + attn_mode=attn_mode, + window_size=window_size, + pe_mode=pe_mode, + use_fp16=use_fp16, + use_checkpoint=use_checkpoint, + qk_rms_norm=qk_rms_norm, + ) + self.resolution = resolution + self.out_layer = sp.SparseLinear(model_channels, 2 * latent_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + def initialize_weights(self) -> None: + super().initialize_weights() + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, sample_posterior=True, return_raw=False): + h = super().forward(x) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z + + +class ElasticSLatEncoder(SparseTransformerElasticMixin, SLatEncoder): + """ + SLat VAE encoder with elastic memory management. + Used for training with low VRAM. + """ diff --git a/trellis/modules/attention/full_attn.py b/trellis/modules/attention/full_attn.py new file mode 100755 index 0000000..d9ebf63 --- /dev/null +++ b/trellis/modules/attention/full_attn.py @@ -0,0 +1,140 @@ +from typing import * +import torch +import math +from . import DEBUG, BACKEND + +if BACKEND == 'xformers': + import xformers.ops as xops +elif BACKEND == 'flash_attn': + import flash_attn +elif BACKEND == 'sdpa': + from torch.nn.functional import scaled_dot_product_attention as sdpa +elif BACKEND == 'naive': + pass +else: + raise ValueError(f"Unknown attention backend: {BACKEND}") + + +__all__ = [ + 'scaled_dot_product_attention', +] + + +def _naive_sdpa(q, k, v): + """ + Naive implementation of scaled dot product attention. + """ + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight @ v + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + return out + + +@overload +def scaled_dot_product_attention(qkv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + qkv (torch.Tensor): A [N, L, 3, H, C] tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, kv: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, C] tensor containing Qs. + kv (torch.Tensor): A [N, L, 2, H, C] tensor containing Ks and Vs. + """ + ... + +@overload +def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """ + Apply scaled dot product attention. + + Args: + q (torch.Tensor): A [N, L, H, Ci] tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +def scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert len(qkv.shape) == 5 and qkv.shape[2] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, L, 3, H, C]" + device = qkv.device + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + device = q.device + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + device = q.device + + if BACKEND == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif BACKEND == 'flash_attn': + if num_all_args == 1: + out = flash_attn.flash_attn_qkvpacked_func(qkv) + elif num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif BACKEND == 'sdpa': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif BACKEND == 'naive': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = _naive_sdpa(q, k, v) + else: + raise ValueError(f"Unknown attention module: {BACKEND}") + + return out diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py new file mode 100755 index 0000000..e9e27ae --- /dev/null +++ b/trellis/modules/sparse/attention/full_attn.py @@ -0,0 +1,215 @@ +from typing import * +import torch +from .. import SparseTensor +from .. import DEBUG, ATTN + +if ATTN == 'xformers': + import xformers.ops as xops +elif ATTN == 'flash_attn': + import flash_attn +else: + raise ValueError(f"Unknown attention module: {ATTN}") + + +__all__ = [ + 'sparse_scaled_dot_product_attention', +] + + +@overload +def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + qkv (SparseTensor): A [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, kv: Union[SparseTensor, torch.Tensor]) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, C] sparse tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor or a [N, L, 2, H, C] dense tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, kv: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, L, H, C] dense tensor containing Qs. + kv (SparseTensor or torch.Tensor): A [N, *, 2, H, C] sparse tensor containing Ks and Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: SparseTensor, v: SparseTensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + + Note: + k and v are assumed to have the same coordinate map. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: SparseTensor, k: torch.Tensor, v: torch.Tensor) -> SparseTensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (SparseTensor): A [N, *, H, Ci] sparse tensor containing Qs. + k (torch.Tensor): A [N, L, H, Ci] dense tensor containing Ks. + v (torch.Tensor): A [N, L, H, Co] dense tensor containing Vs. + """ + ... + +@overload +def sparse_scaled_dot_product_attention(q: torch.Tensor, k: SparseTensor, v: SparseTensor) -> torch.Tensor: + """ + Apply scaled dot product attention to a sparse tensor. + + Args: + q (torch.Tensor): A [N, L, H, Ci] dense tensor containing Qs. + k (SparseTensor): A [N, *, H, Ci] sparse tensor containing Ks. + v (SparseTensor): A [N, *, H, Co] sparse tensor containing Vs. + """ + ... + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + assert num_all_args in arg_names_dict, f"Invalid number of arguments, got {num_all_args}, expected 1, 2, or 3" + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + assert isinstance(qkv, SparseTensor), f"qkv must be a SparseTensor, got {type(qkv)}" + assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + assert isinstance(q, SparseTensor) and isinstance(kv, (SparseTensor, torch.Tensor)) or \ + isinstance(q, torch.Tensor) and isinstance(kv, SparseTensor), \ + f"Invalid types, got {type(q)} and {type(kv)}" + assert q.shape[0] == kv.shape[0], f"Batch size mismatch, got {q.shape[0]} and {kv.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, C]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, C]" + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, SparseTensor): + assert len(kv.shape) == 4 and kv.shape[1] == 2, f"Invalid shape for kv, got {kv.shape}, expected [N, *, 2, H, C]" + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + assert len(kv.shape) == 5, f"Invalid shape for kv, got {kv.shape}, expected [N, L, 2, H, C]" + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + assert isinstance(q, SparseTensor) and isinstance(k, (SparseTensor, torch.Tensor)) and type(k) == type(v) or \ + isinstance(q, torch.Tensor) and isinstance(k, SparseTensor) and isinstance(v, SparseTensor), \ + f"Invalid types, got {type(q)}, {type(k)}, and {type(v)}" + assert q.shape[0] == k.shape[0] == v.shape[0], f"Batch size mismatch, got {q.shape[0]}, {k.shape[0]}, and {v.shape[0]}" + device = q.device + + if isinstance(q, SparseTensor): + assert len(q.shape) == 3, f"Invalid shape for q, got {q.shape}, expected [N, *, H, Ci]" + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + assert len(q.shape) == 4, f"Invalid shape for q, got {q.shape}, expected [N, L, H, Ci]" + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, SparseTensor): + assert len(k.shape) == 3, f"Invalid shape for k, got {k.shape}, expected [N, *, H, Ci]" + assert len(v.shape) == 3, f"Invalid shape for v, got {v.shape}, expected [N, *, H, Co]" + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + assert len(k.shape) == 4, f"Invalid shape for k, got {k.shape}, expected [N, L, H, Ci]" + assert len(v.shape) == 4, f"Invalid shape for v, got {v.shape}, expected [N, L, H, Co]" + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if DEBUG: + if s is not None: + for i in range(s.shape[0]): + assert (s.coords[s.layout[i]] == i).all(), f"SparseScaledDotProductSelfAttention: batch index mismatch" + if num_all_args in [2, 3]: + assert q.shape[:2] == [1, sum(q_seqlen)], f"SparseScaledDotProductSelfAttention: q shape mismatch" + if num_all_args == 3: + assert k.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: k shape mismatch" + assert v.shape[:2] == [1, sum(kv_seqlen)], f"SparseScaledDotProductSelfAttention: v shape mismatch" + + if ATTN == 'xformers': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif ATTN == 'flash_attn': + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + else: + raise ValueError(f"Unknown attention module: {ATTN}") + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/trellis/pipelines/samplers/flow_euler.py b/trellis/pipelines/samplers/flow_euler.py new file mode 100644 index 0000000..cd463f2 --- /dev/null +++ b/trellis/pipelines/samplers/flow_euler.py @@ -0,0 +1,201 @@ +from typing import * +import torch +import numpy as np +from tqdm import tqdm +from easydict import EasyDict as edict +from .base import Sampler +from .classifier_free_guidance_mixin import ClassifierFreeGuidanceSamplerMixin +from .guidance_interval_mixin import GuidanceIntervalSamplerMixin + + +class FlowEulerSampler(Sampler): + """ + Generate samples from a flow-matching model using Euler sampling. + + Args: + sigma_min: The minimum scale of noise in flow. + """ + def __init__( + self, + sigma_min: float, + ): + self.sigma_min = sigma_min + + def _eps_to_xstart(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * eps) / (1 - t) + + def _xstart_to_eps(self, x_t, t, x_0): + assert x_t.shape == x_0.shape + return (x_t - (1 - t) * x_0) / (self.sigma_min + (1 - self.sigma_min) * t) + + def _v_to_xstart_eps(self, x_t, t, v): + assert x_t.shape == v.shape + eps = (1 - t) * v + x_t + x_0 = (1 - self.sigma_min) * x_t - (self.sigma_min + (1 - self.sigma_min) * t) * v + return x_0, eps + + def _inference_model(self, model, x_t, t, cond=None, **kwargs): + t = torch.tensor([1000 * t] * x_t.shape[0], device=x_t.device, dtype=torch.float32) + if cond is not None and cond.shape[0] == 1 and x_t.shape[0] > 1: + cond = cond.repeat(x_t.shape[0], *([1] * (len(cond.shape) - 1))) + return model(x_t, t, cond, **kwargs) + + def _get_model_prediction(self, model, x_t, t, cond=None, **kwargs): + pred_v = self._inference_model(model, x_t, t, cond, **kwargs) + pred_x_0, pred_eps = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v) + return pred_x_0, pred_eps, pred_v + + @torch.no_grad() + def sample_once( + self, + model, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_x_prev = x_t - (t - t_prev) * pred_v + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0}) + + @torch.no_grad() + def sample( + self, + model, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + for t, t_prev in tqdm(t_pairs, desc="Sampling", disable=not verbose): + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + +class FlowEulerCfgSampler(ClassifierFreeGuidanceSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, **kwargs) + + +class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSampler): + """ + Generate samples from a flow-matching model using Euler sampling with classifier-free guidance and interval. + """ + @torch.no_grad() + def sample( + self, + model, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) diff --git a/trellis/representations/gaussian/gaussian_model.py b/trellis/representations/gaussian/gaussian_model.py new file mode 100755 index 0000000..54ba16f --- /dev/null +++ b/trellis/representations/gaussian/gaussian_model.py @@ -0,0 +1,209 @@ +import torch +import numpy as np +from plyfile import PlyData, PlyElement +from .general_utils import inverse_sigmoid, strip_symmetric, build_scaling_rotation +import utils3d + + +class Gaussian: + def __init__( + self, + aabb : list, + sh_degree : int = 0, + mininum_kernel_size : float = 0.0, + scaling_bias : float = 0.01, + opacity_bias : float = 0.1, + scaling_activation : str = "exp", + device='cuda' + ): + self.init_params = { + 'aabb': aabb, + 'sh_degree': sh_degree, + 'mininum_kernel_size': mininum_kernel_size, + 'scaling_bias': scaling_bias, + 'opacity_bias': opacity_bias, + 'scaling_activation': scaling_activation, + } + + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.scaling_activation_type = scaling_activation + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.setup_functions() + + self._xyz = None + self._features_dc = None + self._features_rest = None + self._scaling = None + self._rotation = None + self._opacity = None + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + if self.scaling_activation_type == "exp": + self.scaling_activation = torch.exp + self.inverse_scaling_activation = torch.log + elif self.scaling_activation_type == "softplus": + self.scaling_activation = torch.nn.functional.softplus + self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).cuda() + self.rots_bias = torch.zeros((4)).cuda() + self.rots_bias[0] = 1 + self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).cuda() + + @property + def get_scaling(self): + scales = self.scaling_activation(self._scaling + self.scale_bias) + scales = torch.square(scales) + self.mininum_kernel_size ** 2 + scales = torch.sqrt(scales) + return scales + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation + self.rots_bias[None, :]) + + @property + def get_xyz(self): + return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3] + + @property + def get_features(self): + return torch.cat((self._features_dc, self._features_rest), dim=2) if self._features_rest is not None else self._features_dc + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity + self.opacity_bias) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :]) + + def from_scaling(self, scales): + scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2) + self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias + + def from_rotation(self, rots): + self._rotation = rots - self.rots_bias[None, :] + + def from_xyz(self, xyz): + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + + def from_features(self, features): + self._features_dc = features + + def from_opacity(self, opacities): + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): + xyz = self.get_xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy() + scale = torch.log(self.get_scaling).detach().cpu().numpy() + rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy() + + if transform is not None: + transform = np.array(transform) + xyz = np.matmul(xyz, transform.T) + rotation = utils3d.numpy.quaternion_to_matrix(rotation) + rotation = np.matmul(transform, rotation) + rotation = utils3d.numpy.matrix_to_quaternion(rotation) + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]): + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + if self.sh_degree > 0: + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + if transform is not None: + transform = np.array(transform) + xyz = np.matmul(xyz, transform) + rotation = utils3d.numpy.quaternion_to_matrix(rotation) + rotation = np.matmul(rotation, transform) + rotation = utils3d.numpy.matrix_to_quaternion(rotation) + + # convert to actual gaussian attributes + xyz = torch.tensor(xyz, dtype=torch.float, device=self.device) + features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + if self.sh_degree > 0: + features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1, 2).contiguous() + opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device)) + scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device)) + rots = torch.tensor(rots, dtype=torch.float, device=self.device) + + # convert to _hidden attributes + self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:] + self._features_dc = features_dc + if self.sh_degree > 0: + self._features_rest = features_extra + else: + self._features_rest = None + self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias + self._scaling = self.inverse_scaling_activation(torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias + self._rotation = rots - self.rots_bias[None, :] + \ No newline at end of file diff --git a/trellis/representations/mesh/flexicubes/examples/download_data.py b/trellis/representations/mesh/flexicubes/examples/download_data.py new file mode 100644 index 0000000..fa1bddd --- /dev/null +++ b/trellis/representations/mesh/flexicubes/examples/download_data.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import requests +from zipfile import ZipFile +from tqdm import tqdm +import os + +def download_file(url, output_path): + response = requests.get(url, stream=True) + response.raise_for_status() + total_size_in_bytes = int(response.headers.get('content-length', 0)) + block_size = 1024 #1 Kibibyte + progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + + with open(output_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + raise Exception("ERROR, something went wrong") + + +url = "https://vcg.isti.cnr.it/Publications/2014/MPZ14/inputmodels.zip" +zip_file_path = './data/inputmodels.zip' + +os.makedirs('./data', exist_ok=True) + +download_file(url, zip_file_path) + +with ZipFile(zip_file_path, 'r') as zip_ref: + zip_ref.extractall('./data') + +os.remove(zip_file_path) + +print("Download and extraction complete.") diff --git a/trellis/representations/mesh/flexicubes/examples/optimize.py b/trellis/representations/mesh/flexicubes/examples/optimize.py new file mode 100644 index 0000000..2177101 --- /dev/null +++ b/trellis/representations/mesh/flexicubes/examples/optimize.py @@ -0,0 +1,157 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import numpy as np +import torch +import nvdiffrast.torch as dr +import trimesh +import os +from util import * +import render +import loss +import imageio + +import sys +sys.path.append('..') +from flexicubes import FlexiCubes + +############################################################################### +# Functions adapted from https://github.com/NVlabs/nvdiffrec +############################################################################### + +def lr_schedule(iter): + return max(0.0, 10**(-(iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs. + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='flexicubes optimization') + parser.add_argument('-o', '--out_dir', type=str, default=None) + parser.add_argument('-rm', '--ref_mesh', type=str) + + parser.add_argument('-i', '--iter', type=int, default=1000) + parser.add_argument('-b', '--batch', type=int, default=8) + parser.add_argument('-r', '--train_res', nargs=2, type=int, default=[2048, 2048]) + parser.add_argument('-lr', '--learning_rate', type=float, default=0.01) + parser.add_argument('--voxel_grid_res', type=int, default=64) + + parser.add_argument('--sdf_loss', type=bool, default=True) + parser.add_argument('--develop_reg', type=bool, default=False) + parser.add_argument('--sdf_regularizer', type=float, default=0.2) + + parser.add_argument('-dr', '--display_res', nargs=2, type=int, default=[512, 512]) + parser.add_argument('-si', '--save_interval', type=int, default=20) + FLAGS = parser.parse_args() + device = 'cuda' + + os.makedirs(FLAGS.out_dir, exist_ok=True) + glctx = dr.RasterizeGLContext() + + # Load GT mesh + gt_mesh = load_mesh(FLAGS.ref_mesh, device) + gt_mesh.auto_normals() # compute face normals for visualization + + # ============================================================================================== + # Create and initialize FlexiCubes + # ============================================================================================== + fc = FlexiCubes(device) + x_nx3, cube_fx8 = fc.construct_voxel_grid(FLAGS.voxel_grid_res) + x_nx3 *= 2 # scale up the grid so that it's larger than the target object + + sdf = torch.rand_like(x_nx3[:,0]) - 0.1 # randomly init SDF + sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True) + # set per-cube learnable weights to zeros + weight = torch.zeros((cube_fx8.shape[0], 21), dtype=torch.float, device='cuda') + weight = torch.nn.Parameter(weight.clone().detach(), requires_grad=True) + deform = torch.nn.Parameter(torch.zeros_like(x_nx3), requires_grad=True) + + # Retrieve all the edges of the voxel grid; these edges will be utilized to + # compute the regularization loss in subsequent steps of the process. + all_edges = cube_fx8[:, fc.cube_edges].reshape(-1, 2) + grid_edges = torch.unique(all_edges, dim=0) + + # ============================================================================================== + # Setup optimizer + # ============================================================================================== + optimizer = torch.optim.Adam([sdf, weight,deform], lr=FLAGS.learning_rate) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x)) + + # ============================================================================================== + # Train loop + # ============================================================================================== + for it in range(FLAGS.iter): + optimizer.zero_grad() + # sample random camera poses + mv, mvp = render.get_random_camera_batch(FLAGS.batch, iter_res=FLAGS.train_res, device=device, use_kaolin=False) + # render gt mesh + target = render.render_mesh_paper(gt_mesh, mv, mvp, FLAGS.train_res) + # extract and render FlexiCubes mesh + grid_verts = x_nx3 + (2-1e-8) / (FLAGS.voxel_grid_res * 2) * torch.tanh(deform) + vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20], + gamma_f=weight[:,20], training=True) + flexicubes_mesh = Mesh(vertices, faces) + buffers = render.render_mesh_paper(flexicubes_mesh, mv, mvp, FLAGS.train_res) + + # evaluate reconstruction loss + mask_loss = (buffers['mask'] - target['mask']).abs().mean() + depth_loss = (((((buffers['depth'] - (target['depth']))* target['mask'])**2).sum(-1)+1e-8)).sqrt().mean() * 10 + + t_iter = it / FLAGS.iter + sdf_weight = FLAGS.sdf_regularizer - (FLAGS.sdf_regularizer - FLAGS.sdf_regularizer/20)*min(1.0, 4.0 * t_iter) + reg_loss = loss.sdf_reg_loss(sdf, grid_edges).mean() * sdf_weight # Loss to eliminate internal floaters that are not visible + reg_loss += L_dev.mean() * 0.5 + reg_loss += (weight[:,:20]).abs().mean() * 0.1 + total_loss = mask_loss + depth_loss + reg_loss + + if FLAGS.sdf_loss: # optionally add SDF loss to eliminate internal structures + with torch.no_grad(): + pts = sample_random_points(1000, gt_mesh) + gt_sdf = compute_sdf(pts, gt_mesh.vertices, gt_mesh.faces) + pred_sdf = compute_sdf(pts, flexicubes_mesh.vertices, flexicubes_mesh.faces) + total_loss += torch.nn.functional.mse_loss(pred_sdf, gt_sdf) * 2e3 + + # optionally add developability regularizer, as described in paper section 5.2 + if FLAGS.develop_reg: + reg_weight = max(0, t_iter - 0.8) * 5 + if reg_weight > 0: # only applied after shape converges + reg_loss = loss.mesh_developable_reg(flexicubes_mesh).mean() * 10 + reg_loss += (deform).abs().mean() + reg_loss += (weight[:,:20]).abs().mean() + total_loss = mask_loss + depth_loss + reg_loss + + total_loss.backward() + optimizer.step() + scheduler.step() + + if (it % FLAGS.save_interval == 0 or it == (FLAGS.iter-1)): # save normal image for visualization + with torch.no_grad(): + # extract mesh with training=False + vertices, faces, L_dev = fc(grid_verts, sdf, cube_fx8, FLAGS.voxel_grid_res, beta_fx12=weight[:,:12], alpha_fx8=weight[:,12:20], + gamma_f=weight[:,20], training=False) + flexicubes_mesh = Mesh(vertices, faces) + + flexicubes_mesh.auto_normals() # compute face normals for visualization + mv, mvp = render.get_rotate_camera(it//FLAGS.save_interval, iter_res=FLAGS.display_res, device=device,use_kaolin=False) + val_buffers = render.render_mesh_paper(flexicubes_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True) + val_image = ((val_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8) + + gt_buffers = render.render_mesh_paper(gt_mesh, mv.unsqueeze(0), mvp.unsqueeze(0), FLAGS.display_res, return_types=["normal"], white_bg=True) + gt_image = ((gt_buffers["normal"][0].detach().cpu().numpy()+1)/2*255).astype(np.uint8) + imageio.imwrite(os.path.join(FLAGS.out_dir, '{:04d}.png'.format(it)), np.concatenate([val_image, gt_image], 1)) + print(f"Optimization Step [{it}/{FLAGS.iter}], Loss: {total_loss.item():.4f}") + + # ============================================================================================== + # Save ouput + # ============================================================================================== + mesh_np = trimesh.Trimesh(vertices = vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy(), process=False) + mesh_np.export(os.path.join(FLAGS.out_dir, 'output_mesh.obj')) \ No newline at end of file diff --git a/trellis/representations/mesh/flexicubes/flexicubes.py b/trellis/representations/mesh/flexicubes/flexicubes.py new file mode 100644 index 0000000..93e47b6 --- /dev/null +++ b/trellis/representations/mesh/flexicubes/flexicubes.py @@ -0,0 +1,390 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from .tables import * +from kaolin.utils.testing import check_tensor + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + def __init__(self, device="cuda"): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + + def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3, + weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False): + assert torch.is_tensor(voxelgrid_vertices) and \ + check_tensor(voxelgrid_vertices, (None, 3), throw=False), \ + "'voxelgrid_vertices' should be a tensor of shape (num_vertices, 3)" + num_vertices = voxelgrid_vertices.shape[0] + assert torch.is_tensor(scalar_field) and \ + check_tensor(scalar_field, (num_vertices,), throw=False), \ + "'scalar_field' should be a tensor of shape (num_vertices,)" + assert torch.is_tensor(cube_idx) and \ + check_tensor(cube_idx, (None, 8), throw=False), \ + "'cube_idx' should be a tensor of shape (num_cubes, 8)" + num_cubes = cube_idx.shape[0] + assert beta is None or ( + torch.is_tensor(beta) and + check_tensor(beta, (num_cubes, 12), throw=False) + ), "'beta' should be a tensor of shape (num_cubes, 12)" + assert alpha is None or ( + torch.is_tensor(alpha) and + check_tensor(alpha, (num_cubes, 8), throw=False) + ), "'alpha' should be a tensor of shape (num_cubes, 8)" + assert gamma_f is None or ( + torch.is_tensor(gamma_f) and + check_tensor(gamma_f, (num_cubes,), throw=False) + ), "'gamma_f' should be a tensor of shape (num_cubes,)" + + surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx) + if surf_cubes.sum() == 0: + return ( + torch.zeros((0, 3), device=self.device), + torch.zeros((0, 3), dtype=torch.long, device=self.device), + torch.zeros((0), device=self.device), + torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None + ) + beta, alpha, gamma_f = self._normalize_weights( + beta, alpha, gamma_f, surf_cubes, weight_scale) + + if voxelgrid_colors is not None: + voxelgrid_colors = torch.sigmoid(voxelgrid_colors) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges( + scalar_field, cube_idx, surf_cubes + ) + + vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd( + voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors) + vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate( + scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, + vd_idx_map, surf_edges_mask, training, vd_color) + return vertices, faces, L_dev, vertices_color + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta is not None: + beta = (torch.tanh(beta) * weight_scale + 1) + else: + beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha is not None: + alpha = (torch.tanh(alpha) * weight_scale + 1) + else: + alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = scalar_field < 0 + all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, scalar_field, cube_idx): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = scalar_field < 0 + occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)] + , edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field, + case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + if voxelgrid_colors is not None: + C = voxelgrid_colors.shape[-1] + surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + # if color is not None: + # vd_color = [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + # if color is not None: + # vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3)) + + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + # if color is not None: + # vd_color = torch.cat(vd_color) + # else: + # vd_color = None + + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + + ''' + interpolate colors use the same method as dual vertices + ''' + if voxelgrid_colors is not None: + vd_color = torch.zeros((total_num_vd, C), device=self.device) + c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C) + uc_group = self._linear_interp(s_group * alpha_group, c_group) + vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum + else: + vd_color = None + + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map, vd_color + + def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2] + gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3] + if not training: + mask = (gamma_02 > gamma_13) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2 + vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + + if vd_color is not None: + color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1]) + color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2 + color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2 + color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1) + vd_color = torch.cat([vd_color, color_center]) + + + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices, vd_color \ No newline at end of file diff --git a/trellis/trainers/flow_matching/flow_matching.py b/trellis/trainers/flow_matching/flow_matching.py new file mode 100644 index 0000000..1df8efa --- /dev/null +++ b/trellis/trainers/flow_matching/flow_matching.py @@ -0,0 +1,353 @@ +from typing import * +import copy +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict + +from ..basic import BasicTrainer +from ...pipelines import samplers +from ...utils.general_utils import dict_reduce +from .mixins.classifier_free_guidance import ClassifierFreeGuidanceMixin +from .mixins.text_conditioned import TextConditionedMixin +from .mixins.image_conditioned import ImageConditionedMixin + + +class FlowMatchingTrainer(BasicTrainer): + """ + Trainer for diffusion model with flow matching objective. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + """ + def __init__( + self, + *args, + t_schedule: dict = { + 'name': 'logitNormal', + 'args': { + 'mean': 0.0, + 'std': 1.0, + } + }, + sigma_min: float = 1e-5, + **kwargs + ): + super().__init__(*args, **kwargs) + self.t_schedule = t_schedule + self.sigma_min = sigma_min + + def diffuse(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + + Args: + x_0: The [N x C x ...] tensor of noiseless inputs. + t: The [N] tensor of diffusion steps [0-1]. + noise: If specified, use this noise instead of generating new noise. + + Returns: + x_t, the noisy version of x_0 under timestep t. + """ + if noise is None: + noise = torch.randn_like(x_0) + assert noise.shape == x_0.shape, "noise must have same shape as x_0" + + t = t.view(-1, *[1 for _ in range(len(x_0.shape) - 1)]) + x_t = (1 - t) * x_0 + (self.sigma_min + (1 - self.sigma_min) * t) * noise + + return x_t + + def reverse_diffuse(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: + """ + Get original image from noisy version under timestep t. + """ + assert noise.shape == x_t.shape, "noise must have same shape as x_t" + t = t.view(-1, *[1 for _ in range(len(x_t.shape) - 1)]) + x_0 = (x_t - (self.sigma_min + (1 - self.sigma_min) * t) * noise) / (1 - t) + return x_0 + + def get_v(self, x_0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Compute the velocity of the diffusion process at time t. + """ + return (1 - self.sigma_min) * noise - x_0 + + def get_cond(self, cond, **kwargs): + """ + Get the conditioning data. + """ + return cond + + def get_inference_cond(self, cond, **kwargs): + """ + Get the conditioning data for inference. + """ + return {'cond': cond, **kwargs} + + def get_sampler(self, **kwargs) -> samplers.FlowEulerSampler: + """ + Get the sampler for the diffusion process. + """ + return samplers.FlowEulerSampler(self.sigma_min) + + def vis_cond(self, **kwargs): + """ + Visualize the conditioning data. + """ + return {} + + def sample_t(self, batch_size: int) -> torch.Tensor: + """ + Sample timesteps. + """ + if self.t_schedule['name'] == 'uniform': + t = torch.rand(batch_size) + elif self.t_schedule['name'] == 'logitNormal': + mean = self.t_schedule['args']['mean'] + std = self.t_schedule['args']['std'] + t = torch.sigmoid(torch.randn(batch_size) * std + mean) + else: + raise ValueError(f"Unknown t_schedule: {self.t_schedule['name']}") + return t + + def training_losses( + self, + x_0: torch.Tensor, + cond=None, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses for a single timestep. + + Args: + x_0: The [N x C x ...] tensor of noiseless inputs. + cond: The [N x ...] tensor of additional conditions. + kwargs: Additional arguments to pass to the backbone. + + Returns: + a dict with the key "loss" containing a tensor of shape [N]. + may also contain other keys for different terms. + """ + noise = torch.randn_like(x_0) + t = self.sample_t(x_0.shape[0]).to(x_0.device).float() + x_t = self.diffuse(x_0, t, noise=noise) + cond = self.get_cond(cond, **kwargs) + + pred = self.training_models['denoiser'](x_t, t * 1000, cond, **kwargs) + assert pred.shape == noise.shape == x_0.shape + target = self.get_v(x_0, noise, t) + terms = edict() + terms["mse"] = F.mse_loss(pred, target) + terms["loss"] = terms["mse"] + + # log loss with time bins + mse_per_instance = np.array([ + F.mse_loss(pred[i], target[i]).item() + for i in range(x_0.shape[0]) + ]) + time_bin = np.digitize(t.cpu().numpy(), np.linspace(0, 1, 11)) - 1 + for i in range(10): + if (time_bin == i).sum() != 0: + terms[f"bin_{i}"] = {"mse": mse_per_instance[time_bin == i].mean()} + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + sampler = self.get_sampler() + sample_gt = [] + sample = [] + cond_vis = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + data = {k: v[:batch].cuda() if isinstance(v, torch.Tensor) else v[:batch] for k, v in data.items()} + noise = torch.randn_like(data['x_0']) + sample_gt.append(data['x_0']) + cond_vis.append(self.vis_cond(**data)) + del data['x_0'] + args = self.get_inference_cond(**data) + res = sampler.sample( + self.models['denoiser'], + noise=noise, + **args, + steps=50, cfg_strength=3.0, verbose=verbose, + ) + sample.append(res.samples) + + sample_gt = torch.cat(sample_gt, dim=0) + sample = torch.cat(sample, dim=0) + sample_dict = { + 'sample_gt': {'value': sample_gt, 'type': 'sample'}, + 'sample': {'value': sample, 'type': 'sample'}, + } + sample_dict.update(dict_reduce(cond_vis, None, { + 'value': lambda x: torch.cat(x, dim=0), + 'type': lambda x: x[0], + })) + + return sample_dict + + +class FlowMatchingCFGTrainer(ClassifierFreeGuidanceMixin, FlowMatchingTrainer): + """ + Trainer for diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + """ + pass + + +class TextConditionedFlowMatchingCFGTrainer(TextConditionedMixin, FlowMatchingCFGTrainer): + """ + Trainer for text-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + text_cond_model(str): Text conditioning model. + """ + pass + + +class ImageConditionedFlowMatchingCFGTrainer(ImageConditionedMixin, FlowMatchingCFGTrainer): + """ + Trainer for image-conditioned diffusion model with flow matching objective and classifier-free guidance. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + t_schedule (dict): Time schedule for flow matching. + sigma_min (float): Minimum noise level. + p_uncond (float): Probability of dropping conditions. + image_cond_model (str): Image conditioning model. + """ + pass diff --git a/trellis/utils/dist_utils.py b/trellis/utils/dist_utils.py new file mode 100644 index 0000000..348799c --- /dev/null +++ b/trellis/utils/dist_utils.py @@ -0,0 +1,93 @@ +import os +import io +from contextlib import contextmanager +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + + +def setup_dist(rank, local_rank, world_size, master_addr, master_port): + os.environ['MASTER_ADDR'] = master_addr + os.environ['MASTER_PORT'] = master_port + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = str(local_rank) + torch.cuda.set_device(local_rank) + dist.init_process_group('nccl', rank=rank, world_size=world_size) + + +def read_file_dist(path): + """ + Read the binary file distributedly. + File is only read once by the rank 0 process and broadcasted to other processes. + + Returns: + data (io.BytesIO): The binary data read from the file. + """ + if dist.is_initialized() and dist.get_world_size() > 1: + # read file + size = torch.LongTensor(1).cuda() + if dist.get_rank() == 0: + with open(path, 'rb') as f: + data = f.read() + data = torch.ByteTensor( + torch.UntypedStorage.from_buffer(data, dtype=torch.uint8) + ).cuda() + size[0] = data.shape[0] + # broadcast size + dist.broadcast(size, src=0) + if dist.get_rank() != 0: + data = torch.ByteTensor(size[0].item()).cuda() + # broadcast data + dist.broadcast(data, src=0) + # convert to io.BytesIO + data = data.cpu().numpy().tobytes() + data = io.BytesIO(data) + return data + else: + with open(path, 'rb') as f: + data = f.read() + data = io.BytesIO(data) + return data + + +def unwrap_dist(model): + """ + Unwrap the model from distributed training. + """ + if isinstance(model, DDP): + return model.module + return model + + +@contextmanager +def master_first(): + """ + A context manager that ensures master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + + +@contextmanager +def local_master_first(): + """ + A context manager that ensures local master process executes first. + """ + if not dist.is_initialized(): + yield + else: + if dist.get_rank() % torch.cuda.device_count() == 0: + yield + dist.barrier() + else: + dist.barrier() + yield + \ No newline at end of file diff --git a/trellis/utils/elastic_utils.py b/trellis/utils/elastic_utils.py new file mode 100644 index 0000000..cba3cf8 --- /dev/null +++ b/trellis/utils/elastic_utils.py @@ -0,0 +1,228 @@ +from abc import abstractmethod +from contextlib import contextmanager +from typing import Tuple +import torch +import torch.nn as nn +import numpy as np + + +class MemoryController: + """ + Base class for memory management during training. + """ + + _last_input_size = None + _last_mem_ratio = [] + + @contextmanager + def record(self): + pass + + def update_run_states(self, input_size=None, mem_ratio=None): + if self._last_input_size is None: + self._last_input_size = input_size + elif self._last_input_size!= input_size: + raise ValueError(f'Input size should not change for different ElasticModules.') + self._last_mem_ratio.append(mem_ratio) + + @abstractmethod + def get_mem_ratio(self, input_size): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def log(self): + pass + + +class LinearMemoryController(MemoryController): + """ + A simple controller for memory management during training. + The memory usage is modeled as a linear function of: + - the number of input parameters + - the ratio of memory the model use compared to the maximum usage (with no checkpointing) + memory_usage = k * input_size * mem_ratio + b + The controller keeps track of the memory usage and gives the + expected memory ratio to keep the memory usage under a target + """ + def __init__( + self, + buffer_size=1000, + update_every=500, + target_ratio=0.8, + available_memory=None, + max_mem_ratio_start=0.1, + params=None, + device=None + ): + self.buffer_size = buffer_size + self.update_every = update_every + self.target_ratio = target_ratio + self.device = device or torch.cuda.current_device() + self.available_memory = available_memory or torch.cuda.get_device_properties(self.device).total_memory / 1024**3 + + self._memory = np.zeros(buffer_size, dtype=np.float32) + self._input_size = np.zeros(buffer_size, dtype=np.float32) + self._mem_ratio = np.zeros(buffer_size, dtype=np.float32) + self._buffer_ptr = 0 + self._buffer_length = 0 + self._params = tuple(params) if params is not None else (0.0, 0.0) + self._max_mem_ratio = max_mem_ratio_start + self.step = 0 + + def __repr__(self): + return f'LinearMemoryController(target_ratio={self.target_ratio}, available_memory={self.available_memory})' + + def _add_sample(self, memory, input_size, mem_ratio): + self._memory[self._buffer_ptr] = memory + self._input_size[self._buffer_ptr] = input_size + self._mem_ratio[self._buffer_ptr] = mem_ratio + self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size + self._buffer_length = min(self._buffer_length + 1, self.buffer_size) + + @contextmanager + def record(self): + torch.cuda.reset_peak_memory_stats(self.device) + self._last_input_size = None + self._last_mem_ratio = [] + yield + self._last_memory = torch.cuda.max_memory_allocated(self.device) / 1024**3 + self._last_mem_ratio = sum(self._last_mem_ratio) / len(self._last_mem_ratio) + self._add_sample(self._last_memory, self._last_input_size, self._last_mem_ratio) + self.step += 1 + if self.step % self.update_every == 0: + self._max_mem_ratio = min(1.0, self._max_mem_ratio + 0.1) + self._fit_params() + + def _fit_params(self): + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + + x = input_size * mem_ratio + y = memory_usage + k, b = np.polyfit(x, y, 1) + self._params = (k, b) + # self._visualize() + + def _visualize(self): + import matplotlib.pyplot as plt + memory_usage = self._memory[:self._buffer_length] + input_size = self._input_size[:self._buffer_length] + mem_ratio = self._mem_ratio[:self._buffer_length] + k, b = self._params + + plt.scatter(input_size * mem_ratio, memory_usage, c=mem_ratio, cmap='viridis') + x = np.array([0.0, 20000.0]) + plt.plot(x, k * x + b, c='r') + plt.savefig(f'linear_memory_controller_{self.step}.png') + plt.cla() + + def get_mem_ratio(self, input_size): + k, b = self._params + if k == 0: return np.random.rand() * self._max_mem_ratio + pred = (self.available_memory * self.target_ratio - b) / (k * input_size) + return min(self._max_mem_ratio, max(0.0, pred)) + + def state_dict(self): + return { + 'params': self._params, + } + + def load_state_dict(self, state_dict): + self._params = tuple(state_dict['params']) + + def log(self): + return { + 'params/k': self._params[0], + 'params/b': self._params[1], + 'memory': self._last_memory, + 'input_size': self._last_input_size, + 'mem_ratio': self._last_mem_ratio, + } + + +class ElasticModule(nn.Module): + """ + Module for training with elastic memory management. + """ + def __init__(self): + super().__init__() + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + def _forward_with_mem_ratio(self, *args, mem_ratio=0.0, **kwargs) -> Tuple[float, Tuple]: + """ + Forward with a given memory ratio. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + _, ret = self._forward_with_mem_ratio(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + mem_ratio, ret = self._forward_with_mem_ratio(*args, mem_ratio=mem_ratio, **kwargs) + self._memory_controller.update_run_states(input_size, mem_ratio) + return ret + + +class ElasticModuleMixin: + """ + Mixin for training with elastic memory management. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._memory_controller: MemoryController = None + + @abstractmethod + def _get_input_size(self, *args, **kwargs) -> int: + """ + Get the size of the input data. + + Returns: + int: The size of the input data. + """ + pass + + @abstractmethod + @contextmanager + def with_mem_ratio(self, mem_ratio=1.0) -> float: + """ + Context manager for training with a reduced memory ratio compared to the full memory usage. + + Returns: + float: The exact memory ratio used during the forward pass. + """ + pass + + def register_memory_controller(self, memory_controller: MemoryController): + self._memory_controller = memory_controller + + def forward(self, *args, **kwargs): + if self._memory_controller is None or not torch.is_grad_enabled() or not self.training: + ret = super().forward(*args, **kwargs) + else: + input_size = self._get_input_size(*args, **kwargs) + mem_ratio = self._memory_controller.get_mem_ratio(input_size) + with self.with_mem_ratio(mem_ratio) as exact_mem_ratio: + ret = super().forward(*args, **kwargs) + self._memory_controller.update_run_states(input_size, exact_mem_ratio) + return ret diff --git a/trellis/utils/postprocessing_utils.py b/trellis/utils/postprocessing_utils.py new file mode 100644 index 0000000..0a8d9fb --- /dev/null +++ b/trellis/utils/postprocessing_utils.py @@ -0,0 +1,587 @@ +from typing import * +import numpy as np +import torch +import utils3d +import nvdiffrast.torch as dr +from tqdm import tqdm +import trimesh +import trimesh.visual +import xatlas +import pyvista as pv +from pymeshfix import _meshfix +import igraph +import cv2 +from PIL import Image +from .random_utils import sphere_hammersley_sequence +from .render_utils import render_multiview +from ..renderers import GaussianRenderer +from ..representations import Strivec, Gaussian, MeshExtractResult + + +@torch.no_grad() +def _fill_holes( + verts, + faces, + max_hole_size=0.04, + max_hole_nbe=32, + resolution=128, + num_views=500, + debug=False, + verbose=False +): + """ + Rasterize a mesh from multiple views and remove invisible faces. + Also includes postprocessing to: + 1. Remove connected components that are have low visibility. + 2. Mincut to remove faces at the inner side of the mesh connected to the outer side with a small hole. + + Args: + verts (torch.Tensor): Vertices of the mesh. Shape (V, 3). + faces (torch.Tensor): Faces of the mesh. Shape (F, 3). + max_hole_size (float): Maximum area of a hole to fill. + resolution (int): Resolution of the rasterization. + num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + # Construct cameras + yaws = [] + pitchs = [] + for i in range(num_views): + y, p = sphere_hammersley_sequence(i, num_views) + yaws.append(y) + pitchs.append(p) + yaws = torch.tensor(yaws).cuda() + pitchs = torch.tensor(pitchs).cuda() + radius = 2.0 + fov = torch.deg2rad(torch.tensor(40)).cuda() + projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3) + views = [] + for (yaw, pitch) in zip(yaws, pitchs): + orig = torch.tensor([ + torch.sin(yaw) * torch.cos(pitch), + torch.cos(yaw) * torch.cos(pitch), + torch.sin(pitch), + ]).cuda().float() * radius + view = utils3d.torch.view_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + views.append(view) + views = torch.stack(views, dim=0) + + # Rasterize + visblity = torch.zeros(faces.shape[0], dtype=torch.int32, device=verts.device) + rastctx = utils3d.torch.RastContext(backend='cuda') + for i in tqdm(range(views.shape[0]), total=views.shape[0], disable=not verbose, desc='Rasterizing'): + view = views[i] + buffers = utils3d.torch.rasterize_triangle_faces( + rastctx, verts[None], faces, resolution, resolution, view=view, projection=projection + ) + face_id = buffers['face_id'][0][buffers['mask'][0] > 0.95] - 1 + face_id = torch.unique(face_id).long() + visblity[face_id] += 1 + visblity = visblity.float() / num_views + + # Mincut + ## construct outer faces + edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces) + boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1) + connected_components = utils3d.torch.compute_connected_components(faces, edges, face2edge) + outer_face_indices = torch.zeros(faces.shape[0], dtype=torch.bool, device=faces.device) + for i in range(len(connected_components)): + outer_face_indices[connected_components[i]] = visblity[connected_components[i]] > min(max(visblity[connected_components[i]].quantile(0.75).item(), 0.25), 0.5) + outer_face_indices = outer_face_indices.nonzero().reshape(-1) + + ## construct inner faces + inner_face_indices = torch.nonzero(visblity == 0).reshape(-1) + if verbose: + tqdm.write(f'Found {inner_face_indices.shape[0]} invisible faces') + if inner_face_indices.shape[0] == 0: + return verts, faces + + ## Construct dual graph (faces as nodes, edges as edges) + dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(face2edge) + dual_edge2edge = edges[dual_edge2edge] + dual_edges_weights = torch.norm(verts[dual_edge2edge[:, 0]] - verts[dual_edge2edge[:, 1]], dim=1) + if verbose: + tqdm.write(f'Dual graph: {dual_edges.shape[0]} edges') + + ## solve mincut problem + ### construct main graph + g = igraph.Graph() + g.add_vertices(faces.shape[0]) + g.add_edges(dual_edges.cpu().numpy()) + g.es['weight'] = dual_edges_weights.cpu().numpy() + + ### source and target + g.add_vertex('s') + g.add_vertex('t') + + ### connect invisible faces to source + g.add_edges([(f, 's') for f in inner_face_indices], attributes={'weight': torch.ones(inner_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### connect outer faces to target + g.add_edges([(f, 't') for f in outer_face_indices], attributes={'weight': torch.ones(outer_face_indices.shape[0], dtype=torch.float32).cpu().numpy()}) + + ### solve mincut + cut = g.mincut('s', 't', (np.array(g.es['weight']) * 1000).tolist()) + remove_face_indices = torch.tensor([v for v in cut.partition[0] if v < faces.shape[0]], dtype=torch.long, device=faces.device) + if verbose: + tqdm.write(f'Mincut solved, start checking the cut') + + ### check if the cut is valid with each connected component + to_remove_cc = utils3d.torch.compute_connected_components(faces[remove_face_indices]) + if debug: + tqdm.write(f'Number of connected components of the cut: {len(to_remove_cc)}') + valid_remove_cc = [] + cutting_edges = [] + for cc in to_remove_cc: + #### check if the connected component has low visibility + visblity_median = visblity[remove_face_indices[cc]].median() + if debug: + tqdm.write(f'visblity_median: {visblity_median}') + if visblity_median > 0.25: + continue + + #### check if the cuting loop is small enough + cc_edge_indices, cc_edges_degree = torch.unique(face2edge[remove_face_indices[cc]], return_counts=True) + cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1] + cc_new_boundary_edge_indices = cc_boundary_edge_indices[~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)] + if len(cc_new_boundary_edge_indices) > 0: + cc_new_boundary_edge_cc = utils3d.torch.compute_edge_connected_components(edges[cc_new_boundary_edge_indices]) + cc_new_boundary_edges_cc_center = [verts[edges[cc_new_boundary_edge_indices[edge_cc]]].mean(dim=1).mean(dim=0) for edge_cc in cc_new_boundary_edge_cc] + cc_new_boundary_edges_cc_area = [] + for i, edge_cc in enumerate(cc_new_boundary_edge_cc): + _e1 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]] - cc_new_boundary_edges_cc_center[i] + _e2 = verts[edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]] - cc_new_boundary_edges_cc_center[i] + cc_new_boundary_edges_cc_area.append(torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum() * 0.5) + if debug: + cutting_edges.append(cc_new_boundary_edge_indices) + tqdm.write(f'Area of the cutting loop: {cc_new_boundary_edges_cc_area}') + if any([l > max_hole_size for l in cc_new_boundary_edges_cc_area]): + continue + + valid_remove_cc.append(cc) + + if debug: + face_v = verts[faces].mean(dim=1).cpu().numpy() + vis_dual_edges = dual_edges.cpu().numpy() + vis_colors = np.zeros((faces.shape[0], 3), dtype=np.uint8) + vis_colors[inner_face_indices.cpu().numpy()] = [0, 0, 255] + vis_colors[outer_face_indices.cpu().numpy()] = [0, 255, 0] + vis_colors[remove_face_indices.cpu().numpy()] = [255, 0, 255] + if len(valid_remove_cc) > 0: + vis_colors[remove_face_indices[torch.cat(valid_remove_cc)].cpu().numpy()] = [255, 0, 0] + utils3d.io.write_ply('dbg_dual.ply', face_v, edges=vis_dual_edges, vertex_colors=vis_colors) + + vis_verts = verts.cpu().numpy() + vis_edges = edges[torch.cat(cutting_edges)].cpu().numpy() + utils3d.io.write_ply('dbg_cut.ply', vis_verts, edges=vis_edges) + + + if len(valid_remove_cc) > 0: + remove_face_indices = remove_face_indices[torch.cat(valid_remove_cc)] + mask = torch.ones(faces.shape[0], dtype=torch.bool, device=faces.device) + mask[remove_face_indices] = 0 + faces = faces[mask] + faces, verts = utils3d.torch.remove_unreferenced_vertices(faces, verts) + if verbose: + tqdm.write(f'Removed {(~mask).sum()} faces by mincut') + else: + if verbose: + tqdm.write(f'Removed 0 faces by mincut') + + mesh = _meshfix.PyTMesh() + mesh.load_array(verts.cpu().numpy(), faces.cpu().numpy()) + mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True) + verts, faces = mesh.return_arrays() + verts, faces = torch.tensor(verts, device='cuda', dtype=torch.float32), torch.tensor(faces, device='cuda', dtype=torch.int32) + + return verts, faces + + +def postprocess_mesh( + vertices: np.array, + faces: np.array, + simplify: bool = True, + simplify_ratio: float = 0.9, + fill_holes: bool = True, + fill_holes_max_hole_size: float = 0.04, + fill_holes_max_hole_nbe: int = 32, + fill_holes_resolution: int = 1024, + fill_holes_num_views: int = 1000, + debug: bool = False, + verbose: bool = False, +): + """ + Postprocess a mesh by simplifying, removing invisible faces, and removing isolated pieces. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + simplify (bool): Whether to simplify the mesh, using quadric edge collapse. + simplify_ratio (float): Ratio of faces to keep after simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_hole_size (float): Maximum area of a hole to fill. + fill_holes_max_hole_nbe (int): Maximum number of boundary edges of a hole to fill. + fill_holes_resolution (int): Resolution of the rasterization. + fill_holes_num_views (int): Number of views to rasterize the mesh. + verbose (bool): Whether to print progress. + """ + + if verbose: + tqdm.write(f'Before postprocess: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Simplify + if simplify and simplify_ratio > 0: + mesh = pv.PolyData(vertices, np.concatenate([np.full((faces.shape[0], 1), 3), faces], axis=1)) + mesh = mesh.decimate(simplify_ratio, progress_bar=verbose) + vertices, faces = mesh.points, mesh.faces.reshape(-1, 4)[:, 1:] + if verbose: + tqdm.write(f'After decimate: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + # Remove invisible faces + if fill_holes: + vertices, faces = torch.tensor(vertices).cuda(), torch.tensor(faces.astype(np.int32)).cuda() + vertices, faces = _fill_holes( + vertices, faces, + max_hole_size=fill_holes_max_hole_size, + max_hole_nbe=fill_holes_max_hole_nbe, + resolution=fill_holes_resolution, + num_views=fill_holes_num_views, + debug=debug, + verbose=verbose, + ) + vertices, faces = vertices.cpu().numpy(), faces.cpu().numpy() + if verbose: + tqdm.write(f'After remove invisible faces: {vertices.shape[0]} vertices, {faces.shape[0]} faces') + + return vertices, faces + + +def parametrize_mesh(vertices: np.array, faces: np.array): + """ + Parametrize a mesh to a texture space, using xatlas. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + """ + + vmapping, indices, uvs = xatlas.parametrize(vertices, faces) + + vertices = vertices[vmapping] + faces = indices + + return vertices, faces, uvs + + +def bake_texture( + vertices: np.array, + faces: np.array, + uvs: np.array, + observations: List[np.array], + masks: List[np.array], + extrinsics: List[np.array], + intrinsics: List[np.array], + texture_size: int = 2048, + near: float = 0.1, + far: float = 10.0, + mode: Literal['fast', 'opt'] = 'opt', + lambda_tv: float = 1e-2, + verbose: bool = False, +): + """ + Bake texture to a mesh from multiple observations. + + Args: + vertices (np.array): Vertices of the mesh. Shape (V, 3). + faces (np.array): Faces of the mesh. Shape (F, 3). + uvs (np.array): UV coordinates of the mesh. Shape (V, 2). + observations (List[np.array]): List of observations. Each observation is a 2D image. Shape (H, W, 3). + masks (List[np.array]): List of masks. Each mask is a 2D image. Shape (H, W). + extrinsics (List[np.array]): List of extrinsics. Shape (4, 4). + intrinsics (List[np.array]): List of intrinsics. Shape (3, 3). + texture_size (int): Size of the texture. + near (float): Near plane of the camera. + far (float): Far plane of the camera. + mode (Literal['fast', 'opt']): Mode of texture baking. + lambda_tv (float): Weight of total variation loss in optimization. + verbose (bool): Whether to print progress. + """ + vertices = torch.tensor(vertices).cuda() + faces = torch.tensor(faces.astype(np.int32)).cuda() + uvs = torch.tensor(uvs).cuda() + observations = [torch.tensor(obs / 255.0).float().cuda() for obs in observations] + masks = [torch.tensor(m>0).bool().cuda() for m in masks] + views = [utils3d.torch.extrinsics_to_view(torch.tensor(extr).cuda()) for extr in extrinsics] + projections = [utils3d.torch.intrinsics_to_perspective(torch.tensor(intr).cuda(), near, far) for intr in intrinsics] + + if mode == 'fast': + texture = torch.zeros((texture_size * texture_size, 3), dtype=torch.float32).cuda() + texture_weights = torch.zeros((texture_size * texture_size), dtype=torch.float32).cuda() + rastctx = utils3d.torch.RastContext(backend='cuda') + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(observations), disable=not verbose, desc='Texture baking (fast)'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + uv_map = rast['uv'][0].detach().flip(0) + mask = rast['mask'][0].detach().bool() & masks[0] + + # nearest neighbor interpolation + uv_map = (uv_map * texture_size).floor().long() + obs = observation[mask] + uv_map = uv_map[mask] + idx = uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size + texture = texture.scatter_add(0, idx.view(-1, 1).expand(-1, 3), obs) + texture_weights = texture_weights.scatter_add(0, idx, torch.ones((obs.shape[0]), dtype=torch.float32, device=texture.device)) + + mask = texture_weights > 0 + texture[mask] /= texture_weights[mask][:, None] + texture = np.clip(texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, 0, 255).astype(np.uint8) + + # inpaint + mask = (texture_weights == 0).cpu().numpy().astype(np.uint8).reshape(texture_size, texture_size) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + elif mode == 'opt': + rastctx = utils3d.torch.RastContext(backend='cuda') + observations = [observations.flip(0) for observations in observations] + masks = [m.flip(0) for m in masks] + _uv = [] + _uv_dr = [] + for observation, view, projection in tqdm(zip(observations, views, projections), total=len(views), disable=not verbose, desc='Texture baking (opt): UV'): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, vertices[None], faces, observation.shape[1], observation.shape[0], uv=uvs[None], view=view, projection=projection + ) + _uv.append(rast['uv'].detach()) + _uv_dr.append(rast['uv_dr'].detach()) + + texture = torch.nn.Parameter(torch.zeros((1, texture_size, texture_size, 3), dtype=torch.float32).cuda()) + optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + + def tv_loss(texture): + return torch.nn.functional.l1_loss(texture[:, :-1, :, :], texture[:, 1:, :, :]) + \ + torch.nn.functional.l1_loss(texture[:, :, :-1, :], texture[:, :, 1:, :]) + + total_steps = 2500 + with tqdm(total=total_steps, disable=not verbose, desc='Texture baking (opt): optimizing') as pbar: + for step in range(total_steps): + optimizer.zero_grad() + selected = np.random.randint(0, len(views)) + uv, uv_dr, observation, mask = _uv[selected], _uv_dr[selected], observations[selected], masks[selected] + render = dr.texture(texture, uv, uv_dr)[0] + loss = torch.nn.functional.l1_loss(render[mask], observation[mask]) + if lambda_tv > 0: + loss += lambda_tv * tv_loss(texture) + loss.backward() + optimizer.step() + # annealing + optimizer.param_groups[0]['lr'] = cosine_anealing(optimizer, step, total_steps, 1e-2, 1e-5) + pbar.set_postfix({'loss': loss.item()}) + pbar.update() + texture = np.clip(texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255).astype(np.uint8) + mask = 1 - utils3d.torch.rasterize_triangle_faces( + rastctx, (uvs * 2 - 1)[None], faces, texture_size, texture_size + )['mask'][0].detach().cpu().numpy().astype(np.uint8) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + else: + raise ValueError(f'Unknown mode: {mode}') + + return texture + + +def to_glb( + app_rep: Union[Strivec, Gaussian], + mesh: MeshExtractResult, + simplify: float = 0.95, + fill_holes: bool = True, + fill_holes_max_size: float = 0.04, + texture_size: int = 1024, + debug: bool = False, + verbose: bool = True, +) -> trimesh.Trimesh: + """ + Convert a generated asset to a glb file. + + Args: + app_rep (Union[Strivec, Gaussian]): Appearance representation. + mesh (MeshExtractResult): Extracted mesh. + simplify (float): Ratio of faces to remove in simplification. + fill_holes (bool): Whether to fill holes in the mesh. + fill_holes_max_size (float): Maximum area of a hole to fill. + texture_size (int): Size of the texture. + debug (bool): Whether to print debug information. + verbose (bool): Whether to print progress. + """ + vertices = mesh.vertices.cpu().numpy() + faces = mesh.faces.cpu().numpy() + + # mesh postprocess + vertices, faces = postprocess_mesh( + vertices, faces, + simplify=simplify > 0, + simplify_ratio=simplify, + fill_holes=fill_holes, + fill_holes_max_hole_size=fill_holes_max_size, + fill_holes_max_hole_nbe=int(250 * np.sqrt(1-simplify)), + fill_holes_resolution=1024, + fill_holes_num_views=1000, + debug=debug, + verbose=verbose, + ) + + # parametrize mesh + vertices, faces, uvs = parametrize_mesh(vertices, faces) + + # bake texture + observations, extrinsics, intrinsics = render_multiview(app_rep, resolution=1024, nviews=100) + masks = [np.any(observation > 0, axis=-1) for observation in observations] + extrinsics = [extrinsics[i].cpu().numpy() for i in range(len(extrinsics))] + intrinsics = [intrinsics[i].cpu().numpy() for i in range(len(intrinsics))] + texture = bake_texture( + vertices, faces, uvs, + observations, masks, extrinsics, intrinsics, + texture_size=texture_size, mode='opt', + lambda_tv=0.01, + verbose=verbose + ) + texture = Image.fromarray(texture) + + # rotate mesh (from z-up to y-up) + vertices = vertices @ np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + material = trimesh.visual.material.PBRMaterial( + roughnessFactor=1.0, + baseColorTexture=texture, + baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8) + ) + mesh = trimesh.Trimesh(vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, material=material)) + return mesh + + +def simplify_gs( + gs: Gaussian, + simplify: float = 0.95, + verbose: bool = True, +): + """ + Simplify 3D Gaussians + NOTE: this function is not used in the current implementation for the unsatisfactory performance. + + Args: + gs (Gaussian): 3D Gaussian. + simplify (float): Ratio of Gaussians to remove in simplification. + """ + if simplify <= 0: + return gs + + # simplify + observations, extrinsics, intrinsics = render_multiview(gs, resolution=1024, nviews=100) + observations = [torch.tensor(obs / 255.0).float().cuda().permute(2, 0, 1) for obs in observations] + + # Following https://arxiv.org/pdf/2411.06019 + renderer = GaussianRenderer({ + "resolution": 1024, + "near": 0.8, + "far": 1.6, + "ssaa": 1, + "bg_color": (0,0,0), + }) + new_gs = Gaussian(**gs.init_params) + new_gs._features_dc = gs._features_dc.clone() + new_gs._features_rest = gs._features_rest.clone() if gs._features_rest is not None else None + new_gs._opacity = torch.nn.Parameter(gs._opacity.clone()) + new_gs._rotation = torch.nn.Parameter(gs._rotation.clone()) + new_gs._scaling = torch.nn.Parameter(gs._scaling.clone()) + new_gs._xyz = torch.nn.Parameter(gs._xyz.clone()) + + start_lr = [1e-4, 1e-3, 5e-3, 0.025] + end_lr = [1e-6, 1e-5, 5e-5, 0.00025] + optimizer = torch.optim.Adam([ + {"params": new_gs._xyz, "lr": start_lr[0]}, + {"params": new_gs._rotation, "lr": start_lr[1]}, + {"params": new_gs._scaling, "lr": start_lr[2]}, + {"params": new_gs._opacity, "lr": start_lr[3]}, + ], lr=start_lr[0]) + + def exp_anealing(optimizer, step, total_steps, start_lr, end_lr): + return start_lr * (end_lr / start_lr) ** (step / total_steps) + + def cosine_anealing(optimizer, step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + + _zeta = new_gs.get_opacity.clone().detach().squeeze() + _lambda = torch.zeros_like(_zeta) + _delta = 1e-7 + _interval = 10 + num_target = int((1 - simplify) * _zeta.shape[0]) + + with tqdm(total=2500, disable=not verbose, desc='Simplifying Gaussian') as pbar: + for i in range(2500): + # prune + if i % 100 == 0: + mask = new_gs.get_opacity.squeeze() > 0.05 + mask = torch.nonzero(mask).squeeze() + new_gs._xyz = torch.nn.Parameter(new_gs._xyz[mask]) + new_gs._rotation = torch.nn.Parameter(new_gs._rotation[mask]) + new_gs._scaling = torch.nn.Parameter(new_gs._scaling[mask]) + new_gs._opacity = torch.nn.Parameter(new_gs._opacity[mask]) + new_gs._features_dc = new_gs._features_dc[mask] + new_gs._features_rest = new_gs._features_rest[mask] if new_gs._features_rest is not None else None + _zeta = _zeta[mask] + _lambda = _lambda[mask] + # update optimizer state + for param_group, new_param in zip(optimizer.param_groups, [new_gs._xyz, new_gs._rotation, new_gs._scaling, new_gs._opacity]): + stored_state = optimizer.state[param_group['params'][0]] + if 'exp_avg' in stored_state: + stored_state['exp_avg'] = stored_state['exp_avg'][mask] + stored_state['exp_avg_sq'] = stored_state['exp_avg_sq'][mask] + del optimizer.state[param_group['params'][0]] + param_group['params'][0] = new_param + optimizer.state[param_group['params'][0]] = stored_state + + opacity = new_gs.get_opacity.squeeze() + + # sparisfy + if i % _interval == 0: + _zeta = _lambda + opacity.detach() + if opacity.shape[0] > num_target: + index = _zeta.topk(num_target)[1] + _m = torch.ones_like(_zeta, dtype=torch.bool) + _m[index] = 0 + _zeta[_m] = 0 + _lambda = _lambda + opacity.detach() - _zeta + + # sample a random view + view_idx = np.random.randint(len(observations)) + observation = observations[view_idx] + extrinsic = extrinsics[view_idx] + intrinsic = intrinsics[view_idx] + + color = renderer.render(new_gs, extrinsic, intrinsic)['color'] + rgb_loss = torch.nn.functional.l1_loss(color, observation) + loss = rgb_loss + \ + _delta * torch.sum(torch.pow(_lambda + opacity - _zeta, 2)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # update lr + for j in range(len(optimizer.param_groups)): + optimizer.param_groups[j]['lr'] = cosine_anealing(optimizer, i, 2500, start_lr[j], end_lr[j]) + + pbar.set_postfix({'loss': rgb_loss.item(), 'num': opacity.shape[0], 'lambda': _lambda.mean().item()}) + pbar.update() + + new_gs._xyz = new_gs._xyz.data + new_gs._rotation = new_gs._rotation.data + new_gs._scaling = new_gs._scaling.data + new_gs._opacity = new_gs._opacity.data + + return new_gs