Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModel | |
from utils import ( | |
convert_frames_to_gif, | |
download_youtube_video, | |
get_num_total_frames, | |
sample_frames_from_video_file, | |
) | |
FRAME_SAMPLING_RATE = 4 | |
DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot" | |
VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [ | |
"microsoft/xclip-base-patch32", | |
"microsoft/xclip-base-patch16-zero-shot", | |
"microsoft/xclip-base-patch16-kinetics-600", | |
"microsoft/xclip-large-patch14ft/xclip-base-patch32-16-frames", | |
"microsoft/xclip-large-patch14", | |
"microsoft/xclip-base-patch16-hmdb-4-shot", | |
"microsoft/xclip-base-patch16-16-frames", | |
"microsoft/xclip-base-patch16-hmdb-2-shot", | |
"microsoft/xclip-base-patch16-ucf-2-shot", | |
"microsoft/xclip-base-patch16-ucf-8-shot", | |
"microsoft/xclip-base-patch16", | |
"microsoft/xclip-base-patch16-hmdb-8-shot", | |
"microsoft/xclip-base-patch16-hmdb-16-shot", | |
"microsoft/xclip-base-patch16-ucf-16-shot", | |
] | |
processor = AutoProcessor.from_pretrained(DEFAULT_MODEL) | |
model = AutoModel.from_pretrained(DEFAULT_MODEL) | |
examples = [ | |
[ | |
"https://www.youtu.be/l1dBM8ZECao", | |
"sleeping dog,cat fight club,birds of prey", | |
], | |
[ | |
"https://youtu.be/VMj-3S1tku0", | |
"programming course,eating spaghetti,playing football", | |
], | |
[ | |
"https://youtu.be/BRw7rvLdGzU", | |
"game of thrones,the lord of the rings,vikings", | |
], | |
] | |
def select_model(model_name): | |
global processor, model | |
processor = AutoProcessor.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
def predict(youtube_url_or_file_path, labels_text): | |
if youtube_url_or_file_path.startswith("http"): | |
video_path = download_youtube_video(youtube_url_or_file_path) | |
else: | |
video_path = youtube_url_or_file_path | |
# rearrange sampling rate based on video length and model input length | |
num_total_frames = get_num_total_frames(video_path) | |
num_model_input_frames = model.config.vision_config.num_frames | |
if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames: | |
frame_sampling_rate = num_total_frames // num_model_input_frames | |
else: | |
frame_sampling_rate = FRAME_SAMPLING_RATE | |
labels = labels_text.split(",") | |
frames = sample_frames_from_video_file( | |
video_path, num_model_input_frames, frame_sampling_rate | |
) | |
gif_path = convert_frames_to_gif(frames, save_path="video.gif") | |
inputs = processor( | |
text=labels, videos=list(frames), return_tensors="pt", padding=True | |
) | |
# forward pass | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy() | |
label_to_prob = {} | |
for ind, label in enumerate(labels): | |
label_to_prob[label] = float(probs[ind]) | |
return label_to_prob, gif_path | |
app = gr.Blocks() | |
with app: | |
gr.Markdown( | |
"# **<p align='center'>PROGTOG VIOLENCE DETECTION</p>**" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
model_names_dropdown = gr.Dropdown( | |
choices=VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS, | |
label="Model:", | |
show_label=True, | |
value=DEFAULT_MODEL, | |
) | |
model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown) | |
with gr.Tab(label="Youtube URL"): | |
gr.Markdown( | |
"### **Youtube URL**" | |
) | |
youtube_url = gr.Textbox(label="Youtube URL:", show_label=True) | |
youtube_url_labels_text = gr.Textbox( | |
label="Labels Text:", show_label=True | |
) | |
youtube_url_predict_btn = gr.Button(value="Predict") | |
with gr.Tab(label="Local File"): | |
gr.Markdown( | |
"### **Tags**" | |
) | |
video_file = gr.Video(label="Video File:", show_label=True) | |
local_video_labels_text = gr.Textbox( | |
label="Labels Text:", show_label=True | |
) | |
local_video_predict_btn = gr.Button(value="Predict") | |
with gr.Column(): | |
video_gif = gr.Image( | |
label="Input Clip", | |
show_label=True, | |
) | |
with gr.Column(): | |
predictions = gr.Label(label="Predictions:", show_label=True) | |
# gr.Markdown("**Examples:**") | |
# gr.Examples( | |
# examples, | |
# [youtube_url, youtube_url_labels_text], | |
# [predictions, video_gif], | |
# fn=predict, | |
# cache_examples=True, | |
# ) | |
youtube_url_predict_btn.click( | |
predict, | |
inputs=[youtube_url, youtube_url_labels_text], | |
outputs=[predictions, video_gif], | |
) | |
local_video_predict_btn.click( | |
predict, | |
inputs=[video_file, local_video_labels_text], | |
outputs=[predictions, video_gif], | |
) | |
# gr.Markdown( | |
# """ | |
# \n Demo created by: <a href=\"https://github.com/fcakyon\">fcakyon</a>. | |
# <br> Based on this <a href=\"https://huggingface.co/docs/transformers/main/model_doc/xclip">HuggingFace model</a>. | |
# """ | |
# ) | |
app.launch() |