Add codeformer and update license

This commit is contained in:
Felipe Daragon
2025-04-09 23:03:26 +01:00
parent f40ffde905
commit 3549b7e50c
84 changed files with 12545 additions and 19 deletions

100
basicsr/data/__init__.py Normal file
View File

@@ -0,0 +1,100 @@
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY
__all__ = ['build_dataset', 'build_dataloader']
# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
def build_dataset(dataset_opt):
"""Build dataset from options.
Args:
dataset_opt (dict): Configuration for dataset. It must constain:
name (str): Dataset name.
type (str): Dataset type.
"""
dataset_opt = deepcopy(dataset_opt)
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
logger = get_root_logger()
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
return dataset
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
"""Build dataloader.
Args:
dataset (torch.utils.data.Dataset): Dataset.
dataset_opt (dict): Dataset options. It contains the following keys:
phase (str): 'train' or 'val'.
num_worker_per_gpu (int): Number of workers for each GPU.
batch_size_per_gpu (int): Training batch size for each GPU.
num_gpu (int): Number of GPUs. Used only in the train phase.
Default: 1.
dist (bool): Whether in distributed training. Used only in the train
phase. Default: False.
sampler (torch.utils.data.sampler): Data sampler. Default: None.
seed (int | None): Seed. Default: None
"""
phase = dataset_opt['phase']
rank, _ = get_dist_info()
if phase == 'train':
if dist: # distributed training
batch_size = dataset_opt['batch_size_per_gpu']
num_workers = dataset_opt['num_worker_per_gpu']
else: # non-distributed training
multiplier = 1 if num_gpu == 0 else num_gpu
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
drop_last=True)
if sampler is None:
dataloader_args['shuffle'] = True
dataloader_args['worker_init_fn'] = partial(
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
elif phase in ['val', 'test']: # validation
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
else:
raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
prefetch_mode = dataset_opt.get('prefetch_mode')
if prefetch_mode == 'cpu': # CPUPrefetcher
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
logger = get_root_logger()
logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
else:
# prefetch_mode=None: Normal dataloader
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
return torch.utils.data.DataLoader(**dataloader_args)
def worker_init_fn(worker_id, num_workers, rank, seed):
# Set the worker seed to num_workers * rank + worker_id + seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)

View File

@@ -0,0 +1,48 @@
import math
import torch
from torch.utils.data.sampler import Sampler
class EnlargedSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
Modified from torch.utils.data.distributed.DistributedSampler
Support enlarging the dataset for iteration-based training, for saving
time when restart the dataloader after each epoch
Args:
dataset (torch.utils.data.Dataset): Dataset used for sampling.
num_replicas (int | None): Number of processes participating in
the training. It is usually the world_size.
rank (int | None): Rank of the current process within num_replicas.
ratio (int): Enlarging ratio. Default: 1.
"""
def __init__(self, dataset, num_replicas, rank, ratio=1):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(self.total_size, generator=g).tolist()
dataset_size = len(self.dataset)
indices = [v % dataset_size for v in indices]
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch

305
basicsr/data/data_util.py Normal file
View File

@@ -0,0 +1,305 @@
import cv2
import numpy as np
import torch
from os import path as osp
from torch.nn import functional as F
from basicsr.data.transforms import mod_crop
from basicsr.utils import img2tensor, scandir
def read_img_seq(path, require_mod_crop=False, scale=1):
"""Read a sequence of images from a given folder path.
Args:
path (list[str] | str): List of image paths or image folder path.
require_mod_crop (bool): Require mod crop for each image.
Default: False.
scale (int): Scale factor for mod_crop. Default: 1.
Returns:
Tensor: size (t, c, h, w), RGB, [0, 1].
"""
if isinstance(path, list):
img_paths = path
else:
img_paths = sorted(list(scandir(path, full_path=True)))
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = torch.stack(imgs, dim=0)
return imgs
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
"""Generate an index list for reading `num_frames` frames from a sequence
of images.
Args:
crt_idx (int): Current center index.
max_frame_num (int): Max number of the sequence of images (from 1).
num_frames (int): Reading num_frames frames.
padding (str): Padding mode, one of
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
Examples: current_idx = 0, num_frames = 5
The generated frame indices under different padding mode:
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
reflection_circle: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
list[int]: A list of indices.
"""
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
max_frame_num = max_frame_num - 1 # start from 0
num_pad = num_frames // 2
indices = []
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
if i < 0:
if padding == 'replicate':
pad_idx = 0
elif padding == 'reflection':
pad_idx = -i
elif padding == 'reflection_circle':
pad_idx = crt_idx + num_pad - i
else:
pad_idx = num_frames + i
elif i > max_frame_num:
if padding == 'replicate':
pad_idx = max_frame_num
elif padding == 'reflection':
pad_idx = max_frame_num * 2 - i
elif padding == 'reflection_circle':
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
else:
pad_idx = i - num_frames
else:
pad_idx = i
indices.append(pad_idx)
return indices
def paired_paths_from_lmdb(folders, keys):
"""Generate paired paths from lmdb files.
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
lq.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records
1)image name (with extension),
2)image shape,
3)compression level, separated by a white space.
Example: `baboon.png (120,125,3) 1`
We use the image name without extension as the lmdb key.
Note that we use the same key for the corresponding lq and gt images.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
Note that this key is different from lmdb keys.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
f'formats. But received {input_key}: {input_folder}; '
f'{gt_key}: {gt_folder}')
# ensure that the two meta_info files are the same
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
input_lmdb_keys = [line.split('.')[0] for line in fin]
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
gt_lmdb_keys = [line.split('.')[0] for line in fin]
if set(input_lmdb_keys) != set(gt_lmdb_keys):
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
else:
paths = []
for lmdb_key in sorted(input_lmdb_keys):
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
return paths
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
"""Generate paired paths from an meta information file.
Each line in the meta information file contains the image names and
image shape (usually for gt), separated by a white space.
Example of an meta information file:
```
0001_s001.png (480,480,3)
0001_s002.png (480,480,3)
```
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
meta_info_file (str): Path to the meta information file.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
with open(meta_info_file, 'r') as fin:
gt_names = [line.split(' ')[0] for line in fin]
paths = []
for gt_name in gt_names:
basename, ext = osp.splitext(osp.basename(gt_name))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
gt_path = osp.join(gt_folder, gt_name)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for gt_path in gt_paths:
basename, ext = osp.splitext(osp.basename(gt_path))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
gt_path = osp.join(gt_folder, gt_path)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paths_from_folder(folder):
"""Generate paths from folder.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
paths = list(scandir(folder))
paths = [osp.join(folder, path) for path in paths]
return paths
def paths_from_lmdb(folder):
"""Generate paths from lmdb.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
if not folder.endswith('.lmdb'):
raise ValueError(f'Folder {folder}folder should in lmdb format.')
with open(osp.join(folder, 'meta_info.txt')) as fin:
paths = [line.split('.')[0] for line in fin]
return paths
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
"""Generate Gaussian kernel used in `duf_downsample`.
Args:
kernel_size (int): Kernel size. Default: 13.
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
Returns:
np.array: The Gaussian kernel.
"""
from scipy.ndimage import filters as filters
kernel = np.zeros((kernel_size, kernel_size))
# set element at the middle to one, a dirac delta
kernel[kernel_size // 2, kernel_size // 2] = 1
# gaussian-smooth the dirac, resulting in a gaussian filter
return filters.gaussian_filter(kernel, sigma)
def duf_downsample(x, kernel_size=13, scale=4):
"""Downsamping with Gaussian kernel used in the DUF official code.
Args:
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
kernel_size (int): Kernel size. Default: 13.
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
Default: 4.
Returns:
Tensor: DUF downsampled frames.
"""
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
squeeze_flag = False
if x.ndim == 4:
squeeze_flag = True
x = x.unsqueeze(0)
b, t, c, h, w = x.size()
x = x.view(-1, 1, h, w)
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
x = F.conv2d(x, gaussian_filter, stride=scale)
x = x[:, :, 2:-2, 2:-2]
x = x.view(b, t, c, x.size(2), x.size(3))
if squeeze_flag:
x = x.squeeze(0)
return x

View File

@@ -0,0 +1,125 @@
import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader
class PrefetchGenerator(threading.Thread):
"""A general prefetch generator.
Ref:
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
Args:
generator: Python generator.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, generator, num_prefetch_queue):
threading.Thread.__init__(self)
self.queue = Queue.Queue(num_prefetch_queue)
self.generator = generator
self.daemon = True
self.start()
def run(self):
for item in self.generator:
self.queue.put(item)
self.queue.put(None)
def __next__(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class PrefetchDataLoader(DataLoader):
"""Prefetch version of dataloader.
Ref:
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
TODO:
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
ddp.
Args:
num_prefetch_queue (int): Number of prefetch queue.
kwargs (dict): Other arguments for dataloader.
"""
def __init__(self, num_prefetch_queue, **kwargs):
self.num_prefetch_queue = num_prefetch_queue
super(PrefetchDataLoader, self).__init__(**kwargs)
def __iter__(self):
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
class CPUPrefetcher():
"""CPU prefetcher.
Args:
loader: Dataloader.
"""
def __init__(self, loader):
self.ori_loader = loader
self.loader = iter(loader)
def next(self):
try:
return next(self.loader)
except StopIteration:
return None
def reset(self):
self.loader = iter(self.ori_loader)
class CUDAPrefetcher():
"""CUDA prefetcher.
Ref:
https://github.com/NVIDIA/apex/issues/304#
It may consums more GPU memory.
Args:
loader: Dataloader.
opt (dict): Options.
"""
def __init__(self, loader, opt):
self.ori_loader = loader
self.loader = iter(loader)
self.opt = opt
self.stream = torch.cuda.Stream()
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.preload()
def preload(self):
try:
self.batch = next(self.loader) # self.batch is a dict
except StopIteration:
self.batch = None
return None
# put tensors to gpu
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
self.preload()
return batch
def reset(self):
self.loader = iter(self.ori_loader)
self.preload()

165
basicsr/data/transforms.py Normal file
View File

@@ -0,0 +1,165 @@
import cv2
import random
def mod_crop(img, scale):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
"""Paired random crop.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
h_lq, w_lq, _ = img_lqs[0].shape
h_gt, w_gt, _ = img_gts[0].shape
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
We use vertical flip and transpose for rotation implementation.
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Ratotation. Default: True.
flows (list[ndarray]: Flows to be augmented. If the input is an
ndarray, it will be transformed to a list.
Dimension is (h, w, 2). Default: None.
return_status (bool): Return the status of flip and rotation.
Default: False.
Returns:
list[ndarray] | ndarray: Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip: # horizontal
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip: # vertical
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
if return_status:
return imgs, (hflip, vflip, rot90)
else:
return imgs
def img_rotate(img, angle, center=None, scale=1.0):
"""Rotate image.
Args:
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees. Positive values mean
counter-clockwise rotation.
center (tuple[int]): Rotation center. If the center is None,
initialize it as the center of the image. Default: None.
scale (float): Isotropic scale factor. Default: 1.0.
"""
(h, w) = img.shape[:2]
if center is None:
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
rotated_img = cv2.warpAffine(img, matrix, (w, h))
return rotated_img