diff --git a/refacer.py b/refacer.py index 9ae0d36..5be6d70 100644 --- a/refacer.py +++ b/refacer.py @@ -61,7 +61,7 @@ class Refacer: print(f"CoreML mode with providers {self.providers}") elif 'CUDAExecutionProvider' in self.providers: self.mode = RefacerMode.CUDA - self.use_num_cpus = 1 + self.use_num_cpus = 2 self.sess_options.intra_op_num_threads = 1 if 'TensorrtExecutionProvider' in self.providers: self.providers.remove('TensorrtExecutionProvider') @@ -149,21 +149,21 @@ class Refacer: ret.append(face) return ret + def process_first_face(self,frame): + faces = self.__get_faces(frame,max_num=1) + if len(faces) != 0: + frame = self.face_swapper.get(frame, faces[0], self.replacement_faces[0][1], paste_back=True) + return frame + def process_faces(self,frame): - max_num=0 - if self.first_face: - max_num=1 - - faces = self.__get_faces(frame,max_num=max_num) - for face in faces: - if self.first_face: - frame = self.face_swapper.get(frame, face, self.replacement_faces[0][1], paste_back=True) - break - else: - for rep_face in self.replacement_faces: - sim = self.rec_app.compute_sim(rep_face[0], face.embedding) - if sim>=rep_face[2]: - frame = self.face_swapper.get(frame, face, rep_face[1], paste_back=True) + faces = self.__get_faces(frame,max_num=0) + for rep_face in self.replacement_faces: + for i in range(len(faces) - 1, -1, -1): + sim = self.rec_app.compute_sim(rep_face[0], faces[i].embedding) + if sim>=rep_face[2]: + frame = self.face_swapper.get(frame, faces[i], rep_face[1], paste_back=True) + del faces[i] + break return frame def __check_video_has_audio(self,video_path): @@ -175,7 +175,10 @@ class Refacer: def reface_group(self, faces, frames, output): with ThreadPoolExecutor(max_workers = self.use_num_cpus) as executor: - results = list(tqdm(executor.map(self.process_faces, frames), total=len(frames),desc="Processing frames")) + if self.first_face: + results = list(tqdm(executor.map(self.process_first_face, frames), total=len(frames),desc="Processing frames")) + else: + results = list(tqdm(executor.map(self.process_faces, frames), total=len(frames),desc="Processing frames")) for result in results: output.write(result)