ViVQA / Predict.py
windy2612's picture
Update Predict.py
10439c7 verified
raw
history blame contribute delete
904 Bytes
from Model import BaseModel
import json
import numpy as np
from PIL import Image
from torchvision import transforms as T
import torch
device = torch.device('cpu')
checkpoint = torch.load('last_checkpoint.pt', map_location = device)
with open('answer.json', 'r', encoding = 'utf8') as f:
answer_space = json.load(f)
swap_space = {v : k for k, v in answer_space.items()}
model = BaseModel().to(device)
model.load_state_dict(checkpoint['model_state_dict'])
def generate_caption(image, question):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str):
image = Image.open(image).convert("RGB")
transform = T.Compose([T.Resize((224, 224)),T.ToTensor()])
image = transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(image, question)
idx = torch.argmax(logits)
return swap_space[idx.item()]