diff --git a/configs/trt_infer.yaml b/configs/trt_infer.yaml index f00f0a4..86d666b 100644 --- a/configs/trt_infer.yaml +++ b/configs/trt_infer.yaml @@ -78,10 +78,12 @@ joyvasa_models: crop_params: src_dsize: 512 src_scale: 2.3 - src_vx_ratio: 0.0 + #src_vx_ratio: 0.0 + src_vx_ratio: -0.0 src_vy_ratio: -0.125 dri_scale: 2.2 - dri_vx_ratio: 0.0 + #dri_vx_ratio: 0.0 + dri_vx_ratio: -0. dri_vy_ratio: -0.1 @@ -99,12 +101,12 @@ infer_params: flag_do_rot: True # NOT EXPOERTED PARAMS - lip_normalize_threshold: 0.1 # threshold for flag_normalize_lip - source_video_eye_retargeting_threshold: 0.18 # threshold for eyes retargeting if the input is a source video - driving_smooth_observation_variance: 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy + lip_normalize_threshold: 0.15 # threshold for flag_normalize_lip + source_video_eye_retargeting_threshold: 0.15 # threshold for eyes retargeting if the input is a source video + driving_smooth_observation_variance: 1e-6 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy anchor_frame: 0 # TO IMPLEMENT mask_crop_path: "./assets/mask_template.png" - driving_multiplier: 1.0 + driving_multiplier: .8 animation_region: "all" cfg_mode: "incremental" diff --git a/download_kokoro.py b/download_kokoro.py new file mode 100644 index 0000000..6458975 --- /dev/null +++ b/download_kokoro.py @@ -0,0 +1,57 @@ +import os +from huggingface_hub import snapshot_download, HfApi + +print("Attempting to download Kokoro-82M model...") + +# Try different possible repositories +repositories = [ + 'kokoro-ai/Kokoro-82M', + 'agiresearch/Kokoro-82M', + 'KwaiVGI/Kokoro-82M', + 'kokoro-ai/Kokoro-82M-v0' +] + +api = HfApi() +success = False + +foreach (repo in repositories): + try: + print(f"Checking repository: {repo}") + + # First check if repo exists + model_info = api.model_info(repo) + print(f"Repository exists: {repo}") + print(f"Model ID: {model_info.modelId}") + + # Try to download + print(f"Attempting download from: {repo}") + snapshot_download( + repo_id=repo, + local_dir='./Kokoro-82M', + local_dir_use_symlinks=False, + resume_download=True, + allow_patterns=["*.json", "*.pt", "*.pth", "*.bin", "*.safetensors"] + ) + print(f" Successfully downloaded from {repo}") + success = True + break + + except Exception as e: + print(f" Failed with {repo}: {str(e)}") + continue + +if not success: + print("") + print("="*50) + print("All repositories failed. Possible solutions:") + print("1. The model might require authentication") + print("2. Check if you need to accept terms of use") + print("3. The model name might be different") + print("") + print("Searching for similar models...") + try: + models = api.list_models(search="kokoro") + for model in list(models)[:5]: + print(f"Found: {model.modelId}") + except: + print("Could not search for alternative models") diff --git a/run.py b/run.py index 52ed8d1..025f7f1 100644 --- a/run.py +++ b/run.py @@ -4,77 +4,172 @@ # @Project : FasterLivePortrait # @FileName: run.py -""" -# video - python run.py \ - --src_image assets/examples/driving/d13.mp4 \ - --dri_video assets/examples/driving/d11.mp4 \ - --cfg configs/trt_infer.yaml \ - --paste_back \ - --animal -# pkl - python run.py \ - --src_image assets/examples/source/s12.jpg \ - --dri_video ./results/2024-09-13-081710/d0.mp4.pkl \ - --cfg configs/trt_infer.yaml \ - --paste_back \ - --animal -""" -import os import argparse -import pdb +import datetime +import os +import pickle +import platform import subprocess -import ffmpeg -import cv2 import time + +import cv2 import numpy as np -import os -import datetime -import platform -import pickle +from colorama import Fore, Style from omegaconf import OmegaConf from tqdm import tqdm -from colorama import Fore, Back, Style + from src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline from src.utils.utils import video_has_audio -if platform.system().lower() == 'windows': +if platform.system().lower() == "windows": FFMPEG = "third_party/ffmpeg-7.0.1-full_build/bin/ffmpeg.exe" else: FFMPEG = "ffmpeg" +# === ADJUSTABLE SMOOTHING STABILIZER === +class AdjustableStabilizer: + """Smoothing stabilizer with real-time adjustable parameters""" + + def __init__( + self, alpha=0.3, movement_threshold=1.5 + ): # FIXED: Consistent threshold + self.alpha = alpha + self.movement_threshold = movement_threshold + self.prev_frame = None + self.is_moving = False + + def stabilize(self, frame): + if self.prev_frame is None: + self.prev_frame = frame.copy() + return frame + + # Calculate movement + movement = self.calculate_movement(self.prev_frame, frame) + + # Only apply smoothing when still + if movement < self.movement_threshold: + smoothed = cv2.addWeighted( + frame, 1 - self.alpha, self.prev_frame, self.alpha, 0 + ) + self.prev_frame = smoothed.copy() + self.is_moving = False + return smoothed + else: + self.prev_frame = frame.copy() + self.is_moving = True + return frame + + def calculate_movement(self, prev_frame, curr_frame): + prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY) + curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY) + + flow = cv2.calcOpticalFlowFarneback( + prev_gray, curr_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0 + ) + + magnitude = np.sqrt(flow[..., 0] ** 2 + flow[..., 1] ** 2) + return np.mean(magnitude) + + def increase_smoothing(self): + self.alpha = min(0.9, self.alpha + 0.1) + print(f"Smoothing increased: alpha={self.alpha:.2f}") + + def decrease_smoothing(self): + self.alpha = max(0.1, self.alpha - 0.1) + print(f"Smoothing decreased: alpha={self.alpha:.2f}") + + def increase_threshold(self): + self.movement_threshold = min(20.0, self.movement_threshold + 1.0) + print(f"Movement threshold increased: {self.movement_threshold:.1f}") + + def decrease_threshold(self): + self.movement_threshold = max(1.0, self.movement_threshold - 1.0) + print(f"Movement threshold decreased: {self.movement_threshold:.1f}") + + def run_with_video(args): - print(Fore.RED+'Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate'+Style.RESET_ALL) + print( + Fore.RED + + "Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate" + + Style.RESET_ALL + ) + print( + Fore.GREEN + + "1/2 > Smoothing, 3/4 > Movement Threshold, 0 > Show Settings" + + Style.RESET_ALL + ) + infer_cfg = OmegaConf.load(args.cfg) infer_cfg.infer_params.flag_pasteback = args.paste_back + # Good settings for stability + infer_cfg.infer_params.flag_relative_motion = True + infer_cfg.infer_params.flag_stitching = True + infer_cfg.infer_params.animation_region = "all" + pipe = FasterLivePortraitPipeline(cfg=infer_cfg, is_animal=args.animal) - ret = pipe.prepare_source(args.src_image, realtime=args.realtime) + + # FORCE RESIZE FOR REALTIME PERFORMANCE + if args.realtime: + print("Optimizing for realtime performance...") + # Create a temporary resized version if source is large + original_img = cv2.imread(args.src_image) + if original_img is not None and max(original_img.shape[:2]) > 512: + print( + f"Source image is large: {original_img.shape}, resizing to 512x512 for better performance" + ) + temp_resized_path = "temp_resized_source.jpg" + resized_img = cv2.resize(original_img, (512, 512)) + cv2.imwrite(temp_resized_path, resized_img) + ret = pipe.prepare_source(temp_resized_path, realtime=args.realtime) + # Clean up temp file + try: + os.remove(temp_resized_path) + except: + pass + else: + ret = pipe.prepare_source(args.src_image, realtime=args.realtime) + else: + ret = pipe.prepare_source(args.src_image, realtime=args.realtime) + if not ret: print(f"no face in {args.src_image}! exit!") exit(1) - if not args.dri_video or not os.path.exists(args.dri_video): - # read frame from camera if no driving video input + + print(f"Source image size for processing: {pipe.src_imgs[0].shape}") + + # Initialize adjustable stabilizer - FIXED: Consistent threshold + stabilizer = ( + AdjustableStabilizer(alpha=0.3, movement_threshold=1.5) + if args.realtime + else None + ) + + if args.dri_video and os.path.exists(args.dri_video): + vcap = cv2.VideoCapture(args.dri_video) + else: vcap = cv2.VideoCapture(0) if not vcap.isOpened(): print("no camera found! exit!") exit(1) - else: - vcap = cv2.VideoCapture(args.dri_video) + fps = int(vcap.get(cv2.CAP_PROP_FPS)) h, w = pipe.src_imgs[0].shape[:2] save_dir = f"./results/{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}" os.makedirs(save_dir, exist_ok=True) - # render output video if not args.realtime: - fourcc = cv2.VideoWriter_fourcc(*'mp4v') - vsave_crop_path = os.path.join(save_dir, - f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-crop.mp4") + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + vsave_crop_path = os.path.join( + save_dir, + f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-crop.mp4", + ) vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512 * 2, 512)) - vsave_org_path = os.path.join(save_dir, - f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-org.mp4") + vsave_org_path = os.path.join( + save_dir, + f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-org.mp4", + ) vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h)) infer_times = [] @@ -83,14 +178,29 @@ def run_with_video(args): c_lip_lst = [] frame_ind = 0 + + # === FORCE PASTEBACK FOR REALTIME MODE === + if args.realtime: + infer_cfg.infer_params.flag_pasteback = True + infer_cfg.infer_params.flag_do_crop = True + infer_cfg.infer_params.flag_stitching = True + print("Forced pasteback enabled for realtime full output") + while vcap.isOpened(): ret, frame = vcap.read() if not ret: break t0 = time.time() first_frame = frame_ind == 0 - dri_crop, out_crop, out_org, dri_motion_info = pipe.run(frame, pipe.src_imgs[0], pipe.src_infos[0], - first_frame=first_frame) + + # FIXED: Correct parameter order for pipe.run() + dri_crop, out_crop, out_org, dri_motion_info = pipe.run( + frame, # image (driving frame) FIRST + pipe.src_imgs[0], # img_src (source image) SECOND + pipe.src_infos[0], # src_info (source info) THIRD + first_frame=first_frame, + ) + frame_ind += 1 if out_crop is None: print(f"no face in driving frame:{frame_ind}") @@ -101,24 +211,82 @@ def run_with_video(args): c_lip_lst.append(dri_motion_info[2]) infer_times.append(time.time() - t0) - # print(time.time() - t0) dri_crop = cv2.resize(dri_crop, (512, 512)) out_crop = np.concatenate([dri_crop, out_crop], axis=1) out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR) + + # Apply stabilization + if args.realtime and stabilizer: + if infer_cfg.infer_params.flag_pasteback and out_org is not None: + out_org = stabilizer.stabilize(out_org) + else: + out_crop = stabilizer.stabilize(out_crop) + if not args.realtime: vout_crop.write(out_crop) - out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR) - vout_org.write(out_org) - else: - if infer_cfg.infer_params.flag_pasteback: + if out_org is not None: out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR) - cv2.imshow('Render', out_org) + vout_org.write(out_org) + else: + # FIXED: Safe display with fallback + if out_org is not None: + out_org_display = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR) + cv2.imshow("Render", out_org_display) else: - # image show in realtime mode - cv2.imshow('Render', out_crop) - # 按下'q'键退出循环 - if cv2.waitKey(1) & 0xFF == ord('q'): + # Fallback to cropped view if full output isn't available + cv2.imshow("Render", out_crop) + + k = cv2.waitKey(1) & 0xFF + if k == ord("q"): break + # Existing keys + if k == ord("s"): + infer_cfg.infer_params.flag_stitching = ( + not infer_cfg.infer_params.flag_stitching + ) + print("flag_stitching:" + str(infer_cfg.infer_params.flag_stitching)) + if k == ord("z"): + infer_cfg.infer_params.flag_relative_motion = ( + not infer_cfg.infer_params.flag_relative_motion + ) + print( + "flag_relative_motion:" + + str(infer_cfg.infer_params.flag_relative_motion) + ) + if k == ord("x"): + if infer_cfg.infer_params.animation_region == "all": + infer_cfg.infer_params.animation_region = "exp" + print('animation_region = "exp"') + else: + infer_cfg.infer_params.animation_region = "all" + print('animation_region = "all"') + if k == ord("c"): + infer_cfg.infer_params.flag_crop_driving_video = ( + not infer_cfg.infer_params.flag_crop_driving_video + ) + print( + "flag_crop_driving_video:" + + str(infer_cfg.infer_params.flag_crop_driving_video) + ) + # NEW SMOOTHING CONTROLS + if k == ord("1"): # Decrease smoothing (less blur, more jitter) + if stabilizer: + stabilizer.decrease_smoothing() + if k == ord("2"): # Increase smoothing (more blur, less jitter) + if stabilizer: + stabilizer.increase_smoothing() + if k == ord("3"): # Decrease movement threshold (smoother when moving) + if stabilizer: + stabilizer.decrease_threshold() + if k == ord("4"): # Increase movement threshold (less smooth when moving) + if stabilizer: + stabilizer.increase_threshold() + if k == ord("0"): # Print current settings + if stabilizer: + print( + f"Current: alpha={stabilizer.alpha:.2f}, threshold={stabilizer.movement_threshold:.1f}" + ) + vcap.release() if not args.realtime: vout_crop.release() @@ -126,19 +294,54 @@ def run_with_video(args): if video_has_audio(args.dri_video): vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4" subprocess.call( - [FFMPEG, "-i", vsave_crop_path, "-i", args.dri_video, - "-b:v", "10M", "-c:v", - "libx264", "-map", "0:v", "-map", "1:a", - "-c:a", "aac", - "-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"]) + [ + FFMPEG, + "-i", + vsave_crop_path, + "-i", + args.dri_video, + "-b:v", + "10M", + "-c:v", + "libx264", + "-map", + "0:v", + "-map", + "1:a", + "-c:a", + "aac", + "-pix_fmt", + "yuv420p", + vsave_crop_path_new, + "-y", + "-shortest", + ] + ) vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4" subprocess.call( - [FFMPEG, "-i", vsave_org_path, "-i", args.dri_video, - "-b:v", "10M", "-c:v", - "libx264", "-map", "0:v", "-map", "1:a", - "-c:a", "aac", - "-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"]) - + [ + FFMPEG, + "-i", + vsave_org_path, + "-i", + args.dri_video, + "-b:v", + "10M", + "-c:v", + "libx264", + "-map", + "0:v", + "-map", + "1:a", + "-c:a", + "aac", + "-pix_fmt", + "yuv420p", + vsave_org_path_new, + "-y", + "-shortest", + ] + ) print(vsave_crop_path_new) print(vsave_org_path_new) else: @@ -148,18 +351,21 @@ def run_with_video(args): cv2.destroyAllWindows() print( - "inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000, - np.mean(infer_times) * 1000)) - # save driving motion to pkl + "inference median time: {} ms/frame, mean time: {} ms/frame".format( + np.median(infer_times) * 1000, np.mean(infer_times) * 1000 + ) + ) + template_dct = { - 'n_frames': len(motion_lst), - 'output_fps': fps, - 'motion': motion_lst, - 'c_eyes_lst': c_eyes_lst, - 'c_lip_lst': c_lip_lst, + "n_frames": len(motion_lst), + "output_fps": fps, + "motion": motion_lst, + "c_eyes_lst": c_eyes_lst, + "c_lip_lst": c_lip_lst, } - template_pkl_path = os.path.join(save_dir, - f"{os.path.basename(args.dri_video)}.pkl") + template_pkl_path = os.path.join( + save_dir, f"{os.path.basename(args.dri_video)}.pkl" + ) with open(template_pkl_path, "wb") as fw: pickle.dump(template_dct, fw) print(f"save driving motion pkl file at : {template_pkl_path}") @@ -169,6 +375,11 @@ def run_with_pkl(args): infer_cfg = OmegaConf.load(args.cfg) infer_cfg.infer_params.flag_pasteback = args.paste_back + # Good settings for stability + infer_cfg.infer_params.flag_relative_motion = True + infer_cfg.infer_params.flag_stitching = True + infer_cfg.infer_params.animation_region = "all" + pipe = FasterLivePortraitPipeline(cfg=infer_cfg, is_animal=args.animal) ret = pipe.prepare_source(args.src_image, realtime=args.realtime) if not ret: @@ -182,95 +393,142 @@ def run_with_pkl(args): save_dir = f"./results/{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}" os.makedirs(save_dir, exist_ok=True) - # render output video if not args.realtime: - fourcc = cv2.VideoWriter_fourcc(*'mp4v') - vsave_crop_path = os.path.join(save_dir, - f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-crop.mp4") + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + vsave_crop_path = os.path.join( + save_dir, + f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-crop.mp4", + ) vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512, 512)) - vsave_org_path = os.path.join(save_dir, - f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-org.mp4") + vsave_org_path = os.path.join( + save_dir, + f"{os.path.basename(args.src_image)}-{os.path.basename(args.dri_video)}-org.mp4", + ) vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h)) infer_times = [] motion_lst = dri_motion_infos["motion"] - c_eyes_lst = dri_motion_infos["c_eyes_lst"] if "c_eyes_lst" in dri_motion_infos else dri_motion_infos[ - "c_d_eyes_lst"] - c_lip_lst = dri_motion_infos["c_lip_lst"] if "c_lip_lst" in dri_motion_infos else dri_motion_infos["c_d_lip_lst"] + c_eyes_lst = ( + dri_motion_infos["c_eyes_lst"] + if "c_eyes_lst" in dri_motion_infos + else dri_motion_infos["c_d_eyes_lst"] + ) + c_lip_lst = ( + dri_motion_infos["c_lip_lst"] + if "c_lip_lst" in dri_motion_infos + else dri_motion_infos["c_d_lip_lst"] + ) frame_num = len(motion_lst) + stabilizer = ( + AdjustableStabilizer(alpha=0.3, movement_threshold=1.5) + if args.realtime + else None + ) + + # === FORCE PASTEBACK FOR REALTIME MODE === + if args.realtime: + infer_cfg.infer_params.flag_pasteback = True + infer_cfg.infer_params.flag_do_crop = True + infer_cfg.infer_params.flag_stitching = True + print("Forced pasteback enabled for realtime full output") + for frame_ind in tqdm(range(frame_num)): t0 = time.time() first_frame = frame_ind == 0 - dri_motion_info_ = [motion_lst[frame_ind], c_eyes_lst[frame_ind], c_lip_lst[frame_ind]] - out_crop, out_org = pipe.run_with_pkl(dri_motion_info_, pipe.src_imgs[0], pipe.src_infos[0], - first_frame=first_frame) + dri_motion_info_ = [ + motion_lst[frame_ind], + c_eyes_lst[frame_ind], + c_lip_lst[frame_ind], + ] + out_crop, out_org = pipe.run_with_pkl( + dri_motion_info_, + pipe.src_imgs[0], + pipe.src_infos[0], + first_frame=first_frame, + ) if out_crop is None: print(f"no face in driving frame:{frame_ind}") continue infer_times.append(time.time() - t0) - # print(time.time() - t0) out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR) + + if args.realtime and stabilizer: + if infer_cfg.infer_params.flag_pasteback and out_org is not None: + out_org = stabilizer.stabilize(out_org) + else: + out_crop = stabilizer.stabilize(out_crop) + if not args.realtime: vout_crop.write(out_crop) - out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR) - vout_org.write(out_org) - else: - if infer_cfg.infer_params.flag_pasteback: + if out_org is not None: out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR) - cv2.imshow('Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate',out_org) + vout_org.write(out_org) + else: + # FIXED: Safe display with fallback + if out_org is not None: + out_org_display = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR) + cv2.imshow( + "Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate", + out_org_display, + ) else: - # image show in realtime mode - cv2.imshow('Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate', out_crop) - # Press the 'q' key to exit the loop, r to switch realtime src_webcam update, spacebar to switch sourceisWebcam + cv2.imshow( + "Render, Q > exit, S > Stitching, Z > RelativeMotion, X > AnimationRegion, C > CropDrivingVideo, KL > AdjustSourceScale, NM > AdjustDriverScale, Space > Webcamassource, R > SwitchRealtimeWebcamUpdate", + out_crop, + ) + k = cv2.waitKey(1) & 0xFF - if k == ord('q'): + if k == ord("q"): break - # Key for Interesting Params - if k == ord('s'): - infer_cfg.infer_params.flag_stitching = not infer_cfg.infer_params.flag_stitching - print('flag_stitching:'+str(infer_cfg.infer_params.flag_stitching)) - if k == ord('z'): - infer_cfg.infer_params.flag_relative_motion = not infer_cfg.infer_params.flag_relative_motion - print('flag_relative_motion:'+str(infer_cfg.infer_params.flag_relative_motion)) - if k == ord('x'): - if infer_cfg.infer_params.animation_region == "all": infer_cfg.infer_params.animation_region = "exp", print('animation_region = "exp"') - else:infer_cfg.infer_params.animation_region = "all", print('animation_region = "all"') - if k == ord('c'): - infer_cfg.infer_params.flag_crop_driving_video = not infer_cfg.infer_params.flag_crop_driving_video - print('flag_crop_driving_video:'+str(infer_cfg.infer_params.flag_crop_driving_video)) - if k == ord('v'): - infer_cfg.infer_params.flag_pasteback = not infer_cfg.infer_params.flag_pasteback - print('flag_pasteback:'+str(infer_cfg.infer_params.flag_pasteback)) - - if k == ord('a'): - infer_cfg.infer_params.flag_normalize_lip = not infer_cfg.infer_params.flag_normalize_lip - print('flag_normalize_lip:'+str(infer_cfg.infer_params.flag_normalize_lip)) - if k == ord('d'): - infer_cfg.infer_params.flag_source_video_eye_retargeting = not infer_cfg.infer_params.flag_source_video_eye_retargeting - print('flag_source_video_eye_retargeting:'+str(infer_cfg.infer_params.flag_source_video_eye_retargeting)) - if k == ord('f'): - infer_cfg.infer_params.flag_video_editing_head_rotation = not infer_cfg.infer_params.flag_video_editing_head_rotation - print('flag_video_editing_head_rotation:'+str(infer_cfg.infer_params.flag_video_editing_head_rotation)) - if k == ord('g'): - infer_cfg.infer_params.flag_eye_retargeting = not infer_cfg.infer_params.flag_eye_retargeting - print('flag_eye_retargeting:'+str(infer_cfg.infer_params.flag_eye_retargeting)) - - if k == ord('k'): - infer_cfg.crop_params.src_scale -= 0.1 - ret = pipe.prepare_source(args.src_image, realtime=args.realtime) - print('src_scale:'+str(infer_cfg.crop_params.src_scale)) - if k == ord('l'): - infer_cfg.crop_params.src_scale += 0.1 - ret = pipe.prepare_source(args.src_image, realtime=args.realtime) - print('src_scale:'+str(infer_cfg.crop_params.src_scale)) - if k == ord('n'): - infer_cfg.crop_params.dri_scale -= 0.1 - print('dri_scale:'+str(infer_cfg.crop_params.dri_scale)) - if k == ord('m'): - infer_cfg.crop_params.dri_scale += 0.1 - print('dri_scale:'+str(infer_cfg.crop_params.dri_scale)) + # Existing keys + if k == ord("s"): + infer_cfg.infer_params.flag_stitching = ( + not infer_cfg.infer_params.flag_stitching + ) + print("flag_stitching:" + str(infer_cfg.infer_params.flag_stitching)) + if k == ord("z"): + infer_cfg.infer_params.flag_relative_motion = ( + not infer_cfg.infer_params.flag_relative_motion + ) + print( + "flag_relative_motion:" + + str(infer_cfg.infer_params.flag_relative_motion) + ) + if k == ord("x"): + if infer_cfg.infer_params.animation_region == "all": + infer_cfg.infer_params.animation_region = "exp" + print('animation_region = "exp"') + else: + infer_cfg.infer_params.animation_region = "all" + print('animation_region = "all"') + if k == ord("c"): + infer_cfg.infer_params.flag_crop_driving_video = ( + not infer_cfg.infer_params.flag_crop_driving_video + ) + print( + "flag_crop_driving_video:" + + str(infer_cfg.infer_params.flag_crop_driving_video) + ) + # NEW SMOOTHING CONTROLS + if k == ord("1"): # Decrease smoothing (less blur, more jitter) + if stabilizer: + stabilizer.decrease_smoothing() + if k == ord("2"): # Increase smoothing (more blur, less jitter) + if stabilizer: + stabilizer.increase_smoothing() + if k == ord("3"): # Decrease movement threshold (smoother when moving) + if stabilizer: + stabilizer.decrease_threshold() + if k == ord("4"): # Increase movement threshold (less smooth when moving) + if stabilizer: + stabilizer.increase_threshold() + if k == ord("0"): # Print current settings + if stabilizer: + print( + f"Current: alpha={stabilizer.alpha:.2f}, threshold={stabilizer.movement_threshold:.1f}" + ) if not args.realtime: vout_crop.release() @@ -278,19 +536,54 @@ def run_with_pkl(args): if video_has_audio(args.dri_video): vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4" subprocess.call( - [FFMPEG, "-i", vsave_crop_path, "-i", args.dri_video, - "-b:v", "10M", "-c:v", - "libx264", "-map", "0:v", "-map", "1:a", - "-c:a", "aac", - "-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"]) + [ + FFMPEG, + "-i", + vsave_crop_path, + "-i", + args.dri_video, + "-b:v", + "10M", + "-c:v", + "libx264", + "-map", + "0:v", + "-map", + "1:a", + "-c:a", + "aac", + "-pix_fmt", + "yuv420p", + vsave_crop_path_new, + "-y", + "-shortest", + ] + ) vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4" subprocess.call( - [FFMPEG, "-i", vsave_org_path, "-i", args.dri_video, - "-b:v", "10M", "-c:v", - "libx264", "-map", "0:v", "-map", "1:a", - "-c:a", "aac", - "-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"]) - + [ + FFMPEG, + "-i", + vsave_org_path, + "-i", + args.dri_video, + "-b:v", + "10M", + "-c:v", + "libx264", + "-map", + "0:v", + "-map", + "1:a", + "-c:a", + "aac", + "-pix_fmt", + "yuv420p", + vsave_org_path_new, + "-y", + "-shortest", + ] + ) print(vsave_crop_path_new) print(vsave_org_path_new) else: @@ -300,20 +593,43 @@ def run_with_pkl(args): cv2.destroyAllWindows() print( - "inference median time: {} ms/frame, mean time: {} ms/frame".format(np.median(infer_times) * 1000, - np.mean(infer_times) * 1000)) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Faster Live Portrait Pipeline') - parser.add_argument('--src_image', required=False, type=str, default="assets/examples/source/s12.jpg", - help='source image') - parser.add_argument('--dri_video', required=False, type=str, default="assets/examples/driving/d14.mp4", - help='driving video') - parser.add_argument('--cfg', required=False, type=str, default="configs/onnx_infer.yaml", help='inference config') - parser.add_argument('--realtime', action='store_true', help='realtime inference') - parser.add_argument('--animal', action='store_true', help='use animal model') - parser.add_argument('--paste_back', action='store_true', default=False, help='paste back to origin image') + "inference median time: {} ms/frame, mean time: {} ms/frame".format( + np.median(infer_times) * 1000, np.mean(infer_times) * 1000 + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Faster Live Portrait Pipeline") + parser.add_argument( + "--src_image", + required=False, + type=str, + default="assets/examples/source/s12.jpg", + help="source image", + ) + parser.add_argument( + "--dri_video", + required=False, + type=str, + default="assets/examples/driving/d14.mp4", + help="driving video", + ) + parser.add_argument( + "--cfg", + required=False, + type=str, + default="configs/onnx_infer.yaml", + help="inference config", + ) + parser.add_argument("--realtime", action="store_true", help="realtime inference") + parser.add_argument("--animal", action="store_true", help="use animal model") + parser.add_argument( + "--paste_back", + action="store_true", + default=False, + help="paste back to origin image", + ) args, unknown = parser.parse_known_args() if args.dri_video.endswith(".pkl"): diff --git a/src/models/JoyVASA/hubert.py b/src/models/JoyVASA/hubert.py index c98c8f0..2b7b496 100644 --- a/src/models/JoyVASA/hubert.py +++ b/src/models/JoyVASA/hubert.py @@ -1,4 +1,4 @@ -from transformers import HubertModel +from transformers import HubertModel from transformers.modeling_outputs import BaseModelOutput from .wav2vec2 import linear_interpolation @@ -12,7 +12,7 @@ def __init__(self, config): def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, frame_num=None): - self.config.output_attentions = True + # self.config.output_attentions = True # Commented out: conflicts with sdpa attention output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -49,3 +49,4 @@ def forward(self, input_values, output_fps=25, attention_mask=None, output_atten return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) + diff --git a/src/models/motion_extractor_model.py b/src/models/motion_extractor_model.py index fb45276..9733612 100644 --- a/src/models/motion_extractor_model.py +++ b/src/models/motion_extractor_model.py @@ -46,7 +46,7 @@ def input_process(self, *data): def output_process(self, *data): if self.predict_type == "trt": - kp, pitch, yaw, roll, t, exp, scale = data + pitch, yaw, roll, t, exp, scale, kp = data else: pitch, yaw, roll, t, exp, scale, kp = data if self.flag_refine_info: diff --git a/src/pipelines/faster_live_portrait_pipeline.py b/src/pipelines/faster_live_portrait_pipeline.py index 9031fc3..1d0102e 100644 --- a/src/pipelines/faster_live_portrait_pipeline.py +++ b/src/pipelines/faster_live_portrait_pipeline.py @@ -6,21 +6,33 @@ import copy import os.path -import pdb -import time import traceback -from PIL import Image + import cv2 -from tqdm import tqdm import numpy as np import torch +from PIL import Image +from tqdm import tqdm -from .. import models -from ..utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back, paste_back_pytorch -from ..utils.utils import resize_to_limit, prepare_paste_back, get_rotation_matrix, calc_lip_close_ratio, \ - calc_eye_close_ratio, transform_keypoint, concat_feat from src.utils import utils +from .. import models +from ..utils.crop import ( + crop_image, + crop_image_by_bbox, + parse_bbox_from_landmark, + paste_back_pytorch, +) +from ..utils.utils import ( + calc_eye_close_ratio, + calc_lip_close_ratio, + concat_feat, + get_rotation_matrix, + prepare_paste_back, + resize_to_limit, + transform_keypoint, +) + class FasterLivePortraitPipeline: def __init__(self, cfg, **kwargs): @@ -37,15 +49,26 @@ def update_cfg(self, args_user): if key in self.cfg.infer_params: if self.cfg.infer_params[key] != args_user[key]: update_ret = True - print("update infer cfg {} from {} to {}".format(key, self.cfg.infer_params[key], args_user[key])) + print( + "update infer cfg {} from {} to {}".format( + key, self.cfg.infer_params[key], args_user[key] + ) + ) self.cfg.infer_params[key] = args_user[key] elif key in self.cfg.crop_params: if self.cfg.crop_params[key] != args_user[key]: update_ret = True - print("update crop cfg {} from {} to {}".format(key, self.cfg.crop_params[key], args_user[key])) + print( + "update crop cfg {} from {} to {}".format( + key, self.cfg.crop_params[key], args_user[key] + ) + ) self.cfg.crop_params[key] = args_user[key] else: - if key in self.cfg.infer_params and self.cfg.infer_params[key] != args_user[key]: + if ( + key in self.cfg.infer_params + and self.cfg.infer_params[key] != args_user[key] + ): update_ret = True print("add {}:{} to infer cfg".format(key, args_user[key])) self.cfg.infer_params[key] = args_user[key] @@ -69,33 +92,48 @@ def init_models(self, **kwargs): for model_name in self.cfg.models: print(f"loading model: {model_name}") print(self.cfg.models[model_name]) - self.model_dict[model_name] = getattr(models, self.cfg.models[model_name]["name"])( - **self.cfg.models[model_name]) + self.model_dict[model_name] = getattr( + models, self.cfg.models[model_name]["name"] + )(**self.cfg.models[model_name]) else: print("load Animal Model >>>") self.is_animal = True self.model_dict = {} from src.utils.animal_landmark_runner import XPoseRunner from src.utils.utils import make_abs_path + checkpoint_dir = None for model_name in self.cfg.animal_models: print(f"loading model: {model_name}") print(self.cfg.animal_models[model_name]) - if checkpoint_dir is None and isinstance(self.cfg.animal_models[model_name].model_path, str): - checkpoint_dir = os.path.dirname(self.cfg.animal_models[model_name].model_path) - self.model_dict[model_name] = getattr(models, self.cfg.animal_models[model_name]["name"])( - **self.cfg.animal_models[model_name]) + if checkpoint_dir is None and isinstance( + self.cfg.animal_models[model_name].model_path, str + ): + checkpoint_dir = os.path.dirname( + self.cfg.animal_models[model_name].model_path + ) + self.model_dict[model_name] = getattr( + models, self.cfg.animal_models[model_name]["name"] + )(**self.cfg.animal_models[model_name]) - xpose_config_file_path: str = make_abs_path("models/XPose/config_model/UniPose_SwinT.py") + xpose_config_file_path: str = make_abs_path( + "models/XPose/config_model/UniPose_SwinT.py" + ) xpose_ckpt_path: str = os.path.join(checkpoint_dir, "xpose.pth") - xpose_embedding_cache_path: str = os.path.join(checkpoint_dir, 'clip_embedding') - self.model_dict["xpose"] = XPoseRunner(model_config_path=xpose_config_file_path, - model_checkpoint_path=xpose_ckpt_path, - embeddings_cache_path=xpose_embedding_cache_path, - flag_use_half_precision=True) + xpose_embedding_cache_path: str = os.path.join( + checkpoint_dir, "clip_embedding" + ) + self.model_dict["xpose"] = XPoseRunner( + model_config_path=xpose_config_file_path, + model_checkpoint_path=xpose_ckpt_path, + embeddings_cache_path=xpose_embedding_cache_path, + flag_use_half_precision=True, + ) def init_vars(self, **kwargs): - self.mask_crop = cv2.imread(self.cfg.infer_params.mask_crop_path, cv2.IMREAD_COLOR) + self.mask_crop = cv2.imread( + self.cfg.infer_params.mask_crop_path, cv2.IMREAD_COLOR + ) self.frame_id = 0 self.src_lmk_pre = None self.R_d_0 = None @@ -108,7 +146,9 @@ def init_vars(self, **kwargs): self.src_infos = [] self.src_imgs = [] self.is_source_video = False - self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk): c_s_eyes = calc_eye_close_ratio(source_lmk[None]) @@ -150,19 +190,18 @@ def prepare_source(self, source_path, **kwargs): self.source_path = source_path for ii, img_bgr in tqdm(enumerate(src_imgs_bgr), total=len(src_imgs_bgr)): - img_bgr = resize_to_limit(img_bgr, self.cfg.infer_params.source_max_dim, - self.cfg.infer_params.source_division) + img_bgr = resize_to_limit( + img_bgr, + self.cfg.infer_params.source_max_dim, + self.cfg.infer_params.source_division, + ) img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) src_faces = [] if self.is_animal: with torch.no_grad(): img_rgb_pil = Image.fromarray(img_rgb) lmk = self.model_dict["xpose"].run( - img_rgb_pil, - 'face', - 'animal_face', - 0, - 0 + img_rgb_pil, "face", "animal_face", 0, 0 ) if lmk is None: continue @@ -196,7 +235,9 @@ def prepare_source(self, source_path, **kwargs): else: lmk = self.model_dict["landmark"].predict(img_rgb, lmk) ret_dct["lmk_crop"] = lmk - ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize + ret_dct["lmk_crop_256x256"] = ( + ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize + ) # update a 256x256 version for network input ret_dct["img_crop_256x256"] = cv2.resize( @@ -206,10 +247,14 @@ def prepare_source(self, source_path, **kwargs): src_infos = [[] for _ in range(len(crop_infos))] for i, crop_info in enumerate(crop_infos): - source_lmk = crop_info['lmk_crop'] - img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] - pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict( - img_crop_256x256) + source_lmk = crop_info["lmk_crop"] + img_crop, img_crop_256x256 = ( + crop_info["img_crop"], + crop_info["img_crop_256x256"], + ) + pitch, yaw, roll, t, exp, scale, kp = self.model_dict[ + "motion_extractor" + ].predict(img_crop_256x256) x_s_info = { "pitch": pitch, "yaw": yaw, @@ -217,30 +262,52 @@ def prepare_source(self, source_path, **kwargs): "t": t, "exp": exp, "scale": scale, - "kp": kp + "kp": kp, } src_infos[i].append(copy.deepcopy(x_s_info)) x_c_s = kp R_s = get_rotation_matrix(pitch, yaw, roll) - f_s = self.model_dict["app_feat_extractor"].predict(img_crop_256x256) + f_s = self.model_dict["app_feat_extractor"].predict( + img_crop_256x256 + ) x_s = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp) - src_infos[i].extend([source_lmk.copy(), R_s.copy(), f_s.copy(), x_s.copy(), x_c_s.copy()]) + src_infos[i].extend( + [ + source_lmk.copy(), + R_s.copy(), + f_s.copy(), + x_s.copy(), + x_c_s.copy(), + ] + ) if not self.is_animal: - flag_lip_zero = self.cfg.infer_params.flag_normalize_lip # not overwrite + flag_lip_zero = ( + self.cfg.infer_params.flag_normalize_lip + ) # not overwrite if flag_lip_zero: # let lip-open scalar to be 0 at first # 似乎要调参? c_d_lip_before_animation = [0.05] - combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio( - c_d_lip_before_animation, source_lmk.copy()) - if combined_lip_ratio_tensor_before_animation[0][ - 0] < self.cfg.infer_params.lip_normalize_threshold: + combined_lip_ratio_tensor_before_animation = ( + self.calc_combined_lip_ratio( + c_d_lip_before_animation, source_lmk.copy() + ) + ) + if ( + combined_lip_ratio_tensor_before_animation[0][0] + < self.cfg.infer_params.lip_normalize_threshold + ): flag_lip_zero = False src_infos[i].append(None) src_infos[i].append(flag_lip_zero) else: - lip_delta_before_animation = self.model_dict['stitching_lip_retarget'].predict( - concat_feat(x_s, combined_lip_ratio_tensor_before_animation)) + lip_delta_before_animation = self.model_dict[ + "stitching_lip_retarget" + ].predict( + concat_feat( + x_s, combined_lip_ratio_tensor_before_animation + ) + ) src_infos[i].append(lip_delta_before_animation.copy()) src_infos[i].append(flag_lip_zero) else: @@ -251,19 +318,28 @@ def prepare_source(self, source_path, **kwargs): src_infos[i].append(False) ######## prepare for pasteback ######## - if self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching: - mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], - dsize=(img_rgb.shape[1], img_rgb.shape[0])) - mask_ori_float = torch.from_numpy(mask_ori_float).to(self.device) + if ( + self.cfg.infer_params.flag_pasteback + and self.cfg.infer_params.flag_do_crop + and self.cfg.infer_params.flag_stitching + ): + mask_ori_float = prepare_paste_back( + self.mask_crop, + crop_info["M_c2o"], + dsize=(img_rgb.shape[1], img_rgb.shape[0]), + ) + mask_ori_float = torch.from_numpy(mask_ori_float).to( + self.device + ) src_infos[i].append(mask_ori_float) else: src_infos[i].append(None) - M = torch.from_numpy(crop_info['M_c2o']).to(self.device) + M = torch.from_numpy(crop_info["M_c2o"]).to(self.device) src_infos[i].append(M) self.src_infos.append(src_infos[:]) print(f"finish process source:{source_path} >>>>>>>>") return len(self.src_infos) > 0 - except Exception as e: + except Exception: traceback.print_exc() return False @@ -274,7 +350,7 @@ def retarget_eye(self, kp_source, eye_close_ratio): Return: Bx(3*num_kp+2) """ feat_eye = concat_feat(kp_source, eye_close_ratio) - delta = self.model_dict['stitching_eye_retarget'].predict(feat_eye) + delta = self.model_dict["stitching_eye_retarget"].predict(feat_eye) return delta def retarget_lip(self, kp_source, lip_close_ratio): @@ -283,11 +359,11 @@ def retarget_lip(self, kp_source, lip_close_ratio): lip_close_ratio: Bx2 """ feat_lip = concat_feat(kp_source, lip_close_ratio) - delta = self.model_dict['stitching_lip_retarget'].predict(feat_lip) + delta = self.model_dict["stitching_lip_retarget"].predict(feat_lip) return delta def stitching(self, kp_source, kp_driving): - """ conduct the stitching + """conduct the stitching kp_source: Bxnum_kpx3 kp_driving: Bxnum_kpx3 """ @@ -296,47 +372,127 @@ def stitching(self, kp_source, kp_driving): kp_driving_new = kp_driving.copy() - delta = self.model_dict['stitching'].predict(concat_feat(kp_source, kp_driving_new)) + delta = self.model_dict["stitching"].predict( + concat_feat(kp_source, kp_driving_new) + ) - delta_exp = delta[..., :3 * num_kp].reshape(bs, num_kp, 3) # 1x20x3 - delta_tx_ty = delta[..., 3 * num_kp:3 * num_kp + 2].reshape(bs, 1, 2) # 1x1x2 + delta_exp = delta[..., : 3 * num_kp].reshape(bs, num_kp, 3) # 1x20x3 + delta_tx_ty = delta[..., 3 * num_kp : 3 * num_kp + 2].reshape(bs, 1, 2) # 1x1x2 kp_driving_new += delta_exp kp_driving_new[..., :2] += delta_tx_ty return kp_driving_new - def _run(self, src_info, x_d_i_info, x_d_0_info, R_d_i, R_d_0, realtime, input_eye_ratio, input_lip_ratio, - I_p_pstbk, **kwargs): + def _run( + self, + src_info, + x_d_i_info, + x_d_0_info, + R_d_i, + R_d_0, + realtime, + input_eye_ratio, + input_lip_ratio, + I_p_pstbk, + **kwargs, + ): out_crop, out_org = None, None eye_delta_before_animation = None + for j in range(len(src_info)): + # Initialize mask_ori_float to None at the start of each iteration + mask_ori_float = None + if self.is_source_video: - x_s_info, source_lmk, R_s, f_s, x_s, x_c_s, lip_delta_before_animation, flag_lip_zero, mask_ori_float, M = \ - src_info[j] + ( + x_s_info, + source_lmk, + R_s, + f_s, + x_s, + x_c_s, + lip_delta_before_animation, + flag_lip_zero, + mask_ori_float, + M, + ) = src_info[j] # let lip-open scalar to be 0 at first if the input is a video and flag_relative_motion - if not (self.cfg.infer_params.flag_normalize_lip and self.cfg.infer_params.flag_relative_motion): + if not ( + self.cfg.infer_params.flag_normalize_lip + and self.cfg.infer_params.flag_relative_motion + ): lip_delta_before_animation = None # let eye-open scalar to be the same as the first frame if the latter is eye-open state - if self.cfg.infer_params.flag_source_video_eye_retargeting and source_lmk is not None: - combined_eye_ratio_tensor_frame_zero = utils.calc_eye_close_ratio(src_info[0][1]) + if ( + self.cfg.infer_params.flag_source_video_eye_retargeting + and source_lmk is not None + ): + combined_eye_ratio_tensor_frame_zero = utils.calc_eye_close_ratio( + src_info[0][1] + ) c_d_eye_before_animation_frame_zero = [ - [combined_eye_ratio_tensor_frame_zero[0][:2].mean()]] - if c_d_eye_before_animation_frame_zero[0][ - 0] < self.cfg.infer_params.source_video_eye_retargeting_threshold: + [combined_eye_ratio_tensor_frame_zero[0][:2].mean()] + ] + if ( + c_d_eye_before_animation_frame_zero[0][0] + < self.cfg.infer_params.source_video_eye_retargeting_threshold + ): c_d_eye_before_animation_frame_zero = [[0.39]] - combined_eye_ratio_tensor_before_animation = self.calc_combined_eye_ratio( - c_d_eye_before_animation_frame_zero, source_lmk) - eye_delta_before_animation = self.retarget_eye(x_s, combined_eye_ratio_tensor_before_animation) - - if not realtime and self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and \ - self.cfg.infer_params.flag_stitching: - mask_ori_float = prepare_paste_back(self.mask_crop, M.cpu().numpy(), - dsize=(self.src_imgs[0].shape[1], self.src_imgs[0].shape[0])) - mask_ori_float = torch.from_numpy(mask_ori_float).to(self.device) + combined_eye_ratio_tensor_before_animation = ( + self.calc_combined_eye_ratio( + c_d_eye_before_animation_frame_zero, source_lmk + ) + ) + eye_delta_before_animation = self.retarget_eye( + x_s, combined_eye_ratio_tensor_before_animation + ) + + # FIX: ALWAYS prepare mask when pasteback is enabled, regardless of realtime mode + if ( + self.cfg.infer_params.flag_pasteback + and self.cfg.infer_params.flag_do_crop + and self.cfg.infer_params.flag_stitching + ): + if mask_ori_float is None: + mask_ori_float = prepare_paste_back( + self.mask_crop, + M.cpu().numpy(), + dsize=( + self.src_imgs[0].shape[1], + self.src_imgs[0].shape[0], + ), + ) + mask_ori_float = torch.from_numpy(mask_ori_float).to( + self.device + ) else: - x_s_info, source_lmk, R_s, f_s, x_s, x_c_s, lip_delta_before_animation, flag_lip_zero, mask_ori_float, M = \ - src_info[j] + ( + x_s_info, + source_lmk, + R_s, + f_s, + x_s, + x_c_s, + lip_delta_before_animation, + flag_lip_zero, + mask_ori_float, + M, + ) = src_info[j] + # FIX: Also prepare mask for non-video sources if needed + if ( + self.cfg.infer_params.flag_pasteback + and self.cfg.infer_params.flag_do_crop + and self.cfg.infer_params.flag_stitching + and mask_ori_float is None + ): + mask_ori_float = prepare_paste_back( + self.mask_crop, + M.cpu().numpy(), + dsize=(self.src_imgs[0].shape[1], self.src_imgs[0].shape[0]), + ) + mask_ori_float = torch.from_numpy(mask_ori_float).to(self.device) + if self.cfg.infer_params.flag_relative_motion: if self.cfg.infer_params.animation_region in ["all", "pose"]: if self.is_source_video: @@ -346,8 +502,8 @@ def _run(self, src_info, x_d_i_info, x_d_0_info, R_d_i, R_d_0, realtime, input_e else: R_new = R_s - delta_new = x_s_info['exp'].copy() - x_d_exp_smooth = x_d_i_info['exp'].copy() + delta_new = x_s_info["exp"].copy() + x_d_exp_smooth = x_d_i_info["exp"].copy() if self.is_source_video: x_d_exp_smooth = self.exp_smooth.process(x_d_exp_smooth) if self.cfg.infer_params.animation_region in ["all", "exp"]: @@ -359,31 +515,44 @@ def _run(self, src_info, x_d_i_info, x_d_0_info, R_d_i, R_d_0, realtime, input_e delta_new[:, 8, 2] = x_d_exp_smooth[:, 8, 2] delta_new[:, 9, 1:] = x_d_exp_smooth[:, 9, 1:] else: - delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) + delta_new = x_s_info["exp"] + ( + x_d_i_info["exp"] - x_d_0_info["exp"] + ) elif self.cfg.infer_params.animation_region in ["lip"]: for lip_idx in [6, 12, 14, 17, 19, 20]: if self.is_source_video: delta_new[:, lip_idx, :] = x_d_exp_smooth[:, lip_idx, :] else: - delta_new[:, lip_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, - lip_idx, :] + delta_new[:, lip_idx, :] = ( + x_s_info["exp"] + + (x_d_i_info["exp"] - x_d_0_info["exp"]) + )[:, lip_idx, :] elif self.cfg.infer_params.animation_region in ["eyes"]: for eyes_idx in [11, 13, 15, 16, 18]: if self.is_source_video: delta_new[:, eyes_idx, :] = x_d_exp_smooth[:, eyes_idx, :] else: - delta_new[:, eyes_idx, :] = (x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']))[:, - eyes_idx, :] + delta_new[:, eyes_idx, :] = ( + x_s_info["exp"] + + (x_d_i_info["exp"] - x_d_0_info["exp"]) + )[:, eyes_idx, :] if self.cfg.infer_params.animation_region in ["all"]: - scale_new = x_s_info['scale'] if self.is_source_video else x_s_info['scale'] * ( - x_d_i_info['scale'] / x_d_0_info['scale']) + scale_new = ( + x_s_info["scale"] + if self.is_source_video + else x_s_info["scale"] + * (x_d_i_info["scale"] / x_d_0_info["scale"]) + ) else: - scale_new = x_s_info['scale'] + scale_new = x_s_info["scale"] if self.cfg.infer_params.animation_region in ["all"]: - t_new = x_s_info['t'] if self.is_source_video else x_s_info['t'] + ( - x_d_i_info['t'] - x_d_0_info['t']) + t_new = ( + x_s_info["t"] + if self.is_source_video + else x_s_info["t"] + (x_d_i_info["t"] - x_d_0_info["t"]) + ) else: - t_new = x_s_info['t'] + t_new = x_s_info["t"] else: if self.cfg.infer_params.animation_region in ["all", "pose"]: if self.is_source_video: @@ -393,77 +562,138 @@ def _run(self, src_info, x_d_i_info, x_d_0_info, R_d_i, R_d_0, realtime, input_e else: R_new = R_s - delta_new = x_s_info['exp'].copy() - x_d_exp_smooth = x_d_i_info['exp'].copy() + delta_new = x_s_info["exp"].copy() + x_d_exp_smooth = x_d_i_info["exp"].copy() if self.is_source_video: x_d_exp_smooth = self.exp_smooth.process(x_d_exp_smooth) if self.cfg.infer_params.animation_region in ["all", "exp"]: for idx in [1, 2, 6, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]: - delta_new[:, idx, :] = x_d_exp_smooth[:, idx, :] if self.is_source_video else x_d_i_info['exp'][ - :, idx, :] - delta_new[:, 3:5, 1] = x_d_exp_smooth[:, 3:5, 1] if self.is_source_video else x_d_i_info['exp'][:, - 3:5, 1] - delta_new[:, 5, 2] = x_d_exp_smooth[:, 5, 2] if self.is_source_video else x_d_i_info['exp'][:, - 5, 2] - delta_new[:, 8, 2] = x_d_exp_smooth[:, 8, 2] if self.is_source_video else x_d_i_info['exp'][:, - 8, 2] - delta_new[:, 9, 1:] = x_d_exp_smooth[:, 9, 1:] if self.is_source_video else x_d_i_info['exp'][:, - 9, 1:] + delta_new[:, idx, :] = ( + x_d_exp_smooth[:, idx, :] + if self.is_source_video + else x_d_i_info["exp"][:, idx, :] + ) + delta_new[:, 3:5, 1] = ( + x_d_exp_smooth[:, 3:5, 1] + if self.is_source_video + else x_d_i_info["exp"][:, 3:5, 1] + ) + delta_new[:, 5, 2] = ( + x_d_exp_smooth[:, 5, 2] + if self.is_source_video + else x_d_i_info["exp"][:, 5, 2] + ) + delta_new[:, 8, 2] = ( + x_d_exp_smooth[:, 8, 2] + if self.is_source_video + else x_d_i_info["exp"][:, 8, 2] + ) + delta_new[:, 9, 1:] = ( + x_d_exp_smooth[:, 9, 1:] + if self.is_source_video + else x_d_i_info["exp"][:, 9, 1:] + ) elif self.cfg.infer_params.animation_region in ["lip"]: for lip_idx in [6, 12, 14, 17, 19, 20]: - delta_new[:, lip_idx, :] = x_d_exp_smooth[:, lip_idx, :] if self.is_source_video else \ - x_d_i_info['exp'][:, lip_idx, :] + delta_new[:, lip_idx, :] = ( + x_d_exp_smooth[:, lip_idx, :] + if self.is_source_video + else x_d_i_info["exp"][:, lip_idx, :] + ) elif self.cfg.infer_params.animation_region in ["eyes"]: for eyes_idx in [11, 13, 15, 16, 18]: - delta_new[:, eyes_idx, :] = x_d_exp_smooth[:, eyes_idx, :] if self.is_source_video else \ - x_d_i_info['exp'][:, eyes_idx, :] - scale_new = x_s_info['scale'].copy() + delta_new[:, eyes_idx, :] = ( + x_d_exp_smooth[:, eyes_idx, :] + if self.is_source_video + else x_d_i_info["exp"][:, eyes_idx, :] + ) + scale_new = x_s_info["scale"].copy() if self.cfg.infer_params.animation_region in ["all", "pose"]: - t_new = x_d_i_info['t'].copy() + t_new = x_d_i_info["t"].copy() else: - t_new = x_s_info['t'].copy() + t_new = x_s_info["t"].copy() t_new[..., 2] = 0 # zero tz x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new if not self.is_animal: # Algorithm 1: - if not self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting: + if ( + not self.cfg.infer_params.flag_stitching + and not self.cfg.infer_params.flag_eye_retargeting + and not self.cfg.infer_params.flag_lip_retargeting + ): # without stitching or retargeting if flag_lip_zero and lip_delta_before_animation is not None: - x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) - if self.cfg.infer_params.flag_source_video_eye_retargeting and eye_delta_before_animation is not None: + x_d_i_new += lip_delta_before_animation.reshape( + -1, x_s.shape[1], 3 + ) + if ( + self.cfg.infer_params.flag_source_video_eye_retargeting + and eye_delta_before_animation is not None + ): x_d_i_new += eye_delta_before_animation - elif self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting: + elif ( + self.cfg.infer_params.flag_stitching + and not self.cfg.infer_params.flag_eye_retargeting + and not self.cfg.infer_params.flag_lip_retargeting + ): # with stitching and without retargeting if flag_lip_zero and lip_delta_before_animation is not None: - x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape( - -1, x_s.shape[1], 3) + x_d_i_new = self.stitching( + x_s, x_d_i_new + ) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) else: x_d_i_new = self.stitching(x_s, x_d_i_new) - if self.cfg.infer_params.flag_source_video_eye_retargeting and eye_delta_before_animation is not None: + if ( + self.cfg.infer_params.flag_source_video_eye_retargeting + and eye_delta_before_animation is not None + ): x_d_i_new += eye_delta_before_animation else: eyes_delta, lip_delta = None, None if self.cfg.infer_params.flag_eye_retargeting: c_d_eyes_i = input_eye_ratio - combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, - source_lmk) + combined_eye_ratio_tensor = self.calc_combined_eye_ratio( + c_d_eyes_i, source_lmk + ) # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor) if self.cfg.infer_params.flag_lip_retargeting: c_d_lip_i = input_lip_ratio - combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk) + combined_lip_ratio_tensor = self.calc_combined_lip_ratio( + c_d_lip_i, source_lmk + ) # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor) if self.cfg.infer_params.flag_relative_motion: # use x_s - x_d_i_new = x_s + \ - (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ - (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) + x_d_i_new = ( + x_s + + ( + eyes_delta.reshape(-1, x_s.shape[1], 3) + if eyes_delta is not None + else 0 + ) + + ( + lip_delta.reshape(-1, x_s.shape[1], 3) + if lip_delta is not None + else 0 + ) + ) else: # use x_d,i - x_d_i_new = x_d_i_new + \ - (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ - (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) + x_d_i_new = ( + x_d_i_new + + ( + eyes_delta.reshape(-1, x_s.shape[1], 3) + if eyes_delta is not None + else 0 + ) + + ( + lip_delta.reshape(-1, x_s.shape[1], 3) + if lip_delta is not None + else 0 + ) + ) if self.cfg.infer_params.flag_stitching: x_d_i_new = self.stitching(x_s, x_d_i_new) @@ -471,17 +701,40 @@ def _run(self, src_info, x_d_i_info, x_d_0_info, R_d_i, R_d_0, realtime, input_e if self.cfg.infer_params.flag_stitching: x_d_i_new = self.stitching(x_s, x_d_i_new) - x_d_i_new = x_s + (x_d_i_new - x_s) * self.cfg.infer_params.driving_multiplier + x_d_i_new = ( + x_s + (x_d_i_new - x_s) * self.cfg.infer_params.driving_multiplier + ) out_crop = self.model_dict["warping_spade"].predict(f_s, x_s, x_d_i_new) - if not realtime and self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching: - # TODO: pasteback is slow, considering optimize it using multi-threading or GPU - # I_p_pstbk = paste_back(out_crop, crop_info['M_c2o'], I_p_pstbk, mask_ori_float) - I_p_pstbk = paste_back_pytorch(out_crop, M, I_p_pstbk, mask_ori_float) - return out_crop.to(dtype=torch.uint8).cpu().numpy(), I_p_pstbk.to(dtype=torch.uint8).cpu().numpy() + + # FIX: Enable pasteback for both realtime and non-realtime modes + if ( + self.cfg.infer_params.flag_pasteback + and self.cfg.infer_params.flag_do_crop + and self.cfg.infer_params.flag_stitching + ): + if mask_ori_float is not None and I_p_pstbk is not None: + try: + I_p_pstbk = paste_back_pytorch( + out_crop, M, I_p_pstbk, mask_ori_float + ) + except Exception: + # Fallback to cropped output if pasteback fails + I_p_pstbk = out_crop + else: + I_p_pstbk = out_crop + + # Ensure we always return valid images + if I_p_pstbk is None: + I_p_pstbk = out_crop + + return out_crop.to(dtype=torch.uint8).cpu().numpy(), I_p_pstbk.to( + dtype=torch.uint8 + ).cpu().numpy() def run(self, image, img_src, src_info, **kwargs): img_bgr = image img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + I_p_pstbk = torch.from_numpy(img_src).to(self.device).float() realtime = kwargs.get("realtime", False) if self.cfg.infer_params.flag_crop_driving_video: @@ -535,7 +788,9 @@ def run(self, image, img_src, src_info, **kwargs): input_eye_ratio = calc_eye_close_ratio(lmk_crop[None]) input_lip_ratio = calc_lip_close_ratio(lmk_crop[None]) - pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(img_crop) + pitch, yaw, roll, t, exp, scale, kp = self.model_dict[ + "motion_extractor" + ].predict(img_crop) x_d_i_info = { "pitch": pitch, "yaw": yaw, @@ -543,15 +798,18 @@ def run(self, image, img_src, src_info, **kwargs): "t": t, "exp": exp, "scale": scale, - "kp": kp + "kp": kp, } R_d_i = get_rotation_matrix(pitch, yaw, roll) x_d_i_info["R"] = R_d_i x_d_i_info_copy = copy.deepcopy(x_d_i_info) for key in x_d_i_info_copy: x_d_i_info_copy[key] = x_d_i_info_copy[key].astype(np.float32) - dri_motion_info = [x_d_i_info_copy, copy.deepcopy(input_eye_ratio.astype(np.float32)), - copy.deepcopy(input_lip_ratio.astype(np.float32))] + dri_motion_info = [ + x_d_i_info_copy, + copy.deepcopy(input_eye_ratio.astype(np.float32)), + copy.deepcopy(input_lip_ratio.astype(np.float32)), + ] if kwargs.get("first_frame", False) or self.R_d_0 is None: self.frame_id = 0 self.R_d_0 = R_d_i.copy() @@ -561,9 +819,23 @@ def run(self, image, img_src, src_info, **kwargs): self.exp_smooth = utils.OneEuroFilter(4, 0.3) R_d_0 = self.R_d_0.copy() x_d_0_info = copy.deepcopy(self.x_d_0_info) - out_crop, I_p_pstbk = self._run(src_info, x_d_i_info, x_d_0_info, R_d_i, R_d_0, realtime, input_eye_ratio, - input_lip_ratio, - I_p_pstbk, **kwargs) + + # FIX: Ensure I_p_pstbk is properly set for pasteback + if self.cfg.infer_params.flag_pasteback and I_p_pstbk is None: + I_p_pstbk = torch.from_numpy(img_src).to(self.device).float() + + out_crop, I_p_pstbk = self._run( + src_info, + x_d_i_info, + x_d_0_info, + R_d_i, + R_d_0, + realtime, + input_eye_ratio, + input_lip_ratio, + I_p_pstbk, + **kwargs, + ) return img_crop, out_crop, I_p_pstbk, dri_motion_info def run_with_pkl(self, dri_motion_info, img_src, src_info, **kwargs): @@ -584,8 +856,18 @@ def run_with_pkl(self, dri_motion_info, img_src, src_info, **kwargs): self.exp_smooth = utils.OneEuroFilter(4, 0.3) R_d_0 = self.R_d_0.copy() x_d_0_info = copy.deepcopy(self.x_d_0_info) - out_crop, I_p_pstbk = self._run(src_info, x_d_i_info, x_d_0_info, R_d_i, R_d_0, realtime, input_eye_ratio, - input_lip_ratio, I_p_pstbk, **kwargs) + out_crop, I_p_pstbk = self._run( + src_info, + x_d_i_info, + x_d_0_info, + R_d_i, + R_d_0, + realtime, + input_eye_ratio, + input_lip_ratio, + I_p_pstbk, + **kwargs, + ) return out_crop, I_p_pstbk def __del__(self): diff --git a/src/pipelines/joyvasa_audio_to_motion_pipeline.py b/src/pipelines/joyvasa_audio_to_motion_pipeline.py index 3538080..b3f8480 100644 --- a/src/pipelines/joyvasa_audio_to_motion_pipeline.py +++ b/src/pipelines/joyvasa_audio_to_motion_pipeline.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # @Time : 2024/12/15 # @Author : wenshao # @Email : wenshaoguo1026@gmail.com @@ -24,7 +24,7 @@ class JoyVASAAudio2MotionPipeline: """ - JoyVASA 声音生成LivePortrait Motion + JoyVASA 声音生成LivePortrait Motion """ def __init__(self, **kwargs): @@ -36,7 +36,7 @@ def __init__(self, **kwargs): motion_model_path = kwargs.get("motion_model_path", "") audio_model_path = kwargs.get("audio_model_path", "") motion_template_path = kwargs.get("motion_template_path", "") - model_data = torch.load(motion_model_path, map_location="cpu") + model_data = torch.load(motion_model_path, map_location="cpu", weights_only=False) model_args = NullableArgs(model_data['args']) model = DitTalkingHead(motion_feat_dim=model_args.motion_feat_dim, n_motions=model_args.n_motions, @@ -171,3 +171,4 @@ def gen_motion_sequence(self, audio_path, **kwargs): tgt_motion = {'n_frames': motion_coef.shape[0], 'output_fps': self.fps, 'motion': motion_list, 'c_eyes_lst': [], 'c_lip_lst': []} return tgt_motion +