Add codeformer and update license
This commit is contained in:
100
basicsr/data/__init__.py
Normal file
100
basicsr/data/__init__.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import importlib
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from os import path as osp
|
||||
|
||||
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
||||
from basicsr.utils import get_root_logger, scandir
|
||||
from basicsr.utils.dist_util import get_dist_info
|
||||
from basicsr.utils.registry import DATASET_REGISTRY
|
||||
|
||||
__all__ = ['build_dataset', 'build_dataloader']
|
||||
|
||||
# automatically scan and import dataset modules for registry
|
||||
# scan all the files under the data folder with '_dataset' in file names
|
||||
data_folder = osp.dirname(osp.abspath(__file__))
|
||||
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
||||
# import all the dataset modules
|
||||
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
||||
|
||||
|
||||
def build_dataset(dataset_opt):
|
||||
"""Build dataset from options.
|
||||
|
||||
Args:
|
||||
dataset_opt (dict): Configuration for dataset. It must constain:
|
||||
name (str): Dataset name.
|
||||
type (str): Dataset type.
|
||||
"""
|
||||
dataset_opt = deepcopy(dataset_opt)
|
||||
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
|
||||
return dataset
|
||||
|
||||
|
||||
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
||||
"""Build dataloader.
|
||||
|
||||
Args:
|
||||
dataset (torch.utils.data.Dataset): Dataset.
|
||||
dataset_opt (dict): Dataset options. It contains the following keys:
|
||||
phase (str): 'train' or 'val'.
|
||||
num_worker_per_gpu (int): Number of workers for each GPU.
|
||||
batch_size_per_gpu (int): Training batch size for each GPU.
|
||||
num_gpu (int): Number of GPUs. Used only in the train phase.
|
||||
Default: 1.
|
||||
dist (bool): Whether in distributed training. Used only in the train
|
||||
phase. Default: False.
|
||||
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
||||
seed (int | None): Seed. Default: None
|
||||
"""
|
||||
phase = dataset_opt['phase']
|
||||
rank, _ = get_dist_info()
|
||||
if phase == 'train':
|
||||
if dist: # distributed training
|
||||
batch_size = dataset_opt['batch_size_per_gpu']
|
||||
num_workers = dataset_opt['num_worker_per_gpu']
|
||||
else: # non-distributed training
|
||||
multiplier = 1 if num_gpu == 0 else num_gpu
|
||||
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
||||
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
||||
dataloader_args = dict(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
sampler=sampler,
|
||||
drop_last=True)
|
||||
if sampler is None:
|
||||
dataloader_args['shuffle'] = True
|
||||
dataloader_args['worker_init_fn'] = partial(
|
||||
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
||||
elif phase in ['val', 'test']: # validation
|
||||
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
||||
else:
|
||||
raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
|
||||
|
||||
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
||||
|
||||
prefetch_mode = dataset_opt.get('prefetch_mode')
|
||||
if prefetch_mode == 'cpu': # CPUPrefetcher
|
||||
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
||||
logger = get_root_logger()
|
||||
logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
|
||||
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
||||
else:
|
||||
# prefetch_mode=None: Normal dataloader
|
||||
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
||||
return torch.utils.data.DataLoader(**dataloader_args)
|
||||
|
||||
|
||||
def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||
# Set the worker seed to num_workers * rank + worker_id + seed
|
||||
worker_seed = num_workers * rank + worker_id + seed
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
48
basicsr/data/data_sampler.py
Normal file
48
basicsr/data/data_sampler.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import math
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class EnlargedSampler(Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset.
|
||||
|
||||
Modified from torch.utils.data.distributed.DistributedSampler
|
||||
Support enlarging the dataset for iteration-based training, for saving
|
||||
time when restart the dataloader after each epoch
|
||||
|
||||
Args:
|
||||
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
||||
num_replicas (int | None): Number of processes participating in
|
||||
the training. It is usually the world_size.
|
||||
rank (int | None): Rank of the current process within num_replicas.
|
||||
ratio (int): Enlarging ratio. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(self.total_size, generator=g).tolist()
|
||||
|
||||
dataset_size = len(self.dataset)
|
||||
indices = [v % dataset_size for v in indices]
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
305
basicsr/data/data_util.py
Normal file
305
basicsr/data/data_util.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from os import path as osp
|
||||
from torch.nn import functional as F
|
||||
|
||||
from basicsr.data.transforms import mod_crop
|
||||
from basicsr.utils import img2tensor, scandir
|
||||
|
||||
|
||||
def read_img_seq(path, require_mod_crop=False, scale=1):
|
||||
"""Read a sequence of images from a given folder path.
|
||||
|
||||
Args:
|
||||
path (list[str] | str): List of image paths or image folder path.
|
||||
require_mod_crop (bool): Require mod crop for each image.
|
||||
Default: False.
|
||||
scale (int): Scale factor for mod_crop. Default: 1.
|
||||
|
||||
Returns:
|
||||
Tensor: size (t, c, h, w), RGB, [0, 1].
|
||||
"""
|
||||
if isinstance(path, list):
|
||||
img_paths = path
|
||||
else:
|
||||
img_paths = sorted(list(scandir(path, full_path=True)))
|
||||
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
||||
if require_mod_crop:
|
||||
imgs = [mod_crop(img, scale) for img in imgs]
|
||||
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
return imgs
|
||||
|
||||
|
||||
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
||||
"""Generate an index list for reading `num_frames` frames from a sequence
|
||||
of images.
|
||||
|
||||
Args:
|
||||
crt_idx (int): Current center index.
|
||||
max_frame_num (int): Max number of the sequence of images (from 1).
|
||||
num_frames (int): Reading num_frames frames.
|
||||
padding (str): Padding mode, one of
|
||||
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
||||
Examples: current_idx = 0, num_frames = 5
|
||||
The generated frame indices under different padding mode:
|
||||
replicate: [0, 0, 0, 1, 2]
|
||||
reflection: [2, 1, 0, 1, 2]
|
||||
reflection_circle: [4, 3, 0, 1, 2]
|
||||
circle: [3, 4, 0, 1, 2]
|
||||
|
||||
Returns:
|
||||
list[int]: A list of indices.
|
||||
"""
|
||||
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
||||
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
||||
|
||||
max_frame_num = max_frame_num - 1 # start from 0
|
||||
num_pad = num_frames // 2
|
||||
|
||||
indices = []
|
||||
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
||||
if i < 0:
|
||||
if padding == 'replicate':
|
||||
pad_idx = 0
|
||||
elif padding == 'reflection':
|
||||
pad_idx = -i
|
||||
elif padding == 'reflection_circle':
|
||||
pad_idx = crt_idx + num_pad - i
|
||||
else:
|
||||
pad_idx = num_frames + i
|
||||
elif i > max_frame_num:
|
||||
if padding == 'replicate':
|
||||
pad_idx = max_frame_num
|
||||
elif padding == 'reflection':
|
||||
pad_idx = max_frame_num * 2 - i
|
||||
elif padding == 'reflection_circle':
|
||||
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
||||
else:
|
||||
pad_idx = i - num_frames
|
||||
else:
|
||||
pad_idx = i
|
||||
indices.append(pad_idx)
|
||||
return indices
|
||||
|
||||
|
||||
def paired_paths_from_lmdb(folders, keys):
|
||||
"""Generate paired paths from lmdb files.
|
||||
|
||||
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
||||
|
||||
lq.lmdb
|
||||
├── data.mdb
|
||||
├── lock.mdb
|
||||
├── meta_info.txt
|
||||
|
||||
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
||||
https://lmdb.readthedocs.io/en/release/ for more details.
|
||||
|
||||
The meta_info.txt is a specified txt file to record the meta information
|
||||
of our datasets. It will be automatically created when preparing
|
||||
datasets by our provided dataset tools.
|
||||
Each line in the txt file records
|
||||
1)image name (with extension),
|
||||
2)image shape,
|
||||
3)compression level, separated by a white space.
|
||||
Example: `baboon.png (120,125,3) 1`
|
||||
|
||||
We use the image name without extension as the lmdb key.
|
||||
Note that we use the same key for the corresponding lq and gt images.
|
||||
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
Note that this key is different from lmdb keys.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
||||
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
||||
f'formats. But received {input_key}: {input_folder}; '
|
||||
f'{gt_key}: {gt_folder}')
|
||||
# ensure that the two meta_info files are the same
|
||||
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
||||
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
||||
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
||||
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
||||
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
||||
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
||||
else:
|
||||
paths = []
|
||||
for lmdb_key in sorted(input_lmdb_keys):
|
||||
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
||||
return paths
|
||||
|
||||
|
||||
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
||||
"""Generate paired paths from an meta information file.
|
||||
|
||||
Each line in the meta information file contains the image names and
|
||||
image shape (usually for gt), separated by a white space.
|
||||
|
||||
Example of an meta information file:
|
||||
```
|
||||
0001_s001.png (480,480,3)
|
||||
0001_s002.png (480,480,3)
|
||||
```
|
||||
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
meta_info_file (str): Path to the meta information file.
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Usually the filename_tmpl is
|
||||
for files in the input folder.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
with open(meta_info_file, 'r') as fin:
|
||||
gt_names = [line.split(' ')[0] for line in fin]
|
||||
|
||||
paths = []
|
||||
for gt_name in gt_names:
|
||||
basename, ext = osp.splitext(osp.basename(gt_name))
|
||||
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
||||
input_path = osp.join(input_folder, input_name)
|
||||
gt_path = osp.join(gt_folder, gt_name)
|
||||
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
||||
return paths
|
||||
|
||||
|
||||
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
||||
"""Generate paired paths from folders.
|
||||
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Usually the filename_tmpl is
|
||||
for files in the input folder.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
input_paths = list(scandir(input_folder))
|
||||
gt_paths = list(scandir(gt_folder))
|
||||
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
||||
f'{len(input_paths)}, {len(gt_paths)}.')
|
||||
paths = []
|
||||
for gt_path in gt_paths:
|
||||
basename, ext = osp.splitext(osp.basename(gt_path))
|
||||
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
||||
input_path = osp.join(input_folder, input_name)
|
||||
assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
|
||||
gt_path = osp.join(gt_folder, gt_path)
|
||||
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
||||
return paths
|
||||
|
||||
|
||||
def paths_from_folder(folder):
|
||||
"""Generate paths from folder.
|
||||
|
||||
Args:
|
||||
folder (str): Folder path.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
|
||||
paths = list(scandir(folder))
|
||||
paths = [osp.join(folder, path) for path in paths]
|
||||
return paths
|
||||
|
||||
|
||||
def paths_from_lmdb(folder):
|
||||
"""Generate paths from lmdb.
|
||||
|
||||
Args:
|
||||
folder (str): Folder path.
|
||||
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
if not folder.endswith('.lmdb'):
|
||||
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
||||
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
||||
paths = [line.split('.')[0] for line in fin]
|
||||
return paths
|
||||
|
||||
|
||||
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
||||
"""Generate Gaussian kernel used in `duf_downsample`.
|
||||
|
||||
Args:
|
||||
kernel_size (int): Kernel size. Default: 13.
|
||||
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
||||
|
||||
Returns:
|
||||
np.array: The Gaussian kernel.
|
||||
"""
|
||||
from scipy.ndimage import filters as filters
|
||||
kernel = np.zeros((kernel_size, kernel_size))
|
||||
# set element at the middle to one, a dirac delta
|
||||
kernel[kernel_size // 2, kernel_size // 2] = 1
|
||||
# gaussian-smooth the dirac, resulting in a gaussian filter
|
||||
return filters.gaussian_filter(kernel, sigma)
|
||||
|
||||
|
||||
def duf_downsample(x, kernel_size=13, scale=4):
|
||||
"""Downsamping with Gaussian kernel used in the DUF official code.
|
||||
|
||||
Args:
|
||||
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
||||
kernel_size (int): Kernel size. Default: 13.
|
||||
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
||||
Default: 4.
|
||||
|
||||
Returns:
|
||||
Tensor: DUF downsampled frames.
|
||||
"""
|
||||
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
||||
|
||||
squeeze_flag = False
|
||||
if x.ndim == 4:
|
||||
squeeze_flag = True
|
||||
x = x.unsqueeze(0)
|
||||
b, t, c, h, w = x.size()
|
||||
x = x.view(-1, 1, h, w)
|
||||
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
||||
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
||||
|
||||
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
||||
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
||||
x = F.conv2d(x, gaussian_filter, stride=scale)
|
||||
x = x[:, :, 2:-2, 2:-2]
|
||||
x = x.view(b, t, c, x.size(2), x.size(3))
|
||||
if squeeze_flag:
|
||||
x = x.squeeze(0)
|
||||
return x
|
||||
125
basicsr/data/prefetch_dataloader.py
Normal file
125
basicsr/data/prefetch_dataloader.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import queue as Queue
|
||||
import threading
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class PrefetchGenerator(threading.Thread):
|
||||
"""A general prefetch generator.
|
||||
|
||||
Ref:
|
||||
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
||||
|
||||
Args:
|
||||
generator: Python generator.
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
"""
|
||||
|
||||
def __init__(self, generator, num_prefetch_queue):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = Queue.Queue(num_prefetch_queue)
|
||||
self.generator = generator
|
||||
self.daemon = True
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
for item in self.generator:
|
||||
self.queue.put(item)
|
||||
self.queue.put(None)
|
||||
|
||||
def __next__(self):
|
||||
next_item = self.queue.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class PrefetchDataLoader(DataLoader):
|
||||
"""Prefetch version of dataloader.
|
||||
|
||||
Ref:
|
||||
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
||||
|
||||
TODO:
|
||||
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
||||
ddp.
|
||||
|
||||
Args:
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
kwargs (dict): Other arguments for dataloader.
|
||||
"""
|
||||
|
||||
def __init__(self, num_prefetch_queue, **kwargs):
|
||||
self.num_prefetch_queue = num_prefetch_queue
|
||||
super(PrefetchDataLoader, self).__init__(**kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
||||
|
||||
|
||||
class CPUPrefetcher():
|
||||
"""CPU prefetcher.
|
||||
|
||||
Args:
|
||||
loader: Dataloader.
|
||||
"""
|
||||
|
||||
def __init__(self, loader):
|
||||
self.ori_loader = loader
|
||||
self.loader = iter(loader)
|
||||
|
||||
def next(self):
|
||||
try:
|
||||
return next(self.loader)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
self.loader = iter(self.ori_loader)
|
||||
|
||||
|
||||
class CUDAPrefetcher():
|
||||
"""CUDA prefetcher.
|
||||
|
||||
Ref:
|
||||
https://github.com/NVIDIA/apex/issues/304#
|
||||
|
||||
It may consums more GPU memory.
|
||||
|
||||
Args:
|
||||
loader: Dataloader.
|
||||
opt (dict): Options.
|
||||
"""
|
||||
|
||||
def __init__(self, loader, opt):
|
||||
self.ori_loader = loader
|
||||
self.loader = iter(loader)
|
||||
self.opt = opt
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.batch = next(self.loader) # self.batch is a dict
|
||||
except StopIteration:
|
||||
self.batch = None
|
||||
return None
|
||||
# put tensors to gpu
|
||||
with torch.cuda.stream(self.stream):
|
||||
for k, v in self.batch.items():
|
||||
if torch.is_tensor(v):
|
||||
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
||||
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
batch = self.batch
|
||||
self.preload()
|
||||
return batch
|
||||
|
||||
def reset(self):
|
||||
self.loader = iter(self.ori_loader)
|
||||
self.preload()
|
||||
165
basicsr/data/transforms.py
Normal file
165
basicsr/data/transforms.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import cv2
|
||||
import random
|
||||
|
||||
|
||||
def mod_crop(img, scale):
|
||||
"""Mod crop images, used during testing.
|
||||
|
||||
Args:
|
||||
img (ndarray): Input image.
|
||||
scale (int): Scale factor.
|
||||
|
||||
Returns:
|
||||
ndarray: Result image.
|
||||
"""
|
||||
img = img.copy()
|
||||
if img.ndim in (2, 3):
|
||||
h, w = img.shape[0], img.shape[1]
|
||||
h_remainder, w_remainder = h % scale, w % scale
|
||||
img = img[:h - h_remainder, :w - w_remainder, ...]
|
||||
else:
|
||||
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
||||
return img
|
||||
|
||||
|
||||
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
|
||||
"""Paired random crop.
|
||||
|
||||
It crops lists of lq and gt images with corresponding locations.
|
||||
|
||||
Args:
|
||||
img_gts (list[ndarray] | ndarray): GT images. Note that all images
|
||||
should have the same shape. If the input is an ndarray, it will
|
||||
be transformed to a list containing itself.
|
||||
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
||||
should have the same shape. If the input is an ndarray, it will
|
||||
be transformed to a list containing itself.
|
||||
gt_patch_size (int): GT patch size.
|
||||
scale (int): Scale factor.
|
||||
gt_path (str): Path to ground-truth.
|
||||
|
||||
Returns:
|
||||
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
||||
only have one element, just return ndarray.
|
||||
"""
|
||||
|
||||
if not isinstance(img_gts, list):
|
||||
img_gts = [img_gts]
|
||||
if not isinstance(img_lqs, list):
|
||||
img_lqs = [img_lqs]
|
||||
|
||||
h_lq, w_lq, _ = img_lqs[0].shape
|
||||
h_gt, w_gt, _ = img_gts[0].shape
|
||||
lq_patch_size = gt_patch_size // scale
|
||||
|
||||
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
||||
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
||||
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
||||
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
||||
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
||||
f'({lq_patch_size}, {lq_patch_size}). '
|
||||
f'Please remove {gt_path}.')
|
||||
|
||||
# randomly choose top and left coordinates for lq patch
|
||||
top = random.randint(0, h_lq - lq_patch_size)
|
||||
left = random.randint(0, w_lq - lq_patch_size)
|
||||
|
||||
# crop lq patch
|
||||
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
||||
|
||||
# crop corresponding gt patch
|
||||
top_gt, left_gt = int(top * scale), int(left * scale)
|
||||
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
||||
if len(img_gts) == 1:
|
||||
img_gts = img_gts[0]
|
||||
if len(img_lqs) == 1:
|
||||
img_lqs = img_lqs[0]
|
||||
return img_gts, img_lqs
|
||||
|
||||
|
||||
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
||||
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
||||
|
||||
We use vertical flip and transpose for rotation implementation.
|
||||
All the images in the list use the same augmentation.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
||||
is an ndarray, it will be transformed to a list.
|
||||
hflip (bool): Horizontal flip. Default: True.
|
||||
rotation (bool): Ratotation. Default: True.
|
||||
flows (list[ndarray]: Flows to be augmented. If the input is an
|
||||
ndarray, it will be transformed to a list.
|
||||
Dimension is (h, w, 2). Default: None.
|
||||
return_status (bool): Return the status of flip and rotation.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
list[ndarray] | ndarray: Augmented images and flows. If returned
|
||||
results only have one element, just return ndarray.
|
||||
|
||||
"""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
vflip = rotation and random.random() < 0.5
|
||||
rot90 = rotation and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip: # horizontal
|
||||
cv2.flip(img, 1, img)
|
||||
if vflip: # vertical
|
||||
cv2.flip(img, 0, img)
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
def _augment_flow(flow):
|
||||
if hflip: # horizontal
|
||||
cv2.flip(flow, 1, flow)
|
||||
flow[:, :, 0] *= -1
|
||||
if vflip: # vertical
|
||||
cv2.flip(flow, 0, flow)
|
||||
flow[:, :, 1] *= -1
|
||||
if rot90:
|
||||
flow = flow.transpose(1, 0, 2)
|
||||
flow = flow[:, :, [1, 0]]
|
||||
return flow
|
||||
|
||||
if not isinstance(imgs, list):
|
||||
imgs = [imgs]
|
||||
imgs = [_augment(img) for img in imgs]
|
||||
if len(imgs) == 1:
|
||||
imgs = imgs[0]
|
||||
|
||||
if flows is not None:
|
||||
if not isinstance(flows, list):
|
||||
flows = [flows]
|
||||
flows = [_augment_flow(flow) for flow in flows]
|
||||
if len(flows) == 1:
|
||||
flows = flows[0]
|
||||
return imgs, flows
|
||||
else:
|
||||
if return_status:
|
||||
return imgs, (hflip, vflip, rot90)
|
||||
else:
|
||||
return imgs
|
||||
|
||||
|
||||
def img_rotate(img, angle, center=None, scale=1.0):
|
||||
"""Rotate image.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image to be rotated.
|
||||
angle (float): Rotation angle in degrees. Positive values mean
|
||||
counter-clockwise rotation.
|
||||
center (tuple[int]): Rotation center. If the center is None,
|
||||
initialize it as the center of the image. Default: None.
|
||||
scale (float): Isotropic scale factor. Default: 1.0.
|
||||
"""
|
||||
(h, w) = img.shape[:2]
|
||||
|
||||
if center is None:
|
||||
center = (w // 2, h // 2)
|
||||
|
||||
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
||||
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
||||
return rotated_img
|
||||
Reference in New Issue
Block a user