Files
LC_NeoRefacer/litserver_main.py
2025-12-23 11:06:56 +08:00

73 lines
2.8 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
import cv2
import litserve as ls
from pydantic import BaseModel
from refacer_no_path import Refacer as NoPathRefacer
from utils.minio_client import oss_get_image, minio_client, oss_upload_image
class PredictRequest(BaseModel):
input_image_list: list[str] # 待换脸图片
input_face: str # 目标脸图片
threshold: float = 0.2 # 相似度 max0.5
class InferencePipeline(ls.LitAPI):
def setup(self, device):
force_cpu = False
colab_performance = False
self.supported_exts = {'jpg', 'jpeg', 'png', 'bmp', 'webp'}
self.refacer = NoPathRefacer(force_cpu=force_cpu, colab_performance=colab_performance)
def decode_request(self, request: PredictRequest):
self.input_image_list = []
for path in request.input_image_list:
self.input_image_list.append({
'img_obj': oss_get_image(oss_client=minio_client, path=path, data_type="cv2"),
'img_path': path
})
dest_img = oss_get_image(oss_client=minio_client, path=request.input_face, data_type="cv2")
if dest_img.shape[2] == 4:
dest_img = cv2.cvtColor(dest_img, cv2.COLOR_RGBA2RGB)
faces_config = [
{
'origin': None,
'destination': dest_img,
'destination_path': request.input_face,
'threshold': request.threshold,
}
]
self.refacer.prepare_faces(faces_config)
return faces_config
def predict(self, faces_config):
refaced_images_url = []
for i, image in enumerate(self.input_image_list):
ext = image['img_path'].rsplit(".", 1)[1].lower()
if ext not in self.supported_exts:
print(f"Skipping non-image file: {image['img_path']}")
continue
print(f"Refacing: {image['img_path']}")
try:
refaced_image = self.refacer.reface_image(image['img_obj'], faces_config, disable_similarity=True)
refaced_image_rgb = cv2.cvtColor(refaced_image, cv2.COLOR_RGB2BGR)
image_bytes = cv2.imencode('.jpg', refaced_image_rgb)[1].tobytes()
req = oss_upload_image(oss_client=minio_client, bucket="lanecarford", object_name=f"refaced_image/refaced{time.time()}.{ext}", image_bytes=image_bytes)
refaced_images_url.append(f"{req.bucket_name}/{req.object_name}")
print(f"Saved -> {req.bucket_name}/{req.object_name}")
except Exception as e:
print(f"Failed to process {image['img_path']}: {e}")
return refaced_images_url
def encode_response(self, output):
return {"output": output}
if __name__ == '__main__':
api = InferencePipeline()
server = ls.LitServer(api, accelerator="auto")
server.run(port=8000)