1
This commit is contained in:
81
trellis/utils/grad_clip_utils.py
Normal file
81
trellis/utils/grad_clip_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import *
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils
|
||||
|
||||
|
||||
class AdaptiveGradClipper:
|
||||
"""
|
||||
Adaptive gradient clipping for training.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
max_norm=None,
|
||||
clip_percentile=95.0,
|
||||
buffer_size=1000,
|
||||
):
|
||||
self.max_norm = max_norm
|
||||
self.clip_percentile = clip_percentile
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
|
||||
self._max_norm = max_norm
|
||||
self._buffer_ptr = 0
|
||||
self._buffer_length = 0
|
||||
|
||||
def __repr__(self):
|
||||
return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'grad_norm': self._grad_norm,
|
||||
'max_norm': self._max_norm,
|
||||
'buffer_ptr': self._buffer_ptr,
|
||||
'buffer_length': self._buffer_length,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._grad_norm = state_dict['grad_norm']
|
||||
self._max_norm = state_dict['max_norm']
|
||||
self._buffer_ptr = state_dict['buffer_ptr']
|
||||
self._buffer_length = state_dict['buffer_length']
|
||||
|
||||
def log(self):
|
||||
return {
|
||||
'max_norm': self._max_norm,
|
||||
}
|
||||
|
||||
def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
|
||||
"""Clip the gradient norm of an iterable of parameters.
|
||||
|
||||
The norm is computed over all gradients together, as if they were
|
||||
concatenated into a single vector. Gradients are modified in-place.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||
norm of the gradients from :attr:`parameters` is ``nan``,
|
||||
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
||||
foreach (bool): use the faster foreach-based implementation.
|
||||
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
||||
fall back to the slow implementation for other device types.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
Total norm of the parameter gradients (viewed as a single vector).
|
||||
"""
|
||||
max_norm = self._max_norm if self._max_norm is not None else float('inf')
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
|
||||
|
||||
if torch.isfinite(grad_norm):
|
||||
self._grad_norm[self._buffer_ptr] = grad_norm
|
||||
self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
|
||||
self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
|
||||
if self._buffer_length == self.buffer_size:
|
||||
self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
|
||||
self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
|
||||
|
||||
return grad_norm
|
||||
Reference in New Issue
Block a user