ReactSeq / onmt /tests /test_events.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
1.84 kB
from tensorboard.backend.event_processing import event_accumulator
from argparse import ArgumentParser
import os
class TestEvents:
def __init__(self):
stats = ["xent", "ppl", "accuracy", "tgtper", "lr"]
metrics = ["BLEU", "TER"]
self.scalars = {}
self.scalars["train"] = [("progress/" + stat) for stat in stats]
self.scalars["valid"] = [("valid/" + stat) for stat in stats]
self.scalars["valid_metrics"] = self.scalars["valid"] + [
("valid/" + metric) for metric in metrics
]
def reload_events(self, path):
ea = event_accumulator.EventAccumulator(
path,
size_guidance={event_accumulator.SCALARS: 0},
)
ea.Reload()
return ea
def check_scalars(self, scalars, logdir):
for event_file in os.listdir(logdir):
path = os.path.join(logdir, event_file)
event_accumulator = self.reload_events(path)
# make sure the scalars are in the event accumulator tags
assert all(
s in event_accumulator.Tags()["scalars"] for s in scalars
), "{} some scalars were not found in the event accumulator"
if __name__ == "__main__":
# required arguments
parser = ArgumentParser()
requiredArgs = parser.add_argument_group("required arguments")
requiredArgs.add_argument("-logdir", "--logdir", type=str, required=True)
requiredArgs.add_argument(
"-tensorboard_checks",
"--tensorboard_checks",
type=str,
required=True,
choices=["train", "valid", "valid_metrics"],
)
args = parser.parse_args()
test_event = TestEvents()
scalars = test_event.scalars[args.tensorboard_checks]
print("looking for scalars: ", scalars)
test_event.check_scalars(scalars, args.logdir)