You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
4.3 KiB

  1. import torch
  2. class LossScaler:
  3. def __init__(self, scale=1):
  4. self.cur_scale = scale
  5. # `params` is a list / generator of torch.Variable
  6. def has_overflow(self, params):
  7. return False
  8. # `x` is a torch.Tensor
  9. def _has_inf_or_nan(x):
  10. return False
  11. # `overflow` is boolean indicating whether we overflowed in gradient
  12. def update_scale(self, overflow):
  13. pass
  14. @property
  15. def loss_scale(self):
  16. return self.cur_scale
  17. def scale_gradient(self, module, grad_in, grad_out):
  18. return tuple(self.loss_scale * g for g in grad_in)
  19. def backward(self, loss):
  20. scaled_loss = loss*self.loss_scale
  21. scaled_loss.backward()
  22. class DynamicLossScaler:
  23. def __init__(self,
  24. init_scale=2**32,
  25. scale_factor=2.,
  26. scale_window=1000):
  27. self.cur_scale = init_scale
  28. self.cur_iter = 0
  29. self.last_overflow_iter = -1
  30. self.scale_factor = scale_factor
  31. self.scale_window = scale_window
  32. # `params` is a list / generator of torch.Variable
  33. def has_overflow(self, params):
  34. # return False
  35. for p in params:
  36. if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
  37. return True
  38. return False
  39. # `x` is a torch.Tensor
  40. def _has_inf_or_nan(x):
  41. cpu_sum = float(x.float().sum())
  42. if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
  43. return True
  44. return False
  45. # `overflow` is boolean indicating whether we overflowed in gradient
  46. def update_scale(self, overflow):
  47. if overflow:
  48. #self.cur_scale /= self.scale_factor
  49. self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
  50. self.last_overflow_iter = self.cur_iter
  51. else:
  52. if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
  53. self.cur_scale *= self.scale_factor
  54. # self.cur_scale = 1
  55. self.cur_iter += 1
  56. @property
  57. def loss_scale(self):
  58. return self.cur_scale
  59. def scale_gradient(self, module, grad_in, grad_out):
  60. return tuple(self.loss_scale * g for g in grad_in)
  61. def backward(self, loss):
  62. scaled_loss = loss*self.loss_scale
  63. scaled_loss.backward()
  64. ##############################################################
  65. # Example usage below here -- assuming it's in a separate file
  66. ##############################################################
  67. if __name__ == "__main__":
  68. import torch
  69. from torch.autograd import Variable
  70. from dynamic_loss_scaler import DynamicLossScaler
  71. # N is batch size; D_in is input dimension;
  72. # H is hidden dimension; D_out is output dimension.
  73. N, D_in, H, D_out = 64, 1000, 100, 10
  74. # Create random Tensors to hold inputs and outputs, and wrap them in Variables.
  75. x = Variable(torch.randn(N, D_in), requires_grad=False)
  76. y = Variable(torch.randn(N, D_out), requires_grad=False)
  77. w1 = Variable(torch.randn(D_in, H), requires_grad=True)
  78. w2 = Variable(torch.randn(H, D_out), requires_grad=True)
  79. parameters = [w1, w2]
  80. learning_rate = 1e-6
  81. optimizer = torch.optim.SGD(parameters, lr=learning_rate)
  82. loss_scaler = DynamicLossScaler()
  83. for t in range(500):
  84. y_pred = x.mm(w1).clamp(min=0).mm(w2)
  85. loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
  86. print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
  87. print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
  88. print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
  89. # Run backprop
  90. optimizer.zero_grad()
  91. loss.backward()
  92. # Check for overflow
  93. has_overflow = DynamicLossScaler.has_overflow(parameters)
  94. # If no overflow, unscale grad and update as usual
  95. if not has_overflow:
  96. for param in parameters:
  97. param.grad.data.mul_(1. / loss_scaler.loss_scale)
  98. optimizer.step()
  99. # Otherwise, don't do anything -- ie, skip iteration
  100. else:
  101. print('OVERFLOW!')
  102. # Update loss scale for next iteration
  103. loss_scaler.update_scale(has_overflow)