# torch packages | |
import torch | |
from model.transformer import Transformer | |
import json | |
if __name__ == "__main__": | |
""" | |
Following parameters are for Multi30K dataset | |
""" | |
# Load config containing model input parameters | |
with open('params.json') as json_data: | |
config = json.load(json_data) | |
print(config) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Instantiate model | |
model = Transformer( | |
config["dk"], | |
config["dv"], | |
config["h"], | |
config["src_vocab_size"], | |
config["target_vocab_size"], | |
config["num_encoders"], | |
config["num_decoders"], | |
config["dim_multiplier"], | |
config["pdropout"], | |
device = device) | |
# Load model weights | |
model.load_state_dict(torch.load('pytorch_transformer_model.pt', | |
map_location=device)) | |
print(model) | |