medical
File size: 7,257 Bytes
5ceacbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging
from mpi4py import MPI
import os
import re
import subprocess
import torch

logger = logging.getLogger(__name__)


class MPIAdapter:
    """
    MPIAdapter automatically detects and analyzes the training environment for distributed training
    and offers methods to set up distributed training jobs.

    For example, it determines whether training happens on AML, Philly, or locally.
    It also determines variables such as the world size and the rank of each GPU.
    """

    def __init__(self, set_env_vars=True, master_address=None, port='55551'):
        local_address = '127.0.0.1'
        default_torch_distributed_port = str(port)  # chosen arbitrarily

        if 'OMPI_COMM_WORLD_SIZE' not in os.environ:
            # application was started without MPI
            # default to single node with single process
            self.env_info = 'no MPI'
            self.world_size = 1
            self.local_size = 1
            self.rank = 0
            self.local_rank = 0
            self.master_address = local_address
            self.master_port = default_torch_distributed_port
        else:
            # application was started with MPI
            # get MPI parameters
            self.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
            self.local_size = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
            self.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
            self.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])

            if master_address is not None:
                self.master_address = master_address
                self.master_port = default_torch_distributed_port
                self.env_info = 'manually set master ip'
            elif 'PHILLY_CONTAINER_IP' in os.environ:
                # application is running on Philly
                # read environment variables on master node and broadcast via MPI
                self.env_info = 'philly'
                if self.rank == 0:
                    self.master_address = os.environ['PHILLY_CONTAINER_IP']
                    self.master_port = os.environ['PHILLY_CONTAINER_PORT_RANGE_START']
                else:
                    self.master_address = None
                    self.master_port = None
                self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0)
                self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0)
            elif "AMLK8S_NUM_WORKER" in os.environ or "AZ_CMK8S_JOB_WORK_DIR" in os.environ:
                # application is running on AMLK8S (ITP)
                # read master address from a specific file.
                self.env_info = 'AMLK8S (ITP)'
                # from: https://k8s-wiki.azureml.com/faq.html
                regexp = r"[\s\S]*export[\s]*DLTS_SD_worker0_IP=([0-9.]+)[\s|s]*"
                with open("/dlts-runtime/env/init.env", 'r') as f:
                    line = f.read()
                match = re.match(regexp, line)
                if match:
                    self.master_address = str(match.group(1))
                else:
                    # Did not find master node ip in file. It must be a single-node
                    # debugging job with custom "mpirun" command
                    assert self.world_size == self.local_size, \
                        "It's not a single-node debugging job on AMLK8S (ITP), but no master ip is found in file."
                    self.env_info = 'single-node AMLK8S (ITP) debugging job'
                    self.master_address = local_address
                self.master_port = default_torch_distributed_port
            elif 'AZ_BATCH_MASTER_NODE' in os.environ:
                # application is running on multiple nodes on AML
                self.env_info = 'multi-node AML'
                master_node_params = os.environ['AZ_BATCH_MASTER_NODE'].split(':')
                self.master_address = master_node_params[0]
                self.master_port = default_torch_distributed_port
            elif self.world_size == self.local_size:
                # application is running with MPI on single node
                self.env_info = 'single-node AML or other MPI environment'
                self.master_address = local_address
                self.master_port = default_torch_distributed_port
            else:
                # multi-node MPI environment, but not Philly or AML
                # we use "hostname -I" command on rank 0 to get the master address
                self.env_info = 'multi-node other MPI environment'
                if self.rank == 0:
                    hostname_cmd = ["hostname -I"]
                    result = subprocess.check_output(hostname_cmd, shell=True)
                    self.master_address = result.decode('utf-8').split()[0]
                    self.master_port = default_torch_distributed_port
                else:
                    self.master_address = None
                    self.master_port = None
                self.master_address = MPI.COMM_WORLD.bcast(self.master_address, root=0)
                self.master_port = MPI.COMM_WORLD.bcast(self.master_port, root=0)

        self.init_method_url = f'tcp://{self.master_address}:{self.master_port}'
        if set_env_vars:
            self._set_env_vars()

    def log_info(self):
        """
        Logs information about distributed training environment.
        """
        # use logger.warning because MainzTrain has a hidden convention
        # of not printing logger.info messages on processes with rank > 0
        logger.warning('----------------')
        logger.warning('MPI Adapter data')
        logger.warning('----------------')
        logger.warning(f'environment info: {self.env_info}')
        logger.warning(f'init method url: {self.init_method_url}')
        logger.warning(f'world size: {self.world_size}')
        logger.warning(f'local size: {self.local_size}')
        logger.warning(f'rank: {self.rank}')
        logger.warning(f'local rank: {self.local_rank}')
        logger.warning(f'master address: {self.master_address}')
        logger.warning(f'master port: {self.master_port}')
        logger.warning('----------------')

    def init_process_group(self, backend):
        """
        Initializes the default PyTorch distributed process group.
        """
        # use logger.warning because MainzTrain has a hidden convention
        # of not printing logger.info messages on processes with rank > 0
        logger.warning('trying to initialize process group ...')
        torch.distributed.init_process_group(backend=backend,
                                             init_method=self.init_method_url,
                                             world_size=self.world_size,
                                             rank=self.rank)
        logger.warning('process group initialized')

    def _set_env_vars(self):
        """
        Sets environment variables for world size, rank, local rank, master addr, and master port.
        """
        os.environ['WORLD_SIZE'] = str(self.world_size)
        os.environ['RANK'] = str(self.rank)
        os.environ["LOCAL_RANK"] = str(self.local_rank)
        os.environ['MASTER_ADDR'] = self.master_address
        os.environ['MASTER_PORT'] = self.master_port