From 7531afd162ea73fb05aa1f4bad70a0464e8723d9 Mon Sep 17 00:00:00 2001 From: zcr Date: Tue, 17 Mar 2026 11:38:16 +0800 Subject: [PATCH] 1 --- dataset_toolkits/datasets/ObjaverseXL.py | 92 +++++ dataset_toolkits/stat_latent.py | 66 +++ multi_image_to_3D.py | 112 +++++ trellis/datasets/structured_latent.py | 217 ++++++++++ trellis/datasets/structured_latent2render.py | 160 ++++++++ trellis/models/structured_latent_flow.py | 276 +++++++++++++ trellis/modules/norm.py | 25 ++ trellis/modules/sparse/nonlinearity.py | 35 ++ trellis/modules/sparse/norm.py | 58 +++ trellis/renderers/octree_renderer.py | 300 ++++++++++++++ trellis/representations/octree/octree_dfs.py | 347 ++++++++++++++++ .../representations/radiance_field/strivec.py | 28 ++ .../vae/structured_latent_vae_gaussian.py | 275 +++++++++++++ .../vae/structured_latent_vae_mesh_dec.py | 382 ++++++++++++++++++ .../vae/structured_latent_vae_rf_dec.py | 223 ++++++++++ utils/new_oss_client.py | 140 +++++++ 16 files changed, 2736 insertions(+) create mode 100644 dataset_toolkits/datasets/ObjaverseXL.py create mode 100644 dataset_toolkits/stat_latent.py create mode 100644 multi_image_to_3D.py create mode 100644 trellis/datasets/structured_latent.py create mode 100644 trellis/datasets/structured_latent2render.py create mode 100644 trellis/models/structured_latent_flow.py create mode 100644 trellis/modules/norm.py create mode 100644 trellis/modules/sparse/nonlinearity.py create mode 100644 trellis/modules/sparse/norm.py create mode 100644 trellis/renderers/octree_renderer.py create mode 100644 trellis/representations/octree/octree_dfs.py create mode 100644 trellis/representations/radiance_field/strivec.py create mode 100644 trellis/trainers/vae/structured_latent_vae_gaussian.py create mode 100644 trellis/trainers/vae/structured_latent_vae_mesh_dec.py create mode 100644 trellis/trainers/vae/structured_latent_vae_rf_dec.py create mode 100644 utils/new_oss_client.py diff --git a/dataset_toolkits/datasets/ObjaverseXL.py b/dataset_toolkits/datasets/ObjaverseXL.py new file mode 100644 index 0000000..b2f5c76 --- /dev/null +++ b/dataset_toolkits/datasets/ObjaverseXL.py @@ -0,0 +1,92 @@ +import os +import argparse +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +import pandas as pd +import objaverse.xl as oxl +from utils import get_file_hash + + +def add_args(parser: argparse.ArgumentParser): + parser.add_argument('--source', type=str, default='sketchfab', + help='Data source to download annotations from (github, sketchfab)') + + +def get_metadata(source, **kwargs): + if source == 'sketchfab': + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_sketchfab.csv") + elif source == 'github': + metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_github.csv") + else: + raise ValueError(f"Invalid source: {source}") + return metadata + + +def download(metadata, output_dir, **kwargs): + os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True) + + # download annotations + annotations = oxl.get_annotations() + annotations = annotations[annotations['sha256'].isin(metadata['sha256'].values)] + + # download and render objects + file_paths = oxl.download_objects( + annotations, + download_dir=os.path.join(output_dir, "raw"), + save_repo_format="zip", + ) + + downloaded = {} + metadata = metadata.set_index("file_identifier") + for k, v in file_paths.items(): + sha256 = metadata.loc[k, "sha256"] + downloaded[sha256] = os.path.relpath(v, output_dir) + + return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path']) + + +def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame: + import os + from concurrent.futures import ThreadPoolExecutor + from tqdm import tqdm + import tempfile + import zipfile + + # load metadata + metadata = metadata.to_dict('records') + + # processing objects + records = [] + max_workers = max_workers or os.cpu_count() + try: + with ThreadPoolExecutor(max_workers=max_workers) as executor, \ + tqdm(total=len(metadata), desc=desc) as pbar: + def worker(metadatum): + try: + local_path = metadatum['local_path'] + sha256 = metadatum['sha256'] + if local_path.startswith('raw/github/repos/'): + path_parts = local_path.split('/') + file_name = os.path.join(*path_parts[5:]) + zip_file = os.path.join(output_dir, *path_parts[:5]) + with tempfile.TemporaryDirectory() as tmp_dir: + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(tmp_dir) + file = os.path.join(tmp_dir, file_name) + record = func(file, sha256) + else: + file = os.path.join(output_dir, local_path) + record = func(file, sha256) + if record is not None: + records.append(record) + pbar.update() + except Exception as e: + print(f"Error processing object {sha256}: {e}") + pbar.update() + + executor.map(worker, metadata) + executor.shutdown(wait=True) + except: + print("Error happened during processing.") + + return pd.DataFrame.from_records(records) diff --git a/dataset_toolkits/stat_latent.py b/dataset_toolkits/stat_latent.py new file mode 100644 index 0000000..7f27a06 --- /dev/null +++ b/dataset_toolkits/stat_latent.py @@ -0,0 +1,66 @@ +import os +import json +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from easydict import EasyDict as edict +from concurrent.futures import ThreadPoolExecutor + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--output_dir', type=str, required=True, + help='Directory to save the metadata') + parser.add_argument('--filter_low_aesthetic_score', type=float, default=None, + help='Filter objects with aesthetic score lower than this value') + parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16', + help='Latent model to use') + parser.add_argument('--num_samples', type=int, default=50000, + help='Number of samples to use for calculating stats') + opt = parser.parse_args() + opt = edict(vars(opt)) + + # get file list + if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')): + metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv')) + else: + raise ValueError('metadata.csv not found') + if opt.filter_low_aesthetic_score is not None: + metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score] + metadata = metadata[metadata[f'latent_{opt.model}'] == True] + sha256s = metadata['sha256'].values + sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False) + + # stats + means = [] + mean2s = [] + with ThreadPoolExecutor(max_workers=16) as executor, \ + tqdm(total=len(sha256s), desc="Extracting features") as pbar: + def worker(sha256): + try: + feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz')) + feats = feats['feats'] + means.append(feats.mean(axis=0)) + mean2s.append((feats ** 2).mean(axis=0)) + pbar.update() + except Exception as e: + print(f"Error extracting features for {sha256}: {e}") + pbar.update() + + executor.map(worker, sha256s) + executor.shutdown(wait=True) + + mean = np.array(means).mean(axis=0) + mean2 = np.array(mean2s).mean(axis=0) + std = np.sqrt(mean2 - mean ** 2) + + print('mean:', mean) + print('std:', std) + + with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f: + json.dump({ + 'mean': mean.tolist(), + 'std': std.tolist(), + }, f, indent=4) + \ No newline at end of file diff --git a/multi_image_to_3D.py b/multi_image_to_3D.py new file mode 100644 index 0000000..7731e43 --- /dev/null +++ b/multi_image_to_3D.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +#-*- coding: utf-8 -*- + +import os +import argparse +import numpy as np +import imageio +from PIL import Image + +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.utils import render_utils, postprocessing_utils + +def build_parser(): + p = argparse.ArgumentParser("TRELLIS CLI: multi-image -> 3D (video + optional GLB/PLY)") + + p.add_argument( + "-i", "--images", + nargs="+", + required=True, + help="Input image paths (space-separated), e.g. -i a.png b.png c.png" + ) + p.add_argument( + "-o", "--out_dir", + default="trellis_out_multi", + help="Output directory.") + p.add_argument("--seed", type=int, default=1) + + p.add_argument("--steps_sparse", type=int, default=12) + p.add_argument("--cfg_sparse", type=float, default=7.5) + p.add_argument("--steps_slat", type=int, default=12) + p.add_argument("--cfg_slat", type=float, default=3.0) + + # Render video (default True) + p.add_argument("--save_video", dest="save_video", action="store_true", default=True) + p.add_argument("--no-save_video", dest="save_video", action="store_false") + p.add_argument("--video_name", type=str, default="sample_multi.mp4") + p.add_argument("--fps", type=int, default=30) + + # Export GLB (default True) + p.add_argument("--export_glb", dest="export_glb", action="store_true", default=True) + p.add_argument("--no-export_glb", dest="export_glb", action="store_false") + p.add_argument("--glb_name", type=str, default="sample_multi.glb") + p.add_argument("--texture_size", type=int, default=1024) + p.add_argument("--simplify", type=float, default=0.95) + + # Save PLY (default True) + p.add_argument("--save_ply", dest="save_ply", action="store_true", default=True) + p.add_argument("--no-save_ply", dest="save_ply", action="store_false") + p.add_argument("--ply_name", type=str, default="sample_multi.ply") + + # Env passthrough (optional) + p.add_argument("--spconv_algo", type=str, default="native", choices=["native", "auto"]) + # p.add_argument("--attn_backend", type=str, default="", choices=["", "flash-attn", "xformers"]) + + return p + +def main(): + args = build_parser().parse_args() + os.makedirs(args.out_dir, exist_ok=True) + + os.environ["SPCONV_ALGO"] = args.spconv_algo + # if args.attn_backend: + # os.environ["ATTN_BACKEND"] = args.attn_backend + + pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") + pipeline.cuda() + + images = [Image.open(p) for p in args.images] + + outputs = pipeline.run_multi_image( + images, + seed=args.seed, + sparse_structure_sampler_params={ + "steps": args.steps_sparse, + "cfg_strength": args.cfg_sparse, + }, + slat_sampler_params={ + "steps": args.steps_slat, + "cfg_strength": args.cfg_slat, + }, + ) + + # For multi-image, TRELLIS still returns list-like outputs; export the first asset by default. + gs = outputs["gaussian"][0] + mesh = outputs["mesh"][0] + + if args.save_video: + video_gs = render_utils.render_video(gs)["color"] + video_mesh = render_utils.render_video(mesh)["normal"] + video = [np.concatenate([fg, fm], axis=1) for fg, fm in zip(video_gs, video_mesh)] + out_mp4 = os.path.join(args.out_dir, args.video_name) + imageio.mimsave(out_mp4, video, fps=args.fps) + print(f"[ok] saved video: {out_mp4}") + + if args.export_glb: + out_glb = os.path.join(args.out_dir, args.glb_name) + glb = postprocessing_utils.to_glb( + gs, + mesh, + simplify=args.simplify, + texture_size=args.texture_size, + ) + glb.export(out_glb) + print(f"[ok] exported glb: {out_glb}") + + if args.save_ply: + out_ply = os.path.join(args.out_dir, args.ply_name) + gs.save_ply(out_ply) + print(f"[ok] saved ply: {out_ply}") + +if __name__ == "__main__": + main() diff --git a/trellis/datasets/structured_latent.py b/trellis/datasets/structured_latent.py new file mode 100644 index 0000000..f8b8f24 --- /dev/null +++ b/trellis/datasets/structured_latent.py @@ -0,0 +1,217 @@ +import json +import os +from typing import * +import numpy as np +import torch +import utils3d.torch +from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin +from ..modules.sparse.basic import SparseTensor +from .. import models +from ..utils.render_utils import get_renderer +from ..utils.data_utils import load_balanced_group_indices + + +class SLatVisMixin: + def __init__( + self, + *args, + pretrained_slat_dec: str = 'microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.slat_dec = None + self.pretrained_slat_dec = pretrained_slat_dec + self.slat_dec_path = slat_dec_path + self.slat_dec_ckpt = slat_dec_ckpt + + def _loading_slat_dec(self): + if self.slat_dec is not None: + return + if self.slat_dec_path is not None: + cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r')) + decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args']) + ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt') + decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True)) + else: + decoder = models.from_pretrained(self.pretrained_slat_dec) + self.slat_dec = decoder.cuda().eval() + + def _delete_slat_dec(self): + del self.slat_dec + self.slat_dec = None + + @torch.no_grad() + def decode_latent(self, z, batch_size=4): + self._loading_slat_dec() + reps = [] + if self.normalization is not None: + z = z * self.std.to(z.device) + self.mean.to(z.device) + for i in range(0, z.shape[0], batch_size): + reps.append(self.slat_dec(z[i:i+batch_size])) + reps = sum(reps, []) + self._delete_slat_dec() + return reps + + @torch.no_grad() + def visualize_sample(self, x_0: Union[SparseTensor, dict]): + x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0'] + reps = self.decode_latent(x_0.cuda()) + + # Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + exts = [] + ints = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(40)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + exts.append(extrinsics) + ints.append(intrinsics) + + renderer = get_renderer(reps[0]) + images = [] + for representation in reps: + image = torch.zeros(3, 1024, 1024).cuda() + tile = [2, 2] + for j, (ext, intr) in enumerate(zip(exts, ints)): + res = renderer.render(representation, ext, intr) + image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color'] + images.append(image) + images = torch.stack(images) + + return images + + +class SLat(SLatVisMixin, StandardDatasetBase): + """ + structured latent dataset + + Args: + roots (str): path to the dataset + latent_model (str): name of the latent model + min_aesthetic_score (float): minimum aesthetic score + max_num_voxels (int): maximum number of voxels + normalization (dict): normalization stats + pretrained_slat_dec (str): name of the pretrained slat decoder + slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec + slat_dec_ckpt (str): name of the slat decoder checkpoint + """ + def __init__(self, + roots: str, + *, + latent_model: str, + min_aesthetic_score: float = 5.0, + max_num_voxels: int = 32768, + normalization: Optional[dict] = None, + pretrained_slat_dec: str = 'microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16', + slat_dec_path: Optional[str] = None, + slat_dec_ckpt: Optional[str] = None, + ): + self.normalization = normalization + self.latent_model = latent_model + self.min_aesthetic_score = min_aesthetic_score + self.max_num_voxels = max_num_voxels + self.value_range = (0, 1) + + super().__init__( + roots, + pretrained_slat_dec=pretrained_slat_dec, + slat_dec_path=slat_dec_path, + slat_dec_ckpt=slat_dec_ckpt, + ) + + self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances] + + if self.normalization is not None: + self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1) + self.std = torch.tensor(self.normalization['std']).reshape(1, -1) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata[f'latent_{self.latent_model}']] + stats['With latent'] = len(metadata) + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels] + stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata) + return metadata, stats + + def get_instance(self, root, instance): + data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + feats = torch.tensor(data['feats']).float() + if self.normalization is not None: + feats = (feats - self.mean) / self.std + return { + 'coords': coords, + 'feats': feats, + } + + @staticmethod + def collate_fn(batch, split_size=None): + if split_size is None: + group_idx = [list(range(len(batch)))] + else: + group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size) + packs = [] + for group in group_idx: + sub_batch = [batch[i] for i in group] + pack = {} + coords = [] + feats = [] + layout = [] + start = 0 + for i, b in enumerate(sub_batch): + coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1)) + feats.append(b['feats']) + layout.append(slice(start, start + b['coords'].shape[0])) + start += b['coords'].shape[0] + coords = torch.cat(coords) + feats = torch.cat(feats) + pack['x_0'] = SparseTensor( + coords=coords, + feats=feats, + ) + pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]]) + pack['x_0'].register_spatial_cache('layout', layout) + + # collate other data + keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']] + for k in keys: + if isinstance(sub_batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in sub_batch]) + elif isinstance(sub_batch[0][k], list): + pack[k] = sum([b[k] for b in sub_batch], []) + else: + pack[k] = [b[k] for b in sub_batch] + + packs.append(pack) + + if split_size is None: + return packs[0] + return packs + + +class TextConditionedSLat(TextConditionedMixin, SLat): + """ + Text conditioned structured latent dataset + """ + pass + + +class ImageConditionedSLat(ImageConditionedMixin, SLat): + """ + Image conditioned structured latent dataset + """ + pass diff --git a/trellis/datasets/structured_latent2render.py b/trellis/datasets/structured_latent2render.py new file mode 100644 index 0000000..466737a --- /dev/null +++ b/trellis/datasets/structured_latent2render.py @@ -0,0 +1,160 @@ +import os +from PIL import Image +import json +import numpy as np +import torch +import utils3d.torch +from ..modules.sparse.basic import SparseTensor +from .components import StandardDatasetBase + + +class SLat2Render(StandardDatasetBase): + """ + Dataset for Structured Latent and rendered images. + + Args: + roots (str): paths to the dataset + image_size (int): size of the image + latent_model (str): latent model name + min_aesthetic_score (float): minimum aesthetic score + max_num_voxels (int): maximum number of voxels + """ + def __init__( + self, + roots: str, + image_size: int, + latent_model: str, + min_aesthetic_score: float = 5.0, + max_num_voxels: int = 32768, + ): + self.image_size = image_size + self.latent_model = latent_model + self.min_aesthetic_score = min_aesthetic_score + self.max_num_voxels = max_num_voxels + self.value_range = (0, 1) + + super().__init__(roots) + + def filter_metadata(self, metadata): + stats = {} + metadata = metadata[metadata[f'latent_{self.latent_model}']] + stats['With latent'] = len(metadata) + metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score] + stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata) + metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels] + stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata) + return metadata, stats + + def _get_image(self, root, instance): + with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f: + metadata = json.load(f) + n_views = len(metadata['frames']) + view = np.random.randint(n_views) + metadata = metadata['frames'][view] + fov = metadata['camera_angle_x'] + intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov)) + c2w = torch.tensor(metadata['transform_matrix']) + c2w[:3, 1:3] *= -1 + extrinsics = torch.inverse(c2w) + + image_path = os.path.join(root, 'renders', instance, metadata['file_path']) + image = Image.open(image_path) + alpha = image.getchannel(3) + image = image.convert('RGB') + image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS) + image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0 + alpha = torch.tensor(np.array(alpha)).float() / 255.0 + + return { + 'image': image, + 'alpha': alpha, + 'extrinsics': extrinsics, + 'intrinsics': intrinsics, + } + + def _get_latent(self, root, instance): + data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz')) + coords = torch.tensor(data['coords']).int() + feats = torch.tensor(data['feats']).float() + return { + 'coords': coords, + 'feats': feats, + } + + @torch.no_grad() + def visualize_sample(self, sample: dict): + return sample['image'] + + @staticmethod + def collate_fn(batch): + pack = {} + coords = [] + for i, b in enumerate(batch): + coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1)) + coords = torch.cat(coords) + feats = torch.cat([b['feats'] for b in batch]) + pack['latents'] = SparseTensor( + coords=coords, + feats=feats, + ) + + # collate other data + keys = [k for k in batch[0].keys() if k not in ['coords', 'feats']] + for k in keys: + if isinstance(batch[0][k], torch.Tensor): + pack[k] = torch.stack([b[k] for b in batch]) + elif isinstance(batch[0][k], list): + pack[k] = sum([b[k] for b in batch], []) + else: + pack[k] = [b[k] for b in batch] + + return pack + + def get_instance(self, root, instance): + image = self._get_image(root, instance) + latent = self._get_latent(root, instance) + return { + **image, + **latent, + } + + +class Slat2RenderGeo(SLat2Render): + def __init__( + self, + roots: str, + image_size: int, + latent_model: str, + min_aesthetic_score: float = 5.0, + max_num_voxels: int = 32768, + ): + super().__init__( + roots, + image_size, + latent_model, + min_aesthetic_score, + max_num_voxels, + ) + + def _get_geo(self, root, instance): + verts, face = utils3d.io.read_ply(os.path.join(root, 'renders', instance, 'mesh.ply')) + mesh = { + "vertices" : torch.from_numpy(verts), + "faces" : torch.from_numpy(face), + } + return { + "mesh" : mesh, + } + + def get_instance(self, root, instance): + image = self._get_image(root, instance) + latent = self._get_latent(root, instance) + geo = self._get_geo(root, instance) + return { + **image, + **latent, + **geo, + } + + \ No newline at end of file diff --git a/trellis/models/structured_latent_flow.py b/trellis/models/structured_latent_flow.py new file mode 100644 index 0000000..4d6f61b --- /dev/null +++ b/trellis/models/structured_latent_flow.py @@ -0,0 +1,276 @@ +from typing import * +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32 +from ..modules.transformer import AbsolutePositionEmbedder +from ..modules.norm import LayerNorm32 +from ..modules import sparse as sp +from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock +from .sparse_structure_flow import TimestepEmbedder +from .sparse_elastic_mixin import SparseTransformerElasticMixin + + +class SparseResBlock3d(nn.Module): + def __init__( + self, + channels: int, + emb_channels: int, + out_channels: Optional[int] = None, + downsample: bool = False, + upsample: bool = False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.out_channels = out_channels or channels + self.downsample = downsample + self.upsample = upsample + + assert not (downsample and upsample), "Cannot downsample and upsample at the same time" + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3) + self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3)) + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(emb_channels, 2 * self.out_channels, bias=True), + ) + self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity() + self.updown = None + if self.downsample: + self.updown = sp.SparseDownsample(2) + elif self.upsample: + self.updown = sp.SparseUpsample(2) + + def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor: + if self.updown is not None: + x = self.updown(x) + return x + + def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor: + emb_out = self.emb_layers(emb).type(x.dtype) + scale, shift = torch.chunk(emb_out, 2, dim=1) + + x = self._updown(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + + return h + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + patch_size: int = 2, + num_io_res_blocks: int = 2, + io_block_channels: List[int] = None, + pe_mode: Literal["ape", "rope"] = "ape", + use_fp16: bool = False, + use_checkpoint: bool = False, + use_skip_connection: bool = True, + share_mod: bool = False, + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.patch_size = patch_size + self.num_io_res_blocks = num_io_res_blocks + self.io_block_channels = io_block_channels + self.pe_mode = pe_mode + self.use_fp16 = use_fp16 + self.use_checkpoint = use_checkpoint + self.use_skip_connection = use_skip_connection + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = torch.float16 if use_fp16 else torch.float32 + + if self.io_block_channels is not None: + assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2" + assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages" + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + if pe_mode == "ape": + self.pos_embedder = AbsolutePositionEmbedder(model_channels) + + self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0]) + + self.input_blocks = nn.ModuleList([]) + if io_block_channels is not None: + for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]): + self.input_blocks.extend([ + SparseResBlock3d( + chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + self.input_blocks.append( + SparseResBlock3d( + chs, + model_channels, + out_channels=next_chs, + downsample=True, + ) + ) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_blocks = nn.ModuleList([]) + if io_block_channels is not None: + for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))): + self.out_blocks.append( + SparseResBlock3d( + prev_chs * 2 if self.use_skip_connection else prev_chs, + model_channels, + out_channels=chs, + upsample=True, + ) + ) + self.out_blocks.extend([ + SparseResBlock3d( + chs * 2 if self.use_skip_connection else chs, + model_channels, + out_channels=chs, + ) + for _ in range(num_io_res_blocks-1) + ]) + + self.out_layer = sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], out_channels) + + self.initialize_weights() + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + """ + Return the device of the model. + """ + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.blocks.apply(convert_module_to_f16) + self.out_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self) -> None: + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.blocks.apply(convert_module_to_f32) + self.out_blocks.apply(convert_module_to_f32) + + def initialize_weights(self) -> None: + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + if self.share_mod: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + else: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.out_layer.weight, 0) + nn.init.constant_(self.out_layer.bias, 0) + + def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor: + h = self.input_layer(x).type(self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = t_emb.type(self.dtype) + cond = cond.type(self.dtype) + + skips = [] + # pack with input blocks + for block in self.input_blocks: + h = block(h, t_emb) + skips.append(h.feats) + + if self.pe_mode == "ape": + h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + # unpack with output blocks + for block, skip in zip(self.out_blocks, reversed(skips)): + if self.use_skip_connection: + h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb) + else: + h = block(h, t_emb) + + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h.type(x.dtype)) + return h + + +class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel): + """ + SLat Flow Model with elastic memory management. + Used for training with low VRAM. + """ + pass diff --git a/trellis/modules/norm.py b/trellis/modules/norm.py new file mode 100644 index 0000000..0903572 --- /dev/null +++ b/trellis/modules/norm.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class GroupNorm32(nn.GroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.float()).type(x.dtype) + + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + \ No newline at end of file diff --git a/trellis/modules/sparse/nonlinearity.py b/trellis/modules/sparse/nonlinearity.py new file mode 100644 index 0000000..f200098 --- /dev/null +++ b/trellis/modules/sparse/nonlinearity.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +from . import SparseTensor + +__all__ = [ + 'SparseReLU', + 'SparseSiLU', + 'SparseGELU', + 'SparseActivation' +] + + +class SparseReLU(nn.ReLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseSiLU(nn.SiLU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseGELU(nn.GELU): + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(super().forward(input.feats)) + + +class SparseActivation(nn.Module): + def __init__(self, activation: nn.Module): + super().__init__() + self.activation = activation + + def forward(self, input: SparseTensor) -> SparseTensor: + return input.replace(self.activation(input.feats)) + diff --git a/trellis/modules/sparse/norm.py b/trellis/modules/sparse/norm.py new file mode 100644 index 0000000..6b38a36 --- /dev/null +++ b/trellis/modules/sparse/norm.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +from . import SparseTensor +from . import DEBUG + +__all__ = [ + 'SparseGroupNorm', + 'SparseLayerNorm', + 'SparseGroupNorm32', + 'SparseLayerNorm32', +] + + +class SparseGroupNorm(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + if DEBUG: + assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch" + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseLayerNorm(nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input: SparseTensor) -> SparseTensor: + nfeats = torch.zeros_like(input.feats) + for k in range(input.shape[0]): + bfeats = input.feats[input.layout[k]] + bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1) + bfeats = super().forward(bfeats) + bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0) + nfeats[input.layout[k]] = bfeats + return input.replace(nfeats) + + +class SparseGroupNorm32(SparseGroupNorm): + """ + A GroupNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) + +class SparseLayerNorm32(SparseLayerNorm): + """ + A LayerNorm layer that converts to float32 before the forward pass. + """ + def forward(self, x: SparseTensor) -> SparseTensor: + return super().forward(x.float()).type(x.dtype) diff --git a/trellis/renderers/octree_renderer.py b/trellis/renderers/octree_renderer.py new file mode 100644 index 0000000..136069c --- /dev/null +++ b/trellis/renderers/octree_renderer.py @@ -0,0 +1,300 @@ +import numpy as np +import torch +import torch.nn.functional as F +import math +import cv2 +from scipy.stats import qmc +from easydict import EasyDict as edict +from ..representations.octree import DfsOctree + + +def intrinsics_to_projection( + intrinsics: torch.Tensor, + near: float, + far: float, + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix + near (float): near plane to clip + far (float): far plane to clip + Returns: + (torch.Tensor): [4, 4] OpenGL perspective matrix + """ + fx, fy = intrinsics[0, 0], intrinsics[1, 1] + cx, cy = intrinsics[0, 2], intrinsics[1, 2] + ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[0, 0] = 2 * fx + ret[1, 1] = 2 * fy + ret[0, 2] = 2 * cx - 1 + ret[1, 2] = - 2 * cy + 1 + ret[2, 2] = far / (far - near) + ret[2, 3] = near * far / (near - far) + ret[3, 2] = 1. + return ret + + +def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + # lazy import + if 'OctreeTrivecRasterizer' not in globals(): + from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = edict( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=octree.active_sh_degree, + campos=viewpoint_camera.camera_center, + with_distloss=pipe.with_distloss, + jitter=pipe.jitter, + debug=pipe.debug, + ) + + positions = octree.get_xyz + if octree.primitive == "voxel": + densities = octree.get_density + elif octree.primitive == "gaussian": + opacities = octree.get_opacity + elif octree.primitive == "trivec": + trivecs = octree.get_trivec + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + elif octree.primitive == "decoupoly": + decoupolys_V, decoupolys_g = octree.get_decoupoly + densities = octree.get_density + raster_settings.density_shift = octree.density_shift + else: + raise ValueError(f"Unknown primitive {octree.primitive}") + depths = octree.get_depth + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + colors_precomp = None + shs = octree.get_features + if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None: + colors_precomp = colors_overwrite + shs = None + + ret = edict() + + if octree.primitive == "voxel": + renderer = OctreeVoxelRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, distloss = renderer( + positions = positions, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + ret['distloss'] = distloss + elif octree.primitive == "gaussian": + renderer = OctreeGaussianRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + opacities = opacities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "trivec": + raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1] + renderer = OctreeTrivecRasterizer(raster_settings=raster_settings) + rgb, depth, alpha, percent_depth = renderer( + positions = positions, + trivecs = trivecs, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + colors_overwrite = colors_overwrite, + depths = depths, + aabb = octree.aabb, + aux = aux, + halton_sampler = halton_sampler, + ) + ret['percent_depth'] = percent_depth + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + elif octree.primitive == "decoupoly": + raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1] + renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings) + rgb, depth, alpha = renderer( + positions = positions, + decoupolys_V = decoupolys_V, + decoupolys_g = decoupolys_g, + densities = densities, + shs = shs, + colors_precomp = colors_precomp, + depths = depths, + aabb = octree.aabb, + aux = aux, + ) + ret['rgb'] = rgb + ret['depth'] = depth + ret['alpha'] = alpha + + return ret + + +class OctreeRenderer: + """ + Renderer for the Voxel representation. + + Args: + rendering_options (dict): Rendering options. + """ + + def __init__(self, rendering_options={}) -> None: + try: + import diffoctreerast + except ImportError: + print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m") + self.unsupported = True + else: + self.unsupported = False + + self.pipe = edict({ + "with_distloss": False, + "with_aux": False, + "scale_modifier": 1.0, + "used_rank": None, + "jitter": False, + "debug": False, + }) + self.rendering_options = edict({ + "resolution": None, + "near": None, + "far": None, + "ssaa": 1, + "bg_color": 'random', + }) + self.halton_sampler = qmc.Halton(2, scramble=False) + self.rendering_options.update(rendering_options) + self.bg_color = None + + def render( + self, + octree: DfsOctree, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + colors_overwrite: torch.Tensor = None, + ) -> edict: + """ + Render the octree. + + Args: + octree (Octree): octree + extrinsics (torch.Tensor): (4, 4) camera extrinsics + intrinsics (torch.Tensor): (3, 3) camera intrinsics + colors_overwrite (torch.Tensor): (N, 3) override color + + Returns: + edict containing: + color (torch.Tensor): (3, H, W) rendered color + depth (torch.Tensor): (H, W) rendered depth + alpha (torch.Tensor): (H, W) rendered alpha + distloss (Optional[torch.Tensor]): (H, W) rendered distance loss + percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth + aux (Optional[edict]): auxiliary tensors + """ + resolution = self.rendering_options["resolution"] + near = self.rendering_options["near"] + far = self.rendering_options["far"] + ssaa = self.rendering_options["ssaa"] + + if self.unsupported: + image = np.zeros((512, 512, 3), dtype=np.uint8) + text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0] + origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2 + image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA) + return { + 'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255, + } + + if self.rendering_options["bg_color"] == 'random': + self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda") + if np.random.rand() < 0.5: + self.bg_color += 1 + else: + self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda") + + if self.pipe["with_aux"]: + aux = { + 'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + 'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0, + } + for k in aux.keys(): + aux[k].requires_grad_() + aux[k].retain_grad() + else: + aux = None + + view = extrinsics + perspective = intrinsics_to_projection(intrinsics, near, far) + camera = torch.inverse(view)[:3, 3] + focalx = intrinsics[0, 0] + focaly = intrinsics[1, 1] + fovx = 2 * torch.atan(0.5 / focalx) + fovy = 2 * torch.atan(0.5 / focaly) + + camera_dict = edict({ + "image_height": resolution * ssaa, + "image_width": resolution * ssaa, + "FoVx": fovx, + "FoVy": fovy, + "znear": near, + "zfar": far, + "world_view_transform": view.T.contiguous(), + "projection_matrix": perspective.T.contiguous(), + "full_proj_transform": (perspective @ view).T.contiguous(), + "camera_center": camera + }) + + # Render + render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler) + + if ssaa > 1: + render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + if hasattr(render_ret, 'percent_depth'): + render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze() + + ret = edict({ + 'color': render_ret.rgb, + 'depth': render_ret.depth, + 'alpha': render_ret.alpha, + }) + if self.pipe["with_distloss"] and 'distloss' in render_ret: + ret['distloss'] = render_ret.distloss + if self.pipe["with_aux"]: + ret['aux'] = aux + if hasattr(render_ret, 'percent_depth'): + ret['percent_depth'] = render_ret.percent_depth + return ret diff --git a/trellis/representations/octree/octree_dfs.py b/trellis/representations/octree/octree_dfs.py new file mode 100644 index 0000000..c2bd4dc --- /dev/null +++ b/trellis/representations/octree/octree_dfs.py @@ -0,0 +1,347 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DfsOctree: + """ + Sparse Voxel Octree (SVO) implementation for PyTorch. + Using Depth-First Search (DFS) order to store the octree. + DFS order suits rendering and ray tracing. + + The structure and data are separatedly stored. + Structure is stored as a continuous array, each element is a 3*32 bits descriptor. + |-----------------------------------------| + | 0:3 bits | 4:31 bits | + | leaf num | unused | + |-----------------------------------------| + | 0:31 bits | + | child ptr | + |-----------------------------------------| + | 0:31 bits | + | data ptr | + |-----------------------------------------| + Each element represents a non-leaf node in the octree. + The valid mask is used to indicate whether the children are valid. + The leaf mask is used to indicate whether the children are leaf nodes. + The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr. + The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr. + + There are also auxiliary arrays to store the additional structural information to facilitate parallel processing. + - Position: the position of the octree nodes. + - Depth: the depth of the octree nodes. + + Args: + depth (int): the depth of the octree. + """ + + def __init__( + self, + depth, + aabb=[0,0,0,1,1,1], + sh_degree=2, + primitive='voxel', + primitive_config={}, + device='cuda', + ): + self.max_depth = depth + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + self.device = device + self.sh_degree = sh_degree + self.active_sh_degree = sh_degree + self.primitive = primitive + self.primitive_config = primitive_config + + self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device) + self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device) + self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device) + self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device) + self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device) + self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device) + self.depth[:, 0] = 1 + + self.data = ['position', 'depth'] + self.param_names = [] + + if primitive == 'voxel': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac'] + self.param_names += ['features_dc', 'features_ac'] + if not primitive_config.get('solid', False): + self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data.append('density') + self.param_names.append('density') + elif primitive == 'gaussian': + self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device) + self.data += ['features_dc', 'features_ac', 'opacity'] + self.param_names += ['features_dc', 'features_ac', 'opacity'] + elif primitive == 'trivec': + self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['trivec', 'density', 'features_dc', 'features_ac'] + self.param_names += ['trivec', 'density', 'features_dc', 'features_ac'] + elif primitive == 'decoupoly': + self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device) + self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device) + self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device) + self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device) + self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device) + self.density_shift = 0 + self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac'] + + self.setup_functions() + + def setup_functions(self): + self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x) + self.opacity_activation = lambda x: torch.sigmoid(x - 6) + self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6 + self.color_activation = lambda x: torch.sigmoid(x) + + @property + def num_non_leaf_nodes(self): + return self.structure.shape[0] + + @property + def num_leaf_nodes(self): + return self.depth.shape[0] + + @property + def cur_depth(self): + return self.depth.max().item() + + @property + def occupancy(self): + return self.num_leaf_nodes / 8 ** self.cur_depth + + @property + def get_xyz(self): + return self.position + + @property + def get_depth(self): + return self.depth + + @property + def get_density(self): + if self.primitive == 'voxel' and self.primitive_config.get('solid', False): + return torch.full((self.position.shape[0], 1), torch.finfo(torch.float32).max, dtype=torch.float32, device=self.device) + return self.density_activation(self.density) + + @property + def get_opacity(self): + return self.opacity_activation(self.density) + + @property + def get_trivec(self): + return self.trivec + + @property + def get_decoupoly(self): + return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g + + @property + def get_color(self): + return self.color_activation(self.colors) + + @property + def get_features(self): + if self.sh_degree == 0: + return self.features_dc + return torch.cat([self.features_dc, self.features_ac], dim=-2) + + def state_dict(self): + ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'primitive_config': self.primitive_config, 'primitive': self.primitive} + if hasattr(self, 'density_shift'): + ret['density_shift'] = self.density_shift + for data in set(self.data + self.param_names): + if not isinstance(getattr(self, data), nn.Module): + ret[data] = getattr(self, data) + else: + ret[data] = getattr(self, data).state_dict() + return ret + + def load_state_dict(self, state_dict): + keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth'])) + for key in keys: + if key not in state_dict: + print(f"Warning: key {key} not found in the state_dict.") + continue + try: + if not isinstance(getattr(self, key), nn.Module): + setattr(self, key, state_dict[key]) + else: + getattr(self, key).load_state_dict(state_dict[key]) + except Exception as e: + print(e) + raise ValueError(f"Error loading key {key}.") + + def gather_from_leaf_children(self, data): + """ + Gather the data from the leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + leaf_cnt = self.structure[:, 0] + leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device) + for i in range(8): + if leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[leaf_cnt_masks[i], 2] + for j in range(i+1): + ret[leaf_cnt_masks[i]] += data[start + j] + return ret + + def gather_from_non_leaf_children(self, data): + """ + Gather the data from the non-leaf children. + + Args: + data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes. + """ + non_leaf_cnt = 8 - self.structure[:, 0] + non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)] + ret = torch.zeros_like(data, device=self.device) + for i in range(8): + if non_leaf_cnt_masks[i].sum() == 0: + continue + start = self.structure[non_leaf_cnt_masks[i], 1] + for j in range(i+1): + ret[non_leaf_cnt_masks[i]] += data[start + j] + return ret + + def structure_control(self, mask): + """ + Control the structure of the octree. + + Args: + mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep. + """ + # Dont subdivide when the depth is the maximum. + mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0) + # Dont merge when the depth is the minimum. + mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0) + + # Gather control mask + structre_ctrl = self.gather_from_leaf_children(mask) + structre_ctrl[structre_ctrl==-8] = -1 + + new_leaf_num = self.structure[:, 0].clone() + # Modify the leaf num. + structre_valid = structre_ctrl >= 0 + new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes. + structre_delete = structre_ctrl < 0 + merged_nodes = self.gather_from_non_leaf_children(structre_delete.int()) + new_leaf_num += merged_nodes # Delete the merged nodes. + + # Update the structure array to allocate new nodes. + mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes. + mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes. + new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_structure_length = new_structre_idx[-1].item() + new_structre_idx = new_structre_idx[:-1] + new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device) + new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid] + + # Initialize the new nodes. + new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device) + new_node_mask[new_structre_idx[structre_valid]] = False + new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes. + new_node_num = new_node_mask.sum().item() + + # Rebuild child ptr. + non_leaf_cnt = 8 - new_structure[:, 0] + new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 1] = new_child_ptr + 1 + + # Rebuild data ptr with old data. + leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device) + leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0]) + old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + + # Update the data array + subdivide_mask = mask == 1 + merge_mask = mask == -1 + data_valid = ~(subdivide_mask | merge_mask) + mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device) + mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes + mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes + mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes + mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes + new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0) + new_data_length = new_data_idx[-1].item() + new_data_idx = new_data_idx[:-1] + new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data} + for data in self.data: + new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid] + + # Rebuild data ptr + leaf_cnt = new_structure[:, 0] + new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]]) + new_structure[:, 2] = new_data_ptr + + # Initialize the new data array + ## For subdivide nodes + if subdivide_mask.sum() > 0: + subdivide_data_ptr = new_structure[new_node_mask, 2] + for data in self.data: + for i in range(8): + if data == 'position': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5 + scale = 2 ** (-1.0 - self.depth[subdivide_mask]) + new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale + elif data == 'depth': + new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask]))) + elif data == 'trivec': + offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5 + coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1) + axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1) + coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1 + new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True) + else: + new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask] + ## For merge nodes + if merge_mask.sum() > 0: + merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device) + merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]]) + for i in range(8): + merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i + old_merge_data_ptr = self.structure[structre_delete, 2] + for data in self.data: + if data == 'position': + scale = 2 ** (1.0 - self.depth[old_merge_data_ptr]) + new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5 + elif data == 'depth': + new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1 + elif data == 'opacity': + new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2) + elif data == 'trivec': + new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr] + else: + new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr] + + # Update the structure and data array + self.structure = new_structure + for data in self.data: + setattr(self, data, new_data[data]) + + # Save data array control temp variables + self.data_rearrange_buffer = { + 'subdivide_mask': subdivide_mask, + 'merge_mask': merge_mask, + 'data_valid': data_valid, + 'new_data_idx': new_data_idx, + 'new_data_length': new_data_length, + 'new_data': new_data + } diff --git a/trellis/representations/radiance_field/strivec.py b/trellis/representations/radiance_field/strivec.py new file mode 100644 index 0000000..8fc4b74 --- /dev/null +++ b/trellis/representations/radiance_field/strivec.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..octree import DfsOctree as Octree + + +class Strivec(Octree): + def __init__( + self, + resolution: int, + aabb: list, + sh_degree: int = 0, + rank: int = 8, + dim: int = 8, + device: str = "cuda", + ): + assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2" + self.resolution = resolution + depth = int(np.round(np.log2(resolution))) + super().__init__( + depth=depth, + aabb=aabb, + sh_degree=sh_degree, + primitive="trivec", + primitive_config={"rank": rank, "dim": dim}, + device=device, + ) diff --git a/trellis/trainers/vae/structured_latent_vae_gaussian.py b/trellis/trainers/vae/structured_latent_vae_gaussian.py new file mode 100644 index 0000000..29ff365 --- /dev/null +++ b/trellis/trainers/vae/structured_latent_vae_gaussian.py @@ -0,0 +1,275 @@ +from typing import * +import copy +import torch +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict +import utils3d.torch + +from ..basic import BasicTrainer +from ...representations import Gaussian +from ...renderers import GaussianRenderer +from ...modules.sparse import SparseTensor +from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips + + +class SLatVaeGaussianTrainer(BasicTrainer): + """ + Trainer for structured latent VAE. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + loss_type (str): Loss type. Can be 'l1', 'l2' + lambda_ssim (float): SSIM loss weight. + lambda_lpips (float): LPIPS loss weight. + lambda_kl (float): KL loss weight. + regularizations (dict): Regularization config. + """ + + def __init__( + self, + *args, + loss_type: str = 'l1', + lambda_ssim: float = 0.2, + lambda_lpips: float = 0.2, + lambda_kl: float = 1e-6, + regularizations: Dict = {}, + **kwargs + ): + super().__init__(*args, **kwargs) + self.loss_type = loss_type + self.lambda_ssim = lambda_ssim + self.lambda_lpips = lambda_lpips + self.lambda_kl = lambda_kl + self.regularizations = regularizations + + self._init_renderer() + + def _init_renderer(self): + rendering_options = {"near" : 0.8, + "far" : 1.6, + "bg_color" : 'random'} + self.renderer = GaussianRenderer(rendering_options) + self.renderer.pipe.kernel_size = self.models['decoder'].rep_config['2d_filter_kernel_size'] + + def _render_batch(self, reps: List[Gaussian], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor: + """ + Render a batch of representations. + + Args: + reps: The dictionary of lists of representations. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + """ + ret = None + for i, representation in enumerate(reps): + render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i]) + if ret is None: + ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']} + for k, v in render_pack.items(): + ret[k].append(v) + ret['bg_color'].append(self.renderer.bg_color) + for k, v in ret.items(): + ret[k] = torch.stack(v, dim=0) + return ret + + @torch.no_grad() + def _get_status(self, z: SparseTensor, reps: List[Gaussian]) -> Dict: + xyz = torch.cat([g.get_xyz for g in reps], dim=0) + xyz_base = (z.coords[:, 1:].float() + 0.5) / self.models['decoder'].resolution - 0.5 + offset = xyz - xyz_base.unsqueeze(1).expand(-1, self.models['decoder'].rep_config['num_gaussians'], -1).reshape(-1, 3) + status = { + 'xyz': xyz, + 'offset': offset, + 'scale': torch.cat([g.get_scaling for g in reps], dim=0), + 'opacity': torch.cat([g.get_opacity for g in reps], dim=0), + } + + for k in list(status.keys()): + status[k] = { + 'mean': status[k].mean().item(), + 'max': status[k].max().item(), + 'min': status[k].min().item(), + } + + return status + + def _get_regularization_loss(self, reps: List[Gaussian]) -> Tuple[torch.Tensor, Dict]: + loss = 0.0 + terms = {} + if 'lambda_vol' in self.regularizations: + scales = torch.cat([g.get_scaling for g in reps], dim=0) # [N x 3] + volume = torch.prod(scales, dim=1) # [N] + terms[f'reg_vol'] = volume.mean() + loss = loss + self.regularizations['lambda_vol'] * terms[f'reg_vol'] + if 'lambda_opacity' in self.regularizations: + opacity = torch.cat([g.get_opacity for g in reps], dim=0) + terms[f'reg_opacity'] = (opacity - 1).pow(2).mean() + loss = loss + self.regularizations['lambda_opacity'] * terms[f'reg_opacity'] + return loss, terms + + def training_losses( + self, + feats: SparseTensor, + image: torch.Tensor, + alpha: torch.Tensor, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_aux: bool = False, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + feats: The [N x * x C] sparse tensor of features. + image: The [N x 3 x H x W] tensor of images. + alpha: The [N x H x W] tensor of alpha channels. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + return_aux: Whether to return auxiliary information. + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + """ + z, mean, logvar = self.training_models['encoder'](feats, sample_posterior=True, return_raw=True) + reps = self.training_models['decoder'](z) + self.renderer.rendering_options.resolution = image.shape[-1] + render_results = self._render_batch(reps, extrinsics, intrinsics) + + terms = edict(loss = 0.0, rec = 0.0) + + rec_image = render_results['color'] + gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None] + + if self.loss_type == 'l1': + terms["l1"] = l1_loss(rec_image, gt_image) + terms["rec"] = terms["rec"] + terms["l1"] + elif self.loss_type == 'l2': + terms["l2"] = l2_loss(rec_image, gt_image) + terms["rec"] = terms["rec"] + terms["l2"] + else: + raise ValueError(f"Invalid loss type: {self.loss_type}") + if self.lambda_ssim > 0: + terms["ssim"] = 1 - ssim(rec_image, gt_image) + terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"] + if self.lambda_lpips > 0: + terms["lpips"] = lpips(rec_image, gt_image) + terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"] + terms["loss"] = terms["loss"] + terms["rec"] + + terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1) + terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"] + + reg_loss, reg_terms = self._get_regularization_loss(reps) + terms.update(reg_terms) + terms["loss"] = terms["loss"] + reg_loss + + status = self._get_status(z, reps) + + if return_aux: + return terms, status, {'rec_image': rec_image, 'gt_image': gt_image} + return terms, status + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + ret_dict = {} + gt_images = [] + exts = [] + ints = [] + reps = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch].cuda() for k, v in data.items()} + gt_images.append(args['image'] * args['alpha'][:, None]) + exts.append(args['extrinsics']) + ints.append(args['intrinsics']) + z = self.models['encoder'](args['feats'], sample_posterior=True, return_raw=False) + reps.extend(self.models['decoder'](z)) + gt_images = torch.cat(gt_images, dim=0) + ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}}) + + # render single view + exts = torch.cat(exts, dim=0) + ints = torch.cat(ints, dim=0) + self.renderer.rendering_options.bg_color = (0, 0, 0) + self.renderer.rendering_options.resolution = gt_images.shape[-1] + render_results = self._render_batch(reps, exts, ints) + ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}}) + + # render multiview + self.renderer.rendering_options.resolution = 512 + ## Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + ## render each view + miltiview_images = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(30)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1) + intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1) + render_results = self._render_batch(reps, extrinsics, intrinsics) + miltiview_images.append(render_results['color']) + + ## Concatenate views + miltiview_images = torch.cat([ + torch.cat(miltiview_images[:2], dim=-2), + torch.cat(miltiview_images[2:], dim=-2), + ], dim=-1) + ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}}) + + self.renderer.rendering_options.bg_color = 'random' + + return ret_dict diff --git a/trellis/trainers/vae/structured_latent_vae_mesh_dec.py b/trellis/trainers/vae/structured_latent_vae_mesh_dec.py new file mode 100644 index 0000000..f3c9a6b --- /dev/null +++ b/trellis/trainers/vae/structured_latent_vae_mesh_dec.py @@ -0,0 +1,382 @@ +from typing import * +import copy +import torch +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict +import utils3d.torch + +from ..basic import BasicTrainer +from ...representations import MeshExtractResult +from ...renderers import MeshRenderer +from ...modules.sparse import SparseTensor +from ...utils.loss_utils import l1_loss, smooth_l1_loss, ssim, lpips +from ...utils.data_utils import recursive_to_device + + +class SLatVaeMeshDecoderTrainer(BasicTrainer): + """ + Trainer for structured latent VAE Mesh Decoder. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + loss_type (str): Loss type. Can be 'l1', 'l2' + lambda_ssim (float): SSIM loss weight. + lambda_lpips (float): LPIPS loss weight. + """ + + def __init__( + self, + *args, + depth_loss_type: str = 'l1', + lambda_depth: int = 1, + lambda_ssim: float = 0.2, + lambda_lpips: float = 0.2, + lambda_tsdf: float = 0.01, + lambda_color: float = 0.1, + **kwargs + ): + super().__init__(*args, **kwargs) + self.depth_loss_type = depth_loss_type + self.lambda_depth = lambda_depth + self.lambda_ssim = lambda_ssim + self.lambda_lpips = lambda_lpips + self.lambda_tsdf = lambda_tsdf + self.lambda_color = lambda_color + self.use_color = self.lambda_color > 0 + + self._init_renderer() + + def _init_renderer(self): + rendering_options = {"near" : 1, + "far" : 3} + self.renderer = MeshRenderer(rendering_options, device=self.device) + + def _render_batch(self, reps: List[MeshExtractResult], extrinsics: torch.Tensor, intrinsics: torch.Tensor, + return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]: + """ + Render a batch of representations. + + Args: + reps: The dictionary of lists of representations. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color'] + + Returns: + a dict with + reg_loss : [N] tensor of regularization losses + mask : [N x 1 x H x W] tensor of rendered masks + normal : [N x 3 x H x W] tensor of rendered normals + depth : [N x 1 x H x W] tensor of rendered depths + """ + ret = {k : [] for k in return_types} + for i, rep in enumerate(reps): + out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types) + for k in out_dict: + ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k]) + for k in ret: + ret[k] = torch.stack(ret[k]) + return ret + + @staticmethod + def _tsdf_reg_loss(rep: MeshExtractResult, depth_map: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor: + # Calculate tsdf + with torch.no_grad(): + # Project points to camera and calculate pseudo-sdf as difference between gt depth and projected depth + projected_pts, pts_depth = utils3d.torch.project_cv(extrinsics=extrinsics, intrinsics=intrinsics, points=rep.tsdf_v) + projected_pts = (projected_pts - 0.5) * 2.0 + depth_map_res = depth_map.shape[1] + gt_depth = torch.nn.functional.grid_sample(depth_map.reshape(1, 1, depth_map_res, depth_map_res), + projected_pts.reshape(1, 1, -1, 2), mode='bilinear', padding_mode='border', align_corners=True) + pseudo_sdf = gt_depth.flatten() - pts_depth.flatten() + # Truncate pseudo-sdf + delta = 1 / rep.res * 3.0 + trunc_mask = pseudo_sdf > -delta + + # Loss + gt_tsdf = pseudo_sdf[trunc_mask] + tsdf = rep.tsdf_s.flatten()[trunc_mask] + gt_tsdf = torch.clamp(gt_tsdf, -delta, delta) + return torch.mean((tsdf - gt_tsdf) ** 2) + + def _calc_tsdf_loss(self, reps : list[MeshExtractResult], depth_maps, extrinsics, intrinsics) -> torch.Tensor: + tsdf_loss = 0.0 + for i, rep in enumerate(reps): + tsdf_loss += self._tsdf_reg_loss(rep, depth_maps[i], extrinsics[i], intrinsics[i]) + return tsdf_loss / len(reps) + + @torch.no_grad() + def _flip_normal(self, normal: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor: + """ + Flip normal to align with camera. + """ + normal = normal * 2.0 - 1.0 + R = torch.zeros_like(extrinsics) + R[:, :3, :3] = extrinsics[:, :3, :3] + R[:, 3, 3] = 1.0 + view_dir = utils3d.torch.unproject_cv( + utils3d.torch.image_uv(*normal.shape[-2:], device=self.device).reshape(1, -1, 2), + torch.ones(*normal.shape[-2:], device=self.device).reshape(1, -1), + R, intrinsics + ).reshape(-1, *normal.shape[-2:], 3).permute(0, 3, 1, 2) + unflip = (normal * view_dir).sum(1, keepdim=True) < 0 + normal *= unflip * 2.0 - 1.0 + return (normal + 1.0) / 2.0 + + def _perceptual_loss(self, gt: torch.Tensor, pred: torch.Tensor, name: str) -> Dict[str, torch.Tensor]: + """ + Combination of L1, SSIM, and LPIPS loss. + """ + if gt.shape[1] != 3: + assert gt.shape[-1] == 3 + gt = gt.permute(0, 3, 1, 2) + if pred.shape[1] != 3: + assert pred.shape[-1] == 3 + pred = pred.permute(0, 3, 1, 2) + terms = { + f"{name}_loss" : l1_loss(gt, pred), + f"{name}_loss_ssim" : 1 - ssim(gt, pred), + f"{name}_loss_lpips" : lpips(gt, pred) + } + terms[f"{name}_loss_perceptual"] = terms[f"{name}_loss"] + terms[f"{name}_loss_ssim"] * self.lambda_ssim + terms[f"{name}_loss_lpips"] * self.lambda_lpips + return terms + + def geometry_losses( + self, + reps: List[MeshExtractResult], + mesh: List[Dict], + normal_map: torch.Tensor, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + ): + with torch.no_grad(): + gt_meshes = [] + for i in range(len(reps)): + gt_mesh = MeshExtractResult(mesh[i]['vertices'].to(self.device), mesh[i]['faces'].to(self.device)) + gt_meshes.append(gt_mesh) + target = self._render_batch(gt_meshes, extrinsics, intrinsics, return_types=['mask', 'depth', 'normal']) + target['normal'] = self._flip_normal(target['normal'], extrinsics, intrinsics) + + terms = edict(geo_loss = 0.0) + if self.lambda_tsdf > 0: + tsdf_loss = self._calc_tsdf_loss(reps, target['depth'], extrinsics, intrinsics) + terms['tsdf_loss'] = tsdf_loss + terms['geo_loss'] += tsdf_loss * self.lambda_tsdf + + return_types = ['mask', 'depth', 'normal', 'normal_map'] if self.use_color else ['mask', 'depth', 'normal'] + buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types) + + success_mask = torch.tensor([rep.success for rep in reps], device=self.device) + if success_mask.sum() != 0: + for k, v in buffer.items(): + buffer[k] = v[success_mask] + for k, v in target.items(): + target[k] = v[success_mask] + + terms['mask_loss'] = l1_loss(buffer['mask'], target['mask']) + if self.depth_loss_type == 'l1': + terms['depth_loss'] = l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask']) + elif self.depth_loss_type == 'smooth_l1': + terms['depth_loss'] = smooth_l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'], beta=1.0 / (2 * reps[0].res)) + else: + raise ValueError(f"Unsupported depth loss type: {self.depth_loss_type}") + terms.update(self._perceptual_loss(buffer['normal'] * target['mask'], target['normal'] * target['mask'], 'normal')) + terms['geo_loss'] = terms['geo_loss'] + terms['mask_loss'] + terms['depth_loss'] * self.lambda_depth + terms['normal_loss_perceptual'] + if self.use_color and normal_map is not None: + terms.update(self._perceptual_loss(normal_map[success_mask], buffer['normal_map'], 'normal_map')) + terms['geo_loss'] = terms['geo_loss'] + terms['normal_map_loss_perceptual'] * self.lambda_color + + return terms + + def color_losses(self, reps, image, alpha, extrinsics, intrinsics): + terms = edict(color_loss = torch.tensor(0.0, device=self.device)) + buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=['color']) + success_mask = torch.tensor([rep.success for rep in reps], device=self.device) + if success_mask.sum() != 0: + terms.update(self._perceptual_loss((image * alpha[:, None])[success_mask], buffer['color'][success_mask], 'color')) + terms['color_loss'] = terms['color_loss'] + terms['color_loss_perceptual'] * self.lambda_color + return terms + + def training_losses( + self, + latents: SparseTensor, + image: torch.Tensor, + alpha: torch.Tensor, + mesh: List[Dict], + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + normal_map: torch.Tensor = None, + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + latents: The [N x * x C] sparse latents + image: The [N x 3 x H x W] tensor of images. + alpha: The [N x H x W] tensor of alpha channels. + mesh: The list of dictionaries of meshes. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + """ + reps = self.training_models['decoder'](latents) + self.renderer.rendering_options.resolution = image.shape[-1] + + terms = edict(loss = 0.0, rec = 0.0) + + terms['reg_loss'] = sum([rep.reg_loss for rep in reps]) / len(reps) + terms['loss'] = terms['loss'] + terms['reg_loss'] + + geo_terms = self.geometry_losses(reps, mesh, normal_map, extrinsics, intrinsics) + terms.update(geo_terms) + terms['loss'] = terms['loss'] + terms['geo_loss'] + + if self.use_color: + color_terms = self.color_losses(reps, image, alpha, extrinsics, intrinsics) + terms.update(color_terms) + terms['loss'] = terms['loss'] + terms['color_loss'] + + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + ret_dict = {} + gt_images = [] + gt_normal_maps = [] + gt_meshes = [] + exts = [] + ints = [] + reps = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = recursive_to_device(data, 'cuda') + gt_images.append(args['image'] * args['alpha'][:, None]) + if self.use_color and 'normal_map' in data: + gt_normal_maps.append(args['normal_map']) + gt_meshes.extend(args['mesh']) + exts.append(args['extrinsics']) + ints.append(args['intrinsics']) + reps.extend(self.models['decoder'](args['latents'])) + gt_images = torch.cat(gt_images, dim=0) + ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}}) + if self.use_color and gt_normal_maps: + gt_normal_maps = torch.cat(gt_normal_maps, dim=0) + ret_dict.update({f'gt_normal_map': {'value': gt_normal_maps, 'type': 'image'}}) + + # render single view + exts = torch.cat(exts, dim=0) + ints = torch.cat(ints, dim=0) + self.renderer.rendering_options.bg_color = (0, 0, 0) + self.renderer.rendering_options.resolution = gt_images.shape[-1] + gt_render_results = self._render_batch([ + MeshExtractResult(vertices=mesh['vertices'].to(self.device), faces=mesh['faces'].to(self.device)) + for mesh in gt_meshes + ], exts, ints, return_types=['normal']) + ret_dict.update({f'gt_normal': {'value': self._flip_normal(gt_render_results['normal'], exts, ints), 'type': 'image'}}) + return_types = ['normal'] + if self.use_color: + return_types.append('color') + if 'normal_map' in data: + return_types.append('normal_map') + render_results = self._render_batch(reps, exts, ints, return_types=return_types) + ret_dict.update({f'rec_normal': {'value': render_results['normal'], 'type': 'image'}}) + if 'color' in return_types: + ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}}) + if 'normal_map' in return_types: + ret_dict.update({f'rec_normal_map': {'value': render_results['normal_map'], 'type': 'image'}}) + + # render multiview + self.renderer.rendering_options.resolution = 512 + ## Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + ## render each view + multiview_normals = [] + multiview_normal_maps = [] + miltiview_images = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(30)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1) + intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1) + render_results = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types) + multiview_normals.append(render_results['normal']) + if 'color' in return_types: + miltiview_images.append(render_results['color']) + if 'normal_map' in return_types: + multiview_normal_maps.append(render_results['normal_map']) + + ## Concatenate views + multiview_normals = torch.cat([ + torch.cat(multiview_normals[:2], dim=-2), + torch.cat(multiview_normals[2:], dim=-2), + ], dim=-1) + ret_dict.update({f'multiview_normal': {'value': multiview_normals, 'type': 'image'}}) + if 'color' in return_types: + miltiview_images = torch.cat([ + torch.cat(miltiview_images[:2], dim=-2), + torch.cat(miltiview_images[2:], dim=-2), + ], dim=-1) + ret_dict.update({f'multiview_image': {'value': miltiview_images, 'type': 'image'}}) + if 'normal_map' in return_types: + multiview_normal_maps = torch.cat([ + torch.cat(multiview_normal_maps[:2], dim=-2), + torch.cat(multiview_normal_maps[2:], dim=-2), + ], dim=-1) + ret_dict.update({f'multiview_normal_map': {'value': multiview_normal_maps, 'type': 'image'}}) + + return ret_dict diff --git a/trellis/trainers/vae/structured_latent_vae_rf_dec.py b/trellis/trainers/vae/structured_latent_vae_rf_dec.py new file mode 100644 index 0000000..6021ea1 --- /dev/null +++ b/trellis/trainers/vae/structured_latent_vae_rf_dec.py @@ -0,0 +1,223 @@ +from typing import * +import copy +import torch +from torch.utils.data import DataLoader +import numpy as np +from easydict import EasyDict as edict +import utils3d.torch + +from ..basic import BasicTrainer +from ...representations import Strivec +from ...renderers import OctreeRenderer +from ...modules.sparse import SparseTensor +from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips + + +class SLatVaeRadianceFieldDecoderTrainer(BasicTrainer): + """ + Trainer for structured latent VAE Radiance Field Decoder. + + Args: + models (dict[str, nn.Module]): Models to train. + dataset (torch.utils.data.Dataset): Dataset. + output_dir (str): Output directory. + load_dir (str): Load directory. + step (int): Step to load. + batch_size (int): Batch size. + batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored. + batch_split (int): Split batch with gradient accumulation. + max_steps (int): Max steps. + optimizer (dict): Optimizer config. + lr_scheduler (dict): Learning rate scheduler config. + elastic (dict): Elastic memory management config. + grad_clip (float or dict): Gradient clip config. + ema_rate (float or list): Exponential moving average rates. + fp16_mode (str): FP16 mode. + - None: No FP16. + - 'inflat_all': Hold a inflated fp32 master param for all params. + - 'amp': Automatic mixed precision. + fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation. + finetune_ckpt (dict): Finetune checkpoint. + log_param_stats (bool): Log parameter stats. + i_print (int): Print interval. + i_log (int): Log interval. + i_sample (int): Sample interval. + i_save (int): Save interval. + i_ddpcheck (int): DDP check interval. + + loss_type (str): Loss type. Can be 'l1', 'l2' + lambda_ssim (float): SSIM loss weight. + lambda_lpips (float): LPIPS loss weight. + """ + + def __init__( + self, + *args, + loss_type: str = 'l1', + lambda_ssim: float = 0.2, + lambda_lpips: float = 0.2, + **kwargs + ): + super().__init__(*args, **kwargs) + self.loss_type = loss_type + self.lambda_ssim = lambda_ssim + self.lambda_lpips = lambda_lpips + + self._init_renderer() + + def _init_renderer(self): + rendering_options = {"near" : 0.8, + "far" : 1.6, + "bg_color" : 'random'} + self.renderer = OctreeRenderer(rendering_options) + self.renderer.pipe.primitive = 'trivec' + + def _render_batch(self, reps: List[Strivec], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor: + """ + Render a batch of representations. + + Args: + reps: The dictionary of lists of representations. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + """ + ret = None + for i, representation in enumerate(reps): + render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i]) + if ret is None: + ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']} + for k, v in render_pack.items(): + ret[k].append(v) + ret['bg_color'].append(self.renderer.bg_color) + for k, v in ret.items(): + ret[k] = torch.stack(v, dim=0) + return ret + + def training_losses( + self, + latents: SparseTensor, + image: torch.Tensor, + alpha: torch.Tensor, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + return_aux: bool = False, + **kwargs + ) -> Tuple[Dict, Dict]: + """ + Compute training losses. + + Args: + latents: The [N x * x C] sparse latents + image: The [N x 3 x H x W] tensor of images. + alpha: The [N x H x W] tensor of alpha channels. + extrinsics: The [N x 4 x 4] tensor of extrinsics. + intrinsics: The [N x 3 x 3] tensor of intrinsics. + return_aux: Whether to return auxiliary information. + + Returns: + a dict with the key "loss" containing a scalar tensor. + may also contain other keys for different terms. + """ + reps = self.training_models['decoder'](latents) + self.renderer.rendering_options.resolution = image.shape[-1] + render_results = self._render_batch(reps, extrinsics, intrinsics) + + terms = edict(loss = 0.0, rec = 0.0) + + rec_image = render_results['color'] + gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None] + + if self.loss_type == 'l1': + terms["l1"] = l1_loss(rec_image, gt_image) + terms["rec"] = terms["rec"] + terms["l1"] + elif self.loss_type == 'l2': + terms["l2"] = l2_loss(rec_image, gt_image) + terms["rec"] = terms["rec"] + terms["l2"] + else: + raise ValueError(f"Invalid loss type: {self.loss_type}") + if self.lambda_ssim > 0: + terms["ssim"] = 1 - ssim(rec_image, gt_image) + terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"] + if self.lambda_lpips > 0: + terms["lpips"] = lpips(rec_image, gt_image) + terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"] + terms["loss"] = terms["loss"] + terms["rec"] + + if return_aux: + return terms, {}, {'rec_image': rec_image, 'gt_image': gt_image} + return terms, {} + + @torch.no_grad() + def run_snapshot( + self, + num_samples: int, + batch_size: int, + verbose: bool = False, + ) -> Dict: + dataloader = DataLoader( + copy.deepcopy(self.dataset), + batch_size=batch_size, + shuffle=True, + num_workers=0, + collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None, + ) + + # inference + ret_dict = {} + gt_images = [] + exts = [] + ints = [] + reps = [] + for i in range(0, num_samples, batch_size): + batch = min(batch_size, num_samples - i) + data = next(iter(dataloader)) + args = {k: v[:batch].cuda() for k, v in data.items()} + gt_images.append(args['image'] * args['alpha'][:, None]) + exts.append(args['extrinsics']) + ints.append(args['intrinsics']) + reps.extend(self.models['decoder'](args['latents'])) + gt_images = torch.cat(gt_images, dim=0) + ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}}) + + # render single view + exts = torch.cat(exts, dim=0) + ints = torch.cat(ints, dim=0) + self.renderer.rendering_options.bg_color = (0, 0, 0) + self.renderer.rendering_options.resolution = gt_images.shape[-1] + render_results = self._render_batch(reps, exts, ints) + ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}}) + + # render multiview + self.renderer.rendering_options.resolution = 512 + ## Build camera + yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2] + yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4) + yaws = [y + yaws_offset for y in yaws] + pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)] + + ## render each view + miltiview_images = [] + for yaw, pitch in zip(yaws, pitch): + orig = torch.tensor([ + np.sin(yaw) * np.cos(pitch), + np.cos(yaw) * np.cos(pitch), + np.sin(pitch), + ]).float().cuda() * 2 + fov = torch.deg2rad(torch.tensor(30)).cuda() + extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda()) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov) + extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1) + intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1) + render_results = self._render_batch(reps, extrinsics, intrinsics) + miltiview_images.append(render_results['color']) + + ## Concatenate views + miltiview_images = torch.cat([ + torch.cat(miltiview_images[:2], dim=-2), + torch.cat(miltiview_images[2:], dim=-2), + ], dim=-1) + ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}}) + + self.renderer.rendering_options.bg_color = 'random' + + return ret_dict diff --git a/utils/new_oss_client.py b/utils/new_oss_client.py new file mode 100644 index 0000000..4817579 --- /dev/null +++ b/utils/new_oss_client.py @@ -0,0 +1,140 @@ +import io +import logging +import mimetypes +import os +from io import BytesIO +import urllib3 +from PIL import Image +from minio import Minio + +MINIO_URL = "www.minio-api.aida.com.hk" +MINIO_ACCESS = "vXKFLSJkYeEq2DrSZvkB" +MINIO_SECRET = "uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR" +MINIO_SECURE = True +MINIO_BUCKET = "test" + +from minio import Minio +import urllib3 + + +class CustomRetry(urllib3.Retry): + def increment(self, method=None, url=None, response=None, error=None, **kwargs): + # 调用父类的 increment 方法 + new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs) + # 打印重试信息 + logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}") + return new_retry + + +http_client = urllib3.PoolManager( + num_pools=20, + maxsize=50, + timeout=urllib3.Timeout(connect=2, read=30), + cert_reqs='CERT_REQUIRED', # 需要证书验证 + retries=CustomRetry( + total=5, + backoff_factor=0.2, + status_forcelist=[500, 502, 503, 504], + ), + +) + +minio_client = Minio( + MINIO_URL, + access_key=MINIO_ACCESS, + secret_key=MINIO_SECRET, + secure=MINIO_SECURE, + http_client=http_client +) + + +def minio_get_image(client, bucket, object_name): + response = None + try: + response = client.get_object(bucket, object_name) + # 直接读取 bytes + image_bytes = response.read() + image = Image.open(BytesIO(image_bytes)).convert("RGB") + return image + except Exception as e: + logger.error(f"读取 MinIO 图片失败: {bucket}/{object_name} | {e}") + return None + finally: + if response: + response.close() + response.release_conn() + + +def upload_bytes(data_bytes, object_name, content_type): + minio_client.put_object( + bucket_name=MINIO_BUCKET, + object_name=object_name, + data=BytesIO(data_bytes), + length=len(data_bytes), + content_type=content_type + ) + return f"{MINIO_BUCKET}/{object_name}" + + +def upload_local_file(file_path, object_name, content_type="application/octet-stream"): + """ + 将本地磁盘上的文件上传到 MinIO + :param file_path: 本地文件路径 (如: 'output/sample.obj') + :param object_name: MinIO 中的存储路径/文件名 + :param content_type: 文件 MIME 类型 + """ + # 健壮性检查:确保本地文件确实存在 + if not os.path.exists(file_path): + raise FileNotFoundError(f"本地文件未找到: {file_path}") + + # 使用 fput_object 直接从磁盘流式上传,不占用过多内存 + minio_client.fput_object( + bucket_name=MINIO_BUCKET, + object_name=object_name, + file_path=file_path, + content_type=content_type + ) + + return f"{MINIO_BUCKET}/{object_name}" + + +def download_from_minio(object_path, local_path): + """ + 从 MinIO 下载文件到本地磁盘 + """ + import os + from minio.error import S3Error + + local_dir = os.path.dirname(local_path) + if local_dir: + os.makedirs(local_dir, exist_ok=True) + + path_parts = object_path.split("/", 1) + bucket_name = path_parts[0] + object_name = path_parts[1] + try: + minio_client.fget_object( + bucket_name=bucket_name, + object_name=object_name, + file_path=local_path + ) + if os.path.exists(local_path): + return local_path + else: + raise RuntimeError("下载文件后本地路径不存在") + except S3Error as err: + if err.code == 'NoSuchKey': + raise FileNotFoundError(f"MinIO 对象不存在: {bucket_name}/{object_name}") + raise + except Exception as e: + raise RuntimeError(f"MinIO 下载失败: {str(e)}") + + +logger = logging.getLogger() + +if __name__ == '__main__': + url = "fida-test/furniture/sketches/4449a66d-6267-43f7-86a2-1e42bd19ec61.png" + read_type = "2" + img = minio_get_image(client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:]) + img.show() + img.save("result.png")