Increased GPU usage

This commit is contained in:
Felipe Daragon
2025-05-15 22:30:23 +01:00
parent 59b6233882
commit badbcc6edf
9 changed files with 239 additions and 223 deletions

View File

@@ -1,65 +1,64 @@
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import copy import copy
import os
from basicsr.utils import get_root_logger from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
# Select Device
def select_device(prefer_coreml=False):
if torch.backends.mps.is_available() and prefer_coreml:
print("BasicSR Archs: Using CoreML backend (MPS).")
return torch.device("mps")
elif torch.cuda.is_available():
print("BasicSR Archs: Using CUDA backend.")
return torch.device("cuda")
else:
print("BasicSR Archs: Using CPU backend.")
return torch.device("cpu")
# Set device globally
DEVICE = select_device(prefer_coreml=True)
def normalize(in_channels): def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script @torch.jit.script
def swish(x): def swish(x):
return x*torch.sigmoid(x) return x * torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module): class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta): def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__() super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings self.codebook_size = codebook_size
self.emb_dim = emb_dim # dimension of embedding self.emb_dim = emb_dim
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 self.beta = beta
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
def forward(self, z): def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous() z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim) z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = (z_flattened ** 2).sum(dim=1, keepdim=True) + \
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ (self.embedding.weight ** 2).sum(1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t()) 2 * torch.matmul(z_flattened, self.embedding.weight.t())
mean_distance = torch.mean(d) mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
# [0-1], higher score, higher confidence min_encoding_scores = torch.exp(-min_encoding_scores / 10)
min_encoding_scores = torch.exp(-min_encoding_scores/10)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size, device=z.device)
min_encodings.scatter_(1, min_encoding_indices, 1) min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach() z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0) e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous() z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, { return z_q, loss, {
@@ -68,18 +67,15 @@ class VectorQuantizer(nn.Module):
"min_encoding_indices": min_encoding_indices, "min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores, "min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance "mean_distance": mean_distance
} }
def get_codebook_feat(self, indices, shape): def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1 indices = indices.view(-1, 1)
# shape: batch, height, width, channel min_encodings = torch.zeros(indices.shape[0], self.codebook_size, device=indices.device)
indices = indices.view(-1,1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1) min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight) z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape if shape is not None:
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q return z_q
@@ -324,112 +320,87 @@ class Generator(nn.Module):
return x return x
# Autoencoder with device transfer
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module): class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): attn_resolutions=[16], codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__() super().__init__()
logger = get_root_logger() logger = get_root_logger()
self.in_channels = 3 self.in_channels = 3
self.nf = nf self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size self.codebook_size = codebook_size
self.embed_dim = emb_dim self.embed_dim = emb_dim
self.ch_mult = ch_mult self.ch_mult = ch_mult
self.resolution = img_size self.resolution = img_size
self.attn_resolutions = attn_resolutions self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer self.quantizer_type = quantizer
self.encoder = Encoder( self.encoder = Encoder(
self.in_channels, self.in_channels, self.nf, self.embed_dim, self.ch_mult,
self.nf, res_blocks, self.resolution, self.attn_resolutions
self.embed_dim, ).to(DEVICE)
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if self.quantizer_type == "nearest": if self.quantizer_type == "nearest":
self.beta = beta #0.25 self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, beta).to(DEVICE)
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) else:
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer( self.quantize = GumbelQuantizer(
self.codebook_size, self.codebook_size, self.embed_dim, emb_dim,
self.embed_dim, gumbel_straight_through, gumbel_kl_weight
self.gumbel_num_hiddens, ).to(DEVICE)
self.straight_through,
self.kl_weight
)
self.generator = Generator( self.generator = Generator(
self.nf, self.nf, self.embed_dim, self.ch_mult, res_blocks,
self.embed_dim, self.resolution, self.attn_resolutions
self.ch_mult, ).to(DEVICE)
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if model_path is not None: if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu') chkpt = torch.load(model_path, map_location='cpu')
if 'params_ema' in chkpt: if 'params_ema' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) self.load_state_dict(chkpt['params_ema'])
logger.info(f'vqgan is loaded from: {model_path} [params_ema]') logger.info(f'Loaded VQGAN from: {model_path} [params_ema]')
elif 'params' in chkpt: elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) self.load_state_dict(chkpt['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]') logger.info(f'Loaded VQGAN from: {model_path} [params]')
else: else:
raise ValueError(f'Wrong params!') raise ValueError("Invalid model format!")
def forward(self, x): def forward(self, x):
x = x.to(DEVICE)
x = self.encoder(x) x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x) quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant) x = self.generator(quant)
return x, codebook_loss, quant_stats return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module): class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__() super().__init__()
layers = [
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n_layers, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True) nn.LeakyReLU(0.2, True)
] ]
nf_mult = 1
for n in range(1, n_layers):
prev = nf_mult
nf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * prev, ndf * nf_mult, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
layers += [ layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map nn.Conv2d(ndf * nf_mult, 1, 4, 1, 1)
self.main = nn.Sequential(*layers) ]
self.main = nn.Sequential(*layers).to(DEVICE)
if model_path is not None: if model_path:
chkpt = torch.load(model_path, map_location='cpu') chkpt = torch.load(model_path, map_location='cpu')
if 'params_d' in chkpt: if 'params_d' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) self.load_state_dict(chkpt['params_d'])
elif 'params' in chkpt: elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) self.load_state_dict(chkpt['params'])
else:
raise ValueError(f'Wrong params!')
def forward(self, x): def forward(self, x):
return self.main(x) return self.main(x.to(DEVICE))

View File

@@ -82,12 +82,9 @@ class CPUPrefetcher():
class CUDAPrefetcher(): class CUDAPrefetcher():
"""CUDA prefetcher. """CUDA (or MPS/CPU) prefetcher.
Ref: It may consume more GPU memory.
https://github.com/NVIDIA/apex/issues/304#
It may consums more GPU memory.
Args: Args:
loader: Dataloader. loader: Dataloader.
@@ -98,8 +95,18 @@ class CUDAPrefetcher():
self.ori_loader = loader self.ori_loader = loader
self.loader = iter(loader) self.loader = iter(loader)
self.opt = opt self.opt = opt
self.stream = torch.cuda.Stream()
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') # Cross-platform device detection
if opt['num_gpu'] != 0 and torch.cuda.is_available():
self.device = torch.device('cuda')
self.stream = torch.cuda.Stream()
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
self.stream = None
else:
self.device = torch.device('cpu')
self.stream = None
self.preload() self.preload()
def preload(self): def preload(self):
@@ -108,18 +115,24 @@ class CUDAPrefetcher():
except StopIteration: except StopIteration:
self.batch = None self.batch = None
return None return None
# put tensors to gpu
with torch.cuda.stream(self.stream): if self.stream is not None:
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
else:
for k, v in self.batch.items(): for k, v in self.batch.items():
if torch.is_tensor(v): if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) self.batch[k] = self.batch[k].to(device=self.device)
def next(self): def next(self):
torch.cuda.current_stream().wait_stream(self.stream) if self.stream is not None:
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch batch = self.batch
self.preload() self.preload()
return batch return batch
def reset(self): def reset(self):
self.loader = iter(self.ori_loader) self.loader = iter(self.ori_loader)
self.preload() self.preload()

View File

@@ -2,22 +2,38 @@ import argparse
import datetime import datetime
import logging import logging
import math import math
import copy
import random import random
import time import time
import torch import torch
import platform
from os import path as osp from os import path as osp
import warnings
from basicsr.data import build_dataloader, build_dataset from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model from basicsr.models import build_model
from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger, from basicsr.utils import (
init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed) MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed
)
from basicsr.utils.dist_util import get_dist_info, init_dist from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse from basicsr.utils.options import dict2str, parse
import warnings # ----------- DEVICE SELECTION ----------
def select_device(prefer_coreml=True):
if torch.backends.mps.is_available() and prefer_coreml and platform.system() == "Darwin":
print("BasicSR: Using CoreML backend (MPS).")
return torch.device("mps")
elif torch.cuda.is_available():
print("BasicSR: Using CUDA backend.")
return torch.device("cuda")
else:
print("BasicSR: Using CPU backend.")
return torch.device("cpu")
DEVICE = select_device(prefer_coreml=True)
# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. # ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
@@ -30,9 +46,9 @@ def parse_options(root_path, is_train=True):
opt = parse(args.opt, root_path, is_train=is_train) opt = parse(args.opt, root_path, is_train=is_train)
# distributed settings # distributed settings
if args.launcher == 'none': if args.launcher == 'none' or DEVICE.type != 'cuda':
opt['dist'] = False opt['dist'] = False
print('Disable distributed.', flush=True) print('Distributed training disabled.', flush=True)
else: else:
opt['dist'] = True opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt: if args.launcher == 'slurm' and 'dist_params' in opt:
@@ -51,122 +67,96 @@ def parse_options(root_path, is_train=True):
return opt return opt
def init_loggers(opt): def init_loggers(opt):
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
logger.info(get_env_info()) logger.info(get_env_info())
logger.info(dict2str(opt)) logger.info(dict2str(opt))
# initialize wandb logger before tensorboard logger to allow proper sync:
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None): if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') assert opt['logger'].get('use_tb_logger') is True
init_wandb_logger(opt) init_wandb_logger(opt)
tb_logger = None tb_logger = None
if opt['logger'].get('use_tb_logger'): if opt['logger'].get('use_tb_logger'):
tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
return logger, tb_logger return logger, tb_logger
def create_train_val_dataloader(opt, logger): def create_train_val_dataloader(opt, logger):
# create train and val dataloaders
train_loader, val_loader = None, None train_loader, val_loader = None, None
for phase, dataset_opt in opt['datasets'].items(): for phase, dataset_opt in opt['datasets'].items():
if phase == 'train': if phase == 'train':
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
train_set = build_dataset(dataset_opt) train_set = build_dataset(dataset_opt)
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
train_loader = build_dataloader( train_loader = build_dataloader(train_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=train_sampler, seed=opt['manual_seed'])
train_set, num_iter_per_epoch = math.ceil(len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
dataset_opt,
num_gpu=opt['num_gpu'],
dist=opt['dist'],
sampler=train_sampler,
seed=opt['manual_seed'])
num_iter_per_epoch = math.ceil(
len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
total_iters = int(opt['train']['total_iter']) total_iters = int(opt['train']['total_iter'])
total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) total_epochs = math.ceil(total_iters / num_iter_per_epoch)
logger.info('Training statistics:' logger.info(f'Training stats:\n\tTrain images: {len(train_set)}\n\tEnlarge ratio: {dataset_enlarge_ratio}\n\tBatch/GPU: {dataset_opt["batch_size_per_gpu"]}\n\tGPUs: {opt["world_size"]}\n\tIters/epoch: {num_iter_per_epoch}\n\tTotal epochs: {total_epochs}, Iters: {total_iters}')
f'\n\tNumber of train images: {len(train_set)}'
f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
f'\n\tWorld size (gpu number): {opt["world_size"]}'
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
elif phase == 'val': elif phase == 'val':
val_set = build_dataset(dataset_opt) val_set = build_dataset(dataset_opt)
val_loader = build_dataloader( val_loader = build_dataloader(val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) logger.info(f'Validation items in {dataset_opt["name"]}: {len(val_set)}')
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}')
else: else:
raise ValueError(f'Dataset phase {phase} is not recognized.') raise ValueError(f'Dataset phase {phase} not recognized.')
return train_loader, train_sampler, val_loader, total_epochs, total_iters return train_loader, train_sampler, val_loader, total_epochs, total_iters
def train_pipeline(root_path): def train_pipeline(root_path):
# parse options, set distributed setting, set ramdom seed
opt = parse_options(root_path, is_train=True) opt = parse_options(root_path, is_train=True)
torch.backends.cudnn.benchmark = True if DEVICE.type == 'cuda':
# torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True
# load resume states if necessary
if opt['path'].get('resume_state'): if opt['path'].get('resume_state'):
device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=DEVICE)
resume_state = torch.load(
opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id))
else: else:
resume_state = None resume_state = None
# mkdir for experiments and logger
if resume_state is None: if resume_state is None:
make_exp_dirs(opt) make_exp_dirs(opt)
if opt['logger'].get('use_tb_logger') and opt['rank'] == 0: if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
mkdir_and_rename(osp.join('tb_logger', opt['name'])) mkdir_and_rename(osp.join('tb_logger', opt['name']))
# initialize loggers
logger, tb_logger = init_loggers(opt) logger, tb_logger = init_loggers(opt)
train_loader, train_sampler, val_loader, total_epochs, total_iters = create_train_val_dataloader(opt, logger)
# create train and validation dataloaders
result = create_train_val_dataloader(opt, logger)
train_loader, train_sampler, val_loader, total_epochs, total_iters = result
# create model if resume_state:
if resume_state: # resume training
check_resume(opt, resume_state['iter']) check_resume(opt, resume_state['iter'])
model = build_model(opt) model = build_model(opt).to(DEVICE)
model.resume_training(resume_state) # handle optimizers and schedulers model.resume_training(resume_state)
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") logger.info(f"Resuming from epoch {resume_state['epoch']}, iter {resume_state['iter']}")
start_epoch = resume_state['epoch'] start_epoch = resume_state['epoch']
current_iter = resume_state['iter'] current_iter = resume_state['iter']
else: else:
model = build_model(opt) model = build_model(opt).to(DEVICE)
start_epoch = 0 start_epoch = 0
current_iter = 0 current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger) msg_logger = MessageLogger(opt, current_iter, tb_logger)
# dataloader prefetcher
prefetch_mode = opt['datasets']['train'].get('prefetch_mode') prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
if prefetch_mode is None or prefetch_mode == 'cpu': if prefetch_mode is None or prefetch_mode == 'cpu' or DEVICE.type in ['cpu', 'mps']:
if prefetch_mode == 'cuda' and DEVICE.type == 'mps':
logger.warning("CUDA prefetch requested but MPS (CoreML) is in use. Falling back to CPU prefetch.")
prefetcher = CPUPrefetcher(train_loader) prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == 'cuda': elif prefetch_mode == 'cuda':
prefetcher = CUDAPrefetcher(train_loader, opt) if DEVICE.type != 'cuda':
logger.info(f'Use {prefetch_mode} prefetch dataloader') logger.warning("CUDA prefetch requested but CUDA unavailable. Using CPU prefetch.")
if opt['datasets']['train'].get('pin_memory') is not True: prefetcher = CPUPrefetcher(train_loader)
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') else:
if opt['datasets']['train'].get('pin_memory') is not True:
raise ValueError('Set pin_memory=True for CUDAPrefetcher.')
prefetcher = CUDAPrefetcher(train_loader, opt)
logger.info(f'Using CUDA prefetcher')
else: else:
raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.") raise ValueError(f"Invalid prefetch_mode: {prefetch_mode}. Supported: 'cpu', 'cuda', None")
# training logger.info(f'Start training at epoch {start_epoch}, iter {current_iter + 1}')
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}')
data_time, iter_time = time.time(), time.time()
start_time = time.time() start_time = time.time()
data_time, iter_time = time.time(), time.time()
for epoch in range(start_epoch, total_epochs + 1): for epoch in range(start_epoch, total_epochs + 1):
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
@@ -175,17 +165,15 @@ def train_pipeline(root_path):
while train_data is not None: while train_data is not None:
data_time = time.time() - data_time data_time = time.time() - data_time
current_iter += 1 current_iter += 1
if current_iter > total_iters: if current_iter > total_iters:
break break
# update learning rate
model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
# training
model.feed_data(train_data) model.feed_data(train_data)
model.optimize_parameters(current_iter) model.optimize_parameters(current_iter)
iter_time = time.time() - iter_time iter_time = time.time() - iter_time
# log
if current_iter % opt['logger']['print_freq'] == 0: if current_iter % opt['logger']['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter} log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': model.get_current_learning_rate()}) log_vars.update({'lrs': model.get_current_learning_rate()})
@@ -193,33 +181,27 @@ def train_pipeline(root_path):
log_vars.update(model.get_current_log()) log_vars.update(model.get_current_log())
msg_logger(log_vars) msg_logger(log_vars)
# save models and training states
if current_iter % opt['logger']['save_checkpoint_freq'] == 0: if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
logger.info('Saving models and training states.') logger.info('Saving model and training state.')
model.save(epoch, current_iter) model.save(epoch, current_iter)
# validation if opt.get('val') and opt['datasets'].get('val') and (current_iter % opt['val']['val_freq'] == 0):
if opt.get('val') is not None and opt['datasets'].get('val') is not None \
and (current_iter % opt['val']['val_freq'] == 0):
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
data_time = time.time() data_time = time.time()
iter_time = time.time() iter_time = time.time()
train_data = prefetcher.next() train_data = prefetcher.next()
# end of iter
# end of epoch
consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
logger.info(f'End of training. Time consumed: {consumed_time}') logger.info(f'Training complete. Time: {consumed_time}')
logger.info('Save the latest model.') logger.info('Saving latest model.')
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest model.save(epoch=-1, current_iter=-1)
if opt.get('val') is not None and opt['datasets'].get('val'):
if opt.get('val') and opt['datasets'].get('val'):
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
if tb_logger: if tb_logger:
tb_logger.close() tb_logger.close()
if __name__ == '__main__': if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path) train_pipeline(root_path)

View File

@@ -44,11 +44,20 @@ class RealESRGANer():
self.half = half self.half = half
# initialize model # initialize model
if gpu_id: if device is not None:
self.device = torch.device( self.device = device
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
else: else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device if torch.cuda.is_available():
if gpu_id is not None and gpu_id < torch.cuda.device_count():
self.device = torch.device(f"cuda:{gpu_id}")
else:
self.device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'): if model_path.startswith('https://'):
model_path = load_file_from_url( model_path = load_file_from_url(

View File

@@ -9,8 +9,15 @@ from basicsr.utils.download_util import load_file_from_url
from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.face_restoration_helper import FaceRestoreHelper
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Cross-platform device selection: CUDA > MPS > CPU
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# Download and load model
pretrain_model_url = { pretrain_model_url = {
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
} }
@@ -20,7 +27,7 @@ net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
model_dir='weights/CodeFormer', progress=True, file_name=None) model_dir='weights/CodeFormer', progress=True, file_name=None)
checkpoint = torch.load(ckpt_path)['params_ema'] checkpoint = torch.load(ckpt_path, map_location=device)['params_ema']
net.load_state_dict(checkpoint) net.load_state_dict(checkpoint)
net.eval() net.eval()
@@ -47,9 +54,9 @@ def _enhance_img(img: np.ndarray, w: float = 0.5) -> np.ndarray:
face_helper.align_warp_face() face_helper.align_warp_face()
for cropped_face in face_helper.cropped_faces: for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True).to(device)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device) cropped_face_t = cropped_face_t.unsqueeze(0) # (1, 3, H, W), already on correct device
with torch.no_grad(): with torch.no_grad():
output = net(cropped_face_t, w=w, adain=True)[0] output = net(cropped_face_t, w=w, adain=True)[0]
@@ -84,4 +91,4 @@ def enhance_image_memory(img: np.ndarray, w: float = 0.5) -> np.ndarray:
""" """
Enhances an input image entirely in memory and returns the enhanced image. Enhances an input image entirely in memory and returns the enhanced image.
""" """
return _enhance_img(img, w=w) return _enhance_img(img, w=w)

View File

@@ -11,7 +11,13 @@ from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, m
from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
py_cpu_nms) py_cpu_nms)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
def generate_config(network_name): def generate_config(network_name):
@@ -367,4 +373,4 @@ class RetinaFace(nn.Module):
# self.total_frame += len(frames) # self.total_frame += len(frames)
# print(self.batch_time / self.total_frame) # print(self.batch_time / self.total_frame)
return final_bounding_boxes, final_landmarks return final_bounding_boxes, final_landmarks

View File

@@ -139,4 +139,4 @@ class YoloDetector:
return None return None
def __call__(self, *args): def __call__(self, *args):
return self.predict(*args) return self.predict(*args)

View File

@@ -1,5 +1,21 @@
import torch import torch
import sys import sys
sys.path.insert(0,'./facelib/detection/yolov5face') import os
model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model']
torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth') # Setup dynamic device selection
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
sys.path.insert(0, './facelib/detection/yolov5face')
# Load the model to the selected device
ckpt = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location=device)
model = ckpt['model'].to(device)
# Save only the weights
os.makedirs('weights/facelib', exist_ok=True)
torch.save(model.state_dict(), 'weights/facelib/yolov5n-face.pth')

View File

@@ -269,32 +269,45 @@ class SCRFD:
return det, kpss return det, kpss
def autodetect(self, img, max_num=0, metric='max'): def autodetect(self, img, max_num=0, metric='max'):
bboxes, kpss = self.detect(img, input_size=(640, 640), thresh=0.5) if self.session.get_providers()[0] == 'CoreMLExecutionProvider':
bboxes2, kpss2 = self.detect(img, input_size=(128, 128), thresh=0.5) # Cache the CPU-based detector
if not hasattr(self, '_cpu_fallback_detector'):
model_path = self.model_file
cpu_session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
self._cpu_fallback_detector = SCRFD(model_file=model_path, session=cpu_session)
self._cpu_fallback_detector.prepare(0, input_size=(640, 640))
detector = self._cpu_fallback_detector
else:
detector = self # Use the original GPU/CoreML session
bboxes, kpss = detector.detect(img, input_size=(640, 640), thresh=0.5)
bboxes2, kpss2 = detector.detect(img, input_size=(128, 128), thresh=0.5)
bboxes_all = np.concatenate([bboxes, bboxes2], axis=0) bboxes_all = np.concatenate([bboxes, bboxes2], axis=0)
kpss_all = np.concatenate([kpss, kpss2], axis=0) kpss_all = np.concatenate([kpss, kpss2], axis=0)
keep = self.nms(bboxes_all) keep = self.nms(bboxes_all)
det = bboxes_all[keep,:] det = bboxes_all[keep, :]
kpss = kpss_all[keep,:] kpss = kpss_all[keep, :]
if max_num > 0 and det.shape[0] > max_num: if max_num > 0 and det.shape[0] > max_num:
area = (det[:, 2] - det[:, 0]) * (det[:, 3] - area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
det[:, 1])
img_center = img.shape[0] // 2, img.shape[1] // 2 img_center = img.shape[0] // 2, img.shape[1] // 2
offsets = np.vstack([ offsets = np.vstack([
(det[:, 0] + det[:, 2]) / 2 - img_center[1], (det[:, 0] + det[:, 2]) / 2 - img_center[1],
(det[:, 1] + det[:, 3]) / 2 - img_center[0] (det[:, 1] + det[:, 3]) / 2 - img_center[0]
]) ])
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
if metric=='max': if metric == 'max':
values = area values = area
else: else:
values = area - offset_dist_squared * 2.0 # some extra weight on the centering values = area - offset_dist_squared * 2.0
bindex = np.argsort( bindex = np.argsort(values)[::-1]
values)[::-1] # some extra weight on the centering
bindex = bindex[0:max_num] bindex = bindex[0:max_num]
det = det[bindex, :] det = det[bindex, :]
if kpss is not None: if kpss is not None:
kpss = kpss[bindex, :] kpss = kpss[bindex, :]
return det, kpss return det, kpss
def nms(self, dets): def nms(self, dets):
@@ -325,5 +338,4 @@ class SCRFD:
inds = np.where(ovr <= thresh)[0] inds = np.where(ovr <= thresh)[0]
order = order[inds + 1] order = order[inds + 1]
return keep return keep