|
|
@ -118,3 +118,55 @@ class DistributedDataParallel(Module): |
|
|
|
super(DistributedDataParallel, self).train(mode) |
|
|
|
self.module.train(mode) |
|
|
|
''' |
|
|
|
''' |
|
|
|
Modifies existing model to do gradient allreduce, but doesn't change class |
|
|
|
so you don't need "module" |
|
|
|
''' |
|
|
|
def apply_gradient_allreduce(module): |
|
|
|
if not hasattr(dist, '_backend'): |
|
|
|
module.warn_on_half = True |
|
|
|
else: |
|
|
|
module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False |
|
|
|
|
|
|
|
for p in module.state_dict().values(): |
|
|
|
if not torch.is_tensor(p): |
|
|
|
continue |
|
|
|
dist.broadcast(p, 0) |
|
|
|
|
|
|
|
def allreduce_params(): |
|
|
|
if(module.needs_reduction): |
|
|
|
module.needs_reduction = False |
|
|
|
buckets = {} |
|
|
|
for param in module.parameters(): |
|
|
|
if param.requires_grad and param.grad is not None: |
|
|
|
tp = type(param.data) |
|
|
|
if tp not in buckets: |
|
|
|
buckets[tp] = [] |
|
|
|
buckets[tp].append(param) |
|
|
|
if module.warn_on_half: |
|
|
|
if torch.cuda.HalfTensor in buckets: |
|
|
|
print("WARNING: gloo dist backend for half parameters may be extremely slow." + |
|
|
|
" It is recommended to use the NCCL backend in this case. This currently requires" + |
|
|
|
"PyTorch built from top of tree master.") |
|
|
|
module.warn_on_half = False |
|
|
|
|
|
|
|
for tp in buckets: |
|
|
|
bucket = buckets[tp] |
|
|
|
grads = [param.grad.data for param in bucket] |
|
|
|
coalesced = _flatten_dense_tensors(grads) |
|
|
|
dist.all_reduce(coalesced) |
|
|
|
coalesced /= dist.get_world_size() |
|
|
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): |
|
|
|
buf.copy_(synced) |
|
|
|
|
|
|
|
for param in list(module.parameters()): |
|
|
|
def allreduce_hook(*unused): |
|
|
|
param._execution_engine.queue_callback(allreduce_params) |
|
|
|
if param.requires_grad: |
|
|
|
param.register_hook(allreduce_hook) |
|
|
|
|
|
|
|
def set_needs_reduction(self, input, output): |
|
|
|
self.needs_reduction = True |
|
|
|
|
|
|
|
module.register_forward_hook(set_needs_reduction) |
|
|
|
return module |