Add codeformer and update license
This commit is contained in:
29
basicsr/utils/__init__.py
Normal file
29
basicsr/utils/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from .file_client import FileClient
|
||||
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
|
||||
from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
|
||||
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
|
||||
|
||||
__all__ = [
|
||||
# file_client.py
|
||||
'FileClient',
|
||||
# img_util.py
|
||||
'img2tensor',
|
||||
'tensor2img',
|
||||
'imfrombytes',
|
||||
'imwrite',
|
||||
'crop_border',
|
||||
# logger.py
|
||||
'MessageLogger',
|
||||
'init_tb_logger',
|
||||
'init_wandb_logger',
|
||||
'get_root_logger',
|
||||
'get_env_info',
|
||||
# misc.py
|
||||
'set_random_seed',
|
||||
'get_time_str',
|
||||
'mkdir_and_rename',
|
||||
'make_exp_dirs',
|
||||
'scandir',
|
||||
'check_resume',
|
||||
'sizeof_fmt'
|
||||
]
|
||||
82
basicsr/utils/dist_util.py
Normal file
82
basicsr/utils/dist_util.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
||||
import functools
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def init_dist(launcher, backend='nccl', **kwargs):
|
||||
if mp.get_start_method(allow_none=True) is None:
|
||||
mp.set_start_method('spawn')
|
||||
if launcher == 'pytorch':
|
||||
_init_dist_pytorch(backend, **kwargs)
|
||||
elif launcher == 'slurm':
|
||||
_init_dist_slurm(backend, **kwargs)
|
||||
else:
|
||||
raise ValueError(f'Invalid launcher type: {launcher}')
|
||||
|
||||
|
||||
def _init_dist_pytorch(backend, **kwargs):
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(rank % num_gpus)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
|
||||
def _init_dist_slurm(backend, port=None):
|
||||
"""Initialize slurm distributed training environment.
|
||||
|
||||
If argument ``port`` is not specified, then the master port will be system
|
||||
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
||||
environment variable, then a default port ``29500`` will be used.
|
||||
|
||||
Args:
|
||||
backend (str): Backend of torch.distributed.
|
||||
port (int, optional): Master port. Defaults to None.
|
||||
"""
|
||||
proc_id = int(os.environ['SLURM_PROCID'])
|
||||
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||
node_list = os.environ['SLURM_NODELIST']
|
||||
num_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(proc_id % num_gpus)
|
||||
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
|
||||
# specify master port
|
||||
if port is not None:
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
elif 'MASTER_PORT' in os.environ:
|
||||
pass # use MASTER_PORT in the environment variable
|
||||
else:
|
||||
# 29500 is torch.distributed default port
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
os.environ['MASTER_ADDR'] = addr
|
||||
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
||||
os.environ['RANK'] = str(proc_id)
|
||||
dist.init_process_group(backend=backend)
|
||||
|
||||
|
||||
def get_dist_info():
|
||||
if dist.is_available():
|
||||
initialized = dist.is_initialized()
|
||||
else:
|
||||
initialized = False
|
||||
if initialized:
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def master_only(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
95
basicsr/utils/download_util.py
Normal file
95
basicsr/utils/download_util.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import math
|
||||
import os
|
||||
import requests
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
from tqdm import tqdm
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .misc import sizeof_fmt
|
||||
|
||||
|
||||
def download_file_from_google_drive(file_id, save_path):
|
||||
"""Download files from google drive.
|
||||
Ref:
|
||||
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
||||
Args:
|
||||
file_id (str): File id.
|
||||
save_path (str): Save path.
|
||||
"""
|
||||
|
||||
session = requests.Session()
|
||||
URL = 'https://docs.google.com/uc?export=download'
|
||||
params = {'id': file_id}
|
||||
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
token = get_confirm_token(response)
|
||||
if token:
|
||||
params['confirm'] = token
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
# get file size
|
||||
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
||||
print(response_file_size)
|
||||
if 'Content-Range' in response_file_size.headers:
|
||||
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
save_response_content(response, save_path, file_size)
|
||||
|
||||
|
||||
def get_confirm_token(response):
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith('download_warning'):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def save_response_content(response, destination, file_size=None, chunk_size=32768):
|
||||
if file_size is not None:
|
||||
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
||||
|
||||
readable_file_size = sizeof_fmt(file_size)
|
||||
else:
|
||||
pbar = None
|
||||
|
||||
with open(destination, 'wb') as f:
|
||||
downloaded_size = 0
|
||||
for chunk in response.iter_content(chunk_size):
|
||||
downloaded_size += chunk_size
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
f.write(chunk)
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
|
||||
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
||||
"""Load file form http url, will download models if necessary.
|
||||
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
||||
Args:
|
||||
url (str): URL to be downloaded.
|
||||
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
||||
Default: None.
|
||||
progress (bool): Whether to show the download progress. Default: True.
|
||||
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
||||
Returns:
|
||||
str: The path to the downloaded file.
|
||||
"""
|
||||
if model_dir is None: # use the pytorch hub_dir
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if file_name is not None:
|
||||
filename = file_name
|
||||
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
||||
if not os.path.exists(cached_file):
|
||||
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
||||
return cached_file
|
||||
167
basicsr/utils/file_client.py
Normal file
167
basicsr/utils/file_client.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseStorageBackend(metaclass=ABCMeta):
|
||||
"""Abstract class of storage backends.
|
||||
|
||||
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
||||
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
||||
as texts.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, filepath):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_text(self, filepath):
|
||||
pass
|
||||
|
||||
|
||||
class MemcachedBackend(BaseStorageBackend):
|
||||
"""Memcached storage backend.
|
||||
|
||||
Attributes:
|
||||
server_list_cfg (str): Config file for memcached server list.
|
||||
client_cfg (str): Config file for memcached client.
|
||||
sys_path (str | None): Additional path to be appended to `sys.path`.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
||||
if sys_path is not None:
|
||||
import sys
|
||||
sys.path.append(sys_path)
|
||||
try:
|
||||
import mc
|
||||
except ImportError:
|
||||
raise ImportError('Please install memcached to enable MemcachedBackend.')
|
||||
|
||||
self.server_list_cfg = server_list_cfg
|
||||
self.client_cfg = client_cfg
|
||||
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
|
||||
# mc.pyvector servers as a point which points to a memory cache
|
||||
self._mc_buffer = mc.pyvector()
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
import mc
|
||||
self._client.Get(filepath, self._mc_buffer)
|
||||
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HardDiskBackend(BaseStorageBackend):
|
||||
"""Raw hard disks storage backend."""
|
||||
|
||||
def get(self, filepath):
|
||||
filepath = str(filepath)
|
||||
with open(filepath, 'rb') as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
filepath = str(filepath)
|
||||
with open(filepath, 'r') as f:
|
||||
value_buf = f.read()
|
||||
return value_buf
|
||||
|
||||
|
||||
class LmdbBackend(BaseStorageBackend):
|
||||
"""Lmdb storage backend.
|
||||
|
||||
Args:
|
||||
db_paths (str | list[str]): Lmdb database paths.
|
||||
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
|
||||
readonly (bool, optional): Lmdb environment parameter. If True,
|
||||
disallow any write operations. Default: True.
|
||||
lock (bool, optional): Lmdb environment parameter. If False, when
|
||||
concurrent access occurs, do not lock the database. Default: False.
|
||||
readahead (bool, optional): Lmdb environment parameter. If False,
|
||||
disable the OS filesystem readahead mechanism, which may improve
|
||||
random read performance when a database is larger than RAM.
|
||||
Default: False.
|
||||
|
||||
Attributes:
|
||||
db_paths (list): Lmdb database path.
|
||||
_client (list): A list of several lmdb envs.
|
||||
"""
|
||||
|
||||
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
|
||||
try:
|
||||
import lmdb
|
||||
except ImportError:
|
||||
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
||||
|
||||
if isinstance(client_keys, str):
|
||||
client_keys = [client_keys]
|
||||
|
||||
if isinstance(db_paths, list):
|
||||
self.db_paths = [str(v) for v in db_paths]
|
||||
elif isinstance(db_paths, str):
|
||||
self.db_paths = [str(db_paths)]
|
||||
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
|
||||
f'but received {len(client_keys)} and {len(self.db_paths)}.')
|
||||
|
||||
self._client = {}
|
||||
for client, path in zip(client_keys, self.db_paths):
|
||||
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
|
||||
|
||||
def get(self, filepath, client_key):
|
||||
"""Get values according to the filepath from one lmdb named client_key.
|
||||
|
||||
Args:
|
||||
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
||||
client_key (str): Used for distinguishing differnet lmdb envs.
|
||||
"""
|
||||
filepath = str(filepath)
|
||||
assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
|
||||
client = self._client[client_key]
|
||||
with client.begin(write=False) as txn:
|
||||
value_buf = txn.get(filepath.encode('ascii'))
|
||||
return value_buf
|
||||
|
||||
def get_text(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileClient(object):
|
||||
"""A general file client to access files in different backend.
|
||||
|
||||
The client loads a file or text in a specified backend from its path
|
||||
and return it as a binary file. it can also register other backend
|
||||
accessor with a given name and backend class.
|
||||
|
||||
Attributes:
|
||||
backend (str): The storage backend type. Options are "disk",
|
||||
"memcached" and "lmdb".
|
||||
client (:obj:`BaseStorageBackend`): The backend object.
|
||||
"""
|
||||
|
||||
_backends = {
|
||||
'disk': HardDiskBackend,
|
||||
'memcached': MemcachedBackend,
|
||||
'lmdb': LmdbBackend,
|
||||
}
|
||||
|
||||
def __init__(self, backend='disk', **kwargs):
|
||||
if backend not in self._backends:
|
||||
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
|
||||
f' are {list(self._backends.keys())}')
|
||||
self.backend = backend
|
||||
self.client = self._backends[backend](**kwargs)
|
||||
|
||||
def get(self, filepath, client_key='default'):
|
||||
# client_key is used only for lmdb, where different fileclients have
|
||||
# different lmdb environments.
|
||||
if self.backend == 'lmdb':
|
||||
return self.client.get(filepath, client_key)
|
||||
else:
|
||||
return self.client.get(filepath)
|
||||
|
||||
def get_text(self, filepath):
|
||||
return self.client.get_text(filepath)
|
||||
170
basicsr/utils/img_util.py
Normal file
170
basicsr/utils/img_util.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
|
||||
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||
"""Numpy array to tensor.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Input images.
|
||||
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||
float32 (bool): Whether to change to float32.
|
||||
|
||||
Returns:
|
||||
list[tensor] | tensor: Tensor images. If returned results only have
|
||||
one element, just return tensor.
|
||||
"""
|
||||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
if img.dtype == 'float64':
|
||||
img = img.astype('float32')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
img = img.float()
|
||||
return img
|
||||
|
||||
if isinstance(imgs, list):
|
||||
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||
else:
|
||||
return _totensor(imgs, bgr2rgb, float32)
|
||||
|
||||
|
||||
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
||||
"""Convert torch Tensors into image numpy arrays.
|
||||
|
||||
After clamping to [min, max], values will be normalized to [0, 1].
|
||||
|
||||
Args:
|
||||
tensor (Tensor or list[Tensor]): Accept shapes:
|
||||
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
||||
2) 3D Tensor of shape (3/1 x H x W);
|
||||
3) 2D Tensor of shape (H x W).
|
||||
Tensor channel should be in RGB order.
|
||||
rgb2bgr (bool): Whether to change rgb to bgr.
|
||||
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
||||
to uint8 type with range [0, 255]; otherwise, float type with
|
||||
range [0, 1]. Default: ``np.uint8``.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
|
||||
Returns:
|
||||
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
||||
shape (H x W). The channel order is BGR.
|
||||
"""
|
||||
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
||||
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
||||
|
||||
if torch.is_tensor(tensor):
|
||||
tensor = [tensor]
|
||||
result = []
|
||||
for _tensor in tensor:
|
||||
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||
|
||||
n_dim = _tensor.dim()
|
||||
if n_dim == 4:
|
||||
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 3:
|
||||
img_np = _tensor.numpy()
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if img_np.shape[2] == 1: # gray image
|
||||
img_np = np.squeeze(img_np, axis=2)
|
||||
else:
|
||||
if rgb2bgr:
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
||||
elif n_dim == 2:
|
||||
img_np = _tensor.numpy()
|
||||
else:
|
||||
raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
|
||||
if out_type == np.uint8:
|
||||
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
||||
img_np = (img_np * 255.0).round()
|
||||
img_np = img_np.astype(out_type)
|
||||
result.append(img_np)
|
||||
if len(result) == 1:
|
||||
result = result[0]
|
||||
return result
|
||||
|
||||
|
||||
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
|
||||
"""This implementation is slightly faster than tensor2img.
|
||||
It now only supports torch tensor with shape (1, c, h, w).
|
||||
|
||||
Args:
|
||||
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
|
||||
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
|
||||
min_max (tuple[int]): min and max values for clamp.
|
||||
"""
|
||||
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
|
||||
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
|
||||
output = output.type(torch.uint8).cpu().numpy()
|
||||
if rgb2bgr:
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
||||
|
||||
|
||||
def imfrombytes(content, flag='color', float32=False):
|
||||
"""Read an image from bytes.
|
||||
|
||||
Args:
|
||||
content (bytes): Image bytes got from files or other streams.
|
||||
flag (str): Flags specifying the color type of a loaded image,
|
||||
candidates are `color`, `grayscale` and `unchanged`.
|
||||
float32 (bool): Whether to change to float32., If True, will also norm
|
||||
to [0, 1]. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: Loaded image array.
|
||||
"""
|
||||
img_np = np.frombuffer(content, np.uint8)
|
||||
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
|
||||
img = cv2.imdecode(img_np, imread_flags[flag])
|
||||
if float32:
|
||||
img = img.astype(np.float32) / 255.
|
||||
return img
|
||||
|
||||
|
||||
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
||||
"""Write image to file.
|
||||
|
||||
Args:
|
||||
img (ndarray): Image array to be written.
|
||||
file_path (str): Image file path.
|
||||
params (None or list): Same as opencv's :func:`imwrite` interface.
|
||||
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
||||
whether to create it automatically.
|
||||
|
||||
Returns:
|
||||
bool: Successful or not.
|
||||
"""
|
||||
if auto_mkdir:
|
||||
dir_name = os.path.abspath(os.path.dirname(file_path))
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
return cv2.imwrite(file_path, img, params)
|
||||
|
||||
|
||||
def crop_border(imgs, crop_border):
|
||||
"""Crop borders of images.
|
||||
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
|
||||
crop_border (int): Crop border for each end of height and weight.
|
||||
|
||||
Returns:
|
||||
list[ndarray]: Cropped images.
|
||||
"""
|
||||
if crop_border == 0:
|
||||
return imgs
|
||||
else:
|
||||
if isinstance(imgs, list):
|
||||
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
|
||||
else:
|
||||
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
|
||||
196
basicsr/utils/lmdb_util.py
Normal file
196
basicsr/utils/lmdb_util.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import cv2
|
||||
import lmdb
|
||||
import sys
|
||||
from multiprocessing import Pool
|
||||
from os import path as osp
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def make_lmdb_from_imgs(data_path,
|
||||
lmdb_path,
|
||||
img_path_list,
|
||||
keys,
|
||||
batch=5000,
|
||||
compress_level=1,
|
||||
multiprocessing_read=False,
|
||||
n_thread=40,
|
||||
map_size=None):
|
||||
"""Make lmdb from images.
|
||||
|
||||
Contents of lmdb. The file structure is:
|
||||
example.lmdb
|
||||
├── data.mdb
|
||||
├── lock.mdb
|
||||
├── meta_info.txt
|
||||
|
||||
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
||||
https://lmdb.readthedocs.io/en/release/ for more details.
|
||||
|
||||
The meta_info.txt is a specified txt file to record the meta information
|
||||
of our datasets. It will be automatically created when preparing
|
||||
datasets by our provided dataset tools.
|
||||
Each line in the txt file records 1)image name (with extension),
|
||||
2)image shape, and 3)compression level, separated by a white space.
|
||||
|
||||
For example, the meta information could be:
|
||||
`000_00000000.png (720,1280,3) 1`, which means:
|
||||
1) image name (with extension): 000_00000000.png;
|
||||
2) image shape: (720,1280,3);
|
||||
3) compression level: 1
|
||||
|
||||
We use the image name without extension as the lmdb key.
|
||||
|
||||
If `multiprocessing_read` is True, it will read all the images to memory
|
||||
using multiprocessing. Thus, your server needs to have enough memory.
|
||||
|
||||
Args:
|
||||
data_path (str): Data path for reading images.
|
||||
lmdb_path (str): Lmdb save path.
|
||||
img_path_list (str): Image path list.
|
||||
keys (str): Used for lmdb keys.
|
||||
batch (int): After processing batch images, lmdb commits.
|
||||
Default: 5000.
|
||||
compress_level (int): Compress level when encoding images. Default: 1.
|
||||
multiprocessing_read (bool): Whether use multiprocessing to read all
|
||||
the images to memory. Default: False.
|
||||
n_thread (int): For multiprocessing.
|
||||
map_size (int | None): Map size for lmdb env. If None, use the
|
||||
estimated size from images. Default: None
|
||||
"""
|
||||
|
||||
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
|
||||
f'but got {len(img_path_list)} and {len(keys)}')
|
||||
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
|
||||
print(f'Totoal images: {len(img_path_list)}')
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
print(f'Folder {lmdb_path} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
if multiprocessing_read:
|
||||
# read all the images to memory (multiprocessing)
|
||||
dataset = {} # use dict to keep the order for multiprocessing
|
||||
shapes = {}
|
||||
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
|
||||
pbar = tqdm(total=len(img_path_list), unit='image')
|
||||
|
||||
def callback(arg):
|
||||
"""get the image data and update pbar."""
|
||||
key, dataset[key], shapes[key] = arg
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Read {key}')
|
||||
|
||||
pool = Pool(n_thread)
|
||||
for path, key in zip(img_path_list, keys):
|
||||
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
|
||||
pool.close()
|
||||
pool.join()
|
||||
pbar.close()
|
||||
print(f'Finish reading {len(img_path_list)} images.')
|
||||
|
||||
# create lmdb environment
|
||||
if map_size is None:
|
||||
# obtain data size for one image
|
||||
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
data_size_per_img = img_byte.nbytes
|
||||
print('Data size per image is: ', data_size_per_img)
|
||||
data_size = data_size_per_img * len(img_path_list)
|
||||
map_size = data_size * 10
|
||||
|
||||
env = lmdb.open(lmdb_path, map_size=map_size)
|
||||
|
||||
# write data to lmdb
|
||||
pbar = tqdm(total=len(img_path_list), unit='chunk')
|
||||
txn = env.begin(write=True)
|
||||
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
||||
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Write {key}')
|
||||
key_byte = key.encode('ascii')
|
||||
if multiprocessing_read:
|
||||
img_byte = dataset[key]
|
||||
h, w, c = shapes[key]
|
||||
else:
|
||||
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
|
||||
h, w, c = img_shape
|
||||
|
||||
txn.put(key_byte, img_byte)
|
||||
# write meta information
|
||||
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
|
||||
if idx % batch == 0:
|
||||
txn.commit()
|
||||
txn = env.begin(write=True)
|
||||
pbar.close()
|
||||
txn.commit()
|
||||
env.close()
|
||||
txt_file.close()
|
||||
print('\nFinish writing lmdb.')
|
||||
|
||||
|
||||
def read_img_worker(path, key, compress_level):
|
||||
"""Read image worker.
|
||||
|
||||
Args:
|
||||
path (str): Image path.
|
||||
key (str): Image key.
|
||||
compress_level (int): Compress level when encoding images.
|
||||
|
||||
Returns:
|
||||
str: Image key.
|
||||
byte: Image byte.
|
||||
tuple[int]: Image shape.
|
||||
"""
|
||||
|
||||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
if img.ndim == 2:
|
||||
h, w = img.shape
|
||||
c = 1
|
||||
else:
|
||||
h, w, c = img.shape
|
||||
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
|
||||
return (key, img_byte, (h, w, c))
|
||||
|
||||
|
||||
class LmdbMaker():
|
||||
"""LMDB Maker.
|
||||
|
||||
Args:
|
||||
lmdb_path (str): Lmdb save path.
|
||||
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
|
||||
batch (int): After processing batch images, lmdb commits.
|
||||
Default: 5000.
|
||||
compress_level (int): Compress level when encoding images. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
|
||||
if not lmdb_path.endswith('.lmdb'):
|
||||
raise ValueError("lmdb_path must end with '.lmdb'.")
|
||||
if osp.exists(lmdb_path):
|
||||
print(f'Folder {lmdb_path} already exists. Exit.')
|
||||
sys.exit(1)
|
||||
|
||||
self.lmdb_path = lmdb_path
|
||||
self.batch = batch
|
||||
self.compress_level = compress_level
|
||||
self.env = lmdb.open(lmdb_path, map_size=map_size)
|
||||
self.txn = self.env.begin(write=True)
|
||||
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
|
||||
self.counter = 0
|
||||
|
||||
def put(self, img_byte, key, img_shape):
|
||||
self.counter += 1
|
||||
key_byte = key.encode('ascii')
|
||||
self.txn.put(key_byte, img_byte)
|
||||
# write meta information
|
||||
h, w, c = img_shape
|
||||
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
|
||||
if self.counter % self.batch == 0:
|
||||
self.txn.commit()
|
||||
self.txn = self.env.begin(write=True)
|
||||
|
||||
def close(self):
|
||||
self.txn.commit()
|
||||
self.env.close()
|
||||
self.txt_file.close()
|
||||
169
basicsr/utils/logger.py
Normal file
169
basicsr/utils/logger.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
from .dist_util import get_dist_info, master_only
|
||||
|
||||
initialized_logger = {}
|
||||
|
||||
|
||||
class MessageLogger():
|
||||
"""Message logger for printing.
|
||||
Args:
|
||||
opt (dict): Config. It contains the following keys:
|
||||
name (str): Exp name.
|
||||
logger (dict): Contains 'print_freq' (str) for logger interval.
|
||||
train (dict): Contains 'total_iter' (int) for total iters.
|
||||
use_tb_logger (bool): Use tensorboard logger.
|
||||
start_iter (int): Start iter. Default: 1.
|
||||
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, opt, start_iter=1, tb_logger=None):
|
||||
self.exp_name = opt['name']
|
||||
self.interval = opt['logger']['print_freq']
|
||||
self.start_iter = start_iter
|
||||
self.max_iters = opt['train']['total_iter']
|
||||
self.use_tb_logger = opt['logger']['use_tb_logger']
|
||||
self.tb_logger = tb_logger
|
||||
self.start_time = time.time()
|
||||
self.logger = get_root_logger()
|
||||
|
||||
@master_only
|
||||
def __call__(self, log_vars):
|
||||
"""Format logging message.
|
||||
Args:
|
||||
log_vars (dict): It contains the following keys:
|
||||
epoch (int): Epoch number.
|
||||
iter (int): Current iter.
|
||||
lrs (list): List for learning rates.
|
||||
time (float): Iter time.
|
||||
data_time (float): Data time for each iter.
|
||||
"""
|
||||
# epoch, iter, learning rates
|
||||
epoch = log_vars.pop('epoch')
|
||||
current_iter = log_vars.pop('iter')
|
||||
lrs = log_vars.pop('lrs')
|
||||
|
||||
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
|
||||
for v in lrs:
|
||||
message += f'{v:.3e},'
|
||||
message += ')] '
|
||||
|
||||
# time and estimated time
|
||||
if 'time' in log_vars.keys():
|
||||
iter_time = log_vars.pop('time')
|
||||
data_time = log_vars.pop('data_time')
|
||||
|
||||
total_time = time.time() - self.start_time
|
||||
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
||||
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
message += f'[eta: {eta_str}, '
|
||||
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
|
||||
|
||||
# other items, especially losses
|
||||
for k, v in log_vars.items():
|
||||
message += f'{k}: {v:.4e} '
|
||||
# tensorboard logger
|
||||
if self.use_tb_logger:
|
||||
if k.startswith('l_'):
|
||||
self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
|
||||
else:
|
||||
self.tb_logger.add_scalar(k, v, current_iter)
|
||||
self.logger.info(message)
|
||||
|
||||
|
||||
@master_only
|
||||
def init_tb_logger(log_dir):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
tb_logger = SummaryWriter(log_dir=log_dir)
|
||||
return tb_logger
|
||||
|
||||
|
||||
@master_only
|
||||
def init_wandb_logger(opt):
|
||||
"""We now only use wandb to sync tensorboard log."""
|
||||
import wandb
|
||||
logger = logging.getLogger('basicsr')
|
||||
|
||||
project = opt['logger']['wandb']['project']
|
||||
resume_id = opt['logger']['wandb'].get('resume_id')
|
||||
if resume_id:
|
||||
wandb_id = resume_id
|
||||
resume = 'allow'
|
||||
logger.warning(f'Resume wandb logger with id={wandb_id}.')
|
||||
else:
|
||||
wandb_id = wandb.util.generate_id()
|
||||
resume = 'never'
|
||||
|
||||
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
|
||||
|
||||
logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
|
||||
|
||||
|
||||
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
||||
"""Get the root logger.
|
||||
The logger will be initialized if it has not been initialized. By default a
|
||||
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
||||
also be added.
|
||||
Args:
|
||||
logger_name (str): root logger name. Default: 'basicsr'.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the root logger.
|
||||
log_level (int): The root logger level. Note that only the process of
|
||||
rank 0 is affected, while other processes will set the level to
|
||||
"Error" and be silent most of the time.
|
||||
Returns:
|
||||
logging.Logger: The root logger.
|
||||
"""
|
||||
logger = logging.getLogger(logger_name)
|
||||
# if the logger has been initialized, just return it
|
||||
if logger_name in initialized_logger:
|
||||
return logger
|
||||
|
||||
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter(format_str))
|
||||
logger.addHandler(stream_handler)
|
||||
logger.propagate = False
|
||||
rank, _ = get_dist_info()
|
||||
if rank != 0:
|
||||
logger.setLevel('ERROR')
|
||||
elif log_file is not None:
|
||||
logger.setLevel(log_level)
|
||||
# add file handler
|
||||
# file_handler = logging.FileHandler(log_file, 'w')
|
||||
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
|
||||
file_handler.setFormatter(logging.Formatter(format_str))
|
||||
file_handler.setLevel(log_level)
|
||||
logger.addHandler(file_handler)
|
||||
initialized_logger[logger_name] = True
|
||||
return logger
|
||||
|
||||
|
||||
def get_env_info():
|
||||
"""Get environment information.
|
||||
Currently, only log the software version.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from basicsr.version import __version__
|
||||
msg = r"""
|
||||
____ _ _____ ____
|
||||
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
|
||||
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
|
||||
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
|
||||
/_____/ \__,_//____//_/ \___//____//_/ |_|
|
||||
______ __ __ __ __
|
||||
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
|
||||
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
|
||||
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
|
||||
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
|
||||
"""
|
||||
msg += ('\nVersion Information: '
|
||||
f'\n\tBasicSR: {__version__}'
|
||||
f'\n\tPyTorch: {torch.__version__}'
|
||||
f'\n\tTorchVision: {torchvision.__version__}')
|
||||
return msg
|
||||
347
basicsr/utils/matlab_functions.py
Normal file
347
basicsr/utils/matlab_functions.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def cubic(x):
|
||||
"""cubic function used for calculate_weights_indices."""
|
||||
absx = torch.abs(x)
|
||||
absx2 = absx**2
|
||||
absx3 = absx**3
|
||||
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
|
||||
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
|
||||
(absx <= 2)).type_as(absx))
|
||||
|
||||
|
||||
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
||||
"""Calculate weights and indices, used for imresize function.
|
||||
|
||||
Args:
|
||||
in_length (int): Input length.
|
||||
out_length (int): Output length.
|
||||
scale (float): Scale factor.
|
||||
kernel_width (int): Kernel width.
|
||||
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
||||
"""
|
||||
|
||||
if (scale < 1) and antialiasing:
|
||||
# Use a modified kernel (larger kernel width) to simultaneously
|
||||
# interpolate and antialias
|
||||
kernel_width = kernel_width / scale
|
||||
|
||||
# Output-space coordinates
|
||||
x = torch.linspace(1, out_length, out_length)
|
||||
|
||||
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
||||
# in output space maps to 0.5 in input space, and 0.5 + scale in output
|
||||
# space maps to 1.5 in input space.
|
||||
u = x / scale + 0.5 * (1 - 1 / scale)
|
||||
|
||||
# What is the left-most pixel that can be involved in the computation?
|
||||
left = torch.floor(u - kernel_width / 2)
|
||||
|
||||
# What is the maximum number of pixels that can be involved in the
|
||||
# computation? Note: it's OK to use an extra pixel here; if the
|
||||
# corresponding weights are all zero, it will be eliminated at the end
|
||||
# of this function.
|
||||
p = math.ceil(kernel_width) + 2
|
||||
|
||||
# The indices of the input pixels involved in computing the k-th output
|
||||
# pixel are in row k of the indices matrix.
|
||||
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
|
||||
out_length, p)
|
||||
|
||||
# The weights used to compute the k-th output pixel are in row k of the
|
||||
# weights matrix.
|
||||
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
|
||||
|
||||
# apply cubic kernel
|
||||
if (scale < 1) and antialiasing:
|
||||
weights = scale * cubic(distance_to_center * scale)
|
||||
else:
|
||||
weights = cubic(distance_to_center)
|
||||
|
||||
# Normalize the weights matrix so that each row sums to 1.
|
||||
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
||||
weights = weights / weights_sum.expand(out_length, p)
|
||||
|
||||
# If a column in weights is all zero, get rid of it. only consider the
|
||||
# first and last column.
|
||||
weights_zero_tmp = torch.sum((weights == 0), 0)
|
||||
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 1, p - 2)
|
||||
weights = weights.narrow(1, 1, p - 2)
|
||||
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
||||
indices = indices.narrow(1, 0, p - 2)
|
||||
weights = weights.narrow(1, 0, p - 2)
|
||||
weights = weights.contiguous()
|
||||
indices = indices.contiguous()
|
||||
sym_len_s = -indices.min() + 1
|
||||
sym_len_e = indices.max() - in_length
|
||||
indices = indices + sym_len_s - 1
|
||||
return weights, indices, int(sym_len_s), int(sym_len_e)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def imresize(img, scale, antialiasing=True):
|
||||
"""imresize function same as MATLAB.
|
||||
|
||||
It now only supports bicubic.
|
||||
The same scale applies for both height and width.
|
||||
|
||||
Args:
|
||||
img (Tensor | Numpy array):
|
||||
Tensor: Input image with shape (c, h, w), [0, 1] range.
|
||||
Numpy: Input image with shape (h, w, c), [0, 1] range.
|
||||
scale (float): Scale factor. The same scale applies for both height
|
||||
and width.
|
||||
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
|
||||
Default: True.
|
||||
|
||||
Returns:
|
||||
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
|
||||
"""
|
||||
if type(img).__module__ == np.__name__: # numpy type
|
||||
numpy_type = True
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||
else:
|
||||
numpy_type = False
|
||||
|
||||
in_c, in_h, in_w = img.size()
|
||||
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
|
||||
kernel_width = 4
|
||||
kernel = 'cubic'
|
||||
|
||||
# get weights and indices
|
||||
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
|
||||
antialiasing)
|
||||
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
|
||||
antialiasing)
|
||||
# process H dimension
|
||||
# symmetric copying
|
||||
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
|
||||
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
|
||||
|
||||
sym_patch = img[:, :sym_len_hs, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = img[:, -sym_len_he:, :]
|
||||
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
||||
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
|
||||
|
||||
out_1 = torch.FloatTensor(in_c, out_h, in_w)
|
||||
kernel_width = weights_h.size(1)
|
||||
for i in range(out_h):
|
||||
idx = int(indices_h[i][0])
|
||||
for j in range(in_c):
|
||||
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
|
||||
|
||||
# process W dimension
|
||||
# symmetric copying
|
||||
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
|
||||
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
|
||||
|
||||
sym_patch = out_1[:, :, :sym_len_ws]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
|
||||
|
||||
sym_patch = out_1[:, :, -sym_len_we:]
|
||||
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
||||
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
||||
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
|
||||
|
||||
out_2 = torch.FloatTensor(in_c, out_h, out_w)
|
||||
kernel_width = weights_w.size(1)
|
||||
for i in range(out_w):
|
||||
idx = int(indices_w[i][0])
|
||||
for j in range(in_c):
|
||||
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
|
||||
|
||||
if numpy_type:
|
||||
out_2 = out_2.numpy().transpose(1, 2, 0)
|
||||
return out_2
|
||||
|
||||
|
||||
def rgb2ycbcr(img, y_only=False):
|
||||
"""Convert a RGB image to YCbCr image.
|
||||
|
||||
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def bgr2ycbcr(img, y_only=False):
|
||||
"""Convert a BGR image to YCbCr image.
|
||||
|
||||
The bgr version of rgb2ycbcr.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
y_only (bool): Whether to only return Y channel. Default: False.
|
||||
|
||||
Returns:
|
||||
ndarray: The converted YCbCr image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img)
|
||||
if y_only:
|
||||
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
||||
else:
|
||||
out_img = np.matmul(
|
||||
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2rgb(img):
|
||||
"""Convert a YCbCr image to RGB image.
|
||||
|
||||
This function produces the same results as Matlab's ycbcr2rgb function.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted RGB image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
||||
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def ycbcr2bgr(img):
|
||||
"""Convert a YCbCr image to BGR image.
|
||||
|
||||
The bgr version of ycbcr2rgb.
|
||||
It implements the ITU-R BT.601 conversion for standard-definition
|
||||
television. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
||||
|
||||
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
||||
In OpenCV, it implements a JPEG conversion. See more details in
|
||||
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
ndarray: The converted BGR image. The output image has the same type
|
||||
and range as input image.
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = _convert_input_type_range(img) * 255
|
||||
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
|
||||
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
|
||||
out_img = _convert_output_type_range(out_img, img_type)
|
||||
return out_img
|
||||
|
||||
|
||||
def _convert_input_type_range(img):
|
||||
"""Convert the type and range of the input image.
|
||||
|
||||
It converts the input image to np.float32 type and range of [0, 1].
|
||||
It is mainly used for pre-processing the input image in colorspace
|
||||
convertion functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The input image. It accepts:
|
||||
1. np.uint8 type with range [0, 255];
|
||||
2. np.float32 type with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with type of np.float32 and range of
|
||||
[0, 1].
|
||||
"""
|
||||
img_type = img.dtype
|
||||
img = img.astype(np.float32)
|
||||
if img_type == np.float32:
|
||||
pass
|
||||
elif img_type == np.uint8:
|
||||
img /= 255.
|
||||
else:
|
||||
raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
|
||||
return img
|
||||
|
||||
|
||||
def _convert_output_type_range(img, dst_type):
|
||||
"""Convert the type and range of the image according to dst_type.
|
||||
|
||||
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
||||
images will be converted to np.uint8 type with range [0, 255]. If
|
||||
`dst_type` is np.float32, it converts the image to np.float32 type with
|
||||
range [0, 1].
|
||||
It is mainly used for post-processing images in colorspace convertion
|
||||
functions such as rgb2ycbcr and ycbcr2rgb.
|
||||
|
||||
Args:
|
||||
img (ndarray): The image to be converted with np.float32 type and
|
||||
range [0, 255].
|
||||
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
||||
converts the image to np.uint8 type with range [0, 255]. If
|
||||
dst_type is np.float32, it converts the image to np.float32 type
|
||||
with range [0, 1].
|
||||
|
||||
Returns:
|
||||
(ndarray): The converted image with desired type and range.
|
||||
"""
|
||||
if dst_type not in (np.uint8, np.float32):
|
||||
raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
|
||||
if dst_type == np.uint8:
|
||||
img = img.round()
|
||||
else:
|
||||
img /= 255.
|
||||
return img.astype(dst_type)
|
||||
134
basicsr/utils/misc.py
Normal file
134
basicsr/utils/misc.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
from os import path as osp
|
||||
|
||||
from .dist_util import master_only
|
||||
from .logger import get_root_logger
|
||||
|
||||
|
||||
def set_random_seed(seed):
|
||||
"""Set random seeds."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def get_time_str():
|
||||
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
|
||||
|
||||
def mkdir_and_rename(path):
|
||||
"""mkdirs. If path exists, rename it with timestamp and create a new one.
|
||||
|
||||
Args:
|
||||
path (str): Folder path.
|
||||
"""
|
||||
if osp.exists(path):
|
||||
new_name = path + '_archived_' + get_time_str()
|
||||
print(f'Path already exists. Rename it to {new_name}', flush=True)
|
||||
os.rename(path, new_name)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
@master_only
|
||||
def make_exp_dirs(opt):
|
||||
"""Make dirs for experiments."""
|
||||
path_opt = opt['path'].copy()
|
||||
if opt['is_train']:
|
||||
mkdir_and_rename(path_opt.pop('experiments_root'))
|
||||
else:
|
||||
mkdir_and_rename(path_opt.pop('results_root'))
|
||||
for key, path in path_opt.items():
|
||||
if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
||||
"""Scan a directory to find the interested files.
|
||||
|
||||
Args:
|
||||
dir_path (str): Path of the directory.
|
||||
suffix (str | tuple(str), optional): File suffix that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
full_path (bool, optional): If set to True, include the dir_path.
|
||||
Default: False.
|
||||
|
||||
Returns:
|
||||
A generator for all the interested files with relative pathes.
|
||||
"""
|
||||
|
||||
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, suffix, recursive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
if full_path:
|
||||
return_path = entry.path
|
||||
else:
|
||||
return_path = osp.relpath(entry.path, root)
|
||||
|
||||
if suffix is None:
|
||||
yield return_path
|
||||
elif return_path.endswith(suffix):
|
||||
yield return_path
|
||||
else:
|
||||
if recursive:
|
||||
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
||||
else:
|
||||
continue
|
||||
|
||||
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
||||
|
||||
|
||||
def check_resume(opt, resume_iter):
|
||||
"""Check resume states and pretrain_network paths.
|
||||
|
||||
Args:
|
||||
opt (dict): Options.
|
||||
resume_iter (int): Resume iteration.
|
||||
"""
|
||||
logger = get_root_logger()
|
||||
if opt['path']['resume_state']:
|
||||
# get all the networks
|
||||
networks = [key for key in opt.keys() if key.startswith('network_')]
|
||||
flag_pretrain = False
|
||||
for network in networks:
|
||||
if opt['path'].get(f'pretrain_{network}') is not None:
|
||||
flag_pretrain = True
|
||||
if flag_pretrain:
|
||||
logger.warning('pretrain_network path will be ignored during resuming.')
|
||||
# set pretrained model paths
|
||||
for network in networks:
|
||||
name = f'pretrain_{network}'
|
||||
basename = network.replace('network_', '')
|
||||
if opt['path'].get('ignore_resume_networks') is None or (basename
|
||||
not in opt['path']['ignore_resume_networks']):
|
||||
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
|
||||
logger.info(f"Set {name} to {opt['path'][name]}")
|
||||
|
||||
|
||||
def sizeof_fmt(size, suffix='B'):
|
||||
"""Get human readable file size.
|
||||
|
||||
Args:
|
||||
size (int): File size.
|
||||
suffix (str): Suffix. Default: 'B'.
|
||||
|
||||
Return:
|
||||
str: Formated file siz.
|
||||
"""
|
||||
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
||||
if abs(size) < 1024.0:
|
||||
return f'{size:3.1f} {unit}{suffix}'
|
||||
size /= 1024.0
|
||||
return f'{size:3.1f} Y{suffix}'
|
||||
108
basicsr/utils/options.py
Normal file
108
basicsr/utils/options.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import yaml
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from os import path as osp
|
||||
from basicsr.utils.misc import get_time_str
|
||||
|
||||
def ordered_yaml():
|
||||
"""Support OrderedDict for yaml.
|
||||
|
||||
Returns:
|
||||
yaml Loader and Dumper.
|
||||
"""
|
||||
try:
|
||||
from yaml import CDumper as Dumper
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Dumper, Loader
|
||||
|
||||
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
|
||||
|
||||
def dict_representer(dumper, data):
|
||||
return dumper.represent_dict(data.items())
|
||||
|
||||
def dict_constructor(loader, node):
|
||||
return OrderedDict(loader.construct_pairs(node))
|
||||
|
||||
Dumper.add_representer(OrderedDict, dict_representer)
|
||||
Loader.add_constructor(_mapping_tag, dict_constructor)
|
||||
return Loader, Dumper
|
||||
|
||||
|
||||
def parse(opt_path, root_path, is_train=True):
|
||||
"""Parse option file.
|
||||
|
||||
Args:
|
||||
opt_path (str): Option file path.
|
||||
is_train (str): Indicate whether in training or not. Default: True.
|
||||
|
||||
Returns:
|
||||
(dict): Options.
|
||||
"""
|
||||
with open(opt_path, mode='r') as f:
|
||||
Loader, _ = ordered_yaml()
|
||||
opt = yaml.load(f, Loader=Loader)
|
||||
|
||||
opt['is_train'] = is_train
|
||||
|
||||
# opt['name'] = f"{get_time_str()}_{opt['name']}"
|
||||
if opt['path'].get('resume_state', None): # Shangchen added
|
||||
resume_state_path = opt['path'].get('resume_state')
|
||||
opt['name'] = resume_state_path.split("/")[-3]
|
||||
else:
|
||||
opt['name'] = f"{get_time_str()}_{opt['name']}"
|
||||
|
||||
|
||||
# datasets
|
||||
for phase, dataset in opt['datasets'].items():
|
||||
# for several datasets, e.g., test_1, test_2
|
||||
phase = phase.split('_')[0]
|
||||
dataset['phase'] = phase
|
||||
if 'scale' in opt:
|
||||
dataset['scale'] = opt['scale']
|
||||
if dataset.get('dataroot_gt') is not None:
|
||||
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
|
||||
if dataset.get('dataroot_lq') is not None:
|
||||
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
|
||||
|
||||
# paths
|
||||
for key, val in opt['path'].items():
|
||||
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
|
||||
opt['path'][key] = osp.expanduser(val)
|
||||
|
||||
if is_train:
|
||||
experiments_root = osp.join(root_path, 'experiments', opt['name'])
|
||||
opt['path']['experiments_root'] = experiments_root
|
||||
opt['path']['models'] = osp.join(experiments_root, 'models')
|
||||
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
|
||||
opt['path']['log'] = experiments_root
|
||||
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
|
||||
|
||||
else: # test
|
||||
results_root = osp.join(root_path, 'results', opt['name'])
|
||||
opt['path']['results_root'] = results_root
|
||||
opt['path']['log'] = results_root
|
||||
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
||||
|
||||
return opt
|
||||
|
||||
|
||||
def dict2str(opt, indent_level=1):
|
||||
"""dict to string for printing options.
|
||||
|
||||
Args:
|
||||
opt (dict): Option dict.
|
||||
indent_level (int): Indent level. Default: 1.
|
||||
|
||||
Return:
|
||||
(str): Option string for printing.
|
||||
"""
|
||||
msg = '\n'
|
||||
for k, v in opt.items():
|
||||
if isinstance(v, dict):
|
||||
msg += ' ' * (indent_level * 2) + k + ':['
|
||||
msg += dict2str(v, indent_level + 1)
|
||||
msg += ' ' * (indent_level * 2) + ']\n'
|
||||
else:
|
||||
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
||||
return msg
|
||||
296
basicsr/utils/realesrgan_utils.py
Normal file
296
basicsr/utils/realesrgan_utils.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from torch.nn import functional as F
|
||||
|
||||
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class RealESRGANer():
|
||||
"""A helper class for upsampling images with RealESRGAN.
|
||||
|
||||
Args:
|
||||
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
||||
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
||||
model (nn.Module): The defined network. Default: None.
|
||||
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
||||
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
||||
0 denotes for do not use tile. Default: 0.
|
||||
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
||||
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
||||
half (float): Whether to use half precision during inference. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
scale,
|
||||
model_path,
|
||||
model=None,
|
||||
tile=0,
|
||||
tile_pad=10,
|
||||
pre_pad=10,
|
||||
half=False,
|
||||
device=None,
|
||||
gpu_id=None):
|
||||
self.scale = scale
|
||||
self.tile_size = tile
|
||||
self.tile_pad = tile_pad
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale = None
|
||||
self.half = half
|
||||
|
||||
# initialize model
|
||||
if gpu_id:
|
||||
self.device = torch.device(
|
||||
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
else:
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
# prefer to use params_ema
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
model.load_state_dict(loadnet[keyname], strict=True)
|
||||
model.eval()
|
||||
self.model = model.to(self.device)
|
||||
if self.half:
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img):
|
||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
||||
"""
|
||||
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
self.img = self.img.half()
|
||||
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
||||
# mod pad for divisible borders
|
||||
if self.scale == 2:
|
||||
self.mod_scale = 2
|
||||
elif self.scale == 1:
|
||||
self.mod_scale = 4
|
||||
if self.mod_scale is not None:
|
||||
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||
_, _, h, w = self.img.size()
|
||||
if (h % self.mod_scale != 0):
|
||||
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
||||
if (w % self.mod_scale != 0):
|
||||
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
||||
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
||||
|
||||
def process(self):
|
||||
# model inference
|
||||
self.output = self.model(self.img)
|
||||
|
||||
def tile_process(self):
|
||||
"""It will first crop input images to tiles, and then process each tile.
|
||||
Finally, all the processed tiles are merged into one images.
|
||||
|
||||
Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
output_width = width * self.scale
|
||||
output_shape = (batch, channel, output_height, output_width)
|
||||
|
||||
# start with black image
|
||||
self.output = self.img.new_zeros(output_shape)
|
||||
tiles_x = math.ceil(width / self.tile_size)
|
||||
tiles_y = math.ceil(height / self.tile_size)
|
||||
|
||||
# loop over all tiles
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
# extract tile from input image
|
||||
ofs_x = x * self.tile_size
|
||||
ofs_y = y * self.tile_size
|
||||
# input tile area on total image
|
||||
input_start_x = ofs_x
|
||||
input_end_x = min(ofs_x + self.tile_size, width)
|
||||
input_start_y = ofs_y
|
||||
input_end_y = min(ofs_y + self.tile_size, height)
|
||||
|
||||
# input tile area on total image with padding
|
||||
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||
|
||||
# input tile dimensions
|
||||
input_tile_width = input_end_x - input_start_x
|
||||
input_tile_height = input_end_y - input_start_y
|
||||
tile_idx = y * tiles_x + x + 1
|
||||
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||
|
||||
# upscale tile
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output_tile = self.model(input_tile)
|
||||
except RuntimeError as error:
|
||||
print('Error', error)
|
||||
# print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
||||
|
||||
# output tile area on total image
|
||||
output_start_x = input_start_x * self.scale
|
||||
output_end_x = input_end_x * self.scale
|
||||
output_start_y = input_start_y * self.scale
|
||||
output_end_y = input_end_y * self.scale
|
||||
|
||||
# output tile area without padding
|
||||
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||
|
||||
# put tile into output image
|
||||
self.output[:, :, output_start_y:output_end_y,
|
||||
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
||||
output_start_x_tile:output_end_x_tile]
|
||||
|
||||
def post_process(self):
|
||||
# remove extra pad
|
||||
if self.mod_scale is not None:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
||||
# remove prepad
|
||||
if self.pre_pad != 0:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
||||
return self.output
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
||||
h_input, w_input = img.shape[0:2]
|
||||
# img: numpy
|
||||
img = img.astype(np.float32)
|
||||
if np.max(img) > 256: # 16-bit image
|
||||
max_range = 65535
|
||||
print('\tInput is a 16-bit image')
|
||||
else:
|
||||
max_range = 255
|
||||
img = img / max_range
|
||||
if len(img.shape) == 2: # gray image
|
||||
img_mode = 'L'
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
||||
elif img.shape[2] == 4: # RGBA image with alpha channel
|
||||
img_mode = 'RGBA'
|
||||
alpha = img[:, :, 3]
|
||||
img = img[:, :, 0:3]
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if alpha_upsampler == 'realesrgan':
|
||||
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||
else:
|
||||
img_mode = 'RGB'
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# ------------------- process image (without the alpha channel) ------------------- #
|
||||
with torch.no_grad():
|
||||
self.pre_process(img)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_img_t = self.post_process()
|
||||
output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||
if img_mode == 'L':
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||
del output_img_t
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ------------------- process the alpha channel if necessary ------------------- #
|
||||
if img_mode == 'RGBA':
|
||||
if alpha_upsampler == 'realesrgan':
|
||||
self.pre_process(alpha)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_alpha = self.post_process()
|
||||
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||
else: # use the cv2 resize for alpha channel
|
||||
h, w = alpha.shape[0:2]
|
||||
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# merge the alpha channel
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||
output_img[:, :, 3] = output_alpha
|
||||
|
||||
# ------------------------------ return ------------------------------ #
|
||||
if max_range == 65535: # 16-bit image
|
||||
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||
else:
|
||||
output = (output_img * 255.0).round().astype(np.uint8)
|
||||
|
||||
if outscale is not None and outscale != float(self.scale):
|
||||
output = cv2.resize(
|
||||
output, (
|
||||
int(w_input * outscale),
|
||||
int(h_input * outscale),
|
||||
), interpolation=cv2.INTER_LANCZOS4)
|
||||
|
||||
return output, img_mode
|
||||
|
||||
|
||||
class PrefetchReader(threading.Thread):
|
||||
"""Prefetch images.
|
||||
|
||||
Args:
|
||||
img_list (list[str]): A image list of image paths to be read.
|
||||
num_prefetch_queue (int): Number of prefetch queue.
|
||||
"""
|
||||
|
||||
def __init__(self, img_list, num_prefetch_queue):
|
||||
super().__init__()
|
||||
self.que = queue.Queue(num_prefetch_queue)
|
||||
self.img_list = img_list
|
||||
|
||||
def run(self):
|
||||
for img_path in self.img_list:
|
||||
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
self.que.put(img)
|
||||
|
||||
self.que.put(None)
|
||||
|
||||
def __next__(self):
|
||||
next_item = self.que.get()
|
||||
if next_item is None:
|
||||
raise StopIteration
|
||||
return next_item
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class IOConsumer(threading.Thread):
|
||||
|
||||
def __init__(self, opt, que, qid):
|
||||
super().__init__()
|
||||
self._queue = que
|
||||
self.qid = qid
|
||||
self.opt = opt
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
msg = self._queue.get()
|
||||
if isinstance(msg, str) and msg == 'quit':
|
||||
break
|
||||
|
||||
output = msg['output']
|
||||
save_path = msg['save_path']
|
||||
cv2.imwrite(save_path, output)
|
||||
print(f'IO worker {self.qid} is done.')
|
||||
82
basicsr/utils/registry.py
Normal file
82
basicsr/utils/registry.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
|
||||
|
||||
|
||||
class Registry():
|
||||
"""
|
||||
The registry that provides name -> object mapping, to support third-party
|
||||
users' custom modules.
|
||||
|
||||
To create a registry (e.g. a backbone registry):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY = Registry('BACKBONE')
|
||||
|
||||
To register an object:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
class MyBackbone():
|
||||
...
|
||||
|
||||
Or:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
BACKBONE_REGISTRY.register(MyBackbone)
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""
|
||||
Args:
|
||||
name (str): the name of this registry
|
||||
"""
|
||||
self._name = name
|
||||
self._obj_map = {}
|
||||
|
||||
def _do_register(self, name, obj):
|
||||
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
|
||||
f"in '{self._name}' registry!")
|
||||
self._obj_map[name] = obj
|
||||
|
||||
def register(self, obj=None):
|
||||
"""
|
||||
Register the given object under the the name `obj.__name__`.
|
||||
Can be used as either a decorator or not.
|
||||
See docstring of this class for usage.
|
||||
"""
|
||||
if obj is None:
|
||||
# used as a decorator
|
||||
def deco(func_or_class):
|
||||
name = func_or_class.__name__
|
||||
self._do_register(name, func_or_class)
|
||||
return func_or_class
|
||||
|
||||
return deco
|
||||
|
||||
# used as a function call
|
||||
name = obj.__name__
|
||||
self._do_register(name, obj)
|
||||
|
||||
def get(self, name):
|
||||
ret = self._obj_map.get(name)
|
||||
if ret is None:
|
||||
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
|
||||
return ret
|
||||
|
||||
def __contains__(self, name):
|
||||
return name in self._obj_map
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._obj_map.items())
|
||||
|
||||
def keys(self):
|
||||
return self._obj_map.keys()
|
||||
|
||||
|
||||
DATASET_REGISTRY = Registry('dataset')
|
||||
ARCH_REGISTRY = Registry('arch')
|
||||
MODEL_REGISTRY = Registry('model')
|
||||
LOSS_REGISTRY = Registry('loss')
|
||||
METRIC_REGISTRY = Registry('metric')
|
||||
Reference in New Issue
Block a user