MinxuanQin commited on
Commit
58c2c99
·
1 Parent(s): 7b4b5f6

add display visualbert

Browse files
Files changed (1) hide show
  1. model_loader.py +3 -1
model_loader.py CHANGED
@@ -5,6 +5,7 @@ from datasets import load_dataset, get_dataset_split_names
5
  import numpy as np
6
 
7
  import requests
 
8
  from transformers import ViltProcessor, ViltForQuestionAnswering
9
  from transformers import AutoProcessor, AutoModelForCausalLM
10
  from transformers import BlipProcessor, BlipForQuestionAnswering
@@ -87,6 +88,7 @@ def get_item(image, question, tokenizer, image_model, model_name):
87
  )
88
  visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
89
  .squeeze(2, 3).unsqueeze(0)
 
90
  visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
91
  visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
92
  upd_dict = {
@@ -95,7 +97,7 @@ def get_item(image, question, tokenizer, image_model, model_name):
95
  "visual_attention_mask": visual_attention_mask,
96
  }
97
  inputs.update(upd_dict)
98
-
99
  return upd_dict, inputs
100
 
101
 
 
5
  import numpy as np
6
 
7
  import requests
8
+ import streamlit as st
9
  from transformers import ViltProcessor, ViltForQuestionAnswering
10
  from transformers import AutoProcessor, AutoModelForCausalLM
11
  from transformers import BlipProcessor, BlipForQuestionAnswering
 
88
  )
89
  visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
90
  .squeeze(2, 3).unsqueeze(0)
91
+ st.text(f"ques embed: {inputs.shape}, visual: {visual_embeds.shape}")
92
  visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
93
  visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
94
  upd_dict = {
 
97
  "visual_attention_mask": visual_attention_mask,
98
  }
99
  inputs.update(upd_dict)
100
+
101
  return upd_dict, inputs
102
 
103