From 52a30bb7b6b023e800b9ed813c5dcbc94d515451 Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Tue, 27 Nov 2018 21:01:26 -0800 Subject: [PATCH] distributed.py: replacing to avoid distributed error --- distributed.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed.py b/distributed.py index 1dd5910..9076a73 100644 --- a/distributed.py +++ b/distributed.py @@ -1,6 +1,7 @@ import torch import torch.distributed as dist from torch.nn.modules import Module +from torch.autograd import Variable def _flatten_dense_tensors(tensors): """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of @@ -161,12 +162,12 @@ def apply_gradient_allreduce(module): for param in list(module.parameters()): def allreduce_hook(*unused): - param._execution_engine.queue_callback(allreduce_params) + Variable._execution_engine.queue_callback(allreduce_params) if param.requires_grad: param.register_hook(allreduce_hook) - + def set_needs_reduction(self, input, output): self.needs_reduction = True - + module.register_forward_hook(set_needs_reduction) return module