From 0274619e45252180aec5b6ca9c18c9761d1a2765 Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Wed, 3 Apr 2019 13:42:00 -0800 Subject: [PATCH] train.py: using amp for mixed precision training --- train.py | 38 ++++++++++++++++---------------------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/train.py b/train.py index 4287016..4c8988f 100644 --- a/train.py +++ b/train.py @@ -10,8 +10,6 @@ import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler from torch.utils.data import DataLoader -from fp16_optimizer import FP16_Optimizer - from model import Tacotron2 from data_utils import TextMelLoader, TextMelCollate from loss_function import Tacotron2Loss @@ -19,15 +17,6 @@ from logger import Tacotron2Logger from hparams import create_hparams -def batchnorm_to_float(module): - """Converts batch norm modules to FP32""" - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): - module.float() - for child in module.children(): - batchnorm_to_float(child) - return module - - def reduce_tensor(tensor, n_gpus): rt = tensor.clone() dist.all_reduce(rt, op=dist.reduce_op.SUM) @@ -80,8 +69,7 @@ def prepare_directories_and_logger(output_directory, log_directory, rank): def load_model(hparams): model = Tacotron2(hparams).cuda() if hparams.fp16_run: - model = batchnorm_to_float(model.half()) - model.decoder.attention_layer.score_mask_value = float(finfo('float16').min) + model.decoder.attention_layer.score_mask_value = finfo('float16').min if hparams.distributed_run: model = apply_gradient_allreduce(model) @@ -177,9 +165,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, learning_rate = hparams.learning_rate optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay) + if hparams.fp16_run: - optimizer = FP16_Optimizer( - optimizer, dynamic_loss_scale=hparams.dynamic_loss_scaling) + from apex import amp + model, optimizer = amp.initialize( + model, optimizer, opt_level='O2') if hparams.distributed_run: model = apply_gradient_allreduce(model) @@ -207,6 +197,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, epoch_offset = max(0, int(iteration / len(train_loader))) model.train() + is_overflow = False # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, hparams.epochs): print("Epoch: {}".format(epoch)) @@ -224,27 +215,30 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, reduced_loss = reduce_tensor(loss.data, n_gpus).item() else: reduced_loss = loss.item() - if hparams.fp16_run: - optimizer.backward(loss) - grad_norm = optimizer.clip_fp32_grads(hparams.grad_clip_thresh) + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() else: loss.backward() + + if hparams.fp16_run: + grad_norm = torch.nn.utils.clip_grad_norm_( + amp.master_params(optimizer), hparams.grad_clip_thresh) + is_overflow = math.isnan(grad_norm) + else: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), hparams.grad_clip_thresh) optimizer.step() - overflow = optimizer.overflow if hparams.fp16_run else False - - if not overflow and not math.isnan(reduced_loss) and rank == 0: + if not is_overflow and rank == 0: duration = time.perf_counter() - start print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( iteration, reduced_loss, grad_norm, duration)) logger.log_training( reduced_loss, grad_norm, learning_rate, duration, iteration) - if not overflow and (iteration % hparams.iters_per_checkpoint == 0): + if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0): validate(model, criterion, valset, iteration, hparams.batch_size, n_gpus, collate_fn, logger, hparams.distributed_run, rank)