Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
3.0 KiB

  1. import sys
  2. import copy
  3. import torch
  4. def _check_model_old_version(model):
  5. if hasattr(model.WN[0], 'res_layers') or hasattr(model.WN[0], 'cond_layers'):
  6. return True
  7. else:
  8. return False
  9. def _update_model_res_skip(old_model, new_model):
  10. for idx in range(0, len(new_model.WN)):
  11. wavenet = new_model.WN[idx]
  12. n_channels = wavenet.n_channels
  13. n_layers = wavenet.n_layers
  14. wavenet.res_skip_layers = torch.nn.ModuleList()
  15. for i in range(0, n_layers):
  16. if i < n_layers - 1:
  17. res_skip_channels = 2*n_channels
  18. else:
  19. res_skip_channels = n_channels
  20. res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
  21. skip_layer = torch.nn.utils.remove_weight_norm(wavenet.skip_layers[i])
  22. if i < n_layers - 1:
  23. res_layer = torch.nn.utils.remove_weight_norm(wavenet.res_layers[i])
  24. res_skip_layer.weight = torch.nn.Parameter(torch.cat([res_layer.weight, skip_layer.weight]))
  25. res_skip_layer.bias = torch.nn.Parameter(torch.cat([res_layer.bias, skip_layer.bias]))
  26. else:
  27. res_skip_layer.weight = torch.nn.Parameter(skip_layer.weight)
  28. res_skip_layer.bias = torch.nn.Parameter(skip_layer.bias)
  29. res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
  30. wavenet.res_skip_layers.append(res_skip_layer)
  31. del wavenet.res_layers
  32. del wavenet.skip_layers
  33. def _update_model_cond(old_model, new_model):
  34. for idx in range(0, len(new_model.WN)):
  35. wavenet = new_model.WN[idx]
  36. n_channels = wavenet.n_channels
  37. n_layers = wavenet.n_layers
  38. n_mel_channels = wavenet.cond_layers[0].weight.shape[1]
  39. cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
  40. cond_layer_weight = []
  41. cond_layer_bias = []
  42. for i in range(0, n_layers):
  43. _cond_layer = torch.nn.utils.remove_weight_norm(wavenet.cond_layers[i])
  44. cond_layer_weight.append(_cond_layer.weight)
  45. cond_layer_bias.append(_cond_layer.bias)
  46. cond_layer.weight = torch.nn.Parameter(torch.cat(cond_layer_weight))
  47. cond_layer.bias = torch.nn.Parameter(torch.cat(cond_layer_bias))
  48. cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
  49. wavenet.cond_layer = cond_layer
  50. del wavenet.cond_layers
  51. def update_model(old_model):
  52. if not _check_model_old_version(old_model):
  53. return old_model
  54. new_model = copy.deepcopy(old_model)
  55. if hasattr(old_model.WN[0], 'res_layers'):
  56. _update_model_res_skip(old_model, new_model)
  57. if hasattr(old_model.WN[0], 'cond_layers'):
  58. _update_model_cond(old_model, new_model)
  59. return new_model
  60. if __name__ == '__main__':
  61. old_model_path = sys.argv[1]
  62. new_model_path = sys.argv[2]
  63. model = torch.load(old_model_path)
  64. model['model'] = update_model(model['model'])
  65. torch.save(model, new_model_path)