import torch import gradio as gr from transformers import AutoProcessor, AutoModel from utils import ( convert_frames_to_gif, download_youtube_video, get_num_total_frames, sample_frames_from_video_file, ) FRAME_SAMPLING_RATE = 4 DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot" VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [ "microsoft/xclip-base-patch32", "microsoft/xclip-base-patch16-zero-shot", "microsoft/xclip-base-patch16-kinetics-600", "microsoft/xclip-large-patch14ft/xclip-base-patch32-16-frames", "microsoft/xclip-large-patch14", "microsoft/xclip-base-patch16-hmdb-4-shot", "microsoft/xclip-base-patch16-16-frames", "microsoft/xclip-base-patch16-hmdb-2-shot", "microsoft/xclip-base-patch16-ucf-2-shot", "microsoft/xclip-base-patch16-ucf-8-shot", "microsoft/xclip-base-patch16", "microsoft/xclip-base-patch16-hmdb-8-shot", "microsoft/xclip-base-patch16-hmdb-16-shot", "microsoft/xclip-base-patch16-ucf-16-shot", ] processor = AutoProcessor.from_pretrained(DEFAULT_MODEL) model = AutoModel.from_pretrained(DEFAULT_MODEL) examples = [ [ "https://www.youtu.be/l1dBM8ZECao", "sleeping dog,cat fight club,birds of prey", ], [ "https://youtu.be/VMj-3S1tku0", "programming course,eating spaghetti,playing football", ], [ "https://youtu.be/Tm6BlRMEny0", "game of thrones,the lord of the rings,vikings", ], ] def select_model(model_name): global processor, model processor = AutoProcessor.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name) def predict(youtube_url_or_file_path, labels_text): if youtube_url_or_file_path.startswith("http"): video_path = download_youtube_video(youtube_url_or_file_path) else: video_path = youtube_url_or_file_path # rearrange sampling rate based on video length and model input length num_total_frames = get_num_total_frames(video_path) num_model_input_frames = model.config.vision_config.num_frames if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames: frame_sampling_rate = num_total_frames // num_model_input_frames else: frame_sampling_rate = FRAME_SAMPLING_RATE labels = labels_text.split(",") frames = sample_frames_from_video_file( video_path, num_model_input_frames, frame_sampling_rate ) gif_path = convert_frames_to_gif(frames, save_path="video.gif") inputs = processor( text=labels, videos=list(frames), return_tensors="pt", padding=True ) # forward pass with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy() label_to_prob = {} for ind, label in enumerate(labels): label_to_prob[label] = float(probs[ind]) return label_to_prob, gif_path app = gr.Blocks() with app: gr.Markdown( "# **
Zero-shot Video Classification with 🤗 Transformers
**" ) gr.Markdown( """
Follow me for more!
twitter | github | linkedin | medium