Compare commits

...

10 Commits

Author SHA1 Message Date
zhh
caa3abc21d first commit 2025-12-23 10:18:20 +08:00
zhh
56c3987688 first commit 2025-10-27 17:01:04 +08:00
zhh
63ab13be55 first commit 2025-10-27 17:00:04 +08:00
zhh
47352f6e30 first commit 2025-10-27 16:52:14 +08:00
zhh
7de7c916b1 first commit 2025-10-27 16:51:16 +08:00
zhh
10a125f454 first commit 2025-10-17 16:04:58 +08:00
zhh
d8d558e185 first commit 2025-10-17 16:04:42 +08:00
Felipe Daragon
ad4dccfe49 Update install notes 2025-05-16 18:02:35 +01:00
Felipe Daragon
e57cfdb0d0 Improve install notes 2025-05-15 23:35:40 +01:00
Felipe Daragon
badbcc6edf Increased GPU usage 2025-05-15 22:30:23 +01:00
110 changed files with 1249 additions and 314 deletions

8
.gitignore vendored Normal file → Executable file
View File

@@ -170,4 +170,10 @@ aaa.md
*_test.py *_test.py
img.jpg img.jpg
test_data test_data
testsrc.mp4 testsrc.mp4
*.jpg
*.png
*.pth
.idea
*.jpeg

40
Dockerfile Executable file
View File

@@ -0,0 +1,40 @@
# Change CUDA and cuDNN version here
FROM nvidia/cuda:12.4.1-base-ubuntu22.04
ARG PYTHON_VERSION=3.11
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
software-properties-common \
wget \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update && apt-get install -y --no-install-recommends \
python$PYTHON_VERSION \
python$PYTHON_VERSION-dev \
python$PYTHON_VERSION-venv \
&& wget https://bootstrap.pypa.io/get-pip.py -O get-pip.py \
&& python$PYTHON_VERSION get-pip.py \
&& rm get-pip.py \
&& ln -sf /usr/bin/python$PYTHON_VERSION /usr/bin/python \
&& ln -sf /usr/local/bin/pip$PYTHON_VERSION /usr/local/bin/pip \
&& python --version \
&& pip --version \
&& apt-get purge -y --auto-remove software-properties-common \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
####### Add your own installation commands here #######
# RUN pip install some-package
# RUN wget https://path/to/some/data/or/weights
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y \
&& apt install -y build-essential g++
WORKDIR /app
COPY . /app
# Install litserve and requirements
RUN pip install --no-cache-dir litserve==0.2.16
RUN pip install -r requirements-GPU.txt
RUN pip install opencv-python
EXPOSE 8000
CMD ["python", "litserver_main.py"]
#CMD ["tail", "-f","/dev/null"]

0
LICENSE Normal file → Executable file
View File

8
README-AIDA-LC.md Executable file
View File

@@ -0,0 +1,8 @@
创建Docker file
litserve dockerize server.py --port 8000 --gpu
构建镜像:
docker build -t litserve-model .
运行容器:
docker run -p 8000:8000 litserve-model

10
README.md Normal file → Executable file
View File

@@ -93,9 +93,9 @@ Follow these steps to install Refacer and its dependencies:
# Create the environment # Create the environment
# Windows: # Windows:
conda create -n neorefacer-env python=3.11 nomkl conda-forge::vs2015_runtime conda create -n neorefacer-env python=3.11 conda-forge::vs2015_runtime
# Linux: # Linux:
conda create -n neorefacer-env python=3.11 nomkl conda create -n neorefacer-env python=3.11
# MacOS: # MacOS:
conda create -n neorefacer-env python=3.11 conda create -n neorefacer-env python=3.11
@@ -107,6 +107,12 @@ Follow these steps to install Refacer and its dependencies:
pip install -r requirements-CPU.txt pip install -r requirements-CPU.txt
# For NVIDIA RTX GPU only (compatible with Windows and Linux only, requires a NVIDIA GPU with CUDA and its libraries) # For NVIDIA RTX GPU only (compatible with Windows and Linux only, requires a NVIDIA GPU with CUDA and its libraries)
# Install Torch with CUDA enabled:
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
# This should install torch 2.5.1, torchaudio 2.5.1 and torchvision 0.20.1
# Make sure that CUDA is returning True:
python -c "import torch; print('CUDA:', torch.cuda.is_available()); print(torch.version.cuda); print(torch.cuda.get_device_name(0))"
# Now install the rest of the dependencies
pip install -r requirements-GPU.txt pip install -r requirements-GPU.txt
# For CoreML only (compatible with MacOSX, requires Silicon architecture): # For CoreML only (compatible with MacOSX, requires Silicon architecture):

0
app.py Normal file → Executable file
View File

0
basicsr/VERSION Normal file → Executable file
View File

0
basicsr/__init__.py Normal file → Executable file
View File

0
basicsr/archs/__init__.py Normal file → Executable file
View File

0
basicsr/archs/arcface_arch.py Normal file → Executable file
View File

0
basicsr/archs/arch_util.py Normal file → Executable file
View File

0
basicsr/archs/codeformer_arch.py Normal file → Executable file
View File

0
basicsr/archs/rrdbnet_arch.py Normal file → Executable file
View File

0
basicsr/archs/vgg_arch.py Normal file → Executable file
View File

175
basicsr/archs/vqgan_arch.py Normal file → Executable file
View File

@@ -1,65 +1,64 @@
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import copy import copy
import os
from basicsr.utils import get_root_logger from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
# Select Device
def select_device(prefer_coreml=False):
if torch.backends.mps.is_available() and prefer_coreml:
print("BasicSR Archs: Using CoreML backend (MPS).")
return torch.device("mps")
elif torch.cuda.is_available():
print("BasicSR Archs: Using CUDA backend.")
return torch.device("cuda")
else:
print("BasicSR Archs: Using CPU backend.")
return torch.device("cpu")
# Set device globally
DEVICE = select_device(prefer_coreml=True)
def normalize(in_channels): def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script @torch.jit.script
def swish(x): def swish(x):
return x*torch.sigmoid(x) return x * torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module): class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta): def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__() super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings self.codebook_size = codebook_size
self.emb_dim = emb_dim # dimension of embedding self.emb_dim = emb_dim
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 self.beta = beta
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
def forward(self, z): def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous() z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim) z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = (z_flattened ** 2).sum(dim=1, keepdim=True) + \
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ (self.embedding.weight ** 2).sum(1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t()) 2 * torch.matmul(z_flattened, self.embedding.weight.t())
mean_distance = torch.mean(d) mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
# [0-1], higher score, higher confidence min_encoding_scores = torch.exp(-min_encoding_scores / 10)
min_encoding_scores = torch.exp(-min_encoding_scores/10)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size, device=z.device)
min_encodings.scatter_(1, min_encoding_indices, 1) min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach() z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0) e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous() z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, { return z_q, loss, {
@@ -68,18 +67,15 @@ class VectorQuantizer(nn.Module):
"min_encoding_indices": min_encoding_indices, "min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores, "min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance "mean_distance": mean_distance
} }
def get_codebook_feat(self, indices, shape): def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1 indices = indices.view(-1, 1)
# shape: batch, height, width, channel min_encodings = torch.zeros(indices.shape[0], self.codebook_size, device=indices.device)
indices = indices.view(-1,1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1) min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight) z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape if shape is not None:
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q return z_q
@@ -324,112 +320,87 @@ class Generator(nn.Module):
return x return x
# Autoencoder with device transfer
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module): class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): attn_resolutions=[16], codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__() super().__init__()
logger = get_root_logger() logger = get_root_logger()
self.in_channels = 3 self.in_channels = 3
self.nf = nf self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size self.codebook_size = codebook_size
self.embed_dim = emb_dim self.embed_dim = emb_dim
self.ch_mult = ch_mult self.ch_mult = ch_mult
self.resolution = img_size self.resolution = img_size
self.attn_resolutions = attn_resolutions self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer self.quantizer_type = quantizer
self.encoder = Encoder( self.encoder = Encoder(
self.in_channels, self.in_channels, self.nf, self.embed_dim, self.ch_mult,
self.nf, res_blocks, self.resolution, self.attn_resolutions
self.embed_dim, ).to(DEVICE)
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if self.quantizer_type == "nearest": if self.quantizer_type == "nearest":
self.beta = beta #0.25 self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, beta).to(DEVICE)
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) else:
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer( self.quantize = GumbelQuantizer(
self.codebook_size, self.codebook_size, self.embed_dim, emb_dim,
self.embed_dim, gumbel_straight_through, gumbel_kl_weight
self.gumbel_num_hiddens, ).to(DEVICE)
self.straight_through,
self.kl_weight
)
self.generator = Generator( self.generator = Generator(
self.nf, self.nf, self.embed_dim, self.ch_mult, res_blocks,
self.embed_dim, self.resolution, self.attn_resolutions
self.ch_mult, ).to(DEVICE)
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if model_path is not None: if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu') chkpt = torch.load(model_path, map_location='cpu')
if 'params_ema' in chkpt: if 'params_ema' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) self.load_state_dict(chkpt['params_ema'])
logger.info(f'vqgan is loaded from: {model_path} [params_ema]') logger.info(f'Loaded VQGAN from: {model_path} [params_ema]')
elif 'params' in chkpt: elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) self.load_state_dict(chkpt['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]') logger.info(f'Loaded VQGAN from: {model_path} [params]')
else: else:
raise ValueError(f'Wrong params!') raise ValueError("Invalid model format!")
def forward(self, x): def forward(self, x):
x = x.to(DEVICE)
x = self.encoder(x) x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x) quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant) x = self.generator(quant)
return x, codebook_loss, quant_stats return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module): class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__() super().__init__()
layers = [
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n_layers, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True) nn.LeakyReLU(0.2, True)
] ]
nf_mult = 1
for n in range(1, n_layers):
prev = nf_mult
nf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * prev, ndf * nf_mult, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
layers += [ layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map nn.Conv2d(ndf * nf_mult, 1, 4, 1, 1)
self.main = nn.Sequential(*layers) ]
self.main = nn.Sequential(*layers).to(DEVICE)
if model_path is not None: if model_path:
chkpt = torch.load(model_path, map_location='cpu') chkpt = torch.load(model_path, map_location='cpu')
if 'params_d' in chkpt: if 'params_d' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) self.load_state_dict(chkpt['params_d'])
elif 'params' in chkpt: elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) self.load_state_dict(chkpt['params'])
else:
raise ValueError(f'Wrong params!')
def forward(self, x): def forward(self, x):
return self.main(x) return self.main(x.to(DEVICE))

0
basicsr/data/__init__.py Normal file → Executable file
View File

0
basicsr/data/data_sampler.py Normal file → Executable file
View File

0
basicsr/data/data_util.py Normal file → Executable file
View File

37
basicsr/data/prefetch_dataloader.py Normal file → Executable file
View File

@@ -82,12 +82,9 @@ class CPUPrefetcher():
class CUDAPrefetcher(): class CUDAPrefetcher():
"""CUDA prefetcher. """CUDA (or MPS/CPU) prefetcher.
Ref: It may consume more GPU memory.
https://github.com/NVIDIA/apex/issues/304#
It may consums more GPU memory.
Args: Args:
loader: Dataloader. loader: Dataloader.
@@ -98,8 +95,18 @@ class CUDAPrefetcher():
self.ori_loader = loader self.ori_loader = loader
self.loader = iter(loader) self.loader = iter(loader)
self.opt = opt self.opt = opt
self.stream = torch.cuda.Stream()
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') # Cross-platform device detection
if opt['num_gpu'] != 0 and torch.cuda.is_available():
self.device = torch.device('cuda')
self.stream = torch.cuda.Stream()
elif torch.backends.mps.is_available():
self.device = torch.device('mps')
self.stream = None
else:
self.device = torch.device('cpu')
self.stream = None
self.preload() self.preload()
def preload(self): def preload(self):
@@ -108,18 +115,24 @@ class CUDAPrefetcher():
except StopIteration: except StopIteration:
self.batch = None self.batch = None
return None return None
# put tensors to gpu
with torch.cuda.stream(self.stream): if self.stream is not None:
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
else:
for k, v in self.batch.items(): for k, v in self.batch.items():
if torch.is_tensor(v): if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) self.batch[k] = self.batch[k].to(device=self.device)
def next(self): def next(self):
torch.cuda.current_stream().wait_stream(self.stream) if self.stream is not None:
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch batch = self.batch
self.preload() self.preload()
return batch return batch
def reset(self): def reset(self):
self.loader = iter(self.ori_loader) self.loader = iter(self.ori_loader)
self.preload() self.preload()

0
basicsr/data/transforms.py Normal file → Executable file
View File

0
basicsr/losses/__init__.py Normal file → Executable file
View File

0
basicsr/losses/loss_util.py Normal file → Executable file
View File

0
basicsr/losses/losses.py Normal file → Executable file
View File

0
basicsr/metrics/__init__.py Normal file → Executable file
View File

0
basicsr/metrics/metric_util.py Normal file → Executable file
View File

0
basicsr/metrics/psnr_ssim.py Normal file → Executable file
View File

0
basicsr/models/__init__.py Normal file → Executable file
View File

0
basicsr/ops/__init__.py Normal file → Executable file
View File

0
basicsr/ops/dcn/__init__.py Normal file → Executable file
View File

0
basicsr/ops/dcn/deform_conv.py Normal file → Executable file
View File

0
basicsr/ops/dcn/src/deform_conv_cuda.cpp Normal file → Executable file
View File

0
basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu Normal file → Executable file
View File

0
basicsr/ops/dcn/src/deform_conv_ext.cpp Normal file → Executable file
View File

0
basicsr/ops/fused_act/__init__.py Normal file → Executable file
View File

0
basicsr/ops/fused_act/fused_act.py Normal file → Executable file
View File

0
basicsr/ops/fused_act/src/fused_bias_act.cpp Normal file → Executable file
View File

0
basicsr/ops/fused_act/src/fused_bias_act_kernel.cu Normal file → Executable file
View File

0
basicsr/ops/upfirdn2d/__init__.py Normal file → Executable file
View File

0
basicsr/ops/upfirdn2d/src/upfirdn2d.cpp Normal file → Executable file
View File

0
basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu Normal file → Executable file
View File

0
basicsr/ops/upfirdn2d/upfirdn2d.py Normal file → Executable file
View File

0
basicsr/setup.py Normal file → Executable file
View File

146
basicsr/train.py Normal file → Executable file
View File

@@ -2,22 +2,38 @@ import argparse
import datetime import datetime
import logging import logging
import math import math
import copy
import random import random
import time import time
import torch import torch
import platform
from os import path as osp from os import path as osp
import warnings
from basicsr.data import build_dataloader, build_dataset from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model from basicsr.models import build_model
from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger, from basicsr.utils import (
init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed) MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed
)
from basicsr.utils.dist_util import get_dist_info, init_dist from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse from basicsr.utils.options import dict2str, parse
import warnings # ----------- DEVICE SELECTION ----------
def select_device(prefer_coreml=True):
if torch.backends.mps.is_available() and prefer_coreml and platform.system() == "Darwin":
print("BasicSR: Using CoreML backend (MPS).")
return torch.device("mps")
elif torch.cuda.is_available():
print("BasicSR: Using CUDA backend.")
return torch.device("cuda")
else:
print("BasicSR: Using CPU backend.")
return torch.device("cpu")
DEVICE = select_device(prefer_coreml=True)
# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. # ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
@@ -30,9 +46,9 @@ def parse_options(root_path, is_train=True):
opt = parse(args.opt, root_path, is_train=is_train) opt = parse(args.opt, root_path, is_train=is_train)
# distributed settings # distributed settings
if args.launcher == 'none': if args.launcher == 'none' or DEVICE.type != 'cuda':
opt['dist'] = False opt['dist'] = False
print('Disable distributed.', flush=True) print('Distributed training disabled.', flush=True)
else: else:
opt['dist'] = True opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt: if args.launcher == 'slurm' and 'dist_params' in opt:
@@ -51,122 +67,96 @@ def parse_options(root_path, is_train=True):
return opt return opt
def init_loggers(opt): def init_loggers(opt):
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
logger.info(get_env_info()) logger.info(get_env_info())
logger.info(dict2str(opt)) logger.info(dict2str(opt))
# initialize wandb logger before tensorboard logger to allow proper sync:
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None): if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') assert opt['logger'].get('use_tb_logger') is True
init_wandb_logger(opt) init_wandb_logger(opt)
tb_logger = None tb_logger = None
if opt['logger'].get('use_tb_logger'): if opt['logger'].get('use_tb_logger'):
tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
return logger, tb_logger return logger, tb_logger
def create_train_val_dataloader(opt, logger): def create_train_val_dataloader(opt, logger):
# create train and val dataloaders
train_loader, val_loader = None, None train_loader, val_loader = None, None
for phase, dataset_opt in opt['datasets'].items(): for phase, dataset_opt in opt['datasets'].items():
if phase == 'train': if phase == 'train':
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
train_set = build_dataset(dataset_opt) train_set = build_dataset(dataset_opt)
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
train_loader = build_dataloader( train_loader = build_dataloader(train_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=train_sampler, seed=opt['manual_seed'])
train_set, num_iter_per_epoch = math.ceil(len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
dataset_opt,
num_gpu=opt['num_gpu'],
dist=opt['dist'],
sampler=train_sampler,
seed=opt['manual_seed'])
num_iter_per_epoch = math.ceil(
len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
total_iters = int(opt['train']['total_iter']) total_iters = int(opt['train']['total_iter'])
total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) total_epochs = math.ceil(total_iters / num_iter_per_epoch)
logger.info('Training statistics:' logger.info(f'Training stats:\n\tTrain images: {len(train_set)}\n\tEnlarge ratio: {dataset_enlarge_ratio}\n\tBatch/GPU: {dataset_opt["batch_size_per_gpu"]}\n\tGPUs: {opt["world_size"]}\n\tIters/epoch: {num_iter_per_epoch}\n\tTotal epochs: {total_epochs}, Iters: {total_iters}')
f'\n\tNumber of train images: {len(train_set)}'
f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
f'\n\tWorld size (gpu number): {opt["world_size"]}'
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
elif phase == 'val': elif phase == 'val':
val_set = build_dataset(dataset_opt) val_set = build_dataset(dataset_opt)
val_loader = build_dataloader( val_loader = build_dataloader(val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) logger.info(f'Validation items in {dataset_opt["name"]}: {len(val_set)}')
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}')
else: else:
raise ValueError(f'Dataset phase {phase} is not recognized.') raise ValueError(f'Dataset phase {phase} not recognized.')
return train_loader, train_sampler, val_loader, total_epochs, total_iters return train_loader, train_sampler, val_loader, total_epochs, total_iters
def train_pipeline(root_path): def train_pipeline(root_path):
# parse options, set distributed setting, set ramdom seed
opt = parse_options(root_path, is_train=True) opt = parse_options(root_path, is_train=True)
torch.backends.cudnn.benchmark = True if DEVICE.type == 'cuda':
# torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True
# load resume states if necessary
if opt['path'].get('resume_state'): if opt['path'].get('resume_state'):
device_id = torch.cuda.current_device() resume_state = torch.load(opt['path']['resume_state'], map_location=DEVICE)
resume_state = torch.load(
opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id))
else: else:
resume_state = None resume_state = None
# mkdir for experiments and logger
if resume_state is None: if resume_state is None:
make_exp_dirs(opt) make_exp_dirs(opt)
if opt['logger'].get('use_tb_logger') and opt['rank'] == 0: if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
mkdir_and_rename(osp.join('tb_logger', opt['name'])) mkdir_and_rename(osp.join('tb_logger', opt['name']))
# initialize loggers
logger, tb_logger = init_loggers(opt) logger, tb_logger = init_loggers(opt)
train_loader, train_sampler, val_loader, total_epochs, total_iters = create_train_val_dataloader(opt, logger)
# create train and validation dataloaders
result = create_train_val_dataloader(opt, logger)
train_loader, train_sampler, val_loader, total_epochs, total_iters = result
# create model if resume_state:
if resume_state: # resume training
check_resume(opt, resume_state['iter']) check_resume(opt, resume_state['iter'])
model = build_model(opt) model = build_model(opt).to(DEVICE)
model.resume_training(resume_state) # handle optimizers and schedulers model.resume_training(resume_state)
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") logger.info(f"Resuming from epoch {resume_state['epoch']}, iter {resume_state['iter']}")
start_epoch = resume_state['epoch'] start_epoch = resume_state['epoch']
current_iter = resume_state['iter'] current_iter = resume_state['iter']
else: else:
model = build_model(opt) model = build_model(opt).to(DEVICE)
start_epoch = 0 start_epoch = 0
current_iter = 0 current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger) msg_logger = MessageLogger(opt, current_iter, tb_logger)
# dataloader prefetcher
prefetch_mode = opt['datasets']['train'].get('prefetch_mode') prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
if prefetch_mode is None or prefetch_mode == 'cpu': if prefetch_mode is None or prefetch_mode == 'cpu' or DEVICE.type in ['cpu', 'mps']:
if prefetch_mode == 'cuda' and DEVICE.type == 'mps':
logger.warning("CUDA prefetch requested but MPS (CoreML) is in use. Falling back to CPU prefetch.")
prefetcher = CPUPrefetcher(train_loader) prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == 'cuda': elif prefetch_mode == 'cuda':
prefetcher = CUDAPrefetcher(train_loader, opt) if DEVICE.type != 'cuda':
logger.info(f'Use {prefetch_mode} prefetch dataloader') logger.warning("CUDA prefetch requested but CUDA unavailable. Using CPU prefetch.")
if opt['datasets']['train'].get('pin_memory') is not True: prefetcher = CPUPrefetcher(train_loader)
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') else:
if opt['datasets']['train'].get('pin_memory') is not True:
raise ValueError('Set pin_memory=True for CUDAPrefetcher.')
prefetcher = CUDAPrefetcher(train_loader, opt)
logger.info(f'Using CUDA prefetcher')
else: else:
raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.") raise ValueError(f"Invalid prefetch_mode: {prefetch_mode}. Supported: 'cpu', 'cuda', None")
# training logger.info(f'Start training at epoch {start_epoch}, iter {current_iter + 1}')
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}')
data_time, iter_time = time.time(), time.time()
start_time = time.time() start_time = time.time()
data_time, iter_time = time.time(), time.time()
for epoch in range(start_epoch, total_epochs + 1): for epoch in range(start_epoch, total_epochs + 1):
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
@@ -175,17 +165,15 @@ def train_pipeline(root_path):
while train_data is not None: while train_data is not None:
data_time = time.time() - data_time data_time = time.time() - data_time
current_iter += 1 current_iter += 1
if current_iter > total_iters: if current_iter > total_iters:
break break
# update learning rate
model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
# training
model.feed_data(train_data) model.feed_data(train_data)
model.optimize_parameters(current_iter) model.optimize_parameters(current_iter)
iter_time = time.time() - iter_time iter_time = time.time() - iter_time
# log
if current_iter % opt['logger']['print_freq'] == 0: if current_iter % opt['logger']['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter} log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': model.get_current_learning_rate()}) log_vars.update({'lrs': model.get_current_learning_rate()})
@@ -193,33 +181,27 @@ def train_pipeline(root_path):
log_vars.update(model.get_current_log()) log_vars.update(model.get_current_log())
msg_logger(log_vars) msg_logger(log_vars)
# save models and training states
if current_iter % opt['logger']['save_checkpoint_freq'] == 0: if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
logger.info('Saving models and training states.') logger.info('Saving model and training state.')
model.save(epoch, current_iter) model.save(epoch, current_iter)
# validation if opt.get('val') and opt['datasets'].get('val') and (current_iter % opt['val']['val_freq'] == 0):
if opt.get('val') is not None and opt['datasets'].get('val') is not None \
and (current_iter % opt['val']['val_freq'] == 0):
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
data_time = time.time() data_time = time.time()
iter_time = time.time() iter_time = time.time()
train_data = prefetcher.next() train_data = prefetcher.next()
# end of iter
# end of epoch
consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
logger.info(f'End of training. Time consumed: {consumed_time}') logger.info(f'Training complete. Time: {consumed_time}')
logger.info('Save the latest model.') logger.info('Saving latest model.')
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest model.save(epoch=-1, current_iter=-1)
if opt.get('val') is not None and opt['datasets'].get('val'):
if opt.get('val') and opt['datasets'].get('val'):
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
if tb_logger: if tb_logger:
tb_logger.close() tb_logger.close()
if __name__ == '__main__': if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path) train_pipeline(root_path)

0
basicsr/utils/__init__.py Normal file → Executable file
View File

0
basicsr/utils/dist_util.py Normal file → Executable file
View File

0
basicsr/utils/download_util.py Normal file → Executable file
View File

0
basicsr/utils/file_client.py Normal file → Executable file
View File

0
basicsr/utils/img_util.py Normal file → Executable file
View File

0
basicsr/utils/lmdb_util.py Normal file → Executable file
View File

0
basicsr/utils/logger.py Normal file → Executable file
View File

0
basicsr/utils/matlab_functions.py Normal file → Executable file
View File

0
basicsr/utils/misc.py Normal file → Executable file
View File

0
basicsr/utils/options.py Normal file → Executable file
View File

17
basicsr/utils/realesrgan_utils.py Normal file → Executable file
View File

@@ -44,11 +44,20 @@ class RealESRGANer():
self.half = half self.half = half
# initialize model # initialize model
if gpu_id: if device is not None:
self.device = torch.device( self.device = device
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
else: else:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device if torch.cuda.is_available():
if gpu_id is not None and gpu_id < torch.cuda.device_count():
self.device = torch.device(f"cuda:{gpu_id}")
else:
self.device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'): if model_path.startswith('https://'):
model_path = load_file_from_url( model_path = load_file_from_url(

0
basicsr/utils/registry.py Normal file → Executable file
View File

0
basicsr/version.py Normal file → Executable file
View File

7
client.py Executable file
View File

@@ -0,0 +1,7 @@
# This file is auto-generated by LitServe.
# Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`.
import requests
response = requests.post("http://127.0.0.1:8080/predict", json={"input": 4.0})
print(f"Status: {response.status_code}\nResponse:\n {response.text}")

17
codeformer_wrapper.py Normal file → Executable file
View File

@@ -9,8 +9,15 @@ from basicsr.utils.download_util import load_file_from_url
from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.face_restoration_helper import FaceRestoreHelper
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Cross-platform device selection: CUDA > MPS > CPU
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# Download and load model
pretrain_model_url = { pretrain_model_url = {
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
} }
@@ -20,7 +27,7 @@ net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
model_dir='weights/CodeFormer', progress=True, file_name=None) model_dir='weights/CodeFormer', progress=True, file_name=None)
checkpoint = torch.load(ckpt_path)['params_ema'] checkpoint = torch.load(ckpt_path, map_location=device)['params_ema']
net.load_state_dict(checkpoint) net.load_state_dict(checkpoint)
net.eval() net.eval()
@@ -47,9 +54,9 @@ def _enhance_img(img: np.ndarray, w: float = 0.5) -> np.ndarray:
face_helper.align_warp_face() face_helper.align_warp_face()
for cropped_face in face_helper.cropped_faces: for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True).to(device)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device) cropped_face_t = cropped_face_t.unsqueeze(0) # (1, 3, H, W), already on correct device
with torch.no_grad(): with torch.no_grad():
output = net(cropped_face_t, w=w, adain=True)[0] output = net(cropped_face_t, w=w, adain=True)[0]
@@ -84,4 +91,4 @@ def enhance_image_memory(img: np.ndarray, w: float = 0.5) -> np.ndarray:
""" """
Enhances an input image entirely in memory and returns the enhanced image. Enhances an input image entirely in memory and returns the enhanced image.
""" """
return _enhance_img(img, w=w) return _enhance_img(img, w=w)

93
codeformer_wrapper_no_path.py Executable file
View File

@@ -0,0 +1,93 @@
import os
import torch
import cv2
import numpy as np
from pathlib import Path
from torchvision.transforms.functional import normalize
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from basicsr.utils.registry import ARCH_REGISTRY
# Cross-platform device selection: CUDA > MPS > CPU
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# Download and load model
pretrain_model_url = {
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
}
net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
connect_list=['32', '64', '128', '256']).to(device)
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
model_dir='weights/CodeFormer', progress=True, file_name=None)
checkpoint = torch.load(ckpt_path, map_location=device)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
face_helper = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='jpg',
use_parse=True,
device=device
)
def _enhance_img(img: np.ndarray, w: float = 0.5) -> np.ndarray:
"""
Internal helper to enhance a numpy image with CodeFormer.
"""
face_helper.clean_all()
face_helper.read_image(img)
num_faces = face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
if num_faces == 0:
return img # Return original if no faces detected
face_helper.align_warp_face()
for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True).to(device)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0) # (1, 3, H, W), already on correct device
with torch.no_grad():
output = net(cropped_face_t, w=w, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
return restored_img
def enhance_image(img: str, w: float = 0.5) -> str:
"""
Enhances an input image using CodeFormer and saves it with a '.enhanced.jpg' suffix.
"""
# input_path = Path(input_image_path)
# output_path = input_path.with_name(f"{input_path.stem}.enhanced.jpg")
# img = cv2.imread(str(input_path), cv2.IMREAD_COLOR)
if img is None:
raise ValueError(f"Cannot read image")
restored_img = _enhance_img(img, w=w)
# os.makedirs(output_path.parent, exist_ok=True)
# cv2.imwrite(str(output_path), restored_img)
# print(f"Enhanced image saved to: {output_path}")
return restored_img
def enhance_image_memory(img: np.ndarray, w: float = 0.5) -> np.ndarray:
"""
Enhances an input image entirely in memory and returns the enhanced image.
"""
return _enhance_img(img, w=w)

0
demo.jpg Normal file → Executable file
View File

Before

Width:  |  Height:  |  Size: 37 KiB

After

Width:  |  Height:  |  Size: 37 KiB

18
docker-compose.yml Executable file
View File

@@ -0,0 +1,18 @@
services:
lc_neo_refacer:
build:
context: .
dockerfile: Dockerfile
working_dir: /app
volumes:
- .:/app
ports:
- "10071:8000"
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: [ '1' ]
capabilities:
- gpu

0
facelib/detection/__init__.py Normal file → Executable file
View File

0
facelib/detection/align_trans.py Normal file → Executable file
View File

0
facelib/detection/matlab_cp2tform.py Normal file → Executable file
View File

10
facelib/detection/retinaface/retinaface.py Normal file → Executable file
View File

@@ -11,7 +11,13 @@ from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, m
from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
py_cpu_nms) py_cpu_nms)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
def generate_config(network_name): def generate_config(network_name):
@@ -367,4 +373,4 @@ class RetinaFace(nn.Module):
# self.total_frame += len(frames) # self.total_frame += len(frames)
# print(self.batch_time / self.total_frame) # print(self.batch_time / self.total_frame)
return final_bounding_boxes, final_landmarks return final_bounding_boxes, final_landmarks

0
facelib/detection/retinaface/retinaface_net.py Normal file → Executable file
View File

0
facelib/detection/retinaface/retinaface_utils.py Normal file → Executable file
View File

0
facelib/detection/yolov5face/__init__.py Normal file → Executable file
View File

2
facelib/detection/yolov5face/face_detector.py Normal file → Executable file
View File

@@ -139,4 +139,4 @@ class YoloDetector:
return None return None
def __call__(self, *args): def __call__(self, *args):
return self.predict(*args) return self.predict(*args)

0
facelib/detection/yolov5face/models/__init__.py Normal file → Executable file
View File

0
facelib/detection/yolov5face/models/common.py Normal file → Executable file
View File

0
facelib/detection/yolov5face/models/experimental.py Normal file → Executable file
View File

0
facelib/detection/yolov5face/models/yolo.py Normal file → Executable file
View File

0
facelib/detection/yolov5face/models/yolov5l.yaml Normal file → Executable file
View File

0
facelib/detection/yolov5face/models/yolov5n.yaml Normal file → Executable file
View File

0
facelib/detection/yolov5face/utils/__init__.py Normal file → Executable file
View File

0
facelib/detection/yolov5face/utils/autoanchor.py Normal file → Executable file
View File

0
facelib/detection/yolov5face/utils/datasets.py Normal file → Executable file
View File

22
facelib/detection/yolov5face/utils/extract_ckpt.py Normal file → Executable file
View File

@@ -1,5 +1,21 @@
import torch import torch
import sys import sys
sys.path.insert(0,'./facelib/detection/yolov5face') import os
model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model']
torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth') # Setup dynamic device selection
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
sys.path.insert(0, './facelib/detection/yolov5face')
# Load the model to the selected device
ckpt = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location=device)
model = ckpt['model'].to(device)
# Save only the weights
os.makedirs('weights/facelib', exist_ok=True)
torch.save(model.state_dict(), 'weights/facelib/yolov5n-face.pth')

0
facelib/detection/yolov5face/utils/general.py Normal file → Executable file
View File

0
facelib/detection/yolov5face/utils/torch_utils.py Normal file → Executable file
View File

0
facelib/parsing/__init__.py Normal file → Executable file
View File

0
facelib/parsing/bisenet.py Normal file → Executable file
View File

0
facelib/parsing/parsenet.py Normal file → Executable file
View File

0
facelib/parsing/resnet.py Normal file → Executable file
View File

0
facelib/utils/__init__.py Normal file → Executable file
View File

0
facelib/utils/face_restoration_helper.py Normal file → Executable file
View File

0
facelib/utils/face_utils.py Normal file → Executable file
View File

0
facelib/utils/misc.py Normal file → Executable file
View File

0
icon.png Normal file → Executable file
View File

Before

Width:  |  Height:  |  Size: 46 KiB

After

Width:  |  Height:  |  Size: 46 KiB

70
litserver_main.py Executable file
View File

@@ -0,0 +1,70 @@
import time
import cv2
import litserve as ls
from pydantic import BaseModel
from refacer_no_path import Refacer as NoPathRefacer
from utils.minio_client import oss_get_image, minio_client, oss_upload_image
class PredictRequest(BaseModel):
input_image_list: list[str] # 待换脸图片
input_face: str # 目标脸图片
threshold: float = 0.2 # 相似度 max0.5
class InferencePipeline(ls.LitAPI):
def setup(self, device):
force_cpu = False
colab_performance = False
self.supported_exts = {'jpg', 'jpeg', 'png', 'bmp', 'webp'}
self.refacer = NoPathRefacer(force_cpu=force_cpu, colab_performance=colab_performance)
def decode_request(self, request: PredictRequest):
self.input_image_list = []
for path in request.input_image_list:
self.input_image_list.append({
'img_obj': oss_get_image(oss_client=minio_client, path=path, data_type="cv2"),
'img_path': path
})
dest_img = oss_get_image(oss_client=minio_client, path=request.input_face, data_type="cv2")
faces_config = [
{
'origin': None,
'destination': dest_img,
'destination_path': request.input_face,
'threshold': request.threshold,
}
]
self.refacer.prepare_faces(faces_config)
return faces_config
def predict(self, faces_config):
refaced_images_url = []
for i, image in enumerate(self.input_image_list):
ext = image['img_path'].rsplit(".", 1)[1].lower()
if ext not in self.supported_exts:
print(f"Skipping non-image file: {image['img_path']}")
continue
print(f"Refacing: {image['img_path']}")
try:
refaced_image = self.refacer.reface_image(image['img_obj'], faces_config, disable_similarity=True)
refaced_image_rgb = cv2.cvtColor(refaced_image, cv2.COLOR_RGB2BGR)
image_bytes = cv2.imencode('.jpg', refaced_image_rgb)[1].tobytes()
req = oss_upload_image(oss_client=minio_client, bucket="lanecarford", object_name=f"refaced_image/refaced{time.time()}.{ext}", image_bytes=image_bytes)
refaced_images_url.append(f"{req.bucket_name}/{req.object_name}")
print(f"Saved -> {req.bucket_name}/{req.object_name}")
except Exception as e:
print(f"Failed to process {image['img_path']}: {e}")
return refaced_images_url
def encode_response(self, output):
return {"output": output}
if __name__ == '__main__':
api = InferencePipeline()
server = ls.LitServer(api, accelerator="auto")
server.run(port=8000)

0
output/.gitkeep Normal file → Executable file
View File

0
recognition/arcface_onnx.py Normal file → Executable file
View File

0
recognition/face_align.py Normal file → Executable file
View File

0
recognition/main.py Normal file → Executable file
View File

36
recognition/scrfd.py Normal file → Executable file
View File

@@ -269,32 +269,45 @@ class SCRFD:
return det, kpss return det, kpss
def autodetect(self, img, max_num=0, metric='max'): def autodetect(self, img, max_num=0, metric='max'):
bboxes, kpss = self.detect(img, input_size=(640, 640), thresh=0.5) if self.session.get_providers()[0] == 'CoreMLExecutionProvider':
bboxes2, kpss2 = self.detect(img, input_size=(128, 128), thresh=0.5) # Cache the CPU-based detector
if not hasattr(self, '_cpu_fallback_detector'):
model_path = self.model_file
cpu_session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
self._cpu_fallback_detector = SCRFD(model_file=model_path, session=cpu_session)
self._cpu_fallback_detector.prepare(0, input_size=(640, 640))
detector = self._cpu_fallback_detector
else:
detector = self # Use the original GPU/CoreML session
bboxes, kpss = detector.detect(img, input_size=(640, 640), thresh=0.5)
bboxes2, kpss2 = detector.detect(img, input_size=(128, 128), thresh=0.5)
bboxes_all = np.concatenate([bboxes, bboxes2], axis=0) bboxes_all = np.concatenate([bboxes, bboxes2], axis=0)
kpss_all = np.concatenate([kpss, kpss2], axis=0) kpss_all = np.concatenate([kpss, kpss2], axis=0)
keep = self.nms(bboxes_all) keep = self.nms(bboxes_all)
det = bboxes_all[keep,:] det = bboxes_all[keep, :]
kpss = kpss_all[keep,:] kpss = kpss_all[keep, :]
if max_num > 0 and det.shape[0] > max_num: if max_num > 0 and det.shape[0] > max_num:
area = (det[:, 2] - det[:, 0]) * (det[:, 3] - area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1])
det[:, 1])
img_center = img.shape[0] // 2, img.shape[1] // 2 img_center = img.shape[0] // 2, img.shape[1] // 2
offsets = np.vstack([ offsets = np.vstack([
(det[:, 0] + det[:, 2]) / 2 - img_center[1], (det[:, 0] + det[:, 2]) / 2 - img_center[1],
(det[:, 1] + det[:, 3]) / 2 - img_center[0] (det[:, 1] + det[:, 3]) / 2 - img_center[0]
]) ])
offset_dist_squared = np.sum(np.power(offsets, 2.0), 0) offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
if metric=='max': if metric == 'max':
values = area values = area
else: else:
values = area - offset_dist_squared * 2.0 # some extra weight on the centering values = area - offset_dist_squared * 2.0
bindex = np.argsort( bindex = np.argsort(values)[::-1]
values)[::-1] # some extra weight on the centering
bindex = bindex[0:max_num] bindex = bindex[0:max_num]
det = det[bindex, :] det = det[bindex, :]
if kpss is not None: if kpss is not None:
kpss = kpss[bindex, :] kpss = kpss[bindex, :]
return det, kpss return det, kpss
def nms(self, dets): def nms(self, dets):
@@ -325,5 +338,4 @@ class SCRFD:
inds = np.where(ovr <= thresh)[0] inds = np.where(ovr <= thresh)[0]
order = order[inds + 1] order = order[inds + 1]
return keep return keep

178
refacer.py Normal file → Executable file
View File

@@ -1,6 +1,9 @@
import cv2 import cv2
import onnxruntime as rt import onnxruntime as rt
import sys import sys
from utils.minio_client import oss_get_image, minio_client
sys.path.insert(1, './recognition') sys.path.insert(1, './recognition')
from scrfd import SCRFD from scrfd import SCRFD
from arcface_onnx import ArcFaceONNX from arcface_onnx import ArcFaceONNX
@@ -40,9 +43,11 @@ if sys.platform in ("win32", "win64"):
if hasattr(rt, "preload_dlls"): if hasattr(rt, "preload_dlls"):
rt.preload_dlls() rt.preload_dlls()
class RefacerMode(Enum): class RefacerMode(Enum):
CPU, CUDA, COREML, TENSORRT = range(1, 5) CPU, CUDA, COREML, TENSORRT = range(1, 5)
class Refacer: class Refacer:
def __init__(self, force_cpu=False, colab_performance=False): def __init__(self, force_cpu=False, colab_performance=False):
self.disable_similarity = False self.disable_similarity = False
@@ -55,46 +60,45 @@ class Refacer:
self.__check_providers() self.__check_providers()
self.total_mem = psutil.virtual_memory().total self.total_mem = psutil.virtual_memory().total
self.__init_apps() self.__init_apps()
def _partial_face_blend(self, original_frame, swapped_frame, face): def _partial_face_blend(self, original_frame, swapped_frame, face):
h_frame, w_frame = original_frame.shape[:2] h_frame, w_frame = original_frame.shape[:2]
x1, y1, x2, y2 = map(int, face.bbox) x1, y1, x2, y2 = map(int, face.bbox)
x1 = max(0, min(x1, w_frame-1)) x1 = max(0, min(x1, w_frame - 1))
y1 = max(0, min(y1, h_frame-1)) y1 = max(0, min(y1, h_frame - 1))
x2 = max(0, min(x2, w_frame)) x2 = max(0, min(x2, w_frame))
y2 = max(0, min(y2, h_frame)) y2 = max(0, min(y2, h_frame))
if x2 <= x1 or y2 <= y1: if x2 <= x1 or y2 <= y1:
print(f"Invalid bbox: {x1},{y1},{x2},{y2}") print(f"Invalid bbox: {x1},{y1},{x2},{y2}")
return swapped_frame return swapped_frame
w = x2 - x1 w = x2 - x1
h = y2 - y1 h = y2 - y1
cutoff = int(h * (1.0 - self.blend_height_ratio)) cutoff = int(h * (1.0 - self.blend_height_ratio))
swap_crop = swapped_frame[y1:y2, x1:x2].copy() swap_crop = swapped_frame[y1:y2, x1:x2].copy()
orig_crop = original_frame[y1:y2, x1:x2].copy() orig_crop = original_frame[y1:y2, x1:x2].copy()
mask = np.ones((h, w, 3), dtype=np.float32) mask = np.ones((h, w, 3), dtype=np.float32)
transition = 40 transition = 40
if cutoff < h: if cutoff < h:
blend_start = max(cutoff - transition // 2, 0) blend_start = max(cutoff - transition // 2, 0)
blend_end = min(cutoff + transition // 2, h) blend_end = min(cutoff + transition // 2, h)
if blend_end > blend_start: if blend_end > blend_start:
alpha = np.linspace(1.0, 0.0, blend_end - blend_start)[:, np.newaxis, np.newaxis] alpha = np.linspace(1.0, 0.0, blend_end - blend_start)[:, np.newaxis, np.newaxis]
mask[blend_start:blend_end, :, :] = alpha mask[blend_start:blend_end, :, :] = alpha
mask[blend_end:, :, :] = 0.0 mask[blend_end:, :, :] = 0.0
blended_crop = (swap_crop.astype(np.float32) * mask + orig_crop.astype(np.float32) * (1.0 - mask)).astype(np.uint8) blended_crop = (swap_crop.astype(np.float32) * mask + orig_crop.astype(np.float32) * (1.0 - mask)).astype(np.uint8)
blended_frame = swapped_frame.copy() blended_frame = swapped_frame.copy()
blended_frame[y1:y2, x1:x2] = blended_crop blended_frame[y1:y2, x1:x2] = blended_crop
return blended_frame return blended_frame
def __download_with_progress(self, url, output_path): def __download_with_progress(self, url, output_path):
response = requests.get(url, stream=True) response = requests.get(url, stream=True)
@@ -238,7 +242,7 @@ class Refacer:
faces = self.__get_faces(frame, max_num=0) faces = self.__get_faces(frame, max_num=0)
if not faces: if not faces:
return frame return frame
if self.disable_similarity: if self.disable_similarity:
for face in faces: for face in faces:
swapped = self.face_swapper.get(frame, face, self.replacement_faces[0][1], paste_back=True) swapped = self.face_swapper.get(frame, face, self.replacement_faces[0][1], paste_back=True)
@@ -253,9 +257,9 @@ class Refacer:
faces = self.__get_faces(frame, max_num=0) faces = self.__get_faces(frame, max_num=0)
if not faces: if not faces:
return frame return frame
faces = sorted(faces, key=lambda face: face.bbox[0]) faces = sorted(faces, key=lambda face: face.bbox[0])
if self.multiple_faces_mode: if self.multiple_faces_mode:
for idx, face in enumerate(faces): for idx, face in enumerate(faces):
if idx >= len(self.replacement_faces): if idx >= len(self.replacement_faces):
@@ -309,33 +313,33 @@ class Refacer:
original_name = osp.splitext(osp.basename(video_path))[0] original_name = osp.splitext(osp.basename(video_path))[0]
timestamp = str(int(time.time())) timestamp = str(int(time.time()))
filename = f"{original_name}_preview.mp4" if preview else f"{original_name}_{timestamp}.mp4" filename = f"{original_name}_preview.mp4" if preview else f"{original_name}_{timestamp}.mp4"
self.__check_video_has_audio(video_path) self.__check_video_has_audio(video_path)
if preview: if preview:
os.makedirs("output/preview", exist_ok=True) os.makedirs("output/preview", exist_ok=True)
output_video_path = os.path.join('output', 'preview', filename) output_video_path = os.path.join('output', 'preview', filename)
else: else:
os.makedirs("output", exist_ok=True) os.makedirs("output", exist_ok=True)
output_video_path = os.path.join('output', filename) output_video_path = os.path.join('output', filename)
self.prepare_faces(faces, disable_similarity=disable_similarity, multiple_faces_mode=multiple_faces_mode) self.prepare_faces(faces, disable_similarity=disable_similarity, multiple_faces_mode=multiple_faces_mode)
self.first_face = False if multiple_faces_mode else (faces[0].get("origin") is None or disable_similarity) self.first_face = False if multiple_faces_mode else (faces[0].get("origin") is None or disable_similarity)
self.partial_reface_ratio = partial_reface_ratio self.partial_reface_ratio = partial_reface_ratio
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG) cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v') fourcc = cv2.VideoWriter_fourcc(*'mp4v')
output = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height)) output = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
frames = [] frames = []
frame_index = 0 frame_index = 0
skip_rate = 10 if preview else 1 skip_rate = 10 if preview else 1
with tqdm(total=total_frames, desc="Extracting frames") as pbar: with tqdm(total=total_frames, desc="Extracting frames") as pbar:
while cap.isOpened(): while cap.isOpened():
flag, frame = cap.read() flag, frame = cap.read()
@@ -349,28 +353,24 @@ class Refacer:
gc.collect() gc.collect()
frame_index += 1 frame_index += 1
pbar.update() pbar.update()
cap.release() cap.release()
if frames: if frames:
self.reface_group(faces, frames, output) self.reface_group(faces, frames, output)
output.release() output.release()
converted_path = self.__convert_video(video_path, output_video_path, preview=preview) converted_path = self.__convert_video(video_path, output_video_path, preview=preview)
if video_path.lower().endswith(".gif"): if video_path.lower().endswith(".gif"):
if preview: if preview:
gif_output_path = os.path.join("output", "preview", os.path.basename(converted_path).replace(".mp4", ".gif")) gif_output_path = os.path.join("output", "preview", os.path.basename(converted_path).replace(".mp4", ".gif"))
else: else:
gif_output_path = os.path.join("output", "gifs", os.path.basename(converted_path).replace(".mp4", ".gif")) gif_output_path = os.path.join("output", "gifs", os.path.basename(converted_path).replace(".mp4", ".gif"))
self.__generate_gif(converted_path, gif_output_path) self.__generate_gif(converted_path, gif_output_path)
return converted_path, gif_output_path return converted_path, gif_output_path
return converted_path, None
return converted_path, None
def __generate_gif(self, video_path, gif_output_path): def __generate_gif(self, video_path, gif_output_path):
os.makedirs(os.path.dirname(gif_output_path), exist_ok=True) os.makedirs(os.path.dirname(gif_output_path), exist_ok=True)
@@ -396,60 +396,64 @@ class Refacer:
return new_path return new_path
def reface_image(self, image_path, faces, disable_similarity=False, multiple_faces_mode=False, partial_reface_ratio=0.0): def reface_image(self, image_path, faces, disable_similarity=False, multiple_faces_mode=False, partial_reface_ratio=0.0):
self.prepare_faces(faces, disable_similarity=disable_similarity, multiple_faces_mode=multiple_faces_mode) self.prepare_faces(faces, disable_similarity=disable_similarity, multiple_faces_mode=multiple_faces_mode)
self.first_face = False if multiple_faces_mode else (faces[0].get("origin") is None or disable_similarity) self.first_face = False if multiple_faces_mode else (faces[0].get("origin") is None or disable_similarity)
self.partial_reface_ratio = partial_reface_ratio self.partial_reface_ratio = partial_reface_ratio
ext = osp.splitext(image_path)[1].lower()
os.makedirs("output", exist_ok=True)
original_name = osp.splitext(osp.basename(image_path))[0]
timestamp = str(int(time.time()))
if ext in ['.tif', '.tiff']:
pil_img = Image.open(image_path)
frames = []
page_count = 0
try:
while True:
pil_img.seek(page_count)
page_count += 1
except EOFError:
pass
pil_img = Image.open(image_path)
with tqdm(total=page_count, desc="Processing TIFF pages") as pbar:
for page in range(page_count):
pil_img.seek(page)
bgr_image = cv2.cvtColor(np.array(pil_img.convert('RGB')), cv2.COLOR_RGB2BGR)
refaced_bgr = self.process_first_face(bgr_image.copy()) if self.first_face else self.process_faces(bgr_image.copy())
enhanced_bgr = enhance_image_memory(refaced_bgr)
enhanced_rgb = cv2.cvtColor(enhanced_bgr, cv2.COLOR_BGR2RGB)
enhanced_pil = Image.fromarray(enhanced_rgb)
frames.append(enhanced_pil)
pbar.update(1)
output_path = os.path.join("output", f"{original_name}_{timestamp}.tif")
frames[0].save(output_path, save_all=True, append_images=frames[1:], compression="tiff_deflate")
print(f"Saved multipage refaced TIFF to {output_path}")
return output_path
else:
bgr_image = cv2.imread(image_path)
if bgr_image is None:
raise ValueError("Failed to read input image")
refaced_bgr = self.process_first_face(bgr_image.copy()) if self.first_face else self.process_faces(bgr_image.copy())
refaced_rgb = cv2.cvtColor(refaced_bgr, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(refaced_rgb)
filename = f"{original_name}_{timestamp}.jpg"
output_path = os.path.join("output", filename)
pil_img.save(output_path, format='JPEG', quality=100, subsampling=0)
output_path = enhance_image(output_path)
print(f"Saved refaced image to {output_path}")
return output_path
ext = osp.splitext(image_path)[1].lower() #
# ext = image_path.rsplit('.',1)[1].lower()
os.makedirs("output", exist_ok=True) #
original_name = osp.splitext(osp.basename(image_path))[0]
timestamp = str(int(time.time()))
if ext in ['.tif', '.tiff']:
pil_img = Image.open(image_path) #
# pil_img = oss_get_image(oss_client=minio_client, path=image_path, data_type="PIL")
frames = []
page_count = 0
try:
while True:
pil_img.seek(page_count)
page_count += 1
except EOFError:
pass
pil_img = Image.open(image_path) #
# pil_img = oss_get_image(oss_client=minio_client, path=image_path, data_type="PIL")
with tqdm(total=page_count, desc="Processing TIFF pages") as pbar:
for page in range(page_count):
pil_img.seek(page)
bgr_image = cv2.cvtColor(np.array(pil_img.convert('RGB')), cv2.COLOR_RGB2BGR)
refaced_bgr = self.process_first_face(bgr_image.copy()) if self.first_face else self.process_faces(bgr_image.copy())
enhanced_bgr = enhance_image_memory(refaced_bgr)
enhanced_rgb = cv2.cvtColor(enhanced_bgr, cv2.COLOR_BGR2RGB)
enhanced_pil = Image.fromarray(enhanced_rgb)
frames.append(enhanced_pil)
pbar.update(1)
output_path = os.path.join("output", f"{original_name}_{timestamp}.tif")
frames[0].save(output_path, save_all=True, append_images=frames[1:], compression="tiff_deflate")
print(f"Saved multipage refaced TIFF to {output_path}")
return output_path
else:
bgr_image = cv2.imread(image_path) #
# bgr_image = oss_get_image(oss_client=minio_client, path=image_path, data_type="cv2")
if bgr_image is None:
raise ValueError("Failed to read input image")
refaced_bgr = self.process_first_face(bgr_image.copy()) if self.first_face else self.process_faces(bgr_image.copy())
refaced_rgb = cv2.cvtColor(refaced_bgr, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(refaced_rgb)
filename = f"{original_name}_{timestamp}.jpg"
output_path = os.path.join("output", filename)
pil_img.save(output_path, format='JPEG', quality=100, subsampling=0)
output_path = enhance_image(output_path)
print(f"Saved refaced image to {output_path}")
return output_path
def extract_faces_from_image(self, image_path, max_faces=5): def extract_faces_from_image(self, image_path, max_faces=5):
frame = cv2.imread(image_path) frame = cv2.imread(image_path)
@@ -508,4 +512,4 @@ class Refacer:
'h264_videotoolbox': '0', 'h264_videotoolbox': '0',
'h264_nvenc': '0', 'h264_nvenc': '0',
'libx264': '0' 'libx264': '0'
} }

0
refacer_bulk.py Normal file → Executable file
View File

82
refacer_bulk_no_path.py Executable file
View File

@@ -0,0 +1,82 @@
# refacer_bulk.py
#
# Example usage:
# python refacer_bulk.py --input_path ./input --dest_face myface.jpg --facetoreplace face1.jpg --threshold 0.3
#
# Or, to disable similarity check (i.e., just apply the destination face to all detected faces):
# python refacer_bulk.py --input_path ./input --dest_face myface.jpg
import argparse
import os
import time
import cv2
from PIL import Image
from refacer_no_path import Refacer as NoPathRefacer
import pyfiglet
from utils.minio_client import oss_get_image, minio_client, oss_upload_image
def main():
input_path = [
"lanecarford/original_image/7450d6e8-bc54-4c85-940c-4a31c879e02f-0-89.png",
"lanecarford-outfits/outfits/outfit_6420.jpg",
"lanecarford-outfits/outfits/outfit_7579.jpg"
]
dest_face = "lanecarford/input_face/leijun.jpg"
facetoreplace = ""
threshold = 0.2
force_cpu = False
colab_performance = False
input_dir = input_path
refacer = NoPathRefacer(force_cpu=force_cpu, colab_performance=colab_performance)
# Load destination and origin face
dest_img = oss_get_image(oss_client=minio_client, path=dest_face, data_type="cv2")
if dest_img is None:
raise ValueError(f"Destination face image not found: {dest_face}")
origin_img = None
if facetoreplace:
origin_img = oss_get_image(oss_client=minio_client, path=facetoreplace, data_type="cv2")
if origin_img is None:
raise ValueError(f"Face to replace image not found: {facetoreplace}")
disable_similarity = origin_img is None
faces_config = [{
'origin': origin_img,
'destination': dest_img,
'threshold': threshold
}]
refacer.prepare_faces(faces_config, disable_similarity=disable_similarity)
print(f"Processing images from: {input_dir}")
image_files = list(input_dir)
supported_exts = {'jpg', 'jpeg', 'png', 'bmp', 'webp'}
refaced_images_url = []
for i, image_path in enumerate(image_files):
ext = image_path.rsplit(".", 1)[1].lower()
if ext not in supported_exts:
print(f"Skipping non-image file: {image_path}")
continue
print(f"Refacing: {image_path}")
try:
refaced_image = refacer.reface_image(str(image_path), faces_config, disable_similarity=disable_similarity)
refaced_image_rgb = cv2.cvtColor(refaced_image, cv2.COLOR_RGB2BGR)
image_bytes = cv2.imencode('.jpg', refaced_image_rgb)[1].tobytes()
req = oss_upload_image(oss_client=minio_client, bucket="lanecarford", object_name=f"refaced_image/refaced{time.time()}.{ext}", image_bytes=image_bytes)
refaced_images_url.append(f"{req.bucket_name}/{req.object_name}")
print(f"Saved -> {req.bucket_name}/{req.object_name}")
except Exception as e:
print(f"Failed to process {image_path}: {e}")
if __name__ == "__main__":
main()

466
refacer_no_path.py Executable file
View File

@@ -0,0 +1,466 @@
import cv2
import onnxruntime as rt
import sys
from utils.minio_client import oss_get_image, minio_client
sys.path.insert(1, './recognition')
from scrfd import SCRFD
from arcface_onnx import ArcFaceONNX
import os.path as osp
import os
import requests
from tqdm import tqdm
import ffmpeg
import random
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor
from insightface.model_zoo.inswapper import INSwapper
import psutil
from enum import Enum
from insightface.app.common import Face
from insightface.utils.storage import ensure_available
import re
import subprocess
from PIL import Image
import numpy as np
import time
from codeformer_wrapper_no_path import enhance_image, enhance_image_memory
import tempfile
gc = __import__('gc')
# Preload NVIDIA DLLs if Windows
if sys.platform in ("win32", "win64"):
if hasattr(os, "add_dll_directory"):
try:
os.add_dll_directory(r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin")
os.add_dll_directory(r"C:\Program Files\NVIDIA\CUDNN\v9.4\bin\12.6")
except Exception as e:
print(f"[INFO] Failed to add CUDA or CUDNN DLL directory: {e}")
print("[INFO] This error can be ignored if running in CPU mode. Otherwise, make sure the paths are correct.")
if hasattr(rt, "preload_dlls"):
rt.preload_dlls()
class RefacerMode(Enum):
CPU, CUDA, COREML, TENSORRT = range(1, 5)
class Refacer:
def __init__(self, force_cpu=False, colab_performance=False):
self.disable_similarity = False
self.multiple_faces_mode = False
self.first_face = False
self.force_cpu = force_cpu
self.colab_performance = colab_performance
self.use_num_cpus = mp.cpu_count()
self.__check_encoders()
self.__check_providers()
self.total_mem = psutil.virtual_memory().total
self.__init_apps()
def _partial_face_blend(self, original_frame, swapped_frame, face):
h_frame, w_frame = original_frame.shape[:2]
x1, y1, x2, y2 = map(int, face.bbox)
x1 = max(0, min(x1, w_frame - 1))
y1 = max(0, min(y1, h_frame - 1))
x2 = max(0, min(x2, w_frame))
y2 = max(0, min(y2, h_frame))
if x2 <= x1 or y2 <= y1:
print(f"Invalid bbox: {x1},{y1},{x2},{y2}")
return swapped_frame
w = x2 - x1
h = y2 - y1
cutoff = int(h * (1.0 - self.blend_height_ratio))
swap_crop = swapped_frame[y1:y2, x1:x2].copy()
orig_crop = original_frame[y1:y2, x1:x2].copy()
mask = np.ones((h, w, 3), dtype=np.float32)
transition = 40
if cutoff < h:
blend_start = max(cutoff - transition // 2, 0)
blend_end = min(cutoff + transition // 2, h)
if blend_end > blend_start:
alpha = np.linspace(1.0, 0.0, blend_end - blend_start)[:, np.newaxis, np.newaxis]
mask[blend_start:blend_end, :, :] = alpha
mask[blend_end:, :, :] = 0.0
blended_crop = (swap_crop.astype(np.float32) * mask + orig_crop.astype(np.float32) * (1.0 - mask)).astype(np.uint8)
blended_frame = swapped_frame.copy()
blended_frame[y1:y2, x1:x2] = blended_crop
return blended_frame
def __download_with_progress(self, url, output_path):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024
t = tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"Downloading {os.path.basename(output_path)}")
with open(output_path, 'wb') as f:
for data in response.iter_content(block_size):
t.update(len(data))
f.write(data)
t.close()
if total_size != 0 and t.n != total_size:
raise Exception("ERROR, something went wrong downloading the model!")
def __check_providers(self):
available_providers = rt.get_available_providers()
if self.force_cpu:
self.providers = ['CPUExecutionProvider']
else:
# Prefer faster execution providers in order
self.providers = []
for p in ['CoreMLExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']:
if p in available_providers:
self.providers.append(p)
rt.set_default_logger_severity(4)
self.sess_options = rt.SessionOptions()
self.sess_options.execution_mode = rt.ExecutionMode.ORT_PARALLEL
self.sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
test_model = os.path.expanduser("~/.insightface/models/buffalo_l/det_10g.onnx")
try:
test_session = rt.InferenceSession(test_model, self.sess_options, providers=self.providers)
active_provider = test_session.get_providers()[0]
except Exception as e:
print(f"[ERROR] Failed to create test session: {e}")
active_provider = 'CPUExecutionProvider'
if active_provider == 'CUDAExecutionProvider':
self.mode = RefacerMode.CUDA
self.use_num_cpus = 2
self.sess_options.intra_op_num_threads = 1
elif active_provider == 'CoreMLExecutionProvider':
self.mode = RefacerMode.COREML
self.use_num_cpus = max(mp.cpu_count() - 1, 1)
self.sess_options.intra_op_num_threads = int(self.use_num_cpus / 2)
elif self.colab_performance:
self.mode = RefacerMode.TENSORRT
self.use_num_cpus = max(mp.cpu_count() - 1, 1)
self.sess_options.intra_op_num_threads = int(self.use_num_cpus / 2)
else:
self.mode = RefacerMode.CPU
self.use_num_cpus = max(mp.cpu_count() - 1, 1)
self.sess_options.intra_op_num_threads = int(self.use_num_cpus / 2)
print(f"Available providers: {available_providers}")
print(f"Using providers: {self.providers}")
print(f"Active provider: {active_provider}")
print(f"Mode: {self.mode}")
def __init_apps(self):
assets_dir = ensure_available('models', 'buffalo_l', root='~/.insightface')
model_path = os.path.join(assets_dir, 'det_10g.onnx')
sess_face = rt.InferenceSession(model_path, self.sess_options, providers=self.providers)
print(f"Face Detector providers: {sess_face.get_providers()}")
self.face_detector = SCRFD(model_path, sess_face)
self.face_detector.prepare(0, input_size=(640, 640))
model_path = os.path.join(assets_dir, 'w600k_r50.onnx')
sess_rec = rt.InferenceSession(model_path, self.sess_options, providers=self.providers)
print(f"Face Recognizer providers: {sess_rec.get_providers()}")
self.rec_app = ArcFaceONNX(model_path, sess_rec)
self.rec_app.prepare(0)
model_dir = os.path.join('weights', 'inswapper')
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, 'inswapper_128.onnx')
if not os.path.exists(model_path):
print(f"Model {model_path} not found. Downloading from HuggingFace...")
url = "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx"
try:
self.__download_with_progress(url, model_path)
print(f"Downloaded {model_path}")
except Exception as e:
raise RuntimeError(f"Failed to download {model_path}. Error: {e}")
sess_swap = rt.InferenceSession(model_path, self.sess_options, providers=self.providers)
print(f"Face Swapper providers: {sess_swap.get_providers()}")
self.face_swapper = INSwapper(model_path, sess_swap)
def prepare_faces(self, faces, disable_similarity=False, multiple_faces_mode=False):
self.replacement_faces = []
self.disable_similarity = disable_similarity
self.multiple_faces_mode = multiple_faces_mode
for face in faces:
if "destination" not in face or face["destination"] is None:
print("Skipping face config: No destination face provided.")
continue
_faces = self.__get_faces(face['destination'], max_num=1)
if len(_faces) < 1:
raise Exception('No face detected on "Destination face" image')
if multiple_faces_mode:
self.replacement_faces.append((None, _faces[0], 0.0))
else:
if "origin" in face and face["origin"] is not None and not disable_similarity:
face_threshold = face['threshold']
bboxes1, kpss1 = self.face_detector.autodetect(face['origin'], max_num=1)
if len(kpss1) < 1:
raise Exception('No face detected on "Face to replace" image')
feat_original = self.rec_app.get(face['origin'], kpss1[0])
else:
face_threshold = 0
self.first_face = True
feat_original = None
self.replacement_faces.append((feat_original, _faces[0], face_threshold))
def __get_faces(self, frame, max_num=0):
bboxes, kpss = self.face_detector.detect(frame, max_num=max_num, metric='default')
if bboxes.shape[0] == 0:
return []
ret = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i, 0:4]
det_score = bboxes[i, 4]
kps = kpss[i] if kpss is not None else None
face = Face(bbox=bbox, kps=kps, det_score=det_score)
face.embedding = self.rec_app.get(frame, kps)
ret.append(face)
return ret
def process_first_face(self, frame):
faces = self.__get_faces(frame, max_num=0)
if not faces:
return frame
if self.disable_similarity:
for face in faces:
swapped = self.face_swapper.get(frame, face, self.replacement_faces[0][1], paste_back=True)
if hasattr(self, 'partial_reface_ratio') and self.partial_reface_ratio > 0.0:
self.blend_height_ratio = self.partial_reface_ratio
frame = self._partial_face_blend(frame, swapped, face)
else:
frame = swapped
return frame
def process_faces(self, frame):
faces = self.__get_faces(frame, max_num=0)
if not faces:
return frame
faces = sorted(faces, key=lambda face: face.bbox[0])
if self.multiple_faces_mode:
for idx, face in enumerate(faces):
if idx >= len(self.replacement_faces):
break
swapped = self.face_swapper.get(frame, face, self.replacement_faces[idx][1], paste_back=True)
if hasattr(self, 'partial_reface_ratio') and self.partial_reface_ratio > 0.0:
self.blend_height_ratio = self.partial_reface_ratio
frame = self._partial_face_blend(frame, swapped, face)
else:
frame = swapped
elif self.disable_similarity:
for face in faces:
swapped = self.face_swapper.get(frame, face, self.replacement_faces[0][1], paste_back=True)
if hasattr(self, 'partial_reface_ratio') and self.partial_reface_ratio > 0.0:
self.blend_height_ratio = self.partial_reface_ratio
frame = self._partial_face_blend(frame, swapped, face)
else:
frame = swapped
else:
for rep_face in self.replacement_faces:
for i in range(len(faces) - 1, -1, -1):
sim = self.rec_app.compute_sim(rep_face[0], faces[i].embedding)
if sim >= rep_face[2]:
swapped = self.face_swapper.get(frame, faces[i], rep_face[1], paste_back=True)
if hasattr(self, 'partial_reface_ratio') and self.partial_reface_ratio > 0.0:
self.blend_height_ratio = self.partial_reface_ratio
frame = self._partial_face_blend(frame, swapped, faces[i])
else:
frame = swapped
del faces[i]
break
return frame
def reface_group(self, faces, frames, output):
with ThreadPoolExecutor(max_workers=self.use_num_cpus) as executor:
if self.first_face:
results = list(tqdm(executor.map(self.process_first_face, frames), total=len(frames), desc="Processing frames"))
else:
results = list(tqdm(executor.map(self.process_faces, frames), total=len(frames), desc="Processing frames"))
for result in results:
output.write(result)
def __check_video_has_audio(self, video_path):
self.video_has_audio = False
probe = ffmpeg.probe(video_path)
audio_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), None)
if audio_stream is not None:
self.video_has_audio = True
def reface(self, video_path, faces, preview=False, disable_similarity=False, multiple_faces_mode=False, partial_reface_ratio=0.0):
original_name = osp.splitext(osp.basename(video_path))[0]
timestamp = str(int(time.time()))
filename = f"{original_name}_preview.mp4" if preview else f"{original_name}_{timestamp}.mp4"
self.__check_video_has_audio(video_path)
if preview:
os.makedirs("output/preview", exist_ok=True)
output_video_path = os.path.join('output', 'preview', filename)
else:
os.makedirs("output", exist_ok=True)
output_video_path = os.path.join('output', filename)
self.prepare_faces(faces, disable_similarity=disable_similarity, multiple_faces_mode=multiple_faces_mode)
self.first_face = False if multiple_faces_mode else (faces[0].get("origin") is None or disable_similarity)
self.partial_reface_ratio = partial_reface_ratio
cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
output = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
frames = []
frame_index = 0
skip_rate = 10 if preview else 1
with tqdm(total=total_frames, desc="Extracting frames") as pbar:
while cap.isOpened():
flag, frame = cap.read()
if not flag:
break
if frame_index % skip_rate == 0:
frames.append(frame)
if len(frames) > 300:
self.reface_group(faces, frames, output)
frames = []
gc.collect()
frame_index += 1
pbar.update()
cap.release()
if frames:
self.reface_group(faces, frames, output)
output.release()
converted_path = self.__convert_video(video_path, output_video_path, preview=preview)
if video_path.lower().endswith(".gif"):
if preview:
gif_output_path = os.path.join("output", "preview", os.path.basename(converted_path).replace(".mp4", ".gif"))
else:
gif_output_path = os.path.join("output", "gifs", os.path.basename(converted_path).replace(".mp4", ".gif"))
self.__generate_gif(converted_path, gif_output_path)
return converted_path, gif_output_path
return converted_path, None
def __generate_gif(self, video_path, gif_output_path):
os.makedirs(os.path.dirname(gif_output_path), exist_ok=True)
print(f"Generating GIF at {gif_output_path}")
(
ffmpeg
.input(video_path)
.output(gif_output_path, vf='fps=10,scale=512:-1:flags=lanczos', loop=0)
.overwrite_output()
.run(quiet=True)
)
def __convert_video(self, video_path, output_video_path, preview=False):
if self.video_has_audio and not preview:
new_path = output_video_path + str(random.randint(0, 999)) + "_c.mp4"
in1 = ffmpeg.input(output_video_path)
in2 = ffmpeg.input(video_path)
out = ffmpeg.output(in1.video, in2.audio, new_path, video_bitrate=self.ffmpeg_video_bitrate, vcodec=self.ffmpeg_video_encoder)
out.run(overwrite_output=True, quiet=True)
else:
new_path = output_video_path
print(f"Refaced video saved at: {os.path.abspath(new_path)}")
return new_path
def reface_image(self, bgr_image, faces, disable_similarity=False, multiple_faces_mode=False, partial_reface_ratio=0.0):
self.prepare_faces(faces, disable_similarity=disable_similarity, multiple_faces_mode=multiple_faces_mode)
self.first_face = False if multiple_faces_mode else (faces[0].get("origin") is None or disable_similarity)
self.partial_reface_ratio = partial_reface_ratio
if bgr_image is None:
raise ValueError("Failed to read input image")
refaced_bgr = self.process_first_face(bgr_image.copy()) if self.first_face else self.process_faces(bgr_image.copy())
refaced_rgb = cv2.cvtColor(refaced_bgr, cv2.COLOR_BGR2RGB)
output_image = enhance_image(refaced_rgb)
return output_image
def extract_faces_from_image(self, image_path, max_faces=5):
frame = cv2.imread(image_path)
if frame is None:
raise ValueError("Failed to read input image for face extraction.")
faces = self.__get_faces(frame, max_num=max_faces)
cropped_faces = []
for face in faces:
x1, y1, x2, y2 = map(int, face.bbox)
x1 = max(x1, 0)
y1 = max(y1, 0)
x2 = min(x2, frame.shape[1])
y2 = min(y2, frame.shape[0])
cropped = frame[y1:y2, x1:x2]
pil_img = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
temp_file = tempfile.NamedTemporaryFile(delete=False, dir="./tmp", suffix=".png")
pil_img.save(temp_file.name)
cropped_faces.append(temp_file.name)
if len(cropped_faces) >= max_faces:
break
return cropped_faces
def __try_ffmpeg_encoder(self, vcodec):
command = ['ffmpeg', '-y', '-f', 'lavfi', '-i', 'testsrc=duration=1:size=1280x720:rate=30', '-vcodec', vcodec, 'testsrc.mp4']
try:
subprocess.run(command, check=True, capture_output=True).stderr
except subprocess.CalledProcessError:
return False
return True
def __check_encoders(self):
self.ffmpeg_video_encoder = 'libx264'
self.ffmpeg_video_bitrate = '0'
pattern = r"encoders: ([a-zA-Z0-9_]+(?: [a-zA-Z0-9_]+)*)"
command = ['ffmpeg', '-codecs', '--list-encoders']
commandout = subprocess.run(command, check=True, capture_output=True).stdout
result = commandout.decode('utf-8').split('\n')
for r in result:
if "264" in r:
encoders = re.search(pattern, r)
if encoders:
for v_c in Refacer.VIDEO_CODECS:
for v_k in encoders.group(1).split(' '):
if v_c == v_k and self.__try_ffmpeg_encoder(v_k):
self.ffmpeg_video_encoder = v_k
self.ffmpeg_video_bitrate = Refacer.VIDEO_CODECS[v_k]
return
VIDEO_CODECS = {
'h264_videotoolbox': '0',
'h264_nvenc': '0',
'libx264': '0'
}

Some files were not shown because too many files have changed in this diff Show More