Browse Source

data_utils.py: adding support for loading mel from disk

master
Rafael Valle 6 years ago
parent
commit
62d2c8b957
1 changed files with 15 additions and 6 deletions
  1. +15
    -6
      data_utils.py

+ 15
- 6
data_utils.py View File

@ -1,4 +1,5 @@
import random import random
import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
@ -19,6 +20,7 @@ class TextMelLoader(torch.utils.data.Dataset):
self.text_cleaners = hparams.text_cleaners self.text_cleaners = hparams.text_cleaners
self.max_wav_value = hparams.max_wav_value self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate self.sampling_rate = hparams.sampling_rate
self.load_mel_from_disk = hparams.load_mel_from_disk
self.stft = layers.TacotronSTFT( self.stft = layers.TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.filter_length, hparams.hop_length, hparams.win_length,
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
@ -35,12 +37,19 @@ class TextMelLoader(torch.utils.data.Dataset):
return (text, mel) return (text, mel)
def get_mel(self, filename): def get_mel(self, filename):
audio = load_wav_to_torch(filename, self.sampling_rate)
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = self.stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
if not self.load_mel_from_disk:
audio = load_wav_to_torch(filename, self.sampling_rate)
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = self.stft.mel_spectrogram(audio_norm)
melspec = torch.squeeze(melspec, 0)
else:
melspec = torch.from_numpy(np.load(filename))
assert melspec.size(0) == self.stft.n_mel_channels, (
'Mel dimension mismatch: given {}, expected {}'.format(
melspec.size(0), self.stft.n_mel_channels))
return melspec return melspec
def get_text(self, text): def get_text(self, text):

Loading…
Cancel
Save