File size: 3,866 Bytes
7183fd3
07c39b0
 
7183fd3
07c39b0
7183fd3
07c39b0
7183fd3
 
 
 
984b4fb
7183fd3
 
 
 
 
07c39b0
7183fd3
 
 
07c39b0
 
 
 
 
 
70bb060
 
 
7183fd3
 
 
07c39b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7183fd3
07c39b0
a822ce3
07c39b0
 
984b4fb
07c39b0
 
7183fd3
 
 
07c39b0
 
02a6b14
07c39b0
 
7183fd3
07c39b0
 
 
e71d8d9
07c39b0
 
7183fd3
 
07c39b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02a6b14
 
 
 
 
 
 
 
 
 
a822ce3
02a6b14
 
 
 
07c39b0
 
a822ce3
 
 
02a6b14
07c39b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os

import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
from PIL import Image

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

model, vis_processors, _ = load_model_and_preprocess(
    name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)


def generate_caption(image, caption_type):
    image = vis_processors["eval"](image).unsqueeze(0).to(device)

    if caption_type == "Beam Search":
        caption = model.generate({"image": image})
    else:
        caption = model.generate(
            {"image": image}, use_nucleus_sampling=True, num_captions=3
        )

    caption = "\n".join(caption)

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return caption


def chat(input_image, question, history):
    history = history or []
    question = question.lower()

    image = vis_processors["eval"](input_image).unsqueeze(0).to(device)

    clean = lambda x: x.replace("<p>", "").replace("</p>", "").replace("\n", "")
    clean_h = lambda x: (clean(x[0]), clean(x[1]))
    context = list(map(clean_h, history))
    template = "Question: {} Answer: {}."
    prompt = (
        " ".join(
            [template.format(context[i][0], context[i][1]) for i in range(len(context))]
        )
        + " Question: "
        + question
        + " Answer:"
    )

    response = model.generate({"image": image, "prompt": prompt})
    history.append((question, response[0]))

    return history, history


def clear_chat(history):
    return [], []


with gr.Blocks() as demo:
    gr.Markdown(
        "### BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models"
    )
    gr.Markdown(
        "This demo uses the `pretrain_opt2.7b` weights. For more information please visit [Github](https://github.com/salesforce/LAVIS/tree/main/projects/blip2) or [Paper](https://arxiv.org/abs/2301.12597)."
    )

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Image", type="pil")
            caption_type = gr.Radio(
                ["Beam Search", "Nucleus Sampling"],
                label="Caption Decoding Strategy",
                value="Beam Search",
            )
            btn_caption = gr.Button("Generate Caption")
            output_text = gr.Textbox(label="Answer", lines=5)

        with gr.Column():
            chatbot = gr.Chatbot()
            chat_state = gr.State()

            question_txt = gr.Textbox(label="Question", lines=1)
            btn_answer = gr.Button("Generate Answer")
            btn_clear = gr.Button("Clear Chat")

    btn_caption.click(
        generate_caption, inputs=[input_image, caption_type], outputs=[output_text]
    )

    btn_answer.click(
        chat,
        inputs=[input_image, question_txt, chat_state],
        outputs=[chatbot, chat_state],
    )

    btn_clear.click(clear_chat, inputs=[chat_state], outputs=[chatbot, chat_state])

    gr.Examples(
        [
            ["./merlion.png", "Beam Search", "which city is this?"],
            [
                "./Blue_Jay_0044_62759.jpg",
                "Beam Search",
                "what is the name of this bird?",
            ],
            ["./5kstbz-0001.png", "Beam Search", "where is the man standing?"],
            [
                "ILSVRC2012_val_00000008.JPEG",
                "Beam Search",
                "Name the colors of macarons you see in the image.",
            ],
        ],
        inputs=[input_image, caption_type, question_txt],
    )

    gr.Markdown(
        "Sample images are taken from [ImageNet](https://paperswithcode.com/sota/image-classification-on-imagenet), [CUB](https://paperswithcode.com/dataset/cub-200-2011) and [GamePhysics](https://asgaardlab.github.io/CLIPxGamePhysics/) datasets."
    )

demo.launch()