import os import time import argparse import math from numpy import finfo import torch from distributed import DistributedDataParallel from torch.utils.data.distributed import DistributedSampler from torch.nn import DataParallel from torch.utils.data import DataLoader from fp16_optimizer import FP16_Optimizer from model import Tacotron2 from data_utils import TextMelLoader, TextMelCollate from loss_function import Tacotron2Loss from logger import Tacotron2Logger from hparams import create_hparams def batchnorm_to_float(module): """Converts batch norm modules to FP32""" if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): module.float() for child in module.children(): batchnorm_to_float(child) return module def reduce_tensor(tensor, num_gpus): rt = tensor.clone() torch.distributed.all_reduce(rt, op=torch.distributed.reduce_op.SUM) rt /= num_gpus return rt def init_distributed(hparams, n_gpus, rank, group_name): assert torch.cuda.is_available(), "Distributed mode requires CUDA." print("Initializing distributed") # Set cuda device so everything is done on the right GPU. torch.cuda.set_device(rank % torch.cuda.device_count()) # Initialize distributed communication torch.distributed.init_process_group( backend=hparams.dist_backend, init_method=hparams.dist_url, world_size=n_gpus, rank=rank, group_name=group_name) print("Done initializing distributed") def prepare_dataloaders(hparams): # Get data, data loaders and collate function ready trainset = TextMelLoader(hparams.training_files, hparams) valset = TextMelLoader(hparams.validation_files, hparams) collate_fn = TextMelCollate(hparams.n_frames_per_step) train_sampler = DistributedSampler(trainset) \ if hparams.distributed_run else None train_loader = DataLoader(trainset, num_workers=1, shuffle=False, sampler=train_sampler, batch_size=hparams.batch_size, pin_memory=False, drop_last=True, collate_fn=collate_fn) return train_loader, valset, collate_fn def prepare_directories_and_logger(output_directory, log_directory, rank): if rank == 0: if not os.path.isdir(output_directory): os.makedirs(output_directory) os.chmod(output_directory, 0o775) logger = Tacotron2Logger(os.path.join(output_directory, log_directory)) else: logger = None return logger def load_model(hparams): model = Tacotron2(hparams).cuda() if hparams.fp16_run: model = batchnorm_to_float(model.half()) model.decoder.attention_layer.score_mask_value = float(finfo('float16').min) if hparams.distributed_run: model = DistributedDataParallel(model) elif torch.cuda.device_count() > 1: model = DataParallel(model) return model def warm_start_model(checkpoint_path, model): assert os.path.isfile(checkpoint_path) print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint_dict['state_dict']) return model def load_checkpoint(checkpoint_path, model, optimizer): assert os.path.isfile(checkpoint_path) print("Loading checkpoint '{}'".format(checkpoint_path)) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint_dict['state_dict']) optimizer.load_state_dict(checkpoint_dict['optimizer']) learning_rate = checkpoint_dict['learning_rate'] iteration = checkpoint_dict['iteration'] print("Loaded checkpoint '{}' from iteration {}" .format( checkpoint_path, iteration)) return model, optimizer, learning_rate, iteration def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): print("Saving model and optimizer state at iteration {} to {}".format( iteration, filepath)) torch.save({'iteration': iteration, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'learning_rate': learning_rate}, filepath) def validate(model, criterion, valset, iteration, batch_size, n_gpus, collate_fn, logger, distributed_run, rank): """Handles all the validation scoring and printing""" model.eval() with torch.no_grad(): val_sampler = DistributedSampler(valset) if distributed_run else None val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, shuffle=False, batch_size=batch_size, pin_memory=False, collate_fn=collate_fn) val_loss = 0.0 if distributed_run or torch.cuda.device_count() > 1: batch_parser = model.module.parse_batch else: batch_parser = model.parse_batch for i, batch in enumerate(val_loader): x, y = batch_parser(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \ if distributed_run else loss.data[0] val_loss += reduced_val_loss val_loss = val_loss / (i + 1) model.train() return val_loss def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, rank, group_name, hparams): """Training and validation logging results to tensorboard and stdout Params ------ output_directory (string): directory to save checkpoints log_directory (string) directory to save tensorboard logs checkpoint_path(string): checkpoint path n_gpus (int): number of gpus rank (int): rank of current gpu hparams (object): comma separated list of "name=value" pairs. """ if hparams.distributed_run: init_distributed(hparams, n_gpus, rank, group_name) torch.manual_seed(hparams.seed) torch.cuda.manual_seed(hparams.seed) model = load_model(hparams) learning_rate = hparams.learning_rate optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay) if hparams.fp16_run: optimizer = FP16_Optimizer( optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling) criterion = Tacotron2Loss() logger = prepare_directories_and_logger( output_directory, log_directory, rank) train_loader, valset, collate_fn = prepare_dataloaders(hparams) # Load checkpoint if one exists iteration = 0 epoch_offset = 0 if checkpoint_path is not None: if warm_start: model = warm_start_model(checkpoint_path, model) else: model, optimizer, _learning_rate, iteration = load_checkpoint( checkpoint_path, model, optimizer) if hparams.use_saved_learning_rate: learning_rate = _learning_rate iteration += 1 # next iteration is iteration + 1 epoch_offset = max(0, int(iteration / len(train_loader))) model.train() if hparams.distributed_run or torch.cuda.device_count() > 1: batch_parser = model.module.parse_batch else: batch_parser = model.parse_batch # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, hparams.epochs): print("Epoch: {}".format(epoch)) for i, batch in enumerate(train_loader): start = time.perf_counter() for param_group in optimizer.param_groups: param_group['lr'] = learning_rate model.zero_grad() x, y = batch_parser(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \ if hparams.distributed_run else loss.data[0] if hparams.fp16_run: optimizer.backward(loss) grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh) else: loss.backward() grad_norm = torch.nn.utils.clip_grad_norm( model.parameters(), hparams.grad_clip_thresh) optimizer.step() overflow = optimizer.overflow if hparams.fp16_run else False if not overflow and not math.isnan(reduced_loss) and rank == 0: duration = time.perf_counter() - start print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( iteration, reduced_loss, grad_norm, duration)) logger.log_training( reduced_loss, grad_norm, learning_rate, duration, iteration) if not overflow and (iteration % hparams.iters_per_checkpoint == 0): reduced_val_loss = validate( model, criterion, valset, iteration, hparams.batch_size, n_gpus, collate_fn, logger, hparams.distributed_run, rank) if rank == 0: print("Validation loss {}: {:9f} ".format( iteration, reduced_val_loss)) logger.log_validation( reduced_val_loss, model, y, y_pred, iteration) checkpoint_path = os.path.join( output_directory, "checkpoint_{}".format(iteration)) save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) iteration += 1 if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-o', '--output_directory', type=str, help='directory to save checkpoints') parser.add_argument('-l', '--log_directory', type=str, help='directory to save tensorboard logs') parser.add_argument('-c', '--checkpoint_path', type=str, default=None, required=False, help='checkpoint path') parser.add_argument('--warm_start', action='store_true', help='load the model only (warm start)') parser.add_argument('--n_gpus', type=int, default=1, required=False, help='number of gpus') parser.add_argument('--rank', type=int, default=0, required=False, help='rank of current gpu') parser.add_argument('--group_name', type=str, default='group_name', required=False, help='Distributed group name') parser.add_argument('--hparams', type=str, required=False, help='comma separated name=value pairs') args = parser.parse_args() hparams = create_hparams(args.hparams) torch.backends.cudnn.enabled = hparams.cudnn_enabled torch.backends.cudnn.benchmark = hparams.cudnn_benchmark print("FP16 Run:", hparams.fp16_run) print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling) print("Distributed Run:", hparams.distributed_run) print("cuDNN Enabled:", hparams.cudnn_enabled) print("cuDNN Benchmark:", hparams.cudnn_benchmark) train(args.output_directory, args.log_directory, args.checkpoint_path, args.warm_start, args.n_gpus, args.rank, args.group_name, hparams)