You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

381 lines
17 KiB

  1. import torch
  2. from torch import nn
  3. from torch.autograd import Variable
  4. from torch.nn.parameter import Parameter
  5. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  6. from loss_scaler import DynamicLossScaler, LossScaler
  7. FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
  8. HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
  9. def conversion_helper(val, conversion):
  10. """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
  11. if not isinstance(val, (tuple, list)):
  12. return conversion(val)
  13. rtn = [conversion_helper(v, conversion) for v in val]
  14. if isinstance(val, tuple):
  15. rtn = tuple(rtn)
  16. return rtn
  17. def fp32_to_fp16(val):
  18. """Convert fp32 `val` to fp16"""
  19. def half_conversion(val):
  20. val_typecheck = val
  21. if isinstance(val_typecheck, (Parameter, Variable)):
  22. val_typecheck = val.data
  23. if isinstance(val_typecheck, FLOAT_TYPES):
  24. val = val.half()
  25. return val
  26. return conversion_helper(val, half_conversion)
  27. def fp16_to_fp32(val):
  28. """Convert fp16 `val` to fp32"""
  29. def float_conversion(val):
  30. val_typecheck = val
  31. if isinstance(val_typecheck, (Parameter, Variable)):
  32. val_typecheck = val.data
  33. if isinstance(val_typecheck, HALF_TYPES):
  34. val = val.float()
  35. return val
  36. return conversion_helper(val, float_conversion)
  37. class FP16_Module(nn.Module):
  38. def __init__(self, module):
  39. super(FP16_Module, self).__init__()
  40. self.add_module('module', module.half())
  41. def forward(self, *inputs, **kwargs):
  42. return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
  43. class FP16_Optimizer(object):
  44. """
  45. FP16_Optimizer is designed to wrap an existing PyTorch optimizer,
  46. and enable an fp16 model to be trained using a master copy of fp32 weights.
  47. Args:
  48. optimizer (torch.optim.optimizer): Existing optimizer containing initialized fp16 parameters. Internally, FP16_Optimizer replaces the passed optimizer's fp16 parameters with new fp32 parameters copied from the original ones. FP16_Optimizer also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy after each step.
  49. static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale fp16 gradients computed by the model. Scaled gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so static_loss_scale should not affect learning rate.
  50. dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any static_loss_scale option.
  51. """
  52. def __init__(self, optimizer, static_loss_scale=1.0, dynamic_loss_scale=False):
  53. if not torch.cuda.is_available:
  54. raise SystemError('Cannot use fp16 without CUDA')
  55. self.fp16_param_groups = []
  56. self.fp32_param_groups = []
  57. self.fp32_flattened_groups = []
  58. for i, param_group in enumerate(optimizer.param_groups):
  59. print("FP16_Optimizer processing param group {}:".format(i))
  60. fp16_params_this_group = []
  61. fp32_params_this_group = []
  62. for param in param_group['params']:
  63. if param.requires_grad:
  64. if param.type() == 'torch.cuda.HalfTensor':
  65. print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
  66. .format(param.size()))
  67. fp16_params_this_group.append(param)
  68. elif param.type() == 'torch.cuda.FloatTensor':
  69. print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
  70. .format(param.size()))
  71. fp32_params_this_group.append(param)
  72. else:
  73. raise TypeError("Wrapped parameters must be either "
  74. "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
  75. "Received {}".format(param.type()))
  76. fp32_flattened_this_group = None
  77. if len(fp16_params_this_group) > 0:
  78. fp32_flattened_this_group = _flatten_dense_tensors(
  79. [param.detach().data.clone().float() for param in fp16_params_this_group])
  80. fp32_flattened_this_group = Variable(fp32_flattened_this_group, requires_grad = True)
  81. fp32_flattened_this_group.grad = fp32_flattened_this_group.new(
  82. *fp32_flattened_this_group.size())
  83. # python's lovely list concatenation via +
  84. if fp32_flattened_this_group is not None:
  85. param_group['params'] = [fp32_flattened_this_group] + fp32_params_this_group
  86. else:
  87. param_group['params'] = fp32_params_this_group
  88. self.fp16_param_groups.append(fp16_params_this_group)
  89. self.fp32_param_groups.append(fp32_params_this_group)
  90. self.fp32_flattened_groups.append(fp32_flattened_this_group)
  91. # print("self.fp32_flattened_groups = ", self.fp32_flattened_groups)
  92. # print("self.fp16_param_groups = ", self.fp16_param_groups)
  93. self.optimizer = optimizer.__class__(optimizer.param_groups)
  94. # self.optimizer.load_state_dict(optimizer.state_dict())
  95. self.param_groups = self.optimizer.param_groups
  96. if dynamic_loss_scale:
  97. self.dynamic_loss_scale = True
  98. self.loss_scaler = DynamicLossScaler()
  99. else:
  100. self.dynamic_loss_scale = False
  101. self.loss_scaler = LossScaler(static_loss_scale)
  102. self.overflow = False
  103. self.first_closure_call_this_step = True
  104. def zero_grad(self):
  105. """
  106. Zero fp32 and fp16 parameter grads.
  107. """
  108. self.optimizer.zero_grad()
  109. for fp16_group in self.fp16_param_groups:
  110. for param in fp16_group:
  111. if param.grad is not None:
  112. param.grad.detach_() # This does appear in torch.optim.optimizer.zero_grad(),
  113. # but I'm not sure why it's needed.
  114. param.grad.zero_()
  115. def _check_overflow(self):
  116. params = []
  117. for group in self.fp16_param_groups:
  118. for param in group:
  119. params.append(param)
  120. for group in self.fp32_param_groups:
  121. for param in group:
  122. params.append(param)
  123. self.overflow = self.loss_scaler.has_overflow(params)
  124. def _update_scale(self, has_overflow=False):
  125. self.loss_scaler.update_scale(has_overflow)
  126. def _copy_grads_fp16_to_fp32(self):
  127. for fp32_group, fp16_group in zip(self.fp32_flattened_groups, self.fp16_param_groups):
  128. if len(fp16_group) > 0:
  129. # This might incur one more deep copy than is necessary.
  130. fp32_group.grad.data.copy_(
  131. _flatten_dense_tensors([fp16_param.grad.data for fp16_param in fp16_group]))
  132. def _downscale_fp32(self):
  133. if self.loss_scale != 1.0:
  134. for param_group in self.optimizer.param_groups:
  135. for param in param_group['params']:
  136. param.grad.data.mul_(1./self.loss_scale)
  137. def clip_fp32_grads(self, clip=-1):
  138. if not self.overflow:
  139. fp32_params = []
  140. for param_group in self.optimizer.param_groups:
  141. for param in param_group['params']:
  142. fp32_params.append(param)
  143. if clip > 0:
  144. return torch.nn.utils.clip_grad_norm(fp32_params, clip)
  145. def _copy_params_fp32_to_fp16(self):
  146. for fp16_group, fp32_group in zip(self.fp16_param_groups, self.fp32_flattened_groups):
  147. if len(fp16_group) > 0:
  148. for fp16_param, fp32_data in zip(fp16_group,
  149. _unflatten_dense_tensors(fp32_group.data, fp16_group)):
  150. fp16_param.data.copy_(fp32_data)
  151. def state_dict(self):
  152. """
  153. Returns a dict containing the current state of this FP16_Optimizer instance.
  154. This dict contains attributes of FP16_Optimizer, as well as the state_dict
  155. of the contained Pytorch optimizer.
  156. Untested.
  157. """
  158. state_dict = {}
  159. state_dict['loss_scaler'] = self.loss_scaler
  160. state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
  161. state_dict['overflow'] = self.overflow
  162. state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
  163. state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
  164. return state_dict
  165. def load_state_dict(self, state_dict):
  166. """
  167. Loads a state_dict created by an earlier call to state_dict.
  168. Untested.
  169. """
  170. self.loss_scaler = state_dict['loss_scaler']
  171. self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
  172. self.overflow = state_dict['overflow']
  173. self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
  174. self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
  175. def step(self, closure=None): # could add clip option.
  176. """
  177. If no closure is supplied, step should be called after fp16_optimizer_obj.backward(loss).
  178. step updates the fp32 master copy of parameters using the optimizer supplied to
  179. FP16_Optimizer's constructor, then copies the updated fp32 params into the fp16 params
  180. originally referenced by Fp16_Optimizer's constructor, so the user may immediately run
  181. another forward pass using their model.
  182. If a closure is supplied, step may be called without a prior call to self.backward(loss).
  183. However, the user should take care that any loss.backward() call within the closure
  184. has been replaced by fp16_optimizer_obj.backward(loss).
  185. Args:
  186. closure (optional): Closure that will be supplied to the underlying optimizer originally passed to FP16_Optimizer's constructor. closure should call zero_grad on the FP16_Optimizer object, compute the loss, call .backward(loss), and return the loss.
  187. Closure example::
  188. # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
  189. # existing pytorch optimizer.
  190. for input, target in dataset:
  191. def closure():
  192. optimizer.zero_grad()
  193. output = model(input)
  194. loss = loss_fn(output, target)
  195. optimizer.backward(loss)
  196. return loss
  197. optimizer.step(closure)
  198. .. note::
  199. The only changes that need to be made compared to
  200. `ordinary optimizer closures`_ are that "optimizer" itself should be an instance of
  201. FP16_Optimizer, and that the call to loss.backward should be replaced by
  202. optimizer.backward(loss).
  203. .. warning::
  204. Currently, calling step with a closure is not compatible with dynamic loss scaling.
  205. .. _`ordinary optimizer closures`:
  206. http://pytorch.org/docs/master/optim.html#optimizer-step-closure
  207. """
  208. if closure is not None and isinstance(self.loss_scaler, DynamicLossScaler):
  209. raise TypeError("Using step with a closure is currently not "
  210. "compatible with dynamic loss scaling.")
  211. scale = self.loss_scaler.loss_scale
  212. self._update_scale(self.overflow)
  213. if self.overflow:
  214. print("OVERFLOW! Skipping step. Attempted loss scale: {}".format(scale))
  215. return
  216. if closure is not None:
  217. self._step_with_closure(closure)
  218. else:
  219. self.optimizer.step()
  220. self._copy_params_fp32_to_fp16()
  221. return
  222. def _step_with_closure(self, closure):
  223. def wrapped_closure():
  224. if self.first_closure_call_this_step:
  225. """
  226. We expect that the fp16 params are initially fresh on entering self.step(),
  227. so _copy_params_fp32_to_fp16() is unnecessary the first time wrapped_closure()
  228. is called within self.optimizer.step().
  229. """
  230. self.first_closure_call_this_step = False
  231. else:
  232. """
  233. If self.optimizer.step() internally calls wrapped_closure more than once,
  234. it may update the fp32 params after each call. However, self.optimizer
  235. doesn't know about the fp16 params at all. If the fp32 params get updated,
  236. we can't rely on self.optimizer to refresh the fp16 params. We need
  237. to handle that manually:
  238. """
  239. self._copy_params_fp32_to_fp16()
  240. """
  241. Our API expects the user to give us ownership of the backward() call by
  242. replacing all calls to loss.backward() with optimizer.backward(loss).
  243. This requirement holds whether or not the call to backward() is made within
  244. a closure.
  245. If the user is properly calling optimizer.backward(loss) within "closure,"
  246. calling closure() here will give the fp32 master params fresh gradients
  247. for the optimizer to play with,
  248. so all wrapped_closure needs to do is call closure() and return the loss.
  249. """
  250. temp_loss = closure()
  251. return temp_loss
  252. self.optimizer.step(wrapped_closure)
  253. self.first_closure_call_this_step = True
  254. def backward(self, loss, update_fp32_grads=True):
  255. """
  256. fp16_optimizer_obj.backward performs the following conceptual operations:
  257. fp32_loss = loss.float() (see first Note below)
  258. scaled_loss = fp32_loss*loss_scale
  259. scaled_loss.backward(), which accumulates scaled gradients into the .grad attributes of the
  260. fp16 model's leaves.
  261. fp16 grads are then copied to the stored fp32 params' .grad attributes (see second Note).
  262. Finally, fp32 grads are divided by loss_scale.
  263. In this way, after fp16_optimizer_obj.backward, the fp32 parameters have fresh gradients,
  264. and fp16_optimizer_obj.step may be called.
  265. .. note::
  266. Converting the loss to fp32 before applying the loss scale provides some
  267. additional safety against overflow if the user has supplied an fp16 value.
  268. However, for maximum overflow safety, the user should
  269. compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
  270. fp16_optimizer_obj.backward.
  271. .. note::
  272. The gradients found in an fp16 model's leaves after a call to
  273. fp16_optimizer_obj.backward should not be regarded as valid in general,
  274. because it's possible
  275. they have been scaled (and in the case of dynamic loss scaling,
  276. the scale factor may silently change over time).
  277. If the user wants to inspect gradients after a call to fp16_optimizer_obj.backward,
  278. he/she should query the .grad attribute of FP16_Optimizer's stored fp32 parameters.
  279. Args:
  280. loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
  281. update_fp32_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay this copy, which is useful to eliminate redundant fp16->fp32 grad copies if fp16_optimizer_obj.backward is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling fp16_optimizer_obj.update_fp32_grads before calling fp16_optimizer_obj.step.
  282. Example::
  283. # Ordinary operation:
  284. optimizer.backward(loss)
  285. # Naive operation with multiple losses (technically valid, but less efficient):
  286. # fp32 grads will be correct after the second call, but
  287. # the first call incurs an unnecessary fp16->fp32 grad copy.
  288. optimizer.backward(loss1)
  289. optimizer.backward(loss2)
  290. # More efficient way to handle multiple losses:
  291. # The fp16->fp32 grad copy is delayed until fp16 grads from all
  292. # losses have been accumulated.
  293. optimizer.backward(loss1, update_fp32_grads=False)
  294. optimizer.backward(loss2, update_fp32_grads=False)
  295. optimizer.update_fp32_grads()
  296. """
  297. self.loss_scaler.backward(loss.float())
  298. if update_fp32_grads:
  299. self.update_fp32_grads()
  300. def update_fp32_grads(self):
  301. """
  302. Copy the .grad attribute from stored references to fp16 parameters to
  303. the .grad attribute of the master fp32 parameters that are directly
  304. updated by the optimizer. :attr:`update_fp32_grads` only needs to be called if
  305. fp16_optimizer_obj.backward was called with update_fp32_grads=False.
  306. """
  307. if self.dynamic_loss_scale:
  308. self._check_overflow()
  309. if self.overflow: return
  310. self._copy_grads_fp16_to_fp32()
  311. self._downscale_fp32()
  312. @property
  313. def loss_scale(self):
  314. return self.loss_scaler.loss_scale