diff --git a/data_utils.py b/data_utils.py index 786a282..09f42ac 100644 --- a/data_utils.py +++ b/data_utils.py @@ -1,4 +1,5 @@ import random +import numpy as np import torch import torch.utils.data @@ -19,6 +20,7 @@ class TextMelLoader(torch.utils.data.Dataset): self.text_cleaners = hparams.text_cleaners self.max_wav_value = hparams.max_wav_value self.sampling_rate = hparams.sampling_rate + self.load_mel_from_disk = hparams.load_mel_from_disk self.stft = layers.TacotronSTFT( hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, @@ -35,12 +37,19 @@ class TextMelLoader(torch.utils.data.Dataset): return (text, mel) def get_mel(self, filename): - audio = load_wav_to_torch(filename, self.sampling_rate) - audio_norm = audio / self.max_wav_value - audio_norm = audio_norm.unsqueeze(0) - audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) - melspec = self.stft.mel_spectrogram(audio_norm) - melspec = torch.squeeze(melspec, 0) + if not self.load_mel_from_disk: + audio = load_wav_to_torch(filename, self.sampling_rate) + audio_norm = audio / self.max_wav_value + audio_norm = audio_norm.unsqueeze(0) + audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) + melspec = self.stft.mel_spectrogram(audio_norm) + melspec = torch.squeeze(melspec, 0) + else: + melspec = torch.from_numpy(np.load(filename)) + assert melspec.size(0) == self.stft.n_mel_channels, ( + 'Mel dimension mismatch: given {}, expected {}'.format( + melspec.size(0), self.stft.n_mel_channels)) + return melspec def get_text(self, text):