import logging from mpi4py import MPI import os import re import subprocess import torch logger = logging.getLogger(__name__) class MPIAdapter: """ MPIAdapter automatically detects and analyzes the training environment for distributed training and offers methods to set up distributed training jobs. For example, it determines whether training happens on AML, Philly, or locally. It also determines variables such as the world size and the rank of each GPU. """ def __init__(self, set_env_vars=True, master_address=None, port='55551'): local_address = '127.0.0.1' default_torch_distributed_port = str(port) # chosen arbitrarily if 'OMPI_COMM_WORLD_SIZE' not in os.environ: # application was started without MPI # default to single node with single process self.env_info = 'no MPI' self.world_size = 1 self.local_size = 1 self.rank = 0 self.local_rank = 0 self.master_address = local_address self.master_port = default_torch_distributed_port else: # application was started with MPI # get MPI parameters self.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) self.local_size = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) self.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) self.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if master_address is not None: self.master_address = master_address self.master_port = default_torch_distributed_port self.env_info = 'manually set master ip' elif 'PHILLY_CONTAINER_IP' in os.environ: # application is running on Philly # read environment variables on master node and broadcast via MPI self.env_info = 'philly' if self.rank == 0: self.master_address = os.environ['PHILLY_CONTAINER_IP'] self.master_port = os.environ['PHILLY_CONTAINER_PORT_RANGE_START'] else: self.master_address = None self.master_port = None self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0) self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0) elif "AMLK8S_NUM_WORKER" in os.environ or "AZ_CMK8S_JOB_WORK_DIR" in os.environ: # application is running on AMLK8S (ITP) # read master address from a specific file. self.env_info = 'AMLK8S (ITP)' # from: https://k8s-wiki.azureml.com/faq.html regexp = r"[\s\S]*export[\s]*DLTS_SD_worker0_IP=([0-9.]+)[\s|s]*" with open("/dlts-runtime/env/init.env", 'r') as f: line = f.read() match = re.match(regexp, line) if match: self.master_address = str(match.group(1)) else: # Did not find master node ip in file. It must be a single-node # debugging job with custom "mpirun" command assert self.world_size == self.local_size, \ "It's not a single-node debugging job on AMLK8S (ITP), but no master ip is found in file." self.env_info = 'single-node AMLK8S (ITP) debugging job' self.master_address = local_address self.master_port = default_torch_distributed_port elif 'AZ_BATCH_MASTER_NODE' in os.environ: # application is running on multiple nodes on AML self.env_info = 'multi-node AML' master_node_params = os.environ['AZ_BATCH_MASTER_NODE'].split(':') self.master_address = master_node_params[0] self.master_port = default_torch_distributed_port elif self.world_size == self.local_size: # application is running with MPI on single node self.env_info = 'single-node AML or other MPI environment' self.master_address = local_address self.master_port = default_torch_distributed_port else: # multi-node MPI environment, but not Philly or AML # we use "hostname -I" command on rank 0 to get the master address self.env_info = 'multi-node other MPI environment' if self.rank == 0: hostname_cmd = ["hostname -I"] result = subprocess.check_output(hostname_cmd, shell=True) self.master_address = result.decode('utf-8').split()[0] self.master_port = default_torch_distributed_port else: self.master_address = None self.master_port = None self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0) self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0) self.init_method_url = f'tcp://{self.master_address}:{self.master_port}' if set_env_vars: self._set_env_vars() def log_info(self): """ Logs information about distributed training environment. """ # use logger.warning because MainzTrain has a hidden convention # of not printing logger.info messages on processes with rank > 0 logger.warning('----------------') logger.warning('MPI Adapter data') logger.warning('----------------') logger.warning(f'environment info: {self.env_info}') logger.warning(f'init method url: {self.init_method_url}') logger.warning(f'world size: {self.world_size}') logger.warning(f'local size: {self.local_size}') logger.warning(f'rank: {self.rank}') logger.warning(f'local rank: {self.local_rank}') logger.warning(f'master address: {self.master_address}') logger.warning(f'master port: {self.master_port}') logger.warning('----------------') def init_process_group(self, backend): """ Initializes the default PyTorch distributed process group. """ # use logger.warning because MainzTrain has a hidden convention # of not printing logger.info messages on processes with rank > 0 logger.warning('trying to initialize process group ...') torch.distributed.init_process_group(backend=backend, init_method=self.init_method_url, world_size=self.world_size, rank=self.rank) logger.warning('process group initialized') def _set_env_vars(self): """ Sets environment variables for world size, rank, local rank, master addr, and master port. """ os.environ['WORLD_SIZE'] = str(self.world_size) os.environ['RANK'] = str(self.rank) os.environ["LOCAL_RANK"] = str(self.local_rank) os.environ['MASTER_ADDR'] = self.master_address os.environ['MASTER_PORT'] = self.master_port