|
import gradio as gr |
|
import torch |
|
import os |
|
import base64 |
|
import uuid |
|
import tempfile |
|
import numpy as np |
|
import cv2 |
|
import subprocess |
|
|
|
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
from PIL import Image |
|
|
|
|
|
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') |
|
|
|
|
|
bases = { |
|
"ToonYou": "frankjoshua/toonyou_beta6", |
|
"epiCRealism": "emilianJR/epiCRealism" |
|
} |
|
step_loaded = None |
|
base_loaded = "epiCRealism" |
|
motion_loaded = None |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
raise NotImplementedError("No GPU detected!") |
|
|
|
device = "cuda" |
|
dtype = torch.float16 |
|
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device) |
|
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def export_to_video_file(video_frames, output_video_path=None, fps=10): |
|
if output_video_path is None: |
|
output_video_path = tempfile.NamedTemporaryFile(suffix=".webm").name |
|
|
|
if isinstance(video_frames[0], np.ndarray): |
|
video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] |
|
elif isinstance(video_frames[0], Image.Image): |
|
video_frames = [np.array(frame) for frame in video_frames] |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'VP90') |
|
h, w, c = video_frames[0].shape |
|
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (w, h), True) |
|
|
|
for frame in video_frames: |
|
|
|
img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
video_writer.write(img) |
|
video_writer.release() |
|
|
|
return output_video_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
def interpolate_video_frames(input_file_path, output_file_path, output_fps=10, desired_duration=2): |
|
""" |
|
Interpolates frames in a video file to adjust frame rate and duration using ffmpeg's minterpolate. |
|
|
|
Parameters: |
|
input_file_path (str): Path to the input video file. |
|
output_file_path (str): Path to the output video file. |
|
output_fps (int): Target frames per second for the output video. |
|
desired_duration (int): Desired duration of the video in seconds. |
|
|
|
Returns: |
|
str: The file path of the modified video. |
|
""" |
|
|
|
input_fps = find_input_fps(input_file_path, desired_duration) |
|
|
|
|
|
cmd = [ |
|
'ffmpeg', |
|
'-i', input_file_path, |
|
'-filter:v', f'minterpolate=fps={output_fps}', |
|
'-r', str(output_fps), |
|
output_file_path |
|
] |
|
|
|
|
|
try: |
|
subprocess.run(cmd, check=True) |
|
print("Video interpolation successful.") |
|
return input_file_path |
|
except subprocess.CalledProcessError as e: |
|
print("Failed to interpolate video. Error:", e) |
|
return output_file_path |
|
|
|
def find_input_fps(file_path, desired_duration): |
|
""" |
|
Determine the input fps that, when stretched to the desired duration, matches the original video length. |
|
|
|
Parameters: |
|
file_path (str): Path to the video file. |
|
desired_duration (int or float): Desired duration in seconds. |
|
|
|
Returns: |
|
float: Calculated input fps. |
|
""" |
|
|
|
ffprobe_cmd = [ |
|
'ffprobe', |
|
'-v', 'error', |
|
'-show_entries', 'format=duration', |
|
'-of', 'default=noprint_wrappers=1:nokey=1', |
|
file_path |
|
] |
|
|
|
try: |
|
result = subprocess.run(ffprobe_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
duration = float(result.stdout.strip()) |
|
input_fps = duration / desired_duration |
|
except Exception as e: |
|
print("Failed to get video duration. Error:", e) |
|
input_fps = 10 |
|
|
|
return input_fps |
|
|
|
def generate_image(secret_token, prompt, base, width, height, motion, step, desired_duration, desired_fps): |
|
if secret_token != SECRET_TOKEN: |
|
raise gr.Error( |
|
f'Invalid secret token. Please fork the original space if you want to use it for yourself.') |
|
|
|
|
|
global step_loaded |
|
global base_loaded |
|
global motion_loaded |
|
|
|
|
|
if step_loaded != step: |
|
repo = "ByteDance/AnimateDiff-Lightning" |
|
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" |
|
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False) |
|
step_loaded = step |
|
|
|
if base_loaded != base: |
|
pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False) |
|
base_loaded = base |
|
|
|
if motion_loaded != motion: |
|
pipe.unload_lora_weights() |
|
if motion != "": |
|
pipe.load_lora_weights(motion, adapter_name="motion") |
|
pipe.set_adapters(["motion"], [0.7]) |
|
motion_loaded = motion |
|
|
|
output = pipe( |
|
prompt=prompt, |
|
|
|
width=width, |
|
height=height, |
|
|
|
guidance_scale=1.0, |
|
num_inference_steps=step, |
|
) |
|
|
|
video_uuid = str(uuid.uuid4()).replace("-", "") |
|
raw_video_path = f"/tmp/{video_uuid}_raw.webm" |
|
enhanced_video_path = f"/tmp/{video_uuid}_enhanced.webm" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw_video_path = export_to_video_file(output.frames[0], raw_video_path, fps=10) |
|
|
|
final_video_path = raw_video_path |
|
|
|
|
|
if desired_duration != 2 or desired_fps != 10: |
|
final_video_path = interpolate_video_frames(raw_video_path, enhanced_video_path, output_fps=desired_fps, desired_duration=desired_duration) |
|
|
|
|
|
with open(final_video_path, "rb") as video_file: |
|
video_base64 = base64.b64encode(video_file.read()).decode('utf-8') |
|
|
|
|
|
|
|
|
|
try: |
|
os.remove(raw_video_path) |
|
if final_video_path != raw_video_path: |
|
os.remove(final_video_path) |
|
except Exception as e: |
|
print("Failed to delete a video path:", e) |
|
|
|
|
|
video_data_uri = 'data:video/webm;base64,' + video_base64 |
|
|
|
return video_data_uri |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML(""" |
|
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;"> |
|
<div style="text-align: center; color: black;"> |
|
<p style="color: black;">This space is a headless component of the cloud rendering engine used by AiTube.</p> |
|
<p style="color: black;">It is not available for public use, but you can use the <a href="https://huggingface.co/spaces/ByteDance/AnimateDiff-Lightning" target="_blank">original space</a>.</p> |
|
</div> |
|
</div>""") |
|
|
|
|
|
secret_token = gr.Text(label='Secret Token', max_lines=1) |
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
prompt = gr.Textbox( |
|
label='Prompt' |
|
) |
|
with gr.Row(): |
|
select_base = gr.Dropdown( |
|
label='Base model', |
|
choices=[ |
|
"ToonYou", |
|
"epiCRealism", |
|
], |
|
value=base_loaded |
|
) |
|
width = gr.Slider( |
|
label='Width', |
|
minimum=128, |
|
maximum=2048, |
|
step=32, |
|
value=512, |
|
) |
|
height = gr.Slider( |
|
label='Height', |
|
minimum=128, |
|
maximum=2048, |
|
step=32, |
|
value=256, |
|
) |
|
select_motion = gr.Dropdown( |
|
label='Motion', |
|
choices=[ |
|
("Default", ""), |
|
("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"), |
|
("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"), |
|
("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"), |
|
("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"), |
|
("Pan left", "guoyww/animatediff-motion-lora-pan-left"), |
|
("Pan right", "guoyww/animatediff-motion-lora-pan-right"), |
|
("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"), |
|
("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"), |
|
], |
|
value="", |
|
) |
|
select_step = gr.Dropdown( |
|
label='Inference steps', |
|
choices=[ |
|
('1-Step', 1), |
|
('2-Step', 2), |
|
('4-Step', 4), |
|
('8-Step', 8)], |
|
value=4, |
|
) |
|
duration_slider = gr.Slider(label="Desired Duration (seconds)", min_value=2, max_value=30, value=2, step=1) |
|
fps_slider = gr.Slider(label="Desired Frames Per Second", min_value=10, max_value=60, value=10, step=1) |
|
|
|
submit = gr.Button() |
|
|
|
output_video_base64 = gr.Text() |
|
|
|
submit.click( |
|
fn=generate_image, |
|
inputs=[secret_token, prompt, select_base, width, height, select_motion, select_step, duration_slider, fps_slider], |
|
outputs=output_video_base64, |
|
) |
|
|
|
demo.queue(max_size=12).launch(show_api=True) |