Add codeformer and update license
This commit is contained in:
25
basicsr/archs/__init__.py
Normal file
25
basicsr/archs/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import importlib
|
||||
from copy import deepcopy
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.utils import get_root_logger, scandir
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
__all__ = ['build_network']
|
||||
|
||||
# automatically scan and import arch modules for registry
|
||||
# scan all the files under the 'archs' folder and collect files ending with
|
||||
# '_arch.py'
|
||||
arch_folder = osp.dirname(osp.abspath(__file__))
|
||||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
||||
# import all the arch modules
|
||||
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
|
||||
|
||||
|
||||
def build_network(opt):
|
||||
opt = deepcopy(opt)
|
||||
network_type = opt.pop('type')
|
||||
net = ARCH_REGISTRY.get(network_type)(**opt)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Network [{net.__class__.__name__}] is created.')
|
||||
return net
|
||||
245
basicsr/archs/arcface_arch.py
Normal file
245
basicsr/archs/arcface_arch.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import torch.nn as nn
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
def conv3x3(inplanes, outplanes, stride=1):
|
||||
"""A simple wrapper for 3x3 convolution with padding.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
outplanes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
"""
|
||||
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
"""Basic residual block used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
"""
|
||||
expansion = 1 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class IRBlock(nn.Module):
|
||||
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
||||
"""
|
||||
expansion = 1 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
||||
super(IRBlock, self).__init__()
|
||||
self.bn0 = nn.BatchNorm2d(inplanes)
|
||||
self.conv1 = conv3x3(inplanes, inplanes)
|
||||
self.bn1 = nn.BatchNorm2d(inplanes)
|
||||
self.prelu = nn.PReLU()
|
||||
self.conv2 = conv3x3(inplanes, planes, stride)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.use_se = use_se
|
||||
if self.use_se:
|
||||
self.se = SEBlock(planes)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.bn0(x)
|
||||
out = self.conv1(out)
|
||||
out = self.bn1(out)
|
||||
out = self.prelu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
if self.use_se:
|
||||
out = self.se(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.prelu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""Bottleneck block used in the ResNetArcFace architecture.
|
||||
|
||||
Args:
|
||||
inplanes (int): Channel number of inputs.
|
||||
planes (int): Channel number of outputs.
|
||||
stride (int): Stride in convolution. Default: 1.
|
||||
downsample (nn.Module): The downsample module. Default: None.
|
||||
"""
|
||||
expansion = 4 # output channel expansion ratio
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SEBlock(nn.Module):
|
||||
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
||||
|
||||
Args:
|
||||
channel (int): Channel number of inputs.
|
||||
reduction (int): Channel reduction ration. Default: 16.
|
||||
"""
|
||||
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SEBlock, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
||||
nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class ResNetArcFace(nn.Module):
|
||||
"""ArcFace with ResNet architectures.
|
||||
|
||||
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
||||
|
||||
Args:
|
||||
block (str): Block used in the ArcFace architecture.
|
||||
layers (tuple(int)): Block numbers in each layer.
|
||||
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, block, layers, use_se=True):
|
||||
if block == 'IRBlock':
|
||||
block = IRBlock
|
||||
self.inplanes = 64
|
||||
self.use_se = use_se
|
||||
super(ResNetArcFace, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.prelu = nn.PReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.bn4 = nn.BatchNorm2d(512)
|
||||
self.dropout = nn.Dropout()
|
||||
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
||||
self.bn5 = nn.BatchNorm1d(512)
|
||||
|
||||
# initialization
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, num_blocks):
|
||||
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.prelu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.bn4(x)
|
||||
x = self.dropout(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc5(x)
|
||||
x = self.bn5(x)
|
||||
|
||||
return x
|
||||
318
basicsr/archs/arch_util.py
Normal file
318
basicsr/archs/arch_util.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import collections.abc
|
||||
import math
|
||||
import torch
|
||||
import torchvision
|
||||
import warnings
|
||||
from distutils.version import LooseVersion
|
||||
from itertools import repeat
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import init as init
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
|
||||
from basicsr.utils import get_root_logger
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
||||
"""Initialize network weights.
|
||||
|
||||
Args:
|
||||
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
||||
scale (float): Scale initialized weights, especially for residual
|
||||
blocks. Default: 1.
|
||||
bias_fill (float): The value to fill bias. Default: 0
|
||||
kwargs (dict): Other arguments for initialization function.
|
||||
"""
|
||||
if not isinstance(module_list, list):
|
||||
module_list = [module_list]
|
||||
for module in module_list:
|
||||
for m in module.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, **kwargs)
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(bias_fill)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, **kwargs)
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(bias_fill)
|
||||
elif isinstance(m, _BatchNorm):
|
||||
init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(bias_fill)
|
||||
|
||||
|
||||
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||
"""Make layers by stacking the same blocks.
|
||||
|
||||
Args:
|
||||
basic_block (nn.module): nn.module class for basic block.
|
||||
num_basic_block (int): number of blocks.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||
"""
|
||||
layers = []
|
||||
for _ in range(num_basic_block):
|
||||
layers.append(basic_block(**kwarg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualBlockNoBN(nn.Module):
|
||||
"""Residual block without BN.
|
||||
|
||||
It has a style of:
|
||||
---Conv-ReLU-Conv-+-
|
||||
|________________|
|
||||
|
||||
Args:
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
Default: 64.
|
||||
res_scale (float): Residual scale. Default: 1.
|
||||
pytorch_init (bool): If set to True, use pytorch default init,
|
||||
otherwise, use default_init_weights. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
||||
super(ResidualBlockNoBN, self).__init__()
|
||||
self.res_scale = res_scale
|
||||
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
if not pytorch_init:
|
||||
default_init_weights([self.conv1, self.conv2], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = self.conv2(self.relu(self.conv1(x)))
|
||||
return identity + out * self.res_scale
|
||||
|
||||
|
||||
class Upsample(nn.Sequential):
|
||||
"""Upsample module.
|
||||
|
||||
Args:
|
||||
scale (int): Scale factor. Supported scales: 2^n and 3.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
"""
|
||||
|
||||
def __init__(self, scale, num_feat):
|
||||
m = []
|
||||
if (scale & (scale - 1)) == 0: # scale = 2^n
|
||||
for _ in range(int(math.log(scale, 2))):
|
||||
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(2))
|
||||
elif scale == 3:
|
||||
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(3))
|
||||
else:
|
||||
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
||||
super(Upsample, self).__init__(*m)
|
||||
|
||||
|
||||
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
||||
"""Warp an image or feature map with optical flow.
|
||||
|
||||
Args:
|
||||
x (Tensor): Tensor with size (n, c, h, w).
|
||||
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
||||
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
||||
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
||||
Default: 'zeros'.
|
||||
align_corners (bool): Before pytorch 1.3, the default value is
|
||||
align_corners=True. After pytorch 1.3, the default value is
|
||||
align_corners=False. Here, we use the True as default.
|
||||
|
||||
Returns:
|
||||
Tensor: Warped image or feature map.
|
||||
"""
|
||||
assert x.size()[-2:] == flow.size()[1:3]
|
||||
_, _, h, w = x.size()
|
||||
# create mesh grid
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
||||
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
||||
grid.requires_grad = False
|
||||
|
||||
vgrid = grid + flow
|
||||
# scale grid to [-1,1]
|
||||
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
||||
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
||||
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
||||
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
||||
|
||||
# TODO, what if align_corners=False
|
||||
return output
|
||||
|
||||
|
||||
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
||||
"""Resize a flow according to ratio or shape.
|
||||
|
||||
Args:
|
||||
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
||||
size_type (str): 'ratio' or 'shape'.
|
||||
sizes (list[int | float]): the ratio for resizing or the final output
|
||||
shape.
|
||||
1) The order of ratio should be [ratio_h, ratio_w]. For
|
||||
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
||||
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
||||
ratio > 1.0).
|
||||
2) The order of output_size should be [out_h, out_w].
|
||||
interp_mode (str): The mode of interpolation for resizing.
|
||||
Default: 'bilinear'.
|
||||
align_corners (bool): Whether align corners. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor: Resized flow.
|
||||
"""
|
||||
_, _, flow_h, flow_w = flow.size()
|
||||
if size_type == 'ratio':
|
||||
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
||||
elif size_type == 'shape':
|
||||
output_h, output_w = sizes[0], sizes[1]
|
||||
else:
|
||||
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
||||
|
||||
input_flow = flow.clone()
|
||||
ratio_h = output_h / flow_h
|
||||
ratio_w = output_w / flow_w
|
||||
input_flow[:, 0, :, :] *= ratio_w
|
||||
input_flow[:, 1, :, :] *= ratio_h
|
||||
resized_flow = F.interpolate(
|
||||
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
||||
return resized_flow
|
||||
|
||||
|
||||
# TODO: may write a cpp file
|
||||
def pixel_unshuffle(x, scale):
|
||||
""" Pixel unshuffle.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||
scale (int): Downsample ratio.
|
||||
|
||||
Returns:
|
||||
Tensor: the pixel unshuffled feature.
|
||||
"""
|
||||
b, c, hh, hw = x.size()
|
||||
out_channel = c * (scale**2)
|
||||
assert hh % scale == 0 and hw % scale == 0
|
||||
h = hh // scale
|
||||
w = hw // scale
|
||||
x_view = x.view(b, c, h, scale, w, scale)
|
||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||
|
||||
|
||||
class DCNv2Pack(ModulatedDeformConvPack):
|
||||
"""Modulated deformable conv for deformable alignment.
|
||||
|
||||
Different from the official DCNv2Pack, which generates offsets and masks
|
||||
from the preceding features, this DCNv2Pack takes another different
|
||||
features to generate offsets and masks.
|
||||
|
||||
Ref:
|
||||
Delving Deep into Deformable Alignment in Video Super-Resolution.
|
||||
"""
|
||||
|
||||
def forward(self, x, feat):
|
||||
out = self.conv_offset(feat)
|
||||
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
||||
offset = torch.cat((o1, o2), dim=1)
|
||||
mask = torch.sigmoid(mask)
|
||||
|
||||
offset_absmean = torch.mean(torch.abs(offset))
|
||||
if offset_absmean > 50:
|
||||
logger = get_root_logger()
|
||||
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
|
||||
|
||||
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
|
||||
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, mask)
|
||||
else:
|
||||
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups, self.deformable_groups)
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn(
|
||||
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
||||
'The distribution of values may be incorrect.',
|
||||
stacklevel=2)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
low = norm_cdf((a - mean) / std)
|
||||
up = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [low, up], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution.
|
||||
|
||||
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
||||
|
||||
The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
# From PyTorch
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_1tuple = _ntuple(1)
|
||||
to_2tuple = _ntuple(2)
|
||||
to_3tuple = _ntuple(3)
|
||||
to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
276
basicsr/archs/codeformer_arch.py
Normal file
276
basicsr/archs/codeformer_arch.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List
|
||||
|
||||
from basicsr.archs.vqgan_arch import *
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
|
||||
def adaptive_instance_normalization(content_feat, style_feat):
|
||||
"""Adaptive instance normalization.
|
||||
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransformerSALayer(nn.Module):
|
||||
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model - MLP
|
||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
||||
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward(self, tgt,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
# ffn
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
class Fuse_sft_block(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
||||
|
||||
self.scale = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
self.shift = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, enc_feat, dec_feat, w=1):
|
||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
||||
scale = self.scale(enc_feat)
|
||||
shift = self.shift(enc_feat)
|
||||
residual = w * (dec_feat * scale + shift)
|
||||
out = dec_feat + residual
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=['32', '64', '128', '256'],
|
||||
fix_modules=['quantize','generator']):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
for param in getattr(self, module).parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.connect_list = connect_list
|
||||
self.n_layers = n_layers
|
||||
self.dim_embd = dim_embd
|
||||
self.dim_mlp = dim_embd*2
|
||||
|
||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
'64': 256,
|
||||
'128': 128,
|
||||
'256': 128,
|
||||
'512': 64,
|
||||
}
|
||||
|
||||
# after second residual block for > 16, before attn layer for ==16
|
||||
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
||||
# after first residual block for > 16, before attn layer for ==16
|
||||
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
||||
|
||||
# fuse_convs_dict
|
||||
self.fuse_convs_dict = nn.ModuleDict()
|
||||
for f_size in self.connect_list:
|
||||
in_ch = self.channels[f_size]
|
||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
||||
# ################### Encoder #####################
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
lq_feat = x
|
||||
# ################# Transformer ###################
|
||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
||||
# BCHW -> BC(HW) -> (HW)BC
|
||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
||||
query_emb = feat_emb
|
||||
# Transformer encoder
|
||||
for layer in self.ft_layers:
|
||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
||||
|
||||
# output logits
|
||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
||||
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
||||
|
||||
if code_only: # for training stage II
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return logits, lq_feat
|
||||
|
||||
# ################# Quantization ###################
|
||||
# if self.training:
|
||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
||||
# # b(hw)c -> bc(hw) -> bchw
|
||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
||||
# ------------
|
||||
soft_one_hot = F.softmax(logits, dim=2)
|
||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
||||
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
||||
# preserve gradients
|
||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
||||
|
||||
if detach_16:
|
||||
quant_feat = quant_feat.detach() # for training stage III
|
||||
if adain:
|
||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
||||
|
||||
# ################## Generator ####################
|
||||
x = quant_feat
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
||||
119
basicsr/archs/rrdbnet_arch.py
Normal file
119
basicsr/archs/rrdbnet_arch.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
|
||||
|
||||
|
||||
class ResidualDenseBlock(nn.Module):
|
||||
"""Residual Dense Block.
|
||||
|
||||
Used in RRDB block in ESRGAN.
|
||||
|
||||
Args:
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
num_grow_ch (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, num_feat=64, num_grow_ch=32):
|
||||
super(ResidualDenseBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
# Emperically, we use 0.2 to scale the residual for better performance
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block.
|
||||
|
||||
Used in RRDB-Net in ESRGAN.
|
||||
|
||||
Args:
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
num_grow_ch (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, num_feat, num_grow_ch=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.rdb3(out)
|
||||
# Emperically, we use 0.2 to scale the residual for better performance
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class RRDBNet(nn.Module):
|
||||
"""Networks consisting of Residual in Residual Dense Block, which is used
|
||||
in ESRGAN.
|
||||
|
||||
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
||||
|
||||
We extend ESRGAN for scale x2 and scale x1.
|
||||
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
||||
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
||||
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
||||
|
||||
Args:
|
||||
num_in_ch (int): Channel number of inputs.
|
||||
num_out_ch (int): Channel number of outputs.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
Default: 64
|
||||
num_block (int): Block number in the trunk network. Defaults: 23
|
||||
num_grow_ch (int): Channels for each growth. Default: 32.
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
self.scale = scale
|
||||
if scale == 2:
|
||||
num_in_ch = num_in_ch * 4
|
||||
elif scale == 1:
|
||||
num_in_ch = num_in_ch * 16
|
||||
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
||||
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
||||
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
# upsample
|
||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale == 2:
|
||||
feat = pixel_unshuffle(x, scale=2)
|
||||
elif self.scale == 1:
|
||||
feat = pixel_unshuffle(x, scale=4)
|
||||
else:
|
||||
feat = x
|
||||
feat = self.conv_first(feat)
|
||||
body_feat = self.conv_body(self.body(feat))
|
||||
feat = feat + body_feat
|
||||
# upsample
|
||||
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
return out
|
||||
161
basicsr/archs/vgg_arch.py
Normal file
161
basicsr/archs/vgg_arch.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import os
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from torch import nn as nn
|
||||
from torchvision.models import vgg as vgg
|
||||
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
|
||||
NAMES = {
|
||||
'vgg11': [
|
||||
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
||||
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
||||
'pool5'
|
||||
],
|
||||
'vgg13': [
|
||||
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
||||
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
||||
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
||||
],
|
||||
'vgg16': [
|
||||
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
||||
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
||||
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
||||
'pool5'
|
||||
],
|
||||
'vgg19': [
|
||||
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
||||
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
||||
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
||||
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def insert_bn(names):
|
||||
"""Insert bn layer after each conv.
|
||||
|
||||
Args:
|
||||
names (list): The list of layer names.
|
||||
|
||||
Returns:
|
||||
list: The list of layer names with bn layers.
|
||||
"""
|
||||
names_bn = []
|
||||
for name in names:
|
||||
names_bn.append(name)
|
||||
if 'conv' in name:
|
||||
position = name.replace('conv', '')
|
||||
names_bn.append('bn' + position)
|
||||
return names_bn
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VGGFeatureExtractor(nn.Module):
|
||||
"""VGG network for feature extraction.
|
||||
|
||||
In this implementation, we allow users to choose whether use normalization
|
||||
in the input feature and the type of vgg network. Note that the pretrained
|
||||
path must fit the vgg type.
|
||||
|
||||
Args:
|
||||
layer_name_list (list[str]): Forward function returns the corresponding
|
||||
features according to the layer_name_list.
|
||||
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
||||
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image. Importantly,
|
||||
the input feature must in the range [0, 1]. Default: True.
|
||||
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
||||
Default: False.
|
||||
requires_grad (bool): If true, the parameters of VGG network will be
|
||||
optimized. Default: False.
|
||||
remove_pooling (bool): If true, the max pooling operations in VGG net
|
||||
will be removed. Default: False.
|
||||
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layer_name_list,
|
||||
vgg_type='vgg19',
|
||||
use_input_norm=True,
|
||||
range_norm=False,
|
||||
requires_grad=False,
|
||||
remove_pooling=False,
|
||||
pooling_stride=2):
|
||||
super(VGGFeatureExtractor, self).__init__()
|
||||
|
||||
self.layer_name_list = layer_name_list
|
||||
self.use_input_norm = use_input_norm
|
||||
self.range_norm = range_norm
|
||||
|
||||
self.names = NAMES[vgg_type.replace('_bn', '')]
|
||||
if 'bn' in vgg_type:
|
||||
self.names = insert_bn(self.names)
|
||||
|
||||
# only borrow layers that will be used to avoid unused params
|
||||
max_idx = 0
|
||||
for v in layer_name_list:
|
||||
idx = self.names.index(v)
|
||||
if idx > max_idx:
|
||||
max_idx = idx
|
||||
|
||||
if os.path.exists(VGG_PRETRAIN_PATH):
|
||||
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
||||
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
|
||||
vgg_net.load_state_dict(state_dict)
|
||||
else:
|
||||
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
||||
|
||||
features = vgg_net.features[:max_idx + 1]
|
||||
|
||||
modified_net = OrderedDict()
|
||||
for k, v in zip(self.names, features):
|
||||
if 'pool' in k:
|
||||
# if remove_pooling is true, pooling operation will be removed
|
||||
if remove_pooling:
|
||||
continue
|
||||
else:
|
||||
# in some cases, we may want to change the default stride
|
||||
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
||||
else:
|
||||
modified_net[k] = v
|
||||
|
||||
self.vgg_net = nn.Sequential(modified_net)
|
||||
|
||||
if not requires_grad:
|
||||
self.vgg_net.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
else:
|
||||
self.vgg_net.train()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
if self.use_input_norm:
|
||||
# the mean is for image with range [0, 1]
|
||||
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
# the std is for image with range [0, 1]
|
||||
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
if self.range_norm:
|
||||
x = (x + 1) / 2
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = {}
|
||||
|
||||
for key, layer in self.vgg_net._modules.items():
|
||||
x = layer(x)
|
||||
if key in self.layer_name_list:
|
||||
output[key] = x.clone()
|
||||
|
||||
return output
|
||||
435
basicsr/archs/vqgan_arch.py
Normal file
435
basicsr/archs/vqgan_arch.py
Normal file
@@ -0,0 +1,435 @@
|
||||
'''
|
||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
# Define VQVAE classes
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.emb_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
||||
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
|
||||
mean_distance = torch.mean(d)
|
||||
# find closest encodings
|
||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
||||
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
||||
# [0-1], higher score, higher confidence
|
||||
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
||||
|
||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, {
|
||||
"perplexity": perplexity,
|
||||
"min_encodings": min_encodings,
|
||||
"min_encoding_indices": min_encoding_indices,
|
||||
"min_encoding_scores": min_encoding_scores,
|
||||
"mean_distance": mean_distance
|
||||
}
|
||||
|
||||
def get_codebook_feat(self, indices, shape):
|
||||
# input indices: batch*token_num -> (batch*token_num)*1
|
||||
# shape: batch, height, width, channel
|
||||
indices = indices.view(-1,1)
|
||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
||||
min_encodings.scatter_(1, indices, 1)
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None: # reshape back to match original input shape
|
||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
||||
|
||||
def forward(self, z):
|
||||
hard = self.straight_through if self.training else True
|
||||
|
||||
logits = self.proj(z)
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
||||
|
||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
||||
|
||||
return z_q, diff, {
|
||||
"min_encoding_indices": min_encoding_indices
|
||||
}
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(ResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm1 = normalize(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = normalize(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
x = self.norm1(x)
|
||||
x = swish(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = swish(x)
|
||||
x = self.conv2(x)
|
||||
if self.in_channels != self.out_channels:
|
||||
x_in = self.conv_out(x_in)
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h*w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.attn_resolutions = attn_resolutions
|
||||
|
||||
curr_res = self.resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
|
||||
blocks = []
|
||||
# initial convultion
|
||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
||||
for i in range(self.num_resolutions):
|
||||
block_in_ch = nf * in_ch_mult[i]
|
||||
block_out_ch = nf * ch_mult[i]
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
if curr_res in attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != self.num_resolutions - 1:
|
||||
blocks.append(Downsample(block_in_ch))
|
||||
curr_res = curr_res // 2
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
# normalise and convert to latent size
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.num_resolutions = len(self.ch_mult)
|
||||
self.num_res_blocks = res_blocks
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
block_in_ch = self.nf * self.ch_mult[-1]
|
||||
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
||||
|
||||
blocks = []
|
||||
# initial conv
|
||||
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
for i in reversed(range(self.num_resolutions)):
|
||||
block_out_ch = self.nf * self.ch_mult[i]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
|
||||
if curr_res in self.attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != 0:
|
||||
blocks.append(Upsample(block_in_ch))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.codebook_size = codebook_size
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
if self.quantizer_type == "nearest":
|
||||
self.beta = beta #0.25
|
||||
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
||||
elif self.quantizer_type == "gumbel":
|
||||
self.gumbel_num_hiddens = emb_dim
|
||||
self.straight_through = gumbel_straight_through
|
||||
self.kl_weight = gumbel_kl_weight
|
||||
self.quantize = GumbelQuantizer(
|
||||
self.codebook_size,
|
||||
self.embed_dim,
|
||||
self.gumbel_num_hiddens,
|
||||
self.straight_through,
|
||||
self.kl_weight
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_ema' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
||||
x = self.generator(quant)
|
||||
return x, codebook_loss, quant_stats
|
||||
|
||||
|
||||
|
||||
# patch based discriminator
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQGANDiscriminator(nn.Module):
|
||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
||||
super().__init__()
|
||||
|
||||
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
||||
ndf_mult = 1
|
||||
ndf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n, 8)
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n_layers, 8)
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_d' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
||||
Reference in New Issue
Block a user