import torch import torch.distributed as dist from torch.nn.modules import Module def _flatten_dense_tensors(tensors): """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of same dense type. Since inputs are dense, the resulting tensor will be a concatenated 1D buffer. Element-wise operation on this buffer will be equivalent to operating individually. Arguments: tensors (Iterable[Tensor]): dense tensors to flatten. Returns: A contiguous 1D buffer containing input tensors. """ if len(tensors) == 1: return tensors[0].contiguous().view(-1) flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) return flat def _unflatten_dense_tensors(flat, tensors): """View a flat buffer using the sizes of tensors. Assume that tensors are of same dense type, and that flat is given by _flatten_dense_tensors. Arguments: flat (Tensor): flattened dense tensors to unflatten. tensors (Iterable[Tensor]): dense tensors whose sizes will be used to unflatten flat. Returns: Unflattened dense tensors with sizes same as tensors and values from flat. """ outputs = [] offset = 0 for tensor in tensors: numel = tensor.numel() outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) offset += numel return tuple(outputs) ''' This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py launcher included with this example. It assumes that your run is using multiprocess with 1 GPU/process, that the model is on the correct device, and that torch.set_device has been used to set the device. Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, and will be allreduced at the finish of the backward pass. ''' class DistributedDataParallel(Module): def __init__(self, module): super(DistributedDataParallel, self).__init__() #fallback for PyTorch 0.3 if not hasattr(dist, '_backend'): self.warn_on_half = True else: self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False self.module = module for p in self.module.state_dict().values(): if not torch.is_tensor(p): continue dist.broadcast(p, 0) def allreduce_params(): if(self.needs_reduction): self.needs_reduction = False buckets = {} for param in self.module.parameters(): if param.requires_grad and param.grad is not None: tp = type(param.data) if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) if self.warn_on_half: if torch.cuda.HalfTensor in buckets: print("WARNING: gloo dist backend for half parameters may be extremely slow." + " It is recommended to use the NCCL backend in this case. This currently requires" + "PyTorch built from top of tree master.") self.warn_on_half = False for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = _flatten_dense_tensors(grads) dist.all_reduce(coalesced) coalesced /= dist.get_world_size() for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) for param in list(self.module.parameters()): def allreduce_hook(*unused): param._execution_engine.queue_callback(allreduce_params) if param.requires_grad: param.register_hook(allreduce_hook) def forward(self, *inputs, **kwargs): self.needs_reduction = True return self.module(*inputs, **kwargs) ''' def _sync_buffers(self): buffers = list(self.module._all_buffers()) if len(buffers) > 0: # cross-node buffer sync flat_buffers = _flatten_dense_tensors(buffers) dist.broadcast(flat_buffers, 0) for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): buf.copy_(synced) def train(self, mode=True): # Clear NCCL communicator and CUDA event cache of the default group ID, # These cache will be recreated at the later call. This is currently a # work-around for a potential NCCL deadlock. if dist._backend == dist.dist_backend.NCCL: dist._clear_group_cache() super(DistributedDataParallel, self).train(mode) self.module.train(mode) '''