File size: 6,681 Bytes
caca7fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46a655a
 
 
 
 
 
 
6a6f278
46a655a
 
caca7fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a6f278
 
 
 
 
 
 
 
 
 
 
 
 
 
caca7fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import numpy as np
import gradio as gr
import torch

from transformers import BertTokenizer, FlavaForPreTraining, FlavaModel, FlavaFeatureExtractor, FlavaProcessor
from PIL import Image


demo = gr.Blocks()

tokenizer = BertTokenizer.from_pretrained("facebook/flava-full")
flava_pt = FlavaForPreTraining.from_pretrained("facebook/flava-full")
flava = FlavaModel.from_pretrained("facebook/flava-full")
processor = FlavaProcessor.from_pretrained("facebook/flava-full")
fe = FlavaFeatureExtractor.from_pretrained("facebook/flava-full")


PREDICTION_ATTR = "mlm_logits"

def zero_shot_text(text, options):
    options = [option.strip() for option in options.split(";")]
    option_indices = tokenizer.convert_tokens_to_ids(options)
    tokens = tokenizer([text], return_tensors="pt")
    mask_ids = tokens["input_ids"][0] == 103
    with torch.no_grad():
        output = flava_pt(**tokens)
    
    text_logits = getattr(output, PREDICTION_ATTR)
    probs = text_logits[0, mask_ids, option_indices].view(-1, len(option_indices)).mean(dim=0)
    probs = torch.nn.functional.softmax(probs, dim=-1)
    return {label: probs[idx].item() for idx, label in enumerate(options)}


def zero_shot_image(image, options):
    PIL_image = Image.fromarray(np.uint8(image)).convert("RGB")
    labels = [label.strip() for label in options.split(";")]
    image_input = fe([PIL_image], return_tensors="pt")
    text_inputs = tokenizer(
        labels, padding="max_length", return_tensors="pt"
    )

    image_embeddings = flava.get_image_features(**image_input)[:, 0, :]
    text_embeddings = flava.get_text_features(**text_inputs)[:, 0, :]
    similarities = list(
        torch.nn.functional.softmax(
            (text_embeddings @ image_embeddings.T).squeeze(0), dim=0
        )
    )
    return {label: similarities[idx].item() for idx, label in enumerate(labels)}
  
def zero_shot_multimodal(image, text, options):
    options = [option.strip() for option in options.split(";")]
    option_indices = tokenizer.convert_tokens_to_ids(options)
    tokens = processor([image], [text], return_tensors="pt", return_codebook_pixels=True, return_image_mask=True)

    mask_ids = tokens["input_ids"][0] == 103
    tokens["bool_masked_pos"] = torch.ones_like(tokens["bool_masked_pos"])

    with torch.no_grad():
        output = flava_pt(**tokens)
    
    text_logits = getattr(output, "mmm_text_logits")
    probs = text_logits[0, mask_ids, option_indices].view(-1, len(option_indices)).mean(dim=0)
    probs = torch.nn.functional.softmax(probs, dim=-1)
    return {label: probs[idx].item() for idx, label in enumerate(options)}

with demo:
    gr.Markdown(
    """
    # Zero-Shot image, text or multimodal classification using the same FLAVA model

    Click on one the examples provided to load them into the UI and "Classify".

    - For image classification, provide class options to be ranked separated by `;`.
    - For text and multimodal classification, provide your 1) prompt with the word you want to be filled in as `[MASK]`, and 2) possible options to be ranked separated by `;`.  
    """
    )
    with gr.Tabs():
        with gr.TabItem("Zero-Shot Image Classification"):
            with gr.Row():
                with gr.Column():
                    image_input = gr.Image()
                    text_options_i = gr.Textbox(label="Classes (seperated by ;)")
                    image_button = gr.Button("Classify")
                    image_dataset = gr.Dataset(
                        components=[image_input, text_options_i],
                        samples=[
                            ["cows.jpg", "a cow; two cows in a green field; a cow in a green field"],
                            ["sofa.jpg", "a room with red sofa; a red room with sofa; ladder in a room"]
                        ]
                    )

                labels_image = gr.Label(label="Probabilities")
        with gr.TabItem("Zero-Shot Text Classification"):
            with gr.Row():
                with gr.Column():
                    text_input = gr.Textbox(label="Prompt")
                    text_options = gr.Textbox(label="Label options (separate by ;)")
                    text_button = gr.Button("Classify")
                    text_dataset = gr.Dataset(
                        components=[text_input, text_options],
                        samples=[
                            ["by far the worst movie of the year. This was [MASK]", "negative; positive"],
                            ["Lord Voldemort -- in the films; born Tom Marvolo Riddle) is a fictional character and the main antagonist in J.K. Rowling's series of Harry Potter novels. Voldemort first appeared in Harry Potter and the Philosopher's Stone, which was released in 1997. Voldemort appears either in person or in flashbacks in each book and its film adaptation in the series, except the third, Harry Potter and the Prisoner of Azkaban, where he is only mentioned. Question: are tom riddle and lord voldemort the same person? Answer: [MASK]", "no; yes"],
                        ]
                    )
                labels_text = gr.Label(label="Probabilities")
        with gr.TabItem("Zero-Shot MultiModal Classification"):
            with gr.Row():
                with gr.Column():
                    image_input_mm = gr.Image()
                    text_input_mm = gr.Textbox(label="Prompt")
                    text_options_mm = gr.Textbox(label="Options (separate by ;)")
                    multimodal_button = gr.Button("Classify")
                    multimodal_dataset = gr.Dataset(
                        components=[image_input_mm, text_input_mm],
                        samples=[
                            ["cows.jpg", "What animals are in the field? They are [MASK].", "cows; lions; sheep; monkeys"],
                            ["sofa.jpg", "What furniture is in the room? It is [MASK].", "sofa; ladder; bucket"]
                        ]
                    )
                labels_multimodal = gr.Label(label="Probabilities")

    text_button.click(zero_shot_text, inputs=[text_input, text_options], outputs=labels_text)
    image_button.click(zero_shot_image, inputs=[image_input, text_options_i], outputs=labels_image)
    multimodal_button.click(zero_shot_multimodal, inputs=[image_input_mm, text_input_mm, text_options_mm], outputs=labels_multimodal)
    text_dataset.click(lambda a: a, inputs=[text_dataset], outputs=[text_input, text_options])
    image_dataset.click(lambda a: a, inputs=[image_dataset], outputs=[image_input, text_options_i])
    multimodal_dataset.click(lambda a: a, inputs=[multimodal_dataset], outputs=[image_input_mm, text_input_mm, text_options_mm])

demo.launch()