Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.3 KiB

  1. import torch
  2. import numpy as np
  3. from scipy.io.wavfile import read
  4. from scipy.io.wavfile import write
  5. import audio.stft as stft
  6. import audio.hparams as hparams
  7. from audio.audio_processing import griffin_lim
  8. _stft = stft.TacotronSTFT(
  9. hparams.filter_length, hparams.hop_length, hparams.win_length,
  10. hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
  11. hparams.mel_fmax)
  12. def load_wav_to_torch(full_path):
  13. sampling_rate, data = read(full_path)
  14. return torch.FloatTensor(data.astype(np.float32)), sampling_rate
  15. def get_mel(filename):
  16. audio, sampling_rate = load_wav_to_torch(filename)
  17. if sampling_rate != _stft.sampling_rate:
  18. raise ValueError("{} {} SR doesn't match target {} SR".format(
  19. sampling_rate, _stft.sampling_rate))
  20. audio_norm = audio / hparams.max_wav_value
  21. audio_norm = audio_norm.unsqueeze(0)
  22. audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
  23. melspec = _stft.mel_spectrogram(audio_norm)
  24. melspec = torch.squeeze(melspec, 0)
  25. # melspec = torch.from_numpy(_normalize(melspec.numpy()))
  26. return melspec
  27. def get_mel_from_wav(audio):
  28. sampling_rate = hparams.sampling_rate
  29. if sampling_rate != _stft.sampling_rate:
  30. raise ValueError("{} {} SR doesn't match target {} SR".format(
  31. sampling_rate, _stft.sampling_rate))
  32. audio_norm = audio / hparams.max_wav_value
  33. audio_norm = audio_norm.unsqueeze(0)
  34. audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
  35. melspec = _stft.mel_spectrogram(audio_norm)
  36. melspec = torch.squeeze(melspec, 0)
  37. return melspec
  38. def inv_mel_spec(mel, out_filename, griffin_iters=60):
  39. mel = torch.stack([mel])
  40. # mel = torch.stack([torch.from_numpy(_denormalize(mel.numpy()))])
  41. mel_decompress = _stft.spectral_de_normalize(mel)
  42. mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
  43. spec_from_mel_scaling = 1000
  44. spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
  45. spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
  46. spec_from_mel = spec_from_mel * spec_from_mel_scaling
  47. audio = griffin_lim(torch.autograd.Variable(
  48. spec_from_mel[:, :, :-1]), _stft.stft_fn, griffin_iters)
  49. audio = audio.squeeze()
  50. audio = audio.cpu().numpy()
  51. audio_path = out_filename
  52. write(audio_path, hparams.sampling_rate, audio)