|
|
- import torch
- import torch.nn as nn
-
- from multiprocessing import cpu_count
- import numpy as np
- import argparse
- import os
- import time
- import math
-
- from fastspeech import FastSpeech
- from loss import FastSpeechLoss
- from dataset import FastSpeechDataset, collate_fn, DataLoader
- from optimizer import ScheduledOptim
- import hparams as hp
- import utils
-
-
- def main(args):
- # Get device
- device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
-
- # Define model
- model = nn.DataParallel(FastSpeech()).to(device)
- print("Model Has Been Defined")
- num_param = utils.get_param_num(model)
- print('Number of FastSpeech Parameters:', num_param)
-
- # Get dataset
- dataset = FastSpeechDataset()
-
- # Optimizer and loss
- optimizer = torch.optim.Adam(
- model.parameters(), betas=(0.9, 0.98), eps=1e-9)
- scheduled_optim = ScheduledOptim(optimizer,
- hp.d_model,
- hp.n_warm_up_step,
- args.restore_step)
- fastspeech_loss = FastSpeechLoss().to(device)
- print("Defined Optimizer and Loss Function.")
-
- # Load checkpoint if exists
- try:
- checkpoint = torch.load(os.path.join(
- hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
- model.load_state_dict(checkpoint['model'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- print("\n---Model Restored at Step %d---\n" % args.restore_step)
- except:
- print("\n---Start New Training---\n")
- if not os.path.exists(hp.checkpoint_path):
- os.mkdir(hp.checkpoint_path)
-
- # Init logger
- if not os.path.exists(hp.logger_path):
- os.mkdir(hp.logger_path)
-
- # Define Some Information
- Time = np.array([])
- Start = time.clock()
-
- # Training
- model = model.train()
-
- for epoch in range(hp.epochs):
- # Get Training Loader
- training_loader = DataLoader(dataset,
- batch_size=hp.batch_size**2,
- shuffle=True,
- collate_fn=collate_fn,
- drop_last=True,
- num_workers=0)
- total_step = hp.epochs * len(training_loader) * hp.batch_size
-
- for i, batchs in enumerate(training_loader):
- for j, data_of_batch in enumerate(batchs):
- start_time = time.clock()
-
- current_step = i * hp.batch_size + j + args.restore_step + \
- epoch * len(training_loader)*hp.batch_size + 1
-
- # Init
- scheduled_optim.zero_grad()
-
- # Get Data
- character = torch.from_numpy(
- data_of_batch["text"]).long().to(device)
- mel_target = torch.from_numpy(
- data_of_batch["mel_target"]).float().to(device)
- D = torch.from_numpy(data_of_batch["D"]).int().to(device)
- mel_pos = torch.from_numpy(
- data_of_batch["mel_pos"]).long().to(device)
- src_pos = torch.from_numpy(
- data_of_batch["src_pos"]).long().to(device)
- max_mel_len = data_of_batch["mel_max_len"]
-
- # Forward
- mel_output, mel_postnet_output, duration_predictor_output = model(character,
- src_pos,
- mel_pos=mel_pos,
- mel_max_length=max_mel_len,
- length_target=D)
-
- # print(mel_target.size())
- # print(mel_output.size())
-
- # Cal Loss
- mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(mel_output,
- mel_postnet_output,
- duration_predictor_output,
- mel_target,
- D)
- total_loss = mel_loss + mel_postnet_loss + duration_loss
-
- # Logger
- t_l = total_loss.item()
- m_l = mel_loss.item()
- m_p_l = mel_postnet_loss.item()
- d_l = duration_loss.item()
-
- with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss:
- f_total_loss.write(str(t_l)+"\n")
-
- with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss:
- f_mel_loss.write(str(m_l)+"\n")
-
- with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
- f_mel_postnet_loss.write(str(m_p_l)+"\n")
-
- with open(os.path.join("logger", "duration_loss.txt"), "a") as f_d_loss:
- f_d_loss.write(str(d_l)+"\n")
-
- # Backward
- total_loss.backward()
-
- # Clipping gradients to avoid gradient explosion
- nn.utils.clip_grad_norm_(
- model.parameters(), hp.grad_clip_thresh)
-
- # Update weights
- if args.frozen_learning_rate:
- scheduled_optim.step_and_update_lr_frozen(
- args.learning_rate_frozen)
- else:
- scheduled_optim.step_and_update_lr()
-
- # Print
- if current_step % hp.log_step == 0:
- Now = time.clock()
-
- str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
- epoch+1, hp.epochs, current_step, total_step)
- str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
- m_l, m_p_l, d_l)
- str3 = "Current Learning Rate is {:.6f}.".format(
- scheduled_optim.get_learning_rate())
- str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
- (Now-Start), (total_step-current_step)*np.mean(Time))
-
- print("\n" + str1)
- print(str2)
- print(str3)
- print(str4)
-
- with open(os.path.join("logger", "logger.txt"), "a") as f_logger:
- f_logger.write(str1 + "\n")
- f_logger.write(str2 + "\n")
- f_logger.write(str3 + "\n")
- f_logger.write(str4 + "\n")
- f_logger.write("\n")
-
- if current_step % hp.save_step == 0:
- torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
- )}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step))
- print("save model at step %d ..." % current_step)
-
- end_time = time.clock()
- Time = np.append(Time, end_time - start_time)
- if len(Time) == hp.clear_Time:
- temp_value = np.mean(Time)
- Time = np.delete(
- Time, [i for i in range(len(Time))], axis=None)
- Time = np.append(Time, temp_value)
-
-
- if __name__ == "__main__":
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--restore_step', type=int, default=0)
- parser.add_argument('--frozen_learning_rate', type=bool, default=False)
- parser.add_argument("--learning_rate_frozen", type=float, default=1e-3)
- args = parser.parse_args()
-
- main(args)
|