|
|
|
|
|
import os |
|
import subprocess |
|
|
|
from cog import BasePredictor, Input, Path |
|
|
|
import inference |
|
|
|
from time import time |
|
|
|
from functools import wraps |
|
import torch |
|
|
|
|
|
def make_mem_efficient(cls: BasePredictor): |
|
if not torch.cuda.is_available(): |
|
return cls |
|
|
|
old_setup = cls.setup |
|
old_predict = cls.predict |
|
|
|
@wraps(old_setup) |
|
def new_setup(self, *args, **kwargs): |
|
ret = old_setup(self, *args, **kwargs) |
|
_move_to(self, "cpu") |
|
return ret |
|
|
|
@wraps(old_predict) |
|
def new_predict(self, *args, **kwargs): |
|
_move_to(self, "cuda") |
|
try: |
|
ret = old_predict(self, *args, **kwargs) |
|
finally: |
|
_move_to(self, "cpu") |
|
return ret |
|
|
|
cls.setup = new_setup |
|
cls.predict = new_predict |
|
|
|
return cls |
|
|
|
|
|
def _move_to(self, device): |
|
try: |
|
self = self.cached_models |
|
except AttributeError: |
|
pass |
|
for attr, value in vars(self).items(): |
|
try: |
|
value = value.to(device) |
|
except AttributeError: |
|
pass |
|
else: |
|
print(f"Moving {self.__name__}.{attr} to {device}") |
|
setattr(self, attr, value) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
@make_mem_efficient |
|
class Predictor(BasePredictor): |
|
cached_models = inference |
|
|
|
def setup(self): |
|
inference.do_load("checkpoints/wav2lip_gan.pth") |
|
|
|
def predict( |
|
self, |
|
face: Path = Input(description="video/image that contains faces to use"), |
|
audio: Path = Input(description="video/audio file to use as raw audio source"), |
|
pads: str = Input( |
|
description="Padding for the detected face bounding box.\n" |
|
"Please adjust to include chin at least\n" |
|
'Format: "top bottom left right"', |
|
default="0 10 0 0", |
|
), |
|
smooth: bool = Input( |
|
description="Smooth face detections over a short temporal window", |
|
default=True, |
|
), |
|
fps: float = Input( |
|
description="Can be specified only if input is a static image", |
|
default=25.0, |
|
), |
|
out_height: int = Input( |
|
description="Output video height. Best results are obtained at 480 or 720", |
|
default=480, |
|
), |
|
) -> Path: |
|
try: |
|
os.remove("results/result_voice.mp4") |
|
except FileNotFoundError: |
|
pass |
|
|
|
face_ext = os.path.splitext(face)[-1] |
|
if face_ext not in [".mp4", ".mov", ".png" , ".jpg" , ".jpeg" , ".gif", ".mkv", ".webp"]: |
|
raise ValueError(f'Unsupported face format {face_ext!r}') |
|
|
|
audio_ext = os.path.splitext(audio)[-1] |
|
if audio_ext not in [".wav", ".mp3"]: |
|
raise ValueError(f'Unsupported audio format {audio_ext!r}') |
|
|
|
args = [ |
|
"--checkpoint_path", "checkpoints/wav2lip_gan.pth", |
|
"--face", str(face), |
|
"--audio", str(audio), |
|
"--pads", *pads.split(" "), |
|
"--fps", str(fps), |
|
"--out_height", str(out_height), |
|
] |
|
if not smooth: |
|
args += ["--nosmooth"] |
|
|
|
print("-> run:", " ".join(args)) |
|
inference.args = inference.parser.parse_args(args) |
|
|
|
s = time() |
|
|
|
try: |
|
inference.main() |
|
except ValueError as e: |
|
print('-> Encountered error, skipping lipsync:', e) |
|
|
|
args = [ |
|
"ffmpeg", "-y", |
|
|
|
"-stream_loop", "-1", |
|
"-i", str(face), |
|
"-i", str(audio), |
|
"-shortest", |
|
"-fflags", "+shortest", |
|
"-max_interleave_delta", "100M", |
|
"-map", "0:v:0", |
|
"-map", "1:a:0", |
|
|
|
|
|
"results/result_voice.mp4", |
|
] |
|
print("-> run:", " ".join(args)) |
|
print(subprocess.check_output(args, encoding="utf-8")) |
|
|
|
print(time() - s) |
|
|
|
return Path("results/result_voice.mp4") |
|
|