From d0aa9e7d320f0a3827e4bcb4c4418be2afaf060d Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Sun, 25 Nov 2018 22:32:54 -0800 Subject: [PATCH] distributed.py: rewrite --- distributed.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/distributed.py b/distributed.py index ebe3b5b..1dd5910 100644 --- a/distributed.py +++ b/distributed.py @@ -118,3 +118,55 @@ class DistributedDataParallel(Module): super(DistributedDataParallel, self).train(mode) self.module.train(mode) ''' +''' +Modifies existing model to do gradient allreduce, but doesn't change class +so you don't need "module" +''' +def apply_gradient_allreduce(module): + if not hasattr(dist, '_backend'): + module.warn_on_half = True + else: + module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + for p in module.state_dict().values(): + if not torch.is_tensor(p): + continue + dist.broadcast(p, 0) + + def allreduce_params(): + if(module.needs_reduction): + module.needs_reduction = False + buckets = {} + for param in module.parameters(): + if param.requires_grad and param.grad is not None: + tp = type(param.data) + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + if module.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print("WARNING: gloo dist backend for half parameters may be extremely slow." + + " It is recommended to use the NCCL backend in this case. This currently requires" + + "PyTorch built from top of tree master.") + module.warn_on_half = False + + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced) + coalesced /= dist.get_world_size() + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + for param in list(module.parameters()): + def allreduce_hook(*unused): + param._execution_engine.queue_callback(allreduce_params) + if param.requires_grad: + param.register_hook(allreduce_hook) + + def set_needs_reduction(self, input, output): + self.needs_reduction = True + + module.register_forward_hook(set_needs_reduction) + return module