diffused-heads / main.py
Sof22's picture
Update main.py
d44dd3f
raw
history blame
2.28 kB
from fastapi import FastAPI, Query, File, UploadFile
from fastapi.responses import FileResponse
import torch
from diffusion import Diffusion # Make sure you import your own modules correctly
from utils import get_id_frame, get_audio_emb, save_video # Make sure you import your own modules correctly
import shutil
from pathlib import Path
app = FastAPI()
@app.post("/generate_video/")
async def generate_video(
id_frame_file: UploadFile = File(...),
audio_file: UploadFile = File(...),
gpu: bool = Query(True, description="Use GPU if available"),
id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"),
inference_steps: int = Query(100, description="Number of inference diffusion steps"),
output: str = Query("/Users/a/Documents/Automations/git talking heads/output_video.mp4", description="Path to save the output video")
):
device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu'
print('Loading model...')
unet = torch.jit.load("/Users/a/Documents/Automations/git talking heads/checkpoints/crema_script.pt")
diffusion_args = {
"in_channels": 3,
"image_size": 128,
"out_channels": 6,
"n_timesteps": 1000,
}
diffusion = Diffusion(unet, device, **diffusion_args).to(device)
diffusion.space(inference_steps)
# Save uploaded files to disk
id_frame_path = Path("temp_id_frame.jpg")
audio_path = Path("temp_audio.mp3")
with id_frame_path.open("wb") as buffer:
shutil.copyfileobj(id_frame_file.file, buffer)
with audio_path.open("wb") as buffer:
shutil.copyfileobj(audio_file.file, buffer)
id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device)
audio, audio_emb = get_audio_emb(str(audio_path), "/Users/a/Documents/Automations/git talking heads/checkpoints/audio_encoder.pt", device)
unet_args = {
"n_audio_motion_embs": 2,
"n_motion_frames": 2,
"motion_channels": 3
}
samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0), **unet_args)
save_video(output, samples, audio=audio, fps=25, audio_rate=16000)
print(f'Results saved at {output}')
return FileResponse(output)