File size: 2,523 Bytes
20985fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import torch
import gradio as gr
from transformers import XCLIPProcessor, XCLIPModel
from utils import convert_frames_to_gif, download_youtube_video, sample_frames_from_video_file

model_name = "microsoft/xclip-base-patch16-zero-shot"
processor = XCLIPProcessor.from_pretrained(model_name)
model = XCLIPModel.from_pretrained(model_name)

examples = [
    ["https://www.youtu.be/l1dBM8ZECao", "sleeping dog,cat fight club,birds of prey"],
    ["https://youtu.be/VMj-3S1tku0", "programming course,eating spaghetti,playing football"], 
    ["https://www.youtu.be/x8UAUAuKNcU", "game of thrones,the lord of the rings,vikings"]
]


def predict(youtube_url, labels_text):

    labels = labels_text.split(",")
    video_path = download_youtube_video(youtube_url)
    frames = sample_frames_from_video_file(video_path, num_frames=32)
    os.remove(video_path)
    gif_path = convert_frames_to_gif(frames)

    inputs = processor(
        text=labels,
        videos=list(frames),
        return_tensors="pt",
        padding=True
    )
    # forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy()
    label_to_prob = {}
    for ind, label in enumerate(labels):
        label_to_prob[label] = float(probs[ind])
    
    return label_to_prob, gif_path

app = gr.Blocks()
with app:
    gr.Markdown("# **<p align='center'>Zero-shot Video Classification with X-CLIP</p>**")

    with gr.Row():
        with gr.Column():
            gr.Markdown("Provide a Youtube video URL and a list of labels separated by commas")
            youtube_url = gr.Textbox(label="Youtube URL:", show_label=True)
            labels_text = gr.Textbox(label="Labels Text:", show_label=True)
            predict_btn = gr.Button(value="Predict")
        with gr.Column():
            video_gif = gr.Image(label="Input Clip", show_label=True,)
        with gr.Column():
            predictions = gr.Label(label='Predictions:', show_label=True)

    gr.Markdown("**Examples:**")
    gr.Examples(examples, [youtube_url, labels_text], [predictions, video_gif], fn=predict, cache_examples=True)

    predict_btn.click(predict, inputs=[youtube_url, labels_text], outputs=[predictions, video_gif])
    gr.Markdown(
        """
        \n Demo created by: <a href=\"https://github.com/fcakyon\">fcakyon</a>
        <br> Based on this <a href=\"https://huggingface.co/microsoft/xclip-base-patch16-zero-shot\">HuggingFace model</a>
        """
        )
    
app.launch()