1
This commit is contained in:
77
trellis/trainers/utils.py
Normal file
77
trellis/trainers/utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# FP16 utils
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
def make_master_params(model_params):
|
||||
"""
|
||||
Copy model parameters into a inflated tensor of full-precision parameters.
|
||||
"""
|
||||
master_params = _flatten_dense_tensors(
|
||||
[param.detach().float() for param in model_params]
|
||||
)
|
||||
master_params = nn.Parameter(master_params)
|
||||
master_params.requires_grad = True
|
||||
return [master_params]
|
||||
|
||||
|
||||
def unflatten_master_params(model_params, master_params):
|
||||
"""
|
||||
Unflatten the master parameters to look like model_params.
|
||||
"""
|
||||
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
|
||||
|
||||
|
||||
def model_params_to_master_params(model_params, master_params):
|
||||
"""
|
||||
Copy the model parameter data into the master parameters.
|
||||
"""
|
||||
master_params[0].detach().copy_(
|
||||
_flatten_dense_tensors([param.detach().float() for param in model_params])
|
||||
)
|
||||
|
||||
|
||||
def master_params_to_model_params(model_params, master_params):
|
||||
"""
|
||||
Copy the master parameter data back into the model parameters.
|
||||
"""
|
||||
for param, master_param in zip(
|
||||
model_params, _unflatten_dense_tensors(master_params[0].detach(), model_params)
|
||||
):
|
||||
param.detach().copy_(master_param)
|
||||
|
||||
|
||||
def model_grads_to_master_grads(model_params, master_params):
|
||||
"""
|
||||
Copy the gradients from the model parameters into the master parameters
|
||||
from make_master_params().
|
||||
"""
|
||||
master_params[0].grad = _flatten_dense_tensors(
|
||||
[param.grad.data.detach().float() for param in model_params]
|
||||
)
|
||||
|
||||
|
||||
def zero_grad(model_params):
|
||||
for param in model_params:
|
||||
if param.grad is not None:
|
||||
if param.grad.grad_fn is not None:
|
||||
param.grad.detach_()
|
||||
else:
|
||||
param.grad.requires_grad_(False)
|
||||
param.grad.zero_()
|
||||
|
||||
|
||||
# LR Schedulers
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
class LinearWarmupLRScheduler(LambdaLR):
|
||||
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
super(LinearWarmupLRScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, current_step):
|
||||
if current_step < self.warmup_steps:
|
||||
return float(current_step + 1) / self.warmup_steps
|
||||
return 1.0
|
||||
|
||||
Reference in New Issue
Block a user