You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
4.4 KiB

  1. import random
  2. import numpy as np
  3. import torch
  4. import torch.utils.data
  5. import layers
  6. from utils import load_wav_to_torch, load_filepaths_and_text
  7. from text import text_to_sequence
  8. class TextMelLoader(torch.utils.data.Dataset):
  9. """
  10. 1) loads audio,text pairs
  11. 2) normalizes text and converts them to sequences of one-hot vectors
  12. 3) computes mel-spectrograms from audio files.
  13. """
  14. def __init__(self, audiopaths_and_text, hparams):
  15. self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
  16. self.text_cleaners = hparams.text_cleaners
  17. self.max_wav_value = hparams.max_wav_value
  18. self.sampling_rate = hparams.sampling_rate
  19. self.load_mel_from_disk = hparams.load_mel_from_disk
  20. self.stft = layers.TacotronSTFT(
  21. hparams.filter_length, hparams.hop_length, hparams.win_length,
  22. hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
  23. hparams.mel_fmax)
  24. random.seed(1234)
  25. random.shuffle(self.audiopaths_and_text)
  26. def get_mel_text_pair(self, audiopath_and_text):
  27. # separate filename and text
  28. audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
  29. text = self.get_text(text)
  30. mel = self.get_mel(audiopath)
  31. return (text, mel)
  32. def get_mel(self, filename):
  33. if not self.load_mel_from_disk:
  34. audio, sampling_rate = load_wav_to_torch(filename)
  35. if sampling_rate != self.stft.sampling_rate:
  36. raise ValueError("{} {} SR doesn't match target {} SR".format(
  37. sampling_rate, self.stft.sampling_rate))
  38. audio_norm = audio / self.max_wav_value
  39. audio_norm = audio_norm.unsqueeze(0)
  40. audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
  41. melspec = self.stft.mel_spectrogram(audio_norm)
  42. melspec = torch.squeeze(melspec, 0)
  43. else:
  44. melspec = torch.from_numpy(np.load(filename))
  45. assert melspec.size(0) == self.stft.n_mel_channels, (
  46. 'Mel dimension mismatch: given {}, expected {}'.format(
  47. melspec.size(0), self.stft.n_mel_channels))
  48. return melspec
  49. def get_text(self, text):
  50. text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
  51. return text_norm
  52. def __getitem__(self, index):
  53. return self.get_mel_text_pair(self.audiopaths_and_text[index])
  54. def __len__(self):
  55. return len(self.audiopaths_and_text)
  56. class TextMelCollate():
  57. """ Zero-pads model inputs and targets based on number of frames per setep
  58. """
  59. def __init__(self, n_frames_per_step):
  60. self.n_frames_per_step = n_frames_per_step
  61. def __call__(self, batch):
  62. """Collate's training batch from normalized text and mel-spectrogram
  63. PARAMS
  64. ------
  65. batch: [text_normalized, mel_normalized]
  66. """
  67. # Right zero-pad all one-hot text sequences to max input length
  68. input_lengths, ids_sorted_decreasing = torch.sort(
  69. torch.LongTensor([len(x[0]) for x in batch]),
  70. dim=0, descending=True)
  71. max_input_len = input_lengths[0]
  72. text_padded = torch.LongTensor(len(batch), max_input_len)
  73. text_padded.zero_()
  74. for i in range(len(ids_sorted_decreasing)):
  75. text = batch[ids_sorted_decreasing[i]][0]
  76. text_padded[i, :text.size(0)] = text
  77. # Right zero-pad mel-spec
  78. num_mels = batch[0][1].size(0)
  79. max_target_len = max([x[1].size(1) for x in batch])
  80. if max_target_len % self.n_frames_per_step != 0:
  81. max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
  82. assert max_target_len % self.n_frames_per_step == 0
  83. # include mel padded and gate padded
  84. mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
  85. mel_padded.zero_()
  86. gate_padded = torch.FloatTensor(len(batch), max_target_len)
  87. gate_padded.zero_()
  88. output_lengths = torch.LongTensor(len(batch))
  89. for i in range(len(ids_sorted_decreasing)):
  90. mel = batch[ids_sorted_decreasing[i]][1]
  91. mel_padded[i, :, :mel.size(1)] = mel
  92. gate_padded[i, mel.size(1)-1:] = 1
  93. output_lengths[i] = mel.size(1)
  94. return text_padded, input_lengths, mel_padded, gate_padded, \
  95. output_lengths