Browse Source

train.py: patching score_mask_value formerly inf, not concrete value, for compatibility with pytorch

master
Rafael Valle 6 years ago
parent
commit
1071023017
1 changed files with 4 additions and 1 deletions
  1. +4
    -1
      train.py

+ 4
- 1
train.py View File

@ -2,6 +2,7 @@ import os
import time
import argparse
import math
from numpy import finfo
import torch
from distributed import DistributedDataParallel
@ -77,7 +78,9 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
def load_model(hparams):
model = Tacotron2(hparams).cuda()
model = batchnorm_to_float(model.half()) if hparams.fp16_run else model
if hparams.fp16_run:
model = batchnorm_to_float(model.half())
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
if hparams.distributed_run:
model = DistributedDataParallel(model)

Loading…
Cancel
Save