|
|
@ -51,11 +51,10 @@ class DynamicLossScaler: |
|
|
|
|
|
|
|
# `x` is a torch.Tensor |
|
|
|
def _has_inf_or_nan(x): |
|
|
|
inf_count = torch.sum(x.abs() == float('inf')) |
|
|
|
if inf_count > 0: |
|
|
|
cpu_sum = float(x.float().sum()) |
|
|
|
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: |
|
|
|
return True |
|
|
|
nan_count = torch.sum(x != x) |
|
|
|
return nan_count > 0 |
|
|
|
return False |
|
|
|
|
|
|
|
# `overflow` is boolean indicating whether we overflowed in gradient |
|
|
|
def update_scale(self, overflow): |
|
|
|