Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import requests
|
3 |
+
from PIL import Image
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
from lavis.common.gradcam import getAttMap
|
9 |
+
from lavis.models import load_model_and_preprocess
|
10 |
+
|
11 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
def prepare_data(image, question):
|
15 |
+
image = vis_processors["eval"](image).unsqueeze(0).to(device)
|
16 |
+
question = txt_processors["eval"](question)
|
17 |
+
samples = {"image": image, "text_input": [question]}
|
18 |
+
return samples
|
19 |
+
|
20 |
+
def gradcam_attention(image, question):
|
21 |
+
dst_w = 720
|
22 |
+
samples = prepare_data(image, question)
|
23 |
+
samples = model.forward_itm(samples=samples)
|
24 |
+
|
25 |
+
w, h = image.size
|
26 |
+
scaling_factor = dst_w / w
|
27 |
+
|
28 |
+
resized_img = image.resize((int(w * scaling_factor), int(h * scaling_factor)))
|
29 |
+
norm_img = np.float32(resized_img) / 255
|
30 |
+
gradcam = samples['gradcams'].reshape(24,24)
|
31 |
+
|
32 |
+
avg_gradcam = getAttMap(norm_img, gradcam, blur=True)
|
33 |
+
return (avg_gradcam * 255).astype(np.uint8)
|
34 |
+
|
35 |
+
def generate_cap(image, question, cap_number):
|
36 |
+
samples = prepare_data(image, question)
|
37 |
+
samples = model.forward_itm(samples=samples)
|
38 |
+
samples = model.forward_cap(samples=samples, num_captions=cap_number, num_patches=5)
|
39 |
+
print('Examples of question-guided captions: ')
|
40 |
+
return pd.DataFrame({'Caption': samples['captions'][0][:cap_number]})
|
41 |
+
|
42 |
+
def postprocess(text):
|
43 |
+
for i, ans in enumerate(text):
|
44 |
+
for j, w in enumerate(ans):
|
45 |
+
if w == '.' or w == '\n':
|
46 |
+
ans = ans[:j].lower()
|
47 |
+
break
|
48 |
+
return ans
|
49 |
+
|
50 |
+
def generate_answer(image, question):
|
51 |
+
samples = prepare_data(image, question)
|
52 |
+
samples = model.forward_itm(samples=samples)
|
53 |
+
samples = model.forward_cap(samples=samples, num_captions=5, num_patches=5)
|
54 |
+
samples = model.forward_qa_generation(samples)
|
55 |
+
Img2Prompt = model.prompts_construction(samples)
|
56 |
+
Img2Prompt_input = tokenizer(Img2Prompt, padding='longest', truncation=True, return_tensors="pt").to(device)
|
57 |
+
|
58 |
+
outputs = llm_model.generate(input_ids=Img2Prompt_input.input_ids,
|
59 |
+
attention_mask=Img2Prompt_input.attention_mask,
|
60 |
+
max_length=20+len(Img2Prompt_input.input_ids[0]),
|
61 |
+
return_dict_in_generate=True,
|
62 |
+
output_scores=True
|
63 |
+
)
|
64 |
+
pred_answer = tokenizer.batch_decode(outputs.sequences[:, len(Img2Prompt_input.input_ids[0]):])
|
65 |
+
pred_answer = postprocess(pred_answer)
|
66 |
+
print(pred_answer, type(pred_answer))
|
67 |
+
return pred_answer
|
68 |
+
|
69 |
+
# setup device to use
|
70 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
71 |
+
print(device)
|
72 |
+
|
73 |
+
def load_model(model_selection):
|
74 |
+
model = AutoModelForCausalLM.from_pretrained(model_selection)
|
75 |
+
tokenizer = AutoTokenizer.from_pretrained(model_selection, use_fast=False)
|
76 |
+
return model,tokenizer
|
77 |
+
|
78 |
+
# Choose LLM to use
|
79 |
+
# weights for OPT-6.7B/OPT-13B/OPT-30B/OPT-66B will download automatically
|
80 |
+
print("Loading Large Language Model (LLM)...")
|
81 |
+
llm_model, tokenizer = load_model('facebook/opt-1.3b') # ~13G (FP16)
|
82 |
+
llm_model.to(device)
|
83 |
+
model, vis_processors, txt_processors = load_model_and_preprocess(name="img2prompt_vqa", model_type="base", is_eval=True, device=device)
|
84 |
+
|
85 |
+
|
86 |
+
# ---- Gradio Layout -----
|
87 |
+
title = "From Images to Textual Prompts: Zero-shot VQA with Frozen Large Language Models"
|
88 |
+
df_init = pd.DataFrame(columns=['Caption'])
|
89 |
+
raw_image = gr.Image(label='Input image', type="pil")
|
90 |
+
question = gr.Textbox(label="Input question", lines=1, interactive=True)
|
91 |
+
demo = gr.Blocks(title=title)
|
92 |
+
demo.encrypt = False
|
93 |
+
cap_df = gr.DataFrame(value=df_init, label="Caption dataframe", row_count=(0, "dynamic"), max_rows = 20, wrap=True, overflow_row_behaviour='paginate')
|
94 |
+
|
95 |
+
with demo:
|
96 |
+
with gr.Row():
|
97 |
+
gr.Markdown(f'''
|
98 |
+
<div>
|
99 |
+
<h1 style='text-align: center'>From Images to Textual Prompts: Zero-shot VQA with Frozen Large Language Models</h1>
|
100 |
+
</div>
|
101 |
+
<div align="center">
|
102 |
+
<h3> What you can do with this space </h3>
|
103 |
+
<h4> 1. Upload your image and fill your question </h4>
|
104 |
+
<h4> 2. Generating gradcam attention from model </h4>
|
105 |
+
<h4> 3. Creating caption from your image </h4>
|
106 |
+
<h4> 4. Answering your question based on uploaded image </h4>
|
107 |
+
</div>
|
108 |
+
''')
|
109 |
+
examples = gr.Examples(examples=
|
110 |
+
[["image1.jpg", "What type of bird is this?"]],
|
111 |
+
label="Examples", inputs=[raw_image, question])
|
112 |
+
with gr.Row():
|
113 |
+
with gr.Column():
|
114 |
+
raw_image.render()
|
115 |
+
with gr.Column():
|
116 |
+
avg_gradcam = gr.Image(label="GradCam image",)
|
117 |
+
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Column():
|
120 |
+
question.render()
|
121 |
+
with gr.Column():
|
122 |
+
number_cap = gr.Number(precision=0, value=5, label="Selected number of caption you want to generate", interactive=True)
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column():
|
125 |
+
gradcam_btn = gr.Button("Generate Gradcam")
|
126 |
+
gradcam_btn.click(gradcam_attention, [raw_image, question], outputs=[avg_gradcam])
|
127 |
+
with gr.Column():
|
128 |
+
cap_btn = gr.Button("Generate caption")
|
129 |
+
cap_btn.click(generate_cap, [raw_image, question, number_cap], [cap_df])
|
130 |
+
with gr.Row():
|
131 |
+
with gr.Column():
|
132 |
+
cap_df.render()
|
133 |
+
with gr.Row():
|
134 |
+
anws_btn = gr.Button("Answer")
|
135 |
+
text_output = gr.Textbox()
|
136 |
+
anws_btn.click(generate_answer, [raw_image, question], outputs=text_output)
|
137 |
+
|
138 |
+
demo.launch(debug=True)
|