Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
1.05 kB
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.loader import NeighborLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import MessagePassing, SAGEConv
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
import pandas as pd
import numpy as np
def get_score(submission_folder = "../env"):
submission_path = os.path.join(submission_folder, "submission.csv")
submission = pd.read_csv(submission_path)
target_dataset = 'ogbn-arxiv'
dataset = PygNodePropPredDataset(name=target_dataset, root='networks')
data = dataset[0]
split_idx = dataset.get_idx_split()
test_idx = split_idx['test']
evaluator = Evaluator(name=target_dataset)
y_true = data.y.cpu()
submission = torch.tensor(np.array(submission))
test_acc = evaluator.eval({
'y_true': y_true[test_idx],
'y_pred': submission,
})['acc']
return test_acc
if __name__ == "__main__":
print(get_score())