mehal's picture
Update app.py
0832bd7
import torch
import gradio as gr
from transformers import AutoProcessor, AutoModel
from utils import (
convert_frames_to_gif,
download_youtube_video,
get_num_total_frames,
sample_frames_from_video_file,
)
FRAME_SAMPLING_RATE = 4
DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot"
VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [
"microsoft/xclip-base-patch32",
"microsoft/xclip-base-patch16-zero-shot",
"microsoft/xclip-base-patch16-kinetics-600",
"microsoft/xclip-large-patch14ft/xclip-base-patch32-16-frames",
"microsoft/xclip-large-patch14",
"microsoft/xclip-base-patch16-hmdb-4-shot",
"microsoft/xclip-base-patch16-16-frames",
"microsoft/xclip-base-patch16-hmdb-2-shot",
"microsoft/xclip-base-patch16-ucf-2-shot",
"microsoft/xclip-base-patch16-ucf-8-shot",
"microsoft/xclip-base-patch16",
"microsoft/xclip-base-patch16-hmdb-8-shot",
"microsoft/xclip-base-patch16-hmdb-16-shot",
"microsoft/xclip-base-patch16-ucf-16-shot",
]
processor = AutoProcessor.from_pretrained(DEFAULT_MODEL)
model = AutoModel.from_pretrained(DEFAULT_MODEL)
examples = [
[
"https://www.youtu.be/l1dBM8ZECao",
"sleeping dog,cat fight club,birds of prey",
],
[
"https://youtu.be/VMj-3S1tku0",
"programming course,eating spaghetti,playing football",
],
[
"https://youtu.be/BRw7rvLdGzU",
"game of thrones,the lord of the rings,vikings",
],
]
def select_model(model_name):
global processor, model
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
def predict(youtube_url_or_file_path, labels_text):
if youtube_url_or_file_path.startswith("http"):
video_path = download_youtube_video(youtube_url_or_file_path)
else:
video_path = youtube_url_or_file_path
# rearrange sampling rate based on video length and model input length
num_total_frames = get_num_total_frames(video_path)
num_model_input_frames = model.config.vision_config.num_frames
if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames:
frame_sampling_rate = num_total_frames // num_model_input_frames
else:
frame_sampling_rate = FRAME_SAMPLING_RATE
labels = labels_text.split(",")
frames = sample_frames_from_video_file(
video_path, num_model_input_frames, frame_sampling_rate
)
gif_path = convert_frames_to_gif(frames, save_path="video.gif")
inputs = processor(
text=labels, videos=list(frames), return_tensors="pt", padding=True
)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy()
label_to_prob = {}
for ind, label in enumerate(labels):
label_to_prob[label] = float(probs[ind])
return label_to_prob, gif_path
app = gr.Blocks()
with app:
gr.Markdown(
"# **<p align='center'>PROGTOG VIOLENCE DETECTION</p>**"
)
with gr.Row():
with gr.Column():
model_names_dropdown = gr.Dropdown(
choices=VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS,
label="Model:",
show_label=True,
value=DEFAULT_MODEL,
)
model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown)
with gr.Tab(label="Youtube URL"):
gr.Markdown(
"### **Youtube URL**"
)
youtube_url = gr.Textbox(label="Youtube URL:", show_label=True)
youtube_url_labels_text = gr.Textbox(
label="Labels Text:", show_label=True
)
youtube_url_predict_btn = gr.Button(value="Predict")
with gr.Tab(label="Local File"):
gr.Markdown(
"### **Tags**"
)
video_file = gr.Video(label="Video File:", show_label=True)
local_video_labels_text = gr.Textbox(
label="Labels Text:", show_label=True
)
local_video_predict_btn = gr.Button(value="Predict")
with gr.Column():
video_gif = gr.Image(
label="Input Clip",
show_label=True,
)
with gr.Column():
predictions = gr.Label(label="Predictions:", show_label=True)
# gr.Markdown("**Examples:**")
# gr.Examples(
# examples,
# [youtube_url, youtube_url_labels_text],
# [predictions, video_gif],
# fn=predict,
# cache_examples=True,
# )
youtube_url_predict_btn.click(
predict,
inputs=[youtube_url, youtube_url_labels_text],
outputs=[predictions, video_gif],
)
local_video_predict_btn.click(
predict,
inputs=[video_file, local_video_labels_text],
outputs=[predictions, video_gif],
)
# gr.Markdown(
# """
# \n Demo created by: <a href=\"https://github.com/fcakyon\">fcakyon</a>.
# <br> Based on this <a href=\"https://huggingface.co/docs/transformers/main/model_doc/xclip">HuggingFace model</a>.
# """
# )
app.launch()