""" Statistics calculation utility """ import time import math import sys from onmt.utils.logging import logger class Statistics(object): """ Accumulator for loss statistics. Currently calculates: * accuracy * perplexity * elapsed time """ def __init__( self, loss=0, n_batchs=0, n_sents=0, n_words=0, n_correct=0, computed_metrics={} ): self.loss = loss self.n_batchs = n_batchs self.n_sents = n_sents self.n_words = n_words self.n_correct = n_correct self.n_src_words = 0 self.computed_metrics = computed_metrics self.start_time = time.time() @staticmethod def all_gather_stats(stat, max_size=4096): """ Gather a `Statistics` object accross multiple process/nodes Args: stat(:obj:Statistics): the statistics object to gather accross all processes/nodes max_size(int): max buffer size to use Returns: `Statistics`, the update stats object """ stats = Statistics.all_gather_stats_list([stat], max_size=max_size) return stats[0] @staticmethod def all_gather_stats_list(stat_list, max_size=4096): """ Gather a `Statistics` list accross all processes/nodes Args: stat_list(list([`Statistics`])): list of statistics objects to gather accross all processes/nodes max_size(int): max buffer size to use Returns: our_stats(list([`Statistics`])): list of updated stats """ from torch.distributed import get_rank from onmt.utils.distributed import all_gather_list # Get a list of world_size lists with len(stat_list) Statistics objects all_stats = all_gather_list(stat_list, max_size=max_size) our_rank = get_rank() our_stats = all_stats[our_rank] for other_rank, stats in enumerate(all_stats): if other_rank == our_rank: continue for i, stat in enumerate(stats): our_stats[i].update(stat, update_n_src_words=True) return our_stats def update(self, stat, update_n_src_words=False): """ Update statistics by suming values with another `Statistics` object Args: stat: another statistic object update_n_src_words(bool): whether to update (sum) `n_src_words` or not """ self.loss += stat.loss self.n_batchs += stat.n_batchs self.n_sents += stat.n_sents self.n_words += stat.n_words self.n_correct += stat.n_correct self.computed_metrics = stat.computed_metrics if update_n_src_words: self.n_src_words += stat.n_src_words def accuracy(self): """compute accuracy""" return 100 * (self.n_correct / self.n_words) def xent(self): """compute cross entropy""" return self.loss / self.n_words def ppl(self): """compute perplexity""" return math.exp(min(self.loss / self.n_words, 100)) def elapsed_time(self): """compute elapsed time""" return time.time() - self.start_time def output(self, step, num_steps, learning_rate, start): """Write out statistics to stdout. Args: step (int): current step n_batch (int): total batches start (int): start time of step. """ t = self.elapsed_time() step_fmt = "%2d" % step if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) logger.info( ( "Step %s; acc: %2.1f; ppl: %5.1f; xent: %2.1f; " + "lr: %7.5f; sents: %7.0f; bsz: %4.0f/%4.0f/%2.0f; " + "%3.0f/%3.0f tok/s; %6.0f sec;" ) % ( step_fmt, self.accuracy(), self.ppl(), self.xent(), learning_rate, self.n_sents, self.n_src_words / self.n_batchs, self.n_words / self.n_batchs, self.n_sents / self.n_batchs, self.n_src_words / (t + 1e-5), self.n_words / (t + 1e-5), time.time() - start, ) + "".join( [ " {}: {}".format(k, round(v, 2)) for k, v in self.computed_metrics.items() ] ) ) sys.stdout.flush() def log_tensorboard(self, prefix, writer, learning_rate, patience, step): """display statistics to tensorboard""" t = self.elapsed_time() writer.add_scalar(prefix + "/xent", self.xent(), step) writer.add_scalar(prefix + "/ppl", self.ppl(), step) for k, v in self.computed_metrics.items(): writer.add_scalar(prefix + "/" + k, round(v, 4), step) writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) writer.add_scalar(prefix + "/lr", learning_rate, step) if patience is not None: writer.add_scalar(prefix + "/patience", patience, step)