Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

194 lines
7.9 KiB

import torch
import torch.nn as nn
from multiprocessing import cpu_count
import numpy as np
import argparse
import os
import time
import math
from fastspeech import FastSpeech
from loss import FastSpeechLoss
from dataset import FastSpeechDataset, collate_fn, DataLoader
from optimizer import ScheduledOptim
import hparams as hp
import utils
def main(args):
# Get device
device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
# Define model
model = nn.DataParallel(FastSpeech()).to(device)
print("Model Has Been Defined")
num_param = utils.get_param_num(model)
print('Number of FastSpeech Parameters:', num_param)
# Get dataset
dataset = FastSpeechDataset()
# Optimizer and loss
optimizer = torch.optim.Adam(
model.parameters(), betas=(0.9, 0.98), eps=1e-9)
scheduled_optim = ScheduledOptim(optimizer,
hp.d_model,
hp.n_warm_up_step,
args.restore_step)
fastspeech_loss = FastSpeechLoss().to(device)
print("Defined Optimizer and Loss Function.")
# Load checkpoint if exists
try:
checkpoint = torch.load(os.path.join(
hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("\n---Model Restored at Step %d---\n" % args.restore_step)
except:
print("\n---Start New Training---\n")
if not os.path.exists(hp.checkpoint_path):
os.mkdir(hp.checkpoint_path)
# Init logger
if not os.path.exists(hp.logger_path):
os.mkdir(hp.logger_path)
# Define Some Information
Time = np.array([])
Start = time.clock()
# Training
model = model.train()
for epoch in range(hp.epochs):
# Get Training Loader
training_loader = DataLoader(dataset,
batch_size=hp.batch_size**2,
shuffle=True,
collate_fn=collate_fn,
drop_last=True,
num_workers=0)
total_step = hp.epochs * len(training_loader) * hp.batch_size
for i, batchs in enumerate(training_loader):
for j, data_of_batch in enumerate(batchs):
start_time = time.clock()
current_step = i * hp.batch_size + j + args.restore_step + \
epoch * len(training_loader)*hp.batch_size + 1
# Init
scheduled_optim.zero_grad()
# Get Data
character = torch.from_numpy(
data_of_batch["text"]).long().to(device)
mel_target = torch.from_numpy(
data_of_batch["mel_target"]).float().to(device)
D = torch.from_numpy(data_of_batch["D"]).int().to(device)
mel_pos = torch.from_numpy(
data_of_batch["mel_pos"]).long().to(device)
src_pos = torch.from_numpy(
data_of_batch["src_pos"]).long().to(device)
max_mel_len = data_of_batch["mel_max_len"]
# Forward
mel_output, mel_postnet_output, duration_predictor_output = model(character,
src_pos,
mel_pos=mel_pos,
mel_max_length=max_mel_len,
length_target=D)
# print(mel_target.size())
# print(mel_output.size())
# Cal Loss
mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss(mel_output,
mel_postnet_output,
duration_predictor_output,
mel_target,
D)
total_loss = mel_loss + mel_postnet_loss + duration_loss
# Logger
t_l = total_loss.item()
m_l = mel_loss.item()
m_p_l = mel_postnet_loss.item()
d_l = duration_loss.item()
with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss:
f_total_loss.write(str(t_l)+"\n")
with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss:
f_mel_loss.write(str(m_l)+"\n")
with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss:
f_mel_postnet_loss.write(str(m_p_l)+"\n")
with open(os.path.join("logger", "duration_loss.txt"), "a") as f_d_loss:
f_d_loss.write(str(d_l)+"\n")
# Backward
total_loss.backward()
# Clipping gradients to avoid gradient explosion
nn.utils.clip_grad_norm_(
model.parameters(), hp.grad_clip_thresh)
# Update weights
if args.frozen_learning_rate:
scheduled_optim.step_and_update_lr_frozen(
args.learning_rate_frozen)
else:
scheduled_optim.step_and_update_lr()
# Print
if current_step % hp.log_step == 0:
Now = time.clock()
str1 = "Epoch [{}/{}], Step [{}/{}]:".format(
epoch+1, hp.epochs, current_step, total_step)
str2 = "Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format(
m_l, m_p_l, d_l)
str3 = "Current Learning Rate is {:.6f}.".format(
scheduled_optim.get_learning_rate())
str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format(
(Now-Start), (total_step-current_step)*np.mean(Time))
print("\n" + str1)
print(str2)
print(str3)
print(str4)
with open(os.path.join("logger", "logger.txt"), "a") as f_logger:
f_logger.write(str1 + "\n")
f_logger.write(str2 + "\n")
f_logger.write(str3 + "\n")
f_logger.write(str4 + "\n")
f_logger.write("\n")
if current_step % hp.save_step == 0:
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
)}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step))
print("save model at step %d ..." % current_step)
end_time = time.clock()
Time = np.append(Time, end_time - start_time)
if len(Time) == hp.clear_Time:
temp_value = np.mean(Time)
Time = np.delete(
Time, [i for i in range(len(Time))], axis=None)
Time = np.append(Time, temp_value)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int, default=0)
parser.add_argument('--frozen_learning_rate', type=bool, default=False)
parser.add_argument("--learning_rate_frozen", type=float, default=1e-3)
args = parser.parse_args()
main(args)