diff --git a/loss_scaler.py b/loss_scaler.py index c7dfa13..0662a60 100644 --- a/loss_scaler.py +++ b/loss_scaler.py @@ -51,11 +51,10 @@ class DynamicLossScaler: # `x` is a torch.Tensor def _has_inf_or_nan(x): - inf_count = torch.sum(x.abs() == float('inf')) - if inf_count > 0: + cpu_sum = float(x.float().sum()) + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: return True - nan_count = torch.sum(x != x) - return nan_count > 0 + return False # `overflow` is boolean indicating whether we overflowed in gradient def update_scale(self, overflow):