From df4a466af2093f3825e72cd6e8a6cc96c4630360 Mon Sep 17 00:00:00 2001 From: gkarch Date: Fri, 1 Feb 2019 09:55:59 +0100 Subject: [PATCH] Fixing concatenation error for fp16 ditributed training --- distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed.py b/distributed.py index 9076a73..cce7494 100644 --- a/distributed.py +++ b/distributed.py @@ -140,7 +140,7 @@ def apply_gradient_allreduce(module): buckets = {} for param in module.parameters(): if param.requires_grad and param.grad is not None: - tp = type(param.data) + tp = param.data.dtype if tp not in buckets: buckets[tp] = [] buckets[tp].append(param)