71 lines
2.7 KiB
Python
71 lines
2.7 KiB
Python
|
|
import time
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import litserve as ls
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
from refacer_no_path import Refacer as NoPathRefacer
|
|||
|
|
from utils.minio_client import oss_get_image, minio_client, oss_upload_image
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PredictRequest(BaseModel):
|
|||
|
|
input_image_list: list[str] # 待换脸图片
|
|||
|
|
input_face: str # 目标脸图片
|
|||
|
|
threshold: float = 0.2 # 相似度 max:0.5
|
|||
|
|
|
|||
|
|
|
|||
|
|
class InferencePipeline(ls.LitAPI):
|
|||
|
|
def setup(self, device):
|
|||
|
|
force_cpu = False
|
|||
|
|
colab_performance = False
|
|||
|
|
self.supported_exts = {'jpg', 'jpeg', 'png', 'bmp', 'webp'}
|
|||
|
|
self.refacer = NoPathRefacer(force_cpu=force_cpu, colab_performance=colab_performance)
|
|||
|
|
|
|||
|
|
def decode_request(self, request: PredictRequest):
|
|||
|
|
self.input_image_list = []
|
|||
|
|
for path in request.input_image_list:
|
|||
|
|
self.input_image_list.append({
|
|||
|
|
'img_obj': oss_get_image(oss_client=minio_client, path=path, data_type="cv2"),
|
|||
|
|
'img_path': path
|
|||
|
|
})
|
|||
|
|
dest_img = oss_get_image(oss_client=minio_client, path=request.input_face, data_type="cv2")
|
|||
|
|
faces_config = [
|
|||
|
|
{
|
|||
|
|
'origin': None,
|
|||
|
|
'destination': dest_img,
|
|||
|
|
'destination_path': request.input_face,
|
|||
|
|
'threshold': request.threshold,
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
self.refacer.prepare_faces(faces_config)
|
|||
|
|
return faces_config
|
|||
|
|
|
|||
|
|
def predict(self, faces_config):
|
|||
|
|
refaced_images_url = []
|
|||
|
|
for i, image in enumerate(self.input_image_list):
|
|||
|
|
ext = image['img_path'].rsplit(".", 1)[1].lower()
|
|||
|
|
|
|||
|
|
if ext not in self.supported_exts:
|
|||
|
|
print(f"Skipping non-image file: {image['img_path']}")
|
|||
|
|
continue
|
|||
|
|
print(f"Refacing: {image['img_path']}")
|
|||
|
|
try:
|
|||
|
|
refaced_image = self.refacer.reface_image(image['img_obj'], faces_config, disable_similarity=True)
|
|||
|
|
refaced_image_rgb = cv2.cvtColor(refaced_image, cv2.COLOR_RGB2BGR)
|
|||
|
|
image_bytes = cv2.imencode('.jpg', refaced_image_rgb)[1].tobytes()
|
|||
|
|
req = oss_upload_image(oss_client=minio_client, bucket="lanecarford", object_name=f"refaced_image/refaced{time.time()}.{ext}", image_bytes=image_bytes)
|
|||
|
|
refaced_images_url.append(f"{req.bucket_name}/{req.object_name}")
|
|||
|
|
print(f"Saved -> {req.bucket_name}/{req.object_name}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Failed to process {image['img_path']}: {e}")
|
|||
|
|
return refaced_images_url
|
|||
|
|
|
|||
|
|
def encode_response(self, output):
|
|||
|
|
return {"output": output}
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
api = InferencePipeline()
|
|||
|
|
server = ls.LitServer(api, accelerator="gpu")
|
|||
|
|
server.run(port=8080)
|