|
from tensorboard.backend.event_processing import event_accumulator |
|
from argparse import ArgumentParser |
|
import os |
|
|
|
|
|
class TestEvents: |
|
def __init__(self): |
|
stats = ["xent", "ppl", "accuracy", "tgtper", "lr"] |
|
metrics = ["BLEU", "TER"] |
|
self.scalars = {} |
|
self.scalars["train"] = [("progress/" + stat) for stat in stats] |
|
self.scalars["valid"] = [("valid/" + stat) for stat in stats] |
|
self.scalars["valid_metrics"] = self.scalars["valid"] + [ |
|
("valid/" + metric) for metric in metrics |
|
] |
|
|
|
def reload_events(self, path): |
|
ea = event_accumulator.EventAccumulator( |
|
path, |
|
size_guidance={event_accumulator.SCALARS: 0}, |
|
) |
|
ea.Reload() |
|
return ea |
|
|
|
def check_scalars(self, scalars, logdir): |
|
for event_file in os.listdir(logdir): |
|
path = os.path.join(logdir, event_file) |
|
event_accumulator = self.reload_events(path) |
|
|
|
assert all( |
|
s in event_accumulator.Tags()["scalars"] for s in scalars |
|
), "{} some scalars were not found in the event accumulator" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = ArgumentParser() |
|
requiredArgs = parser.add_argument_group("required arguments") |
|
requiredArgs.add_argument("-logdir", "--logdir", type=str, required=True) |
|
requiredArgs.add_argument( |
|
"-tensorboard_checks", |
|
"--tensorboard_checks", |
|
type=str, |
|
required=True, |
|
choices=["train", "valid", "valid_metrics"], |
|
) |
|
args = parser.parse_args() |
|
test_event = TestEvents() |
|
scalars = test_event.scalars[args.tensorboard_checks] |
|
print("looking for scalars: ", scalars) |
|
test_event.check_scalars(scalars, args.logdir) |
|
|