Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.2 KiB

4 years ago
  1. import torch
  2. import torch.nn as nn
  3. import matplotlib
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import time
  7. import os
  8. from fastspeech import FastSpeech
  9. from text import text_to_sequence
  10. import hparams as hp
  11. import utils
  12. import audio as Audio
  13. import glow
  14. import waveglow
  15. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  16. def get_FastSpeech(num):
  17. checkpoint_path = "checkpoint_" + str(num) + ".pth.tar"
  18. model = nn.DataParallel(FastSpeech()).to(device)
  19. model.load_state_dict(torch.load(os.path.join(
  20. hp.checkpoint_path, checkpoint_path))['model'])
  21. model.eval()
  22. return model
  23. def synthesis(model, text, alpha=1.0):
  24. text = np.array(text_to_sequence(text, hp.text_cleaners))
  25. text = np.stack([text])
  26. src_pos = np.array([i+1 for i in range(text.shape[1])])
  27. src_pos = np.stack([src_pos])
  28. with torch.no_grad():
  29. sequence = torch.autograd.Variable(
  30. torch.from_numpy(text)).cuda().long()
  31. src_pos = torch.autograd.Variable(
  32. torch.from_numpy(src_pos)).cuda().long()
  33. mel, mel_postnet = model.module.forward(sequence, src_pos, alpha=alpha)
  34. return mel[0].cpu().transpose(0, 1), \
  35. mel_postnet[0].cpu().transpose(0, 1), \
  36. mel.transpose(1, 2), \
  37. mel_postnet.transpose(1, 2)
  38. if __name__ == "__main__":
  39. # Test
  40. num = 112000
  41. alpha = 1.0
  42. model = get_FastSpeech(num)
  43. words = "Let’s go out to the airport. The plane landed ten minutes ago."
  44. mel, mel_postnet, mel_torch, mel_postnet_torch = synthesis(
  45. model, words, alpha=alpha)
  46. if not os.path.exists("results"):
  47. os.mkdir("results")
  48. Audio.tools.inv_mel_spec(mel_postnet, os.path.join(
  49. "results", words + "_" + str(num) + "_griffin_lim.wav"))
  50. wave_glow = utils.get_WaveGlow()
  51. waveglow.inference.inference(mel_postnet_torch, wave_glow, os.path.join(
  52. "results", words + "_" + str(num) + "_waveglow.wav"))
  53. tacotron2 = utils.get_Tacotron2()
  54. mel_tac2, _, _ = utils.load_data_from_tacotron2(words, tacotron2)
  55. waveglow.inference.inference(torch.stack([torch.from_numpy(
  56. mel_tac2).cuda()]), wave_glow, os.path.join("results", "tacotron2.wav"))
  57. utils.plot_data([mel.numpy(), mel_postnet.numpy(), mel_tac2])