import torch import torch.nn as nn from multiprocessing import cpu_count import numpy as np import argparse import os import time import math from fastspeech import FastSpeech from loss import FastSpeechLoss from dataset import FastSpeechDataset, collate_fn, DataLoader from optimizer import ScheduledOptim import hparams as hp import utils def main(args): # Get device device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') # Define model model = nn.DataParallel(FastSpeech()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech Parameters:', num_param) # Get dataset dataset = FastSpeechDataset() # Optimizer and loss optimizer = torch.optim.Adam( model.parameters(), betas=(0.9, 0.98), eps=1e-9) scheduled_optim = ScheduledOptim(optimizer, hp.d_model, hp.n_warm_up_step, args.restore_step) fastspeech_loss = FastSpeechLoss().to(device) print("Defined Optimizer and Loss Function.") # Load checkpoint if exists try: checkpoint = torch.load(os.path.join( hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step %d---\n" % args.restore_step) except: print("\n---Start New Training---\n") if not os.path.exists(hp.checkpoint_path): os.mkdir(hp.checkpoint_path) # Init logger if not os.path.exists(hp.logger_path): os.mkdir(hp.logger_path) # Define Some Information Time = np.array([]) Start = time.clock() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader training_loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=0) total_step = hp.epochs * len(training_loader) * hp.batch_size for i, batchs in enumerate(training_loader): for j, data_of_batch in enumerate(batchs): start_time = time.clock() current_step = i * hp.batch_size + j + args.restore_step + \ epoch * len(training_loader)*hp.batch_size + 1 # Init scheduled_optim.zero_grad() # Get Data character = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) mel_pos = torch.from_numpy( data_of_batch["mel_pos"]).long().to(device) src_pos = torch.from_numpy( data_of_batch["src_pos"]).long().to(device) max_mel_len = data_of_batch["mel_max_len"] # Forward mel_output, mel_postnet_output, duration_predictor_output = model(character, src_pos, mel_pos=mel_pos, mel_max_length=max_mel_len, length_target=D) # print(mel_target.size()) # print(mel_output.size()) # Cal Loss mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(mel_output, mel_postnet_output, duration_predictor_output, mel_target, D) total_loss = mel_loss + mel_postnet_loss + duration_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = duration_loss.item() with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l)+"\n") with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l)+"\n") with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l)+"\n") with open(os.path.join("logger", "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l)+"\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_( model.parameters(), hp.grad_clip_thresh) # Update weights if args.frozen_learning_rate: scheduled_optim.step_and_update_lr_frozen( args.learning_rate_frozen) else: scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.clock() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch+1, hp.epochs, current_step, total_step) str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format( m_l, m_p_l, d_l) str3 = "Current Learning Rate is {:.6f}.".format( scheduled_optim.get_learning_rate()) str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now-Start), (total_step-current_step)*np.mean(Time)) print("\n" + str1) print(str2) print(str3) print(str4) with open(os.path.join("logger", "logger.txt"), "a") as f_logger: f_logger.write(str1 + "\n") f_logger.write(str2 + "\n") f_logger.write(str3 + "\n") f_logger.write(str4 + "\n") f_logger.write("\n") if current_step % hp.save_step == 0: torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict( )}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step)) print("save model at step %d ..." % current_step) end_time = time.clock() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete( Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--restore_step', type=int, default=0) parser.add_argument('--frozen_learning_rate', type=bool, default=False) parser.add_argument("--learning_rate_frozen", type=float, default=1e-3) args = parser.parse_args() main(args)