diff --git a/train.py b/train.py index ee01b07..b13f01f 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ import os import time import argparse import math +from numpy import finfo import torch from distributed import DistributedDataParallel @@ -77,7 +78,9 @@ def prepare_directories_and_logger(output_directory, log_directory, rank): def load_model(hparams): model = Tacotron2(hparams).cuda() - model = batchnorm_to_float(model.half()) if hparams.fp16_run else model + if hparams.fp16_run: + model = batchnorm_to_float(model.half()) + model.decoder.attention_layer.score_mask_value = float(finfo('float16').min) if hparams.distributed_run: model = DistributedDataParallel(model)