import time import cv2 import litserve as ls from pydantic import BaseModel from refacer_no_path import Refacer as NoPathRefacer from utils.minio_client import oss_get_image, minio_client, oss_upload_image class PredictRequest(BaseModel): input_image_list: list[str] # 待换脸图片 input_face: str # 目标脸图片 threshold: float = 0.2 # 相似度 max:0.5 class InferencePipeline(ls.LitAPI): def setup(self, device): force_cpu = False colab_performance = False self.supported_exts = {'jpg', 'jpeg', 'png', 'bmp', 'webp'} self.refacer = NoPathRefacer(force_cpu=force_cpu, colab_performance=colab_performance) def decode_request(self, request: PredictRequest): self.input_image_list = [] for path in request.input_image_list: self.input_image_list.append({ 'img_obj': oss_get_image(oss_client=minio_client, path=path, data_type="cv2"), 'img_path': path }) dest_img = oss_get_image(oss_client=minio_client, path=request.input_face, data_type="cv2") faces_config = [ { 'origin': None, 'destination': dest_img, 'destination_path': request.input_face, 'threshold': request.threshold, } ] self.refacer.prepare_faces(faces_config) return faces_config def predict(self, faces_config): refaced_images_url = [] for i, image in enumerate(self.input_image_list): ext = image['img_path'].rsplit(".", 1)[1].lower() if ext not in self.supported_exts: print(f"Skipping non-image file: {image['img_path']}") continue print(f"Refacing: {image['img_path']}") try: refaced_image = self.refacer.reface_image(image['img_obj'], faces_config, disable_similarity=True) refaced_image_rgb = cv2.cvtColor(refaced_image, cv2.COLOR_RGB2BGR) image_bytes = cv2.imencode('.jpg', refaced_image_rgb)[1].tobytes() req = oss_upload_image(oss_client=minio_client, bucket="lanecarford", object_name=f"refaced_image/refaced{time.time()}.{ext}", image_bytes=image_bytes) refaced_images_url.append(f"{req.bucket_name}/{req.object_name}") print(f"Saved -> {req.bucket_name}/{req.object_name}") except Exception as e: print(f"Failed to process {image['img_path']}: {e}") return refaced_images_url def encode_response(self, output): return {"output": output} if __name__ == '__main__': api = InferencePipeline() server = ls.LitServer(api, accelerator="auto") server.run(port=8000)