Spaces:
Runtime error
Runtime error
import torch | |
import requests | |
from PIL import Image | |
from matplotlib import pyplot as plt | |
import numpy as np | |
import pandas as pd | |
from lavis.common.gradcam import getAttMap | |
from lavis.models import load_model_and_preprocess | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM | |
import gradio as gr | |
def prepare_data(image, question): | |
image = vis_processors["eval"](image).unsqueeze(0).to(device) | |
question = txt_processors["eval"](question) | |
samples = {"image": image, "text_input": [question]} | |
return samples | |
def gradcam_attention(image, question): | |
dst_w = 720 | |
samples = prepare_data(image, question) | |
samples = model.forward_itm(samples=samples) | |
w, h = image.size | |
scaling_factor = dst_w / w | |
resized_img = image.resize((int(w * scaling_factor), int(h * scaling_factor))) | |
norm_img = np.float32(resized_img) / 255 | |
gradcam = samples['gradcams'].reshape(24,24) | |
avg_gradcam = getAttMap(norm_img, gradcam, blur=True) | |
return (avg_gradcam * 255).astype(np.uint8) | |
def generate_cap(image, question, cap_number): | |
samples = prepare_data(image, question) | |
samples = model.forward_itm(samples=samples) | |
samples = model.forward_cap(samples=samples, num_captions=cap_number, num_patches=5) | |
print('Examples of question-guided captions: ') | |
return pd.DataFrame({'Caption': samples['captions'][0][:cap_number]}) | |
def postprocess(text): | |
for i, ans in enumerate(text): | |
for j, w in enumerate(ans): | |
if w == '.' or w == '\n': | |
ans = ans[:j].lower() | |
break | |
return ans | |
def generate_answer(image, question): | |
samples = prepare_data(image, question) | |
samples = model.forward_itm(samples=samples) | |
samples = model.forward_cap(samples=samples, num_captions=5, num_patches=5) | |
samples = model.forward_qa_generation(samples) | |
Img2Prompt = model.prompts_construction(samples) | |
Img2Prompt_input = tokenizer(Img2Prompt, padding='longest', truncation=True, return_tensors="pt").to(device) | |
outputs = llm_model.generate(input_ids=Img2Prompt_input.input_ids, | |
attention_mask=Img2Prompt_input.attention_mask, | |
max_length=20+len(Img2Prompt_input.input_ids[0]), | |
return_dict_in_generate=True, | |
output_scores=True | |
) | |
pred_answer = tokenizer.batch_decode(outputs.sequences[:, len(Img2Prompt_input.input_ids[0]):]) | |
pred_answer = postprocess(pred_answer) | |
print(pred_answer, type(pred_answer)) | |
return pred_answer | |
# setup device to use | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(device) | |
def load_model(model_selection): | |
model = AutoModelForCausalLM.from_pretrained(model_selection) | |
tokenizer = AutoTokenizer.from_pretrained(model_selection, use_fast=False) | |
return model,tokenizer | |
# Choose LLM to use | |
# weights for OPT-6.7B/OPT-13B/OPT-30B/OPT-66B will download automatically | |
print("Loading Large Language Model (LLM)...") | |
llm_model, tokenizer = load_model('facebook/opt-1.3b') # ~13G (FP16) | |
llm_model.to(device) | |
model, vis_processors, txt_processors = load_model_and_preprocess(name="img2prompt_vqa", model_type="base", is_eval=True, device=device) | |
# ---- Gradio Layout ----- | |
title = "From Images to Textual Prompts: Zero-shot VQA with Frozen Large Language Models" | |
df_init = pd.DataFrame(columns=['Caption']) | |
raw_image = gr.Image(label='Input image', type="pil") | |
question = gr.Textbox(label="Input question", lines=1, interactive=True) | |
demo = gr.Blocks(title=title) | |
demo.encrypt = False | |
cap_df = gr.DataFrame(value=df_init, label="Caption dataframe", row_count=(0, "dynamic"), max_rows = 20, wrap=True, overflow_row_behaviour='paginate') | |
with demo: | |
with gr.Row(): | |
gr.Markdown(f''' | |
<div> | |
<h1 style='text-align: center'>From Images to Textual Prompts: Zero-shot VQA with Frozen Large Language Models</h1> | |
</div> | |
<div align="center"> | |
<h3> What you can do with this space </h3> | |
<h4> 1. Upload your image and fill your question </h4> | |
<h4> 2. Generating gradcam attention from model </h4> | |
<h4> 3. Creating caption from your image </h4> | |
<h4> 4. Answering your question based on uploaded image </h4> | |
</div> | |
''') | |
examples = gr.Examples(examples= | |
[["image1.jpg", "What type of bird is this?"]], | |
label="Examples", inputs=[raw_image, question]) | |
with gr.Row(): | |
with gr.Column(): | |
raw_image.render() | |
with gr.Column(): | |
avg_gradcam = gr.Image(label="GradCam image",) | |
with gr.Row(): | |
with gr.Column(): | |
question.render() | |
with gr.Column(): | |
number_cap = gr.Number(precision=0, value=5, label="Selected number of caption you want to generate", interactive=True) | |
with gr.Row(): | |
with gr.Column(): | |
gradcam_btn = gr.Button("Generate Gradcam") | |
gradcam_btn.click(gradcam_attention, [raw_image, question], outputs=[avg_gradcam]) | |
with gr.Column(): | |
cap_btn = gr.Button("Generate caption") | |
cap_btn.click(generate_cap, [raw_image, question, number_cap], [cap_df]) | |
with gr.Row(): | |
with gr.Column(): | |
cap_df.render() | |
with gr.Row(): | |
anws_btn = gr.Button("Answer") | |
text_output = gr.Textbox() | |
anws_btn.click(generate_answer, [raw_image, question], outputs=text_output) | |
demo.launch(debug=True) |