You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
7.1 KiB

  1. import torch
  2. import torch.distributed as dist
  3. from torch.nn.modules import Module
  4. from torch.autograd import Variable
  5. def _flatten_dense_tensors(tensors):
  6. """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
  7. same dense type.
  8. Since inputs are dense, the resulting tensor will be a concatenated 1D
  9. buffer. Element-wise operation on this buffer will be equivalent to
  10. operating individually.
  11. Arguments:
  12. tensors (Iterable[Tensor]): dense tensors to flatten.
  13. Returns:
  14. A contiguous 1D buffer containing input tensors.
  15. """
  16. if len(tensors) == 1:
  17. return tensors[0].contiguous().view(-1)
  18. flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
  19. return flat
  20. def _unflatten_dense_tensors(flat, tensors):
  21. """View a flat buffer using the sizes of tensors. Assume that tensors are of
  22. same dense type, and that flat is given by _flatten_dense_tensors.
  23. Arguments:
  24. flat (Tensor): flattened dense tensors to unflatten.
  25. tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
  26. unflatten flat.
  27. Returns:
  28. Unflattened dense tensors with sizes same as tensors and values from
  29. flat.
  30. """
  31. outputs = []
  32. offset = 0
  33. for tensor in tensors:
  34. numel = tensor.numel()
  35. outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
  36. offset += numel
  37. return tuple(outputs)
  38. '''
  39. This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
  40. launcher included with this example. It assumes that your run is using multiprocess with 1
  41. GPU/process, that the model is on the correct device, and that torch.set_device has been
  42. used to set the device.
  43. Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
  44. and will be allreduced at the finish of the backward pass.
  45. '''
  46. class DistributedDataParallel(Module):
  47. def __init__(self, module):
  48. super(DistributedDataParallel, self).__init__()
  49. #fallback for PyTorch 0.3
  50. if not hasattr(dist, '_backend'):
  51. self.warn_on_half = True
  52. else:
  53. self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
  54. self.module = module
  55. for p in self.module.state_dict().values():
  56. if not torch.is_tensor(p):
  57. continue
  58. dist.broadcast(p, 0)
  59. def allreduce_params():
  60. if(self.needs_reduction):
  61. self.needs_reduction = False
  62. buckets = {}
  63. for param in self.module.parameters():
  64. if param.requires_grad and param.grad is not None:
  65. tp = type(param.data)
  66. if tp not in buckets:
  67. buckets[tp] = []
  68. buckets[tp].append(param)
  69. if self.warn_on_half:
  70. if torch.cuda.HalfTensor in buckets:
  71. print("WARNING: gloo dist backend for half parameters may be extremely slow." +
  72. " It is recommended to use the NCCL backend in this case. This currently requires" +
  73. "PyTorch built from top of tree master.")
  74. self.warn_on_half = False
  75. for tp in buckets:
  76. bucket = buckets[tp]
  77. grads = [param.grad.data for param in bucket]
  78. coalesced = _flatten_dense_tensors(grads)
  79. dist.all_reduce(coalesced)
  80. coalesced /= dist.get_world_size()
  81. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
  82. buf.copy_(synced)
  83. for param in list(self.module.parameters()):
  84. def allreduce_hook(*unused):
  85. param._execution_engine.queue_callback(allreduce_params)
  86. if param.requires_grad:
  87. param.register_hook(allreduce_hook)
  88. def forward(self, *inputs, **kwargs):
  89. self.needs_reduction = True
  90. return self.module(*inputs, **kwargs)
  91. '''
  92. def _sync_buffers(self):
  93. buffers = list(self.module._all_buffers())
  94. if len(buffers) > 0:
  95. # cross-node buffer sync
  96. flat_buffers = _flatten_dense_tensors(buffers)
  97. dist.broadcast(flat_buffers, 0)
  98. for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
  99. buf.copy_(synced)
  100. def train(self, mode=True):
  101. # Clear NCCL communicator and CUDA event cache of the default group ID,
  102. # These cache will be recreated at the later call. This is currently a
  103. # work-around for a potential NCCL deadlock.
  104. if dist._backend == dist.dist_backend.NCCL:
  105. dist._clear_group_cache()
  106. super(DistributedDataParallel, self).train(mode)
  107. self.module.train(mode)
  108. '''
  109. '''
  110. Modifies existing model to do gradient allreduce, but doesn't change class
  111. so you don't need "module"
  112. '''
  113. def apply_gradient_allreduce(module):
  114. if not hasattr(dist, '_backend'):
  115. module.warn_on_half = True
  116. else:
  117. module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
  118. for p in module.state_dict().values():
  119. if not torch.is_tensor(p):
  120. continue
  121. dist.broadcast(p, 0)
  122. def allreduce_params():
  123. if(module.needs_reduction):
  124. module.needs_reduction = False
  125. buckets = {}
  126. for param in module.parameters():
  127. if param.requires_grad and param.grad is not None:
  128. tp = param.data.dtype
  129. if tp not in buckets:
  130. buckets[tp] = []
  131. buckets[tp].append(param)
  132. if module.warn_on_half:
  133. if torch.cuda.HalfTensor in buckets:
  134. print("WARNING: gloo dist backend for half parameters may be extremely slow." +
  135. " It is recommended to use the NCCL backend in this case. This currently requires" +
  136. "PyTorch built from top of tree master.")
  137. module.warn_on_half = False
  138. for tp in buckets:
  139. bucket = buckets[tp]
  140. grads = [param.grad.data for param in bucket]
  141. coalesced = _flatten_dense_tensors(grads)
  142. dist.all_reduce(coalesced)
  143. coalesced /= dist.get_world_size()
  144. for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
  145. buf.copy_(synced)
  146. for param in list(module.parameters()):
  147. def allreduce_hook(*unused):
  148. Variable._execution_engine.queue_callback(allreduce_params)
  149. if param.requires_grad:
  150. param.register_hook(allreduce_hook)
  151. def set_needs_reduction(self, input, output):
  152. self.needs_reduction = True
  153. module.register_forward_hook(set_needs_reduction)
  154. return module