1
This commit is contained in:
92
dataset_toolkits/datasets/ObjaverseXL.py
Normal file
92
dataset_toolkits/datasets/ObjaverseXL.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
import objaverse.xl as oxl
|
||||||
|
from utils import get_file_hash
|
||||||
|
|
||||||
|
|
||||||
|
def add_args(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument('--source', type=str, default='sketchfab',
|
||||||
|
help='Data source to download annotations from (github, sketchfab)')
|
||||||
|
|
||||||
|
|
||||||
|
def get_metadata(source, **kwargs):
|
||||||
|
if source == 'sketchfab':
|
||||||
|
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_sketchfab.csv")
|
||||||
|
elif source == 'github':
|
||||||
|
metadata = pd.read_csv("hf://datasets/JeffreyXiang/TRELLIS-500K/ObjaverseXL_github.csv")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid source: {source}")
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def download(metadata, output_dir, **kwargs):
|
||||||
|
os.makedirs(os.path.join(output_dir, 'raw'), exist_ok=True)
|
||||||
|
|
||||||
|
# download annotations
|
||||||
|
annotations = oxl.get_annotations()
|
||||||
|
annotations = annotations[annotations['sha256'].isin(metadata['sha256'].values)]
|
||||||
|
|
||||||
|
# download and render objects
|
||||||
|
file_paths = oxl.download_objects(
|
||||||
|
annotations,
|
||||||
|
download_dir=os.path.join(output_dir, "raw"),
|
||||||
|
save_repo_format="zip",
|
||||||
|
)
|
||||||
|
|
||||||
|
downloaded = {}
|
||||||
|
metadata = metadata.set_index("file_identifier")
|
||||||
|
for k, v in file_paths.items():
|
||||||
|
sha256 = metadata.loc[k, "sha256"]
|
||||||
|
downloaded[sha256] = os.path.relpath(v, output_dir)
|
||||||
|
|
||||||
|
return pd.DataFrame(downloaded.items(), columns=['sha256', 'local_path'])
|
||||||
|
|
||||||
|
|
||||||
|
def foreach_instance(metadata, output_dir, func, max_workers=None, desc='Processing objects') -> pd.DataFrame:
|
||||||
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from tqdm import tqdm
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
# load metadata
|
||||||
|
metadata = metadata.to_dict('records')
|
||||||
|
|
||||||
|
# processing objects
|
||||||
|
records = []
|
||||||
|
max_workers = max_workers or os.cpu_count()
|
||||||
|
try:
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor, \
|
||||||
|
tqdm(total=len(metadata), desc=desc) as pbar:
|
||||||
|
def worker(metadatum):
|
||||||
|
try:
|
||||||
|
local_path = metadatum['local_path']
|
||||||
|
sha256 = metadatum['sha256']
|
||||||
|
if local_path.startswith('raw/github/repos/'):
|
||||||
|
path_parts = local_path.split('/')
|
||||||
|
file_name = os.path.join(*path_parts[5:])
|
||||||
|
zip_file = os.path.join(output_dir, *path_parts[:5])
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
||||||
|
zip_ref.extractall(tmp_dir)
|
||||||
|
file = os.path.join(tmp_dir, file_name)
|
||||||
|
record = func(file, sha256)
|
||||||
|
else:
|
||||||
|
file = os.path.join(output_dir, local_path)
|
||||||
|
record = func(file, sha256)
|
||||||
|
if record is not None:
|
||||||
|
records.append(record)
|
||||||
|
pbar.update()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing object {sha256}: {e}")
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
executor.map(worker, metadata)
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
except:
|
||||||
|
print("Error happened during processing.")
|
||||||
|
|
||||||
|
return pd.DataFrame.from_records(records)
|
||||||
66
dataset_toolkits/stat_latent.py
Normal file
66
dataset_toolkits/stat_latent.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--output_dir', type=str, required=True,
|
||||||
|
help='Directory to save the metadata')
|
||||||
|
parser.add_argument('--filter_low_aesthetic_score', type=float, default=None,
|
||||||
|
help='Filter objects with aesthetic score lower than this value')
|
||||||
|
parser.add_argument('--model', type=str, default='dinov2_vitl14_reg_slat_enc_swin8_B_64l8_fp16',
|
||||||
|
help='Latent model to use')
|
||||||
|
parser.add_argument('--num_samples', type=int, default=50000,
|
||||||
|
help='Number of samples to use for calculating stats')
|
||||||
|
opt = parser.parse_args()
|
||||||
|
opt = edict(vars(opt))
|
||||||
|
|
||||||
|
# get file list
|
||||||
|
if os.path.exists(os.path.join(opt.output_dir, 'metadata.csv')):
|
||||||
|
metadata = pd.read_csv(os.path.join(opt.output_dir, 'metadata.csv'))
|
||||||
|
else:
|
||||||
|
raise ValueError('metadata.csv not found')
|
||||||
|
if opt.filter_low_aesthetic_score is not None:
|
||||||
|
metadata = metadata[metadata['aesthetic_score'] >= opt.filter_low_aesthetic_score]
|
||||||
|
metadata = metadata[metadata[f'latent_{opt.model}'] == True]
|
||||||
|
sha256s = metadata['sha256'].values
|
||||||
|
sha256s = np.random.choice(sha256s, min(opt.num_samples, len(sha256s)), replace=False)
|
||||||
|
|
||||||
|
# stats
|
||||||
|
means = []
|
||||||
|
mean2s = []
|
||||||
|
with ThreadPoolExecutor(max_workers=16) as executor, \
|
||||||
|
tqdm(total=len(sha256s), desc="Extracting features") as pbar:
|
||||||
|
def worker(sha256):
|
||||||
|
try:
|
||||||
|
feats = np.load(os.path.join(opt.output_dir, 'latents', opt.model, f'{sha256}.npz'))
|
||||||
|
feats = feats['feats']
|
||||||
|
means.append(feats.mean(axis=0))
|
||||||
|
mean2s.append((feats ** 2).mean(axis=0))
|
||||||
|
pbar.update()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error extracting features for {sha256}: {e}")
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
executor.map(worker, sha256s)
|
||||||
|
executor.shutdown(wait=True)
|
||||||
|
|
||||||
|
mean = np.array(means).mean(axis=0)
|
||||||
|
mean2 = np.array(mean2s).mean(axis=0)
|
||||||
|
std = np.sqrt(mean2 - mean ** 2)
|
||||||
|
|
||||||
|
print('mean:', mean)
|
||||||
|
print('std:', std)
|
||||||
|
|
||||||
|
with open(os.path.join(opt.output_dir, 'latents', opt.model, 'stats.json'), 'w') as f:
|
||||||
|
json.dump({
|
||||||
|
'mean': mean.tolist(),
|
||||||
|
'std': std.tolist(),
|
||||||
|
}, f, indent=4)
|
||||||
|
|
||||||
112
multi_image_to_3D.py
Normal file
112
multi_image_to_3D.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#-*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import imageio
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from trellis.pipelines import TrellisImageTo3DPipeline
|
||||||
|
from trellis.utils import render_utils, postprocessing_utils
|
||||||
|
|
||||||
|
def build_parser():
|
||||||
|
p = argparse.ArgumentParser("TRELLIS CLI: multi-image -> 3D (video + optional GLB/PLY)")
|
||||||
|
|
||||||
|
p.add_argument(
|
||||||
|
"-i", "--images",
|
||||||
|
nargs="+",
|
||||||
|
required=True,
|
||||||
|
help="Input image paths (space-separated), e.g. -i a.png b.png c.png"
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"-o", "--out_dir",
|
||||||
|
default="trellis_out_multi",
|
||||||
|
help="Output directory.")
|
||||||
|
p.add_argument("--seed", type=int, default=1)
|
||||||
|
|
||||||
|
p.add_argument("--steps_sparse", type=int, default=12)
|
||||||
|
p.add_argument("--cfg_sparse", type=float, default=7.5)
|
||||||
|
p.add_argument("--steps_slat", type=int, default=12)
|
||||||
|
p.add_argument("--cfg_slat", type=float, default=3.0)
|
||||||
|
|
||||||
|
# Render video (default True)
|
||||||
|
p.add_argument("--save_video", dest="save_video", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-save_video", dest="save_video", action="store_false")
|
||||||
|
p.add_argument("--video_name", type=str, default="sample_multi.mp4")
|
||||||
|
p.add_argument("--fps", type=int, default=30)
|
||||||
|
|
||||||
|
# Export GLB (default True)
|
||||||
|
p.add_argument("--export_glb", dest="export_glb", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-export_glb", dest="export_glb", action="store_false")
|
||||||
|
p.add_argument("--glb_name", type=str, default="sample_multi.glb")
|
||||||
|
p.add_argument("--texture_size", type=int, default=1024)
|
||||||
|
p.add_argument("--simplify", type=float, default=0.95)
|
||||||
|
|
||||||
|
# Save PLY (default True)
|
||||||
|
p.add_argument("--save_ply", dest="save_ply", action="store_true", default=True)
|
||||||
|
p.add_argument("--no-save_ply", dest="save_ply", action="store_false")
|
||||||
|
p.add_argument("--ply_name", type=str, default="sample_multi.ply")
|
||||||
|
|
||||||
|
# Env passthrough (optional)
|
||||||
|
p.add_argument("--spconv_algo", type=str, default="native", choices=["native", "auto"])
|
||||||
|
# p.add_argument("--attn_backend", type=str, default="", choices=["", "flash-attn", "xformers"])
|
||||||
|
|
||||||
|
return p
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = build_parser().parse_args()
|
||||||
|
os.makedirs(args.out_dir, exist_ok=True)
|
||||||
|
|
||||||
|
os.environ["SPCONV_ALGO"] = args.spconv_algo
|
||||||
|
# if args.attn_backend:
|
||||||
|
# os.environ["ATTN_BACKEND"] = args.attn_backend
|
||||||
|
|
||||||
|
pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large")
|
||||||
|
pipeline.cuda()
|
||||||
|
|
||||||
|
images = [Image.open(p) for p in args.images]
|
||||||
|
|
||||||
|
outputs = pipeline.run_multi_image(
|
||||||
|
images,
|
||||||
|
seed=args.seed,
|
||||||
|
sparse_structure_sampler_params={
|
||||||
|
"steps": args.steps_sparse,
|
||||||
|
"cfg_strength": args.cfg_sparse,
|
||||||
|
},
|
||||||
|
slat_sampler_params={
|
||||||
|
"steps": args.steps_slat,
|
||||||
|
"cfg_strength": args.cfg_slat,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# For multi-image, TRELLIS still returns list-like outputs; export the first asset by default.
|
||||||
|
gs = outputs["gaussian"][0]
|
||||||
|
mesh = outputs["mesh"][0]
|
||||||
|
|
||||||
|
if args.save_video:
|
||||||
|
video_gs = render_utils.render_video(gs)["color"]
|
||||||
|
video_mesh = render_utils.render_video(mesh)["normal"]
|
||||||
|
video = [np.concatenate([fg, fm], axis=1) for fg, fm in zip(video_gs, video_mesh)]
|
||||||
|
out_mp4 = os.path.join(args.out_dir, args.video_name)
|
||||||
|
imageio.mimsave(out_mp4, video, fps=args.fps)
|
||||||
|
print(f"[ok] saved video: {out_mp4}")
|
||||||
|
|
||||||
|
if args.export_glb:
|
||||||
|
out_glb = os.path.join(args.out_dir, args.glb_name)
|
||||||
|
glb = postprocessing_utils.to_glb(
|
||||||
|
gs,
|
||||||
|
mesh,
|
||||||
|
simplify=args.simplify,
|
||||||
|
texture_size=args.texture_size,
|
||||||
|
)
|
||||||
|
glb.export(out_glb)
|
||||||
|
print(f"[ok] exported glb: {out_glb}")
|
||||||
|
|
||||||
|
if args.save_ply:
|
||||||
|
out_ply = os.path.join(args.out_dir, args.ply_name)
|
||||||
|
gs.save_ply(out_ply)
|
||||||
|
print(f"[ok] saved ply: {out_ply}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
217
trellis/datasets/structured_latent.py
Normal file
217
trellis/datasets/structured_latent.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import *
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import utils3d.torch
|
||||||
|
from .components import StandardDatasetBase, TextConditionedMixin, ImageConditionedMixin
|
||||||
|
from ..modules.sparse.basic import SparseTensor
|
||||||
|
from .. import models
|
||||||
|
from ..utils.render_utils import get_renderer
|
||||||
|
from ..utils.data_utils import load_balanced_group_indices
|
||||||
|
|
||||||
|
|
||||||
|
class SLatVisMixin:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
pretrained_slat_dec: str = 'microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
|
||||||
|
slat_dec_path: Optional[str] = None,
|
||||||
|
slat_dec_ckpt: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.slat_dec = None
|
||||||
|
self.pretrained_slat_dec = pretrained_slat_dec
|
||||||
|
self.slat_dec_path = slat_dec_path
|
||||||
|
self.slat_dec_ckpt = slat_dec_ckpt
|
||||||
|
|
||||||
|
def _loading_slat_dec(self):
|
||||||
|
if self.slat_dec is not None:
|
||||||
|
return
|
||||||
|
if self.slat_dec_path is not None:
|
||||||
|
cfg = json.load(open(os.path.join(self.slat_dec_path, 'config.json'), 'r'))
|
||||||
|
decoder = getattr(models, cfg['models']['decoder']['name'])(**cfg['models']['decoder']['args'])
|
||||||
|
ckpt_path = os.path.join(self.slat_dec_path, 'ckpts', f'decoder_{self.slat_dec_ckpt}.pt')
|
||||||
|
decoder.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
|
||||||
|
else:
|
||||||
|
decoder = models.from_pretrained(self.pretrained_slat_dec)
|
||||||
|
self.slat_dec = decoder.cuda().eval()
|
||||||
|
|
||||||
|
def _delete_slat_dec(self):
|
||||||
|
del self.slat_dec
|
||||||
|
self.slat_dec = None
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode_latent(self, z, batch_size=4):
|
||||||
|
self._loading_slat_dec()
|
||||||
|
reps = []
|
||||||
|
if self.normalization is not None:
|
||||||
|
z = z * self.std.to(z.device) + self.mean.to(z.device)
|
||||||
|
for i in range(0, z.shape[0], batch_size):
|
||||||
|
reps.append(self.slat_dec(z[i:i+batch_size]))
|
||||||
|
reps = sum(reps, [])
|
||||||
|
self._delete_slat_dec()
|
||||||
|
return reps
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def visualize_sample(self, x_0: Union[SparseTensor, dict]):
|
||||||
|
x_0 = x_0 if isinstance(x_0, SparseTensor) else x_0['x_0']
|
||||||
|
reps = self.decode_latent(x_0.cuda())
|
||||||
|
|
||||||
|
# Build camera
|
||||||
|
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
||||||
|
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
||||||
|
yaws = [y + yaws_offset for y in yaws]
|
||||||
|
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
||||||
|
|
||||||
|
exts = []
|
||||||
|
ints = []
|
||||||
|
for yaw, pitch in zip(yaws, pitch):
|
||||||
|
orig = torch.tensor([
|
||||||
|
np.sin(yaw) * np.cos(pitch),
|
||||||
|
np.cos(yaw) * np.cos(pitch),
|
||||||
|
np.sin(pitch),
|
||||||
|
]).float().cuda() * 2
|
||||||
|
fov = torch.deg2rad(torch.tensor(40)).cuda()
|
||||||
|
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
||||||
|
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
||||||
|
exts.append(extrinsics)
|
||||||
|
ints.append(intrinsics)
|
||||||
|
|
||||||
|
renderer = get_renderer(reps[0])
|
||||||
|
images = []
|
||||||
|
for representation in reps:
|
||||||
|
image = torch.zeros(3, 1024, 1024).cuda()
|
||||||
|
tile = [2, 2]
|
||||||
|
for j, (ext, intr) in enumerate(zip(exts, ints)):
|
||||||
|
res = renderer.render(representation, ext, intr)
|
||||||
|
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
|
||||||
|
images.append(image)
|
||||||
|
images = torch.stack(images)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
class SLat(SLatVisMixin, StandardDatasetBase):
|
||||||
|
"""
|
||||||
|
structured latent dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots (str): path to the dataset
|
||||||
|
latent_model (str): name of the latent model
|
||||||
|
min_aesthetic_score (float): minimum aesthetic score
|
||||||
|
max_num_voxels (int): maximum number of voxels
|
||||||
|
normalization (dict): normalization stats
|
||||||
|
pretrained_slat_dec (str): name of the pretrained slat decoder
|
||||||
|
slat_dec_path (str): path to the slat decoder, if given, will override the pretrained_slat_dec
|
||||||
|
slat_dec_ckpt (str): name of the slat decoder checkpoint
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
roots: str,
|
||||||
|
*,
|
||||||
|
latent_model: str,
|
||||||
|
min_aesthetic_score: float = 5.0,
|
||||||
|
max_num_voxels: int = 32768,
|
||||||
|
normalization: Optional[dict] = None,
|
||||||
|
pretrained_slat_dec: str = 'microsoft/TRELLIS-image-large/ckpts/slat_dec_gs_swin8_B_64l8gs32_fp16',
|
||||||
|
slat_dec_path: Optional[str] = None,
|
||||||
|
slat_dec_ckpt: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.normalization = normalization
|
||||||
|
self.latent_model = latent_model
|
||||||
|
self.min_aesthetic_score = min_aesthetic_score
|
||||||
|
self.max_num_voxels = max_num_voxels
|
||||||
|
self.value_range = (0, 1)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
roots,
|
||||||
|
pretrained_slat_dec=pretrained_slat_dec,
|
||||||
|
slat_dec_path=slat_dec_path,
|
||||||
|
slat_dec_ckpt=slat_dec_ckpt,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.loads = [self.metadata.loc[sha256, 'num_voxels'] for _, sha256 in self.instances]
|
||||||
|
|
||||||
|
if self.normalization is not None:
|
||||||
|
self.mean = torch.tensor(self.normalization['mean']).reshape(1, -1)
|
||||||
|
self.std = torch.tensor(self.normalization['std']).reshape(1, -1)
|
||||||
|
|
||||||
|
def filter_metadata(self, metadata):
|
||||||
|
stats = {}
|
||||||
|
metadata = metadata[metadata[f'latent_{self.latent_model}']]
|
||||||
|
stats['With latent'] = len(metadata)
|
||||||
|
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
||||||
|
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
||||||
|
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
|
||||||
|
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
|
||||||
|
return metadata, stats
|
||||||
|
|
||||||
|
def get_instance(self, root, instance):
|
||||||
|
data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
|
||||||
|
coords = torch.tensor(data['coords']).int()
|
||||||
|
feats = torch.tensor(data['feats']).float()
|
||||||
|
if self.normalization is not None:
|
||||||
|
feats = (feats - self.mean) / self.std
|
||||||
|
return {
|
||||||
|
'coords': coords,
|
||||||
|
'feats': feats,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collate_fn(batch, split_size=None):
|
||||||
|
if split_size is None:
|
||||||
|
group_idx = [list(range(len(batch)))]
|
||||||
|
else:
|
||||||
|
group_idx = load_balanced_group_indices([b['coords'].shape[0] for b in batch], split_size)
|
||||||
|
packs = []
|
||||||
|
for group in group_idx:
|
||||||
|
sub_batch = [batch[i] for i in group]
|
||||||
|
pack = {}
|
||||||
|
coords = []
|
||||||
|
feats = []
|
||||||
|
layout = []
|
||||||
|
start = 0
|
||||||
|
for i, b in enumerate(sub_batch):
|
||||||
|
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
|
||||||
|
feats.append(b['feats'])
|
||||||
|
layout.append(slice(start, start + b['coords'].shape[0]))
|
||||||
|
start += b['coords'].shape[0]
|
||||||
|
coords = torch.cat(coords)
|
||||||
|
feats = torch.cat(feats)
|
||||||
|
pack['x_0'] = SparseTensor(
|
||||||
|
coords=coords,
|
||||||
|
feats=feats,
|
||||||
|
)
|
||||||
|
pack['x_0']._shape = torch.Size([len(group), *sub_batch[0]['feats'].shape[1:]])
|
||||||
|
pack['x_0'].register_spatial_cache('layout', layout)
|
||||||
|
|
||||||
|
# collate other data
|
||||||
|
keys = [k for k in sub_batch[0].keys() if k not in ['coords', 'feats']]
|
||||||
|
for k in keys:
|
||||||
|
if isinstance(sub_batch[0][k], torch.Tensor):
|
||||||
|
pack[k] = torch.stack([b[k] for b in sub_batch])
|
||||||
|
elif isinstance(sub_batch[0][k], list):
|
||||||
|
pack[k] = sum([b[k] for b in sub_batch], [])
|
||||||
|
else:
|
||||||
|
pack[k] = [b[k] for b in sub_batch]
|
||||||
|
|
||||||
|
packs.append(pack)
|
||||||
|
|
||||||
|
if split_size is None:
|
||||||
|
return packs[0]
|
||||||
|
return packs
|
||||||
|
|
||||||
|
|
||||||
|
class TextConditionedSLat(TextConditionedMixin, SLat):
|
||||||
|
"""
|
||||||
|
Text conditioned structured latent dataset
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConditionedSLat(ImageConditionedMixin, SLat):
|
||||||
|
"""
|
||||||
|
Image conditioned structured latent dataset
|
||||||
|
"""
|
||||||
|
pass
|
||||||
160
trellis/datasets/structured_latent2render.py
Normal file
160
trellis/datasets/structured_latent2render.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import utils3d.torch
|
||||||
|
from ..modules.sparse.basic import SparseTensor
|
||||||
|
from .components import StandardDatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
class SLat2Render(StandardDatasetBase):
|
||||||
|
"""
|
||||||
|
Dataset for Structured Latent and rendered images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots (str): paths to the dataset
|
||||||
|
image_size (int): size of the image
|
||||||
|
latent_model (str): latent model name
|
||||||
|
min_aesthetic_score (float): minimum aesthetic score
|
||||||
|
max_num_voxels (int): maximum number of voxels
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
roots: str,
|
||||||
|
image_size: int,
|
||||||
|
latent_model: str,
|
||||||
|
min_aesthetic_score: float = 5.0,
|
||||||
|
max_num_voxels: int = 32768,
|
||||||
|
):
|
||||||
|
self.image_size = image_size
|
||||||
|
self.latent_model = latent_model
|
||||||
|
self.min_aesthetic_score = min_aesthetic_score
|
||||||
|
self.max_num_voxels = max_num_voxels
|
||||||
|
self.value_range = (0, 1)
|
||||||
|
|
||||||
|
super().__init__(roots)
|
||||||
|
|
||||||
|
def filter_metadata(self, metadata):
|
||||||
|
stats = {}
|
||||||
|
metadata = metadata[metadata[f'latent_{self.latent_model}']]
|
||||||
|
stats['With latent'] = len(metadata)
|
||||||
|
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
|
||||||
|
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
|
||||||
|
metadata = metadata[metadata['num_voxels'] <= self.max_num_voxels]
|
||||||
|
stats[f'Num voxels <= {self.max_num_voxels}'] = len(metadata)
|
||||||
|
return metadata, stats
|
||||||
|
|
||||||
|
def _get_image(self, root, instance):
|
||||||
|
with open(os.path.join(root, 'renders', instance, 'transforms.json')) as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
n_views = len(metadata['frames'])
|
||||||
|
view = np.random.randint(n_views)
|
||||||
|
metadata = metadata['frames'][view]
|
||||||
|
fov = metadata['camera_angle_x']
|
||||||
|
intrinsics = utils3d.torch.intrinsics_from_fov_xy(torch.tensor(fov), torch.tensor(fov))
|
||||||
|
c2w = torch.tensor(metadata['transform_matrix'])
|
||||||
|
c2w[:3, 1:3] *= -1
|
||||||
|
extrinsics = torch.inverse(c2w)
|
||||||
|
|
||||||
|
image_path = os.path.join(root, 'renders', instance, metadata['file_path'])
|
||||||
|
image = Image.open(image_path)
|
||||||
|
alpha = image.getchannel(3)
|
||||||
|
image = image.convert('RGB')
|
||||||
|
image = image.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
||||||
|
alpha = alpha.resize((self.image_size, self.image_size), Image.Resampling.LANCZOS)
|
||||||
|
image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 255.0
|
||||||
|
alpha = torch.tensor(np.array(alpha)).float() / 255.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'image': image,
|
||||||
|
'alpha': alpha,
|
||||||
|
'extrinsics': extrinsics,
|
||||||
|
'intrinsics': intrinsics,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_latent(self, root, instance):
|
||||||
|
data = np.load(os.path.join(root, 'latents', self.latent_model, f'{instance}.npz'))
|
||||||
|
coords = torch.tensor(data['coords']).int()
|
||||||
|
feats = torch.tensor(data['feats']).float()
|
||||||
|
return {
|
||||||
|
'coords': coords,
|
||||||
|
'feats': feats,
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def visualize_sample(self, sample: dict):
|
||||||
|
return sample['image']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collate_fn(batch):
|
||||||
|
pack = {}
|
||||||
|
coords = []
|
||||||
|
for i, b in enumerate(batch):
|
||||||
|
coords.append(torch.cat([torch.full((b['coords'].shape[0], 1), i, dtype=torch.int32), b['coords']], dim=-1))
|
||||||
|
coords = torch.cat(coords)
|
||||||
|
feats = torch.cat([b['feats'] for b in batch])
|
||||||
|
pack['latents'] = SparseTensor(
|
||||||
|
coords=coords,
|
||||||
|
feats=feats,
|
||||||
|
)
|
||||||
|
|
||||||
|
# collate other data
|
||||||
|
keys = [k for k in batch[0].keys() if k not in ['coords', 'feats']]
|
||||||
|
for k in keys:
|
||||||
|
if isinstance(batch[0][k], torch.Tensor):
|
||||||
|
pack[k] = torch.stack([b[k] for b in batch])
|
||||||
|
elif isinstance(batch[0][k], list):
|
||||||
|
pack[k] = sum([b[k] for b in batch], [])
|
||||||
|
else:
|
||||||
|
pack[k] = [b[k] for b in batch]
|
||||||
|
|
||||||
|
return pack
|
||||||
|
|
||||||
|
def get_instance(self, root, instance):
|
||||||
|
image = self._get_image(root, instance)
|
||||||
|
latent = self._get_latent(root, instance)
|
||||||
|
return {
|
||||||
|
**image,
|
||||||
|
**latent,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Slat2RenderGeo(SLat2Render):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
roots: str,
|
||||||
|
image_size: int,
|
||||||
|
latent_model: str,
|
||||||
|
min_aesthetic_score: float = 5.0,
|
||||||
|
max_num_voxels: int = 32768,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
roots,
|
||||||
|
image_size,
|
||||||
|
latent_model,
|
||||||
|
min_aesthetic_score,
|
||||||
|
max_num_voxels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_geo(self, root, instance):
|
||||||
|
verts, face = utils3d.io.read_ply(os.path.join(root, 'renders', instance, 'mesh.ply'))
|
||||||
|
mesh = {
|
||||||
|
"vertices" : torch.from_numpy(verts),
|
||||||
|
"faces" : torch.from_numpy(face),
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"mesh" : mesh,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_instance(self, root, instance):
|
||||||
|
image = self._get_image(root, instance)
|
||||||
|
latent = self._get_latent(root, instance)
|
||||||
|
geo = self._get_geo(root, instance)
|
||||||
|
return {
|
||||||
|
**image,
|
||||||
|
**latent,
|
||||||
|
**geo,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
276
trellis/models/structured_latent_flow.py
Normal file
276
trellis/models/structured_latent_flow.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
from typing import *
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from ..modules.utils import zero_module, convert_module_to_f16, convert_module_to_f32
|
||||||
|
from ..modules.transformer import AbsolutePositionEmbedder
|
||||||
|
from ..modules.norm import LayerNorm32
|
||||||
|
from ..modules import sparse as sp
|
||||||
|
from ..modules.sparse.transformer import ModulatedSparseTransformerCrossBlock
|
||||||
|
from .sparse_structure_flow import TimestepEmbedder
|
||||||
|
from .sparse_elastic_mixin import SparseTransformerElasticMixin
|
||||||
|
|
||||||
|
|
||||||
|
class SparseResBlock3d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
emb_channels: int,
|
||||||
|
out_channels: Optional[int] = None,
|
||||||
|
downsample: bool = False,
|
||||||
|
upsample: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.emb_channels = emb_channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.downsample = downsample
|
||||||
|
self.upsample = upsample
|
||||||
|
|
||||||
|
assert not (downsample and upsample), "Cannot downsample and upsample at the same time"
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
|
||||||
|
self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.conv1 = sp.SparseConv3d(channels, self.out_channels, 3)
|
||||||
|
self.conv2 = zero_module(sp.SparseConv3d(self.out_channels, self.out_channels, 3))
|
||||||
|
self.emb_layers = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(emb_channels, 2 * self.out_channels, bias=True),
|
||||||
|
)
|
||||||
|
self.skip_connection = sp.SparseLinear(channels, self.out_channels) if channels != self.out_channels else nn.Identity()
|
||||||
|
self.updown = None
|
||||||
|
if self.downsample:
|
||||||
|
self.updown = sp.SparseDownsample(2)
|
||||||
|
elif self.upsample:
|
||||||
|
self.updown = sp.SparseUpsample(2)
|
||||||
|
|
||||||
|
def _updown(self, x: sp.SparseTensor) -> sp.SparseTensor:
|
||||||
|
if self.updown is not None:
|
||||||
|
x = self.updown(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: sp.SparseTensor, emb: torch.Tensor) -> sp.SparseTensor:
|
||||||
|
emb_out = self.emb_layers(emb).type(x.dtype)
|
||||||
|
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||||
|
|
||||||
|
x = self._updown(x)
|
||||||
|
h = x.replace(self.norm1(x.feats))
|
||||||
|
h = h.replace(F.silu(h.feats))
|
||||||
|
h = self.conv1(h)
|
||||||
|
h = h.replace(self.norm2(h.feats)) * (1 + scale) + shift
|
||||||
|
h = h.replace(F.silu(h.feats))
|
||||||
|
h = self.conv2(h)
|
||||||
|
h = h + self.skip_connection(x)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class SLatFlowModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resolution: int,
|
||||||
|
in_channels: int,
|
||||||
|
model_channels: int,
|
||||||
|
cond_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_blocks: int,
|
||||||
|
num_heads: Optional[int] = None,
|
||||||
|
num_head_channels: Optional[int] = 64,
|
||||||
|
mlp_ratio: float = 4,
|
||||||
|
patch_size: int = 2,
|
||||||
|
num_io_res_blocks: int = 2,
|
||||||
|
io_block_channels: List[int] = None,
|
||||||
|
pe_mode: Literal["ape", "rope"] = "ape",
|
||||||
|
use_fp16: bool = False,
|
||||||
|
use_checkpoint: bool = False,
|
||||||
|
use_skip_connection: bool = True,
|
||||||
|
share_mod: bool = False,
|
||||||
|
qk_rms_norm: bool = False,
|
||||||
|
qk_rms_norm_cross: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.model_channels = model_channels
|
||||||
|
self.cond_channels = cond_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.num_blocks = num_blocks
|
||||||
|
self.num_heads = num_heads or model_channels // num_head_channels
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_io_res_blocks = num_io_res_blocks
|
||||||
|
self.io_block_channels = io_block_channels
|
||||||
|
self.pe_mode = pe_mode
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
|
self.use_skip_connection = use_skip_connection
|
||||||
|
self.share_mod = share_mod
|
||||||
|
self.qk_rms_norm = qk_rms_norm
|
||||||
|
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||||
|
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||||
|
|
||||||
|
if self.io_block_channels is not None:
|
||||||
|
assert int(np.log2(patch_size)) == np.log2(patch_size), "Patch size must be a power of 2"
|
||||||
|
assert np.log2(patch_size) == len(io_block_channels), "Number of IO ResBlocks must match the number of stages"
|
||||||
|
|
||||||
|
self.t_embedder = TimestepEmbedder(model_channels)
|
||||||
|
if share_mod:
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(model_channels, 6 * model_channels, bias=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if pe_mode == "ape":
|
||||||
|
self.pos_embedder = AbsolutePositionEmbedder(model_channels)
|
||||||
|
|
||||||
|
self.input_layer = sp.SparseLinear(in_channels, model_channels if io_block_channels is None else io_block_channels[0])
|
||||||
|
|
||||||
|
self.input_blocks = nn.ModuleList([])
|
||||||
|
if io_block_channels is not None:
|
||||||
|
for chs, next_chs in zip(io_block_channels, io_block_channels[1:] + [model_channels]):
|
||||||
|
self.input_blocks.extend([
|
||||||
|
SparseResBlock3d(
|
||||||
|
chs,
|
||||||
|
model_channels,
|
||||||
|
out_channels=chs,
|
||||||
|
)
|
||||||
|
for _ in range(num_io_res_blocks-1)
|
||||||
|
])
|
||||||
|
self.input_blocks.append(
|
||||||
|
SparseResBlock3d(
|
||||||
|
chs,
|
||||||
|
model_channels,
|
||||||
|
out_channels=next_chs,
|
||||||
|
downsample=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
ModulatedSparseTransformerCrossBlock(
|
||||||
|
model_channels,
|
||||||
|
cond_channels,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
mlp_ratio=self.mlp_ratio,
|
||||||
|
attn_mode='full',
|
||||||
|
use_checkpoint=self.use_checkpoint,
|
||||||
|
use_rope=(pe_mode == "rope"),
|
||||||
|
share_mod=self.share_mod,
|
||||||
|
qk_rms_norm=self.qk_rms_norm,
|
||||||
|
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||||
|
)
|
||||||
|
for _ in range(num_blocks)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.out_blocks = nn.ModuleList([])
|
||||||
|
if io_block_channels is not None:
|
||||||
|
for chs, prev_chs in zip(reversed(io_block_channels), [model_channels] + list(reversed(io_block_channels[1:]))):
|
||||||
|
self.out_blocks.append(
|
||||||
|
SparseResBlock3d(
|
||||||
|
prev_chs * 2 if self.use_skip_connection else prev_chs,
|
||||||
|
model_channels,
|
||||||
|
out_channels=chs,
|
||||||
|
upsample=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.out_blocks.extend([
|
||||||
|
SparseResBlock3d(
|
||||||
|
chs * 2 if self.use_skip_connection else chs,
|
||||||
|
model_channels,
|
||||||
|
out_channels=chs,
|
||||||
|
)
|
||||||
|
for _ in range(num_io_res_blocks-1)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.out_layer = sp.SparseLinear(model_channels if io_block_channels is None else io_block_channels[0], out_channels)
|
||||||
|
|
||||||
|
self.initialize_weights()
|
||||||
|
if use_fp16:
|
||||||
|
self.convert_to_fp16()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
"""
|
||||||
|
Return the device of the model.
|
||||||
|
"""
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def convert_to_fp16(self) -> None:
|
||||||
|
"""
|
||||||
|
Convert the torso of the model to float16.
|
||||||
|
"""
|
||||||
|
self.input_blocks.apply(convert_module_to_f16)
|
||||||
|
self.blocks.apply(convert_module_to_f16)
|
||||||
|
self.out_blocks.apply(convert_module_to_f16)
|
||||||
|
|
||||||
|
def convert_to_fp32(self) -> None:
|
||||||
|
"""
|
||||||
|
Convert the torso of the model to float32.
|
||||||
|
"""
|
||||||
|
self.input_blocks.apply(convert_module_to_f32)
|
||||||
|
self.blocks.apply(convert_module_to_f32)
|
||||||
|
self.out_blocks.apply(convert_module_to_f32)
|
||||||
|
|
||||||
|
def initialize_weights(self) -> None:
|
||||||
|
# Initialize transformer layers:
|
||||||
|
def _basic_init(module):
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0)
|
||||||
|
self.apply(_basic_init)
|
||||||
|
|
||||||
|
# Initialize timestep embedding MLP:
|
||||||
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||||
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||||
|
|
||||||
|
# Zero-out adaLN modulation layers in DiT blocks:
|
||||||
|
if self.share_mod:
|
||||||
|
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
||||||
|
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
||||||
|
else:
|
||||||
|
for block in self.blocks:
|
||||||
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||||
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||||
|
|
||||||
|
# Zero-out output layers:
|
||||||
|
nn.init.constant_(self.out_layer.weight, 0)
|
||||||
|
nn.init.constant_(self.out_layer.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x: sp.SparseTensor, t: torch.Tensor, cond: torch.Tensor) -> sp.SparseTensor:
|
||||||
|
h = self.input_layer(x).type(self.dtype)
|
||||||
|
t_emb = self.t_embedder(t)
|
||||||
|
if self.share_mod:
|
||||||
|
t_emb = self.adaLN_modulation(t_emb)
|
||||||
|
t_emb = t_emb.type(self.dtype)
|
||||||
|
cond = cond.type(self.dtype)
|
||||||
|
|
||||||
|
skips = []
|
||||||
|
# pack with input blocks
|
||||||
|
for block in self.input_blocks:
|
||||||
|
h = block(h, t_emb)
|
||||||
|
skips.append(h.feats)
|
||||||
|
|
||||||
|
if self.pe_mode == "ape":
|
||||||
|
h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype)
|
||||||
|
for block in self.blocks:
|
||||||
|
h = block(h, t_emb, cond)
|
||||||
|
|
||||||
|
# unpack with output blocks
|
||||||
|
for block, skip in zip(self.out_blocks, reversed(skips)):
|
||||||
|
if self.use_skip_connection:
|
||||||
|
h = block(h.replace(torch.cat([h.feats, skip], dim=1)), t_emb)
|
||||||
|
else:
|
||||||
|
h = block(h, t_emb)
|
||||||
|
|
||||||
|
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||||
|
h = self.out_layer(h.type(x.dtype))
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class ElasticSLatFlowModel(SparseTransformerElasticMixin, SLatFlowModel):
|
||||||
|
"""
|
||||||
|
SLat Flow Model with elastic memory management.
|
||||||
|
Used for training with low VRAM.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
25
trellis/modules/norm.py
Normal file
25
trellis/modules/norm.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm32(nn.LayerNorm):
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNorm32(nn.GroupNorm):
|
||||||
|
"""
|
||||||
|
A GroupNorm layer that converts to float32 before the forward pass.
|
||||||
|
"""
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelLayerNorm32(LayerNorm32):
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
DIM = x.dim()
|
||||||
|
x = x.permute(0, *range(2, DIM), 1).contiguous()
|
||||||
|
x = super().forward(x)
|
||||||
|
x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
35
trellis/modules/sparse/nonlinearity.py
Normal file
35
trellis/modules/sparse/nonlinearity.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from . import SparseTensor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'SparseReLU',
|
||||||
|
'SparseSiLU',
|
||||||
|
'SparseGELU',
|
||||||
|
'SparseActivation'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SparseReLU(nn.ReLU):
|
||||||
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||||
|
return input.replace(super().forward(input.feats))
|
||||||
|
|
||||||
|
|
||||||
|
class SparseSiLU(nn.SiLU):
|
||||||
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||||
|
return input.replace(super().forward(input.feats))
|
||||||
|
|
||||||
|
|
||||||
|
class SparseGELU(nn.GELU):
|
||||||
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||||
|
return input.replace(super().forward(input.feats))
|
||||||
|
|
||||||
|
|
||||||
|
class SparseActivation(nn.Module):
|
||||||
|
def __init__(self, activation: nn.Module):
|
||||||
|
super().__init__()
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||||
|
return input.replace(self.activation(input.feats))
|
||||||
|
|
||||||
58
trellis/modules/sparse/norm.py
Normal file
58
trellis/modules/sparse/norm.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from . import SparseTensor
|
||||||
|
from . import DEBUG
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'SparseGroupNorm',
|
||||||
|
'SparseLayerNorm',
|
||||||
|
'SparseGroupNorm32',
|
||||||
|
'SparseLayerNorm32',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SparseGroupNorm(nn.GroupNorm):
|
||||||
|
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
|
||||||
|
super(SparseGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
|
||||||
|
|
||||||
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||||
|
nfeats = torch.zeros_like(input.feats)
|
||||||
|
for k in range(input.shape[0]):
|
||||||
|
if DEBUG:
|
||||||
|
assert (input.coords[input.layout[k], 0] == k).all(), f"SparseGroupNorm: batch index mismatch"
|
||||||
|
bfeats = input.feats[input.layout[k]]
|
||||||
|
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
|
||||||
|
bfeats = super().forward(bfeats)
|
||||||
|
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
|
||||||
|
nfeats[input.layout[k]] = bfeats
|
||||||
|
return input.replace(nfeats)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseLayerNorm(nn.LayerNorm):
|
||||||
|
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
|
||||||
|
super(SparseLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine)
|
||||||
|
|
||||||
|
def forward(self, input: SparseTensor) -> SparseTensor:
|
||||||
|
nfeats = torch.zeros_like(input.feats)
|
||||||
|
for k in range(input.shape[0]):
|
||||||
|
bfeats = input.feats[input.layout[k]]
|
||||||
|
bfeats = bfeats.permute(1, 0).reshape(1, input.shape[1], -1)
|
||||||
|
bfeats = super().forward(bfeats)
|
||||||
|
bfeats = bfeats.reshape(input.shape[1], -1).permute(1, 0)
|
||||||
|
nfeats[input.layout[k]] = bfeats
|
||||||
|
return input.replace(nfeats)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseGroupNorm32(SparseGroupNorm):
|
||||||
|
"""
|
||||||
|
A GroupNorm layer that converts to float32 before the forward pass.
|
||||||
|
"""
|
||||||
|
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||||
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
|
||||||
|
class SparseLayerNorm32(SparseLayerNorm):
|
||||||
|
"""
|
||||||
|
A LayerNorm layer that converts to float32 before the forward pass.
|
||||||
|
"""
|
||||||
|
def forward(self, x: SparseTensor) -> SparseTensor:
|
||||||
|
return super().forward(x.float()).type(x.dtype)
|
||||||
300
trellis/renderers/octree_renderer.py
Normal file
300
trellis/renderers/octree_renderer.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
import cv2
|
||||||
|
from scipy.stats import qmc
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
from ..representations.octree import DfsOctree
|
||||||
|
|
||||||
|
|
||||||
|
def intrinsics_to_projection(
|
||||||
|
intrinsics: torch.Tensor,
|
||||||
|
near: float,
|
||||||
|
far: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
OpenCV intrinsics to OpenGL perspective matrix
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
|
||||||
|
near (float): near plane to clip
|
||||||
|
far (float): far plane to clip
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): [4, 4] OpenGL perspective matrix
|
||||||
|
"""
|
||||||
|
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
|
||||||
|
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
|
||||||
|
ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
|
||||||
|
ret[0, 0] = 2 * fx
|
||||||
|
ret[1, 1] = 2 * fy
|
||||||
|
ret[0, 2] = 2 * cx - 1
|
||||||
|
ret[1, 2] = - 2 * cy + 1
|
||||||
|
ret[2, 2] = far / (far - near)
|
||||||
|
ret[2, 3] = near * far / (near - far)
|
||||||
|
ret[3, 2] = 1.
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def render(viewpoint_camera, octree : DfsOctree, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, used_rank = None, colors_overwrite = None, aux=None, halton_sampler=None):
|
||||||
|
"""
|
||||||
|
Render the scene.
|
||||||
|
|
||||||
|
Background tensor (bg_color) must be on GPU!
|
||||||
|
"""
|
||||||
|
# lazy import
|
||||||
|
if 'OctreeTrivecRasterizer' not in globals():
|
||||||
|
from diffoctreerast import OctreeVoxelRasterizer, OctreeGaussianRasterizer, OctreeTrivecRasterizer, OctreeDecoupolyRasterizer
|
||||||
|
|
||||||
|
# Set up rasterization configuration
|
||||||
|
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
||||||
|
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
||||||
|
|
||||||
|
raster_settings = edict(
|
||||||
|
image_height=int(viewpoint_camera.image_height),
|
||||||
|
image_width=int(viewpoint_camera.image_width),
|
||||||
|
tanfovx=tanfovx,
|
||||||
|
tanfovy=tanfovy,
|
||||||
|
bg=bg_color,
|
||||||
|
scale_modifier=scaling_modifier,
|
||||||
|
viewmatrix=viewpoint_camera.world_view_transform,
|
||||||
|
projmatrix=viewpoint_camera.full_proj_transform,
|
||||||
|
sh_degree=octree.active_sh_degree,
|
||||||
|
campos=viewpoint_camera.camera_center,
|
||||||
|
with_distloss=pipe.with_distloss,
|
||||||
|
jitter=pipe.jitter,
|
||||||
|
debug=pipe.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
positions = octree.get_xyz
|
||||||
|
if octree.primitive == "voxel":
|
||||||
|
densities = octree.get_density
|
||||||
|
elif octree.primitive == "gaussian":
|
||||||
|
opacities = octree.get_opacity
|
||||||
|
elif octree.primitive == "trivec":
|
||||||
|
trivecs = octree.get_trivec
|
||||||
|
densities = octree.get_density
|
||||||
|
raster_settings.density_shift = octree.density_shift
|
||||||
|
elif octree.primitive == "decoupoly":
|
||||||
|
decoupolys_V, decoupolys_g = octree.get_decoupoly
|
||||||
|
densities = octree.get_density
|
||||||
|
raster_settings.density_shift = octree.density_shift
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown primitive {octree.primitive}")
|
||||||
|
depths = octree.get_depth
|
||||||
|
|
||||||
|
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
||||||
|
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
||||||
|
colors_precomp = None
|
||||||
|
shs = octree.get_features
|
||||||
|
if octree.primitive in ["voxel", "gaussian"] and colors_overwrite is not None:
|
||||||
|
colors_precomp = colors_overwrite
|
||||||
|
shs = None
|
||||||
|
|
||||||
|
ret = edict()
|
||||||
|
|
||||||
|
if octree.primitive == "voxel":
|
||||||
|
renderer = OctreeVoxelRasterizer(raster_settings=raster_settings)
|
||||||
|
rgb, depth, alpha, distloss = renderer(
|
||||||
|
positions = positions,
|
||||||
|
densities = densities,
|
||||||
|
shs = shs,
|
||||||
|
colors_precomp = colors_precomp,
|
||||||
|
depths = depths,
|
||||||
|
aabb = octree.aabb,
|
||||||
|
aux = aux,
|
||||||
|
)
|
||||||
|
ret['rgb'] = rgb
|
||||||
|
ret['depth'] = depth
|
||||||
|
ret['alpha'] = alpha
|
||||||
|
ret['distloss'] = distloss
|
||||||
|
elif octree.primitive == "gaussian":
|
||||||
|
renderer = OctreeGaussianRasterizer(raster_settings=raster_settings)
|
||||||
|
rgb, depth, alpha = renderer(
|
||||||
|
positions = positions,
|
||||||
|
opacities = opacities,
|
||||||
|
shs = shs,
|
||||||
|
colors_precomp = colors_precomp,
|
||||||
|
depths = depths,
|
||||||
|
aabb = octree.aabb,
|
||||||
|
aux = aux,
|
||||||
|
)
|
||||||
|
ret['rgb'] = rgb
|
||||||
|
ret['depth'] = depth
|
||||||
|
ret['alpha'] = alpha
|
||||||
|
elif octree.primitive == "trivec":
|
||||||
|
raster_settings.used_rank = used_rank if used_rank is not None else trivecs.shape[1]
|
||||||
|
renderer = OctreeTrivecRasterizer(raster_settings=raster_settings)
|
||||||
|
rgb, depth, alpha, percent_depth = renderer(
|
||||||
|
positions = positions,
|
||||||
|
trivecs = trivecs,
|
||||||
|
densities = densities,
|
||||||
|
shs = shs,
|
||||||
|
colors_precomp = colors_precomp,
|
||||||
|
colors_overwrite = colors_overwrite,
|
||||||
|
depths = depths,
|
||||||
|
aabb = octree.aabb,
|
||||||
|
aux = aux,
|
||||||
|
halton_sampler = halton_sampler,
|
||||||
|
)
|
||||||
|
ret['percent_depth'] = percent_depth
|
||||||
|
ret['rgb'] = rgb
|
||||||
|
ret['depth'] = depth
|
||||||
|
ret['alpha'] = alpha
|
||||||
|
elif octree.primitive == "decoupoly":
|
||||||
|
raster_settings.used_rank = used_rank if used_rank is not None else decoupolys_V.shape[1]
|
||||||
|
renderer = OctreeDecoupolyRasterizer(raster_settings=raster_settings)
|
||||||
|
rgb, depth, alpha = renderer(
|
||||||
|
positions = positions,
|
||||||
|
decoupolys_V = decoupolys_V,
|
||||||
|
decoupolys_g = decoupolys_g,
|
||||||
|
densities = densities,
|
||||||
|
shs = shs,
|
||||||
|
colors_precomp = colors_precomp,
|
||||||
|
depths = depths,
|
||||||
|
aabb = octree.aabb,
|
||||||
|
aux = aux,
|
||||||
|
)
|
||||||
|
ret['rgb'] = rgb
|
||||||
|
ret['depth'] = depth
|
||||||
|
ret['alpha'] = alpha
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class OctreeRenderer:
|
||||||
|
"""
|
||||||
|
Renderer for the Voxel representation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rendering_options (dict): Rendering options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rendering_options={}) -> None:
|
||||||
|
try:
|
||||||
|
import diffoctreerast
|
||||||
|
except ImportError:
|
||||||
|
print("\033[93m[WARNING] diffoctreerast is not installed. The renderer will be disabled.\033[0m")
|
||||||
|
self.unsupported = True
|
||||||
|
else:
|
||||||
|
self.unsupported = False
|
||||||
|
|
||||||
|
self.pipe = edict({
|
||||||
|
"with_distloss": False,
|
||||||
|
"with_aux": False,
|
||||||
|
"scale_modifier": 1.0,
|
||||||
|
"used_rank": None,
|
||||||
|
"jitter": False,
|
||||||
|
"debug": False,
|
||||||
|
})
|
||||||
|
self.rendering_options = edict({
|
||||||
|
"resolution": None,
|
||||||
|
"near": None,
|
||||||
|
"far": None,
|
||||||
|
"ssaa": 1,
|
||||||
|
"bg_color": 'random',
|
||||||
|
})
|
||||||
|
self.halton_sampler = qmc.Halton(2, scramble=False)
|
||||||
|
self.rendering_options.update(rendering_options)
|
||||||
|
self.bg_color = None
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
octree: DfsOctree,
|
||||||
|
extrinsics: torch.Tensor,
|
||||||
|
intrinsics: torch.Tensor,
|
||||||
|
colors_overwrite: torch.Tensor = None,
|
||||||
|
) -> edict:
|
||||||
|
"""
|
||||||
|
Render the octree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
octree (Octree): octree
|
||||||
|
extrinsics (torch.Tensor): (4, 4) camera extrinsics
|
||||||
|
intrinsics (torch.Tensor): (3, 3) camera intrinsics
|
||||||
|
colors_overwrite (torch.Tensor): (N, 3) override color
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
edict containing:
|
||||||
|
color (torch.Tensor): (3, H, W) rendered color
|
||||||
|
depth (torch.Tensor): (H, W) rendered depth
|
||||||
|
alpha (torch.Tensor): (H, W) rendered alpha
|
||||||
|
distloss (Optional[torch.Tensor]): (H, W) rendered distance loss
|
||||||
|
percent_depth (Optional[torch.Tensor]): (H, W) rendered percent depth
|
||||||
|
aux (Optional[edict]): auxiliary tensors
|
||||||
|
"""
|
||||||
|
resolution = self.rendering_options["resolution"]
|
||||||
|
near = self.rendering_options["near"]
|
||||||
|
far = self.rendering_options["far"]
|
||||||
|
ssaa = self.rendering_options["ssaa"]
|
||||||
|
|
||||||
|
if self.unsupported:
|
||||||
|
image = np.zeros((512, 512, 3), dtype=np.uint8)
|
||||||
|
text_bbox = cv2.getTextSize("Unsupported", cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0]
|
||||||
|
origin = (512 - text_bbox[0]) // 2, (512 - text_bbox[1]) // 2
|
||||||
|
image = cv2.putText(image, "Unsupported", origin, cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3, cv2.LINE_AA)
|
||||||
|
return {
|
||||||
|
'color': torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.rendering_options["bg_color"] == 'random':
|
||||||
|
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
|
||||||
|
if np.random.rand() < 0.5:
|
||||||
|
self.bg_color += 1
|
||||||
|
else:
|
||||||
|
self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
if self.pipe["with_aux"]:
|
||||||
|
aux = {
|
||||||
|
'grad_color2': torch.zeros((octree.num_leaf_nodes, 3), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
|
||||||
|
'contributions': torch.zeros((octree.num_leaf_nodes, 1), dtype=torch.float32, requires_grad=True, device="cuda") + 0,
|
||||||
|
}
|
||||||
|
for k in aux.keys():
|
||||||
|
aux[k].requires_grad_()
|
||||||
|
aux[k].retain_grad()
|
||||||
|
else:
|
||||||
|
aux = None
|
||||||
|
|
||||||
|
view = extrinsics
|
||||||
|
perspective = intrinsics_to_projection(intrinsics, near, far)
|
||||||
|
camera = torch.inverse(view)[:3, 3]
|
||||||
|
focalx = intrinsics[0, 0]
|
||||||
|
focaly = intrinsics[1, 1]
|
||||||
|
fovx = 2 * torch.atan(0.5 / focalx)
|
||||||
|
fovy = 2 * torch.atan(0.5 / focaly)
|
||||||
|
|
||||||
|
camera_dict = edict({
|
||||||
|
"image_height": resolution * ssaa,
|
||||||
|
"image_width": resolution * ssaa,
|
||||||
|
"FoVx": fovx,
|
||||||
|
"FoVy": fovy,
|
||||||
|
"znear": near,
|
||||||
|
"zfar": far,
|
||||||
|
"world_view_transform": view.T.contiguous(),
|
||||||
|
"projection_matrix": perspective.T.contiguous(),
|
||||||
|
"full_proj_transform": (perspective @ view).T.contiguous(),
|
||||||
|
"camera_center": camera
|
||||||
|
})
|
||||||
|
|
||||||
|
# Render
|
||||||
|
render_ret = render(camera_dict, octree, self.pipe, self.bg_color, aux=aux, colors_overwrite=colors_overwrite, scaling_modifier=self.pipe.scale_modifier, used_rank=self.pipe.used_rank, halton_sampler=self.halton_sampler)
|
||||||
|
|
||||||
|
if ssaa > 1:
|
||||||
|
render_ret.rgb = F.interpolate(render_ret.rgb[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
|
||||||
|
render_ret.depth = F.interpolate(render_ret.depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
|
||||||
|
render_ret.alpha = F.interpolate(render_ret.alpha[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
|
||||||
|
if hasattr(render_ret, 'percent_depth'):
|
||||||
|
render_ret.percent_depth = F.interpolate(render_ret.percent_depth[None, None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
|
||||||
|
|
||||||
|
ret = edict({
|
||||||
|
'color': render_ret.rgb,
|
||||||
|
'depth': render_ret.depth,
|
||||||
|
'alpha': render_ret.alpha,
|
||||||
|
})
|
||||||
|
if self.pipe["with_distloss"] and 'distloss' in render_ret:
|
||||||
|
ret['distloss'] = render_ret.distloss
|
||||||
|
if self.pipe["with_aux"]:
|
||||||
|
ret['aux'] = aux
|
||||||
|
if hasattr(render_ret, 'percent_depth'):
|
||||||
|
ret['percent_depth'] = render_ret.percent_depth
|
||||||
|
return ret
|
||||||
347
trellis/representations/octree/octree_dfs.py
Normal file
347
trellis/representations/octree/octree_dfs.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class DfsOctree:
|
||||||
|
"""
|
||||||
|
Sparse Voxel Octree (SVO) implementation for PyTorch.
|
||||||
|
Using Depth-First Search (DFS) order to store the octree.
|
||||||
|
DFS order suits rendering and ray tracing.
|
||||||
|
|
||||||
|
The structure and data are separatedly stored.
|
||||||
|
Structure is stored as a continuous array, each element is a 3*32 bits descriptor.
|
||||||
|
|-----------------------------------------|
|
||||||
|
| 0:3 bits | 4:31 bits |
|
||||||
|
| leaf num | unused |
|
||||||
|
|-----------------------------------------|
|
||||||
|
| 0:31 bits |
|
||||||
|
| child ptr |
|
||||||
|
|-----------------------------------------|
|
||||||
|
| 0:31 bits |
|
||||||
|
| data ptr |
|
||||||
|
|-----------------------------------------|
|
||||||
|
Each element represents a non-leaf node in the octree.
|
||||||
|
The valid mask is used to indicate whether the children are valid.
|
||||||
|
The leaf mask is used to indicate whether the children are leaf nodes.
|
||||||
|
The child ptr is used to point to the first non-leaf child. Non-leaf children descriptors are stored continuously from the child ptr.
|
||||||
|
The data ptr is used to point to the data of leaf children. Leaf children data are stored continuously from the data ptr.
|
||||||
|
|
||||||
|
There are also auxiliary arrays to store the additional structural information to facilitate parallel processing.
|
||||||
|
- Position: the position of the octree nodes.
|
||||||
|
- Depth: the depth of the octree nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
depth (int): the depth of the octree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
depth,
|
||||||
|
aabb=[0,0,0,1,1,1],
|
||||||
|
sh_degree=2,
|
||||||
|
primitive='voxel',
|
||||||
|
primitive_config={},
|
||||||
|
device='cuda',
|
||||||
|
):
|
||||||
|
self.max_depth = depth
|
||||||
|
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||||
|
self.device = device
|
||||||
|
self.sh_degree = sh_degree
|
||||||
|
self.active_sh_degree = sh_degree
|
||||||
|
self.primitive = primitive
|
||||||
|
self.primitive_config = primitive_config
|
||||||
|
|
||||||
|
self.structure = torch.tensor([[8, 1, 0]], dtype=torch.int32, device=self.device)
|
||||||
|
self.position = torch.zeros((8, 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.depth = torch.zeros((8, 1), dtype=torch.uint8, device=self.device)
|
||||||
|
self.position[:, 0] = torch.tensor([0.25, 0.75, 0.25, 0.75, 0.25, 0.75, 0.25, 0.75], device=self.device)
|
||||||
|
self.position[:, 1] = torch.tensor([0.25, 0.25, 0.75, 0.75, 0.25, 0.25, 0.75, 0.75], device=self.device)
|
||||||
|
self.position[:, 2] = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.75, 0.75, 0.75, 0.75], device=self.device)
|
||||||
|
self.depth[:, 0] = 1
|
||||||
|
|
||||||
|
self.data = ['position', 'depth']
|
||||||
|
self.param_names = []
|
||||||
|
|
||||||
|
if primitive == 'voxel':
|
||||||
|
self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.data += ['features_dc', 'features_ac']
|
||||||
|
self.param_names += ['features_dc', 'features_ac']
|
||||||
|
if not primitive_config.get('solid', False):
|
||||||
|
self.density = torch.zeros((8, 1), dtype=torch.float32, device=self.device)
|
||||||
|
self.data.append('density')
|
||||||
|
self.param_names.append('density')
|
||||||
|
elif primitive == 'gaussian':
|
||||||
|
self.features_dc = torch.zeros((8, 1, 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.features_ac = torch.zeros((8, (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.opacity = torch.zeros((8, 1), dtype=torch.float32, device=self.device)
|
||||||
|
self.data += ['features_dc', 'features_ac', 'opacity']
|
||||||
|
self.param_names += ['features_dc', 'features_ac', 'opacity']
|
||||||
|
elif primitive == 'trivec':
|
||||||
|
self.trivec = torch.zeros((8, primitive_config['rank'], 3, primitive_config['dim']), dtype=torch.float32, device=self.device)
|
||||||
|
self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device)
|
||||||
|
self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.density_shift = 0
|
||||||
|
self.data += ['trivec', 'density', 'features_dc', 'features_ac']
|
||||||
|
self.param_names += ['trivec', 'density', 'features_dc', 'features_ac']
|
||||||
|
elif primitive == 'decoupoly':
|
||||||
|
self.decoupoly_V = torch.zeros((8, primitive_config['rank'], 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.decoupoly_g = torch.zeros((8, primitive_config['rank'], primitive_config['degree']), dtype=torch.float32, device=self.device)
|
||||||
|
self.density = torch.zeros((8, primitive_config['rank']), dtype=torch.float32, device=self.device)
|
||||||
|
self.features_dc = torch.zeros((8, primitive_config['rank'], 1, 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.features_ac = torch.zeros((8, primitive_config['rank'], (sh_degree+1)**2-1, 3), dtype=torch.float32, device=self.device)
|
||||||
|
self.density_shift = 0
|
||||||
|
self.data += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac']
|
||||||
|
self.param_names += ['decoupoly_V', 'decoupoly_g', 'density', 'features_dc', 'features_ac']
|
||||||
|
|
||||||
|
self.setup_functions()
|
||||||
|
|
||||||
|
def setup_functions(self):
|
||||||
|
self.density_activation = (lambda x: torch.exp(x - 2)) if self.primitive != 'trivec' else (lambda x: x)
|
||||||
|
self.opacity_activation = lambda x: torch.sigmoid(x - 6)
|
||||||
|
self.inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + 6
|
||||||
|
self.color_activation = lambda x: torch.sigmoid(x)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_non_leaf_nodes(self):
|
||||||
|
return self.structure.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_leaf_nodes(self):
|
||||||
|
return self.depth.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cur_depth(self):
|
||||||
|
return self.depth.max().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def occupancy(self):
|
||||||
|
return self.num_leaf_nodes / 8 ** self.cur_depth
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_xyz(self):
|
||||||
|
return self.position
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_depth(self):
|
||||||
|
return self.depth
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_density(self):
|
||||||
|
if self.primitive == 'voxel' and self.primitive_config.get('solid', False):
|
||||||
|
return torch.full((self.position.shape[0], 1), torch.finfo(torch.float32).max, dtype=torch.float32, device=self.device)
|
||||||
|
return self.density_activation(self.density)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_opacity(self):
|
||||||
|
return self.opacity_activation(self.density)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_trivec(self):
|
||||||
|
return self.trivec
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_decoupoly(self):
|
||||||
|
return F.normalize(self.decoupoly_V, dim=-1), self.decoupoly_g
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_color(self):
|
||||||
|
return self.color_activation(self.colors)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_features(self):
|
||||||
|
if self.sh_degree == 0:
|
||||||
|
return self.features_dc
|
||||||
|
return torch.cat([self.features_dc, self.features_ac], dim=-2)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
ret = {'structure': self.structure, 'position': self.position, 'depth': self.depth, 'sh_degree': self.sh_degree, 'active_sh_degree': self.active_sh_degree, 'primitive_config': self.primitive_config, 'primitive': self.primitive}
|
||||||
|
if hasattr(self, 'density_shift'):
|
||||||
|
ret['density_shift'] = self.density_shift
|
||||||
|
for data in set(self.data + self.param_names):
|
||||||
|
if not isinstance(getattr(self, data), nn.Module):
|
||||||
|
ret[data] = getattr(self, data)
|
||||||
|
else:
|
||||||
|
ret[data] = getattr(self, data).state_dict()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
keys = list(set(self.data + self.param_names + list(state_dict.keys()) + ['structure', 'position', 'depth']))
|
||||||
|
for key in keys:
|
||||||
|
if key not in state_dict:
|
||||||
|
print(f"Warning: key {key} not found in the state_dict.")
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if not isinstance(getattr(self, key), nn.Module):
|
||||||
|
setattr(self, key, state_dict[key])
|
||||||
|
else:
|
||||||
|
getattr(self, key).load_state_dict(state_dict[key])
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise ValueError(f"Error loading key {key}.")
|
||||||
|
|
||||||
|
def gather_from_leaf_children(self, data):
|
||||||
|
"""
|
||||||
|
Gather the data from the leaf children.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes.
|
||||||
|
"""
|
||||||
|
leaf_cnt = self.structure[:, 0]
|
||||||
|
leaf_cnt_masks = [leaf_cnt == i for i in range(1, 9)]
|
||||||
|
ret = torch.zeros((self.num_non_leaf_nodes,), dtype=data.dtype, device=self.device)
|
||||||
|
for i in range(8):
|
||||||
|
if leaf_cnt_masks[i].sum() == 0:
|
||||||
|
continue
|
||||||
|
start = self.structure[leaf_cnt_masks[i], 2]
|
||||||
|
for j in range(i+1):
|
||||||
|
ret[leaf_cnt_masks[i]] += data[start + j]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def gather_from_non_leaf_children(self, data):
|
||||||
|
"""
|
||||||
|
Gather the data from the non-leaf children.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (torch.Tensor): the data to gather. The first dimension should be the number of leaf nodes.
|
||||||
|
"""
|
||||||
|
non_leaf_cnt = 8 - self.structure[:, 0]
|
||||||
|
non_leaf_cnt_masks = [non_leaf_cnt == i for i in range(1, 9)]
|
||||||
|
ret = torch.zeros_like(data, device=self.device)
|
||||||
|
for i in range(8):
|
||||||
|
if non_leaf_cnt_masks[i].sum() == 0:
|
||||||
|
continue
|
||||||
|
start = self.structure[non_leaf_cnt_masks[i], 1]
|
||||||
|
for j in range(i+1):
|
||||||
|
ret[non_leaf_cnt_masks[i]] += data[start + j]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def structure_control(self, mask):
|
||||||
|
"""
|
||||||
|
Control the structure of the octree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask (torch.Tensor): the mask to control the structure. 1 for subdivide, -1 for merge, 0 for keep.
|
||||||
|
"""
|
||||||
|
# Dont subdivide when the depth is the maximum.
|
||||||
|
mask[self.depth.squeeze() == self.max_depth] = torch.clamp_max(mask[self.depth.squeeze() == self.max_depth], 0)
|
||||||
|
# Dont merge when the depth is the minimum.
|
||||||
|
mask[self.depth.squeeze() == 1] = torch.clamp_min(mask[self.depth.squeeze() == 1], 0)
|
||||||
|
|
||||||
|
# Gather control mask
|
||||||
|
structre_ctrl = self.gather_from_leaf_children(mask)
|
||||||
|
structre_ctrl[structre_ctrl==-8] = -1
|
||||||
|
|
||||||
|
new_leaf_num = self.structure[:, 0].clone()
|
||||||
|
# Modify the leaf num.
|
||||||
|
structre_valid = structre_ctrl >= 0
|
||||||
|
new_leaf_num[structre_valid] -= structre_ctrl[structre_valid] # Add the new nodes.
|
||||||
|
structre_delete = structre_ctrl < 0
|
||||||
|
merged_nodes = self.gather_from_non_leaf_children(structre_delete.int())
|
||||||
|
new_leaf_num += merged_nodes # Delete the merged nodes.
|
||||||
|
|
||||||
|
# Update the structure array to allocate new nodes.
|
||||||
|
mem_offset = torch.zeros((self.num_non_leaf_nodes + 1,), dtype=torch.int32, device=self.device)
|
||||||
|
mem_offset.index_add_(0, self.structure[structre_valid, 1], structre_ctrl[structre_valid]) # Add the new nodes.
|
||||||
|
mem_offset[:-1] -= structre_delete.int() # Delete the merged nodes.
|
||||||
|
new_structre_idx = torch.arange(0, self.num_non_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0)
|
||||||
|
new_structure_length = new_structre_idx[-1].item()
|
||||||
|
new_structre_idx = new_structre_idx[:-1]
|
||||||
|
new_structure = torch.empty((new_structure_length, 3), dtype=torch.int32, device=self.device)
|
||||||
|
new_structure[new_structre_idx[structre_valid], 0] = new_leaf_num[structre_valid]
|
||||||
|
|
||||||
|
# Initialize the new nodes.
|
||||||
|
new_node_mask = torch.ones((new_structure_length,), dtype=torch.bool, device=self.device)
|
||||||
|
new_node_mask[new_structre_idx[structre_valid]] = False
|
||||||
|
new_structure[new_node_mask, 0] = 8 # Initialize to all leaf nodes.
|
||||||
|
new_node_num = new_node_mask.sum().item()
|
||||||
|
|
||||||
|
# Rebuild child ptr.
|
||||||
|
non_leaf_cnt = 8 - new_structure[:, 0]
|
||||||
|
new_child_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), non_leaf_cnt.cumsum(0)[:-1]])
|
||||||
|
new_structure[:, 1] = new_child_ptr + 1
|
||||||
|
|
||||||
|
# Rebuild data ptr with old data.
|
||||||
|
leaf_cnt = torch.zeros((new_structure_length,), dtype=torch.int32, device=self.device)
|
||||||
|
leaf_cnt.index_add_(0, new_structre_idx, self.structure[:, 0])
|
||||||
|
old_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]])
|
||||||
|
|
||||||
|
# Update the data array
|
||||||
|
subdivide_mask = mask == 1
|
||||||
|
merge_mask = mask == -1
|
||||||
|
data_valid = ~(subdivide_mask | merge_mask)
|
||||||
|
mem_offset = torch.zeros((self.num_leaf_nodes + 1,), dtype=torch.int32, device=self.device)
|
||||||
|
mem_offset.index_add_(0, old_data_ptr[new_node_mask], torch.full((new_node_num,), 8, dtype=torch.int32, device=self.device)) # Add data array for new nodes
|
||||||
|
mem_offset[:-1] -= subdivide_mask.int() # Delete data elements for subdivide nodes
|
||||||
|
mem_offset[:-1] -= merge_mask.int() # Delete data elements for merge nodes
|
||||||
|
mem_offset.index_add_(0, self.structure[structre_valid, 2], merged_nodes[structre_valid]) # Add data elements for merge nodes
|
||||||
|
new_data_idx = torch.arange(0, self.num_leaf_nodes + 1, dtype=torch.int32, device=self.device) + mem_offset.cumsum(0)
|
||||||
|
new_data_length = new_data_idx[-1].item()
|
||||||
|
new_data_idx = new_data_idx[:-1]
|
||||||
|
new_data = {data: torch.empty((new_data_length,) + getattr(self, data).shape[1:], dtype=getattr(self, data).dtype, device=self.device) for data in self.data}
|
||||||
|
for data in self.data:
|
||||||
|
new_data[data][new_data_idx[data_valid]] = getattr(self, data)[data_valid]
|
||||||
|
|
||||||
|
# Rebuild data ptr
|
||||||
|
leaf_cnt = new_structure[:, 0]
|
||||||
|
new_data_ptr = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), leaf_cnt.cumsum(0)[:-1]])
|
||||||
|
new_structure[:, 2] = new_data_ptr
|
||||||
|
|
||||||
|
# Initialize the new data array
|
||||||
|
## For subdivide nodes
|
||||||
|
if subdivide_mask.sum() > 0:
|
||||||
|
subdivide_data_ptr = new_structure[new_node_mask, 2]
|
||||||
|
for data in self.data:
|
||||||
|
for i in range(8):
|
||||||
|
if data == 'position':
|
||||||
|
offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) - 0.5
|
||||||
|
scale = 2 ** (-1.0 - self.depth[subdivide_mask])
|
||||||
|
new_data['position'][subdivide_data_ptr + i] = self.position[subdivide_mask] + offset * scale
|
||||||
|
elif data == 'depth':
|
||||||
|
new_data['depth'][subdivide_data_ptr + i] = self.depth[subdivide_mask] + 1
|
||||||
|
elif data == 'opacity':
|
||||||
|
new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(torch.sqrt(self.opacity_activation(self.opacity[subdivide_mask])))
|
||||||
|
elif data == 'trivec':
|
||||||
|
offset = torch.tensor([i // 4, (i // 2) % 2, i % 2], dtype=torch.float32, device=self.device) * 0.5
|
||||||
|
coord = (torch.linspace(0, 0.5, self.trivec.shape[-1], dtype=torch.float32, device=self.device)[None] + offset[:, None]).reshape(1, 3, self.trivec.shape[-1], 1)
|
||||||
|
axis = torch.linspace(0, 1, 3, dtype=torch.float32, device=self.device).reshape(1, 3, 1, 1).repeat(1, 1, self.trivec.shape[-1], 1)
|
||||||
|
coord = torch.stack([coord, axis], dim=3).reshape(1, 3, self.trivec.shape[-1], 2).expand(self.trivec[subdivide_mask].shape[0], -1, -1, -1) * 2 - 1
|
||||||
|
new_data['trivec'][subdivide_data_ptr + i] = F.grid_sample(self.trivec[subdivide_mask], coord, align_corners=True)
|
||||||
|
else:
|
||||||
|
new_data[data][subdivide_data_ptr + i] = getattr(self, data)[subdivide_mask]
|
||||||
|
## For merge nodes
|
||||||
|
if merge_mask.sum() > 0:
|
||||||
|
merge_data_ptr = torch.empty((merged_nodes.sum().item(),), dtype=torch.int32, device=self.device)
|
||||||
|
merge_nodes_cumsum = torch.cat([torch.zeros((1,), dtype=torch.int32, device=self.device), merged_nodes.cumsum(0)[:-1]])
|
||||||
|
for i in range(8):
|
||||||
|
merge_data_ptr[merge_nodes_cumsum[merged_nodes > i] + i] = new_structure[new_structre_idx[merged_nodes > i], 2] + i
|
||||||
|
old_merge_data_ptr = self.structure[structre_delete, 2]
|
||||||
|
for data in self.data:
|
||||||
|
if data == 'position':
|
||||||
|
scale = 2 ** (1.0 - self.depth[old_merge_data_ptr])
|
||||||
|
new_data['position'][merge_data_ptr] = ((self.position[old_merge_data_ptr] + 0.5) / scale).floor() * scale + 0.5 * scale - 0.5
|
||||||
|
elif data == 'depth':
|
||||||
|
new_data['depth'][merge_data_ptr] = self.depth[old_merge_data_ptr] - 1
|
||||||
|
elif data == 'opacity':
|
||||||
|
new_data['opacity'][subdivide_data_ptr + i] = self.inverse_opacity_activation(self.opacity_activation(self.opacity[subdivide_mask])**2)
|
||||||
|
elif data == 'trivec':
|
||||||
|
new_data['trivec'][merge_data_ptr] = self.trivec[old_merge_data_ptr]
|
||||||
|
else:
|
||||||
|
new_data[data][merge_data_ptr] = getattr(self, data)[old_merge_data_ptr]
|
||||||
|
|
||||||
|
# Update the structure and data array
|
||||||
|
self.structure = new_structure
|
||||||
|
for data in self.data:
|
||||||
|
setattr(self, data, new_data[data])
|
||||||
|
|
||||||
|
# Save data array control temp variables
|
||||||
|
self.data_rearrange_buffer = {
|
||||||
|
'subdivide_mask': subdivide_mask,
|
||||||
|
'merge_mask': merge_mask,
|
||||||
|
'data_valid': data_valid,
|
||||||
|
'new_data_idx': new_data_idx,
|
||||||
|
'new_data_length': new_data_length,
|
||||||
|
'new_data': new_data
|
||||||
|
}
|
||||||
28
trellis/representations/radiance_field/strivec.py
Normal file
28
trellis/representations/radiance_field/strivec.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from ..octree import DfsOctree as Octree
|
||||||
|
|
||||||
|
|
||||||
|
class Strivec(Octree):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
resolution: int,
|
||||||
|
aabb: list,
|
||||||
|
sh_degree: int = 0,
|
||||||
|
rank: int = 8,
|
||||||
|
dim: int = 8,
|
||||||
|
device: str = "cuda",
|
||||||
|
):
|
||||||
|
assert np.log2(resolution) % 1 == 0, "Resolution must be a power of 2"
|
||||||
|
self.resolution = resolution
|
||||||
|
depth = int(np.round(np.log2(resolution)))
|
||||||
|
super().__init__(
|
||||||
|
depth=depth,
|
||||||
|
aabb=aabb,
|
||||||
|
sh_degree=sh_degree,
|
||||||
|
primitive="trivec",
|
||||||
|
primitive_config={"rank": rank, "dim": dim},
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
275
trellis/trainers/vae/structured_latent_vae_gaussian.py
Normal file
275
trellis/trainers/vae/structured_latent_vae_gaussian.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
from typing import *
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
import utils3d.torch
|
||||||
|
|
||||||
|
from ..basic import BasicTrainer
|
||||||
|
from ...representations import Gaussian
|
||||||
|
from ...renderers import GaussianRenderer
|
||||||
|
from ...modules.sparse import SparseTensor
|
||||||
|
from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
|
||||||
|
|
||||||
|
|
||||||
|
class SLatVaeGaussianTrainer(BasicTrainer):
|
||||||
|
"""
|
||||||
|
Trainer for structured latent VAE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models (dict[str, nn.Module]): Models to train.
|
||||||
|
dataset (torch.utils.data.Dataset): Dataset.
|
||||||
|
output_dir (str): Output directory.
|
||||||
|
load_dir (str): Load directory.
|
||||||
|
step (int): Step to load.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||||
|
batch_split (int): Split batch with gradient accumulation.
|
||||||
|
max_steps (int): Max steps.
|
||||||
|
optimizer (dict): Optimizer config.
|
||||||
|
lr_scheduler (dict): Learning rate scheduler config.
|
||||||
|
elastic (dict): Elastic memory management config.
|
||||||
|
grad_clip (float or dict): Gradient clip config.
|
||||||
|
ema_rate (float or list): Exponential moving average rates.
|
||||||
|
fp16_mode (str): FP16 mode.
|
||||||
|
- None: No FP16.
|
||||||
|
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||||
|
- 'amp': Automatic mixed precision.
|
||||||
|
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||||
|
finetune_ckpt (dict): Finetune checkpoint.
|
||||||
|
log_param_stats (bool): Log parameter stats.
|
||||||
|
i_print (int): Print interval.
|
||||||
|
i_log (int): Log interval.
|
||||||
|
i_sample (int): Sample interval.
|
||||||
|
i_save (int): Save interval.
|
||||||
|
i_ddpcheck (int): DDP check interval.
|
||||||
|
|
||||||
|
loss_type (str): Loss type. Can be 'l1', 'l2'
|
||||||
|
lambda_ssim (float): SSIM loss weight.
|
||||||
|
lambda_lpips (float): LPIPS loss weight.
|
||||||
|
lambda_kl (float): KL loss weight.
|
||||||
|
regularizations (dict): Regularization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
loss_type: str = 'l1',
|
||||||
|
lambda_ssim: float = 0.2,
|
||||||
|
lambda_lpips: float = 0.2,
|
||||||
|
lambda_kl: float = 1e-6,
|
||||||
|
regularizations: Dict = {},
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.loss_type = loss_type
|
||||||
|
self.lambda_ssim = lambda_ssim
|
||||||
|
self.lambda_lpips = lambda_lpips
|
||||||
|
self.lambda_kl = lambda_kl
|
||||||
|
self.regularizations = regularizations
|
||||||
|
|
||||||
|
self._init_renderer()
|
||||||
|
|
||||||
|
def _init_renderer(self):
|
||||||
|
rendering_options = {"near" : 0.8,
|
||||||
|
"far" : 1.6,
|
||||||
|
"bg_color" : 'random'}
|
||||||
|
self.renderer = GaussianRenderer(rendering_options)
|
||||||
|
self.renderer.pipe.kernel_size = self.models['decoder'].rep_config['2d_filter_kernel_size']
|
||||||
|
|
||||||
|
def _render_batch(self, reps: List[Gaussian], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Render a batch of representations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reps: The dictionary of lists of representations.
|
||||||
|
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
||||||
|
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
||||||
|
"""
|
||||||
|
ret = None
|
||||||
|
for i, representation in enumerate(reps):
|
||||||
|
render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
|
||||||
|
if ret is None:
|
||||||
|
ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
|
||||||
|
for k, v in render_pack.items():
|
||||||
|
ret[k].append(v)
|
||||||
|
ret['bg_color'].append(self.renderer.bg_color)
|
||||||
|
for k, v in ret.items():
|
||||||
|
ret[k] = torch.stack(v, dim=0)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _get_status(self, z: SparseTensor, reps: List[Gaussian]) -> Dict:
|
||||||
|
xyz = torch.cat([g.get_xyz for g in reps], dim=0)
|
||||||
|
xyz_base = (z.coords[:, 1:].float() + 0.5) / self.models['decoder'].resolution - 0.5
|
||||||
|
offset = xyz - xyz_base.unsqueeze(1).expand(-1, self.models['decoder'].rep_config['num_gaussians'], -1).reshape(-1, 3)
|
||||||
|
status = {
|
||||||
|
'xyz': xyz,
|
||||||
|
'offset': offset,
|
||||||
|
'scale': torch.cat([g.get_scaling for g in reps], dim=0),
|
||||||
|
'opacity': torch.cat([g.get_opacity for g in reps], dim=0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in list(status.keys()):
|
||||||
|
status[k] = {
|
||||||
|
'mean': status[k].mean().item(),
|
||||||
|
'max': status[k].max().item(),
|
||||||
|
'min': status[k].min().item(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return status
|
||||||
|
|
||||||
|
def _get_regularization_loss(self, reps: List[Gaussian]) -> Tuple[torch.Tensor, Dict]:
|
||||||
|
loss = 0.0
|
||||||
|
terms = {}
|
||||||
|
if 'lambda_vol' in self.regularizations:
|
||||||
|
scales = torch.cat([g.get_scaling for g in reps], dim=0) # [N x 3]
|
||||||
|
volume = torch.prod(scales, dim=1) # [N]
|
||||||
|
terms[f'reg_vol'] = volume.mean()
|
||||||
|
loss = loss + self.regularizations['lambda_vol'] * terms[f'reg_vol']
|
||||||
|
if 'lambda_opacity' in self.regularizations:
|
||||||
|
opacity = torch.cat([g.get_opacity for g in reps], dim=0)
|
||||||
|
terms[f'reg_opacity'] = (opacity - 1).pow(2).mean()
|
||||||
|
loss = loss + self.regularizations['lambda_opacity'] * terms[f'reg_opacity']
|
||||||
|
return loss, terms
|
||||||
|
|
||||||
|
def training_losses(
|
||||||
|
self,
|
||||||
|
feats: SparseTensor,
|
||||||
|
image: torch.Tensor,
|
||||||
|
alpha: torch.Tensor,
|
||||||
|
extrinsics: torch.Tensor,
|
||||||
|
intrinsics: torch.Tensor,
|
||||||
|
return_aux: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[Dict, Dict]:
|
||||||
|
"""
|
||||||
|
Compute training losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats: The [N x * x C] sparse tensor of features.
|
||||||
|
image: The [N x 3 x H x W] tensor of images.
|
||||||
|
alpha: The [N x H x W] tensor of alpha channels.
|
||||||
|
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
||||||
|
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
||||||
|
return_aux: Whether to return auxiliary information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a dict with the key "loss" containing a scalar tensor.
|
||||||
|
may also contain other keys for different terms.
|
||||||
|
"""
|
||||||
|
z, mean, logvar = self.training_models['encoder'](feats, sample_posterior=True, return_raw=True)
|
||||||
|
reps = self.training_models['decoder'](z)
|
||||||
|
self.renderer.rendering_options.resolution = image.shape[-1]
|
||||||
|
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
||||||
|
|
||||||
|
terms = edict(loss = 0.0, rec = 0.0)
|
||||||
|
|
||||||
|
rec_image = render_results['color']
|
||||||
|
gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
|
||||||
|
|
||||||
|
if self.loss_type == 'l1':
|
||||||
|
terms["l1"] = l1_loss(rec_image, gt_image)
|
||||||
|
terms["rec"] = terms["rec"] + terms["l1"]
|
||||||
|
elif self.loss_type == 'l2':
|
||||||
|
terms["l2"] = l2_loss(rec_image, gt_image)
|
||||||
|
terms["rec"] = terms["rec"] + terms["l2"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid loss type: {self.loss_type}")
|
||||||
|
if self.lambda_ssim > 0:
|
||||||
|
terms["ssim"] = 1 - ssim(rec_image, gt_image)
|
||||||
|
terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
|
||||||
|
if self.lambda_lpips > 0:
|
||||||
|
terms["lpips"] = lpips(rec_image, gt_image)
|
||||||
|
terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
|
||||||
|
terms["loss"] = terms["loss"] + terms["rec"]
|
||||||
|
|
||||||
|
terms["kl"] = 0.5 * torch.mean(mean.pow(2) + logvar.exp() - logvar - 1)
|
||||||
|
terms["loss"] = terms["loss"] + self.lambda_kl * terms["kl"]
|
||||||
|
|
||||||
|
reg_loss, reg_terms = self._get_regularization_loss(reps)
|
||||||
|
terms.update(reg_terms)
|
||||||
|
terms["loss"] = terms["loss"] + reg_loss
|
||||||
|
|
||||||
|
status = self._get_status(z, reps)
|
||||||
|
|
||||||
|
if return_aux:
|
||||||
|
return terms, status, {'rec_image': rec_image, 'gt_image': gt_image}
|
||||||
|
return terms, status
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run_snapshot(
|
||||||
|
self,
|
||||||
|
num_samples: int,
|
||||||
|
batch_size: int,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> Dict:
|
||||||
|
dataloader = DataLoader(
|
||||||
|
copy.deepcopy(self.dataset),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# inference
|
||||||
|
ret_dict = {}
|
||||||
|
gt_images = []
|
||||||
|
exts = []
|
||||||
|
ints = []
|
||||||
|
reps = []
|
||||||
|
for i in range(0, num_samples, batch_size):
|
||||||
|
batch = min(batch_size, num_samples - i)
|
||||||
|
data = next(iter(dataloader))
|
||||||
|
args = {k: v[:batch].cuda() for k, v in data.items()}
|
||||||
|
gt_images.append(args['image'] * args['alpha'][:, None])
|
||||||
|
exts.append(args['extrinsics'])
|
||||||
|
ints.append(args['intrinsics'])
|
||||||
|
z = self.models['encoder'](args['feats'], sample_posterior=True, return_raw=False)
|
||||||
|
reps.extend(self.models['decoder'](z))
|
||||||
|
gt_images = torch.cat(gt_images, dim=0)
|
||||||
|
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
|
||||||
|
|
||||||
|
# render single view
|
||||||
|
exts = torch.cat(exts, dim=0)
|
||||||
|
ints = torch.cat(ints, dim=0)
|
||||||
|
self.renderer.rendering_options.bg_color = (0, 0, 0)
|
||||||
|
self.renderer.rendering_options.resolution = gt_images.shape[-1]
|
||||||
|
render_results = self._render_batch(reps, exts, ints)
|
||||||
|
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
|
||||||
|
|
||||||
|
# render multiview
|
||||||
|
self.renderer.rendering_options.resolution = 512
|
||||||
|
## Build camera
|
||||||
|
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
||||||
|
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
||||||
|
yaws = [y + yaws_offset for y in yaws]
|
||||||
|
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
||||||
|
|
||||||
|
## render each view
|
||||||
|
miltiview_images = []
|
||||||
|
for yaw, pitch in zip(yaws, pitch):
|
||||||
|
orig = torch.tensor([
|
||||||
|
np.sin(yaw) * np.cos(pitch),
|
||||||
|
np.cos(yaw) * np.cos(pitch),
|
||||||
|
np.sin(pitch),
|
||||||
|
]).float().cuda() * 2
|
||||||
|
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
||||||
|
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
||||||
|
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
||||||
|
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
||||||
|
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
||||||
|
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
||||||
|
miltiview_images.append(render_results['color'])
|
||||||
|
|
||||||
|
## Concatenate views
|
||||||
|
miltiview_images = torch.cat([
|
||||||
|
torch.cat(miltiview_images[:2], dim=-2),
|
||||||
|
torch.cat(miltiview_images[2:], dim=-2),
|
||||||
|
], dim=-1)
|
||||||
|
ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
|
||||||
|
|
||||||
|
self.renderer.rendering_options.bg_color = 'random'
|
||||||
|
|
||||||
|
return ret_dict
|
||||||
382
trellis/trainers/vae/structured_latent_vae_mesh_dec.py
Normal file
382
trellis/trainers/vae/structured_latent_vae_mesh_dec.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
from typing import *
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
import utils3d.torch
|
||||||
|
|
||||||
|
from ..basic import BasicTrainer
|
||||||
|
from ...representations import MeshExtractResult
|
||||||
|
from ...renderers import MeshRenderer
|
||||||
|
from ...modules.sparse import SparseTensor
|
||||||
|
from ...utils.loss_utils import l1_loss, smooth_l1_loss, ssim, lpips
|
||||||
|
from ...utils.data_utils import recursive_to_device
|
||||||
|
|
||||||
|
|
||||||
|
class SLatVaeMeshDecoderTrainer(BasicTrainer):
|
||||||
|
"""
|
||||||
|
Trainer for structured latent VAE Mesh Decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models (dict[str, nn.Module]): Models to train.
|
||||||
|
dataset (torch.utils.data.Dataset): Dataset.
|
||||||
|
output_dir (str): Output directory.
|
||||||
|
load_dir (str): Load directory.
|
||||||
|
step (int): Step to load.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||||
|
batch_split (int): Split batch with gradient accumulation.
|
||||||
|
max_steps (int): Max steps.
|
||||||
|
optimizer (dict): Optimizer config.
|
||||||
|
lr_scheduler (dict): Learning rate scheduler config.
|
||||||
|
elastic (dict): Elastic memory management config.
|
||||||
|
grad_clip (float or dict): Gradient clip config.
|
||||||
|
ema_rate (float or list): Exponential moving average rates.
|
||||||
|
fp16_mode (str): FP16 mode.
|
||||||
|
- None: No FP16.
|
||||||
|
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||||
|
- 'amp': Automatic mixed precision.
|
||||||
|
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||||
|
finetune_ckpt (dict): Finetune checkpoint.
|
||||||
|
log_param_stats (bool): Log parameter stats.
|
||||||
|
i_print (int): Print interval.
|
||||||
|
i_log (int): Log interval.
|
||||||
|
i_sample (int): Sample interval.
|
||||||
|
i_save (int): Save interval.
|
||||||
|
i_ddpcheck (int): DDP check interval.
|
||||||
|
|
||||||
|
loss_type (str): Loss type. Can be 'l1', 'l2'
|
||||||
|
lambda_ssim (float): SSIM loss weight.
|
||||||
|
lambda_lpips (float): LPIPS loss weight.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
depth_loss_type: str = 'l1',
|
||||||
|
lambda_depth: int = 1,
|
||||||
|
lambda_ssim: float = 0.2,
|
||||||
|
lambda_lpips: float = 0.2,
|
||||||
|
lambda_tsdf: float = 0.01,
|
||||||
|
lambda_color: float = 0.1,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.depth_loss_type = depth_loss_type
|
||||||
|
self.lambda_depth = lambda_depth
|
||||||
|
self.lambda_ssim = lambda_ssim
|
||||||
|
self.lambda_lpips = lambda_lpips
|
||||||
|
self.lambda_tsdf = lambda_tsdf
|
||||||
|
self.lambda_color = lambda_color
|
||||||
|
self.use_color = self.lambda_color > 0
|
||||||
|
|
||||||
|
self._init_renderer()
|
||||||
|
|
||||||
|
def _init_renderer(self):
|
||||||
|
rendering_options = {"near" : 1,
|
||||||
|
"far" : 3}
|
||||||
|
self.renderer = MeshRenderer(rendering_options, device=self.device)
|
||||||
|
|
||||||
|
def _render_batch(self, reps: List[MeshExtractResult], extrinsics: torch.Tensor, intrinsics: torch.Tensor,
|
||||||
|
return_types=['mask', 'normal', 'depth']) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Render a batch of representations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reps: The dictionary of lists of representations.
|
||||||
|
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
||||||
|
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
||||||
|
return_types: vary in ['mask', 'normal', 'depth', 'normal_map', 'color']
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a dict with
|
||||||
|
reg_loss : [N] tensor of regularization losses
|
||||||
|
mask : [N x 1 x H x W] tensor of rendered masks
|
||||||
|
normal : [N x 3 x H x W] tensor of rendered normals
|
||||||
|
depth : [N x 1 x H x W] tensor of rendered depths
|
||||||
|
"""
|
||||||
|
ret = {k : [] for k in return_types}
|
||||||
|
for i, rep in enumerate(reps):
|
||||||
|
out_dict = self.renderer.render(rep, extrinsics[i], intrinsics[i], return_types=return_types)
|
||||||
|
for k in out_dict:
|
||||||
|
ret[k].append(out_dict[k][None] if k in ['mask', 'depth'] else out_dict[k])
|
||||||
|
for k in ret:
|
||||||
|
ret[k] = torch.stack(ret[k])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tsdf_reg_loss(rep: MeshExtractResult, depth_map: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Calculate tsdf
|
||||||
|
with torch.no_grad():
|
||||||
|
# Project points to camera and calculate pseudo-sdf as difference between gt depth and projected depth
|
||||||
|
projected_pts, pts_depth = utils3d.torch.project_cv(extrinsics=extrinsics, intrinsics=intrinsics, points=rep.tsdf_v)
|
||||||
|
projected_pts = (projected_pts - 0.5) * 2.0
|
||||||
|
depth_map_res = depth_map.shape[1]
|
||||||
|
gt_depth = torch.nn.functional.grid_sample(depth_map.reshape(1, 1, depth_map_res, depth_map_res),
|
||||||
|
projected_pts.reshape(1, 1, -1, 2), mode='bilinear', padding_mode='border', align_corners=True)
|
||||||
|
pseudo_sdf = gt_depth.flatten() - pts_depth.flatten()
|
||||||
|
# Truncate pseudo-sdf
|
||||||
|
delta = 1 / rep.res * 3.0
|
||||||
|
trunc_mask = pseudo_sdf > -delta
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
gt_tsdf = pseudo_sdf[trunc_mask]
|
||||||
|
tsdf = rep.tsdf_s.flatten()[trunc_mask]
|
||||||
|
gt_tsdf = torch.clamp(gt_tsdf, -delta, delta)
|
||||||
|
return torch.mean((tsdf - gt_tsdf) ** 2)
|
||||||
|
|
||||||
|
def _calc_tsdf_loss(self, reps : list[MeshExtractResult], depth_maps, extrinsics, intrinsics) -> torch.Tensor:
|
||||||
|
tsdf_loss = 0.0
|
||||||
|
for i, rep in enumerate(reps):
|
||||||
|
tsdf_loss += self._tsdf_reg_loss(rep, depth_maps[i], extrinsics[i], intrinsics[i])
|
||||||
|
return tsdf_loss / len(reps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _flip_normal(self, normal: torch.Tensor, extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Flip normal to align with camera.
|
||||||
|
"""
|
||||||
|
normal = normal * 2.0 - 1.0
|
||||||
|
R = torch.zeros_like(extrinsics)
|
||||||
|
R[:, :3, :3] = extrinsics[:, :3, :3]
|
||||||
|
R[:, 3, 3] = 1.0
|
||||||
|
view_dir = utils3d.torch.unproject_cv(
|
||||||
|
utils3d.torch.image_uv(*normal.shape[-2:], device=self.device).reshape(1, -1, 2),
|
||||||
|
torch.ones(*normal.shape[-2:], device=self.device).reshape(1, -1),
|
||||||
|
R, intrinsics
|
||||||
|
).reshape(-1, *normal.shape[-2:], 3).permute(0, 3, 1, 2)
|
||||||
|
unflip = (normal * view_dir).sum(1, keepdim=True) < 0
|
||||||
|
normal *= unflip * 2.0 - 1.0
|
||||||
|
return (normal + 1.0) / 2.0
|
||||||
|
|
||||||
|
def _perceptual_loss(self, gt: torch.Tensor, pred: torch.Tensor, name: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Combination of L1, SSIM, and LPIPS loss.
|
||||||
|
"""
|
||||||
|
if gt.shape[1] != 3:
|
||||||
|
assert gt.shape[-1] == 3
|
||||||
|
gt = gt.permute(0, 3, 1, 2)
|
||||||
|
if pred.shape[1] != 3:
|
||||||
|
assert pred.shape[-1] == 3
|
||||||
|
pred = pred.permute(0, 3, 1, 2)
|
||||||
|
terms = {
|
||||||
|
f"{name}_loss" : l1_loss(gt, pred),
|
||||||
|
f"{name}_loss_ssim" : 1 - ssim(gt, pred),
|
||||||
|
f"{name}_loss_lpips" : lpips(gt, pred)
|
||||||
|
}
|
||||||
|
terms[f"{name}_loss_perceptual"] = terms[f"{name}_loss"] + terms[f"{name}_loss_ssim"] * self.lambda_ssim + terms[f"{name}_loss_lpips"] * self.lambda_lpips
|
||||||
|
return terms
|
||||||
|
|
||||||
|
def geometry_losses(
|
||||||
|
self,
|
||||||
|
reps: List[MeshExtractResult],
|
||||||
|
mesh: List[Dict],
|
||||||
|
normal_map: torch.Tensor,
|
||||||
|
extrinsics: torch.Tensor,
|
||||||
|
intrinsics: torch.Tensor,
|
||||||
|
):
|
||||||
|
with torch.no_grad():
|
||||||
|
gt_meshes = []
|
||||||
|
for i in range(len(reps)):
|
||||||
|
gt_mesh = MeshExtractResult(mesh[i]['vertices'].to(self.device), mesh[i]['faces'].to(self.device))
|
||||||
|
gt_meshes.append(gt_mesh)
|
||||||
|
target = self._render_batch(gt_meshes, extrinsics, intrinsics, return_types=['mask', 'depth', 'normal'])
|
||||||
|
target['normal'] = self._flip_normal(target['normal'], extrinsics, intrinsics)
|
||||||
|
|
||||||
|
terms = edict(geo_loss = 0.0)
|
||||||
|
if self.lambda_tsdf > 0:
|
||||||
|
tsdf_loss = self._calc_tsdf_loss(reps, target['depth'], extrinsics, intrinsics)
|
||||||
|
terms['tsdf_loss'] = tsdf_loss
|
||||||
|
terms['geo_loss'] += tsdf_loss * self.lambda_tsdf
|
||||||
|
|
||||||
|
return_types = ['mask', 'depth', 'normal', 'normal_map'] if self.use_color else ['mask', 'depth', 'normal']
|
||||||
|
buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
|
||||||
|
|
||||||
|
success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
|
||||||
|
if success_mask.sum() != 0:
|
||||||
|
for k, v in buffer.items():
|
||||||
|
buffer[k] = v[success_mask]
|
||||||
|
for k, v in target.items():
|
||||||
|
target[k] = v[success_mask]
|
||||||
|
|
||||||
|
terms['mask_loss'] = l1_loss(buffer['mask'], target['mask'])
|
||||||
|
if self.depth_loss_type == 'l1':
|
||||||
|
terms['depth_loss'] = l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'])
|
||||||
|
elif self.depth_loss_type == 'smooth_l1':
|
||||||
|
terms['depth_loss'] = smooth_l1_loss(buffer['depth'] * target['mask'], target['depth'] * target['mask'], beta=1.0 / (2 * reps[0].res))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported depth loss type: {self.depth_loss_type}")
|
||||||
|
terms.update(self._perceptual_loss(buffer['normal'] * target['mask'], target['normal'] * target['mask'], 'normal'))
|
||||||
|
terms['geo_loss'] = terms['geo_loss'] + terms['mask_loss'] + terms['depth_loss'] * self.lambda_depth + terms['normal_loss_perceptual']
|
||||||
|
if self.use_color and normal_map is not None:
|
||||||
|
terms.update(self._perceptual_loss(normal_map[success_mask], buffer['normal_map'], 'normal_map'))
|
||||||
|
terms['geo_loss'] = terms['geo_loss'] + terms['normal_map_loss_perceptual'] * self.lambda_color
|
||||||
|
|
||||||
|
return terms
|
||||||
|
|
||||||
|
def color_losses(self, reps, image, alpha, extrinsics, intrinsics):
|
||||||
|
terms = edict(color_loss = torch.tensor(0.0, device=self.device))
|
||||||
|
buffer = self._render_batch(reps, extrinsics, intrinsics, return_types=['color'])
|
||||||
|
success_mask = torch.tensor([rep.success for rep in reps], device=self.device)
|
||||||
|
if success_mask.sum() != 0:
|
||||||
|
terms.update(self._perceptual_loss((image * alpha[:, None])[success_mask], buffer['color'][success_mask], 'color'))
|
||||||
|
terms['color_loss'] = terms['color_loss'] + terms['color_loss_perceptual'] * self.lambda_color
|
||||||
|
return terms
|
||||||
|
|
||||||
|
def training_losses(
|
||||||
|
self,
|
||||||
|
latents: SparseTensor,
|
||||||
|
image: torch.Tensor,
|
||||||
|
alpha: torch.Tensor,
|
||||||
|
mesh: List[Dict],
|
||||||
|
extrinsics: torch.Tensor,
|
||||||
|
intrinsics: torch.Tensor,
|
||||||
|
normal_map: torch.Tensor = None,
|
||||||
|
) -> Tuple[Dict, Dict]:
|
||||||
|
"""
|
||||||
|
Compute training losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: The [N x * x C] sparse latents
|
||||||
|
image: The [N x 3 x H x W] tensor of images.
|
||||||
|
alpha: The [N x H x W] tensor of alpha channels.
|
||||||
|
mesh: The list of dictionaries of meshes.
|
||||||
|
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
||||||
|
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a dict with the key "loss" containing a scalar tensor.
|
||||||
|
may also contain other keys for different terms.
|
||||||
|
"""
|
||||||
|
reps = self.training_models['decoder'](latents)
|
||||||
|
self.renderer.rendering_options.resolution = image.shape[-1]
|
||||||
|
|
||||||
|
terms = edict(loss = 0.0, rec = 0.0)
|
||||||
|
|
||||||
|
terms['reg_loss'] = sum([rep.reg_loss for rep in reps]) / len(reps)
|
||||||
|
terms['loss'] = terms['loss'] + terms['reg_loss']
|
||||||
|
|
||||||
|
geo_terms = self.geometry_losses(reps, mesh, normal_map, extrinsics, intrinsics)
|
||||||
|
terms.update(geo_terms)
|
||||||
|
terms['loss'] = terms['loss'] + terms['geo_loss']
|
||||||
|
|
||||||
|
if self.use_color:
|
||||||
|
color_terms = self.color_losses(reps, image, alpha, extrinsics, intrinsics)
|
||||||
|
terms.update(color_terms)
|
||||||
|
terms['loss'] = terms['loss'] + terms['color_loss']
|
||||||
|
|
||||||
|
return terms, {}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run_snapshot(
|
||||||
|
self,
|
||||||
|
num_samples: int,
|
||||||
|
batch_size: int,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> Dict:
|
||||||
|
dataloader = DataLoader(
|
||||||
|
copy.deepcopy(self.dataset),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# inference
|
||||||
|
ret_dict = {}
|
||||||
|
gt_images = []
|
||||||
|
gt_normal_maps = []
|
||||||
|
gt_meshes = []
|
||||||
|
exts = []
|
||||||
|
ints = []
|
||||||
|
reps = []
|
||||||
|
for i in range(0, num_samples, batch_size):
|
||||||
|
batch = min(batch_size, num_samples - i)
|
||||||
|
data = next(iter(dataloader))
|
||||||
|
args = recursive_to_device(data, 'cuda')
|
||||||
|
gt_images.append(args['image'] * args['alpha'][:, None])
|
||||||
|
if self.use_color and 'normal_map' in data:
|
||||||
|
gt_normal_maps.append(args['normal_map'])
|
||||||
|
gt_meshes.extend(args['mesh'])
|
||||||
|
exts.append(args['extrinsics'])
|
||||||
|
ints.append(args['intrinsics'])
|
||||||
|
reps.extend(self.models['decoder'](args['latents']))
|
||||||
|
gt_images = torch.cat(gt_images, dim=0)
|
||||||
|
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
|
||||||
|
if self.use_color and gt_normal_maps:
|
||||||
|
gt_normal_maps = torch.cat(gt_normal_maps, dim=0)
|
||||||
|
ret_dict.update({f'gt_normal_map': {'value': gt_normal_maps, 'type': 'image'}})
|
||||||
|
|
||||||
|
# render single view
|
||||||
|
exts = torch.cat(exts, dim=0)
|
||||||
|
ints = torch.cat(ints, dim=0)
|
||||||
|
self.renderer.rendering_options.bg_color = (0, 0, 0)
|
||||||
|
self.renderer.rendering_options.resolution = gt_images.shape[-1]
|
||||||
|
gt_render_results = self._render_batch([
|
||||||
|
MeshExtractResult(vertices=mesh['vertices'].to(self.device), faces=mesh['faces'].to(self.device))
|
||||||
|
for mesh in gt_meshes
|
||||||
|
], exts, ints, return_types=['normal'])
|
||||||
|
ret_dict.update({f'gt_normal': {'value': self._flip_normal(gt_render_results['normal'], exts, ints), 'type': 'image'}})
|
||||||
|
return_types = ['normal']
|
||||||
|
if self.use_color:
|
||||||
|
return_types.append('color')
|
||||||
|
if 'normal_map' in data:
|
||||||
|
return_types.append('normal_map')
|
||||||
|
render_results = self._render_batch(reps, exts, ints, return_types=return_types)
|
||||||
|
ret_dict.update({f'rec_normal': {'value': render_results['normal'], 'type': 'image'}})
|
||||||
|
if 'color' in return_types:
|
||||||
|
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
|
||||||
|
if 'normal_map' in return_types:
|
||||||
|
ret_dict.update({f'rec_normal_map': {'value': render_results['normal_map'], 'type': 'image'}})
|
||||||
|
|
||||||
|
# render multiview
|
||||||
|
self.renderer.rendering_options.resolution = 512
|
||||||
|
## Build camera
|
||||||
|
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
||||||
|
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
||||||
|
yaws = [y + yaws_offset for y in yaws]
|
||||||
|
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
||||||
|
|
||||||
|
## render each view
|
||||||
|
multiview_normals = []
|
||||||
|
multiview_normal_maps = []
|
||||||
|
miltiview_images = []
|
||||||
|
for yaw, pitch in zip(yaws, pitch):
|
||||||
|
orig = torch.tensor([
|
||||||
|
np.sin(yaw) * np.cos(pitch),
|
||||||
|
np.cos(yaw) * np.cos(pitch),
|
||||||
|
np.sin(pitch),
|
||||||
|
]).float().cuda() * 2
|
||||||
|
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
||||||
|
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
||||||
|
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
||||||
|
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
||||||
|
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
||||||
|
render_results = self._render_batch(reps, extrinsics, intrinsics, return_types=return_types)
|
||||||
|
multiview_normals.append(render_results['normal'])
|
||||||
|
if 'color' in return_types:
|
||||||
|
miltiview_images.append(render_results['color'])
|
||||||
|
if 'normal_map' in return_types:
|
||||||
|
multiview_normal_maps.append(render_results['normal_map'])
|
||||||
|
|
||||||
|
## Concatenate views
|
||||||
|
multiview_normals = torch.cat([
|
||||||
|
torch.cat(multiview_normals[:2], dim=-2),
|
||||||
|
torch.cat(multiview_normals[2:], dim=-2),
|
||||||
|
], dim=-1)
|
||||||
|
ret_dict.update({f'multiview_normal': {'value': multiview_normals, 'type': 'image'}})
|
||||||
|
if 'color' in return_types:
|
||||||
|
miltiview_images = torch.cat([
|
||||||
|
torch.cat(miltiview_images[:2], dim=-2),
|
||||||
|
torch.cat(miltiview_images[2:], dim=-2),
|
||||||
|
], dim=-1)
|
||||||
|
ret_dict.update({f'multiview_image': {'value': miltiview_images, 'type': 'image'}})
|
||||||
|
if 'normal_map' in return_types:
|
||||||
|
multiview_normal_maps = torch.cat([
|
||||||
|
torch.cat(multiview_normal_maps[:2], dim=-2),
|
||||||
|
torch.cat(multiview_normal_maps[2:], dim=-2),
|
||||||
|
], dim=-1)
|
||||||
|
ret_dict.update({f'multiview_normal_map': {'value': multiview_normal_maps, 'type': 'image'}})
|
||||||
|
|
||||||
|
return ret_dict
|
||||||
223
trellis/trainers/vae/structured_latent_vae_rf_dec.py
Normal file
223
trellis/trainers/vae/structured_latent_vae_rf_dec.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
from typing import *
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import numpy as np
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
import utils3d.torch
|
||||||
|
|
||||||
|
from ..basic import BasicTrainer
|
||||||
|
from ...representations import Strivec
|
||||||
|
from ...renderers import OctreeRenderer
|
||||||
|
from ...modules.sparse import SparseTensor
|
||||||
|
from ...utils.loss_utils import l1_loss, l2_loss, ssim, lpips
|
||||||
|
|
||||||
|
|
||||||
|
class SLatVaeRadianceFieldDecoderTrainer(BasicTrainer):
|
||||||
|
"""
|
||||||
|
Trainer for structured latent VAE Radiance Field Decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models (dict[str, nn.Module]): Models to train.
|
||||||
|
dataset (torch.utils.data.Dataset): Dataset.
|
||||||
|
output_dir (str): Output directory.
|
||||||
|
load_dir (str): Load directory.
|
||||||
|
step (int): Step to load.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
batch_size_per_gpu (int): Batch size per GPU. If specified, batch_size will be ignored.
|
||||||
|
batch_split (int): Split batch with gradient accumulation.
|
||||||
|
max_steps (int): Max steps.
|
||||||
|
optimizer (dict): Optimizer config.
|
||||||
|
lr_scheduler (dict): Learning rate scheduler config.
|
||||||
|
elastic (dict): Elastic memory management config.
|
||||||
|
grad_clip (float or dict): Gradient clip config.
|
||||||
|
ema_rate (float or list): Exponential moving average rates.
|
||||||
|
fp16_mode (str): FP16 mode.
|
||||||
|
- None: No FP16.
|
||||||
|
- 'inflat_all': Hold a inflated fp32 master param for all params.
|
||||||
|
- 'amp': Automatic mixed precision.
|
||||||
|
fp16_scale_growth (float): Scale growth for FP16 gradient backpropagation.
|
||||||
|
finetune_ckpt (dict): Finetune checkpoint.
|
||||||
|
log_param_stats (bool): Log parameter stats.
|
||||||
|
i_print (int): Print interval.
|
||||||
|
i_log (int): Log interval.
|
||||||
|
i_sample (int): Sample interval.
|
||||||
|
i_save (int): Save interval.
|
||||||
|
i_ddpcheck (int): DDP check interval.
|
||||||
|
|
||||||
|
loss_type (str): Loss type. Can be 'l1', 'l2'
|
||||||
|
lambda_ssim (float): SSIM loss weight.
|
||||||
|
lambda_lpips (float): LPIPS loss weight.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
loss_type: str = 'l1',
|
||||||
|
lambda_ssim: float = 0.2,
|
||||||
|
lambda_lpips: float = 0.2,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.loss_type = loss_type
|
||||||
|
self.lambda_ssim = lambda_ssim
|
||||||
|
self.lambda_lpips = lambda_lpips
|
||||||
|
|
||||||
|
self._init_renderer()
|
||||||
|
|
||||||
|
def _init_renderer(self):
|
||||||
|
rendering_options = {"near" : 0.8,
|
||||||
|
"far" : 1.6,
|
||||||
|
"bg_color" : 'random'}
|
||||||
|
self.renderer = OctreeRenderer(rendering_options)
|
||||||
|
self.renderer.pipe.primitive = 'trivec'
|
||||||
|
|
||||||
|
def _render_batch(self, reps: List[Strivec], extrinsics: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Render a batch of representations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reps: The dictionary of lists of representations.
|
||||||
|
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
||||||
|
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
||||||
|
"""
|
||||||
|
ret = None
|
||||||
|
for i, representation in enumerate(reps):
|
||||||
|
render_pack = self.renderer.render(representation, extrinsics[i], intrinsics[i])
|
||||||
|
if ret is None:
|
||||||
|
ret = {k: [] for k in list(render_pack.keys()) + ['bg_color']}
|
||||||
|
for k, v in render_pack.items():
|
||||||
|
ret[k].append(v)
|
||||||
|
ret['bg_color'].append(self.renderer.bg_color)
|
||||||
|
for k, v in ret.items():
|
||||||
|
ret[k] = torch.stack(v, dim=0)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def training_losses(
|
||||||
|
self,
|
||||||
|
latents: SparseTensor,
|
||||||
|
image: torch.Tensor,
|
||||||
|
alpha: torch.Tensor,
|
||||||
|
extrinsics: torch.Tensor,
|
||||||
|
intrinsics: torch.Tensor,
|
||||||
|
return_aux: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[Dict, Dict]:
|
||||||
|
"""
|
||||||
|
Compute training losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: The [N x * x C] sparse latents
|
||||||
|
image: The [N x 3 x H x W] tensor of images.
|
||||||
|
alpha: The [N x H x W] tensor of alpha channels.
|
||||||
|
extrinsics: The [N x 4 x 4] tensor of extrinsics.
|
||||||
|
intrinsics: The [N x 3 x 3] tensor of intrinsics.
|
||||||
|
return_aux: Whether to return auxiliary information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a dict with the key "loss" containing a scalar tensor.
|
||||||
|
may also contain other keys for different terms.
|
||||||
|
"""
|
||||||
|
reps = self.training_models['decoder'](latents)
|
||||||
|
self.renderer.rendering_options.resolution = image.shape[-1]
|
||||||
|
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
||||||
|
|
||||||
|
terms = edict(loss = 0.0, rec = 0.0)
|
||||||
|
|
||||||
|
rec_image = render_results['color']
|
||||||
|
gt_image = image * alpha[:, None] + (1 - alpha[:, None]) * render_results['bg_color'][..., None, None]
|
||||||
|
|
||||||
|
if self.loss_type == 'l1':
|
||||||
|
terms["l1"] = l1_loss(rec_image, gt_image)
|
||||||
|
terms["rec"] = terms["rec"] + terms["l1"]
|
||||||
|
elif self.loss_type == 'l2':
|
||||||
|
terms["l2"] = l2_loss(rec_image, gt_image)
|
||||||
|
terms["rec"] = terms["rec"] + terms["l2"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid loss type: {self.loss_type}")
|
||||||
|
if self.lambda_ssim > 0:
|
||||||
|
terms["ssim"] = 1 - ssim(rec_image, gt_image)
|
||||||
|
terms["rec"] = terms["rec"] + self.lambda_ssim * terms["ssim"]
|
||||||
|
if self.lambda_lpips > 0:
|
||||||
|
terms["lpips"] = lpips(rec_image, gt_image)
|
||||||
|
terms["rec"] = terms["rec"] + self.lambda_lpips * terms["lpips"]
|
||||||
|
terms["loss"] = terms["loss"] + terms["rec"]
|
||||||
|
|
||||||
|
if return_aux:
|
||||||
|
return terms, {}, {'rec_image': rec_image, 'gt_image': gt_image}
|
||||||
|
return terms, {}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run_snapshot(
|
||||||
|
self,
|
||||||
|
num_samples: int,
|
||||||
|
batch_size: int,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> Dict:
|
||||||
|
dataloader = DataLoader(
|
||||||
|
copy.deepcopy(self.dataset),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# inference
|
||||||
|
ret_dict = {}
|
||||||
|
gt_images = []
|
||||||
|
exts = []
|
||||||
|
ints = []
|
||||||
|
reps = []
|
||||||
|
for i in range(0, num_samples, batch_size):
|
||||||
|
batch = min(batch_size, num_samples - i)
|
||||||
|
data = next(iter(dataloader))
|
||||||
|
args = {k: v[:batch].cuda() for k, v in data.items()}
|
||||||
|
gt_images.append(args['image'] * args['alpha'][:, None])
|
||||||
|
exts.append(args['extrinsics'])
|
||||||
|
ints.append(args['intrinsics'])
|
||||||
|
reps.extend(self.models['decoder'](args['latents']))
|
||||||
|
gt_images = torch.cat(gt_images, dim=0)
|
||||||
|
ret_dict.update({f'gt_image': {'value': gt_images, 'type': 'image'}})
|
||||||
|
|
||||||
|
# render single view
|
||||||
|
exts = torch.cat(exts, dim=0)
|
||||||
|
ints = torch.cat(ints, dim=0)
|
||||||
|
self.renderer.rendering_options.bg_color = (0, 0, 0)
|
||||||
|
self.renderer.rendering_options.resolution = gt_images.shape[-1]
|
||||||
|
render_results = self._render_batch(reps, exts, ints)
|
||||||
|
ret_dict.update({f'rec_image': {'value': render_results['color'], 'type': 'image'}})
|
||||||
|
|
||||||
|
# render multiview
|
||||||
|
self.renderer.rendering_options.resolution = 512
|
||||||
|
## Build camera
|
||||||
|
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
|
||||||
|
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
|
||||||
|
yaws = [y + yaws_offset for y in yaws]
|
||||||
|
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
|
||||||
|
|
||||||
|
## render each view
|
||||||
|
miltiview_images = []
|
||||||
|
for yaw, pitch in zip(yaws, pitch):
|
||||||
|
orig = torch.tensor([
|
||||||
|
np.sin(yaw) * np.cos(pitch),
|
||||||
|
np.cos(yaw) * np.cos(pitch),
|
||||||
|
np.sin(pitch),
|
||||||
|
]).float().cuda() * 2
|
||||||
|
fov = torch.deg2rad(torch.tensor(30)).cuda()
|
||||||
|
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
|
||||||
|
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
|
||||||
|
extrinsics = extrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
||||||
|
intrinsics = intrinsics.unsqueeze(0).expand(num_samples, -1, -1)
|
||||||
|
render_results = self._render_batch(reps, extrinsics, intrinsics)
|
||||||
|
miltiview_images.append(render_results['color'])
|
||||||
|
|
||||||
|
## Concatenate views
|
||||||
|
miltiview_images = torch.cat([
|
||||||
|
torch.cat(miltiview_images[:2], dim=-2),
|
||||||
|
torch.cat(miltiview_images[2:], dim=-2),
|
||||||
|
], dim=-1)
|
||||||
|
ret_dict.update({f'miltiview_image': {'value': miltiview_images, 'type': 'image'}})
|
||||||
|
|
||||||
|
self.renderer.rendering_options.bg_color = 'random'
|
||||||
|
|
||||||
|
return ret_dict
|
||||||
140
utils/new_oss_client.py
Normal file
140
utils/new_oss_client.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
import urllib3
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
MINIO_URL = "www.minio-api.aida.com.hk"
|
||||||
|
MINIO_ACCESS = "vXKFLSJkYeEq2DrSZvkB"
|
||||||
|
MINIO_SECRET = "uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR"
|
||||||
|
MINIO_SECURE = True
|
||||||
|
MINIO_BUCKET = "test"
|
||||||
|
|
||||||
|
from minio import Minio
|
||||||
|
import urllib3
|
||||||
|
|
||||||
|
|
||||||
|
class CustomRetry(urllib3.Retry):
|
||||||
|
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
|
||||||
|
# 调用父类的 increment 方法
|
||||||
|
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
|
||||||
|
# 打印重试信息
|
||||||
|
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
|
||||||
|
return new_retry
|
||||||
|
|
||||||
|
|
||||||
|
http_client = urllib3.PoolManager(
|
||||||
|
num_pools=20,
|
||||||
|
maxsize=50,
|
||||||
|
timeout=urllib3.Timeout(connect=2, read=30),
|
||||||
|
cert_reqs='CERT_REQUIRED', # 需要证书验证
|
||||||
|
retries=CustomRetry(
|
||||||
|
total=5,
|
||||||
|
backoff_factor=0.2,
|
||||||
|
status_forcelist=[500, 502, 503, 504],
|
||||||
|
),
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
minio_client = Minio(
|
||||||
|
MINIO_URL,
|
||||||
|
access_key=MINIO_ACCESS,
|
||||||
|
secret_key=MINIO_SECRET,
|
||||||
|
secure=MINIO_SECURE,
|
||||||
|
http_client=http_client
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def minio_get_image(client, bucket, object_name):
|
||||||
|
response = None
|
||||||
|
try:
|
||||||
|
response = client.get_object(bucket, object_name)
|
||||||
|
# 直接读取 bytes
|
||||||
|
image_bytes = response.read()
|
||||||
|
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取 MinIO 图片失败: {bucket}/{object_name} | {e}")
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
if response:
|
||||||
|
response.close()
|
||||||
|
response.release_conn()
|
||||||
|
|
||||||
|
|
||||||
|
def upload_bytes(data_bytes, object_name, content_type):
|
||||||
|
minio_client.put_object(
|
||||||
|
bucket_name=MINIO_BUCKET,
|
||||||
|
object_name=object_name,
|
||||||
|
data=BytesIO(data_bytes),
|
||||||
|
length=len(data_bytes),
|
||||||
|
content_type=content_type
|
||||||
|
)
|
||||||
|
return f"{MINIO_BUCKET}/{object_name}"
|
||||||
|
|
||||||
|
|
||||||
|
def upload_local_file(file_path, object_name, content_type="application/octet-stream"):
|
||||||
|
"""
|
||||||
|
将本地磁盘上的文件上传到 MinIO
|
||||||
|
:param file_path: 本地文件路径 (如: 'output/sample.obj')
|
||||||
|
:param object_name: MinIO 中的存储路径/文件名
|
||||||
|
:param content_type: 文件 MIME 类型
|
||||||
|
"""
|
||||||
|
# 健壮性检查:确保本地文件确实存在
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"本地文件未找到: {file_path}")
|
||||||
|
|
||||||
|
# 使用 fput_object 直接从磁盘流式上传,不占用过多内存
|
||||||
|
minio_client.fput_object(
|
||||||
|
bucket_name=MINIO_BUCKET,
|
||||||
|
object_name=object_name,
|
||||||
|
file_path=file_path,
|
||||||
|
content_type=content_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"{MINIO_BUCKET}/{object_name}"
|
||||||
|
|
||||||
|
|
||||||
|
def download_from_minio(object_path, local_path):
|
||||||
|
"""
|
||||||
|
从 MinIO 下载文件到本地磁盘
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from minio.error import S3Error
|
||||||
|
|
||||||
|
local_dir = os.path.dirname(local_path)
|
||||||
|
if local_dir:
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
|
||||||
|
path_parts = object_path.split("/", 1)
|
||||||
|
bucket_name = path_parts[0]
|
||||||
|
object_name = path_parts[1]
|
||||||
|
try:
|
||||||
|
minio_client.fget_object(
|
||||||
|
bucket_name=bucket_name,
|
||||||
|
object_name=object_name,
|
||||||
|
file_path=local_path
|
||||||
|
)
|
||||||
|
if os.path.exists(local_path):
|
||||||
|
return local_path
|
||||||
|
else:
|
||||||
|
raise RuntimeError("下载文件后本地路径不存在")
|
||||||
|
except S3Error as err:
|
||||||
|
if err.code == 'NoSuchKey':
|
||||||
|
raise FileNotFoundError(f"MinIO 对象不存在: {bucket_name}/{object_name}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"MinIO 下载失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
url = "fida-test/furniture/sketches/4449a66d-6267-43f7-86a2-1e42bd19ec61.png"
|
||||||
|
read_type = "2"
|
||||||
|
img = minio_get_image(client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:])
|
||||||
|
img.show()
|
||||||
|
img.save("result.png")
|
||||||
Reference in New Issue
Block a user