feat sketch 提取接口
fix
This commit is contained in:
49
app/service/image2sketch/models/__init__.py
Normal file
49
app/service/image2sketch/models/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import importlib
|
||||
|
||||
from app.service.image2sketch.models import unpaired_model as modellib
|
||||
from .base_model import BaseModel
|
||||
|
||||
|
||||
def find_model_using_name(model_name):
|
||||
"""Import the module "models/[model_name]_model.py".
|
||||
|
||||
In the file, the class called DatasetNameModel() will
|
||||
be instantiated. It has to be a subclass of BaseModel,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
# model_filename = "." + model_name + "_model"
|
||||
# modellib = importlib.import_module(model_filename)
|
||||
model = None
|
||||
target_model_name = model_name.replace('_', '') + 'model'
|
||||
for name, cls in modellib.__dict__.items():
|
||||
if name.lower() == target_model_name.lower() \
|
||||
and issubclass(cls, BaseModel):
|
||||
model = cls
|
||||
|
||||
if model is None:
|
||||
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
||||
exit(0)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_option_setter(model_name):
|
||||
"""Return the static method <modify_commandline_options> of the model class."""
|
||||
model_class = find_model_using_name(model_name)
|
||||
return model_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
"""Create a model given the option.
|
||||
|
||||
This function warps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from .models import create_model
|
||||
>>> model = create_model(opt)
|
||||
"""
|
||||
model = find_model_using_name(opt.model)
|
||||
instance = model(opt)
|
||||
print("model [%s] was created" % type(instance).__name__)
|
||||
return instance
|
||||
230
app/service/image2sketch/models/base_model.py
Normal file
230
app/service/image2sketch/models/base_model.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import os
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from . import networks
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""This class is an abstract base class (ABC) for models.
|
||||
To create a subclass, you need to implement the following five functions:
|
||||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||||
-- <forward>: produce intermediate results.
|
||||
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
||||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the BaseModel class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
|
||||
When creating your custom class, you need to implement your own initialization.
|
||||
In this function, you should first call <BaseModel.__init__(self, opt)>
|
||||
Then, you need to define four lists:
|
||||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||||
-- self.model_names (str list): define networks used in our training.
|
||||
-- self.visual_names (str list): specify the images that you want to display and save.
|
||||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
||||
"""
|
||||
self.opt = opt
|
||||
self.gpu_ids = opt.gpu_ids
|
||||
self.isTrain = opt.isTrain
|
||||
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
||||
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
||||
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.loss_names = []
|
||||
self.model_names = []
|
||||
self.visual_names = []
|
||||
self.optimizers = []
|
||||
self.image_paths = []
|
||||
self.metric = 0 # used for learning rate policy 'plateau'
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new model-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input (dict): includes the data itself and its metadata information.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimize_parameters(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
pass
|
||||
|
||||
def setup(self, opt):
|
||||
"""Load and print networks; create schedulers
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
if self.isTrain:
|
||||
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
||||
if not self.isTrain or opt.continue_train:
|
||||
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
|
||||
self.load_networks(load_suffix)
|
||||
self.print_networks(opt.verbose)
|
||||
|
||||
def eval(self):
|
||||
"""Make models eval mode during test time"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
net.eval()
|
||||
|
||||
def test(self):
|
||||
"""Forward function used in test time.
|
||||
|
||||
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
||||
It also calls <compute_visuals> to produce additional visualization results
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.forward()
|
||||
self.compute_visuals()
|
||||
|
||||
def compute_visuals(self):
|
||||
"""Calculate additional output images for visdom and HTML visualization"""
|
||||
pass
|
||||
|
||||
def get_image_paths(self):
|
||||
""" Return image paths that are used to load current data"""
|
||||
return self.image_paths
|
||||
|
||||
def update_learning_rate(self):
|
||||
"""Update learning rates for all the networks; called at the end of every epoch"""
|
||||
old_lr = self.optimizers[0].param_groups[0]['lr']
|
||||
for scheduler in self.schedulers:
|
||||
if self.opt.lr_policy == 'plateau':
|
||||
scheduler.step(self.metric)
|
||||
else:
|
||||
scheduler.step()
|
||||
|
||||
lr = self.optimizers[0].param_groups[0]['lr']
|
||||
print('learning rate %.7f -> %.7f' % (old_lr, lr))
|
||||
|
||||
def get_current_visuals(self):
|
||||
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
||||
visual_ret = OrderedDict()
|
||||
for name in self.visual_names:
|
||||
if isinstance(name, str):
|
||||
visual_ret[name] = getattr(self, name)
|
||||
return visual_ret
|
||||
|
||||
def get_current_losses(self):
|
||||
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
||||
errors_ret = OrderedDict()
|
||||
for name in self.loss_names:
|
||||
if isinstance(name, str):
|
||||
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
||||
return errors_ret
|
||||
|
||||
def save_networks(self, epoch):
|
||||
"""Save all the networks to the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
save_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
|
||||
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
||||
torch.save(net.module.cpu().state_dict(), save_path)
|
||||
net.cuda(self.gpu_ids[0])
|
||||
else:
|
||||
torch.save(net.cpu().state_dict(), save_path)
|
||||
|
||||
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
||||
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
||||
key = keys[i]
|
||||
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'running_mean' or key == 'running_var'):
|
||||
if getattr(module, key) is None:
|
||||
state_dict.pop('.'.join(keys))
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'num_batches_tracked'):
|
||||
state_dict.pop('.'.join(keys))
|
||||
else:
|
||||
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
||||
|
||||
def load_networks(self, epoch):
|
||||
"""Load all the networks from the disk.
|
||||
|
||||
Parameters:
|
||||
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
||||
"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
load_filename = '%s_net_%s.pth' % (epoch, name)
|
||||
load_path = os.path.join(self.save_dir, load_filename)
|
||||
net = getattr(self, 'net' + name)
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
net = net.module
|
||||
print('loading the model from %s' % load_path)
|
||||
# if you are using PyTorch newer than 0.4 (e.g., built from
|
||||
# GitHub source), you can remove str() on self.device
|
||||
state_dict = torch.load(load_path, map_location=str(self.device))
|
||||
if hasattr(state_dict, '_metadata'):
|
||||
del state_dict._metadata
|
||||
|
||||
# patch InstanceNorm checkpoints prior to 0.4
|
||||
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
||||
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
||||
net.load_state_dict(state_dict)
|
||||
|
||||
def print_networks(self, verbose):
|
||||
"""Print the total number of parameters in the network and (if verbose) network architecture
|
||||
|
||||
Parameters:
|
||||
verbose (bool) -- if verbose: print the network architecture
|
||||
"""
|
||||
print('---------- Networks initialized -------------')
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, 'net' + name)
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
if verbose:
|
||||
print(net)
|
||||
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
||||
print('-----------------------------------------------')
|
||||
|
||||
def set_requires_grad(self, nets, requires_grad=False):
|
||||
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
||||
Parameters:
|
||||
nets (network list) -- a list of networks
|
||||
requires_grad (bool) -- whether the networks require gradients or not
|
||||
"""
|
||||
if not isinstance(nets, list):
|
||||
nets = [nets]
|
||||
for net in nets:
|
||||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
354
app/service/image2sketch/models/layer.py
Normal file
354
app/service/image2sketch/models/layer.py
Normal file
@@ -0,0 +1,354 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CNR2d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]):
|
||||
super().__init__()
|
||||
|
||||
if bias == []:
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
|
||||
|
||||
if norm != []:
|
||||
layers += [Norm2d(nch_out, norm)]
|
||||
|
||||
if relu != []:
|
||||
layers += [ReLU(relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
self.cbr = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cbr(x)
|
||||
|
||||
|
||||
class DECNR2d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]):
|
||||
super().__init__()
|
||||
|
||||
if bias == []:
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)]
|
||||
|
||||
if norm != []:
|
||||
layers += [Norm2d(nch_out, norm)]
|
||||
|
||||
if relu != []:
|
||||
layers += [ReLU(relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
self.decbr = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.decbr(x)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
|
||||
super().__init__()
|
||||
|
||||
if bias == []:
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
|
||||
# 1st conv
|
||||
layers += [Padding(padding, padding_mode=padding_mode)]
|
||||
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
# 2nd conv
|
||||
layers += [Padding(padding, padding_mode=padding_mode)]
|
||||
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
|
||||
|
||||
self.resblk = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.resblk(x)
|
||||
|
||||
|
||||
class ResBlock_cat(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
|
||||
super().__init__()
|
||||
|
||||
if bias == []:
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
|
||||
# 1st conv
|
||||
layers += [Padding(padding, padding_mode=padding_mode)]
|
||||
layers += [CNR2d(nch_in*2, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
# 2nd conv
|
||||
layers += [Padding(padding, padding_mode=padding_mode)]
|
||||
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
|
||||
|
||||
self.resblk = nn.Sequential(*layers)
|
||||
|
||||
def forward(self,x,y):
|
||||
output = x + self.resblk(torch.cat([x,y],dim=1))
|
||||
return output
|
||||
|
||||
class LinearBlock(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
||||
super(LinearBlock, self).__init__()
|
||||
use_bias = True
|
||||
# initialize fully connected layer
|
||||
if norm == 'sn':
|
||||
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
|
||||
else:
|
||||
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
||||
|
||||
# initialize normalization
|
||||
norm_dim = output_dim
|
||||
if norm == 'bn':
|
||||
self.norm = nn.BatchNorm1d(norm_dim)
|
||||
elif norm == 'in':
|
||||
self.norm = nn.InstanceNorm1d(norm_dim)
|
||||
elif norm == 'ln':
|
||||
self.norm = LayerNorm(norm_dim)
|
||||
elif norm == 'none' or norm == 'sn':
|
||||
self.norm = None
|
||||
else:
|
||||
assert 0, "Unsupported normalization: {}".format(norm)
|
||||
|
||||
# initialize activation
|
||||
if activation == 'relu':
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
elif activation == 'lrelu':
|
||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
||||
elif activation == 'prelu':
|
||||
self.activation = nn.PReLU()
|
||||
elif activation == 'selu':
|
||||
self.activation = nn.SELU(inplace=True)
|
||||
elif activation == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
elif activation == 'none':
|
||||
self.activation = None
|
||||
else:
|
||||
assert 0, "Unsupported activation: {}".format(activation)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc(x)
|
||||
if self.norm:
|
||||
out = self.norm(out)
|
||||
if self.activation:
|
||||
out = self.activation(out)
|
||||
return out
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
|
||||
|
||||
super(MLP, self).__init__()
|
||||
self.model = []
|
||||
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
|
||||
for i in range(n_blk - 2):
|
||||
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
|
||||
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x.view(x.size(0), -1))
|
||||
|
||||
class CNR1d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]):
|
||||
super().__init__()
|
||||
|
||||
if norm == 'bnorm':
|
||||
bias = False
|
||||
else:
|
||||
bias = True
|
||||
|
||||
layers = []
|
||||
layers += [nn.Linear(nch_in, nch_out, bias=bias)]
|
||||
|
||||
if norm != []:
|
||||
layers += [Norm2d(nch_out, norm)]
|
||||
|
||||
if relu != []:
|
||||
layers += [ReLU(relu)]
|
||||
|
||||
if drop != []:
|
||||
layers += [nn.Dropout2d(drop)]
|
||||
|
||||
self.cbr = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cbr(x)
|
||||
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True):
|
||||
super(Conv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Deconv2d(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True):
|
||||
super(Deconv2d, self).__init__()
|
||||
self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)
|
||||
|
||||
# layers = [nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
# nn.ReflectionPad2d(1),
|
||||
# nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)]
|
||||
#
|
||||
# self.deconv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.deconv(x)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, nch_in, nch_out):
|
||||
super(Linear, self).__init__()
|
||||
self.linear = nn.Linear(nch_in, nch_out)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class Norm2d(nn.Module):
|
||||
def __init__(self, nch, norm_mode):
|
||||
super(Norm2d, self).__init__()
|
||||
if norm_mode == 'bnorm':
|
||||
self.norm = nn.BatchNorm2d(nch)
|
||||
elif norm_mode == 'inorm':
|
||||
self.norm = nn.InstanceNorm2d(nch)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class ReLU(nn.Module):
|
||||
def __init__(self, relu):
|
||||
super(ReLU, self).__init__()
|
||||
if relu > 0:
|
||||
self.relu = nn.LeakyReLU(relu, True)
|
||||
elif relu == 0:
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class Padding(nn.Module):
|
||||
def __init__(self, padding, padding_mode='zeros', value=0):
|
||||
super(Padding, self).__init__()
|
||||
if padding_mode == 'reflection':
|
||||
self. padding = nn.ReflectionPad2d(padding)
|
||||
elif padding_mode == 'replication':
|
||||
self.padding = nn.ReplicationPad2d(padding)
|
||||
elif padding_mode == 'constant':
|
||||
self.padding = nn.ConstantPad2d(padding, value)
|
||||
elif padding_mode == 'zeros':
|
||||
self.padding = nn.ZeroPad2d(padding)
|
||||
|
||||
def forward(self, x):
|
||||
return self.padding(x)
|
||||
|
||||
|
||||
class Pooling2d(nn.Module):
|
||||
def __init__(self, nch=[], pool=2, type='avg'):
|
||||
super().__init__()
|
||||
|
||||
if type == 'avg':
|
||||
self.pooling = nn.AvgPool2d(pool)
|
||||
elif type == 'max':
|
||||
self.pooling = nn.MaxPool2d(pool)
|
||||
elif type == 'conv':
|
||||
self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pooling(x)
|
||||
|
||||
|
||||
class UnPooling2d(nn.Module):
|
||||
def __init__(self, nch=[], pool=2, type='nearest'):
|
||||
super().__init__()
|
||||
|
||||
if type == 'nearest':
|
||||
self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True)
|
||||
elif type == 'bilinear':
|
||||
self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True)
|
||||
elif type == 'conv':
|
||||
self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool)
|
||||
|
||||
def forward(self, x):
|
||||
return self.unpooling(x)
|
||||
|
||||
|
||||
class Concat(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x1, x2):
|
||||
diffy = x2.size()[2] - x1.size()[2]
|
||||
diffx = x2.size()[3] - x1.size()[3]
|
||||
|
||||
x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2,
|
||||
diffy // 2, diffy - diffy // 2])
|
||||
|
||||
return torch.cat([x2, x1], dim=1)
|
||||
|
||||
|
||||
class TV1dLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(TV1dLoss, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
# loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
|
||||
# torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
|
||||
loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:]))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class TV2dLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(TV2dLoss, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
|
||||
torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
|
||||
return loss
|
||||
|
||||
|
||||
class SSIM2dLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(SSIM2dLoss, self).__init__()
|
||||
|
||||
def forward(self, input, targer):
|
||||
loss = 0
|
||||
return loss
|
||||
|
||||
734
app/service/image2sketch/models/networks.py
Normal file
734
app/service/image2sketch/models/networks.py
Normal file
@@ -0,0 +1,734 @@
|
||||
import functools
|
||||
|
||||
from torch.nn import init
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
from .layer import *
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Helper Functions
|
||||
###############################################################################
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
def get_norm_layer(norm_type='instance'):
|
||||
"""Return a normalization layer
|
||||
|
||||
Parameters:
|
||||
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
||||
|
||||
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
||||
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
||||
"""
|
||||
if norm_type == 'batch':
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||
elif norm_type == 'instance':
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x):
|
||||
return Identity()
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
||||
return norm_layer
|
||||
|
||||
|
||||
def get_scheduler(optimizer, opt):
|
||||
"""Return a learning rate scheduler
|
||||
|
||||
Parameters:
|
||||
optimizer -- the optimizer of the network
|
||||
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
||||
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
||||
|
||||
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
||||
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
||||
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
||||
See https://pytorch.org/docs/stable/optim.html for more details.
|
||||
"""
|
||||
if opt.lr_policy == 'linear':
|
||||
def lambda_rule(epoch):
|
||||
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
||||
return lr_l
|
||||
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
||||
elif opt.lr_policy == 'step':
|
||||
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
||||
elif opt.lr_policy == 'plateau':
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
||||
elif opt.lr_policy == 'cosine':
|
||||
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
||||
else:
|
||||
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
||||
return scheduler
|
||||
|
||||
|
||||
def init_weights(net, init_type='normal', init_gain=0.02):
|
||||
"""Initialize network weights.
|
||||
|
||||
Parameters:
|
||||
net (network) -- network to be initialized
|
||||
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||
|
||||
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
||||
work better for some applications. Feel free to try yourself.
|
||||
"""
|
||||
|
||||
def init_func(m): # define the initialization function
|
||||
classname = m.__class__.__name__
|
||||
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
||||
if init_type == 'normal':
|
||||
init.normal_(m.weight.data, 0.0, init_gain)
|
||||
elif init_type == 'xavier':
|
||||
init.xavier_normal_(m.weight.data, gain=init_gain)
|
||||
elif init_type == 'kaiming':
|
||||
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
init.orthogonal_(m.weight.data, gain=init_gain)
|
||||
else:
|
||||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
||||
init.normal_(m.weight.data, 1.0, init_gain)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
print('initialize network with %s' % init_type)
|
||||
net.apply(init_func) # apply the initialization function <init_func>
|
||||
|
||||
|
||||
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
||||
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
||||
Parameters:
|
||||
net (network) -- the network to be initialized
|
||||
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
||||
|
||||
Return an initialized network.
|
||||
"""
|
||||
if len(gpu_ids) > 0:
|
||||
assert (torch.cuda.is_available())
|
||||
net.to(gpu_ids[0])
|
||||
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
||||
init_weights(net, init_type, init_gain=init_gain)
|
||||
return net
|
||||
|
||||
|
||||
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
||||
net = None
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
|
||||
if netG == 'ref_unpair_cbam_cat':
|
||||
net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_cbam_cat')
|
||||
elif netG == 'ref_unpair_recon':
|
||||
net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_recon')
|
||||
elif netG == 'triplet':
|
||||
net = triplet(input_nc, output_nc, ngf, norm='inorm')
|
||||
|
||||
else:
|
||||
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
||||
return init_net(net, init_type, init_gain, gpu_ids)
|
||||
|
||||
|
||||
class AdaIN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
eps = 1e-5
|
||||
mean_x = torch.mean(x, dim=[2, 3])
|
||||
mean_y = torch.mean(y, dim=[2, 3])
|
||||
|
||||
std_x = torch.std(x, dim=[2, 3])
|
||||
std_y = torch.std(y, dim=[2, 3])
|
||||
|
||||
mean_x = mean_x.unsqueeze(-1).unsqueeze(-1)
|
||||
mean_y = mean_y.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
std_x = std_x.unsqueeze(-1).unsqueeze(-1) + eps
|
||||
std_y = std_y.unsqueeze(-1).unsqueeze(-1) + eps
|
||||
|
||||
out = (x - mean_x) / std_x * std_y + mean_y
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class HED(nn.Module):
|
||||
def __init__(self):
|
||||
super(HED, self).__init__()
|
||||
|
||||
self.moduleVggOne = nn.Sequential(
|
||||
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleVggTwo = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleVggThr = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleVggFou = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleVggFiv = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False),
|
||||
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=False)
|
||||
)
|
||||
|
||||
self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.moduleCombine = nn.Sequential(
|
||||
nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, tensorInput):
|
||||
tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793
|
||||
tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762
|
||||
tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434
|
||||
tensorInput = torch.cat([tensorBlue, tensorGreen, tensorRed], 1)
|
||||
|
||||
tensorVggOne = self.moduleVggOne(tensorInput)
|
||||
tensorVggTwo = self.moduleVggTwo(tensorVggOne)
|
||||
tensorVggThr = self.moduleVggThr(tensorVggTwo)
|
||||
tensorVggFou = self.moduleVggFou(tensorVggThr)
|
||||
tensorVggFiv = self.moduleVggFiv(tensorVggFou)
|
||||
|
||||
tensorScoreOne = self.moduleScoreOne(tensorVggOne)
|
||||
tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo)
|
||||
tensorScoreThr = self.moduleScoreThr(tensorVggThr)
|
||||
tensorScoreFou = self.moduleScoreFou(tensorVggFou)
|
||||
tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv)
|
||||
|
||||
tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
|
||||
|
||||
return self.moduleCombine(torch.cat([tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv], 1))
|
||||
# return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreOne, tensorScoreTwo ], 1))
|
||||
|
||||
# return torch.sigmoid(tensorScoreOne),torch.sigmoid(tensorScoreTwo),torch.sigmoid(tensorScoreThr),torch.sigmoid(tensorScoreFou),torch.sigmoid(tensorScoreFiv),self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1))
|
||||
# return torch.sigmoid(tensorScoreTwo)
|
||||
|
||||
|
||||
def define_HED(init_weights_, gpu_ids_=[]):
|
||||
net = HED()
|
||||
|
||||
if len(gpu_ids_) > 0:
|
||||
assert (torch.cuda.is_available())
|
||||
net.to(gpu_ids_[0])
|
||||
net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
|
||||
|
||||
if not init_weights_ == None:
|
||||
device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
|
||||
print('Loading model from: %s' % init_weights_)
|
||||
state_dict = torch.load(init_weights_, map_location=str(device))
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
net.module.load_state_dict(state_dict)
|
||||
else:
|
||||
net.load_state_dict(state_dict)
|
||||
print('load the weights successfully')
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def define_styletps(init_weights_, gpu_ids_=[], shape=False):
|
||||
net = None
|
||||
if shape == False:
|
||||
net = triplet()
|
||||
if len(gpu_ids_) > 0:
|
||||
assert (torch.cuda.is_available())
|
||||
net.to(gpu_ids_[0])
|
||||
net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
|
||||
|
||||
if not init_weights_ == None:
|
||||
device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
|
||||
print('Loading model from: %s' % init_weights_)
|
||||
state_dict = torch.load(init_weights_, map_location=str(device))
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
net.module.load_state_dict(state_dict)
|
||||
else:
|
||||
net.load_state_dict(state_dict)
|
||||
print('load the weights successfully')
|
||||
|
||||
return net
|
||||
|
||||
|
||||
class triplet(nn.Module):
|
||||
def __init__(self): # mnblk=4
|
||||
super(triplet, self).__init__()
|
||||
|
||||
# self.channels = nch_in
|
||||
self.nch_in = 1
|
||||
self.nch_out = 1
|
||||
self.nch_ker = 64
|
||||
self.norm = 'bnorm'
|
||||
# self.nblk = nblk
|
||||
|
||||
if self.norm == 'bnorm':
|
||||
self.bias = False
|
||||
else:
|
||||
self.bias = True
|
||||
|
||||
self.conv0 = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
|
||||
self.conv1 = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.conv2 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
|
||||
self.final_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.linear = nn.Linear(256, 128)
|
||||
|
||||
def forward(self, x, y, z):
|
||||
|
||||
x = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.final_pool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.linear(x)
|
||||
|
||||
y = self.conv0(y)
|
||||
y = self.conv1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.final_pool(y)
|
||||
y = torch.flatten(y, 1)
|
||||
y = self.linear(y)
|
||||
|
||||
z = self.conv0(z)
|
||||
z = self.conv1(z)
|
||||
z = self.conv2(z)
|
||||
z = self.final_pool(z)
|
||||
z = torch.flatten(z, 1)
|
||||
z = self.linear(z)
|
||||
|
||||
return x, y, z
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
|
||||
super(MLP, self).__init__()
|
||||
self.model = []
|
||||
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
|
||||
for i in range(n_blk - 2):
|
||||
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
|
||||
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x.view(x.size(0), -1))
|
||||
|
||||
|
||||
class ref_unpair(nn.Module):
|
||||
def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm', nblk=4, status='ref_unpair'):
|
||||
super(ref_unpair, self).__init__()
|
||||
|
||||
nch_ker = 64
|
||||
# self.channels = nch_in
|
||||
self.nch_in = nch_in
|
||||
self.nchs_in = 1
|
||||
self.status = status
|
||||
|
||||
if self.status == 'ref_unpair_recon':
|
||||
self.nch_out = 3
|
||||
self.nch_in = 1
|
||||
else:
|
||||
self.nch_out = 1
|
||||
|
||||
self.nch_ker = nch_ker
|
||||
self.norm = norm
|
||||
self.nblk = nblk
|
||||
self.dec0 = []
|
||||
|
||||
if status == 'ref_unpair_cbam_cat':
|
||||
self.cbam_c = CBAM(nch_ker * 8, 16, 3, cbam_status="channel")
|
||||
self.cbam_s = CBAM(nch_ker * 8, 16, 3, cbam_status="spatial")
|
||||
|
||||
self.enc1_s = CNR2d(self.nchs_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
|
||||
self.enc2_s = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.enc3_s = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.enc4_s = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
|
||||
if norm == 'bnorm':
|
||||
self.bias = False
|
||||
else:
|
||||
self.bias = True
|
||||
|
||||
self.enc1_c = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
|
||||
self.enc2_c = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.enc3_c = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
self.enc4_c = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
|
||||
|
||||
if status == 'ref_unpair_cbam_cat':
|
||||
self.res_cat1 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
|
||||
self.res_cat2 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
|
||||
self.res_cat3 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
|
||||
self.res_cat4 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')
|
||||
|
||||
if self.nblk and status != 'ref_unpair_cbam_cat':
|
||||
res = []
|
||||
for i in range(self.nblk):
|
||||
res += [ResBlock(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')]
|
||||
self.res1 = nn.Sequential(*res)
|
||||
|
||||
# self.dec0 += [DECNR2d(16 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [DECNR2d(8 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [DECNR2d(1 * self.nch_ker, 1 * self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)]
|
||||
self.dec0 += [nn.Conv2d(1 * self.nch_ker, self.nch_out, kernel_size=3, stride=1, padding=1)]
|
||||
|
||||
self.dec = nn.Sequential(*self.dec0)
|
||||
|
||||
def forward(self, content, style):
|
||||
|
||||
content_cs = self.enc1_c(content)
|
||||
content_cs = self.enc2_c(content_cs)
|
||||
content_cs = self.enc3_c(content_cs)
|
||||
content_cs = self.enc4_c(content_cs)
|
||||
# content_cs = self.enc5_c(content_cs)
|
||||
|
||||
if self.status == 'ref_unpair_cbam_cat':
|
||||
cbam_content_cs = self.cbam_s(content_cs)
|
||||
sp_content_cs = content_cs + cbam_content_cs
|
||||
|
||||
style_cs = self.enc1_s(style)
|
||||
style_cs = self.enc2_s(style_cs)
|
||||
style_cs = self.enc3_s(style_cs)
|
||||
style_cs = self.enc4_s(style_cs)
|
||||
|
||||
cbam_style_cs = self.cbam_c(style_cs)
|
||||
ch_style_cs = style_cs + cbam_style_cs
|
||||
|
||||
content_output = self.adaptive_instance_normalization(content_cs, style_cs)
|
||||
cbam_content_output = self.adaptive_instance_normalization(sp_content_cs, ch_style_cs)
|
||||
|
||||
content_output = self.res_cat1(content_output, cbam_content_output)
|
||||
content_output = self.res_cat2(content_output, cbam_content_output)
|
||||
content_output = self.res_cat3(content_output, cbam_content_output)
|
||||
content_output = self.res_cat4(content_output, cbam_content_output)
|
||||
|
||||
|
||||
else:
|
||||
content_output = content_cs
|
||||
|
||||
if self.nblk and self.status != 'ref_unpair_cbam_cat':
|
||||
content_cs = self.res1(content_output)
|
||||
|
||||
content_output = self.dec(content_output)
|
||||
|
||||
content_output = torch.tanh(content_output)
|
||||
|
||||
return content_output
|
||||
|
||||
def calc_mean_std(self, feat, eps=1e-5):
|
||||
# eps is a small value added to the variance to avoid divide-by-zero.
|
||||
size = feat.size()
|
||||
assert (len(size) == 4)
|
||||
N, C = size[:2]
|
||||
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
||||
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
def adaptive_instance_normalization(self, content_feat, style_feat):
|
||||
assert (content_feat.size()[:2] == style_feat.size()[:2])
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = self.calc_mean_std(style_feat)
|
||||
content_mean, content_std = self.calc_mean_std(content_feat)
|
||||
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
|
||||
net = None
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
|
||||
if netD == 'basic': # default PatchGAN classifier
|
||||
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
|
||||
elif netD == 'n_layers': # more options
|
||||
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
|
||||
elif netD == 'pixel': # classify if each pixel is real or fake
|
||||
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
|
||||
else:
|
||||
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
|
||||
return init_net(net, init_type, init_gain, gpu_ids)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Classes
|
||||
##############################################################################
|
||||
class GANLoss(nn.Module):
|
||||
"""Define different GAN objectives.
|
||||
|
||||
The GANLoss class abstracts away the need to create the target label tensor
|
||||
that has the same size as the input.
|
||||
"""
|
||||
|
||||
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
||||
""" Initialize the GANLoss class.
|
||||
|
||||
Parameters:
|
||||
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
||||
target_real_label (bool) - - label for a real image
|
||||
target_fake_label (bool) - - label of a fake image
|
||||
|
||||
Note: Do not use sigmoid as the last layer of Discriminator.
|
||||
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
||||
"""
|
||||
super(GANLoss, self).__init__()
|
||||
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||||
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||||
self.gan_mode = gan_mode
|
||||
if gan_mode == 'lsgan':
|
||||
self.loss = nn.MSELoss()
|
||||
elif gan_mode == 'vanilla':
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
elif gan_mode in ['wgangp']:
|
||||
self.loss = None
|
||||
else:
|
||||
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
||||
|
||||
def get_target_tensor(self, prediction, target_is_real):
|
||||
if target_is_real:
|
||||
target_tensor = self.real_label
|
||||
else:
|
||||
target_tensor = self.fake_label
|
||||
return target_tensor.expand_as(prediction)
|
||||
|
||||
def __call__(self, prediction, target_is_real):
|
||||
if self.gan_mode in ['lsgan', 'vanilla']:
|
||||
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
||||
loss = self.loss(prediction, target_tensor)
|
||||
elif self.gan_mode == 'wgangp':
|
||||
if target_is_real:
|
||||
loss = -prediction.mean()
|
||||
else:
|
||||
loss = prediction.mean()
|
||||
return loss
|
||||
|
||||
|
||||
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
|
||||
if lambda_gp > 0.0:
|
||||
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
||||
interpolatesv = real_data
|
||||
elif type == 'fake':
|
||||
interpolatesv = fake_data
|
||||
elif type == 'mixed':
|
||||
alpha = torch.rand(real_data.shape[0], 1, device=device)
|
||||
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
|
||||
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
||||
else:
|
||||
raise NotImplementedError('{} not implemented'.format(type))
|
||||
interpolatesv.requires_grad_(True)
|
||||
disc_interpolates = netD(interpolatesv)
|
||||
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
||||
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)
|
||||
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
||||
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
||||
return gradient_penalty, gradients
|
||||
else:
|
||||
return 0.0, None
|
||||
|
||||
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
"""Defines a PatchGAN discriminator"""
|
||||
|
||||
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
||||
"""Construct a PatchGAN discriminator
|
||||
|
||||
Parameters:
|
||||
input_nc (int) -- the number of channels in input images
|
||||
ndf (int) -- the number of filters in the last conv layer
|
||||
n_layers (int) -- the number of conv layers in the discriminator
|
||||
norm_layer -- normalization layer
|
||||
"""
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
kw = 4
|
||||
padw = 1
|
||||
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
||||
nf_mult = 1
|
||||
nf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2 ** n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2 ** n_layers, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, input):
|
||||
"""Standard forward."""
|
||||
return self.model(input)
|
||||
|
||||
|
||||
class PixelDiscriminator(nn.Module):
|
||||
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
||||
|
||||
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
||||
"""Construct a 1x1 PatchGAN discriminator
|
||||
|
||||
Parameters:
|
||||
input_nc (int) -- the number of channels in input images
|
||||
ndf (int) -- the number of filters in the last conv layer
|
||||
norm_layer -- normalization layer
|
||||
"""
|
||||
super(PixelDiscriminator, self).__init__()
|
||||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
self.net = [
|
||||
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
||||
norm_layer(ndf * 2),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
||||
|
||||
self.net = nn.Sequential(*self.net)
|
||||
|
||||
def forward(self, input):
|
||||
"""Standard forward."""
|
||||
return self.net(input)
|
||||
|
||||
|
||||
class CBAM(nn.Module):
|
||||
def __init__(self, n_channels_in, reduction_ratio, kernel_size, cbam_status):
|
||||
super(CBAM, self).__init__()
|
||||
self.n_channels_in = n_channels_in
|
||||
self.reduction_ratio = reduction_ratio
|
||||
self.kernel_size = kernel_size
|
||||
self.channel_attention = ChannelAttention_nopara(n_channels_in, reduction_ratio)
|
||||
self.spatial_attention = SpatialAttention_nopara(kernel_size)
|
||||
self.status = cbam_status
|
||||
|
||||
def forward(self, x):
|
||||
## We don't use cbam in this version
|
||||
if self.status == "cbam":
|
||||
chan_att = self.channel_attention(x)
|
||||
fp = chan_att * x
|
||||
spat_att = self.spatial_attention(fp)
|
||||
fpp = spat_att * fp
|
||||
|
||||
if self.status == "spatial":
|
||||
spat_att = self.spatial_attention(x) # * s_para_1d
|
||||
fpp = spat_att * x
|
||||
if self.status == "channel":
|
||||
chan_att = self.channel_attention(x) # * c_para_1d
|
||||
fpp = chan_att * x
|
||||
|
||||
return fpp # ,c_wgt,s_wgt
|
||||
|
||||
|
||||
class SpatialAttention_nopara(nn.Module):
|
||||
def __init__(self, kernel_size):
|
||||
super(SpatialAttention_nopara, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
assert kernel_size % 2 == 1, "Odd kernel size required"
|
||||
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=int((kernel_size - 1) / 2))
|
||||
|
||||
def forward(self, x):
|
||||
max_pool = self.agg_channel(x, "max")
|
||||
avg_pool = self.agg_channel(x, "avg")
|
||||
pool = torch.cat([max_pool, avg_pool], dim=1)
|
||||
conv = self.conv(pool)
|
||||
conv = conv.repeat(1, x.size()[1], 1, 1)
|
||||
att = torch.sigmoid(conv)
|
||||
return att
|
||||
|
||||
def agg_channel(self, x, pool="max"):
|
||||
b, c, h, w = x.size()
|
||||
x = x.view(b, c, h * w)
|
||||
x = x.permute(0, 2, 1)
|
||||
if pool == "max":
|
||||
x = F.max_pool1d(x, c)
|
||||
elif pool == "avg":
|
||||
x = F.avg_pool1d(x, c)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = x.view(b, 1, h, w)
|
||||
return x
|
||||
|
||||
|
||||
class ChannelAttention_nopara(nn.Module):
|
||||
def __init__(self, n_channels_in, reduction_ratio):
|
||||
super(ChannelAttention_nopara, self).__init__()
|
||||
self.n_channels_in = n_channels_in
|
||||
self.reduction_ratio = reduction_ratio
|
||||
self.middle_layer_size = int(self.n_channels_in / float(self.reduction_ratio))
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Linear(self.n_channels_in, self.middle_layer_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.middle_layer_size, self.n_channels_in)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
kernel = (x.size()[2], x.size()[3])
|
||||
avg_pool = F.avg_pool2d(x, kernel)
|
||||
max_pool = F.max_pool2d(x, kernel)
|
||||
avg_pool = avg_pool.view(avg_pool.size()[0], -1)
|
||||
max_pool = max_pool.view(max_pool.size()[0], -1)
|
||||
avg_pool_bck = self.bottleneck(avg_pool)
|
||||
max_pool_bck = self.bottleneck(max_pool)
|
||||
pool_sum = avg_pool_bck + max_pool_bck
|
||||
sig_pool = torch.sigmoid(pool_sum)
|
||||
sig_pool = sig_pool.unsqueeze(2).unsqueeze(3)
|
||||
# out = sig_pool.repeat(1,1,kernel[0], kernel[1])
|
||||
|
||||
return sig_pool
|
||||
86
app/service/image2sketch/models/perceptual.py
Normal file
86
app/service/image2sketch/models/perceptual.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
class VGGPerceptualLoss(torch.nn.Module):
|
||||
def __init__(self, resize=True):
|
||||
super(VGGPerceptualLoss, self).__init__()
|
||||
blocks = []
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
||||
for bl in blocks:
|
||||
for p in bl:
|
||||
p.requires_grad = False
|
||||
self.blocks = torch.nn.ModuleList(blocks)
|
||||
self.transform = torch.nn.functional.interpolate
|
||||
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
||||
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
||||
self.resize = resize
|
||||
|
||||
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
|
||||
if input.shape[1] != 3:
|
||||
input = input.repeat(1, 3, 1, 1)
|
||||
target = target.repeat(1, 3, 1, 1)
|
||||
input = (input-self.mean) / self.std
|
||||
target = (target-self.mean) / self.std
|
||||
if self.resize:
|
||||
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
||||
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
||||
loss = 0.0
|
||||
x = input
|
||||
y = target
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x)
|
||||
y = block(y)
|
||||
if i in feature_layers:
|
||||
loss += torch.nn.functional.l1_loss(x, y)
|
||||
if i in style_layers:
|
||||
act_x = x.reshape(x.shape[0], x.shape[1], -1)
|
||||
act_y = y.reshape(y.shape[0], y.shape[1], -1)
|
||||
gram_x = act_x @ act_x.permute(0, 2, 1)
|
||||
gram_y = act_y @ act_y.permute(0, 2, 1)
|
||||
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
|
||||
return loss
|
||||
|
||||
class VGGstyleLoss(torch.nn.Module):
|
||||
def __init__(self, resize=True):
|
||||
super(VGGstyleLoss, self).__init__()
|
||||
blocks = []
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
||||
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
||||
for bl in blocks:
|
||||
for p in bl:
|
||||
p.requires_grad = False
|
||||
self.blocks = torch.nn.ModuleList(blocks)
|
||||
self.transform = torch.nn.functional.interpolate
|
||||
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
|
||||
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
|
||||
self.resize = resize
|
||||
|
||||
def forward(self, input, target, feature_layers=[0,1,2,3], style_layers=[]):
|
||||
if input.shape[1] != 3:
|
||||
input = input.repeat(1, 3, 1, 1)
|
||||
target = target.repeat(1, 3, 1, 1)
|
||||
input = (input-self.mean) / self.std
|
||||
target = (target-self.mean) / self.std
|
||||
if self.resize:
|
||||
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
||||
target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
|
||||
loss = 0.0
|
||||
x = input
|
||||
y = target
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x)
|
||||
y = block(y)
|
||||
if i in feature_layers:
|
||||
loss += torch.nn.functional.l1_loss(x, y)
|
||||
if i in style_layers:
|
||||
act_x = x.reshape(x.shape[0], x.shape[1], -1)
|
||||
act_y = y.reshape(y.shape[0], y.shape[1], -1)
|
||||
gram_x = act_x @ act_x.permute(0, 2, 1)
|
||||
gram_y = act_y @ act_y.permute(0, 2, 1)
|
||||
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
|
||||
return loss
|
||||
82
app/service/image2sketch/models/template_model.py
Normal file
82
app/service/image2sketch/models/template_model.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
|
||||
|
||||
class TemplateModel(BaseModel):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
"""Add new model-specific options and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- the option parser
|
||||
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
|
||||
if is_train:
|
||||
parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this model class.
|
||||
|
||||
Parameters:
|
||||
opt -- training/test options
|
||||
|
||||
A few things can be done here.
|
||||
- (required) call the initialization function of BaseModel
|
||||
- define loss function, visualization images, model names, and optimizers
|
||||
"""
|
||||
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
|
||||
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
|
||||
self.loss_names = ['loss_G']
|
||||
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
|
||||
self.visual_names = ['data_A', 'data_B', 'output']
|
||||
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
|
||||
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
|
||||
self.model_names = ['G']
|
||||
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
|
||||
if self.isTrain: # only defined during training time
|
||||
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
|
||||
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
|
||||
self.criterionLoss = torch.nn.L1Loss()
|
||||
# define and initialize optimizers. You can define one optimizer for each network.
|
||||
# If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
||||
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers = [self.optimizer]
|
||||
|
||||
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input: a dictionary that contains the data itself and its metadata information.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
|
||||
self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
|
||||
self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
|
||||
self.output = self.netG(self.data_A) # generate output image given the input data_A
|
||||
|
||||
def backward(self):
|
||||
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
||||
# caculate the intermediate results if necessary; here self.output has been computed during function <forward>
|
||||
# calculate loss given the input and intermediate results
|
||||
self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
|
||||
self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
|
||||
|
||||
def optimize_parameters(self):
|
||||
"""Update network weights; it will be called in every training iteration."""
|
||||
self.forward() # first call forward to calculate intermediate results
|
||||
self.optimizer.zero_grad() # clear network G's existing gradients
|
||||
self.backward() # calculate gradients for network G
|
||||
self.optimizer.step() # update gradients for network G
|
||||
45
app/service/image2sketch/models/test_model.py
Normal file
45
app/service/image2sketch/models/test_model.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
|
||||
|
||||
class TestModel(BaseModel):
|
||||
""" This TesteModel can be used to generate CycleGAN results for only one direction.
|
||||
This model will automatically set '--dataset_mode single', which only loads the images from one collection.
|
||||
|
||||
See the test instruction for more details.
|
||||
"""
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
assert not is_train, 'TestModel cannot be used during training time'
|
||||
parser.set_defaults(dataset_mode='single')
|
||||
parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
assert(not opt.isTrain)
|
||||
BaseModel.__init__(self, opt)
|
||||
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
||||
self.loss_names = []
|
||||
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
||||
self.visual_names = ['real', 'fake']
|
||||
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
||||
self.model_names = ['G' + opt.model_suffix] # only generator is needed.
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
|
||||
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
|
||||
# assigns the model to self.netG_[suffix] so that it can be loaded
|
||||
# please see <BaseModel.load_networks>
|
||||
setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
|
||||
|
||||
def set_input(self, input):
|
||||
self.real = input['A'].to(self.device)
|
||||
self.image_paths = input['A_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass."""
|
||||
self.fake = self.netG(self.real) # G(real)
|
||||
|
||||
def optimize_parameters(self):
|
||||
"""No optimization for test model."""
|
||||
pass
|
||||
68
app/service/image2sketch/models/triplet_model.py
Normal file
68
app/service/image2sketch/models/triplet_model.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from util.image_pool import ImagePool
|
||||
|
||||
|
||||
class TripletModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
parser.set_defaults(norm='batch', netG='triplet', dataset_mode='triplet')
|
||||
if is_train:
|
||||
parser.set_defaults(pool_size=0, gan_mode='vanilla')
|
||||
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
self.loss_names = ['G_triplet']
|
||||
self.visual_names = ['x','y']
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G']
|
||||
else:
|
||||
self.model_names = ['G']
|
||||
self.netG = networks.define_G(1, 1, opt.ngf, opt.netG, opt.norm,
|
||||
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionL1 = torch.nn.L1Loss()
|
||||
|
||||
self.triplet = torch.nn.TripletMarginLoss(margin=3.0)
|
||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
|
||||
def set_input(self, input):
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
||||
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
||||
self.real_C = input['C'].to(self.device)
|
||||
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
|
||||
|
||||
def forward(self):
|
||||
self.x,self.y,self.z = self.netG(self.real_A,self.real_B,self.real_C)
|
||||
|
||||
|
||||
def backward_G(self):
|
||||
self.loss_G_triplet_1 = self.triplet(self.x,self.y,self.z)
|
||||
self.loss_G_triplet = self.loss_G_triplet_1
|
||||
|
||||
self.loss_G = self.loss_G_triplet
|
||||
self.loss_G.backward()
|
||||
|
||||
def optimize_parameters(self):
|
||||
self.optimizer_G.zero_grad()
|
||||
self.backward_G()
|
||||
self.optimizer_G.step()
|
||||
144
app/service/image2sketch/models/unpaired_model.py
Normal file
144
app/service/image2sketch/models/unpaired_model.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import torch
|
||||
|
||||
from . import networks
|
||||
from .base_model import BaseModel
|
||||
from .perceptual import VGGPerceptualLoss
|
||||
from ..util.image_pool import ImagePool
|
||||
|
||||
|
||||
class UnpairedModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
parser.set_defaults(norm='batch', netG='ref_unpair_cbam_cat', netG2='ref_unpair_recon', dataset_mode='unaligned')
|
||||
if is_train:
|
||||
parser.set_defaults(pool_size=0, gan_mode='vanilla')
|
||||
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseModel.__init__(self, opt)
|
||||
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
||||
self.loss_names = ['G_GAN', 'G_L1_1', 'G_Rec', 'G_line', 'D_real', 'D_fake']
|
||||
self.visual_names = ['real_A', 'content_output', 'real_B']
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G_A', 'G_B', 'D']
|
||||
else: # during test time, only load G
|
||||
self.model_names = ['G_A', 'G_B']
|
||||
# define networks (both generator and discriminator)
|
||||
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
|
||||
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
self.netG_B = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG2, opt.norm,
|
||||
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
|
||||
if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
||||
self.netD = networks.define_D(1, opt.ndf, opt.netD,
|
||||
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
self.styletps = networks.define_styletps(init_weights_='./checkpoints/contrastive_pretrained.pth', gpu_ids_=self.gpu_ids, shape=False)
|
||||
self.HED = networks.define_HED(init_weights_='./checkpoints/network-bsds500.pytorch', gpu_ids_=self.gpu_ids)
|
||||
|
||||
if self.isTrain: # define discriminators
|
||||
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
|
||||
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
|
||||
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
||||
|
||||
if self.isTrain:
|
||||
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
|
||||
# define loss functions
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionL1_1 = torch.nn.L1Loss()
|
||||
self.criterionL1_2 = torch.nn.L1Loss()
|
||||
self.criterionL1_3 = torch.nn.L1Loss()
|
||||
self.per_loss_1 = VGGPerceptualLoss().to(self.device)
|
||||
self.per_loss_2 = VGGPerceptualLoss().to(self.device)
|
||||
self.per_loss_3 = VGGPerceptualLoss().to(self.device)
|
||||
|
||||
self.optimizer_GA = torch.optim.Adam(self.netG_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizer_GB = torch.optim.Adam(self.netG_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizers.append(self.optimizer_GA)
|
||||
self.optimizers.append(self.optimizer_GB)
|
||||
|
||||
self.optimizers.append(self.optimizer_D)
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
|
||||
The option 'direction' can be used to swap images in domain A and domain B.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
||||
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
||||
# self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
self.content_output = self.netG_A(self.real_A, self.real_B)
|
||||
self.rec_output = self.netG_B(self.content_output, self.content_output)
|
||||
|
||||
def update_process(self, epoch, total_epoch):
|
||||
self.epoch_count = epoch
|
||||
self.epoch_count_total = total_epoch
|
||||
|
||||
def backward_D(self):
|
||||
"""Calculate GAN loss for the discriminator
|
||||
|
||||
Parameters:
|
||||
netD (network) -- the discriminator D
|
||||
real (tensor array) -- real images
|
||||
fake (tensor array) -- images generated by a generator
|
||||
|
||||
Return the discriminator loss.
|
||||
We also call loss_D.backward() to calculate the gradients.
|
||||
"""
|
||||
# Real
|
||||
pred_real = self.netD(self.real_B)
|
||||
self.loss_D_real = self.criterionGAN(pred_real, True)
|
||||
# Fake
|
||||
pred_fake = self.netD(self.content_output.detach())
|
||||
self.loss_D_fake = self.criterionGAN(pred_fake, False)
|
||||
# Combined loss and calculate gradients
|
||||
loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5
|
||||
loss_D.backward()
|
||||
return loss_D
|
||||
|
||||
def backward_G(self):
|
||||
"""Calculate GAN and L1 loss for the generator"""
|
||||
|
||||
pred_fake = self.netD(self.content_output)
|
||||
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
|
||||
|
||||
self.content_output_line = self.HED(self.real_A)
|
||||
self.rec_output_line = self.HED(self.rec_output)
|
||||
self.t1, self.t2, _ = self.styletps(self.content_output, self.real_B, self.real_B)
|
||||
|
||||
decay_lambda = 5 - ((self.epoch_count * 4.5) / self.epoch_count_total)
|
||||
self.loss_G_L1_1 = self.criterionL1_1(self.t1, self.t2) * 10
|
||||
self.loss_G_Rec = self.per_loss_2(self.real_A, self.rec_output) * decay_lambda
|
||||
self.loss_G_line = self.per_loss_3(self.content_output_line, self.rec_output_line) * decay_lambda
|
||||
|
||||
self.loss_G = self.loss_G_GAN + self.loss_G_L1_1 + self.loss_G_Rec + self.loss_G_line
|
||||
self.loss_G.backward()
|
||||
|
||||
def optimize_parameters(self):
|
||||
self.forward() # compute fake images: G(A)
|
||||
# update D
|
||||
self.set_requires_grad(self.netD, True) # enable backprop for D
|
||||
self.optimizer_D.zero_grad() # set D's gradients to zero
|
||||
self.backward_D() # calculate gradients for backward_D_unsuper
|
||||
self.optimizer_D.step() # update D's weights
|
||||
# update G
|
||||
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
|
||||
self.optimizer_GA.zero_grad() # set G's gradients to zero
|
||||
self.optimizer_GB.zero_grad() # set G's gradients to zero
|
||||
self.backward_G() # calculate graidents for G
|
||||
self.optimizer_GA.step() # udpate G's weights
|
||||
self.optimizer_GB.step() # udpate G's weights
|
||||
Reference in New Issue
Block a user