From cd851585cb6f96128fac5b590e3ab4aa1ab59b95 Mon Sep 17 00:00:00 2001 From: Rafael Valle Date: Tue, 15 May 2018 09:50:08 -0700 Subject: [PATCH] loss_scaler.py: patching loss scaler for compatibility with current pytorch --- loss_scaler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/loss_scaler.py b/loss_scaler.py index c7dfa13..0662a60 100644 --- a/loss_scaler.py +++ b/loss_scaler.py @@ -51,11 +51,10 @@ class DynamicLossScaler: # `x` is a torch.Tensor def _has_inf_or_nan(x): - inf_count = torch.sum(x.abs() == float('inf')) - if inf_count > 0: + cpu_sum = float(x.float().sum()) + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: return True - nan_count = torch.sum(x != x) - return nan_count > 0 + return False # `overflow` is boolean indicating whether we overflowed in gradient def update_scale(self, overflow):