Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

147 lines
5.9 KiB

  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************\
  27. # from tacotron2.layers import TacotronSTFT
  28. import os
  29. import random
  30. import argparse
  31. import json
  32. import torch
  33. import torch.utils.data
  34. import sys
  35. from scipy.io.wavfile import read
  36. # We're using the audio processing from TacoTron2 to make sure it matches
  37. sys.path.insert(0, 'tacotron2')
  38. MAX_WAV_VALUE = 32768.0
  39. def files_to_list(filename):
  40. """
  41. Takes a text file of filenames and makes a list of filenames
  42. """
  43. with open(filename, encoding='utf-8') as f:
  44. files = f.readlines()
  45. files = [f.rstrip() for f in files]
  46. return files
  47. # def load_wav_to_torch(full_path):
  48. # """
  49. # Loads wavdata into torch array
  50. # """
  51. # sampling_rate, data = read(full_path)
  52. # return torch.from_numpy(data).float(), sampling_rate
  53. # class Mel2Samp(torch.utils.data.Dataset):
  54. # """
  55. # This is the main class that calculates the spectrogram and returns the
  56. # spectrogram, audio pair.
  57. # """
  58. # def __init__(self, training_files, segment_length, filter_length,
  59. # hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
  60. # self.audio_files = files_to_list(training_files)
  61. # random.seed(1234)
  62. # random.shuffle(self.audio_files)
  63. # self.stft = TacotronSTFT(filter_length=filter_length,
  64. # hop_length=hop_length,
  65. # win_length=win_length,
  66. # sampling_rate=sampling_rate,
  67. # mel_fmin=mel_fmin, mel_fmax=mel_fmax)
  68. # self.segment_length = segment_length
  69. # self.sampling_rate = sampling_rate
  70. # def get_mel(self, audio):
  71. # audio_norm = audio / MAX_WAV_VALUE
  72. # audio_norm = audio_norm.unsqueeze(0)
  73. # audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
  74. # melspec = self.stft.mel_spectrogram(audio_norm)
  75. # melspec = torch.squeeze(melspec, 0)
  76. # return melspec
  77. # def __getitem__(self, index):
  78. # # Read audio
  79. # filename = self.audio_files[index]
  80. # audio, sampling_rate = load_wav_to_torch(filename)
  81. # if sampling_rate != self.sampling_rate:
  82. # raise ValueError("{} SR doesn't match target {} SR".format(
  83. # sampling_rate, self.sampling_rate))
  84. # # Take segment
  85. # if audio.size(0) >= self.segment_length:
  86. # max_audio_start = audio.size(0) - self.segment_length
  87. # audio_start = random.randint(0, max_audio_start)
  88. # audio = audio[audio_start:audio_start+self.segment_length]
  89. # else:
  90. # audio = torch.nn.functional.pad(
  91. # audio, (0, self.segment_length - audio.size(0)), 'constant').data
  92. # mel = self.get_mel(audio)
  93. # audio = audio / MAX_WAV_VALUE
  94. # return (mel, audio)
  95. # def __len__(self):
  96. # return len(self.audio_files)
  97. # # ===================================================================
  98. # # Takes directory of clean audio and makes directory of spectrograms
  99. # # Useful for making test sets
  100. # # ===================================================================
  101. # if __name__ == "__main__":
  102. # # Get defaults so it can work with no Sacred
  103. # parser = argparse.ArgumentParser()
  104. # parser.add_argument('-f', "--filelist_path", required=True)
  105. # parser.add_argument('-c', '--config', type=str,
  106. # help='JSON file for configuration')
  107. # parser.add_argument('-o', '--output_dir', type=str,
  108. # help='Output directory')
  109. # args = parser.parse_args()
  110. # with open(args.config) as f:
  111. # data = f.read()
  112. # data_config = json.loads(data)["data_config"]
  113. # mel2samp = Mel2Samp(**data_config)
  114. # filepaths = files_to_list(args.filelist_path)
  115. # # Make directory if it doesn't exist
  116. # if not os.path.isdir(args.output_dir):
  117. # os.makedirs(args.output_dir)
  118. # os.chmod(args.output_dir, 0o775)
  119. # for filepath in filepaths:
  120. # audio, sr = load_wav_to_torch(filepath)
  121. # melspectrogram = mel2samp.get_mel(audio)
  122. # filename = os.path.basename(filepath)
  123. # new_filepath = args.output_dir + '/' + filename + '.pt'
  124. # print(new_filepath)
  125. # torch.save(melspectrogram, new_filepath)