Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

194 lines
7.9 KiB

  1. import torch
  2. import torch.nn as nn
  3. from multiprocessing import cpu_count
  4. import numpy as np
  5. import argparse
  6. import os
  7. import time
  8. import math
  9. from fastspeech import FastSpeech
  10. from loss import FastSpeechLoss
  11. from dataset import FastSpeechDataset, collate_fn, DataLoader
  12. from optimizer import ScheduledOptim
  13. import hparams as hp
  14. import utils
  15. def main(args):
  16. # Get device
  17. device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
  18. # Define model
  19. model = nn.DataParallel(FastSpeech()).to(device)
  20. print("Model Has Been Defined")
  21. num_param = utils.get_param_num(model)
  22. print('Number of FastSpeech Parameters:', num_param)
  23. # Get dataset
  24. dataset = FastSpeechDataset()
  25. # Optimizer and loss
  26. optimizer = torch.optim.Adam(
  27. model.parameters(), betas=(0.9, 0.98), eps=1e-9)
  28. scheduled_optim = ScheduledOptim(optimizer,
  29. hp.d_model,
  30. hp.n_warm_up_step,
  31. args.restore_step)
  32. fastspeech_loss = FastSpeechLoss().to(device)
  33. print("Defined Optimizer and Loss Function.")
  34. # Load checkpoint if exists
  35. try:
  36. checkpoint = torch.load(os.path.join(
  37. hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
  38. model.load_state_dict(checkpoint['model'])
  39. optimizer.load_state_dict(checkpoint['optimizer'])
  40. print("\n---Model Restored at Step %d---\n" % args.restore_step)
  41. except:
  42. print("\n---Start New Training---\n")
  43. if not os.path.exists(hp.checkpoint_path):
  44. os.mkdir(hp.checkpoint_path)
  45. # Init logger
  46. if not os.path.exists(hp.logger_path):
  47. os.mkdir(hp.logger_path)
  48. # Define Some Information
  49. Time = np.array([])
  50. Start = time.clock()
  51. # Training
  52. model = model.train()
  53. for epoch in range(hp.epochs):
  54. # Get Training Loader
  55. training_loader = DataLoader(dataset,
  56. batch_size=hp.batch_size**2,
  57. shuffle=True,
  58. collate_fn=collate_fn,
  59. drop_last=True,
  60. num_workers=0)
  61. total_step = hp.epochs * len(training_loader) * hp.batch_size
  62. for i, batchs in enumerate(training_loader):
  63. for j, data_of_batch in enumerate(batchs):
  64. start_time = time.clock()
  65. current_step = i * hp.batch_size + j + args.restore_step + \
  66. epoch * len(training_loader)*hp.batch_size + 1
  67. # Init
  68. scheduled_optim.zero_grad()
  69. # Get Data
  70. character = torch.from_numpy(
  71. data_of_batch["text"]).long().to(device)
  72. mel_target = torch.from_numpy(
  73. data_of_batch["mel_target"]).float().to(device)
  74. D = torch.from_numpy(data_of_batch["D"]).int().to(device)
  75. mel_pos = torch.from_numpy(
  76. data_of_batch["mel_pos"]).long().to(device)
  77. src_pos = torch.from_numpy(
  78. data_of_batch["src_pos"]).long().to(device)
  79. max_mel_len = data_of_batch["mel_max_len"]
  80. # Forward
  81. mel_output, mel_postnet_output, duration_predictor_output = model(character,
  82. src_pos,
  83. mel_pos=mel_pos,
  84. mel_max_length=max_mel_len,
  85. length_target=D)
  86. # print(mel_target.size())
  87. # print(mel_output.size())
  88. # Cal Loss
  89. mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(mel_output,
  90. mel_postnet_output,
  91. duration_predictor_output,
  92. mel_target,
  93. D)
  94. total_loss = mel_loss + mel_postnet_loss + duration_loss
  95. # Logger
  96. t_l = total_loss.item()
  97. m_l = mel_loss.item()
  98. m_p_l = mel_postnet_loss.item()
  99. d_l = duration_loss.item()
  100. with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss:
  101. f_total_loss.write(str(t_l)+"\n")
  102. with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss:
  103. f_mel_loss.write(str(m_l)+"\n")
  104. with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
  105. f_mel_postnet_loss.write(str(m_p_l)+"\n")
  106. with open(os.path.join("logger", "duration_loss.txt"), "a") as f_d_loss:
  107. f_d_loss.write(str(d_l)+"\n")
  108. # Backward
  109. total_loss.backward()
  110. # Clipping gradients to avoid gradient explosion
  111. nn.utils.clip_grad_norm_(
  112. model.parameters(), hp.grad_clip_thresh)
  113. # Update weights
  114. if args.frozen_learning_rate:
  115. scheduled_optim.step_and_update_lr_frozen(
  116. args.learning_rate_frozen)
  117. else:
  118. scheduled_optim.step_and_update_lr()
  119. # Print
  120. if current_step % hp.log_step == 0:
  121. Now = time.clock()
  122. str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
  123. epoch+1, hp.epochs, current_step, total_step)
  124. str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
  125. m_l, m_p_l, d_l)
  126. str3 = "Current Learning Rate is {:.6f}.".format(
  127. scheduled_optim.get_learning_rate())
  128. str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
  129. (Now-Start), (total_step-current_step)*np.mean(Time))
  130. print("\n" + str1)
  131. print(str2)
  132. print(str3)
  133. print(str4)
  134. with open(os.path.join("logger", "logger.txt"), "a") as f_logger:
  135. f_logger.write(str1 + "\n")
  136. f_logger.write(str2 + "\n")
  137. f_logger.write(str3 + "\n")
  138. f_logger.write(str4 + "\n")
  139. f_logger.write("\n")
  140. if current_step % hp.save_step == 0:
  141. torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
  142. )}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step))
  143. print("save model at step %d ..." % current_step)
  144. end_time = time.clock()
  145. Time = np.append(Time, end_time - start_time)
  146. if len(Time) == hp.clear_Time:
  147. temp_value = np.mean(Time)
  148. Time = np.delete(
  149. Time, [i for i in range(len(Time))], axis=None)
  150. Time = np.append(Time, temp_value)
  151. if __name__ == "__main__":
  152. parser = argparse.ArgumentParser()
  153. parser.add_argument('--restore_step', type=int, default=0)
  154. parser.add_argument('--frozen_learning_rate', type=bool, default=False)
  155. parser.add_argument("--learning_rate_frozen", type=float, default=1e-3)
  156. args = parser.parse_args()
  157. main(args)