diff --git a/distributed.py b/distributed.py index 9076a73..cce7494 100644 --- a/distributed.py +++ b/distributed.py @@ -140,7 +140,7 @@ def apply_gradient_allreduce(module): buckets = {} for param in module.parameters(): if param.requires_grad and param.grad is not None: - tp = type(param.data) + tp = param.data.dtype if tp not in buckets: buckets[tp] = [] buckets[tp].append(param)