Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
6.1 KiB

  1. # We retain the copyright notice by NVIDIA from the original code. However, we
  2. # we reserve our rights on the modifications based on the original code.
  3. #
  4. # *****************************************************************************
  5. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  6. #
  7. # Redistribution and use in source and binary forms, with or without
  8. # modification, are permitted provided that the following conditions are met:
  9. # * Redistributions of source code must retain the above copyright
  10. # notice, this list of conditions and the following disclaimer.
  11. # * Redistributions in binary form must reproduce the above copyright
  12. # notice, this list of conditions and the following disclaimer in the
  13. # documentation and/or other materials provided with the distribution.
  14. # * Neither the name of the NVIDIA CORPORATION nor the
  15. # names of its contributors may be used to endorse or promote products
  16. # derived from this software without specific prior written permission.
  17. #
  18. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  19. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  20. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  21. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  22. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  23. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  24. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  25. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  26. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  27. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  28. #
  29. # *****************************************************************************\
  30. import os
  31. import random
  32. import argparse
  33. import json
  34. import torch
  35. import torch.utils.data
  36. import sys
  37. from scipy.io.wavfile import read
  38. # We're using the audio processing from TacoTron2 to make sure it matches
  39. from TacotronSTFT import TacotronSTFT
  40. MAX_WAV_VALUE = 32768.0
  41. def files_to_list(filename):
  42. """
  43. Takes a text file of filenames and makes a list of filenames
  44. """
  45. with open(filename, encoding='utf-8') as f:
  46. files = f.readlines()
  47. files = [f.rstrip() for f in files]
  48. return files
  49. def load_wav_to_torch(full_path):
  50. """
  51. Loads wavdata into torch array
  52. """
  53. sampling_rate, data = read(full_path)
  54. return torch.from_numpy(data).float(), sampling_rate
  55. class Mel2Samp(torch.utils.data.Dataset):
  56. """
  57. This is the main class that calculates the spectrogram and returns the
  58. spectrogram, audio pair.
  59. """
  60. def __init__(self, n_audio_channel, training_files, segment_length,
  61. filter_length, hop_length, win_length, sampling_rate, mel_fmin,
  62. mel_fmax):
  63. self.audio_files = files_to_list(training_files)
  64. random.seed(1234)
  65. random.shuffle(self.audio_files)
  66. self.stft = TacotronSTFT(filter_length=filter_length,
  67. hop_length=hop_length,
  68. win_length=win_length,
  69. sampling_rate=sampling_rate,
  70. mel_fmin=mel_fmin, mel_fmax=mel_fmax,
  71. n_group=n_audio_channel)
  72. self.segment_length = segment_length
  73. self.sampling_rate = sampling_rate
  74. def get_mel(self, audio):
  75. audio_norm = audio / MAX_WAV_VALUE
  76. audio_norm = audio_norm.unsqueeze(0)
  77. audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
  78. melspec = self.stft.mel_spectrogram(audio_norm)
  79. melspec = torch.squeeze(melspec, 0)
  80. return melspec
  81. def __getitem__(self, index):
  82. # Read audio
  83. filename = self.audio_files[index]
  84. audio, sampling_rate = load_wav_to_torch(filename)
  85. if sampling_rate != self.sampling_rate:
  86. raise ValueError("{} SR doesn't match target {} SR".format(
  87. sampling_rate, self.sampling_rate))
  88. # Take segment
  89. if audio.size(0) >= self.segment_length:
  90. max_audio_start = audio.size(0) - self.segment_length
  91. audio_start = random.randint(0, max_audio_start)
  92. audio = audio[audio_start:audio_start+self.segment_length]
  93. else:
  94. audio = torch.nn.functional.pad(
  95. audio, (0, self.segment_length - audio.size(0)),
  96. 'constant').data
  97. mel = self.get_mel(audio)
  98. audio = audio / MAX_WAV_VALUE
  99. return (mel, audio)
  100. def __len__(self):
  101. return len(self.audio_files)
  102. # ===================================================================
  103. # Takes directory of clean audio and makes directory of spectrograms
  104. # Useful for making test sets
  105. # ===================================================================
  106. if __name__ == "__main__":
  107. # Get defaults so it can work with no Sacred
  108. parser = argparse.ArgumentParser()
  109. parser.add_argument('-f', "--filelist_path", required=True)
  110. parser.add_argument('-c', '--config', type=str,
  111. help='JSON file for configuration')
  112. parser.add_argument('-o', '--output_dir', type=str,
  113. help='Output directory')
  114. args = parser.parse_args()
  115. with open(args.config) as f:
  116. data = f.read()
  117. config = json.loads(data)
  118. data_config = config["data_config"]
  119. squeezewave_config = config["squeezewave_config"]
  120. mel2samp = Mel2Samp(squeezewave_config['n_audio_channel'], **data_config)
  121. filepaths = files_to_list(args.filelist_path)
  122. # Make directory if it doesn't exist
  123. if not os.path.isdir(args.output_dir):
  124. os.makedirs(args.output_dir)
  125. os.chmod(args.output_dir, 0o775)
  126. for filepath in filepaths:
  127. audio, sr = load_wav_to_torch(filepath)
  128. melspectrogram = mel2samp.get_mel(audio)
  129. filename = os.path.basename(filepath)
  130. new_filepath = args.output_dir + '/' + filename + '.pt'
  131. print(new_filepath)
  132. torch.save(melspectrogram, new_filepath)