Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
2.8 KiB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
  1. import torch
  2. import torch.nn as nn
  3. import matplotlib
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import time
  7. import os
  8. from fastspeech import FastSpeech
  9. from text import text_to_sequence
  10. import hparams as hp
  11. import utils
  12. import audio as Audio
  13. import glow
  14. import waveglow
  15. import time
  16. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  17. def get_FastSpeech(num):
  18. checkpoint_path = "checkpoint_" + str(num) + ".pth.tar"
  19. model = nn.DataParallel(FastSpeech()).to(device)
  20. model.load_state_dict(torch.load(os.path.join(
  21. hp.checkpoint_path, checkpoint_path),map_location=device)['model'])
  22. model.eval()
  23. return model
  24. def synthesis(model, text, alpha=1.0):
  25. text = np.array(text_to_sequence(text, hp.text_cleaners))
  26. text = np.stack([text])
  27. src_pos = np.array([i+1 for i in range(text.shape[1])])
  28. src_pos = np.stack([src_pos])
  29. with torch.no_grad():
  30. sequence = torch.autograd.Variable(
  31. torch.from_numpy(text)).long()
  32. src_pos = torch.autograd.Variable(
  33. torch.from_numpy(src_pos)).long()
  34. mel, mel_postnet = model.module.forward(sequence, src_pos, alpha=alpha)
  35. #script for generating torch script
  36. #traced_script_module = torch.jit.trace(model,(sequence,src_pos))
  37. #traced_script_module.save("traced_fastspeech_model.pt")
  38. return mel[0].cpu().transpose(0, 1), \
  39. mel_postnet[0].cpu().transpose(0, 1), \
  40. mel.transpose(1, 2), \
  41. mel_postnet.transpose(1, 2)
  42. if __name__ == "__main__":
  43. # Test
  44. num = 112000
  45. alpha = 1.0
  46. model = get_FastSpeech(num)
  47. words = "Let’s go out to the airport. The plane landed ten minutes ago."
  48. start = time.time()
  49. mel, mel_postnet, mel_torch, mel_postnet_torch = synthesis(
  50. model, words, alpha=alpha)
  51. if not os.path.exists("results"):
  52. os.mkdir("results")
  53. #do not use any vocoder , mel file generated will be passed to squeezewave vocoder.
  54. """
  55. Audio.tools.inv_mel_spec(mel_postnet, os.path.join(
  56. "results", words + "_" + str(num) + "_griffin_lim.wav"))
  57. wave_glow = utils.get_WaveGlow()
  58. waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(
  59. "results", words + "_" + str(num) + "_waveglow.wav"))
  60. tacotron2 = utils.get_Tacotron2()
  61. mel_tac2, _, _ = utils.load_data_from_tacotron2(words, tacotron2)
  62. waveglow.inference.inference(torch.stack([torch.from_numpy(
  63. mel_tac2).cuda()]), wave_glow, os.path.join("results", "tacotron2.wav"))
  64. utils.plot_data([mel.numpy(), mel_postnet.numpy(), mel_tac2])
  65. """
  66. #melspec = torch.squeeze(mel_postnet_torch, 0)
  67. torch.save(mel_postnet_torch, "../SqueezeWave/mel_spectrograms/test.pt")
  68. end = time.time()
  69. print("MEL Calculation:")
  70. print(end-start)