File size: 1,839 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from tensorboard.backend.event_processing import event_accumulator
from argparse import ArgumentParser
import os


class TestEvents:
    def __init__(self):
        stats = ["xent", "ppl", "accuracy", "tgtper", "lr"]
        metrics = ["BLEU", "TER"]
        self.scalars = {}
        self.scalars["train"] = [("progress/" + stat) for stat in stats]
        self.scalars["valid"] = [("valid/" + stat) for stat in stats]
        self.scalars["valid_metrics"] = self.scalars["valid"] + [
            ("valid/" + metric) for metric in metrics
        ]

    def reload_events(self, path):
        ea = event_accumulator.EventAccumulator(
            path,
            size_guidance={event_accumulator.SCALARS: 0},
        )
        ea.Reload()
        return ea

    def check_scalars(self, scalars, logdir):
        for event_file in os.listdir(logdir):
            path = os.path.join(logdir, event_file)
            event_accumulator = self.reload_events(path)
            # make sure the scalars are in the event accumulator tags
            assert all(
                s in event_accumulator.Tags()["scalars"] for s in scalars
            ), "{} some scalars were not found in the event accumulator"


if __name__ == "__main__":
    # required arguments
    parser = ArgumentParser()
    requiredArgs = parser.add_argument_group("required arguments")
    requiredArgs.add_argument("-logdir", "--logdir", type=str, required=True)
    requiredArgs.add_argument(
        "-tensorboard_checks",
        "--tensorboard_checks",
        type=str,
        required=True,
        choices=["train", "valid", "valid_metrics"],
    )
    args = parser.parse_args()
    test_event = TestEvents()
    scalars = test_event.scalars[args.tensorboard_checks]
    print("looking for scalars: ", scalars)
    test_event.check_scalars(scalars, args.logdir)