|
|
@ -1,6 +1,7 @@ |
|
|
|
import torch |
|
|
|
import torch.distributed as dist |
|
|
|
from torch.nn.modules import Module |
|
|
|
from torch.autograd import Variable |
|
|
|
|
|
|
|
def _flatten_dense_tensors(tensors): |
|
|
|
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of |
|
|
@ -161,12 +162,12 @@ def apply_gradient_allreduce(module): |
|
|
|
|
|
|
|
for param in list(module.parameters()): |
|
|
|
def allreduce_hook(*unused): |
|
|
|
param._execution_engine.queue_callback(allreduce_params) |
|
|
|
Variable._execution_engine.queue_callback(allreduce_params) |
|
|
|
if param.requires_grad: |
|
|
|
param.register_hook(allreduce_hook) |
|
|
|
|
|
|
|
|
|
|
|
def set_needs_reduction(self, input, output): |
|
|
|
self.needs_reduction = True |
|
|
|
|
|
|
|
|
|
|
|
module.register_forward_hook(set_needs_reduction) |
|
|
|
return module |