vumichien commited on
Commit
183f457
·
1 Parent(s): ee3f6a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import requests
3
+ from PIL import Image
4
+ from matplotlib import pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from lavis.common.gradcam import getAttMap
9
+ from lavis.models import load_model_and_preprocess
10
+
11
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
12
+ import gradio as gr
13
+
14
+ def prepare_data(image, question):
15
+ image = vis_processors["eval"](image).unsqueeze(0).to(device)
16
+ question = txt_processors["eval"](question)
17
+ samples = {"image": image, "text_input": [question]}
18
+ return samples
19
+
20
+ def gradcam_attention(image, question):
21
+ dst_w = 720
22
+ samples = prepare_data(image, question)
23
+ samples = model.forward_itm(samples=samples)
24
+
25
+ w, h = image.size
26
+ scaling_factor = dst_w / w
27
+
28
+ resized_img = image.resize((int(w * scaling_factor), int(h * scaling_factor)))
29
+ norm_img = np.float32(resized_img) / 255
30
+ gradcam = samples['gradcams'].reshape(24,24)
31
+
32
+ avg_gradcam = getAttMap(norm_img, gradcam, blur=True)
33
+ return (avg_gradcam * 255).astype(np.uint8)
34
+
35
+ def generate_cap(image, question, cap_number):
36
+ samples = prepare_data(image, question)
37
+ samples = model.forward_itm(samples=samples)
38
+ samples = model.forward_cap(samples=samples, num_captions=cap_number, num_patches=5)
39
+ print('Examples of question-guided captions: ')
40
+ return pd.DataFrame({'Caption': samples['captions'][0][:cap_number]})
41
+
42
+ def postprocess(text):
43
+ for i, ans in enumerate(text):
44
+ for j, w in enumerate(ans):
45
+ if w == '.' or w == '\n':
46
+ ans = ans[:j].lower()
47
+ break
48
+ return ans
49
+
50
+ def generate_answer(image, question):
51
+ samples = prepare_data(image, question)
52
+ samples = model.forward_itm(samples=samples)
53
+ samples = model.forward_cap(samples=samples, num_captions=5, num_patches=5)
54
+ samples = model.forward_qa_generation(samples)
55
+ Img2Prompt = model.prompts_construction(samples)
56
+ Img2Prompt_input = tokenizer(Img2Prompt, padding='longest', truncation=True, return_tensors="pt").to(device)
57
+
58
+ outputs = llm_model.generate(input_ids=Img2Prompt_input.input_ids,
59
+ attention_mask=Img2Prompt_input.attention_mask,
60
+ max_length=20+len(Img2Prompt_input.input_ids[0]),
61
+ return_dict_in_generate=True,
62
+ output_scores=True
63
+ )
64
+ pred_answer = tokenizer.batch_decode(outputs.sequences[:, len(Img2Prompt_input.input_ids[0]):])
65
+ pred_answer = postprocess(pred_answer)
66
+ print(pred_answer, type(pred_answer))
67
+ return pred_answer
68
+
69
+ # setup device to use
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ print(device)
72
+
73
+ def load_model(model_selection):
74
+ model = AutoModelForCausalLM.from_pretrained(model_selection)
75
+ tokenizer = AutoTokenizer.from_pretrained(model_selection, use_fast=False)
76
+ return model,tokenizer
77
+
78
+ # Choose LLM to use
79
+ # weights for OPT-6.7B/OPT-13B/OPT-30B/OPT-66B will download automatically
80
+ print("Loading Large Language Model (LLM)...")
81
+ llm_model, tokenizer = load_model('facebook/opt-1.3b') # ~13G (FP16)
82
+ llm_model.to(device)
83
+ model, vis_processors, txt_processors = load_model_and_preprocess(name="img2prompt_vqa", model_type="base", is_eval=True, device=device)
84
+
85
+
86
+ # ---- Gradio Layout -----
87
+ title = "From Images to Textual Prompts: Zero-shot VQA with Frozen Large Language Models"
88
+ df_init = pd.DataFrame(columns=['Caption'])
89
+ raw_image = gr.Image(label='Input image', type="pil")
90
+ question = gr.Textbox(label="Input question", lines=1, interactive=True)
91
+ demo = gr.Blocks(title=title)
92
+ demo.encrypt = False
93
+ cap_df = gr.DataFrame(value=df_init, label="Caption dataframe", row_count=(0, "dynamic"), max_rows = 20, wrap=True, overflow_row_behaviour='paginate')
94
+
95
+ with demo:
96
+ with gr.Row():
97
+ gr.Markdown(f'''
98
+ <div>
99
+ <h1 style='text-align: center'>From Images to Textual Prompts: Zero-shot VQA with Frozen Large Language Models</h1>
100
+ </div>
101
+ <div align="center">
102
+ <h3> What you can do with this space </h3>
103
+ <h4> 1. Upload your image and fill your question </h4>
104
+ <h4> 2. Generating gradcam attention from model </h4>
105
+ <h4> 3. Creating caption from your image </h4>
106
+ <h4> 4. Answering your question based on uploaded image </h4>
107
+ </div>
108
+ ''')
109
+ examples = gr.Examples(examples=
110
+ [["image1.jpg", "What type of bird is this?"]],
111
+ label="Examples", inputs=[raw_image, question])
112
+ with gr.Row():
113
+ with gr.Column():
114
+ raw_image.render()
115
+ with gr.Column():
116
+ avg_gradcam = gr.Image(label="GradCam image",)
117
+
118
+ with gr.Row():
119
+ with gr.Column():
120
+ question.render()
121
+ with gr.Column():
122
+ number_cap = gr.Number(precision=0, value=5, label="Selected number of caption you want to generate", interactive=True)
123
+ with gr.Row():
124
+ with gr.Column():
125
+ gradcam_btn = gr.Button("Generate Gradcam")
126
+ gradcam_btn.click(gradcam_attention, [raw_image, question], outputs=[avg_gradcam])
127
+ with gr.Column():
128
+ cap_btn = gr.Button("Generate caption")
129
+ cap_btn.click(generate_cap, [raw_image, question, number_cap], [cap_df])
130
+ with gr.Row():
131
+ with gr.Column():
132
+ cap_df.render()
133
+ with gr.Row():
134
+ anws_btn = gr.Button("Answer")
135
+ text_output = gr.Textbox()
136
+ anws_btn.click(generate_answer, [raw_image, question], outputs=text_output)
137
+
138
+ demo.launch(debug=True)