You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
4.3 KiB

  1. import torch
  2. class LossScaler:
  3. def __init__(self, scale=1):
  4. self.cur_scale = scale
  5. # `params` is a list / generator of torch.Variable
  6. def has_overflow(self, params):
  7. return False
  8. # `x` is a torch.Tensor
  9. def _has_inf_or_nan(x):
  10. return False
  11. # `overflow` is boolean indicating whether we overflowed in gradient
  12. def update_scale(self, overflow):
  13. pass
  14. @property
  15. def loss_scale(self):
  16. return self.cur_scale
  17. def scale_gradient(self, module, grad_in, grad_out):
  18. return tuple(self.loss_scale * g for g in grad_in)
  19. def backward(self, loss):
  20. scaled_loss = loss*self.loss_scale
  21. scaled_loss.backward()
  22. class DynamicLossScaler:
  23. def __init__(self,
  24. init_scale=2**32,
  25. scale_factor=2.,
  26. scale_window=1000):
  27. self.cur_scale = init_scale
  28. self.cur_iter = 0
  29. self.last_overflow_iter = -1
  30. self.scale_factor = scale_factor
  31. self.scale_window = scale_window
  32. # `params` is a list / generator of torch.Variable
  33. def has_overflow(self, params):
  34. # return False
  35. for p in params:
  36. if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
  37. return True
  38. return False
  39. # `x` is a torch.Tensor
  40. def _has_inf_or_nan(x):
  41. inf_count = torch.sum(x.abs() == float('inf'))
  42. if inf_count > 0:
  43. return True
  44. nan_count = torch.sum(x != x)
  45. return nan_count > 0
  46. # `overflow` is boolean indicating whether we overflowed in gradient
  47. def update_scale(self, overflow):
  48. if overflow:
  49. #self.cur_scale /= self.scale_factor
  50. self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
  51. self.last_overflow_iter = self.cur_iter
  52. else:
  53. if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
  54. self.cur_scale *= self.scale_factor
  55. # self.cur_scale = 1
  56. self.cur_iter += 1
  57. @property
  58. def loss_scale(self):
  59. return self.cur_scale
  60. def scale_gradient(self, module, grad_in, grad_out):
  61. return tuple(self.loss_scale * g for g in grad_in)
  62. def backward(self, loss):
  63. scaled_loss = loss*self.loss_scale
  64. scaled_loss.backward()
  65. ##############################################################
  66. # Example usage below here -- assuming it's in a separate file
  67. ##############################################################
  68. if __name__ == "__main__":
  69. import torch
  70. from torch.autograd import Variable
  71. from dynamic_loss_scaler import DynamicLossScaler
  72. # N is batch size; D_in is input dimension;
  73. # H is hidden dimension; D_out is output dimension.
  74. N, D_in, H, D_out = 64, 1000, 100, 10
  75. # Create random Tensors to hold inputs and outputs, and wrap them in Variables.
  76. x = Variable(torch.randn(N, D_in), requires_grad=False)
  77. y = Variable(torch.randn(N, D_out), requires_grad=False)
  78. w1 = Variable(torch.randn(D_in, H), requires_grad=True)
  79. w2 = Variable(torch.randn(H, D_out), requires_grad=True)
  80. parameters = [w1, w2]
  81. learning_rate = 1e-6
  82. optimizer = torch.optim.SGD(parameters, lr=learning_rate)
  83. loss_scaler = DynamicLossScaler()
  84. for t in range(500):
  85. y_pred = x.mm(w1).clamp(min=0).mm(w2)
  86. loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
  87. print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
  88. print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
  89. print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
  90. # Run backprop
  91. optimizer.zero_grad()
  92. loss.backward()
  93. # Check for overflow
  94. has_overflow = DynamicLossScaler.has_overflow(parameters)
  95. # If no overflow, unscale grad and update as usual
  96. if not has_overflow:
  97. for param in parameters:
  98. param.grad.data.mul_(1. / loss_scaler.loss_scale)
  99. optimizer.step()
  100. # Otherwise, don't do anything -- ie, skip iteration
  101. else:
  102. print('OVERFLOW!')
  103. # Update loss scale for next iteration
  104. loss_scaler.update_scale(has_overflow)