|
""" Pytorch Distributed utils |
|
This piece of code was heavily inspired by the equivalent of Fairseq-py |
|
https://github.com/pytorch/fairseq |
|
""" |
|
import os |
|
import signal |
|
import math |
|
import pickle |
|
import torch.distributed |
|
from datetime import timedelta |
|
from onmt.translate.translator import build_translator |
|
from onmt.transforms import get_transforms_cls |
|
from onmt.constants import CorpusTask |
|
from onmt.utils.logging import init_logger, logger |
|
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter |
|
from onmt.inputters.inputter import IterOnDevice |
|
|
|
|
|
def is_master(opt, device_id): |
|
return opt.gpu_ranks[device_id] == 0 |
|
|
|
|
|
def multi_init(opt, device_id): |
|
dist_init_method = "tcp://{master_ip}:{master_port}".format( |
|
master_ip=opt.master_ip, master_port=opt.master_port |
|
) |
|
dist_world_size = opt.world_size |
|
torch.distributed.init_process_group( |
|
backend=opt.gpu_backend, |
|
init_method=dist_init_method, |
|
world_size=dist_world_size, |
|
rank=opt.gpu_ranks[device_id], |
|
timeout=timedelta(seconds=60), |
|
) |
|
gpu_rank = torch.distributed.get_rank() |
|
if not is_master(opt, device_id): |
|
logger.disabled = True |
|
|
|
return gpu_rank |
|
|
|
|
|
def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=104857600): |
|
"""All-reduce and rescale tensors in chunks of the specified size. |
|
|
|
Args: |
|
tensors: list of Tensors to all-reduce |
|
rescale_denom: denominator for rescaling summed Tensors |
|
buffer_size: all-reduce chunk size in bytes |
|
""" |
|
|
|
buffer_t = ( |
|
tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_() |
|
) |
|
buffer = [] |
|
|
|
def all_reduce_buffer(): |
|
|
|
offset = 0 |
|
for t in buffer: |
|
numel = t.numel() |
|
buffer_t[offset : offset + numel].copy_(t.view(-1)) |
|
offset += numel |
|
|
|
|
|
torch.distributed.all_reduce(buffer_t[:offset], async_op=False) |
|
buffer_t.div_(rescale_denom) |
|
|
|
|
|
offset = 0 |
|
for t in buffer: |
|
numel = t.numel() |
|
t.view(-1).copy_(buffer_t[offset : offset + numel]) |
|
offset += numel |
|
|
|
filled = 0 |
|
for t in tensors: |
|
sz = t.numel() * t.element_size() |
|
|
|
if sz > buffer_size: |
|
|
|
torch.distributed.all_reduce(t, async_op=False) |
|
t.div_(rescale_denom) |
|
elif filled + sz > buffer_size: |
|
|
|
all_reduce_buffer() |
|
buffer = [t] |
|
filled = sz |
|
else: |
|
|
|
buffer.append(t) |
|
filled += sz |
|
|
|
if len(buffer) > 0: |
|
all_reduce_buffer() |
|
|
|
|
|
def all_gather_list(data, max_size=4096): |
|
"""Gathers arbitrary data from all nodes into a list.""" |
|
world_size = torch.distributed.get_world_size() |
|
if ( |
|
not hasattr(all_gather_list, "_in_buffer") |
|
or max_size != all_gather_list._in_buffer.size() |
|
): |
|
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) |
|
all_gather_list._out_buffers = [ |
|
torch.cuda.ByteTensor(max_size) for i in range(world_size) |
|
] |
|
in_buffer = all_gather_list._in_buffer |
|
out_buffers = all_gather_list._out_buffers |
|
|
|
enc = pickle.dumps(data) |
|
enc_size = len(enc) |
|
if enc_size + 2 > max_size: |
|
raise ValueError("encoded data exceeds max_size: {}".format(enc_size + 2)) |
|
assert max_size < 255 * 256 |
|
in_buffer[0] = enc_size // 255 |
|
in_buffer[1] = enc_size % 255 |
|
in_buffer[2 : enc_size + 2] = torch.ByteTensor(list(enc)) |
|
|
|
torch.distributed.all_gather(out_buffers, in_buffer.cuda()) |
|
|
|
results = [] |
|
for i in range(world_size): |
|
out_buffer = out_buffers[i] |
|
size = (255 * out_buffer[0].item()) + out_buffer[1].item() |
|
|
|
bytes_list = bytes(out_buffer[2 : size + 2].tolist()) |
|
result = pickle.loads(bytes_list) |
|
results.append(result) |
|
return results |
|
|
|
|
|
class ErrorHandler(object): |
|
"""A class that listens for exceptions in children processes and propagates |
|
the tracebacks to the parent process.""" |
|
|
|
def __init__(self, error_queue): |
|
"""init error handler""" |
|
import signal |
|
import threading |
|
|
|
self.error_queue = error_queue |
|
self.children_pids = [] |
|
self.error_thread = threading.Thread(target=self.error_listener, daemon=True) |
|
self.error_thread.start() |
|
signal.signal(signal.SIGUSR1, self.signal_handler) |
|
|
|
def add_child(self, pid): |
|
"""error handler""" |
|
self.children_pids.append(pid) |
|
|
|
def error_listener(self): |
|
"""error listener""" |
|
(rank, original_trace) = self.error_queue.get() |
|
self.error_queue.put((rank, original_trace)) |
|
os.kill(os.getpid(), signal.SIGUSR1) |
|
|
|
def signal_handler(self, signalnum, stackframe): |
|
"""signal handler""" |
|
for pid in self.children_pids: |
|
os.kill(pid, signal.SIGINT) |
|
(rank, original_trace) = self.error_queue.get() |
|
msg = """\n\n-- Tracebacks above this line can probably |
|
be ignored --\n\n""" |
|
msg += original_trace |
|
raise Exception(msg) |
|
|
|
|
|
def spawned_train(process_fn, opt, device_id, error_queue): |
|
"""Run `process_fn` on `device_id` with data from `batch_queue`.""" |
|
try: |
|
gpu_rank = multi_init(opt, device_id) |
|
if gpu_rank != opt.gpu_ranks[device_id]: |
|
raise AssertionError( |
|
"An error occurred in \ |
|
Distributed initialization" |
|
) |
|
process_fn(opt, device_id=device_id) |
|
except KeyboardInterrupt: |
|
pass |
|
except Exception: |
|
|
|
import traceback |
|
|
|
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) |
|
|
|
|
|
def spawned_infer(opt, device_id, error_queue, queue_instruct, queue_result): |
|
"""Run various functions for translation in spawned process on `device_id`.""" |
|
try: |
|
gpu_rank = multi_init(opt, device_id) |
|
if gpu_rank != opt.gpu_ranks[device_id]: |
|
raise AssertionError( |
|
"An error occurred in \ |
|
Distributed initialization" |
|
) |
|
torch.cuda.set_device(device_id) |
|
init_logger(opt.log_file) |
|
translator = build_translator(opt, device_id, logger=logger, report_score=True) |
|
transforms_cls = get_transforms_cls(opt._all_transform) |
|
print("Device_id: ", device_id, " translator built") |
|
while True: |
|
instruction = queue_instruct.get() |
|
if instruction[0] == "stop": |
|
break |
|
elif instruction[0] == "infer_list": |
|
src = instruction[1] |
|
infer_iter = build_dynamic_dataset_iter( |
|
opt, |
|
transforms_cls, |
|
translator.vocabs, |
|
task=CorpusTask.INFER, |
|
src=src, |
|
) |
|
infer_iter = IterOnDevice(infer_iter, device_id) |
|
scores, preds = translator._translate( |
|
infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug |
|
) |
|
queue_result.put(scores) |
|
queue_result.put(preds) |
|
elif instruction[0] == "infer_file": |
|
infer_iter = build_dynamic_dataset_iter( |
|
opt, transforms_cls, translator.vocabs, task=CorpusTask.INFER |
|
) |
|
infer_iter = IterOnDevice(infer_iter, device_id) |
|
scores, preds = translator._translate( |
|
infer_iter, infer_iter.transform, opt.attn_debug, opt.align_debug |
|
) |
|
queue_result.put(scores) |
|
queue_result.put(preds) |
|
|
|
except KeyboardInterrupt: |
|
pass |
|
except Exception: |
|
|
|
import traceback |
|
|
|
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) |
|
|