diff --git a/README.md b/README.md index 5353ed6..193c152 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ Distributed and FP16 support relies on work by Christian Sarofeen and NVIDIA's 2. Clone this repo: `git clone https://github.com/NVIDIA/tacotron2.git` 3. CD into this repo: `cd tacotron2` 4. Update .wav paths: `sed -i -- 's,DUMMY,ljs_dataset_folder/wavs,g' filelists/*.txt` + - Alternatively, set `load_mel_from_disk=True` in `hparams.py` and update mel-spectrogram paths 5. Install [pytorch 0.4](https://github.com/pytorch/pytorch) 6. Install python requirements or build docker image - Install python requirements: `pip install -r requirements.txt` diff --git a/data_utils.py b/data_utils.py index 786a282..09f42ac 100644 --- a/data_utils.py +++ b/data_utils.py @@ -1,4 +1,5 @@ import random +import numpy as np import torch import torch.utils.data @@ -19,6 +20,7 @@ class TextMelLoader(torch.utils.data.Dataset): self.text_cleaners = hparams.text_cleaners self.max_wav_value = hparams.max_wav_value self.sampling_rate = hparams.sampling_rate + self.load_mel_from_disk = hparams.load_mel_from_disk self.stft = layers.TacotronSTFT( hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, @@ -35,12 +37,19 @@ class TextMelLoader(torch.utils.data.Dataset): return (text, mel) def get_mel(self, filename): - audio = load_wav_to_torch(filename, self.sampling_rate) - audio_norm = audio / self.max_wav_value - audio_norm = audio_norm.unsqueeze(0) - audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) - melspec = self.stft.mel_spectrogram(audio_norm) - melspec = torch.squeeze(melspec, 0) + if not self.load_mel_from_disk: + audio = load_wav_to_torch(filename, self.sampling_rate) + audio_norm = audio / self.max_wav_value + audio_norm = audio_norm.unsqueeze(0) + audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) + melspec = self.stft.mel_spectrogram(audio_norm) + melspec = torch.squeeze(melspec, 0) + else: + melspec = torch.from_numpy(np.load(filename)) + assert melspec.size(0) == self.stft.n_mel_channels, ( + 'Mel dimension mismatch: given {}, expected {}'.format( + melspec.size(0), self.stft.n_mel_channels)) + return melspec def get_text(self, text): diff --git a/hparams.py b/hparams.py index c5379cb..f692fa0 100644 --- a/hparams.py +++ b/hparams.py @@ -23,6 +23,7 @@ def create_hparams(hparams_string=None, verbose=False): ################################ # Data Parameters # ################################ + load_mel_from_disk=False, training_files='filelists/ljs_audio_text_train_filelist.txt', validation_files='filelists/ljs_audio_text_val_filelist.txt', text_cleaners=['english_cleaners'],