Browse Source

Merge pull request #20 from NVIDIA/fp16_path

Fp16 patch, not path!
master
Rafael Valle 6 years ago
committed by GitHub
parent
commit
da30fd8709
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 6 deletions
  1. +3
    -4
      loss_scaler.py
  2. +5
    -2
      train.py

+ 3
- 4
loss_scaler.py View File

@ -51,11 +51,10 @@ class DynamicLossScaler:
# `x` is a torch.Tensor # `x` is a torch.Tensor
def _has_inf_or_nan(x): def _has_inf_or_nan(x):
inf_count = torch.sum(x.abs() == float('inf'))
if inf_count > 0:
cpu_sum = float(x.float().sum())
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True return True
nan_count = torch.sum(x != x)
return nan_count > 0
return False
# `overflow` is boolean indicating whether we overflowed in gradient # `overflow` is boolean indicating whether we overflowed in gradient
def update_scale(self, overflow): def update_scale(self, overflow):

+ 5
- 2
train.py View File

@ -2,6 +2,7 @@ import os
import time import time
import argparse import argparse
import math import math
from numpy import finfo
import torch import torch
from distributed import DistributedDataParallel from distributed import DistributedDataParallel
@ -77,7 +78,9 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
def load_model(hparams): def load_model(hparams):
model = Tacotron2(hparams).cuda() model = Tacotron2(hparams).cuda()
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
if hparams.fp16_run:
model = batchnorm_to_float(model.half())
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
if hparams.distributed_run: if hparams.distributed_run:
model = DistributedDataParallel(model) model = DistributedDataParallel(model)
@ -276,7 +279,7 @@ if __name__ == '__main__':
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
print("FP16 Run:", hparams.fp16_run) print("FP16 Run:", hparams.fp16_run)
print("Dynamic Loss Scaling", hparams.dynamic_loss_scaling)
print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
print("Distributed Run:", hparams.distributed_run) print("Distributed Run:", hparams.distributed_run)
print("cuDNN Enabled:", hparams.cudnn_enabled) print("cuDNN Enabled:", hparams.cudnn_enabled)
print("cuDNN Benchmark:", hparams.cudnn_benchmark) print("cuDNN Benchmark:", hparams.cudnn_benchmark)

Loading…
Cancel
Save