diff --git a/.gitignore b/.gitignore index 87a4934..fe14af2 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,5 @@ app/logs/* *.log *.jpg /qodana.yaml +.pth +.pytorch \ No newline at end of file diff --git a/app/api/api_image2sketch.py b/app/api/api_image2sketch.py new file mode 100644 index 0000000..98d94ee --- /dev/null +++ b/app/api/api_image2sketch.py @@ -0,0 +1,36 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.image2sketch import Image2SketchModel +from app.schemas.response_template import ResponseModel +from app.service.image2sketch.server import Image2SketchServer + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/image2sketch") +def image2sketch(request_item: Image2SketchModel): + """ + 创建一个具有以下参数的请求体: + - **sr_image_url**: 超分图片的minio或s3 url地址 + - **sr_xn**: 超分的倍数,只接受2或4 + - **sr_tasks_id**: 任务id 用于取消超分任务和获取超分结果 + + 示例参数: + { + "image_url": "test/real_Top_971fe3085a69f31f3e66c225eabb0eea.jpg_Img.jpg", + "sketch_bucket": "test", + "sketch_name": "12341556-89.jpg" + } + """ + # try: + logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = Image2SketchServer(request_item) + sketch_url = service.get_result() + # except Exception as e: + # logger.warning(f"image2sketch Run Exception @@@@@@:{e}") + # raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=sketch_url) diff --git a/app/api/api_route.py b/app/api/api_route.py index c2bd2d2..8bcbe44 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,14 +1,14 @@ from fastapi import APIRouter -from app.api import api_test -from app.api import api_super_resolution -from app.api import api_generate_image from app.api import api_attribute_retrieve -from app.api import api_design from app.api import api_chat_robot -from app.api import api_prompt_generation +from app.api import api_design from app.api import api_design_pre_processing - +from app.api import api_generate_image +from app.api import api_image2sketch +from app.api import api_prompt_generation +from app.api import api_super_resolution +from app.api import api_test router = APIRouter() @@ -20,3 +20,4 @@ router.include_router(api_design.router, tags=['design'], prefix="/api") router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api") router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api") router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") +router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api") diff --git a/app/schemas/image2sketch.py b/app/schemas/image2sketch.py new file mode 100644 index 0000000..a124739 --- /dev/null +++ b/app/schemas/image2sketch.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class Image2SketchModel(BaseModel): + image_url: str + sketch_bucket: str + sketch_name: str diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index 1b4f33c..4cbff8f 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -151,10 +151,10 @@ class DesignPreprocessing: # 推理得到keypoint sketch['keypoint_result'] = self.keypoint_cache(sketch) if sketch['site'] == 'up': - _, seg_cache = self.load_seg_result(sketch['image_id']) + _, seg_cache = self.load_seg_result(sketch['obj']) if not _: # 推理获得seg 结果 - seg_result = get_seg_result(sketch["image_id"], sketch['image_obj'])[0] + seg_result = get_seg_result(sketch["image_id"], sketch['obj'])[0] self.save_seg_result(seg_result, sketch['image_id']) if IF_DEBUG_SHOW: diff --git a/app/service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg b/app/service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg new file mode 100644 index 0000000..0347322 Binary files /dev/null and b/app/service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg differ diff --git a/app/service/image2sketch/infer.py b/app/service/image2sketch/infer.py new file mode 100644 index 0000000..266b37c --- /dev/null +++ b/app/service/image2sketch/infer.py @@ -0,0 +1,89 @@ +import os + +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image + +from .models import create_model + + +def tensor2im(input_image, imtype=np.uint8): + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def save_image(image_numpy, image_path, w, h, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + image_pil = image_pil.resize((w, h)) + image_pil.save(image_path) + + +def save_img(image_tensor, w, h, filename): + image_pil = tensor2im(image_tensor) + + save_image(image_pil, filename, w, h, aspect_ratio=1.0) + print("Image saved as {}".format(filename)) + + +def load_img(filepath): + img = Image.open(filepath).convert('L') + # print(img.size) + width = img.size[0] + height = img.size[1] + # img = img.resize((512, 512), Image.BICUBIC) + return img, width, height + + +if __name__ == '__main__': + img_A = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testA/real_Dress_732caedc416a0cbfedd0e6528040eac7.jpg_Img.jpg" + img_B = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testC/styleA.png" + from opt import Config + + opt = Config() # get test options + # hard-code some parameters for test + opt.num_threads = 0 # test code only supports num_threads = 0 + opt.batch_size = 1 # test code only supports batch_size = 1 + opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + opt.no_flip = True # no flip; comment this line if results on flipped images are needed. + opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. + device = torch.device("cuda:0") + model = create_model(opt) # create a model given opt.model and other options + model.setup(opt) + transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] + transform = transforms.Compose(transform_list) + if opt.eval: + model.eval() + data = {} + print(os.getcwd()) + B = reference, _, _ = load_img(r"E:\workspace\trinity_client_aida\app\service\image2sketch\datasets\ref_unpair\testC\styleA.png") + style_img = transform(reference) + data['B'] = style_img + data['B'] = data['B'].unsqueeze(0).to(device) + A = Image.open(r"E:\workspace\trinity_client_aida\app\service\image2sketch\datasets\ref_unpair\testA\real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg") + width = A.size[0] + height = A.size[1] + # data['A'] = A.resize((512, 512)) + data['A'] = transform(A) + data['A'] = data['A'].unsqueeze(0).to(device) + model.set_input(data) + model.test() # run inference + visuals = model.get_current_visuals() # get image results + save_img(visuals['content_output'].cpu(), width, height, "result/result.jpg") diff --git a/app/service/image2sketch/models/__init__.py b/app/service/image2sketch/models/__init__.py new file mode 100644 index 0000000..809105c --- /dev/null +++ b/app/service/image2sketch/models/__init__.py @@ -0,0 +1,49 @@ +import importlib + +from app.service.image2sketch.models import unpaired_model as modellib +from .base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + # model_filename = "." + model_name + "_model" + # modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from .models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/app/service/image2sketch/models/base_model.py b/app/service/image2sketch/models/base_model.py new file mode 100644 index 0000000..6de961b --- /dev/null +++ b/app/service/image2sketch/models/base_model.py @@ -0,0 +1,230 @@ +import os +import torch +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch + self.load_networks(load_suffix) + self.print_networks(opt.verbose) + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + old_lr = self.optimizers[0].param_groups[0]['lr'] + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate %.7f -> %.7f' % (old_lr, lr)) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + print('loading the model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + # patch InstanceNorm checkpoints prior to 0.4 + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + net.load_state_dict(state_dict) + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/app/service/image2sketch/models/layer.py b/app/service/image2sketch/models/layer.py new file mode 100644 index 0000000..df96a35 --- /dev/null +++ b/app/service/image2sketch/models/layer.py @@ -0,0 +1,354 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CNR2d(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]): + super().__init__() + + if bias == []: + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)] + + if norm != []: + layers += [Norm2d(nch_out, norm)] + + if relu != []: + layers += [ReLU(relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + self.cbr = nn.Sequential(*layers) + + def forward(self, x): + return self.cbr(x) + + +class DECNR2d(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]): + super().__init__() + + if bias == []: + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)] + + if norm != []: + layers += [Norm2d(nch_out, norm)] + + if relu != []: + layers += [ReLU(relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + self.decbr = nn.Sequential(*layers) + + def forward(self, x): + return self.decbr(x) + + +class ResBlock(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): + super().__init__() + + if bias == []: + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + + # 1st conv + layers += [Padding(padding, padding_mode=padding_mode)] + layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + # 2nd conv + layers += [Padding(padding, padding_mode=padding_mode)] + layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] + + self.resblk = nn.Sequential(*layers) + + def forward(self, x): + return x + self.resblk(x) + + +class ResBlock_cat(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): + super().__init__() + + if bias == []: + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + + # 1st conv + layers += [Padding(padding, padding_mode=padding_mode)] + layers += [CNR2d(nch_in*2, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + # 2nd conv + layers += [Padding(padding, padding_mode=padding_mode)] + layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] + + self.resblk = nn.Sequential(*layers) + + def forward(self,x,y): + output = x + self.resblk(torch.cat([x,y],dim=1)) + return output + +class LinearBlock(nn.Module): + def __init__(self, input_dim, output_dim, norm='none', activation='relu'): + super(LinearBlock, self).__init__() + use_bias = True + # initialize fully connected layer + if norm == 'sn': + self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) + else: + self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) + + # initialize normalization + norm_dim = output_dim + if norm == 'bn': + self.norm = nn.BatchNorm1d(norm_dim) + elif norm == 'in': + self.norm = nn.InstanceNorm1d(norm_dim) + elif norm == 'ln': + self.norm = LayerNorm(norm_dim) + elif norm == 'none' or norm == 'sn': + self.norm = None + else: + assert 0, "Unsupported normalization: {}".format(norm) + + # initialize activation + if activation == 'relu': + self.activation = nn.ReLU(inplace=True) + elif activation == 'lrelu': + self.activation = nn.LeakyReLU(0.2, inplace=True) + elif activation == 'prelu': + self.activation = nn.PReLU() + elif activation == 'selu': + self.activation = nn.SELU(inplace=True) + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'none': + self.activation = None + else: + assert 0, "Unsupported activation: {}".format(activation) + + def forward(self, x): + out = self.fc(x) + if self.norm: + out = self.norm(out) + if self.activation: + out = self.activation(out) + return out + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): + + super(MLP, self).__init__() + self.model = [] + self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] + for i in range(n_blk - 2): + self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] + self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + +class CNR1d(nn.Module): + def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]): + super().__init__() + + if norm == 'bnorm': + bias = False + else: + bias = True + + layers = [] + layers += [nn.Linear(nch_in, nch_out, bias=bias)] + + if norm != []: + layers += [Norm2d(nch_out, norm)] + + if relu != []: + layers += [ReLU(relu)] + + if drop != []: + layers += [nn.Dropout2d(drop)] + + self.cbr = nn.Sequential(*layers) + + def forward(self, x): + return self.cbr(x) + + +class Conv2d(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True): + super(Conv2d, self).__init__() + self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) + + def forward(self, x): + return self.conv(x) + + +class Deconv2d(nn.Module): + def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True): + super(Deconv2d, self).__init__() + self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias) + + # layers = [nn.Upsample(scale_factor=2, mode='bilinear'), + # nn.ReflectionPad2d(1), + # nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)] + # + # self.deconv = nn.Sequential(*layers) + + def forward(self, x): + return self.deconv(x) + + +class Linear(nn.Module): + def __init__(self, nch_in, nch_out): + super(Linear, self).__init__() + self.linear = nn.Linear(nch_in, nch_out) + + def forward(self, x): + return self.linear(x) + + +class Norm2d(nn.Module): + def __init__(self, nch, norm_mode): + super(Norm2d, self).__init__() + if norm_mode == 'bnorm': + self.norm = nn.BatchNorm2d(nch) + elif norm_mode == 'inorm': + self.norm = nn.InstanceNorm2d(nch) + + def forward(self, x): + return self.norm(x) + + +class ReLU(nn.Module): + def __init__(self, relu): + super(ReLU, self).__init__() + if relu > 0: + self.relu = nn.LeakyReLU(relu, True) + elif relu == 0: + self.relu = nn.ReLU(True) + + def forward(self, x): + return self.relu(x) + + +class Padding(nn.Module): + def __init__(self, padding, padding_mode='zeros', value=0): + super(Padding, self).__init__() + if padding_mode == 'reflection': + self. padding = nn.ReflectionPad2d(padding) + elif padding_mode == 'replication': + self.padding = nn.ReplicationPad2d(padding) + elif padding_mode == 'constant': + self.padding = nn.ConstantPad2d(padding, value) + elif padding_mode == 'zeros': + self.padding = nn.ZeroPad2d(padding) + + def forward(self, x): + return self.padding(x) + + +class Pooling2d(nn.Module): + def __init__(self, nch=[], pool=2, type='avg'): + super().__init__() + + if type == 'avg': + self.pooling = nn.AvgPool2d(pool) + elif type == 'max': + self.pooling = nn.MaxPool2d(pool) + elif type == 'conv': + self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool) + + def forward(self, x): + return self.pooling(x) + + +class UnPooling2d(nn.Module): + def __init__(self, nch=[], pool=2, type='nearest'): + super().__init__() + + if type == 'nearest': + self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True) + elif type == 'bilinear': + self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True) + elif type == 'conv': + self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool) + + def forward(self, x): + return self.unpooling(x) + + +class Concat(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x1, x2): + diffy = x2.size()[2] - x1.size()[2] + diffx = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, + diffy // 2, diffy - diffy // 2]) + + return torch.cat([x2, x1], dim=1) + + +class TV1dLoss(nn.Module): + def __init__(self): + super(TV1dLoss, self).__init__() + + def forward(self, input): + # loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ + # torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) + loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:])) + + return loss + + +class TV2dLoss(nn.Module): + def __init__(self): + super(TV2dLoss, self).__init__() + + def forward(self, input): + loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ + torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) + return loss + + +class SSIM2dLoss(nn.Module): + def __init__(self): + super(SSIM2dLoss, self).__init__() + + def forward(self, input, targer): + loss = 0 + return loss + diff --git a/app/service/image2sketch/models/networks.py b/app/service/image2sketch/models/networks.py new file mode 100644 index 0000000..fc341c2 --- /dev/null +++ b/app/service/image2sketch/models/networks.py @@ -0,0 +1,734 @@ +import functools + +from torch.nn import init +from torch.optim import lr_scheduler + +from .layer import * + + +############################################################################### +# Helper Functions +############################################################################### + + +class Identity(nn.Module): + def forward(self, x): + return x + + +def get_norm_layer(norm_type='instance'): + """Return a normalization layer + + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == 'none': + def norm_layer(x): + return Identity() + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) + return lr_l + + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert (torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netG == 'ref_unpair_cbam_cat': + net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_cbam_cat') + elif netG == 'ref_unpair_recon': + net = ref_unpair(input_nc, output_nc, ngf, norm='inorm', status='ref_unpair_recon') + elif netG == 'triplet': + net = triplet(input_nc, output_nc, ngf, norm='inorm') + + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % netG) + return init_net(net, init_type, init_gain, gpu_ids) + + +class AdaIN(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + eps = 1e-5 + mean_x = torch.mean(x, dim=[2, 3]) + mean_y = torch.mean(y, dim=[2, 3]) + + std_x = torch.std(x, dim=[2, 3]) + std_y = torch.std(y, dim=[2, 3]) + + mean_x = mean_x.unsqueeze(-1).unsqueeze(-1) + mean_y = mean_y.unsqueeze(-1).unsqueeze(-1) + + std_x = std_x.unsqueeze(-1).unsqueeze(-1) + eps + std_y = std_y.unsqueeze(-1).unsqueeze(-1) + eps + + out = (x - mean_x) / std_x * std_y + mean_y + + return out + + +class HED(nn.Module): + def __init__(self): + super(HED, self).__init__() + + self.moduleVggOne = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleVggTwo = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleVggThr = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleVggFou = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleVggFiv = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=False) + ) + + self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0) + self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0) + self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0) + self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) + self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) + + self.moduleCombine = nn.Sequential( + nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0), + nn.Sigmoid() + ) + + def forward(self, tensorInput): + tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793 + tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762 + tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434 + tensorInput = torch.cat([tensorBlue, tensorGreen, tensorRed], 1) + + tensorVggOne = self.moduleVggOne(tensorInput) + tensorVggTwo = self.moduleVggTwo(tensorVggOne) + tensorVggThr = self.moduleVggThr(tensorVggTwo) + tensorVggFou = self.moduleVggFou(tensorVggThr) + tensorVggFiv = self.moduleVggFiv(tensorVggFou) + + tensorScoreOne = self.moduleScoreOne(tensorVggOne) + tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo) + tensorScoreThr = self.moduleScoreThr(tensorVggThr) + tensorScoreFou = self.moduleScoreFou(tensorVggFou) + tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv) + + tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) + + return self.moduleCombine(torch.cat([tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv], 1)) + # return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreOne, tensorScoreTwo ], 1)) + + # return torch.sigmoid(tensorScoreOne),torch.sigmoid(tensorScoreTwo),torch.sigmoid(tensorScoreThr),torch.sigmoid(tensorScoreFou),torch.sigmoid(tensorScoreFiv),self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1)) + # return torch.sigmoid(tensorScoreTwo) + + +def define_HED(init_weights_, gpu_ids_=[]): + net = HED() + + if len(gpu_ids_) > 0: + assert (torch.cuda.is_available()) + net.to(gpu_ids_[0]) + net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs + + if not init_weights_ == None: + device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu') + print('Loading model from: %s' % init_weights_) + state_dict = torch.load(init_weights_, map_location=str(device)) + if isinstance(net, torch.nn.DataParallel): + net.module.load_state_dict(state_dict) + else: + net.load_state_dict(state_dict) + print('load the weights successfully') + + return net + + +def define_styletps(init_weights_, gpu_ids_=[], shape=False): + net = None + if shape == False: + net = triplet() + if len(gpu_ids_) > 0: + assert (torch.cuda.is_available()) + net.to(gpu_ids_[0]) + net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs + + if not init_weights_ == None: + device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu') + print('Loading model from: %s' % init_weights_) + state_dict = torch.load(init_weights_, map_location=str(device)) + if isinstance(net, torch.nn.DataParallel): + net.module.load_state_dict(state_dict) + else: + net.load_state_dict(state_dict) + print('load the weights successfully') + + return net + + +class triplet(nn.Module): + def __init__(self): # mnblk=4 + super(triplet, self).__init__() + + # self.channels = nch_in + self.nch_in = 1 + self.nch_out = 1 + self.nch_ker = 64 + self.norm = 'bnorm' + # self.nblk = nblk + + if self.norm == 'bnorm': + self.bias = False + else: + self.bias = True + + self.conv0 = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0) + self.conv1 = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.conv2 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + + self.final_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(256, 128) + + def forward(self, x, y, z): + + x = self.conv0(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.final_pool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + + y = self.conv0(y) + y = self.conv1(y) + y = self.conv2(y) + y = self.final_pool(y) + y = torch.flatten(y, 1) + y = self.linear(y) + + z = self.conv0(z) + z = self.conv1(z) + z = self.conv2(z) + z = self.final_pool(z) + z = torch.flatten(z, 1) + z = self.linear(z) + + return x, y, z + + +class MLP(nn.Module): + def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): + super(MLP, self).__init__() + self.model = [] + self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] + for i in range(n_blk - 2): + self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] + self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations + self.model = nn.Sequential(*self.model) + + def forward(self, x): + return self.model(x.view(x.size(0), -1)) + + +class ref_unpair(nn.Module): + def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm', nblk=4, status='ref_unpair'): + super(ref_unpair, self).__init__() + + nch_ker = 64 + # self.channels = nch_in + self.nch_in = nch_in + self.nchs_in = 1 + self.status = status + + if self.status == 'ref_unpair_recon': + self.nch_out = 3 + self.nch_in = 1 + else: + self.nch_out = 1 + + self.nch_ker = nch_ker + self.norm = norm + self.nblk = nblk + self.dec0 = [] + + if status == 'ref_unpair_cbam_cat': + self.cbam_c = CBAM(nch_ker * 8, 16, 3, cbam_status="channel") + self.cbam_s = CBAM(nch_ker * 8, 16, 3, cbam_status="spatial") + + self.enc1_s = CNR2d(self.nchs_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0) + self.enc2_s = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.enc3_s = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.enc4_s = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + + if norm == 'bnorm': + self.bias = False + else: + self.bias = True + + self.enc1_c = CNR2d(self.nch_in, self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0) + self.enc2_c = CNR2d(self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.enc3_c = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + self.enc4_c = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) + + if status == 'ref_unpair_cbam_cat': + self.res_cat1 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection') + self.res_cat2 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection') + self.res_cat3 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection') + self.res_cat4 = ResBlock_cat(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection') + + if self.nblk and status != 'ref_unpair_cbam_cat': + res = [] + for i in range(self.nblk): + res += [ResBlock(8 * self.nch_ker, 8 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')] + self.res1 = nn.Sequential(*res) + + # self.dec0 += [DECNR2d(16 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)] + self.dec0 += [DECNR2d(8 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)] + self.dec0 += [DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)] + self.dec0 += [DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)] + self.dec0 += [DECNR2d(1 * self.nch_ker, 1 * self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)] + self.dec0 += [nn.Conv2d(1 * self.nch_ker, self.nch_out, kernel_size=3, stride=1, padding=1)] + + self.dec = nn.Sequential(*self.dec0) + + def forward(self, content, style): + + content_cs = self.enc1_c(content) + content_cs = self.enc2_c(content_cs) + content_cs = self.enc3_c(content_cs) + content_cs = self.enc4_c(content_cs) + # content_cs = self.enc5_c(content_cs) + + if self.status == 'ref_unpair_cbam_cat': + cbam_content_cs = self.cbam_s(content_cs) + sp_content_cs = content_cs + cbam_content_cs + + style_cs = self.enc1_s(style) + style_cs = self.enc2_s(style_cs) + style_cs = self.enc3_s(style_cs) + style_cs = self.enc4_s(style_cs) + + cbam_style_cs = self.cbam_c(style_cs) + ch_style_cs = style_cs + cbam_style_cs + + content_output = self.adaptive_instance_normalization(content_cs, style_cs) + cbam_content_output = self.adaptive_instance_normalization(sp_content_cs, ch_style_cs) + + content_output = self.res_cat1(content_output, cbam_content_output) + content_output = self.res_cat2(content_output, cbam_content_output) + content_output = self.res_cat3(content_output, cbam_content_output) + content_output = self.res_cat4(content_output, cbam_content_output) + + + else: + content_output = content_cs + + if self.nblk and self.status != 'ref_unpair_cbam_cat': + content_cs = self.res1(content_output) + + content_output = self.dec(content_output) + + content_output = torch.tanh(content_output) + + return content_output + + def calc_mean_std(self, feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + + def adaptive_instance_normalization(self, content_feat, style_feat): + assert (content_feat.size()[:2] == style_feat.size()[:2]) + size = content_feat.size() + style_mean, style_std = self.calc_mean_std(style_feat) + content_mean, content_std = self.calc_mean_std(content_feat) + + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netD == 'basic': # default PatchGAN classifier + net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) + elif netD == 'n_layers': # more options + net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) + elif netD == 'pixel': # classify if each pixel is real or fake + net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) + return init_net(net, init_type, init_gain, gpu_ids) + + +############################################################################## +# Classes +############################################################################## +class GANLoss(nn.Module): + """Define different GAN objectives. + + The GANLoss class abstracts away the need to create the target label tensor + that has the same size as the input. + """ + + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + """ Initialize the GANLoss class. + + Parameters: + gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool) - - label for a real image + target_fake_label (bool) - - label of a fake image + + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + """ + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss + + +def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): + if lambda_gp > 0.0: + if type == 'real': # either use real images, fake images, or a linear interpolation of two. + interpolatesv = real_data + elif type == 'fake': + interpolatesv = fake_data + elif type == 'mixed': + alpha = torch.rand(real_data.shape[0], 1, device=device) + alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) + interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) + else: + raise NotImplementedError('{} not implemented'.format(type)) + interpolatesv.requires_grad_(True) + disc_interpolates = netD(interpolatesv) + gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True) + gradients = gradients[0].view(real_data.size(0), -1) # flat the data + gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps + return gradient_penalty, gradients + else: + return 0.0, None + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + +class PixelDiscriminator(nn.Module): + """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" + + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): + """Construct a 1x1 PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + """ + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + """Standard forward.""" + return self.net(input) + + +class CBAM(nn.Module): + def __init__(self, n_channels_in, reduction_ratio, kernel_size, cbam_status): + super(CBAM, self).__init__() + self.n_channels_in = n_channels_in + self.reduction_ratio = reduction_ratio + self.kernel_size = kernel_size + self.channel_attention = ChannelAttention_nopara(n_channels_in, reduction_ratio) + self.spatial_attention = SpatialAttention_nopara(kernel_size) + self.status = cbam_status + + def forward(self, x): + ## We don't use cbam in this version + if self.status == "cbam": + chan_att = self.channel_attention(x) + fp = chan_att * x + spat_att = self.spatial_attention(fp) + fpp = spat_att * fp + + if self.status == "spatial": + spat_att = self.spatial_attention(x) # * s_para_1d + fpp = spat_att * x + if self.status == "channel": + chan_att = self.channel_attention(x) # * c_para_1d + fpp = chan_att * x + + return fpp # ,c_wgt,s_wgt + + +class SpatialAttention_nopara(nn.Module): + def __init__(self, kernel_size): + super(SpatialAttention_nopara, self).__init__() + self.kernel_size = kernel_size + assert kernel_size % 2 == 1, "Odd kernel size required" + self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=int((kernel_size - 1) / 2)) + + def forward(self, x): + max_pool = self.agg_channel(x, "max") + avg_pool = self.agg_channel(x, "avg") + pool = torch.cat([max_pool, avg_pool], dim=1) + conv = self.conv(pool) + conv = conv.repeat(1, x.size()[1], 1, 1) + att = torch.sigmoid(conv) + return att + + def agg_channel(self, x, pool="max"): + b, c, h, w = x.size() + x = x.view(b, c, h * w) + x = x.permute(0, 2, 1) + if pool == "max": + x = F.max_pool1d(x, c) + elif pool == "avg": + x = F.avg_pool1d(x, c) + x = x.permute(0, 2, 1) + x = x.view(b, 1, h, w) + return x + + +class ChannelAttention_nopara(nn.Module): + def __init__(self, n_channels_in, reduction_ratio): + super(ChannelAttention_nopara, self).__init__() + self.n_channels_in = n_channels_in + self.reduction_ratio = reduction_ratio + self.middle_layer_size = int(self.n_channels_in / float(self.reduction_ratio)) + self.bottleneck = nn.Sequential( + nn.Linear(self.n_channels_in, self.middle_layer_size), + nn.ReLU(), + nn.Linear(self.middle_layer_size, self.n_channels_in) + ) + + def forward(self, x): + kernel = (x.size()[2], x.size()[3]) + avg_pool = F.avg_pool2d(x, kernel) + max_pool = F.max_pool2d(x, kernel) + avg_pool = avg_pool.view(avg_pool.size()[0], -1) + max_pool = max_pool.view(max_pool.size()[0], -1) + avg_pool_bck = self.bottleneck(avg_pool) + max_pool_bck = self.bottleneck(max_pool) + pool_sum = avg_pool_bck + max_pool_bck + sig_pool = torch.sigmoid(pool_sum) + sig_pool = sig_pool.unsqueeze(2).unsqueeze(3) + # out = sig_pool.repeat(1,1,kernel[0], kernel[1]) + + return sig_pool diff --git a/app/service/image2sketch/models/perceptual.py b/app/service/image2sketch/models/perceptual.py new file mode 100644 index 0000000..666fab8 --- /dev/null +++ b/app/service/image2sketch/models/perceptual.py @@ -0,0 +1,86 @@ +import torch +import torchvision + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, resize=True): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) + for bl in blocks: + for p in bl: + p.requires_grad = False + self.blocks = torch.nn.ModuleList(blocks) + self.transform = torch.nn.functional.interpolate + self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) + self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) + self.resize = resize + + def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]): + if input.shape[1] != 3: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + input = (input-self.mean) / self.std + target = (target-self.mean) / self.std + if self.resize: + input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) + target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) + loss = 0.0 + x = input + y = target + for i, block in enumerate(self.blocks): + x = block(x) + y = block(y) + if i in feature_layers: + loss += torch.nn.functional.l1_loss(x, y) + if i in style_layers: + act_x = x.reshape(x.shape[0], x.shape[1], -1) + act_y = y.reshape(y.shape[0], y.shape[1], -1) + gram_x = act_x @ act_x.permute(0, 2, 1) + gram_y = act_y @ act_y.permute(0, 2, 1) + loss += torch.nn.functional.l1_loss(gram_x, gram_y) + return loss + +class VGGstyleLoss(torch.nn.Module): + def __init__(self, resize=True): + super(VGGstyleLoss, self).__init__() + blocks = [] + blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) + blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) + for bl in blocks: + for p in bl: + p.requires_grad = False + self.blocks = torch.nn.ModuleList(blocks) + self.transform = torch.nn.functional.interpolate + self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) + self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) + self.resize = resize + + def forward(self, input, target, feature_layers=[0,1,2,3], style_layers=[]): + if input.shape[1] != 3: + input = input.repeat(1, 3, 1, 1) + target = target.repeat(1, 3, 1, 1) + input = (input-self.mean) / self.std + target = (target-self.mean) / self.std + if self.resize: + input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) + target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) + loss = 0.0 + x = input + y = target + for i, block in enumerate(self.blocks): + x = block(x) + y = block(y) + if i in feature_layers: + loss += torch.nn.functional.l1_loss(x, y) + if i in style_layers: + act_x = x.reshape(x.shape[0], x.shape[1], -1) + act_y = y.reshape(y.shape[0], y.shape[1], -1) + gram_x = act_x @ act_x.permute(0, 2, 1) + gram_y = act_y @ act_y.permute(0, 2, 1) + loss += torch.nn.functional.l1_loss(gram_x, gram_y) + return loss diff --git a/app/service/image2sketch/models/template_model.py b/app/service/image2sketch/models/template_model.py new file mode 100644 index 0000000..45c68b2 --- /dev/null +++ b/app/service/image2sketch/models/template_model.py @@ -0,0 +1,82 @@ +import torch +from .base_model import BaseModel +from . import networks + + +class TemplateModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new model-specific options and rewrite default values for existing options. + + Parameters: + parser -- the option parser + is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. + if is_train: + parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. + + return parser + + def __init__(self, opt): + """Initialize this model class. + + Parameters: + opt -- training/test options + + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + BaseModel.__init__(self, opt) # call the initialization method of BaseModel + # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. + self.loss_names = ['loss_G'] + # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. + self.visual_names = ['data_A', 'data_B', 'output'] + # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. + # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. + self.model_names = ['G'] + # define networks; you can use opt.isTrain to specify different behaviors for training and test. + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) + if self.isTrain: # only defined during training time + # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. + # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) + self.criterionLoss = torch.nn.L1Loss() + # define and initialize optimizers. You can define one optimizer for each network. + # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [self.optimizer] + + # Our program will automatically call to define schedulers, load networks, and print networks + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B + self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A + self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B + self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths + + def forward(self): + """Run forward pass. This will be called by both functions and .""" + self.output = self.netG(self.data_A) # generate output image given the input data_A + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + # caculate the intermediate results if necessary; here self.output has been computed during function + # calculate loss given the input and intermediate results + self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression + self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.forward() # first call forward to calculate intermediate results + self.optimizer.zero_grad() # clear network G's existing gradients + self.backward() # calculate gradients for network G + self.optimizer.step() # update gradients for network G diff --git a/app/service/image2sketch/models/test_model.py b/app/service/image2sketch/models/test_model.py new file mode 100644 index 0000000..2f70821 --- /dev/null +++ b/app/service/image2sketch/models/test_model.py @@ -0,0 +1,45 @@ +from .base_model import BaseModel +from . import networks + + +class TestModel(BaseModel): + """ This TesteModel can be used to generate CycleGAN results for only one direction. + This model will automatically set '--dataset_mode single', which only loads the images from one collection. + + See the test instruction for more details. + """ + @staticmethod + def modify_commandline_options(parser, is_train=True): + assert not is_train, 'TestModel cannot be used during training time' + parser.set_defaults(dataset_mode='single') + parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.') + + return parser + + def __init__(self, opt): + assert(not opt.isTrain) + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + self.loss_names = [] + # specify the images you want to save/display. The training/test scripts will call + self.visual_names = ['real', 'fake'] + # specify the models you want to save to the disk. The training/test scripts will call and + self.model_names = ['G' + opt.model_suffix] # only generator is needed. + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, + opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + # assigns the model to self.netG_[suffix] so that it can be loaded + # please see + setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. + + def set_input(self, input): + self.real = input['A'].to(self.device) + self.image_paths = input['A_paths'] + + def forward(self): + """Run forward pass.""" + self.fake = self.netG(self.real) # G(real) + + def optimize_parameters(self): + """No optimization for test model.""" + pass diff --git a/app/service/image2sketch/models/triplet_model.py b/app/service/image2sketch/models/triplet_model.py new file mode 100644 index 0000000..a667d49 --- /dev/null +++ b/app/service/image2sketch/models/triplet_model.py @@ -0,0 +1,68 @@ +import torch +from .base_model import BaseModel +from . import networks +from util.image_pool import ImagePool + + +class TripletModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=True): + parser.set_defaults(norm='batch', netG='triplet', dataset_mode='triplet') + if is_train: + parser.set_defaults(pool_size=0, gan_mode='vanilla') + parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') + + return parser + + def __init__(self, opt): + + BaseModel.__init__(self, opt) + + self.loss_names = ['G_triplet'] + self.visual_names = ['x','y'] + + if self.isTrain: + self.model_names = ['G'] + else: + self.model_names = ['G'] + self.netG = networks.define_G(1, 1, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + + if self.isTrain: + self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images + self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images + + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + + self.triplet = torch.nn.TripletMarginLoss(margin=3.0) + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + + def set_input(self, input): + AtoB = self.opt.direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) + self.real_C = input['C'].to(self.device) + + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + + + def forward(self): + self.x,self.y,self.z = self.netG(self.real_A,self.real_B,self.real_C) + + + def backward_G(self): + self.loss_G_triplet_1 = self.triplet(self.x,self.y,self.z) + self.loss_G_triplet = self.loss_G_triplet_1 + + self.loss_G = self.loss_G_triplet + self.loss_G.backward() + + def optimize_parameters(self): + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() diff --git a/app/service/image2sketch/models/unpaired_model.py b/app/service/image2sketch/models/unpaired_model.py new file mode 100644 index 0000000..9c043ca --- /dev/null +++ b/app/service/image2sketch/models/unpaired_model.py @@ -0,0 +1,144 @@ +import torch + +from . import networks +from .base_model import BaseModel +from .perceptual import VGGPerceptualLoss +from ..util.image_pool import ImagePool + + +class UnpairedModel(BaseModel): + + @staticmethod + def modify_commandline_options(parser, is_train=True): + parser.set_defaults(norm='batch', netG='ref_unpair_cbam_cat', netG2='ref_unpair_recon', dataset_mode='unaligned') + if is_train: + parser.set_defaults(pool_size=0, gan_mode='vanilla') + parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') + + return parser + + def __init__(self, opt): + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + self.loss_names = ['G_GAN', 'G_L1_1', 'G_Rec', 'G_line', 'D_real', 'D_fake'] + self.visual_names = ['real_A', 'content_output', 'real_B'] + + if self.isTrain: + self.model_names = ['G_A', 'G_B', 'D'] + else: # during test time, only load G + self.model_names = ['G_A', 'G_B'] + # define networks (both generator and discriminator) + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + self.netG_B = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG2, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + self.netD = networks.define_D(1, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + self.styletps = networks.define_styletps(init_weights_='./checkpoints/contrastive_pretrained.pth', gpu_ids_=self.gpu_ids, shape=False) + self.HED = networks.define_HED(init_weights_='./checkpoints/network-bsds500.pytorch', gpu_ids_=self.gpu_ids) + + if self.isTrain: # define discriminators + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: + self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images + self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images + # define loss functions + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1_1 = torch.nn.L1Loss() + self.criterionL1_2 = torch.nn.L1Loss() + self.criterionL1_3 = torch.nn.L1Loss() + self.per_loss_1 = VGGPerceptualLoss().to(self.device) + self.per_loss_2 = VGGPerceptualLoss().to(self.device) + self.per_loss_3 = VGGPerceptualLoss().to(self.device) + + self.optimizer_GA = torch.optim.Adam(self.netG_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_GB = torch.optim.Adam(self.netG_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_GA) + self.optimizers.append(self.optimizer_GB) + + self.optimizers.append(self.optimizer_D) + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + + The option 'direction' can be used to swap images in domain A and domain B. + """ + AtoB = self.opt.direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) + # self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + """Run forward pass; called by both functions and .""" + self.content_output = self.netG_A(self.real_A, self.real_B) + self.rec_output = self.netG_B(self.content_output, self.content_output) + + def update_process(self, epoch, total_epoch): + self.epoch_count = epoch + self.epoch_count_total = total_epoch + + def backward_D(self): + """Calculate GAN loss for the discriminator + + Parameters: + netD (network) -- the discriminator D + real (tensor array) -- real images + fake (tensor array) -- images generated by a generator + + Return the discriminator loss. + We also call loss_D.backward() to calculate the gradients. + """ + # Real + pred_real = self.netD(self.real_B) + self.loss_D_real = self.criterionGAN(pred_real, True) + # Fake + pred_fake = self.netD(self.content_output.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + # Combined loss and calculate gradients + loss_D = (self.loss_D_real + self.loss_D_fake) * 0.5 + loss_D.backward() + return loss_D + + def backward_G(self): + """Calculate GAN and L1 loss for the generator""" + + pred_fake = self.netD(self.content_output) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + + self.content_output_line = self.HED(self.real_A) + self.rec_output_line = self.HED(self.rec_output) + self.t1, self.t2, _ = self.styletps(self.content_output, self.real_B, self.real_B) + + decay_lambda = 5 - ((self.epoch_count * 4.5) / self.epoch_count_total) + self.loss_G_L1_1 = self.criterionL1_1(self.t1, self.t2) * 10 + self.loss_G_Rec = self.per_loss_2(self.real_A, self.rec_output) * decay_lambda + self.loss_G_line = self.per_loss_3(self.content_output_line, self.rec_output_line) * decay_lambda + + self.loss_G = self.loss_G_GAN + self.loss_G_L1_1 + self.loss_G_Rec + self.loss_G_line + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() # compute fake images: G(A) + # update D + self.set_requires_grad(self.netD, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for backward_D_unsuper + self.optimizer_D.step() # update D's weights + # update G + self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G + self.optimizer_GA.zero_grad() # set G's gradients to zero + self.optimizer_GB.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer_GA.step() # udpate G's weights + self.optimizer_GB.step() # udpate G's weights diff --git a/app/service/image2sketch/opt.py b/app/service/image2sketch/opt.py new file mode 100644 index 0000000..7af09e1 --- /dev/null +++ b/app/service/image2sketch/opt.py @@ -0,0 +1,45 @@ +class Config: + def __init__(self): + # 基本参数 + self.dataroot = "service/image2sketch/datasets/ref_unpair" + self.name = 'semi_unpair' + self.gpu_ids = [0] + self.checkpoints_dir = 'service/image2sketch/checkpoints/' + # 模型参数 + self.model = 'unpaired' + self.input_nc = 3 + self.output_nc = 3 + self.ngf = 64 + self.ndf = 64 + self.netD = 'basic' + self.netG = 'ref_unpair_cbam_cat' + self.netG2 = 'ref_unpair_recon' + self.n_layers_D = 3 + self.norm = 'instance' + self.init_type = 'normal' + self.init_gain = 0.02 + self.no_dropout = False # 对应 `--no_dropout` + # 数据集参数 + self.dataset_mode = 'single' + self.direction = 'AtoB' + self.serial_batches = True # 对应 `--serial_batches` + self.num_threads = 4 + self.batch_size = 4 + self.load_size = 512 + self.crop_size = 512 + self.max_dataset_size = float("inf") + self.preprocess = 'resize_and_crop' + self.no_flip = False # 对应 `--no_flip` + self.display_winsize = 256 + # 额外参数 + self.epoch = '100' + self.load_iter = 0 + self.verbose = False # 对应 `--verbose` + self.suffix = '' + self.isTrain = False + self.results_dir = 'service/image2sketch/results' + self.aspect_ratio = 1.0 + self.phase = 'test' + self.eval = False + self.num_test = 1000 + self.morm = 'batch' diff --git a/app/service/image2sketch/server.py b/app/service/image2sketch/server.py new file mode 100644 index 0000000..accd4b8 --- /dev/null +++ b/app/service/image2sketch/server.py @@ -0,0 +1,79 @@ +import logging + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image + +from app.schemas.image2sketch import Image2SketchModel +from app.service.image2sketch.infer import tensor2im +from app.service.image2sketch.models import create_model +from app.service.image2sketch.opt import Config +from app.service.utils.oss_client import oss_get_image, oss_upload_image + +logger = logging.getLogger() + + +def tensor2im(input_image, imtype=np.uint8): + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +class Image2SketchServer: + def __init__(self, request_data): + self.image_url = request_data.image_url + self.sketch_bucket = request_data.sketch_bucket + self.sketch_name = request_data.sketch_name + self.opt = Config() + self.opt.num_threads = 0 # test code only supports num_threads = 0 + self.opt.batch_size = 1 # test code only supports batch_size = 1 + self.opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. + self.opt.no_flip = True # no flip; comment this line if results on flipped images are needed. + self.opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. + self.data = {} + device = torch.device("cuda:0") + self.model = create_model(self.opt) + self.model.setup(self.opt) + transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] + transform = transforms.Compose(transform_list) + style_img = Image.open(r"E:\workspace\trinity_client_aida\app\service\image2sketch\datasets\ref_unpair\testC\20180422151845_stEe4.jpeg").convert('L') + style_img = transform(style_img) + self.data['B'] = style_img + self.data['B'] = self.data['B'].unsqueeze(0).to(device) + A, self.width, self.height = self.get_image(self.image_url) + self.data['A'] = transform(A) + self.data['A'] = self.data['A'].unsqueeze(0).to(device) + + def get_result(self): + self.model.set_input(self.data) + self.model.test() # run inference + visuals = self.model.get_current_visuals() # get image results + image_numpy = tensor2im(visuals['content_output'].cpu()) + image_bytes = cv2.imencode(".jpg", image_numpy)[1].tobytes() + req = oss_upload_image(bucket=self.sketch_bucket, object_name=self.sketch_name, image_bytes=image_bytes) + return f"{req.bucket_name}/{req.object_name}" + + def get_image(self, image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + image = image.convert('RGB') + width = image.size[0] + height = image.size[1] + return image, width, height + + +if __name__ == '__main__': + data = Image2SketchModel(image_url="test/real_Dress_790b2c6e370644e134df7abdfe7e54d9.jpg_Img.jpg", sketch_bucket="test", sketch_name="test123.jpg") + server = Image2SketchServer(data) + sketch_url = server.get_result() + print(sketch_url) diff --git a/app/service/image2sketch/util/__init__.py b/app/service/image2sketch/util/__init__.py new file mode 100644 index 0000000..ae36f63 --- /dev/null +++ b/app/service/image2sketch/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" diff --git a/app/service/image2sketch/util/get_data.py b/app/service/image2sketch/util/get_data.py new file mode 100644 index 0000000..97edc3c --- /dev/null +++ b/app/service/image2sketch/util/get_data.py @@ -0,0 +1,110 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """A Python script for downloading CycleGAN or pix2pix datasets. + + Parameters: + technique (str) -- One of: 'cyclegan' or 'pix2pix'. + verbose (bool) -- If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' + and 'scripts/download_cyclegan_model.sh'. + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Parameters: + save_path (str) -- A directory to save the data to. + dataset (str) -- (optional). A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full (str) -- the absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) diff --git a/app/service/image2sketch/util/html.py b/app/service/image2sketch/util/html.py new file mode 100644 index 0000000..cc3262a --- /dev/null +++ b/app/service/image2sketch/util/html.py @@ -0,0 +1,86 @@ +import dominate +from dominate.tags import meta, h3, table, tr, td, p, a, img, br +import os + + +class HTML: + """This HTML class allows us to save images and write texts into a single HTML file. + + It consists of functions such as (add a text header to the HTML file), + (add a row of images to the HTML file), and (save the HTML to the disk). + It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. + """ + + def __init__(self, web_dir, title, refresh=0): + """Initialize the HTML classes + + Parameters: + web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: + with self.doc.head: + meta(http_equiv="refresh", content=str(refresh)) + + def get_image_dir(self): + """Return the directory that stores images""" + return self.img_dir + + def add_header(self, text): + """Insert a header to the HTML file + + Parameters: + text (str) -- the header text + """ + with self.doc: + h3(text) + + def add_images(self, ims, txts, links, width=400): + """add images to the HTML file + + Parameters: + ims (str list) -- a list of image paths + txts (str list) -- a list of image names shown on the website + links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page + """ + self.t = table(border=1, style="table-layout: fixed;") # Insert a table + self.doc.add(self.t) + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + img(style="width:%dpx" % width, src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + """save the current content to the HMTL file""" + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': # we show an example usage here. + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims, txts, links = [], [], [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/app/service/image2sketch/util/image_pool.py b/app/service/image2sketch/util/image_pool.py new file mode 100644 index 0000000..6d086f8 --- /dev/null +++ b/app/service/image2sketch/util/image_pool.py @@ -0,0 +1,54 @@ +import random +import torch + + +class ImagePool(): + """This class implements an image buffer that stores previously generated images. + + This buffer enables us to update discriminators using a history of generated images + rather than the ones produced by the latest generators. + """ + + def __init__(self, pool_size): + """Initialize the ImagePool class + + Parameters: + pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created + """ + self.pool_size = pool_size + if self.pool_size > 0: # create an empty pool + self.num_imgs = 0 + self.images = [] + + def query(self, images): + """Return an image from the pool. + + Parameters: + images: the latest generated images from the generator + + Returns images from the buffer. + + By 50/100, the buffer will return input images. + By 50/100, the buffer will return images previously stored in the buffer, + and insert the current images to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer + random_id = random.randint(0, self.pool_size - 1) # randint is inclusive + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_images.append(image) + return_images = torch.cat(return_images, 0) # collect all the images and return + return return_images diff --git a/app/service/image2sketch/util/util.py b/app/service/image2sketch/util/util.py new file mode 100644 index 0000000..b050c13 --- /dev/null +++ b/app/service/image2sketch/util/util.py @@ -0,0 +1,103 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os + + +def tensor2im(input_image, imtype=np.uint8): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array + if image_numpy.shape[0] == 1: # grayscale to RGB + image_numpy = np.tile(image_numpy, (3, 1, 1)) + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + + image_pil = Image.fromarray(image_numpy) + h, w, _ = image_numpy.shape + + if aspect_ratio > 1.0: + image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + if aspect_ratio < 1.0: + image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) diff --git a/app/service/image2sketch/util/visualizer.py b/app/service/image2sketch/util/visualizer.py new file mode 100644 index 0000000..239c5ee --- /dev/null +++ b/app/service/image2sketch/util/visualizer.py @@ -0,0 +1,223 @@ +import numpy as np +import os +import sys +import ntpath +import time +from . import util, html +from subprocess import Popen, PIPE + + +if sys.version_info[0] == 2: + VisdomExceptionBase = Exception +else: + VisdomExceptionBase = ConnectionError + + +def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): + """Save images to the disk. + + Parameters: + webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) + visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs + image_path (str) -- the string is used to create image paths + aspect_ratio (float) -- the aspect ratio of saved images + width (int) -- the images will be resized to width x width + + This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. + """ + image_dir = webpage.get_image_dir() + short_path = ntpath.basename(image_path[0]) + name = os.path.splitext(short_path)[0] + + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, im_data in visuals.items(): + im = util.tensor2im(im_data) + image_name = '%s_%s.png' % (name, label) + save_path = os.path.join(image_dir, image_name) + util.save_image(im, save_path, aspect_ratio=aspect_ratio) + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=width) + + +class Visualizer(): + """This class includes several functions that can display/save images and print/save logging information. + + It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. + """ + + def __init__(self, opt): + """Initialize the Visualizer class + + Parameters: + opt -- stores all the experiment flags; needs to be a subclass of BaseOptions + Step 1: Cache the training/test options + Step 2: connect to a visdom server + Step 3: create an HTML object for saveing HTML filters + Step 4: create a logging file to store training losses + """ + self.opt = opt # cache the option + self.display_id = opt.display_id + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + self.port = opt.display_port + self.saved = False + ''' + if self.display_id > 0: # connect to a visdom server given and + import visdom + self.ncols = opt.display_ncols + self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) + if not self.vis.check_connection(): + self.create_visdom_connections() + ''' + if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + # create a logging file to store training losses + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + def reset(self): + """Reset the self.saved status""" + self.saved = False + ''' + def create_visdom_connections(self): + """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ + cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port + print('\n\nCould not connect to Visdom server. \n Trying to start a server....') + print('Command: %s' % cmd) + Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) + + def display_current_results(self, visuals, epoch, save_result): + """Display current results on visdom; save current results to an HTML file. + + Parameters: + visuals (OrderedDict) - - dictionary of images to display or save + epoch (int) - - the current epoch + save_result (bool) - - if save the current results to an HTML file + """ + if self.display_id > 0: # show images in the browser using visdom + ncols = self.ncols + if ncols > 0: # show all the images in one visdom panel + ncols = min(ncols, len(visuals)) + h, w = next(iter(visuals.values())).shape[:2] + table_css = """""" % (w, h) # create a table css + # create a table of images. + title = self.name + label_html = '' + label_html_row = '' + images = [] + idx = 0 + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + label_html_row += '%s' % label + images.append(image_numpy.transpose([2, 0, 1])) + idx += 1 + if idx % ncols == 0: + label_html += '%s' % label_html_row + label_html_row = '' + white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 + while idx % ncols != 0: + images.append(white_image) + label_html_row += '' + idx += 1 + if label_html_row != '': + label_html += '%s' % label_html_row + try: + self.vis.images(images, nrow=ncols, win=self.display_id + 1, + padding=2, opts=dict(title=title + ' images')) + label_html = '%s
' % label_html + self.vis.text(table_css + label_html, win=self.display_id + 2, + opts=dict(title=title + ' labels')) + except VisdomExceptionBase: + self.create_visdom_connections() + + else: # show each image in a separate visdom panel; + idx = 1 + try: + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) + idx += 1 + except VisdomExceptionBase: + self.create_visdom_connections() + + if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. + self.saved = True + # save images to the disk + for label, image in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + image_numpy = util.tensor2im(image) + img_path = 'epoch%.3d_%s.png' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + ''' + def plot_current_losses(self, epoch, counter_ratio, losses): + """display the current losses on visdom display: dictionary of error labels and values + + Parameters: + epoch (int) -- current epoch + counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + """ + if not hasattr(self, 'plot_data'): + self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) + ''' + try: + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + except VisdomExceptionBase: + self.create_visdom_connections() + ''' + # losses: same format as |losses| of plot_current_losses + def print_current_losses(self, epoch, iters, losses, t_comp, t_data): + """print current losses on console; also save the losses to the disk + + Parameters: + epoch (int) -- current epoch + iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) + losses (OrderedDict) -- training losses stored in the format of (name, float) pairs + t_comp (float) -- computational time per data point (normalized by batch_size) + t_data (float) -- data loading time per data point (normalized by batch_size) + """ + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) + for k, v in losses.items(): + message += '%s: %.3f ' % (k, v) + + print(message) # print the message + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) # save the message diff --git a/download_checkpoints.py b/download_checkpoints.py new file mode 100644 index 0000000..03cc2c6 --- /dev/null +++ b/download_checkpoints.py @@ -0,0 +1,45 @@ +import os + +from minio import Minio +from minio.error import S3Error + +MINIO_URL = "www.minio.aida.com.hk:12024" +MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB' +MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR' +MINIO_SECURE = True +# 配置MinIO客户端 +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +# 下载函数 +def download_folder(bucket_name, folder_name, local_dir): + try: + # 确保本地目录存在 + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + # 遍历MinIO中的文件 + objects = minio_client.list_objects(bucket_name, prefix=folder_name, recursive=True) + for obj in objects: + # 构造本地文件路径 + local_file_path = os.path.join(local_dir, obj.object_name[len(folder_name):]) + local_file_dir = os.path.dirname(local_file_path) + + # 确保本地目录存在 + if not os.path.exists(local_file_dir): + os.makedirs(local_file_dir) + + # 下载文件 + minio_client.fget_object(bucket_name, obj.object_name, local_file_path) + print(f"Downloaded {obj.object_name} to {local_file_path}") + + except S3Error as e: + print(f"Error occurred: {e}") + + +# 使用示例 +bucket_name = "test" # 替换成你的bucket名称 +folder_name = "checkpoints/" # 权重文件夹的路径 +local_dir = "app/service/image2sketch/checkpoints" # 替换成你希望保存到的本地目录 + +download_folder(bucket_name, folder_name, local_dir)