You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
4.8 KiB

  1. import torch
  2. import torch.distributed as dist
  3. from torch.nn.modules import Module
  4. def _flatten_dense_tensors(tensors):
  5. """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
  6. same dense type.
  7. Since inputs are dense, the resulting tensor will be a concatenated 1D
  8. buffer. Element-wise operation on this buffer will be equivalent to
  9. operating individually.
  10. Arguments:
  11. tensors (Iterable[Tensor]): dense tensors to flatten.
  12. Returns:
  13. A contiguous 1D buffer containing input tensors.
  14. """
  15. if len(tensors) == 1:
  16. return tensors[0].contiguous().view(-1)
  17. flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
  18. return flat
  19. def _unflatten_dense_tensors(flat, tensors):
  20. """View a flat buffer using the sizes of tensors. Assume that tensors are of
  21. same dense type, and that flat is given by _flatten_dense_tensors.
  22. Arguments:
  23. flat (Tensor): flattened dense tensors to unflatten.
  24. tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
  25. unflatten flat.
  26. Returns:
  27. Unflattened dense tensors with sizes same as tensors and values from
  28. flat.
  29. """
  30. outputs = []
  31. offset = 0
  32. for tensor in tensors:
  33. numel = tensor.numel()
  34. outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
  35. offset += numel
  36. return tuple(outputs)
  37. '''
  38. This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
  39. launcher included with this example. It assumes that your run is using multiprocess with 1
  40. GPU/process, that the model is on the correct device, and that torch.set_device has been
  41. used to set the device.
  42. Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
  43. and will be allreduced at the finish of the backward pass.
  44. '''
  45. class DistributedDataParallel(Module):
  46. def __init__(self, module):
  47. super(DistributedDataParallel, self).__init__()
  48. #fallback for PyTorch 0.3
  49. if not hasattr(dist, '_backend'):
  50. self.warn_on_half = True
  51. else:
  52. self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
  53. self.module = module
  54. for p in self.module.state_dict().values():
  55. if not torch.is_tensor(p):
  56. continue
  57. dist.broadcast(p, 0)
  58. def allreduce_params():
  59. if(self.needs_reduction):
  60. self.needs_reduction = False
  61. buckets = {}
  62. for param in self.module.parameters():
  63. if param.requires_grad and param.grad is not None:
  64. tp = type(param.data)
  65. if tp not in buckets:
  66. buckets[tp] = []
  67. buckets[tp].append(param)
  68. if self.warn_on_half:
  69. if torch.cuda.HalfTensor in buckets:
  70. print("WARNING: gloo dist backend for half parameters may be extremely slow." +
  71. " It is recommended to use the NCCL backend in this case. This currently requires" +
  72. "PyTorch built from top of tree master.")
  73. self.warn_on_half = False
  74. for tp in buckets:
  75. bucket = buckets[tp]
  76. grads = [param.grad.data for param in bucket]
  77. coalesced = _flatten_dense_tensors(grads)
  78. dist.all_reduce(coalesced)
  79. coalesced /= dist.get_world_size()
  80. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
  81. buf.copy_(synced)
  82. for param in list(self.module.parameters()):
  83. def allreduce_hook(*unused):
  84. param._execution_engine.queue_callback(allreduce_params)
  85. if param.requires_grad:
  86. param.register_hook(allreduce_hook)
  87. def forward(self, *inputs, **kwargs):
  88. self.needs_reduction = True
  89. return self.module(*inputs, **kwargs)
  90. '''
  91. def _sync_buffers(self):
  92. buffers = list(self.module._all_buffers())
  93. if len(buffers) > 0:
  94. # cross-node buffer sync
  95. flat_buffers = _flatten_dense_tensors(buffers)
  96. dist.broadcast(flat_buffers, 0)
  97. for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
  98. buf.copy_(synced)
  99. def train(self, mode=True):
  100. # Clear NCCL communicator and CUDA event cache of the default group ID,
  101. # These cache will be recreated at the later call. This is currently a
  102. # work-around for a potential NCCL deadlock.
  103. if dist._backend == dist.dist_backend.NCCL:
  104. dist._clear_group_cache()
  105. super(DistributedDataParallel, self).train(mode)
  106. self.module.train(mode)
  107. '''