Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

66 lines
2.3 KiB

import torch
import numpy as np
from scipy.io.wavfile import read
from scipy.io.wavfile import write
import audio.stft as stft
import audio.hparams as hparams
from audio.audio_processing import griffin_lim
_stft = stft.TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length,
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
hparams.mel_fmax)
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def get_mel(filename):
audio, sampling_rate = load_wav_to_torch(filename)
if sampling_rate != _stft.sampling_rate:
raise ValueError("{} {} SR doesn't match target {} SR".format(
sampling_rate, _stft.sampling_rate))
audio_norm = audio / hparams.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = _stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
# melspec = torch.from_numpy(_normalize(melspec.numpy()))
return melspec
def get_mel_from_wav(audio):
sampling_rate = hparams.sampling_rate
if sampling_rate != _stft.sampling_rate:
raise ValueError("{} {} SR doesn't match target {} SR".format(
sampling_rate, _stft.sampling_rate))
audio_norm = audio / hparams.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = _stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
return melspec
def inv_mel_spec(mel, out_filename, griffin_iters=60):
mel = torch.stack([mel])
# mel = torch.stack([torch.from_numpy(_denormalize(mel.numpy()))])
mel_decompress = _stft.spectral_de_normalize(mel)
mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
spec_from_mel_scaling = 1000
spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
spec_from_mel = spec_from_mel * spec_from_mel_scaling
audio = griffin_lim(torch.autograd.Variable(
spec_from_mel[:, :, :-1]), _stft.stft_fn, griffin_iters)
audio = audio.squeeze()
audio = audio.cpu().numpy()
audio_path = out_filename
write(audio_path, hparams.sampling_rate, audio)