File size: 3,622 Bytes
a858bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Evaluating the perplexity on few shot tasks. This script accept a jsonl file
as input. Each line of the jsonl file representing a dictionary. Each line
represents one example in the evaluation set. The dictionary should have two key:

    input: a list of paths to the input images as context to the model. This
        list should include the few shot examples.
    target: a list of paths to the target images to evaluate perplexity

Ths script should run the model and compute the average perplexity on the
evaluation set.
"""

import os
import json
from PIL import Image
import numpy as np
import mlxu
from tqdm import tqdm, trange
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops

from .inference import MultiProcessInferenceModel


FLAGS, _ = mlxu.define_flags_with_default(
    input_file='',
    checkpoint='',
    input_base_dir='',
    batch_size=2,
    json_input_key='input',
    json_target_key='target',
    dtype='float16',
    torch_devices='',
    n_workers=4,
    max_examples=0,
)


def read_image_to_tensor(path):
    pil_im = Image.open(path).convert('RGB')
    input_img = pil_im.resize((256, 256))
    input_img = np.array(input_img) / 255.0
    input_img = input_img.astype(np.float32)
    return input_img


class MultiFrameDataset(torch.utils.data.Dataset):
    def __init__(self, input_files, target_files):
        assert len(input_files) == len(target_files)
        self.input_files = input_files
        self.target_files = target_files

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_list = np.stack(
            [read_image_to_tensor(f) for f in self.input_files[idx]],
            axis=0
        )
        target_list = np.stack(
            [read_image_to_tensor(f) for f in self.target_files[idx]],
            axis=0
        )
        return input_list, target_list


def main(_):
    assert FLAGS.checkpoint != ''

    print(f'Loading checkpoint from {FLAGS.checkpoint}')
    print(f'Evaluating input file from {FLAGS.input_file}')

    model = MultiProcessInferenceModel(
        checkpoint=FLAGS.checkpoint,
        torch_devices=FLAGS.torch_devices,
        dtype=FLAGS.dtype,
        use_lock=True,
        perplexity_batch_size=FLAGS.batch_size,
    )

    input_files = []
    target_files = []

    with mlxu.open_file(FLAGS.input_file, 'r') as f:
        for line in f:
            record = json.loads(line)
            input_files.append(record[FLAGS.json_input_key])
            target_files.append(record[FLAGS.json_target_key])

    if FLAGS.input_base_dir != '':
        input_files = [
            [os.path.join(FLAGS.input_base_dir, x) for x in y]
            for y in input_files
        ]
        target_files = [
            [os.path.join(FLAGS.input_base_dir, x) for x in y]
            for y in target_files
        ]

    if FLAGS.max_examples > 0:
        input_files = input_files[:FLAGS.max_examples]
        target_files = target_files[:FLAGS.max_examples]

    dataset = MultiFrameDataset(input_files, target_files)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=FLAGS.batch_size * model.n_processes,
        shuffle=False,
        num_workers=FLAGS.n_workers
    )

    perplexities = []

    for input_images, target_images in tqdm(data_loader, ncols=0):
        perplexity = model.compute_perplexity(input_images, target_images)
        perplexities.append(perplexity)

    perplexities = np.concatenate(perplexities, axis=0)
    print(f'Perplexity: {np.mean(perplexities)}')


if __name__ == "__main__":
    mlxu.run(main)