echomimic-v2 / app.py
fffiloni's picture
Update app.py
7a168b3 verified
import os
import random
from pathlib import Path
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from PIL import Image
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.unet_3d_emo import EMOUNet3DConditionModel
from src.models.whisper.audio2feature import load_audio_model
from src.pipelines.pipeline_echomimicv2 import EchoMimicV2Pipeline
from src.utils.util import save_videos_grid
from src.models.pose_encoder import PoseEncoder
from src.utils.dwpose_util import draw_pose_select_v2
from moviepy.editor import VideoFileClip, AudioFileClip
import gradio as gr
from datetime import datetime
from torchao.quantization import quantize_, int8_weight_only
import gc
import tempfile
from pydub import AudioSegment
def cut_audio_to_5_seconds(audio_path):
try:
# Load the audio file
audio = AudioSegment.from_file(audio_path)
# Trim to a maximum of 5 seconds (5000 milliseconds)
trimmed_audio = audio[:5000]
# Create a temporary directory
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, "trimmed_audio.wav")
# Export the trimmed audio
trimmed_audio.export(output_path, format="wav")
return output_path
except Exception as e:
return f"An error occurred while trying to trim audio: {str(e)}"
import requests
import tarfile
def download_and_setup_ffmpeg():
url = "https://www.johnvansickle.com/ffmpeg/old-releases/ffmpeg-4.4-amd64-static.tar.xz"
download_path = "ffmpeg-4.4-amd64-static.tar.xz"
extract_dir = "ffmpeg-4.4-amd64-static"
try:
# Download the file
response = requests.get(url, stream=True)
response.raise_for_status() # Check for HTTP request errors
with open(download_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
# Extract the tar.xz file
with tarfile.open(download_path, "r:xz") as tar:
tar.extractall(path=extract_dir)
# Set the FFMPEG_PATH environment variable
ffmpeg_binary_path = os.path.join(extract_dir, "ffmpeg-4.4-amd64-static", "ffmpeg")
os.environ["FFMPEG_PATH"] = ffmpeg_binary_path
return f"FFmpeg downloaded and setup successfully! Path: {ffmpeg_binary_path}"
except Exception as e:
return f"An error occurred: {str(e)}"
download_and_setup_ffmpeg()
from huggingface_hub import snapshot_download
# Create the main "pretrained_weights" folder
os.makedirs("pretrained_weights", exist_ok=True)
# List of subdirectories to create inside "pretrained_weights"
subfolders = [
"sd-vae-ft-mse",
"sd-image-variations-diffusers",
"audio_processor"
]
# Create each subdirectory
for subfolder in subfolders:
os.makedirs(os.path.join("pretrained_weights", subfolder), exist_ok=True)
snapshot_download(
repo_id = "BadToBest/EchoMimicV2",
local_dir="./pretrained_weights"
)
snapshot_download(
repo_id = "stabilityai/sd-vae-ft-mse",
local_dir="./pretrained_weights/sd-vae-ft-mse"
)
snapshot_download(
repo_id = "lambdalabs/sd-image-variations-diffusers",
local_dir="./pretrained_weights/sd-image-variations-diffusers"
)
is_shared_ui = True if "fffiloni/echomimic-v2" in os.environ['SPACE_ID'] else False
# Download and place the Whisper model in the "audio_processor" folder
def download_whisper_model():
url = "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt"
save_path = os.path.join("pretrained_weights", "audio_processor", "tiny.pt")
try:
# Download the file
response = requests.get(url, stream=True)
response.raise_for_status() # Check for HTTP request errors
with open(save_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Whisper model downloaded and saved to {save_path}")
except Exception as e:
print(f"An error occurred while downloading the model: {str(e)}")
# Download the Whisper model
download_whisper_model()
total_vram_in_gb = torch.cuda.get_device_properties(0).total_memory / 1073741824
print(f'\033[32mCUDA版本:{torch.version.cuda}\033[0m')
print(f'\033[32mPytorch版本:{torch.__version__}\033[0m')
print(f'\033[32m显卡型号:{torch.cuda.get_device_name()}\033[0m')
print(f'\033[32m显存大小:{total_vram_in_gb:.2f}GB\033[0m')
print(f'\033[32m精度:float16\033[0m')
dtype = torch.float16
if torch.cuda.is_available():
device = "cuda"
else:
print("cuda not available, using cpu")
device = "cpu"
ffmpeg_path = os.getenv('FFMPEG_PATH')
if ffmpeg_path is None:
print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=./ffmpeg-4.4-amd64-static")
elif ffmpeg_path not in os.getenv('PATH'):
print("add ffmpeg to path")
os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"
def generate(image_input, audio_input, pose_input, width, height, length, steps, sample_rate, cfg, fps, context_frames, context_overlap, quantization_input, seed, progress=gr.Progress(track_tqdm=True)):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = Path("outputs")
save_dir.mkdir(exist_ok=True, parents=True)
############# model_init started #############
## vae init
vae = AutoencoderKL.from_pretrained("./pretrained_weights/sd-vae-ft-mse").to(device, dtype=dtype)
if quantization_input:
quantize_(vae, int8_weight_only())
print("Use int8 quantization.")
## reference net init
reference_unet = UNet2DConditionModel.from_pretrained("./pretrained_weights/sd-image-variations-diffusers", subfolder="unet", use_safetensors=False).to(dtype=dtype, device=device)
reference_unet.load_state_dict(torch.load("./pretrained_weights/reference_unet.pth", weights_only=True))
if quantization_input:
quantize_(reference_unet, int8_weight_only())
## denoising net init
if os.path.exists("./pretrained_weights/motion_module.pth"):
print('using motion module')
else:
exit("motion module not found")
### stage1 + stage2
denoising_unet = EMOUNet3DConditionModel.from_pretrained_2d(
"./pretrained_weights/sd-image-variations-diffusers",
"./pretrained_weights/motion_module.pth",
subfolder="unet",
unet_additional_kwargs = {
"use_inflated_groupnorm": True,
"unet_use_cross_frame_attention": False,
"unet_use_temporal_attention": False,
"use_motion_module": True,
"cross_attention_dim": 384,
"motion_module_resolutions": [
1,
2,
4,
8
],
"motion_module_mid_block": True ,
"motion_module_decoder_only": False,
"motion_module_type": "Vanilla",
"motion_module_kwargs":{
"num_attention_heads": 8,
"num_transformer_block": 1,
"attention_block_types": [
'Temporal_Self',
'Temporal_Self'
],
"temporal_position_encoding": True,
"temporal_position_encoding_max_len": 32,
"temporal_attention_dim_div": 1,
}
},
).to(dtype=dtype, device=device)
denoising_unet.load_state_dict(torch.load("./pretrained_weights/denoising_unet.pth", weights_only=True),strict=False)
# pose net init
pose_net = PoseEncoder(320, conditioning_channels=3, block_out_channels=(16, 32, 96, 256)).to(dtype=dtype, device=device)
pose_net.load_state_dict(torch.load("./pretrained_weights/pose_encoder.pth", weights_only=True))
### load audio processor params
audio_processor = load_audio_model(model_path="./pretrained_weights/audio_processor/tiny.pt", device=device)
############# model_init finished #############
sched_kwargs = {
"beta_start": 0.00085,
"beta_end": 0.012,
"beta_schedule": "linear",
"clip_sample": False,
"steps_offset": 1,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"timestep_spacing": "trailing"
}
scheduler = DDIMScheduler(**sched_kwargs)
pipe = EchoMimicV2Pipeline(
vae=vae,
reference_unet=reference_unet,
denoising_unet=denoising_unet,
audio_guider=audio_processor,
pose_encoder=pose_net,
scheduler=scheduler,
)
pipe = pipe.to(device, dtype=dtype)
if seed is not None and seed > -1:
generator = torch.manual_seed(seed)
else:
seed = random.randint(100, 1000000)
generator = torch.manual_seed(seed)
if is_shared_ui:
audio_input = cut_audio_to_5_seconds(audio_input)
print(f"Trimmed audio saved at: {audio_input}")
inputs_dict = {
"refimg": image_input,
"audio": audio_input,
"pose": pose_input,
}
print('Pose:', inputs_dict['pose'])
print('Reference:', inputs_dict['refimg'])
print('Audio:', inputs_dict['audio'])
save_name = f"{save_dir}/{timestamp}"
ref_image_pil = Image.open(inputs_dict['refimg']).resize((width, height))
audio_clip = AudioFileClip(inputs_dict['audio'])
length = min(length, int(audio_clip.duration * fps), len(os.listdir(inputs_dict['pose'])))
start_idx = 0
pose_list = []
for index in range(start_idx, start_idx + length):
tgt_musk = np.zeros((width, height, 3)).astype('uint8')
tgt_musk_path = os.path.join(inputs_dict['pose'], "{}.npy".format(index))
detected_pose = np.load(tgt_musk_path, allow_pickle=True).tolist()
imh_new, imw_new, rb, re, cb, ce = detected_pose['draw_pose_params']
im = draw_pose_select_v2(detected_pose, imh_new, imw_new, ref_w=800)
im = np.transpose(np.array(im),(1, 2, 0))
tgt_musk[rb:re,cb:ce,:] = im
tgt_musk_pil = Image.fromarray(np.array(tgt_musk)).convert('RGB')
pose_list.append(torch.Tensor(np.array(tgt_musk_pil)).to(dtype=dtype, device=device).permute(2,0,1) / 255.0)
poses_tensor = torch.stack(pose_list, dim=1).unsqueeze(0)
audio_clip = AudioFileClip(inputs_dict['audio'])
audio_clip = audio_clip.set_duration(length / fps)
video = pipe(
ref_image_pil,
inputs_dict['audio'],
poses_tensor[:,:,:length,...],
width,
height,
length,
steps,
cfg,
generator=generator,
audio_sample_rate=sample_rate,
context_frames=context_frames,
fps=fps,
context_overlap=context_overlap,
start_idx=start_idx,
).videos
final_length = min(video.shape[2], poses_tensor.shape[2], length)
video_sig = video[:, :, :final_length, :, :]
save_videos_grid(
video_sig,
save_name + "_woa_sig.mp4",
n_rows=1,
fps=fps,
)
video_clip_sig = VideoFileClip(save_name + "_woa_sig.mp4",)
video_clip_sig = video_clip_sig.set_audio(audio_clip)
video_clip_sig.write_videofile(save_name + "_sig.mp4", codec="libx264", audio_codec="aac", threads=2)
video_output = save_name + "_sig.mp4"
seed_text = gr.update(visible=True, value=seed)
return video_output, seed_text
with gr.Blocks() as demo:
gr.Markdown("""
# EchoMimicV2
⚠️ This demonstration is for academic research and experiential use only.
""")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/antgroup/echomimic_v2">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
<a href="https://antgroup.github.io/ai/echomimic_v2/">
<img src='https://img.shields.io/badge/Project-Page-green'>
</a>
<a href="https://arxiv.org/abs/2411.10061">
<img src='https://img.shields.io/badge/ArXiv-Paper-red'>
</a>
<a href="https://huggingface.co/spaces/fffiloni/echomimic-v2?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
<a href="https://huggingface.co/fffiloni">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
</a>
</div>
""")
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Group():
image_input = gr.Image(label="Image Input (Auto Scaling)", type="filepath")
audio_input = gr.Audio(label="Audio Input - max 5 seconds on shared UI", type="filepath")
pose_input = gr.Textbox(label="Pose Input (Directory Path)", placeholder="Please enter the directory path for pose data.", value="assets/halfbody_demo/pose/01", interactive=False, visible=False)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width = gr.Number(label="Width (multiple of 16, recommended: 768)", value=768)
height = gr.Number(label="Height (multiple of 16, recommended: 768)", value=768)
length = gr.Number(label="Video Length (recommended: 240)", value=240)
with gr.Row():
steps = gr.Number(label="Steps (recommended: 30)", value=20)
sample_rate = gr.Number(label="Sampling Rate (recommended: 16000)", value=16000)
cfg = gr.Number(label="CFG (recommended: 2.5)", value=2.5, step=0.1)
with gr.Row():
fps = gr.Number(label="Frame Rate (recommended: 24)", value=24)
context_frames = gr.Number(label="Context Frames (recommended: 12)", value=12)
context_overlap = gr.Number(label="Context Overlap (recommended: 3)", value=3)
with gr.Row():
quantization_input = gr.Checkbox(label="Int8 Quantization (recommended for users with 12GB VRAM, use audio no longer than 5 seconds)", value=False)
seed = gr.Number(label="Seed (-1 for random)", value=-1)
generate_button = gr.Button("🎬 Generate Video")
with gr.Column():
video_output = gr.Video(label="Output Video")
seed_text = gr.Textbox(label="Seed", interactive=False, visible=False)
gr.Examples(
examples=[
["EMTD_dataset/ref_imgs_by_FLUX/man/0001.png", "assets/halfbody_demo/audio/chinese/echomimicv2_man.wav"],
["EMTD_dataset/ref_imgs_by_FLUX/woman/0077.png", "assets/halfbody_demo/audio/chinese/echomimicv2_woman.wav"],
["EMTD_dataset/ref_imgs_by_FLUX/man/0003.png", "assets/halfbody_demo/audio/chinese/fighting.wav"],
["EMTD_dataset/ref_imgs_by_FLUX/woman/0033.png", "assets/halfbody_demo/audio/chinese/good.wav"],
["EMTD_dataset/ref_imgs_by_FLUX/man/0010.png", "assets/halfbody_demo/audio/chinese/news.wav"],
["EMTD_dataset/ref_imgs_by_FLUX/man/1168.png", "assets/halfbody_demo/audio/chinese/no_smoking.wav"],
["EMTD_dataset/ref_imgs_by_FLUX/woman/0057.png", "assets/halfbody_demo/audio/chinese/ultraman.wav"]
],
inputs=[image_input, audio_input],
label="Preset Characters and Audio",
)
generate_button.click(
generate,
inputs=[image_input, audio_input, pose_input, width, height, length, steps, sample_rate, cfg, fps, context_frames, context_overlap, quantization_input, seed],
outputs=[video_output, seed_text],
)
if __name__ == "__main__":
demo.queue()
demo.launch(show_api=False, show_error=True, ssr_mode=False)