|
@ -140,7 +140,7 @@ def apply_gradient_allreduce(module): |
|
|
buckets = {} |
|
|
buckets = {} |
|
|
for param in module.parameters(): |
|
|
for param in module.parameters(): |
|
|
if param.requires_grad and param.grad is not None: |
|
|
if param.requires_grad and param.grad is not None: |
|
|
tp = type(param.data) |
|
|
|
|
|
|
|
|
tp = param.data.dtype |
|
|
if tp not in buckets: |
|
|
if tp not in buckets: |
|
|
buckets[tp] = [] |
|
|
buckets[tp] = [] |
|
|
buckets[tp].append(param) |
|
|
buckets[tp].append(param) |
|
|