Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

44 lines
1.4 KiB

  1. import numpy as np
  2. class ScheduledOptim():
  3. ''' A simple wrapper class for learning rate scheduling '''
  4. def __init__(self, optimizer, d_model, n_warmup_steps, current_steps):
  5. self._optimizer = optimizer
  6. self.n_warmup_steps = n_warmup_steps
  7. self.n_current_steps = current_steps
  8. self.init_lr = np.power(d_model, -0.5)
  9. def step_and_update_lr_frozen(self, learning_rate_frozen):
  10. for param_group in self._optimizer.param_groups:
  11. param_group['lr'] = learning_rate_frozen
  12. self._optimizer.step()
  13. def step_and_update_lr(self):
  14. self._update_learning_rate()
  15. self._optimizer.step()
  16. def get_learning_rate(self):
  17. learning_rate = 0.0
  18. for param_group in self._optimizer.param_groups:
  19. learning_rate = param_group['lr']
  20. return learning_rate
  21. def zero_grad(self):
  22. # print(self.init_lr)
  23. self._optimizer.zero_grad()
  24. def _get_lr_scale(self):
  25. return np.min([
  26. np.power(self.n_current_steps, -0.5),
  27. np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
  28. def _update_learning_rate(self):
  29. ''' Learning rate scheduling per step '''
  30. self.n_current_steps += 1
  31. lr = self.init_lr * self._get_lr_scale()
  32. for param_group in self._optimizer.param_groups:
  33. param_group['lr'] = lr