Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, Query, File, UploadFile | |
from fastapi.responses import FileResponse | |
import torch | |
from diffusion import Diffusion # Make sure you import your own modules correctly | |
from utils import get_id_frame, get_audio_emb, save_video # Make sure you import your own modules correctly | |
import shutil | |
from pathlib import Path | |
app = FastAPI() | |
async def generate_video( | |
id_frame_file: UploadFile = File(...), | |
audio_file: UploadFile = File(...), | |
gpu: bool = Query(True, description="Use GPU if available"), | |
id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"), | |
inference_steps: int = Query(100, description="Number of inference diffusion steps"), | |
output: str = Query("/Users/a/Documents/Automations/git talking heads/output_video.mp4", description="Path to save the output video") | |
): | |
device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu' | |
print('Loading model...') | |
unet = torch.jit.load("/Users/a/Documents/Automations/git talking heads/checkpoints/crema_script.pt") | |
diffusion_args = { | |
"in_channels": 3, | |
"image_size": 128, | |
"out_channels": 6, | |
"n_timesteps": 1000, | |
} | |
diffusion = Diffusion(unet, device, **diffusion_args).to(device) | |
diffusion.space(inference_steps) | |
# Save uploaded files to disk | |
id_frame_path = Path("temp_id_frame.jpg") | |
audio_path = Path("temp_audio.mp3") | |
with id_frame_path.open("wb") as buffer: | |
shutil.copyfileobj(id_frame_file.file, buffer) | |
with audio_path.open("wb") as buffer: | |
shutil.copyfileobj(audio_file.file, buffer) | |
id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device) | |
audio, audio_emb = get_audio_emb(str(audio_path), "/Users/a/Documents/Automations/git talking heads/checkpoints/audio_encoder.pt", device) | |
unet_args = { | |
"n_audio_motion_embs": 2, | |
"n_motion_frames": 2, | |
"motion_channels": 3 | |
} | |
samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0), **unet_args) | |
save_video(output, samples, audio=audio, fps=25, audio_rate=16000) | |
print(f'Results saved at {output}') | |
return FileResponse(output) | |