Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,622 Bytes
a858bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
"""
Evaluating the perplexity on few shot tasks. This script accept a jsonl file
as input. Each line of the jsonl file representing a dictionary. Each line
represents one example in the evaluation set. The dictionary should have two key:
input: a list of paths to the input images as context to the model. This
list should include the few shot examples.
target: a list of paths to the target images to evaluate perplexity
Ths script should run the model and compute the average perplexity on the
evaluation set.
"""
import os
import json
from PIL import Image
import numpy as np
import mlxu
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from .inference import MultiProcessInferenceModel
FLAGS, _ = mlxu.define_flags_with_default(
input_file='',
checkpoint='',
input_base_dir='',
batch_size=2,
json_input_key='input',
json_target_key='target',
dtype='float16',
torch_devices='',
n_workers=4,
max_examples=0,
)
def read_image_to_tensor(path):
pil_im = Image.open(path).convert('RGB')
input_img = pil_im.resize((256, 256))
input_img = np.array(input_img) / 255.0
input_img = input_img.astype(np.float32)
return input_img
class MultiFrameDataset(torch.utils.data.Dataset):
def __init__(self, input_files, target_files):
assert len(input_files) == len(target_files)
self.input_files = input_files
self.target_files = target_files
def __len__(self):
return len(self.input_files)
def __getitem__(self, idx):
input_list = np.stack(
[read_image_to_tensor(f) for f in self.input_files[idx]],
axis=0
)
target_list = np.stack(
[read_image_to_tensor(f) for f in self.target_files[idx]],
axis=0
)
return input_list, target_list
def main(_):
assert FLAGS.checkpoint != ''
print(f'Loading checkpoint from {FLAGS.checkpoint}')
print(f'Evaluating input file from {FLAGS.input_file}')
model = MultiProcessInferenceModel(
checkpoint=FLAGS.checkpoint,
torch_devices=FLAGS.torch_devices,
dtype=FLAGS.dtype,
use_lock=True,
perplexity_batch_size=FLAGS.batch_size,
)
input_files = []
target_files = []
with mlxu.open_file(FLAGS.input_file, 'r') as f:
for line in f:
record = json.loads(line)
input_files.append(record[FLAGS.json_input_key])
target_files.append(record[FLAGS.json_target_key])
if FLAGS.input_base_dir != '':
input_files = [
[os.path.join(FLAGS.input_base_dir, x) for x in y]
for y in input_files
]
target_files = [
[os.path.join(FLAGS.input_base_dir, x) for x in y]
for y in target_files
]
if FLAGS.max_examples > 0:
input_files = input_files[:FLAGS.max_examples]
target_files = target_files[:FLAGS.max_examples]
dataset = MultiFrameDataset(input_files, target_files)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=FLAGS.batch_size * model.n_processes,
shuffle=False,
num_workers=FLAGS.n_workers
)
perplexities = []
for input_images, target_images in tqdm(data_loader, ncols=0):
perplexity = model.compute_perplexity(input_images, target_images)
perplexities.append(perplexity)
perplexities = np.concatenate(perplexities, axis=0)
print(f'Perplexity: {np.mean(perplexities)}')
if __name__ == "__main__":
mlxu.run(main) |