|
import gradio as gr |
|
import torch |
|
import os |
|
import base64 |
|
import uuid |
|
import tempfile |
|
import numpy as np |
|
import cv2 |
|
import subprocess |
|
from DeepCache import DeepCacheSDHelper |
|
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
from PIL import Image |
|
|
|
|
|
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') |
|
|
|
|
|
bases = { |
|
"ToonYou": "frankjoshua/toonyou_beta6", |
|
"epiCRealism": "emilianJR/epiCRealism" |
|
} |
|
step_loaded = None |
|
base_loaded = "epiCRealism" |
|
motion_loaded = None |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
raise NotImplementedError("No GPU detected!") |
|
|
|
device = "cuda" |
|
dtype = torch.float16 |
|
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device) |
|
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear") |
|
|
|
|
|
hardcoded_fps = 10 |
|
hardcoded_duration_sec = 1.6 |
|
|
|
|
|
step = 4 |
|
repo = "ByteDance/AnimateDiff-Lightning" |
|
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" |
|
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False) |
|
step_loaded = step |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def export_to_video_file(video_frames, output_video_path=None, fps=hardcoded_fps): |
|
if output_video_path is None: |
|
output_video_path = tempfile.NamedTemporaryFile(suffix=".webm").name |
|
|
|
if isinstance(video_frames[0], np.ndarray): |
|
video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] |
|
elif isinstance(video_frames[0], Image.Image): |
|
video_frames = [np.array(frame) for frame in video_frames] |
|
|
|
|
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'VP90') |
|
h, w, c = video_frames[0].shape |
|
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (w, h), True) |
|
|
|
for frame in video_frames: |
|
|
|
img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
video_writer.write(img) |
|
video_writer.release() |
|
|
|
return output_video_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpolate_video_frames( |
|
input_file_path, |
|
output_file_path, |
|
output_fps=hardcoded_fps, |
|
desired_duration=hardcoded_duration_sec, |
|
original_duration=hardcoded_duration_sec, |
|
output_width=None, |
|
output_height=None, |
|
use_cuda=False, |
|
verbose=False): |
|
|
|
scale_factor = desired_duration / original_duration |
|
|
|
filters = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if output_width and output_height: |
|
filters.append(f'scale={output_width}:{output_height}') |
|
|
|
|
|
|
|
|
|
interpolation_filter = f'minterpolate=mi_mode=mci:mc_mode=obmc:me=hexbs:vsbmc=1:mb_size=4:fps={output_fps}:scd=none,setpts={scale_factor}*PTS' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filters.append(interpolation_filter) |
|
|
|
|
|
filter_complex = ','.join(filters) |
|
|
|
|
|
cmd = [ |
|
'ffmpeg', |
|
'-i', input_file_path, |
|
] |
|
|
|
|
|
if use_cuda: |
|
cmd.extend(['-hwaccel', 'cuda', '-hwaccel_output_format', 'cuda']) |
|
|
|
cmd.extend([ |
|
'-filter:v', filter_complex, |
|
'-r', str(output_fps), |
|
output_file_path |
|
]) |
|
|
|
|
|
if not verbose: |
|
cmd.insert(1, '-loglevel') |
|
cmd.insert(2, 'error') |
|
|
|
|
|
if verbose: |
|
print("output_fps:", output_fps) |
|
print("desired_duration:", desired_duration) |
|
print("original_duration:", original_duration) |
|
print("cmd:", cmd) |
|
|
|
try: |
|
subprocess.run(cmd, check=True) |
|
return output_file_path |
|
except subprocess.CalledProcessError as e: |
|
print("Failed to interpolate video. Error:", e) |
|
return input_file_path |
|
|
|
def generate_image(secret_token, prompt, base, width, height, motion, step, desired_duration, desired_fps): |
|
if secret_token != SECRET_TOKEN: |
|
raise gr.Error( |
|
f'Invalid secret token. Please fork the original space if you want to use it for yourself.') |
|
|
|
|
|
global step_loaded |
|
global base_loaded |
|
global motion_loaded |
|
|
|
|
|
if step_loaded != step: |
|
repo = "ByteDance/AnimateDiff-Lightning" |
|
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" |
|
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False) |
|
step_loaded = step |
|
|
|
if base_loaded != base: |
|
pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False) |
|
base_loaded = base |
|
|
|
if motion_loaded != motion: |
|
pipe.unload_lora_weights() |
|
if motion != "": |
|
pipe.load_lora_weights(motion, adapter_name="motion") |
|
pipe.set_adapters(["motion"], [0.7]) |
|
motion_loaded = motion |
|
|
|
output = pipe( |
|
prompt=prompt, |
|
|
|
width=width, |
|
height=height, |
|
|
|
guidance_scale=1.0, |
|
num_inference_steps=step, |
|
) |
|
|
|
video_uuid = str(uuid.uuid4()).replace("-", "") |
|
raw_video_path = f"/tmp/{video_uuid}_raw.webm" |
|
enhanced_video_path = f"/tmp/{video_uuid}_enhanced.webm" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_video_path = export_to_video_file(output.frames[0], raw_video_path, fps=hardcoded_fps) |
|
|
|
final_video_path = raw_video_path |
|
|
|
|
|
if desired_duration > hardcoded_duration_sec or desired_duration < hardcoded_duration_sec or desired_fps > hardcoded_fps or desired_fps < hardcoded_fps: |
|
final_video_path = interpolate_video_frames(raw_video_path, enhanced_video_path, output_fps=desired_fps, desired_duration=desired_duration) |
|
|
|
|
|
with open(final_video_path, "rb") as video_file: |
|
video_base64 = base64.b64encode(video_file.read()).decode('utf-8') |
|
|
|
|
|
|
|
|
|
try: |
|
os.remove(raw_video_path) |
|
if final_video_path != raw_video_path: |
|
os.remove(final_video_path) |
|
except Exception as e: |
|
print("Failed to delete a video path:", e) |
|
|
|
|
|
video_data_uri = 'data:video/webm;base64,' + video_base64 |
|
|
|
return video_data_uri |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(""" |
|
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;"> |
|
<div style="text-align: center; color: black;"> |
|
<p style="color: black;">This space is a headless component of the cloud rendering engine used by AiTube.</p> |
|
<p style="color: black;">It is not available for public use, but you can use the <a href="https://huggingface.co/spaces/ByteDance/AnimateDiff-Lightning" target="_blank">original space</a>.</p> |
|
</div> |
|
</div>""") |
|
|
|
|
|
secret_token = gr.Text(label='Secret Token', max_lines=1) |
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
prompt = gr.Textbox( |
|
label='Prompt' |
|
) |
|
with gr.Row(): |
|
select_base = gr.Dropdown( |
|
label='Base model', |
|
choices=[ |
|
"ToonYou", |
|
"epiCRealism", |
|
], |
|
value=base_loaded |
|
) |
|
width = gr.Slider( |
|
label='Width', |
|
minimum=128, |
|
maximum=2048, |
|
step=32, |
|
value=512, |
|
) |
|
height = gr.Slider( |
|
label='Height', |
|
minimum=128, |
|
maximum=2048, |
|
step=32, |
|
value=288, |
|
) |
|
select_motion = gr.Dropdown( |
|
label='Motion', |
|
choices=[ |
|
("Default", ""), |
|
("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"), |
|
("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"), |
|
("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"), |
|
("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"), |
|
("Pan left", "guoyww/animatediff-motion-lora-pan-left"), |
|
("Pan right", "guoyww/animatediff-motion-lora-pan-right"), |
|
("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"), |
|
("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"), |
|
], |
|
value="", |
|
) |
|
select_step = gr.Dropdown( |
|
label='Inference steps', |
|
choices=[ |
|
('1-Step', 1), |
|
('2-Step', 2), |
|
('4-Step', 4), |
|
('8-Step', 8)], |
|
value=4, |
|
) |
|
duration_slider = gr.Slider(label="Desired Duration (seconds)", min_value=1, max_value=120, value=hardcoded_duration_sec, step=0.1) |
|
fps_slider = gr.Slider(label="Desired Frames Per Second", min_value=10, max_value=60, value=hardcoded_fps, step=1) |
|
|
|
submit = gr.Button() |
|
|
|
output_video_base64 = gr.Text() |
|
|
|
submit.click( |
|
fn=generate_image, |
|
inputs=[secret_token, prompt, select_base, width, height, select_motion, select_step, duration_slider, fps_slider], |
|
outputs=output_video_base64, |
|
) |
|
|
|
demo.queue(max_size=12).launch(show_api=True) |