This commit is contained in:
zcr
2026-03-17 11:38:02 +08:00
parent 046be2c797
commit 0571f65793
8 changed files with 1413 additions and 0 deletions

View File

@@ -0,0 +1,15 @@
from typing import *
class GuidanceIntervalSamplerMixin:
"""
A mixin class for samplers that apply classifier-free guidance with interval.
"""
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
if cfg_interval[0] <= t <= cfg_interval[1]:
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
else:
return super()._inference_model(model, x_t, t, cond, **kwargs)

View File

@@ -0,0 +1,231 @@
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import math
from easydict import EasyDict as edict
import numpy as np
from ..representations.gaussian import Gaussian
from .sh_utils import eval_sh
import torch.nn.functional as F
from easydict import EasyDict as edict
def intrinsics_to_projection(
intrinsics: torch.Tensor,
near: float,
far: float,
) -> torch.Tensor:
"""
OpenCV intrinsics to OpenGL perspective matrix
Args:
intrinsics (torch.Tensor): [3, 3] OpenCV intrinsics matrix
near (float): near plane to clip
far (float): far plane to clip
Returns:
(torch.Tensor): [4, 4] OpenGL perspective matrix
"""
fx, fy = intrinsics[0, 0], intrinsics[1, 1]
cx, cy = intrinsics[0, 2], intrinsics[1, 2]
ret = torch.zeros((4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
ret[0, 0] = 2 * fx
ret[1, 1] = 2 * fy
ret[0, 2] = 2 * cx - 1
ret[1, 2] = - 2 * cy + 1
ret[2, 2] = far / (far - near)
ret[2, 3] = near * far / (near - far)
ret[3, 2] = 1.
return ret
def render(viewpoint_camera, pc : Gaussian, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
"""
Render the scene.
Background tensor (bg_color) must be on GPU!
"""
# lazy import
if 'GaussianRasterizer' not in globals():
from diff_gaussian_rasterization import GaussianRasterizer, GaussianRasterizationSettings
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
pass
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
kernel_size = pipe.kernel_size
subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
kernel_size=kernel_size,
subpixel_offset=subpixel_offset,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=pipe.debug
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = pc.get_xyz
means2D = screenspace_points
opacity = pc.get_opacity
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
if pipe.compute_cov3D_python:
cov3D_precomp = pc.get_covariance(scaling_modifier)
else:
scales = pc.get_scaling
rotations = pc.get_rotation
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
if override_color is None:
if pipe.convert_SHs_python:
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
else:
shs = pc.get_features
else:
colors_precomp = override_color
# Rasterize visible Gaussians to image, obtain their radii (on screen).
rendered_image, radii = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = cov3D_precomp
)
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return edict({"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter" : radii > 0,
"radii": radii})
class GaussianRenderer:
"""
Renderer for the Voxel representation.
Args:
rendering_options (dict): Rendering options.
"""
def __init__(self, rendering_options={}) -> None:
self.pipe = edict({
"kernel_size": 0.1,
"convert_SHs_python": False,
"compute_cov3D_python": False,
"scale_modifier": 1.0,
"debug": False
})
self.rendering_options = edict({
"resolution": None,
"near": None,
"far": None,
"ssaa": 1,
"bg_color": 'random',
})
self.rendering_options.update(rendering_options)
self.bg_color = None
def render(
self,
gausssian: Gaussian,
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
colors_overwrite: torch.Tensor = None
) -> edict:
"""
Render the gausssian.
Args:
gaussian : gaussianmodule
extrinsics (torch.Tensor): (4, 4) camera extrinsics
intrinsics (torch.Tensor): (3, 3) camera intrinsics
colors_overwrite (torch.Tensor): (N, 3) override color
Returns:
edict containing:
color (torch.Tensor): (3, H, W) rendered color image
"""
resolution = self.rendering_options["resolution"]
near = self.rendering_options["near"]
far = self.rendering_options["far"]
ssaa = self.rendering_options["ssaa"]
if self.rendering_options["bg_color"] == 'random':
self.bg_color = torch.zeros(3, dtype=torch.float32, device="cuda")
if np.random.rand() < 0.5:
self.bg_color += 1
else:
self.bg_color = torch.tensor(self.rendering_options["bg_color"], dtype=torch.float32, device="cuda")
view = extrinsics
perspective = intrinsics_to_projection(intrinsics, near, far)
camera = torch.inverse(view)[:3, 3]
focalx = intrinsics[0, 0]
focaly = intrinsics[1, 1]
fovx = 2 * torch.atan(0.5 / focalx)
fovy = 2 * torch.atan(0.5 / focaly)
camera_dict = edict({
"image_height": resolution * ssaa,
"image_width": resolution * ssaa,
"FoVx": fovx,
"FoVy": fovy,
"znear": near,
"zfar": far,
"world_view_transform": view.T.contiguous(),
"projection_matrix": perspective.T.contiguous(),
"full_proj_transform": (perspective @ view).T.contiguous(),
"camera_center": camera
})
# Render
render_ret = render(camera_dict, gausssian, self.pipe, self.bg_color, override_color=colors_overwrite, scaling_modifier=self.pipe.scale_modifier)
if ssaa > 1:
render_ret.render = F.interpolate(render_ret.render[None], size=(resolution, resolution), mode='bilinear', align_corners=False, antialias=True).squeeze()
ret = edict({
'color': render_ret['render']
})
return ret

View File

@@ -0,0 +1,133 @@
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import sys
from datetime import datetime
import numpy as np
import random
def inverse_sigmoid(x):
return torch.log(x/(1-x))
def PILtoTorch(pil_image, resolution):
resized_image_PIL = pil_image.resize(resolution)
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
if len(resized_image.shape) == 3:
return resized_image.permute(2, 0, 1)
else:
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
def get_expon_lr_func(
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
"""
Copied from Plenoxels
Continuous learning rate decay function. Adapted from JaxNeRF
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
is log-linearly interpolated elsewhere (equivalent to exponential decay).
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
function of lr_delay_mult, such that the initial learning rate is
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
to the normal learning rate when steps>lr_delay_steps.
:param conf: config subtree 'lr' or similar
:param max_steps: int, the number of steps during optimization.
:return HoF which takes step as input
"""
def helper(step):
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
# Disable this parameter
return 0.0
if lr_delay_steps > 0:
# A kind of reverse cosine decay.
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
)
else:
delay_rate = 1.0
t = np.clip(step / max_steps, 0, 1)
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
return delay_rate * log_lerp
return helper
def strip_lowerdiag(L):
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
uncertainty[:, 0] = L[:, 0, 0]
uncertainty[:, 1] = L[:, 0, 1]
uncertainty[:, 2] = L[:, 0, 2]
uncertainty[:, 3] = L[:, 1, 1]
uncertainty[:, 4] = L[:, 1, 2]
uncertainty[:, 5] = L[:, 2, 2]
return uncertainty
def strip_symmetric(sym):
return strip_lowerdiag(sym)
def build_rotation(r):
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
q = r / norm[:, None]
R = torch.zeros((q.size(0), 3, 3), device='cuda')
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
R[:, 0, 1] = 2 * (x*y - r*z)
R[:, 0, 2] = 2 * (x*z + r*y)
R[:, 1, 0] = 2 * (x*y + r*z)
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
R[:, 1, 2] = 2 * (y*z - r*x)
R[:, 2, 0] = 2 * (x*z - r*y)
R[:, 2, 1] = 2 * (y*z + r*x)
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
return R
def build_scaling_rotation(s, r):
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
R = build_rotation(r)
L[:,0,0] = s[:,0]
L[:,1,1] = s[:,1]
L[:,2,2] = s[:,2]
L = R @ L
return L
def safe_state(silent):
old_f = sys.stdout
class F:
def __init__(self, silent):
self.silent = silent
def write(self, x):
if not self.silent:
if x.endswith("\n"):
old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
else:
old_f.write(x)
def flush(self):
old_f.flush()
sys.stdout = F(silent)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))

View File

@@ -0,0 +1,93 @@
from typing import *
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
from ....utils import dist_utils
class ImageConditionedMixin:
"""
Mixin for image-conditioned models.
Args:
image_cond_model: The image conditioning model.
"""
def __init__(self, *args, image_cond_model: str = 'dinov2_vitl14_reg', **kwargs):
super().__init__(*args, **kwargs)
self.image_cond_model_name = image_cond_model
self.image_cond_model = None # the model is init lazily
@staticmethod
def prepare_for_training(image_cond_model: str, **kwargs):
"""
Prepare for training.
"""
if hasattr(super(ImageConditionedMixin, ImageConditionedMixin), 'prepare_for_training'):
super(ImageConditionedMixin, ImageConditionedMixin).prepare_for_training(**kwargs)
# download the model
torch.hub.load('facebookresearch/dinov2', image_cond_model, pretrained=True)
def _init_image_cond_model(self):
"""
Initialize the image conditioning model.
"""
with dist_utils.local_master_first():
dinov2_model = torch.hub.load('facebookresearch/dinov2', self.image_cond_model_name, pretrained=True)
dinov2_model.eval().cuda()
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.image_cond_model = {
'model': dinov2_model,
'transform': transform,
}
@torch.no_grad()
def encode_image(self, image: Union[torch.Tensor, List[Image.Image]]) -> torch.Tensor:
"""
Encode the image.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((518, 518), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).cuda()
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
if self.image_cond_model is None:
self._init_image_cond_model()
image = self.image_cond_model['transform'](image).cuda()
features = self.image_cond_model['model'](image, is_training=True)['x_prenorm']
patchtokens = F.layer_norm(features, features.shape[-1:])
return patchtokens
def get_cond(self, cond, **kwargs):
"""
Get the conditioning data.
"""
cond = self.encode_image(cond)
kwargs['neg_cond'] = torch.zeros_like(cond)
cond = super().get_cond(cond, **kwargs)
return cond
def get_inference_cond(self, cond, **kwargs):
"""
Get the conditioning data for inference.
"""
cond = self.encode_image(cond)
kwargs['neg_cond'] = torch.zeros_like(cond)
cond = super().get_inference_cond(cond, **kwargs)
return cond
def vis_cond(self, cond, **kwargs):
"""
Visualize the conditioning data.
"""
return {'image': {'value': cond, 'type': 'image'}}

View File

@@ -0,0 +1,202 @@
import re
import numpy as np
import cv2
import torch
import contextlib
# Dictionary utils
def _dict_merge(dicta, dictb, prefix=''):
"""
Merge two dictionaries.
"""
assert isinstance(dicta, dict), 'input must be a dictionary'
assert isinstance(dictb, dict), 'input must be a dictionary'
dict_ = {}
all_keys = set(dicta.keys()).union(set(dictb.keys()))
for key in all_keys:
if key in dicta.keys() and key in dictb.keys():
if isinstance(dicta[key], dict) and isinstance(dictb[key], dict):
dict_[key] = _dict_merge(dicta[key], dictb[key], prefix=f'{prefix}.{key}')
else:
raise ValueError(f'Duplicate key {prefix}.{key} found in both dictionaries. Types: {type(dicta[key])}, {type(dictb[key])}')
elif key in dicta.keys():
dict_[key] = dicta[key]
else:
dict_[key] = dictb[key]
return dict_
def dict_merge(dicta, dictb):
"""
Merge two dictionaries.
"""
return _dict_merge(dicta, dictb, prefix='')
def dict_foreach(dic, func, special_func={}):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
dic[key] = dict_foreach(dic[key], func)
else:
if key in special_func.keys():
dic[key] = special_func[key](dic[key])
else:
dic[key] = func(dic[key])
return dic
def dict_reduce(dicts, func, special_func={}):
"""
Reduce a list of dictionaries. Leaf values must be scalars.
"""
assert isinstance(dicts, list), 'input must be a list of dictionaries'
assert all([isinstance(d, dict) for d in dicts]), 'input must be a list of dictionaries'
assert len(dicts) > 0, 'input must be a non-empty list of dictionaries'
all_keys = set([key for dict_ in dicts for key in dict_.keys()])
reduced_dict = {}
for key in all_keys:
vlist = [dict_[key] for dict_ in dicts if key in dict_.keys()]
if isinstance(vlist[0], dict):
reduced_dict[key] = dict_reduce(vlist, func, special_func)
else:
if key in special_func.keys():
reduced_dict[key] = special_func[key](vlist)
else:
reduced_dict[key] = func(vlist)
return reduced_dict
def dict_any(dic, func):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
if dict_any(dic[key], func):
return True
else:
if func(dic[key]):
return True
return False
def dict_all(dic, func):
"""
Recursively apply a function to all non-dictionary leaf values in a dictionary.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
for key in dic.keys():
if isinstance(dic[key], dict):
if not dict_all(dic[key], func):
return False
else:
if not func(dic[key]):
return False
return True
def dict_flatten(dic, sep='.'):
"""
Flatten a nested dictionary into a dictionary with no nested dictionaries.
"""
assert isinstance(dic, dict), 'input must be a dictionary'
flat_dict = {}
for key in dic.keys():
if isinstance(dic[key], dict):
sub_dict = dict_flatten(dic[key], sep=sep)
for sub_key in sub_dict.keys():
flat_dict[str(key) + sep + str(sub_key)] = sub_dict[sub_key]
else:
flat_dict[key] = dic[key]
return flat_dict
# Context utils
@contextlib.contextmanager
def nested_contexts(*contexts):
with contextlib.ExitStack() as stack:
for ctx in contexts:
stack.enter_context(ctx())
yield
# Image utils
def make_grid(images, nrow=None, ncol=None, aspect_ratio=None):
num_images = len(images)
if nrow is None and ncol is None:
if aspect_ratio is not None:
nrow = int(np.round(np.sqrt(num_images / aspect_ratio)))
else:
nrow = int(np.sqrt(num_images))
ncol = (num_images + nrow - 1) // nrow
elif nrow is None and ncol is not None:
nrow = (num_images + ncol - 1) // ncol
elif nrow is not None and ncol is None:
ncol = (num_images + nrow - 1) // nrow
else:
assert nrow * ncol >= num_images, 'nrow * ncol must be greater than or equal to the number of images'
if images[0].ndim == 2:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1]), dtype=images[0].dtype)
else:
grid = np.zeros((nrow * images[0].shape[0], ncol * images[0].shape[1], images[0].shape[2]), dtype=images[0].dtype)
for i, img in enumerate(images):
row = i // ncol
col = i % ncol
grid[row * img.shape[0]:(row + 1) * img.shape[0], col * img.shape[1]:(col + 1) * img.shape[1]] = img
return grid
def notes_on_image(img, notes=None):
img = np.pad(img, ((0, 32), (0, 0), (0, 0)), 'constant', constant_values=0)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if notes is not None:
img = cv2.putText(img, notes, (0, img.shape[0] - 4), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def save_image_with_notes(img, path, notes=None):
"""
Save an image with notes.
"""
if isinstance(img, torch.Tensor):
img = img.cpu().numpy().transpose(1, 2, 0)
if img.dtype == np.float32 or img.dtype == np.float64:
img = np.clip(img * 255, 0, 255).astype(np.uint8)
img = notes_on_image(img, notes)
cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
# debug utils
def atol(x, y):
"""
Absolute tolerance.
"""
return torch.abs(x - y)
def rtol(x, y):
"""
Relative tolerance.
"""
return torch.abs(x - y) / torch.clamp_min(torch.maximum(torch.abs(x), torch.abs(y)), 1e-12)
# print utils
def indent(s, n=4):
"""
Indent a string.
"""
lines = s.split('\n')
for i in range(1, len(lines)):
lines[i] = ' ' * n + lines[i]
return '\n'.join(lines)

View File

@@ -0,0 +1,81 @@
from typing import *
import torch
import numpy as np
import torch.utils
class AdaptiveGradClipper:
"""
Adaptive gradient clipping for training.
"""
def __init__(
self,
max_norm=None,
clip_percentile=95.0,
buffer_size=1000,
):
self.max_norm = max_norm
self.clip_percentile = clip_percentile
self.buffer_size = buffer_size
self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
self._max_norm = max_norm
self._buffer_ptr = 0
self._buffer_length = 0
def __repr__(self):
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
def state_dict(self):
return {
'grad_norm': self._grad_norm,
'max_norm': self._max_norm,
'buffer_ptr': self._buffer_ptr,
'buffer_length': self._buffer_length,
}
def load_state_dict(self, state_dict):
self._grad_norm = state_dict['grad_norm']
self._max_norm = state_dict['max_norm']
self._buffer_ptr = state_dict['buffer_ptr']
self._buffer_length = state_dict['buffer_length']
def log(self):
return {
'max_norm': self._max_norm,
}
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
"""Clip the gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
Total norm of the parameter gradients (viewed as a single vector).
"""
max_norm = self._max_norm if self._max_norm is not None else float('inf')
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
if torch.isfinite(grad_norm):
self._grad_norm[self._buffer_ptr] = grad_norm
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
if self._buffer_length == self.buffer_size:
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
return grad_norm