Fatih commited on
Commit
20985fb
·
1 Parent(s): fb6d4e3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from transformers import XCLIPProcessor, XCLIPModel
5
+ from utils import convert_frames_to_gif, download_youtube_video, sample_frames_from_video_file
6
+
7
+ model_name = "microsoft/xclip-base-patch16-zero-shot"
8
+ processor = XCLIPProcessor.from_pretrained(model_name)
9
+ model = XCLIPModel.from_pretrained(model_name)
10
+
11
+ examples = [
12
+ ["https://www.youtu.be/l1dBM8ZECao", "sleeping dog,cat fight club,birds of prey"],
13
+ ["https://youtu.be/VMj-3S1tku0", "programming course,eating spaghetti,playing football"],
14
+ ["https://www.youtu.be/x8UAUAuKNcU", "game of thrones,the lord of the rings,vikings"]
15
+ ]
16
+
17
+
18
+ def predict(youtube_url, labels_text):
19
+
20
+ labels = labels_text.split(",")
21
+ video_path = download_youtube_video(youtube_url)
22
+ frames = sample_frames_from_video_file(video_path, num_frames=32)
23
+ os.remove(video_path)
24
+ gif_path = convert_frames_to_gif(frames)
25
+
26
+ inputs = processor(
27
+ text=labels,
28
+ videos=list(frames),
29
+ return_tensors="pt",
30
+ padding=True
31
+ )
32
+ # forward pass
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+
36
+ probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy()
37
+ label_to_prob = {}
38
+ for ind, label in enumerate(labels):
39
+ label_to_prob[label] = float(probs[ind])
40
+
41
+ return label_to_prob, gif_path
42
+
43
+ app = gr.Blocks()
44
+ with app:
45
+ gr.Markdown("# **<p align='center'>Zero-shot Video Classification with X-CLIP</p>**")
46
+
47
+ with gr.Row():
48
+ with gr.Column():
49
+ gr.Markdown("Provide a Youtube video URL and a list of labels separated by commas")
50
+ youtube_url = gr.Textbox(label="Youtube URL:", show_label=True)
51
+ labels_text = gr.Textbox(label="Labels Text:", show_label=True)
52
+ predict_btn = gr.Button(value="Predict")
53
+ with gr.Column():
54
+ video_gif = gr.Image(label="Input Clip", show_label=True,)
55
+ with gr.Column():
56
+ predictions = gr.Label(label='Predictions:', show_label=True)
57
+
58
+ gr.Markdown("**Examples:**")
59
+ gr.Examples(examples, [youtube_url, labels_text], [predictions, video_gif], fn=predict, cache_examples=True)
60
+
61
+ predict_btn.click(predict, inputs=[youtube_url, labels_text], outputs=[predictions, video_gif])
62
+ gr.Markdown(
63
+ """
64
+ \n Demo created by: <a href=\"https://github.com/fcakyon\">fcakyon</a>
65
+ <br> Based on this <a href=\"https://huggingface.co/microsoft/xclip-base-patch16-zero-shot\">HuggingFace model</a>
66
+ """
67
+ )
68
+
69
+ app.launch()