Spaces:
Paused
Paused
File size: 3,866 Bytes
7183fd3 07c39b0 7183fd3 07c39b0 7183fd3 07c39b0 7183fd3 984b4fb 7183fd3 07c39b0 7183fd3 07c39b0 70bb060 7183fd3 07c39b0 7183fd3 07c39b0 a822ce3 07c39b0 984b4fb 07c39b0 7183fd3 07c39b0 02a6b14 07c39b0 7183fd3 07c39b0 e71d8d9 07c39b0 7183fd3 07c39b0 02a6b14 a822ce3 02a6b14 07c39b0 a822ce3 02a6b14 07c39b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
from PIL import Image
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model, vis_processors, _ = load_model_and_preprocess(
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)
def generate_caption(image, caption_type):
image = vis_processors["eval"](image).unsqueeze(0).to(device)
if caption_type == "Beam Search":
caption = model.generate({"image": image})
else:
caption = model.generate(
{"image": image}, use_nucleus_sampling=True, num_captions=3
)
caption = "\n".join(caption)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return caption
def chat(input_image, question, history):
history = history or []
question = question.lower()
image = vis_processors["eval"](input_image).unsqueeze(0).to(device)
clean = lambda x: x.replace("<p>", "").replace("</p>", "").replace("\n", "")
clean_h = lambda x: (clean(x[0]), clean(x[1]))
context = list(map(clean_h, history))
template = "Question: {} Answer: {}."
prompt = (
" ".join(
[template.format(context[i][0], context[i][1]) for i in range(len(context))]
)
+ " Question: "
+ question
+ " Answer:"
)
response = model.generate({"image": image, "prompt": prompt})
history.append((question, response[0]))
return history, history
def clear_chat(history):
return [], []
with gr.Blocks() as demo:
gr.Markdown(
"### BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models"
)
gr.Markdown(
"This demo uses the `pretrain_opt2.7b` weights. For more information please visit [Github](https://github.com/salesforce/LAVIS/tree/main/projects/blip2) or [Paper](https://arxiv.org/abs/2301.12597)."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Image", type="pil")
caption_type = gr.Radio(
["Beam Search", "Nucleus Sampling"],
label="Caption Decoding Strategy",
value="Beam Search",
)
btn_caption = gr.Button("Generate Caption")
output_text = gr.Textbox(label="Answer", lines=5)
with gr.Column():
chatbot = gr.Chatbot()
chat_state = gr.State()
question_txt = gr.Textbox(label="Question", lines=1)
btn_answer = gr.Button("Generate Answer")
btn_clear = gr.Button("Clear Chat")
btn_caption.click(
generate_caption, inputs=[input_image, caption_type], outputs=[output_text]
)
btn_answer.click(
chat,
inputs=[input_image, question_txt, chat_state],
outputs=[chatbot, chat_state],
)
btn_clear.click(clear_chat, inputs=[chat_state], outputs=[chatbot, chat_state])
gr.Examples(
[
["./merlion.png", "Beam Search", "which city is this?"],
[
"./Blue_Jay_0044_62759.jpg",
"Beam Search",
"what is the name of this bird?",
],
["./5kstbz-0001.png", "Beam Search", "where is the man standing?"],
[
"ILSVRC2012_val_00000008.JPEG",
"Beam Search",
"Name the colors of macarons you see in the image.",
],
],
inputs=[input_image, caption_type, question_txt],
)
gr.Markdown(
"Sample images are taken from [ImageNet](https://paperswithcode.com/sota/image-classification-on-imagenet), [CUB](https://paperswithcode.com/dataset/cub-200-2011) and [GamePhysics](https://asgaardlab.github.io/CLIPxGamePhysics/) datasets."
)
demo.launch()
|