MinxuanQin
commited on
Commit
·
a5ab0ec
1
Parent(s):
58c2c99
debug vbert
Browse files- model_loader.py +10 -10
model_loader.py
CHANGED
@@ -189,16 +189,16 @@ def get_answer(model_loader_args, img, question, model_name):
|
|
189 |
|
190 |
elif model_name == "vbert":
|
191 |
vqa_answers = get_data(VQA_URL)
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
except Exception:
|
198 |
-
|
199 |
-
else:
|
200 |
-
|
201 |
-
|
202 |
|
203 |
elif model_name == "blip":
|
204 |
try:
|
|
|
189 |
|
190 |
elif model_name == "vbert":
|
191 |
vqa_answers = get_data(VQA_URL)
|
192 |
+
|
193 |
+
# load question and image (processor = tokenizer)
|
194 |
+
## MOD Minxuan: fix error
|
195 |
+
_, inputs = get_item(img, question, processor, "resnet50")
|
196 |
+
outputs = model(**inputs)
|
197 |
+
#except Exception:
|
198 |
+
# return err_msg()
|
199 |
+
# else:
|
200 |
+
answer_idx = torch.argmax(outputs.logits, dim=1).item() # from 3129
|
201 |
+
pred = vqa_answers[answer_idx]
|
202 |
|
203 |
elif model_name == "blip":
|
204 |
try:
|