|
@ -1,4 +1,5 @@ |
|
|
import random |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch |
|
|
import torch.utils.data |
|
|
import torch.utils.data |
|
|
|
|
|
|
|
@ -19,6 +20,7 @@ class TextMelLoader(torch.utils.data.Dataset): |
|
|
self.text_cleaners = hparams.text_cleaners |
|
|
self.text_cleaners = hparams.text_cleaners |
|
|
self.max_wav_value = hparams.max_wav_value |
|
|
self.max_wav_value = hparams.max_wav_value |
|
|
self.sampling_rate = hparams.sampling_rate |
|
|
self.sampling_rate = hparams.sampling_rate |
|
|
|
|
|
self.load_mel_from_disk = hparams.load_mel_from_disk |
|
|
self.stft = layers.TacotronSTFT( |
|
|
self.stft = layers.TacotronSTFT( |
|
|
hparams.filter_length, hparams.hop_length, hparams.win_length, |
|
|
hparams.filter_length, hparams.hop_length, hparams.win_length, |
|
|
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, |
|
|
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, |
|
@ -35,12 +37,19 @@ class TextMelLoader(torch.utils.data.Dataset): |
|
|
return (text, mel) |
|
|
return (text, mel) |
|
|
|
|
|
|
|
|
def get_mel(self, filename): |
|
|
def get_mel(self, filename): |
|
|
audio = load_wav_to_torch(filename, self.sampling_rate) |
|
|
|
|
|
audio_norm = audio / self.max_wav_value |
|
|
|
|
|
audio_norm = audio_norm.unsqueeze(0) |
|
|
|
|
|
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) |
|
|
|
|
|
melspec = self.stft.mel_spectrogram(audio_norm) |
|
|
|
|
|
melspec = torch.squeeze(melspec, 0) |
|
|
|
|
|
|
|
|
if not self.load_mel_from_disk: |
|
|
|
|
|
audio = load_wav_to_torch(filename, self.sampling_rate) |
|
|
|
|
|
audio_norm = audio / self.max_wav_value |
|
|
|
|
|
audio_norm = audio_norm.unsqueeze(0) |
|
|
|
|
|
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) |
|
|
|
|
|
melspec = self.stft.mel_spectrogram(audio_norm) |
|
|
|
|
|
melspec = torch.squeeze(melspec, 0) |
|
|
|
|
|
else: |
|
|
|
|
|
melspec = torch.from_numpy(np.load(filename)) |
|
|
|
|
|
assert melspec.size(0) == self.stft.n_mel_channels, ( |
|
|
|
|
|
'Mel dimension mismatch: given {}, expected {}'.format( |
|
|
|
|
|
melspec.size(0), self.stft.n_mel_channels)) |
|
|
|
|
|
|
|
|
return melspec |
|
|
return melspec |
|
|
|
|
|
|
|
|
def get_text(self, text): |
|
|
def get_text(self, text): |
|
|