Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 lines
8.6 KiB

  1. # We retain the copyright notice by NVIDIA from the original code. However, we
  2. # we reserve our rights on the modifications based on the original code.
  3. #
  4. # *****************************************************************************
  5. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  6. #
  7. # Redistribution and use in source and binary forms, with or without
  8. # modification, are permitted provided that the following conditions are met:
  9. # * Redistributions of source code must retain the above copyright
  10. # notice, this list of conditions and the following disclaimer.
  11. # * Redistributions in binary form must reproduce the above copyright
  12. # notice, this list of conditions and the following disclaimer in the
  13. # documentation and/or other materials provided with the distribution.
  14. # * Neither the name of the NVIDIA CORPORATION nor the
  15. # names of its contributors may be used to endorse or promote products
  16. # derived from this software without specific prior written permission.
  17. #
  18. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  19. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  20. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  21. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  22. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  23. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  24. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  25. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  26. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  27. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  28. #
  29. # *****************************************************************************
  30. import argparse
  31. import json
  32. import os
  33. import torch
  34. #=====START: ADDED FOR DISTRIBUTED======
  35. from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor
  36. from torch.utils.data.distributed import DistributedSampler
  37. #=====END: ADDED FOR DISTRIBUTED======
  38. from torch.utils.data import DataLoader
  39. from glow import SqueezeWave, SqueezeWaveLoss
  40. from mel2samp import Mel2Samp
  41. def load_checkpoint(
  42. checkpoint_path, model, optimizer, n_flows, n_early_every,
  43. n_early_size, n_mel_channels, n_audio_channel, WN_config):
  44. assert os.path.isfile(checkpoint_path)
  45. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  46. iteration = checkpoint_dict['iteration']
  47. #iteration = 1
  48. optimizer.load_state_dict(checkpoint_dict['optimizer'])
  49. model_for_loading = checkpoint_dict['model']
  50. state_dict = model_for_loading.state_dict()
  51. model.load_state_dict(state_dict, strict = False)
  52. print("Loaded checkpoint '{}' (iteration {})" .format(checkpoint_path, iteration))
  53. return model, optimizer, iteration
  54. def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
  55. print("Saving model and optimizer state at iteration {} to {}".format(
  56. iteration, filepath))
  57. model_for_saving = SqueezeWave(**squeezewave_config).cuda()
  58. model_for_saving.load_state_dict(model.state_dict())
  59. torch.save({'model': model_for_saving,
  60. 'iteration': iteration,
  61. 'optimizer': optimizer.state_dict(),
  62. 'learning_rate': learning_rate}, filepath)
  63. def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
  64. sigma, iters_per_checkpoint, batch_size, seed, fp16_run,
  65. checkpoint_path, with_tensorboard):
  66. torch.manual_seed(seed)
  67. torch.cuda.manual_seed(seed)
  68. #=====START: ADDED FOR DISTRIBUTED======
  69. if num_gpus > 1:
  70. init_distributed(rank, num_gpus, group_name, **dist_config)
  71. #=====END: ADDED FOR DISTRIBUTED======
  72. criterion = SqueezeWaveLoss(sigma)
  73. model = SqueezeWave(**squeezewave_config).cuda()
  74. print(model)
  75. pytorch_total_params = sum(p.numel() for p in model.parameters())
  76. pytorch_total_params_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
  77. print("param", pytorch_total_params)
  78. print("param trainable", pytorch_total_params_train)
  79. #=====START: ADDED FOR DISTRIBUTED======
  80. if num_gpus > 1:
  81. model = apply_gradient_allreduce(model)
  82. #=====END: ADDED FOR DISTRIBUTED======
  83. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  84. if fp16_run:
  85. from apex import amp
  86. model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
  87. # Load checkpoint if one exists
  88. iteration = 0
  89. if checkpoint_path != "":
  90. model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
  91. optimizer, **squeezewave_config)
  92. iteration += 1 # next iteration is iteration + 1
  93. n_audio_channel = squeezewave_config["n_audio_channel"]
  94. trainset = Mel2Samp(n_audio_channel, **data_config)
  95. # =====START: ADDED FOR DISTRIBUTED======
  96. train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
  97. # =====END: ADDED FOR DISTRIBUTED======
  98. train_loader = DataLoader(trainset, num_workers=0, shuffle=False,
  99. sampler=train_sampler,
  100. batch_size=batch_size,
  101. pin_memory=False,
  102. drop_last=True)
  103. # Get shared output_directory ready
  104. if rank == 0:
  105. if not os.path.isdir(output_directory):
  106. os.makedirs(output_directory)
  107. os.chmod(output_directory, 0o775)
  108. print("output directory", output_directory)
  109. if with_tensorboard and rank == 0:
  110. from tensorboardX import SummaryWriter
  111. logger = SummaryWriter(os.path.join(output_directory, 'logs'))
  112. model.train()
  113. epoch_offset = max(0, int(iteration / len(train_loader)))
  114. # ================ MAIN TRAINNIG LOOP! ===================
  115. for epoch in range(epoch_offset, epochs):
  116. print("Epoch: {}".format(epoch))
  117. for i, batch in enumerate(train_loader):
  118. model.zero_grad()
  119. mel, audio = batch
  120. mel = torch.autograd.Variable(mel.cuda())
  121. audio = torch.autograd.Variable(audio.cuda())
  122. outputs = model((mel, audio))
  123. loss = criterion(outputs)
  124. if num_gpus > 1:
  125. reduced_loss = reduce_tensor(loss.data, num_gpus).item()
  126. else:
  127. reduced_loss = loss.item()
  128. if fp16_run:
  129. with amp.scale_loss(loss, optimizer) as scaled_loss:
  130. scaled_loss.backward()
  131. else:
  132. loss.backward()
  133. optimizer.step()
  134. print("{}:\t{:.9f}\t".format(iteration, reduced_loss))
  135. if with_tensorboard and rank == 0:
  136. logger.add_scalar('training_loss', reduced_loss, i + len(train_loader) * epoch)
  137. if (iteration % iters_per_checkpoint == 0):
  138. if rank == 0:
  139. checkpoint_path = "{}/SqueezeWave_{}".format(
  140. output_directory, iteration)
  141. save_checkpoint(model, optimizer, learning_rate, iteration,
  142. checkpoint_path)
  143. iteration += 1
  144. if __name__ == "__main__":
  145. parser = argparse.ArgumentParser()
  146. parser.add_argument('-c', '--config', type=str,
  147. help='JSON file for configuration')
  148. parser.add_argument('-r', '--rank', type=int, default=0,
  149. help='rank of process for distributed')
  150. parser.add_argument('-g', '--group_name', type=str, default='',
  151. help='name of group for distributed')
  152. args = parser.parse_args()
  153. # Parse configs. Globals nicer in this case
  154. with open(args.config) as f:
  155. data = f.read()
  156. config = json.loads(data)
  157. train_config = config["train_config"]
  158. global data_config
  159. data_config = config["data_config"]
  160. global dist_config
  161. dist_config = config["dist_config"]
  162. global squeezewave_config
  163. squeezewave_config = config["squeezewave_config"]
  164. num_gpus = torch.cuda.device_count()
  165. if num_gpus > 1:
  166. if args.group_name == '':
  167. print("WARNING: Multiple GPUs detected but no distributed group set")
  168. print("Only running 1 GPU. Use distributed.py for multiple GPUs")
  169. num_gpus = 1
  170. if num_gpus == 1 and args.rank != 0:
  171. raise Exception("Doing single GPU training on rank > 0")
  172. torch.backends.cudnn.enabled = True
  173. torch.backends.cudnn.benchmark = False
  174. train(num_gpus, args.rank, args.group_name, **train_config)