#!/usr/bin/env python """Train models with dynamic data.""" import torch from functools import partial from onmt.utils.distributed import ErrorHandler, spawned_train from onmt.utils.misc import set_random_seed from onmt.utils.logging import init_logger, logger from onmt.utils.parse import ArgumentParser from onmt.opts import train_opts from onmt.train_single import main as single_main # Set sharing strategy manually instead of default based on the OS. # torch.multiprocessing.set_sharing_strategy('file_system') def train(opt): init_logger(opt.log_file) ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) set_random_seed(opt.seed, False) train_process = partial(single_main) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: mp = torch.multiprocessing.get_context("spawn") # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): procs.append( mp.Process( target=spawned_train, args=(train_process, opt, device_id, error_queue), daemon=False, ) ) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) for p in procs: p.join() elif nb_gpu == 1: # case 1 GPU only train_process(opt, device_id=0) else: # case only CPU train_process(opt, device_id=-1) def _get_parser(): parser = ArgumentParser(description="train.py") train_opts(parser) return parser def main(): parser = _get_parser() opt, unknown = parser.parse_known_args() train(opt) if __name__ == "__main__": main()