test / app.py
AngeT10's picture
Create app.py
585f05a verified
import os
import gradio as gr
from loguru import logger
# Funzione per scaricare i modelli
def download_models():
logger.info("Scaricamento dei modelli...")
os.system("apt update && apt install aria2 -y")
base_url = "https://huggingface.co/camenduru/HunyuanVideo"
models = {
"transformers/mp_rank_00_model_states.pt": "ckpts/hunyuan-video-t2v-720p/transformers",
"vae/config.json": "ckpts/hunyuan-video-t2v-720p/vae",
"vae/pytorch_model.pt": "ckpts/hunyuan-video-t2v-720p/vae",
"text_encoder/config.json": "ckpts/text_encoder",
"text_encoder/generation_config.json": "ckpts/text_encoder",
"text_encoder/model-00001-of-00004.safetensors": "ckpts/text_encoder",
"text_encoder/model-00002-of-00004.safetensors": "ckpts/text_encoder",
"text_encoder/model-00003-of-00004.safetensors": "ckpts/text_encoder",
"text_encoder/model-00004-of-00004.safetensors": "ckpts/text_encoder",
"text_encoder/model.safetensors.index.json": "ckpts/text_encoder",
"text_encoder/special_tokens_map.json": "ckpts/text_encoder",
"text_encoder/tokenizer.json": "ckpts/text_encoder",
"text_encoder/tokenizer_config.json": "ckpts/text_encoder",
}
for file_path, folder in models.items():
os.makedirs(folder, exist_ok=True)
command = (
f"aria2c --console-log-level=error -c -x 16 -s 16 -k 1M "
f"{base_url}/resolve/main/{file_path} -d {folder} -o {os.path.basename(file_path)}"
)
logger.info(f"Scaricando: {file_path}")
os.system(command)
logger.info("Download completato.")
# Funzione per generare il video
def generate_video(prompt, video_size, video_length, infer_steps, seed):
download_models()
logger.info("Clonazione del repository...")
os.system("git clone https://github.com/Tencent/HunyuanVideo /content/HunyuanVideo")
os.chdir("/content/HunyuanVideo")
save_path = "./results/generated_video.mp4"
command = (
f"python sample_video.py "
f"--video-size {video_size[0]} {video_size[1]} "
f"--video-length {video_length} "
f"--infer-steps {infer_steps} "
f"--prompt '{prompt}' "
f"--flow-reverse "
f"--seed {seed} "
f"--use-cpu-offload "
f"--save-path {save_path}"
)
logger.info("Esecuzione del modello...")
os.system(command)
if os.path.exists(save_path):
return save_path
else:
logger.error("Video non generato correttamente.")
return None
# Interfaccia Gradio
def infer(prompt, width, height, video_length, infer_steps, seed):
video_size = (width, height)
video_path = generate_video(prompt, video_size, video_length, infer_steps, seed)
if video_path:
return video_path
return "Errore nella generazione del video."
with gr.Blocks() as demo:
gr.Markdown("# HunyuanVideo - Generazione di video basati su testo")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="Descrivi il tuo video (es. a cat is running, realistic.)")
width = gr.Slider(label="Larghezza Video", minimum=360, maximum=1920, step=1, value=720)
height = gr.Slider(label="Altezza Video", minimum=360, maximum=1080, step=1, value=1280)
video_length = gr.Slider(label="Durata Video (frames)", minimum=10, maximum=300, step=1, value=129)
infer_steps = gr.Slider(label="Passi di Inferenza", minimum=10, maximum=100, step=1, value=50)
seed = gr.Slider(label="Seed", minimum=0, maximum=1000, step=1, value=0)
submit_btn = gr.Button("Genera Video")
with gr.Column():
output = gr.Video(label="Video Generato")
submit_btn.click(infer, inputs=[prompt, width, height, video_length, infer_steps, seed], outputs=output)
demo.launch()