You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
4.3 KiB

  1. import random
  2. import numpy as np
  3. import torch
  4. import torch.utils.data
  5. import layers
  6. from utils import load_wav_to_torch, load_filepaths_and_text
  7. from text import text_to_sequence
  8. class TextMelLoader(torch.utils.data.Dataset):
  9. """
  10. 1) loads audio,text pairs
  11. 2) normalizes text and converts them to sequences of one-hot vectors
  12. 3) computes mel-spectrograms from audio files.
  13. """
  14. def __init__(self, audiopaths_and_text, hparams, shuffle=True):
  15. self.audiopaths_and_text = load_filepaths_and_text(
  16. audiopaths_and_text, hparams.sort_by_length)
  17. self.text_cleaners = hparams.text_cleaners
  18. self.max_wav_value = hparams.max_wav_value
  19. self.sampling_rate = hparams.sampling_rate
  20. self.load_mel_from_disk = hparams.load_mel_from_disk
  21. self.stft = layers.TacotronSTFT(
  22. hparams.filter_length, hparams.hop_length, hparams.win_length,
  23. hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
  24. hparams.mel_fmax)
  25. random.seed(1234)
  26. if shuffle:
  27. random.shuffle(self.audiopaths_and_text)
  28. def get_mel_text_pair(self, audiopath_and_text):
  29. # separate filename and text
  30. audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
  31. text = self.get_text(text)
  32. mel = self.get_mel(audiopath)
  33. return (text, mel)
  34. def get_mel(self, filename):
  35. if not self.load_mel_from_disk:
  36. audio = load_wav_to_torch(filename, self.sampling_rate)
  37. audio_norm = audio / self.max_wav_value
  38. audio_norm = audio_norm.unsqueeze(0)
  39. audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
  40. melspec = self.stft.mel_spectrogram(audio_norm)
  41. melspec = torch.squeeze(melspec, 0)
  42. else:
  43. melspec = torch.from_numpy(np.load(filename))
  44. assert melspec.size(0) == self.stft.n_mel_channels, (
  45. 'Mel dimension mismatch: given {}, expected {}'.format(
  46. melspec.size(0), self.stft.n_mel_channels))
  47. return melspec
  48. def get_text(self, text):
  49. text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
  50. return text_norm
  51. def __getitem__(self, index):
  52. return self.get_mel_text_pair(self.audiopaths_and_text[index])
  53. def __len__(self):
  54. return len(self.audiopaths_and_text)
  55. class TextMelCollate():
  56. """ Zero-pads model inputs and targets based on number of frames per setep
  57. """
  58. def __init__(self, n_frames_per_step):
  59. self.n_frames_per_step = n_frames_per_step
  60. def __call__(self, batch):
  61. """Collate's training batch from normalized text and mel-spectrogram
  62. PARAMS
  63. ------
  64. batch: [text_normalized, mel_normalized]
  65. """
  66. # Right zero-pad all one-hot text sequences to max input length
  67. input_lengths, ids_sorted_decreasing = torch.sort(
  68. torch.LongTensor([len(x[0]) for x in batch]),
  69. dim=0, descending=True)
  70. max_input_len = input_lengths[0]
  71. text_padded = torch.LongTensor(len(batch), max_input_len)
  72. text_padded.zero_()
  73. for i in range(len(ids_sorted_decreasing)):
  74. text = batch[ids_sorted_decreasing[i]][0]
  75. text_padded[i, :text.size(0)] = text
  76. # Right zero-pad mel-spec with extra single zero vector to mark the end
  77. num_mels = batch[0][1].size(0)
  78. max_target_len = max([x[1].size(1) for x in batch]) + 1
  79. if max_target_len % self.n_frames_per_step != 0:
  80. max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
  81. assert max_target_len % self.n_frames_per_step == 0
  82. # include mel padded and gate padded
  83. mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
  84. mel_padded.zero_()
  85. gate_padded = torch.FloatTensor(len(batch), max_target_len)
  86. gate_padded.zero_()
  87. output_lengths = torch.LongTensor(len(batch))
  88. for i in range(len(ids_sorted_decreasing)):
  89. mel = batch[ids_sorted_decreasing[i]][1]
  90. mel_padded[i, :, :mel.size(1)] = mel
  91. gate_padded[i, mel.size(1):] = 1
  92. output_lengths[i] = mel.size(1)
  93. return text_padded, input_lengths, mel_padded, gate_padded, \
  94. output_lengths