1
This commit is contained in:
15
trellis/pipelines/samplers/guidance_interval_mixin.py
Normal file
15
trellis/pipelines/samplers/guidance_interval_mixin.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import *
|
||||
|
||||
|
||||
class GuidanceIntervalSamplerMixin:
|
||||
"""
|
||||
A mixin class for samplers that apply classifier-free guidance with interval.
|
||||
"""
|
||||
|
||||
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
|
||||
if cfg_interval[0] <= t <= cfg_interval[1]:
|
||||
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
|
||||
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
|
||||
else:
|
||||
return super()._inference_model(model, x_t, t, cond, **kwargs)
|
||||
231
trellis/renderers/gaussian_render.py
Normal file
231
trellis/renderers/gaussian_render.py
Normal file
@@ -0,0 +1,231 @@
|
||||
#
|
||||
# Copyright (C) 2023, Inria
|
||||
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
||||
# All rights reserved.
|
||||
#
|
||||
# This software is free for non-commercial, research and evaluation use
|
||||
# under the terms of the LICENSE.md file.
|
||||
#
|
||||
# For inquiries contact george.drettakis@inria.fr
|
||||
#
|
||||
|
||||
import torch
|
||||
import math
|
||||
from easydict import EasyDict as edict
|
||||
import numpy as np
|
||||
from ..representations.gaussian import Gaussian
|
||||
from .sh_utils import eval_sh
|
||||
import torch.nn.functional as F
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
|
||||
def intrinsics_to_projection(
|
||||
intrinsics: torch.Tensor,
|
||||
near: float,
|
||||
far: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
OpenCV intrinsics to OpenGL perspective matrix
|
||||
|
||||
Args:
|
||||
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
|
||||
near (float): near plane to clip
|
||||
far (float): far plane to clip
|
||||
Returns:
|
||||
(torch.Tensor): [4, 4] OpenGL perspective matrix
|
||||
"""
|
||||
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
|
||||
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
|
||||
ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
|
||||
ret[0, 0] = 2 * fx
|
||||
ret[1, 1] = 2 * fy
|
||||
ret[0, 2] = 2 * cx - 1
|
||||
ret[1, 2] = - 2 * cy + 1
|
||||
ret[2, 2] = far / (far - near)
|
||||
ret[2, 3] = near * far / (near - far)
|
||||
ret[3, 2] = 1.
|
||||
return ret
|
||||
|
||||
|
||||
def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
|
||||
"""
|
||||
Render the scene.
|
||||
|
||||
Background tensor (bg_color) must be on GPU!
|
||||
"""
|
||||
# lazy import
|
||||
if 'GaussianRasterizer' not in globals():
|
||||
from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
|
||||
|
||||
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
||||
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
|
||||
try:
|
||||
screenspace_points.retain_grad()
|
||||
except:
|
||||
pass
|
||||
# Set up rasterization configuration
|
||||
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
||||
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
||||
|
||||
kernel_size = pipe.kernel_size
|
||||
subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
|
||||
|
||||
raster_settings = GaussianRasterizationSettings(
|
||||
image_height=int(viewpoint_camera.image_height),
|
||||
image_width=int(viewpoint_camera.image_width),
|
||||
tanfovx=tanfovx,
|
||||
tanfovy=tanfovy,
|
||||
kernel_size=kernel_size,
|
||||
subpixel_offset=subpixel_offset,
|
||||
bg=bg_color,
|
||||
scale_modifier=scaling_modifier,
|
||||
viewmatrix=viewpoint_camera.world_view_transform,
|
||||
projmatrix=viewpoint_camera.full_proj_transform,
|
||||
sh_degree=pc.active_sh_degree,
|
||||
campos=viewpoint_camera.camera_center,
|
||||
prefiltered=False,
|
||||
debug=pipe.debug
|
||||
)
|
||||
|
||||
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
||||
|
||||
means3D = pc.get_xyz
|
||||
means2D = screenspace_points
|
||||
opacity = pc.get_opacity
|
||||
|
||||
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
||||
# scaling / rotation by the rasterizer.
|
||||
scales = None
|
||||
rotations = None
|
||||
cov3D_precomp = None
|
||||
if pipe.compute_cov3D_python:
|
||||
cov3D_precomp = pc.get_covariance(scaling_modifier)
|
||||
else:
|
||||
scales = pc.get_scaling
|
||||
rotations = pc.get_rotation
|
||||
|
||||
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
||||
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
||||
shs = None
|
||||
colors_precomp = None
|
||||
if override_color is None:
|
||||
if pipe.convert_SHs_python:
|
||||
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
|
||||
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
|
||||
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
|
||||
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
|
||||
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
|
||||
else:
|
||||
shs = pc.get_features
|
||||
else:
|
||||
colors_precomp = override_color
|
||||
|
||||
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
||||
rendered_image, radii = rasterizer(
|
||||
means3D = means3D,
|
||||
means2D = means2D,
|
||||
shs = shs,
|
||||
colors_precomp = colors_precomp,
|
||||
opacities = opacity,
|
||||
scales = scales,
|
||||
rotations = rotations,
|
||||
cov3D_precomp = cov3D_precomp
|
||||
)
|
||||
|
||||
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
||||
# They will be excluded from value updates used in the splitting criteria.
|
||||
return edict({"render": rendered_image,
|
||||
"viewspace_points": screenspace_points,
|
||||
"visibility_filter" : radii > 0,
|
||||
"radii": radii})
|
||||
|
||||
|
||||
class GaussianRenderer:
|
||||
"""
|
||||
Renderer for the Voxel representation.
|
||||
|
||||
Args:
|
||||
rendering_options (dict): Rendering options.
|
||||
"""
|
||||
|
||||
def __init__(self, rendering_options={}) -> None:
|
||||
self.pipe = edict({
|
||||
"kernel_size": 0.1,
|
||||
"convert_SHs_python": False,
|
||||
"compute_cov3D_python": False,
|
||||
"scale_modifier": 1.0,
|
||||
"debug": False
|
||||
})
|
||||
self.rendering_options = edict({
|
||||
"resolution": None,
|
||||
"near": None,
|
||||
"far": None,
|
||||
"ssaa": 1,
|
||||
"bg_color": 'random',
|
||||
})
|
||||
self.rendering_options.update(rendering_options)
|
||||
self.bg_color = None
|
||||
|
||||
def render(
|
||||
self,
|
||||
gausssian: Gaussian,
|
||||
extrinsics: torch.Tensor,
|
||||
intrinsics: torch.Tensor,
|
||||
colors_overwrite: torch.Tensor = None
|
||||
) -> edict:
|
||||
"""
|
||||
Render the gausssian.
|
||||
|
||||
Args:
|
||||
gaussian : gaussianmodule
|
||||
extrinsics (torch.Tensor): (4, 4) camera extrinsics
|
||||
intrinsics (torch.Tensor): (3, 3) camera intrinsics
|
||||
colors_overwrite (torch.Tensor): (N, 3) override color
|
||||
|
||||
Returns:
|
||||
edict containing:
|
||||
color (torch.Tensor): (3, H, W) rendered color image
|
||||
"""
|
||||
resolution = self.rendering_options["resolution"]
|
||||
near = self.rendering_options["near"]
|
||||
far = self.rendering_options["far"]
|
||||
ssaa = self.rendering_options["ssaa"]
|
||||
|
||||
if self.rendering_options["bg_color"] == 'random':
|
||||
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
|
||||
if np.random.rand() < 0.5:
|
||||
self.bg_color += 1
|
||||
else:
|
||||
self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
|
||||
|
||||
view = extrinsics
|
||||
perspective = intrinsics_to_projection(intrinsics, near, far)
|
||||
camera = torch.inverse(view)[:3, 3]
|
||||
focalx = intrinsics[0, 0]
|
||||
focaly = intrinsics[1, 1]
|
||||
fovx = 2 * torch.atan(0.5 / focalx)
|
||||
fovy = 2 * torch.atan(0.5 / focaly)
|
||||
|
||||
camera_dict = edict({
|
||||
"image_height": resolution * ssaa,
|
||||
"image_width": resolution * ssaa,
|
||||
"FoVx": fovx,
|
||||
"FoVy": fovy,
|
||||
"znear": near,
|
||||
"zfar": far,
|
||||
"world_view_transform": view.T.contiguous(),
|
||||
"projection_matrix": perspective.T.contiguous(),
|
||||
"full_proj_transform": (perspective @ view).T.contiguous(),
|
||||
"camera_center": camera
|
||||
})
|
||||
|
||||
# Render
|
||||
render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier)
|
||||
|
||||
if ssaa > 1:
|
||||
render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
|
||||
|
||||
ret = edict({
|
||||
'color': render_ret['render']
|
||||
})
|
||||
return ret
|
||||
133
trellis/representations/gaussian/general_utils.py
Normal file
133
trellis/representations/gaussian/general_utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#
|
||||
# Copyright (C) 2023, Inria
|
||||
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
||||
# All rights reserved.
|
||||
#
|
||||
# This software is free for non-commercial, research and evaluation use
|
||||
# under the terms of the LICENSE.md file.
|
||||
#
|
||||
# For inquiries contact george.drettakis@inria.fr
|
||||
#
|
||||
|
||||
import torch
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def inverse_sigmoid(x):
|
||||
return torch.log(x/(1-x))
|
||||
|
||||
def PILtoTorch(pil_image, resolution):
|
||||
resized_image_PIL = pil_image.resize(resolution)
|
||||
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
|
||||
if len(resized_image.shape) == 3:
|
||||
return resized_image.permute(2, 0, 1)
|
||||
else:
|
||||
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
|
||||
|
||||
def get_expon_lr_func(
|
||||
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
||||
):
|
||||
"""
|
||||
Copied from Plenoxels
|
||||
|
||||
Continuous learning rate decay function. Adapted from JaxNeRF
|
||||
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
|
||||
is log-linearly interpolated elsewhere (equivalent to exponential decay).
|
||||
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
|
||||
function of lr_delay_mult, such that the initial learning rate is
|
||||
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
|
||||
to the normal learning rate when steps>lr_delay_steps.
|
||||
:param conf: config subtree 'lr' or similar
|
||||
:param max_steps: int, the number of steps during optimization.
|
||||
:return HoF which takes step as input
|
||||
"""
|
||||
|
||||
def helper(step):
|
||||
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
|
||||
# Disable this parameter
|
||||
return 0.0
|
||||
if lr_delay_steps > 0:
|
||||
# A kind of reverse cosine decay.
|
||||
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
||||
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
|
||||
)
|
||||
else:
|
||||
delay_rate = 1.0
|
||||
t = np.clip(step / max_steps, 0, 1)
|
||||
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
||||
return delay_rate * log_lerp
|
||||
|
||||
return helper
|
||||
|
||||
def strip_lowerdiag(L):
|
||||
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
|
||||
|
||||
uncertainty[:, 0] = L[:, 0, 0]
|
||||
uncertainty[:, 1] = L[:, 0, 1]
|
||||
uncertainty[:, 2] = L[:, 0, 2]
|
||||
uncertainty[:, 3] = L[:, 1, 1]
|
||||
uncertainty[:, 4] = L[:, 1, 2]
|
||||
uncertainty[:, 5] = L[:, 2, 2]
|
||||
return uncertainty
|
||||
|
||||
def strip_symmetric(sym):
|
||||
return strip_lowerdiag(sym)
|
||||
|
||||
def build_rotation(r):
|
||||
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
||||
|
||||
q = r / norm[:, None]
|
||||
|
||||
R = torch.zeros((q.size(0), 3, 3), device='cuda')
|
||||
|
||||
r = q[:, 0]
|
||||
x = q[:, 1]
|
||||
y = q[:, 2]
|
||||
z = q[:, 3]
|
||||
|
||||
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
|
||||
R[:, 0, 1] = 2 * (x*y - r*z)
|
||||
R[:, 0, 2] = 2 * (x*z + r*y)
|
||||
R[:, 1, 0] = 2 * (x*y + r*z)
|
||||
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
|
||||
R[:, 1, 2] = 2 * (y*z - r*x)
|
||||
R[:, 2, 0] = 2 * (x*z - r*y)
|
||||
R[:, 2, 1] = 2 * (y*z + r*x)
|
||||
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
||||
return R
|
||||
|
||||
def build_scaling_rotation(s, r):
|
||||
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
|
||||
R = build_rotation(r)
|
||||
|
||||
L[:,0,0] = s[:,0]
|
||||
L[:,1,1] = s[:,1]
|
||||
L[:,2,2] = s[:,2]
|
||||
|
||||
L = R @ L
|
||||
return L
|
||||
|
||||
def safe_state(silent):
|
||||
old_f = sys.stdout
|
||||
class F:
|
||||
def __init__(self, silent):
|
||||
self.silent = silent
|
||||
|
||||
def write(self, x):
|
||||
if not self.silent:
|
||||
if x.endswith("\n"):
|
||||
old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
|
||||
else:
|
||||
old_f.write(x)
|
||||
|
||||
def flush(self):
|
||||
old_f.flush()
|
||||
|
||||
sys.stdout = F(silent)
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.set_device(torch.device("cuda:0"))
|
||||
93
trellis/trainers/flow_matching/mixins/image_conditioned.py
Normal file
93
trellis/trainers/flow_matching/mixins/image_conditioned.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ....utils import dist_utils
|
||||
|
||||
|
||||
class ImageConditionedMixin:
|
||||
"""
|
||||
Mixin for image-conditioned models.
|
||||
|
||||
Args:
|
||||
image_cond_model: The image conditioning model.
|
||||
"""
|
||||
def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.image_cond_model_name = image_cond_model
|
||||
self.image_cond_model = None # the model is init lazily
|
||||
|
||||
@staticmethod
|
||||
def prepare_for_training(image_cond_model: str, **kwargs):
|
||||
"""
|
||||
Prepare for training.
|
||||
"""
|
||||
if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'):
|
||||
super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs)
|
||||
# download the model
|
||||
torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True)
|
||||
|
||||
def _init_image_cond_model(self):
|
||||
"""
|
||||
Initialize the image conditioning model.
|
||||
"""
|
||||
with dist_utils.local_master_first():
|
||||
dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True)
|
||||
dinov2_model.eval().cuda()
|
||||
transform = transforms.Compose([
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
self.image_cond_model = {
|
||||
'model': dinov2_model,
|
||||
'transform': transform,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
|
||||
"""
|
||||
Encode the image.
|
||||
"""
|
||||
if isinstance(image, torch.Tensor):
|
||||
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
|
||||
elif isinstance(image, list):
|
||||
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
|
||||
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
|
||||
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
|
||||
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
|
||||
image = torch.stack(image).cuda()
|
||||
else:
|
||||
raise ValueError(f"Unsupported type of image: {type(image)}")
|
||||
|
||||
if self.image_cond_model is None:
|
||||
self._init_image_cond_model()
|
||||
image = self.image_cond_model['transform'](image).cuda()
|
||||
features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
|
||||
patchtokens = F.layer_norm(features, features.shape[-1:])
|
||||
return patchtokens
|
||||
|
||||
def get_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Get the conditioning data.
|
||||
"""
|
||||
cond = self.encode_image(cond)
|
||||
kwargs['neg_cond'] = torch.zeros_like(cond)
|
||||
cond = super().get_cond(cond, **kwargs)
|
||||
return cond
|
||||
|
||||
def get_inference_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Get the conditioning data for inference.
|
||||
"""
|
||||
cond = self.encode_image(cond)
|
||||
kwargs['neg_cond'] = torch.zeros_like(cond)
|
||||
cond = super().get_inference_cond(cond, **kwargs)
|
||||
return cond
|
||||
|
||||
def vis_cond(self, cond, **kwargs):
|
||||
"""
|
||||
Visualize the conditioning data.
|
||||
"""
|
||||
return {'image': {'value': cond, 'type': 'image'}}
|
||||
202
trellis/utils/general_utils.py
Normal file
202
trellis/utils/general_utils.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
|
||||
# Dictionary utils
|
||||
def _dict_merge(dicta, dictb, prefix=''):
|
||||
"""
|
||||
Merge two dictionaries.
|
||||
"""
|
||||
assert isinstance(dicta, dict), 'input must be a dictionary'
|
||||
assert isinstance(dictb, dict), 'input must be a dictionary'
|
||||
dict_ = {}
|
||||
all_keys = set(dicta.keys()).union(set(dictb.keys()))
|
||||
for key in all_keys:
|
||||
if key in dicta.keys() and key in dictb.keys():
|
||||
if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
|
||||
dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
|
||||
else:
|
||||
raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
|
||||
elif key in dicta.keys():
|
||||
dict_[key] = dicta[key]
|
||||
else:
|
||||
dict_[key] = dictb[key]
|
||||
return dict_
|
||||
|
||||
|
||||
def dict_merge(dicta, dictb):
|
||||
"""
|
||||
Merge two dictionaries.
|
||||
"""
|
||||
return _dict_merge(dicta, dictb, prefix='')
|
||||
|
||||
|
||||
def dict_foreach(dic, func, special_func={}):
|
||||
"""
|
||||
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
||||
"""
|
||||
assert isinstance(dic, dict), 'input must be a dictionary'
|
||||
for key in dic.keys():
|
||||
if isinstance(dic[key], dict):
|
||||
dic[key] = dict_foreach(dic[key], func)
|
||||
else:
|
||||
if key in special_func.keys():
|
||||
dic[key] = special_func[key](dic[key])
|
||||
else:
|
||||
dic[key] = func(dic[key])
|
||||
return dic
|
||||
|
||||
|
||||
def dict_reduce(dicts, func, special_func={}):
|
||||
"""
|
||||
Reduce a list of dictionaries. Leaf values must be scalars.
|
||||
"""
|
||||
assert isinstance(dicts, list), 'input must be a list of dictionaries'
|
||||
assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
|
||||
assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
|
||||
all_keys = set([key for dict_ in dicts for key in dict_.keys()])
|
||||
reduced_dict = {}
|
||||
for key in all_keys:
|
||||
vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
|
||||
if isinstance(vlist[0], dict):
|
||||
reduced_dict[key] = dict_reduce(vlist, func, special_func)
|
||||
else:
|
||||
if key in special_func.keys():
|
||||
reduced_dict[key] = special_func[key](vlist)
|
||||
else:
|
||||
reduced_dict[key] = func(vlist)
|
||||
return reduced_dict
|
||||
|
||||
|
||||
def dict_any(dic, func):
|
||||
"""
|
||||
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
||||
"""
|
||||
assert isinstance(dic, dict), 'input must be a dictionary'
|
||||
for key in dic.keys():
|
||||
if isinstance(dic[key], dict):
|
||||
if dict_any(dic[key], func):
|
||||
return True
|
||||
else:
|
||||
if func(dic[key]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def dict_all(dic, func):
|
||||
"""
|
||||
Recursively apply a function to all non-dictionary leaf values in a dictionary.
|
||||
"""
|
||||
assert isinstance(dic, dict), 'input must be a dictionary'
|
||||
for key in dic.keys():
|
||||
if isinstance(dic[key], dict):
|
||||
if not dict_all(dic[key], func):
|
||||
return False
|
||||
else:
|
||||
if not func(dic[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def dict_flatten(dic, sep='.'):
|
||||
"""
|
||||
Flatten a nested dictionary into a dictionary with no nested dictionaries.
|
||||
"""
|
||||
assert isinstance(dic, dict), 'input must be a dictionary'
|
||||
flat_dict = {}
|
||||
for key in dic.keys():
|
||||
if isinstance(dic[key], dict):
|
||||
sub_dict = dict_flatten(dic[key], sep=sep)
|
||||
for sub_key in sub_dict.keys():
|
||||
flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
|
||||
else:
|
||||
flat_dict[key] = dic[key]
|
||||
return flat_dict
|
||||
|
||||
|
||||
# Context utils
|
||||
@contextlib.contextmanager
|
||||
def nested_contexts(*contexts):
|
||||
with contextlib.ExitStack() as stack:
|
||||
for ctx in contexts:
|
||||
stack.enter_context(ctx())
|
||||
yield
|
||||
|
||||
|
||||
# Image utils
|
||||
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
|
||||
num_images = len(images)
|
||||
if nrow is None and ncol is None:
|
||||
if aspect_ratio is not None:
|
||||
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
|
||||
else:
|
||||
nrow = int(np.sqrt(num_images))
|
||||
ncol = (num_images + nrow - 1) // nrow
|
||||
elif nrow is None and ncol is not None:
|
||||
nrow = (num_images + ncol - 1) // ncol
|
||||
elif nrow is not None and ncol is None:
|
||||
ncol = (num_images + nrow - 1) // nrow
|
||||
else:
|
||||
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
|
||||
|
||||
if images[0].ndim == 2:
|
||||
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
|
||||
else:
|
||||
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
|
||||
for i, img in enumerate(images):
|
||||
row = i // ncol
|
||||
col = i % ncol
|
||||
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
|
||||
return grid
|
||||
|
||||
|
||||
def notes_on_image(img, notes=None):
|
||||
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
if notes is not None:
|
||||
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def save_image_with_notes(img, path, notes=None):
|
||||
"""
|
||||
Save an image with notes.
|
||||
"""
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy().transpose(1, 2, 0)
|
||||
if img.dtype == np.float32 or img.dtype == np.float64:
|
||||
img = np.clip(img * 255, 0, 255).astype(np.uint8)
|
||||
img = notes_on_image(img, notes)
|
||||
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
||||
|
||||
|
||||
# debug utils
|
||||
|
||||
def atol(x, y):
|
||||
"""
|
||||
Absolute tolerance.
|
||||
"""
|
||||
return torch.abs(x - y)
|
||||
|
||||
|
||||
def rtol(x, y):
|
||||
"""
|
||||
Relative tolerance.
|
||||
"""
|
||||
return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
|
||||
|
||||
|
||||
# print utils
|
||||
def indent(s, n=4):
|
||||
"""
|
||||
Indent a string.
|
||||
"""
|
||||
lines = s.split('\n')
|
||||
for i in range(1, len(lines)):
|
||||
lines[i] = ' ' * n + lines[i]
|
||||
return '\n'.join(lines)
|
||||
|
||||
81
trellis/utils/grad_clip_utils.py
Normal file
81
trellis/utils/grad_clip_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils
|
||||
|
||||
|
||||
class AdaptiveGradClipper:
|
||||
"""
|
||||
Adaptive gradient clipping for training.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
max_norm=None,
|
||||
clip_percentile=95.0,
|
||||
buffer_size=1000,
|
||||
):
|
||||
self.max_norm = max_norm
|
||||
self.clip_percentile = clip_percentile
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
|
||||
self._max_norm = max_norm
|
||||
self._buffer_ptr = 0
|
||||
self._buffer_length = 0
|
||||
|
||||
def __repr__(self):
|
||||
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'grad_norm': self._grad_norm,
|
||||
'max_norm': self._max_norm,
|
||||
'buffer_ptr': self._buffer_ptr,
|
||||
'buffer_length': self._buffer_length,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._grad_norm = state_dict['grad_norm']
|
||||
self._max_norm = state_dict['max_norm']
|
||||
self._buffer_ptr = state_dict['buffer_ptr']
|
||||
self._buffer_length = state_dict['buffer_length']
|
||||
|
||||
def log(self):
|
||||
return {
|
||||
'max_norm': self._max_norm,
|
||||
}
|
||||
|
||||
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
|
||||
"""Clip the gradient norm of an iterable of parameters.
|
||||
|
||||
The norm is computed over all gradients together, as if they were
|
||||
concatenated into a single vector. Gradients are modified in-place.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||
norm of the gradients from :attr:`parameters` is ``nan``,
|
||||
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||||
foreach (bool): use the faster foreach-based implementation.
|
||||
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
||||
fall back to the slow implementation for other device types.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
Total norm of the parameter gradients (viewed as a single vector).
|
||||
"""
|
||||
max_norm = self._max_norm if self._max_norm is not None else float('inf')
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
|
||||
|
||||
if torch.isfinite(grad_norm):
|
||||
self._grad_norm[self._buffer_ptr] = grad_norm
|
||||
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
||||
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
||||
if self._buffer_length == self.buffer_size:
|
||||
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
|
||||
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
|
||||
|
||||
return grad_norm
|
||||
Reference in New Issue
Block a user