first commit
This commit is contained in:
70
litserver_main.py
Normal file
70
litserver_main.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import litserve as ls
|
||||
from pydantic import BaseModel
|
||||
|
||||
from refacer_no_path import Refacer as NoPathRefacer
|
||||
from utils.minio_client import oss_get_image, minio_client, oss_upload_image
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
input_image_list: list[str] # 待换脸图片
|
||||
input_face: str # 目标脸图片
|
||||
threshold: float = 0.2 # 相似度 max:0.5
|
||||
|
||||
|
||||
class InferencePipeline(ls.LitAPI):
|
||||
def setup(self, device):
|
||||
force_cpu = False
|
||||
colab_performance = False
|
||||
self.supported_exts = {'jpg', 'jpeg', 'png', 'bmp', 'webp'}
|
||||
self.refacer = NoPathRefacer(force_cpu=force_cpu, colab_performance=colab_performance)
|
||||
|
||||
def decode_request(self, request: PredictRequest):
|
||||
self.input_image_list = []
|
||||
for path in request.input_image_list:
|
||||
self.input_image_list.append({
|
||||
'img_obj': oss_get_image(oss_client=minio_client, path=path, data_type="cv2"),
|
||||
'img_path': path
|
||||
})
|
||||
dest_img = oss_get_image(oss_client=minio_client, path=request.input_face, data_type="cv2")
|
||||
faces_config = [
|
||||
{
|
||||
'origin': None,
|
||||
'destination': dest_img,
|
||||
'destination_path': request.input_face,
|
||||
'threshold': request.threshold,
|
||||
}
|
||||
]
|
||||
self.refacer.prepare_faces(faces_config)
|
||||
return faces_config
|
||||
|
||||
def predict(self, faces_config):
|
||||
refaced_images_url = []
|
||||
for i, image in enumerate(self.input_image_list):
|
||||
ext = image['img_path'].rsplit(".", 1)[1].lower()
|
||||
|
||||
if ext not in self.supported_exts:
|
||||
print(f"Skipping non-image file: {image['img_path']}")
|
||||
continue
|
||||
print(f"Refacing: {image['img_path']}")
|
||||
try:
|
||||
refaced_image = self.refacer.reface_image(image['img_obj'], faces_config, disable_similarity=True)
|
||||
refaced_image_rgb = cv2.cvtColor(refaced_image, cv2.COLOR_RGB2BGR)
|
||||
image_bytes = cv2.imencode('.jpg', refaced_image_rgb)[1].tobytes()
|
||||
req = oss_upload_image(oss_client=minio_client, bucket="lanecarford", object_name=f"refaced_image/refaced{time.time()}.{ext}", image_bytes=image_bytes)
|
||||
refaced_images_url.append(f"{req.bucket_name}/{req.object_name}")
|
||||
print(f"Saved -> {req.bucket_name}/{req.object_name}")
|
||||
except Exception as e:
|
||||
print(f"Failed to process {image['img_path']}: {e}")
|
||||
return refaced_images_url
|
||||
|
||||
def encode_response(self, output):
|
||||
return {"output": output}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
api = InferencePipeline()
|
||||
server = ls.LitServer(api, accelerator="gpu")
|
||||
server.run(port=8080)
|
||||
Reference in New Issue
Block a user