LVM / eval_perplexity.py
Emma02's picture
Add application file
a858bb2
raw
history blame
3.62 kB
"""
Evaluating the perplexity on few shot tasks. This script accept a jsonl file
as input. Each line of the jsonl file representing a dictionary. Each line
represents one example in the evaluation set. The dictionary should have two key:
input: a list of paths to the input images as context to the model. This
list should include the few shot examples.
target: a list of paths to the target images to evaluate perplexity
Ths script should run the model and compute the average perplexity on the
evaluation set.
"""
import os
import json
from PIL import Image
import numpy as np
import mlxu
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from .inference import MultiProcessInferenceModel
FLAGS, _ = mlxu.define_flags_with_default(
input_file='',
checkpoint='',
input_base_dir='',
batch_size=2,
json_input_key='input',
json_target_key='target',
dtype='float16',
torch_devices='',
n_workers=4,
max_examples=0,
)
def read_image_to_tensor(path):
pil_im = Image.open(path).convert('RGB')
input_img = pil_im.resize((256, 256))
input_img = np.array(input_img) / 255.0
input_img = input_img.astype(np.float32)
return input_img
class MultiFrameDataset(torch.utils.data.Dataset):
def __init__(self, input_files, target_files):
assert len(input_files) == len(target_files)
self.input_files = input_files
self.target_files = target_files
def __len__(self):
return len(self.input_files)
def __getitem__(self, idx):
input_list = np.stack(
[read_image_to_tensor(f) for f in self.input_files[idx]],
axis=0
)
target_list = np.stack(
[read_image_to_tensor(f) for f in self.target_files[idx]],
axis=0
)
return input_list, target_list
def main(_):
assert FLAGS.checkpoint != ''
print(f'Loading checkpoint from {FLAGS.checkpoint}')
print(f'Evaluating input file from {FLAGS.input_file}')
model = MultiProcessInferenceModel(
checkpoint=FLAGS.checkpoint,
torch_devices=FLAGS.torch_devices,
dtype=FLAGS.dtype,
use_lock=True,
perplexity_batch_size=FLAGS.batch_size,
)
input_files = []
target_files = []
with mlxu.open_file(FLAGS.input_file, 'r') as f:
for line in f:
record = json.loads(line)
input_files.append(record[FLAGS.json_input_key])
target_files.append(record[FLAGS.json_target_key])
if FLAGS.input_base_dir != '':
input_files = [
[os.path.join(FLAGS.input_base_dir, x) for x in y]
for y in input_files
]
target_files = [
[os.path.join(FLAGS.input_base_dir, x) for x in y]
for y in target_files
]
if FLAGS.max_examples > 0:
input_files = input_files[:FLAGS.max_examples]
target_files = target_files[:FLAGS.max_examples]
dataset = MultiFrameDataset(input_files, target_files)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=FLAGS.batch_size * model.n_processes,
shuffle=False,
num_workers=FLAGS.n_workers
)
perplexities = []
for input_images, target_images in tqdm(data_loader, ncols=0):
perplexity = model.compute_perplexity(input_images, target_images)
perplexities.append(perplexity)
perplexities = np.concatenate(perplexities, axis=0)
print(f'Perplexity: {np.mean(perplexities)}')
if __name__ == "__main__":
mlxu.run(main)