|
import logging |
|
from mpi4py import MPI |
|
import os |
|
import re |
|
import subprocess |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class MPIAdapter: |
|
""" |
|
MPIAdapter automatically detects and analyzes the training environment for distributed training |
|
and offers methods to set up distributed training jobs. |
|
|
|
For example, it determines whether training happens on AML, Philly, or locally. |
|
It also determines variables such as the world size and the rank of each GPU. |
|
""" |
|
|
|
def __init__(self, set_env_vars=True, master_address=None, port='55551'): |
|
local_address = '127.0.0.1' |
|
default_torch_distributed_port = str(port) |
|
|
|
if 'OMPI_COMM_WORLD_SIZE' not in os.environ: |
|
|
|
|
|
self.env_info = 'no MPI' |
|
self.world_size = 1 |
|
self.local_size = 1 |
|
self.rank = 0 |
|
self.local_rank = 0 |
|
self.master_address = local_address |
|
self.master_port = default_torch_distributed_port |
|
else: |
|
|
|
|
|
self.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) |
|
self.local_size = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) |
|
self.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) |
|
self.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) |
|
|
|
if master_address is not None: |
|
self.master_address = master_address |
|
self.master_port = default_torch_distributed_port |
|
self.env_info = 'manually set master ip' |
|
elif 'PHILLY_CONTAINER_IP' in os.environ: |
|
|
|
|
|
self.env_info = 'philly' |
|
if self.rank == 0: |
|
self.master_address = os.environ['PHILLY_CONTAINER_IP'] |
|
self.master_port = os.environ['PHILLY_CONTAINER_PORT_RANGE_START'] |
|
else: |
|
self.master_address = None |
|
self.master_port = None |
|
self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0) |
|
self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0) |
|
elif "AMLK8S_NUM_WORKER" in os.environ or "AZ_CMK8S_JOB_WORK_DIR" in os.environ: |
|
|
|
|
|
self.env_info = 'AMLK8S (ITP)' |
|
|
|
regexp = r"[\s\S]*export[\s]*DLTS_SD_worker0_IP=([0-9.]+)[\s|s]*" |
|
with open("/dlts-runtime/env/init.env", 'r') as f: |
|
line = f.read() |
|
match = re.match(regexp, line) |
|
if match: |
|
self.master_address = str(match.group(1)) |
|
else: |
|
|
|
|
|
assert self.world_size == self.local_size, \ |
|
"It's not a single-node debugging job on AMLK8S (ITP), but no master ip is found in file." |
|
self.env_info = 'single-node AMLK8S (ITP) debugging job' |
|
self.master_address = local_address |
|
self.master_port = default_torch_distributed_port |
|
elif 'AZ_BATCH_MASTER_NODE' in os.environ: |
|
|
|
self.env_info = 'multi-node AML' |
|
master_node_params = os.environ['AZ_BATCH_MASTER_NODE'].split(':') |
|
self.master_address = master_node_params[0] |
|
self.master_port = default_torch_distributed_port |
|
elif self.world_size == self.local_size: |
|
|
|
self.env_info = 'single-node AML or other MPI environment' |
|
self.master_address = local_address |
|
self.master_port = default_torch_distributed_port |
|
else: |
|
|
|
|
|
self.env_info = 'multi-node other MPI environment' |
|
if self.rank == 0: |
|
hostname_cmd = ["hostname -I"] |
|
result = subprocess.check_output(hostname_cmd, shell=True) |
|
self.master_address = result.decode('utf-8').split()[0] |
|
self.master_port = default_torch_distributed_port |
|
else: |
|
self.master_address = None |
|
self.master_port = None |
|
self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0) |
|
self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0) |
|
|
|
self.init_method_url = f'tcp://{self.master_address}:{self.master_port}' |
|
if set_env_vars: |
|
self._set_env_vars() |
|
|
|
def log_info(self): |
|
""" |
|
Logs information about distributed training environment. |
|
""" |
|
|
|
|
|
logger.warning('----------------') |
|
logger.warning('MPI Adapter data') |
|
logger.warning('----------------') |
|
logger.warning(f'environment info: {self.env_info}') |
|
logger.warning(f'init method url: {self.init_method_url}') |
|
logger.warning(f'world size: {self.world_size}') |
|
logger.warning(f'local size: {self.local_size}') |
|
logger.warning(f'rank: {self.rank}') |
|
logger.warning(f'local rank: {self.local_rank}') |
|
logger.warning(f'master address: {self.master_address}') |
|
logger.warning(f'master port: {self.master_port}') |
|
logger.warning('----------------') |
|
|
|
def init_process_group(self, backend): |
|
""" |
|
Initializes the default PyTorch distributed process group. |
|
""" |
|
|
|
|
|
logger.warning('trying to initialize process group ...') |
|
torch.distributed.init_process_group(backend=backend, |
|
init_method=self.init_method_url, |
|
world_size=self.world_size, |
|
rank=self.rank) |
|
logger.warning('process group initialized') |
|
|
|
def _set_env_vars(self): |
|
""" |
|
Sets environment variables for world size, rank, local rank, master addr, and master port. |
|
""" |
|
os.environ['WORLD_SIZE'] = str(self.world_size) |
|
os.environ['RANK'] = str(self.rank) |
|
os.environ["LOCAL_RANK"] = str(self.local_rank) |
|
os.environ['MASTER_ADDR'] = self.master_address |
|
os.environ['MASTER_PORT'] = self.master_port |
|
|