355 lines
11 KiB
Python
355 lines
11 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.nn.functional as F
|
||
|
|
|
||
|
|
|
||
|
|
class CNR2d(nn.Module):
|
||
|
|
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
if bias == []:
|
||
|
|
if norm == 'bnorm':
|
||
|
|
bias = False
|
||
|
|
else:
|
||
|
|
bias = True
|
||
|
|
|
||
|
|
layers = []
|
||
|
|
layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
|
||
|
|
|
||
|
|
if norm != []:
|
||
|
|
layers += [Norm2d(nch_out, norm)]
|
||
|
|
|
||
|
|
if relu != []:
|
||
|
|
layers += [ReLU(relu)]
|
||
|
|
|
||
|
|
if drop != []:
|
||
|
|
layers += [nn.Dropout2d(drop)]
|
||
|
|
|
||
|
|
self.cbr = nn.Sequential(*layers)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.cbr(x)
|
||
|
|
|
||
|
|
|
||
|
|
class DECNR2d(nn.Module):
|
||
|
|
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
if bias == []:
|
||
|
|
if norm == 'bnorm':
|
||
|
|
bias = False
|
||
|
|
else:
|
||
|
|
bias = True
|
||
|
|
|
||
|
|
layers = []
|
||
|
|
layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)]
|
||
|
|
|
||
|
|
if norm != []:
|
||
|
|
layers += [Norm2d(nch_out, norm)]
|
||
|
|
|
||
|
|
if relu != []:
|
||
|
|
layers += [ReLU(relu)]
|
||
|
|
|
||
|
|
if drop != []:
|
||
|
|
layers += [nn.Dropout2d(drop)]
|
||
|
|
|
||
|
|
self.decbr = nn.Sequential(*layers)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.decbr(x)
|
||
|
|
|
||
|
|
|
||
|
|
class ResBlock(nn.Module):
|
||
|
|
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
if bias == []:
|
||
|
|
if norm == 'bnorm':
|
||
|
|
bias = False
|
||
|
|
else:
|
||
|
|
bias = True
|
||
|
|
|
||
|
|
layers = []
|
||
|
|
|
||
|
|
# 1st conv
|
||
|
|
layers += [Padding(padding, padding_mode=padding_mode)]
|
||
|
|
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
|
||
|
|
|
||
|
|
if drop != []:
|
||
|
|
layers += [nn.Dropout2d(drop)]
|
||
|
|
|
||
|
|
# 2nd conv
|
||
|
|
layers += [Padding(padding, padding_mode=padding_mode)]
|
||
|
|
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
|
||
|
|
|
||
|
|
self.resblk = nn.Sequential(*layers)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return x + self.resblk(x)
|
||
|
|
|
||
|
|
|
||
|
|
class ResBlock_cat(nn.Module):
|
||
|
|
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
if bias == []:
|
||
|
|
if norm == 'bnorm':
|
||
|
|
bias = False
|
||
|
|
else:
|
||
|
|
bias = True
|
||
|
|
|
||
|
|
layers = []
|
||
|
|
|
||
|
|
# 1st conv
|
||
|
|
layers += [Padding(padding, padding_mode=padding_mode)]
|
||
|
|
layers += [CNR2d(nch_in*2, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
|
||
|
|
|
||
|
|
if drop != []:
|
||
|
|
layers += [nn.Dropout2d(drop)]
|
||
|
|
|
||
|
|
# 2nd conv
|
||
|
|
layers += [Padding(padding, padding_mode=padding_mode)]
|
||
|
|
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
|
||
|
|
|
||
|
|
self.resblk = nn.Sequential(*layers)
|
||
|
|
|
||
|
|
def forward(self,x,y):
|
||
|
|
output = x + self.resblk(torch.cat([x,y],dim=1))
|
||
|
|
return output
|
||
|
|
|
||
|
|
class LinearBlock(nn.Module):
|
||
|
|
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
||
|
|
super(LinearBlock, self).__init__()
|
||
|
|
use_bias = True
|
||
|
|
# initialize fully connected layer
|
||
|
|
if norm == 'sn':
|
||
|
|
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
|
||
|
|
else:
|
||
|
|
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
||
|
|
|
||
|
|
# initialize normalization
|
||
|
|
norm_dim = output_dim
|
||
|
|
if norm == 'bn':
|
||
|
|
self.norm = nn.BatchNorm1d(norm_dim)
|
||
|
|
elif norm == 'in':
|
||
|
|
self.norm = nn.InstanceNorm1d(norm_dim)
|
||
|
|
elif norm == 'ln':
|
||
|
|
self.norm = LayerNorm(norm_dim)
|
||
|
|
elif norm == 'none' or norm == 'sn':
|
||
|
|
self.norm = None
|
||
|
|
else:
|
||
|
|
assert 0, "Unsupported normalization: {}".format(norm)
|
||
|
|
|
||
|
|
# initialize activation
|
||
|
|
if activation == 'relu':
|
||
|
|
self.activation = nn.ReLU(inplace=True)
|
||
|
|
elif activation == 'lrelu':
|
||
|
|
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
||
|
|
elif activation == 'prelu':
|
||
|
|
self.activation = nn.PReLU()
|
||
|
|
elif activation == 'selu':
|
||
|
|
self.activation = nn.SELU(inplace=True)
|
||
|
|
elif activation == 'tanh':
|
||
|
|
self.activation = nn.Tanh()
|
||
|
|
elif activation == 'none':
|
||
|
|
self.activation = None
|
||
|
|
else:
|
||
|
|
assert 0, "Unsupported activation: {}".format(activation)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
out = self.fc(x)
|
||
|
|
if self.norm:
|
||
|
|
out = self.norm(out)
|
||
|
|
if self.activation:
|
||
|
|
out = self.activation(out)
|
||
|
|
return out
|
||
|
|
|
||
|
|
class MLP(nn.Module):
|
||
|
|
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
|
||
|
|
|
||
|
|
super(MLP, self).__init__()
|
||
|
|
self.model = []
|
||
|
|
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
|
||
|
|
for i in range(n_blk - 2):
|
||
|
|
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
|
||
|
|
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
|
||
|
|
self.model = nn.Sequential(*self.model)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.model(x.view(x.size(0), -1))
|
||
|
|
|
||
|
|
class CNR1d(nn.Module):
|
||
|
|
def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
if norm == 'bnorm':
|
||
|
|
bias = False
|
||
|
|
else:
|
||
|
|
bias = True
|
||
|
|
|
||
|
|
layers = []
|
||
|
|
layers += [nn.Linear(nch_in, nch_out, bias=bias)]
|
||
|
|
|
||
|
|
if norm != []:
|
||
|
|
layers += [Norm2d(nch_out, norm)]
|
||
|
|
|
||
|
|
if relu != []:
|
||
|
|
layers += [ReLU(relu)]
|
||
|
|
|
||
|
|
if drop != []:
|
||
|
|
layers += [nn.Dropout2d(drop)]
|
||
|
|
|
||
|
|
self.cbr = nn.Sequential(*layers)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.cbr(x)
|
||
|
|
|
||
|
|
|
||
|
|
class Conv2d(nn.Module):
|
||
|
|
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True):
|
||
|
|
super(Conv2d, self).__init__()
|
||
|
|
self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.conv(x)
|
||
|
|
|
||
|
|
|
||
|
|
class Deconv2d(nn.Module):
|
||
|
|
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True):
|
||
|
|
super(Deconv2d, self).__init__()
|
||
|
|
self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)
|
||
|
|
|
||
|
|
# layers = [nn.Upsample(scale_factor=2, mode='bilinear'),
|
||
|
|
# nn.ReflectionPad2d(1),
|
||
|
|
# nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)]
|
||
|
|
#
|
||
|
|
# self.deconv = nn.Sequential(*layers)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.deconv(x)
|
||
|
|
|
||
|
|
|
||
|
|
class Linear(nn.Module):
|
||
|
|
def __init__(self, nch_in, nch_out):
|
||
|
|
super(Linear, self).__init__()
|
||
|
|
self.linear = nn.Linear(nch_in, nch_out)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.linear(x)
|
||
|
|
|
||
|
|
|
||
|
|
class Norm2d(nn.Module):
|
||
|
|
def __init__(self, nch, norm_mode):
|
||
|
|
super(Norm2d, self).__init__()
|
||
|
|
if norm_mode == 'bnorm':
|
||
|
|
self.norm = nn.BatchNorm2d(nch)
|
||
|
|
elif norm_mode == 'inorm':
|
||
|
|
self.norm = nn.InstanceNorm2d(nch)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.norm(x)
|
||
|
|
|
||
|
|
|
||
|
|
class ReLU(nn.Module):
|
||
|
|
def __init__(self, relu):
|
||
|
|
super(ReLU, self).__init__()
|
||
|
|
if relu > 0:
|
||
|
|
self.relu = nn.LeakyReLU(relu, True)
|
||
|
|
elif relu == 0:
|
||
|
|
self.relu = nn.ReLU(True)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.relu(x)
|
||
|
|
|
||
|
|
|
||
|
|
class Padding(nn.Module):
|
||
|
|
def __init__(self, padding, padding_mode='zeros', value=0):
|
||
|
|
super(Padding, self).__init__()
|
||
|
|
if padding_mode == 'reflection':
|
||
|
|
self. padding = nn.ReflectionPad2d(padding)
|
||
|
|
elif padding_mode == 'replication':
|
||
|
|
self.padding = nn.ReplicationPad2d(padding)
|
||
|
|
elif padding_mode == 'constant':
|
||
|
|
self.padding = nn.ConstantPad2d(padding, value)
|
||
|
|
elif padding_mode == 'zeros':
|
||
|
|
self.padding = nn.ZeroPad2d(padding)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.padding(x)
|
||
|
|
|
||
|
|
|
||
|
|
class Pooling2d(nn.Module):
|
||
|
|
def __init__(self, nch=[], pool=2, type='avg'):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
if type == 'avg':
|
||
|
|
self.pooling = nn.AvgPool2d(pool)
|
||
|
|
elif type == 'max':
|
||
|
|
self.pooling = nn.MaxPool2d(pool)
|
||
|
|
elif type == 'conv':
|
||
|
|
self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.pooling(x)
|
||
|
|
|
||
|
|
|
||
|
|
class UnPooling2d(nn.Module):
|
||
|
|
def __init__(self, nch=[], pool=2, type='nearest'):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
if type == 'nearest':
|
||
|
|
self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True)
|
||
|
|
elif type == 'bilinear':
|
||
|
|
self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True)
|
||
|
|
elif type == 'conv':
|
||
|
|
self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return self.unpooling(x)
|
||
|
|
|
||
|
|
|
||
|
|
class Concat(nn.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
def forward(self, x1, x2):
|
||
|
|
diffy = x2.size()[2] - x1.size()[2]
|
||
|
|
diffx = x2.size()[3] - x1.size()[3]
|
||
|
|
|
||
|
|
x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2,
|
||
|
|
diffy // 2, diffy - diffy // 2])
|
||
|
|
|
||
|
|
return torch.cat([x2, x1], dim=1)
|
||
|
|
|
||
|
|
|
||
|
|
class TV1dLoss(nn.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super(TV1dLoss, self).__init__()
|
||
|
|
|
||
|
|
def forward(self, input):
|
||
|
|
# loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
|
||
|
|
# torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
|
||
|
|
loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:]))
|
||
|
|
|
||
|
|
return loss
|
||
|
|
|
||
|
|
|
||
|
|
class TV2dLoss(nn.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super(TV2dLoss, self).__init__()
|
||
|
|
|
||
|
|
def forward(self, input):
|
||
|
|
loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
|
||
|
|
torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
|
||
|
|
return loss
|
||
|
|
|
||
|
|
|
||
|
|
class SSIM2dLoss(nn.Module):
|
||
|
|
def __init__(self):
|
||
|
|
super(SSIM2dLoss, self).__init__()
|
||
|
|
|
||
|
|
def forward(self, input, targer):
|
||
|
|
loss = 0
|
||
|
|
return loss
|
||
|
|
|