import torch import requests from PIL import Image from matplotlib import pyplot as plt import numpy as np import pandas as pd from lavis.common.gradcam import getAttMap from lavis.models import load_model_and_preprocess from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr def prepare_data(image, question): image = vis_processors["eval"](image).unsqueeze(0).to(device) question = txt_processors["eval"](question) samples = {"image": image, "text_input": [question]} return samples def gradcam_attention(image, question): dst_w = 720 samples = prepare_data(image, question) samples = model.forward_itm(samples=samples) w, h = image.size scaling_factor = dst_w / w resized_img = image.resize((int(w * scaling_factor), int(h * scaling_factor))) norm_img = np.float32(resized_img) / 255 gradcam = samples['gradcams'].reshape(24,24) avg_gradcam = getAttMap(norm_img, gradcam, blur=True) return (avg_gradcam * 255).astype(np.uint8) def generate_cap(image, question, cap_number): samples = prepare_data(image, question) samples = model.forward_itm(samples=samples) samples = model.forward_cap(samples=samples, num_captions=cap_number, num_patches=5) print('Examples of question-guided captions: ') return pd.DataFrame({'Caption': samples['captions'][0][:cap_number]}) def postprocess(text): for i, ans in enumerate(text): for j, w in enumerate(ans): if w == '.' or w == '\n': ans = ans[:j].lower() break return ans def generate_answer(image, question): samples = prepare_data(image, question) samples = model.forward_itm(samples=samples) samples = model.forward_cap(samples=samples, num_captions=5, num_patches=5) samples = model.forward_qa_generation(samples) Img2Prompt = model.prompts_construction(samples) Img2Prompt_input = tokenizer(Img2Prompt, padding='longest', truncation=True, return_tensors="pt").to(device) outputs = llm_model.generate(input_ids=Img2Prompt_input.input_ids, attention_mask=Img2Prompt_input.attention_mask, max_length=20+len(Img2Prompt_input.input_ids[0]), return_dict_in_generate=True, output_scores=True ) pred_answer = tokenizer.batch_decode(outputs.sequences[:, len(Img2Prompt_input.input_ids[0]):]) pred_answer = postprocess(pred_answer) print(pred_answer, type(pred_answer)) return pred_answer # setup device to use device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) def load_model(model_selection): model = AutoModelForCausalLM.from_pretrained(model_selection) tokenizer = AutoTokenizer.from_pretrained(model_selection, use_fast=False) return model,tokenizer # Choose LLM to use # weights for OPT-6.7B/OPT-13B/OPT-30B/OPT-66B will download automatically print("Loading Large Language Model (LLM)...") llm_model, tokenizer = load_model('facebook/opt-1.3b') # ~13G (FP16) llm_model.to(device) model, vis_processors, txt_processors = load_model_and_preprocess(name="img2prompt_vqa", model_type="base", is_eval=True, device=device) # ---- Gradio Layout ----- title = "From Images to Textual Prompts: Zero-shot VQA with Frozen Large Language Models" df_init = pd.DataFrame(columns=['Caption']) raw_image = gr.Image(label='Input image', type="pil") question = gr.Textbox(label="Input question", lines=1, interactive=True) demo = gr.Blocks(title=title) demo.encrypt = False cap_df = gr.DataFrame(value=df_init, label="Caption dataframe", row_count=(0, "dynamic"), max_rows = 20, wrap=True, overflow_row_behaviour='paginate') with demo: with gr.Row(): gr.Markdown(f'''

From Images to Textual Prompts: Zero-shot VQA with Frozen Large Language Models

What you can do with this space

1. Upload your image and fill your question

2. Generating gradcam attention from model

3. Creating caption from your image

4. Answering your question based on uploaded image

''') examples = gr.Examples(examples= [["image1.jpg", "What type of bird is this?"]], label="Examples", inputs=[raw_image, question]) with gr.Row(): with gr.Column(): raw_image.render() with gr.Column(): avg_gradcam = gr.Image(label="GradCam image",) with gr.Row(): with gr.Column(): question.render() with gr.Column(): number_cap = gr.Number(precision=0, value=5, label="Selected number of caption you want to generate", interactive=True) with gr.Row(): with gr.Column(): gradcam_btn = gr.Button("Generate Gradcam") gradcam_btn.click(gradcam_attention, [raw_image, question], outputs=[avg_gradcam]) with gr.Column(): cap_btn = gr.Button("Generate caption") cap_btn.click(generate_cap, [raw_image, question, number_cap], [cap_df]) with gr.Row(): with gr.Column(): cap_df.render() with gr.Row(): anws_btn = gr.Button("Answer") text_output = gr.Textbox() anws_btn.click(generate_answer, [raw_image, question], outputs=text_output) demo.launch(debug=True)