MinxuanQin
commited on
Commit
·
0c9e22d
1
Parent(s):
a5ab0ec
fix error in visualbert
Browse files- model_loader.py +16 -8
model_loader.py
CHANGED
@@ -62,13 +62,20 @@ def load_dataset(type):
|
|
62 |
raise ValueError("invalid dataset: ", type)
|
63 |
'''
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
sample = {}
|
68 |
-
sample['inputs'] = processor(images=examples['image'], text=examples['question'], return_tensors="pt")
|
69 |
-
sample['outputs'] = examples['multiple_choice_answer']
|
70 |
-
return sample
|
71 |
-
|
72 |
|
73 |
def label_count_list(labels):
|
74 |
res = {}
|
@@ -88,7 +95,7 @@ def get_item(image, question, tokenizer, image_model, model_name):
|
|
88 |
)
|
89 |
visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
|
90 |
.squeeze(2, 3).unsqueeze(0)
|
91 |
-
|
92 |
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
93 |
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
94 |
upd_dict = {
|
@@ -192,7 +199,8 @@ def get_answer(model_loader_args, img, question, model_name):
|
|
192 |
|
193 |
# load question and image (processor = tokenizer)
|
194 |
## MOD Minxuan: fix error
|
195 |
-
|
|
|
196 |
outputs = model(**inputs)
|
197 |
#except Exception:
|
198 |
# return err_msg()
|
|
|
62 |
raise ValueError("invalid dataset: ", type)
|
63 |
'''
|
64 |
|
65 |
+
def load_img_model(name):
|
66 |
+
"""
|
67 |
+
loads image models for feature extraction
|
68 |
+
returns model name and the loaded model
|
69 |
+
"""
|
70 |
+
if name == "resnet50":
|
71 |
+
model = resnet50(weights='DEFAULT')
|
72 |
+
elif name == "vitb16":
|
73 |
+
## MOD Minxuan: add param
|
74 |
+
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
|
75 |
+
else:
|
76 |
+
raise ValueError("undefined model name: ", name)
|
77 |
|
78 |
+
return model, name
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
def label_count_list(labels):
|
81 |
res = {}
|
|
|
95 |
)
|
96 |
visual_embeds = get_img_feats(image, image_model=image_model, name=model_name)\
|
97 |
.squeeze(2, 3).unsqueeze(0)
|
98 |
+
|
99 |
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
100 |
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
101 |
upd_dict = {
|
|
|
199 |
|
200 |
# load question and image (processor = tokenizer)
|
201 |
## MOD Minxuan: fix error
|
202 |
+
img_model, name = load_img_model("resnet50")
|
203 |
+
_, inputs = get_item(img, question, processor, img_model, name)
|
204 |
outputs = model(**inputs)
|
205 |
#except Exception:
|
206 |
# return err_msg()
|