Add codeformer and update license
This commit is contained in:
26
basicsr/losses/__init__.py
Normal file
26
basicsr/losses/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import LOSS_REGISTRY
|
||||
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
|
||||
gradient_penalty_loss, r1_penalty)
|
||||
|
||||
__all__ = [
|
||||
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
|
||||
'r1_penalty', 'g_path_regularize'
|
||||
]
|
||||
|
||||
|
||||
def build_loss(opt):
|
||||
"""Build loss from options.
|
||||
|
||||
Args:
|
||||
opt (dict): Configuration. It must constain:
|
||||
type (str): Model type.
|
||||
"""
|
||||
opt = deepcopy(opt)
|
||||
loss_type = opt.pop('type')
|
||||
loss = LOSS_REGISTRY.get(loss_type)(**opt)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Loss [{loss.__class__.__name__}] is created.')
|
||||
return loss
|
||||
95
basicsr/losses/loss_util.py
Normal file
95
basicsr/losses/loss_util.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import functools
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def reduce_loss(loss, reduction):
|
||||
"""Reduce loss as specified.
|
||||
|
||||
Args:
|
||||
loss (Tensor): Elementwise loss tensor.
|
||||
reduction (str): Options are 'none', 'mean' and 'sum'.
|
||||
|
||||
Returns:
|
||||
Tensor: Reduced loss tensor.
|
||||
"""
|
||||
reduction_enum = F._Reduction.get_enum(reduction)
|
||||
# none: 0, elementwise_mean:1, sum: 2
|
||||
if reduction_enum == 0:
|
||||
return loss
|
||||
elif reduction_enum == 1:
|
||||
return loss.mean()
|
||||
else:
|
||||
return loss.sum()
|
||||
|
||||
|
||||
def weight_reduce_loss(loss, weight=None, reduction='mean'):
|
||||
"""Apply element-wise weight and reduce loss.
|
||||
|
||||
Args:
|
||||
loss (Tensor): Element-wise loss.
|
||||
weight (Tensor): Element-wise weights. Default: None.
|
||||
reduction (str): Same as built-in losses of PyTorch. Options are
|
||||
'none', 'mean' and 'sum'. Default: 'mean'.
|
||||
|
||||
Returns:
|
||||
Tensor: Loss values.
|
||||
"""
|
||||
# if weight is specified, apply element-wise weight
|
||||
if weight is not None:
|
||||
assert weight.dim() == loss.dim()
|
||||
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
||||
loss = loss * weight
|
||||
|
||||
# if weight is not specified or reduction is sum, just reduce the loss
|
||||
if weight is None or reduction == 'sum':
|
||||
loss = reduce_loss(loss, reduction)
|
||||
# if reduction is mean, then compute mean over weight region
|
||||
elif reduction == 'mean':
|
||||
if weight.size(1) > 1:
|
||||
weight = weight.sum()
|
||||
else:
|
||||
weight = weight.sum() * loss.size(1)
|
||||
loss = loss.sum() / weight
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def weighted_loss(loss_func):
|
||||
"""Create a weighted version of a given loss function.
|
||||
|
||||
To use this decorator, the loss function must have the signature like
|
||||
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
||||
element-wise loss without any reduction. This decorator will add weight
|
||||
and reduction arguments to the function. The decorated function will have
|
||||
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
||||
**kwargs)`.
|
||||
|
||||
:Example:
|
||||
|
||||
>>> import torch
|
||||
>>> @weighted_loss
|
||||
>>> def l1_loss(pred, target):
|
||||
>>> return (pred - target).abs()
|
||||
|
||||
>>> pred = torch.Tensor([0, 2, 3])
|
||||
>>> target = torch.Tensor([1, 1, 1])
|
||||
>>> weight = torch.Tensor([1, 0, 1])
|
||||
|
||||
>>> l1_loss(pred, target)
|
||||
tensor(1.3333)
|
||||
>>> l1_loss(pred, target, weight)
|
||||
tensor(1.5000)
|
||||
>>> l1_loss(pred, target, reduction='none')
|
||||
tensor([1., 1., 2.])
|
||||
>>> l1_loss(pred, target, weight, reduction='sum')
|
||||
tensor(3.)
|
||||
"""
|
||||
|
||||
@functools.wraps(loss_func)
|
||||
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
|
||||
# get element-wise loss
|
||||
loss = loss_func(pred, target, **kwargs)
|
||||
loss = weight_reduce_loss(loss, weight, reduction)
|
||||
return loss
|
||||
|
||||
return wrapper
|
||||
455
basicsr/losses/losses.py
Normal file
455
basicsr/losses/losses.py
Normal file
@@ -0,0 +1,455 @@
|
||||
import math
|
||||
import lpips
|
||||
import torch
|
||||
from torch import autograd as autograd
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from basicsr.archs.vgg_arch import VGGFeatureExtractor
|
||||
from basicsr.utils.registry import LOSS_REGISTRY
|
||||
from .loss_util import weighted_loss
|
||||
|
||||
_reduction_modes = ['none', 'mean', 'sum']
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def l1_loss(pred, target):
|
||||
return F.l1_loss(pred, target, reduction='none')
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def mse_loss(pred, target):
|
||||
return F.mse_loss(pred, target, reduction='none')
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def charbonnier_loss(pred, target, eps=1e-12):
|
||||
return torch.sqrt((pred - target)**2 + eps)
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class L1Loss(nn.Module):
|
||||
"""L1 (mean absolute error, MAE) loss.
|
||||
|
||||
Args:
|
||||
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
||||
reduction (str): Specifies the reduction to apply to the output.
|
||||
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_weight=1.0, reduction='mean'):
|
||||
super(L1Loss, self).__init__()
|
||||
if reduction not in ['none', 'mean', 'sum']:
|
||||
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
||||
|
||||
self.loss_weight = loss_weight
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred, target, weight=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
||||
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
||||
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
||||
weights. Default: None.
|
||||
"""
|
||||
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class MSELoss(nn.Module):
|
||||
"""MSE (L2) loss.
|
||||
|
||||
Args:
|
||||
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
|
||||
reduction (str): Specifies the reduction to apply to the output.
|
||||
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_weight=1.0, reduction='mean'):
|
||||
super(MSELoss, self).__init__()
|
||||
if reduction not in ['none', 'mean', 'sum']:
|
||||
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
||||
|
||||
self.loss_weight = loss_weight
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred, target, weight=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
||||
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
||||
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
||||
weights. Default: None.
|
||||
"""
|
||||
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class CharbonnierLoss(nn.Module):
|
||||
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
|
||||
variant of L1Loss).
|
||||
|
||||
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
|
||||
Super-Resolution".
|
||||
|
||||
Args:
|
||||
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
||||
reduction (str): Specifies the reduction to apply to the output.
|
||||
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
||||
eps (float): A value used to control the curvature near zero.
|
||||
Default: 1e-12.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
|
||||
super(CharbonnierLoss, self).__init__()
|
||||
if reduction not in ['none', 'mean', 'sum']:
|
||||
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
||||
|
||||
self.loss_weight = loss_weight
|
||||
self.reduction = reduction
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, pred, target, weight=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
||||
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
||||
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
||||
weights. Default: None.
|
||||
"""
|
||||
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class WeightedTVLoss(L1Loss):
|
||||
"""Weighted TV loss.
|
||||
|
||||
Args:
|
||||
loss_weight (float): Loss weight. Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_weight=1.0):
|
||||
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
|
||||
|
||||
def forward(self, pred, weight=None):
|
||||
y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
|
||||
x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
|
||||
|
||||
loss = x_diff + y_diff
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class PerceptualLoss(nn.Module):
|
||||
"""Perceptual loss with commonly used style loss.
|
||||
|
||||
Args:
|
||||
layer_weights (dict): The weight for each layer of vgg feature.
|
||||
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
||||
feature layer (before relu5_4) will be extracted with weight
|
||||
1.0 in calculting losses.
|
||||
vgg_type (str): The type of vgg network used as feature extractor.
|
||||
Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image in vgg.
|
||||
Default: True.
|
||||
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
||||
Default: False.
|
||||
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
||||
loss will be calculated and the loss will multiplied by the
|
||||
weight. Default: 1.0.
|
||||
style_weight (float): If `style_weight > 0`, the style loss will be
|
||||
calculated and the loss will multiplied by the weight.
|
||||
Default: 0.
|
||||
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layer_weights,
|
||||
vgg_type='vgg19',
|
||||
use_input_norm=True,
|
||||
range_norm=False,
|
||||
perceptual_weight=1.0,
|
||||
style_weight=0.,
|
||||
criterion='l1'):
|
||||
super(PerceptualLoss, self).__init__()
|
||||
self.perceptual_weight = perceptual_weight
|
||||
self.style_weight = style_weight
|
||||
self.layer_weights = layer_weights
|
||||
self.vgg = VGGFeatureExtractor(
|
||||
layer_name_list=list(layer_weights.keys()),
|
||||
vgg_type=vgg_type,
|
||||
use_input_norm=use_input_norm,
|
||||
range_norm=range_norm)
|
||||
|
||||
self.criterion_type = criterion
|
||||
if self.criterion_type == 'l1':
|
||||
self.criterion = torch.nn.L1Loss()
|
||||
elif self.criterion_type == 'l2':
|
||||
self.criterion = torch.nn.L2loss()
|
||||
elif self.criterion_type == 'mse':
|
||||
self.criterion = torch.nn.MSELoss(reduction='mean')
|
||||
elif self.criterion_type == 'fro':
|
||||
self.criterion = None
|
||||
else:
|
||||
raise NotImplementedError(f'{criterion} criterion has not been supported.')
|
||||
|
||||
def forward(self, x, gt):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
# extract vgg features
|
||||
x_features = self.vgg(x)
|
||||
gt_features = self.vgg(gt.detach())
|
||||
|
||||
# calculate perceptual loss
|
||||
if self.perceptual_weight > 0:
|
||||
percep_loss = 0
|
||||
for k in x_features.keys():
|
||||
if self.criterion_type == 'fro':
|
||||
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
|
||||
else:
|
||||
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
percep_loss *= self.perceptual_weight
|
||||
else:
|
||||
percep_loss = None
|
||||
|
||||
# calculate style loss
|
||||
if self.style_weight > 0:
|
||||
style_loss = 0
|
||||
for k in x_features.keys():
|
||||
if self.criterion_type == 'fro':
|
||||
style_loss += torch.norm(
|
||||
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
|
||||
else:
|
||||
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
|
||||
gt_features[k])) * self.layer_weights[k]
|
||||
style_loss *= self.style_weight
|
||||
else:
|
||||
style_loss = None
|
||||
|
||||
return percep_loss, style_loss
|
||||
|
||||
def _gram_mat(self, x):
|
||||
"""Calculate Gram matrix.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Gram matrix.
|
||||
"""
|
||||
n, c, h, w = x.size()
|
||||
features = x.view(n, c, w * h)
|
||||
features_t = features.transpose(1, 2)
|
||||
gram = features.bmm(features_t) / (c * h * w)
|
||||
return gram
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class LPIPSLoss(nn.Module):
|
||||
def __init__(self,
|
||||
loss_weight=1.0,
|
||||
use_input_norm=True,
|
||||
range_norm=False,):
|
||||
super(LPIPSLoss, self).__init__()
|
||||
self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
|
||||
self.loss_weight = loss_weight
|
||||
self.use_input_norm = use_input_norm
|
||||
self.range_norm = range_norm
|
||||
|
||||
if self.use_input_norm:
|
||||
# the mean is for image with range [0, 1]
|
||||
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
# the std is for image with range [0, 1]
|
||||
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def forward(self, pred, target):
|
||||
if self.range_norm:
|
||||
pred = (pred + 1) / 2
|
||||
target = (target + 1) / 2
|
||||
if self.use_input_norm:
|
||||
pred = (pred - self.mean) / self.std
|
||||
target = (target - self.mean) / self.std
|
||||
lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
|
||||
return self.loss_weight * lpips_loss.mean()
|
||||
|
||||
|
||||
@LOSS_REGISTRY.register()
|
||||
class GANLoss(nn.Module):
|
||||
"""Define GAN loss.
|
||||
|
||||
Args:
|
||||
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
|
||||
real_label_val (float): The value for real label. Default: 1.0.
|
||||
fake_label_val (float): The value for fake label. Default: 0.0.
|
||||
loss_weight (float): Loss weight. Default: 1.0.
|
||||
Note that loss_weight is only for generators; and it is always 1.0
|
||||
for discriminators.
|
||||
"""
|
||||
|
||||
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
|
||||
super(GANLoss, self).__init__()
|
||||
self.gan_type = gan_type
|
||||
self.loss_weight = loss_weight
|
||||
self.real_label_val = real_label_val
|
||||
self.fake_label_val = fake_label_val
|
||||
|
||||
if self.gan_type == 'vanilla':
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif self.gan_type == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
elif self.gan_type == 'wgan':
|
||||
self.loss = self._wgan_loss
|
||||
elif self.gan_type == 'wgan_softplus':
|
||||
self.loss = self._wgan_softplus_loss
|
||||
elif self.gan_type == 'hinge':
|
||||
self.loss = nn.ReLU()
|
||||
else:
|
||||
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
|
||||
|
||||
def _wgan_loss(self, input, target):
|
||||
"""wgan loss.
|
||||
|
||||
Args:
|
||||
input (Tensor): Input tensor.
|
||||
target (bool): Target label.
|
||||
|
||||
Returns:
|
||||
Tensor: wgan loss.
|
||||
"""
|
||||
return -input.mean() if target else input.mean()
|
||||
|
||||
def _wgan_softplus_loss(self, input, target):
|
||||
"""wgan loss with soft plus. softplus is a smooth approximation to the
|
||||
ReLU function.
|
||||
|
||||
In StyleGAN2, it is called:
|
||||
Logistic loss for discriminator;
|
||||
Non-saturating loss for generator.
|
||||
|
||||
Args:
|
||||
input (Tensor): Input tensor.
|
||||
target (bool): Target label.
|
||||
|
||||
Returns:
|
||||
Tensor: wgan loss.
|
||||
"""
|
||||
return F.softplus(-input).mean() if target else F.softplus(input).mean()
|
||||
|
||||
def get_target_label(self, input, target_is_real):
|
||||
"""Get target label.
|
||||
|
||||
Args:
|
||||
input (Tensor): Input tensor.
|
||||
target_is_real (bool): Whether the target is real or fake.
|
||||
|
||||
Returns:
|
||||
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
|
||||
return Tensor.
|
||||
"""
|
||||
|
||||
if self.gan_type in ['wgan', 'wgan_softplus']:
|
||||
return target_is_real
|
||||
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
|
||||
return input.new_ones(input.size()) * target_val
|
||||
|
||||
def forward(self, input, target_is_real, is_disc=False):
|
||||
"""
|
||||
Args:
|
||||
input (Tensor): The input for the loss module, i.e., the network
|
||||
prediction.
|
||||
target_is_real (bool): Whether the targe is real or fake.
|
||||
is_disc (bool): Whether the loss for discriminators or not.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor: GAN loss value.
|
||||
"""
|
||||
if self.gan_type == 'hinge':
|
||||
if is_disc: # for discriminators in hinge-gan
|
||||
input = -input if target_is_real else input
|
||||
loss = self.loss(1 + input).mean()
|
||||
else: # for generators in hinge-gan
|
||||
loss = -input.mean()
|
||||
else: # other gan types
|
||||
target_label = self.get_target_label(input, target_is_real)
|
||||
loss = self.loss(input, target_label)
|
||||
|
||||
# loss_weight is always 1.0 for discriminators
|
||||
return loss if is_disc else loss * self.loss_weight
|
||||
|
||||
|
||||
def r1_penalty(real_pred, real_img):
|
||||
"""R1 regularization for discriminator. The core idea is to
|
||||
penalize the gradient on real data alone: when the
|
||||
generator distribution produces the true data distribution
|
||||
and the discriminator is equal to 0 on the data manifold, the
|
||||
gradient penalty ensures that the discriminator cannot create
|
||||
a non-zero gradient orthogonal to the data manifold without
|
||||
suffering a loss in the GAN game.
|
||||
|
||||
Ref:
|
||||
Eq. 9 in Which training methods for GANs do actually converge.
|
||||
"""
|
||||
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
|
||||
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
|
||||
return grad_penalty
|
||||
|
||||
|
||||
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
||||
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
|
||||
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
|
||||
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
||||
|
||||
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
||||
|
||||
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
||||
|
||||
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
|
||||
|
||||
|
||||
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
|
||||
"""Calculate gradient penalty for wgan-gp.
|
||||
|
||||
Args:
|
||||
discriminator (nn.Module): Network for the discriminator.
|
||||
real_data (Tensor): Real input data.
|
||||
fake_data (Tensor): Fake input data.
|
||||
weight (Tensor): Weight tensor. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: A tensor for gradient penalty.
|
||||
"""
|
||||
|
||||
batch_size = real_data.size(0)
|
||||
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
|
||||
|
||||
# interpolate between real_data and fake_data
|
||||
interpolates = alpha * real_data + (1. - alpha) * fake_data
|
||||
interpolates = autograd.Variable(interpolates, requires_grad=True)
|
||||
|
||||
disc_interpolates = discriminator(interpolates)
|
||||
gradients = autograd.grad(
|
||||
outputs=disc_interpolates,
|
||||
inputs=interpolates,
|
||||
grad_outputs=torch.ones_like(disc_interpolates),
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
only_inputs=True)[0]
|
||||
|
||||
if weight is not None:
|
||||
gradients = gradients * weight
|
||||
|
||||
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
|
||||
if weight is not None:
|
||||
gradients_penalty /= torch.mean(weight)
|
||||
|
||||
return gradients_penalty
|
||||
Reference in New Issue
Block a user