diff --git a/train.py b/train.py index dd5de2e..f914cfc 100644 --- a/train.py +++ b/train.py @@ -74,14 +74,17 @@ def prepare_directories_and_logger(output_directory, log_directory, rank): logger = None return logger - def load_model(hparams): model = Tacotron2(hparams).cuda() model = batchnorm_to_float(model.half()) if hparams.fp16_run else model - model = DistributedDataParallel(model) \ - if hparams.distributed_run else DataParallel(model) - return model + tacotron_model = model + if hparams.distributed_run: + model = DistributedDataParallel(model) + elif torch.cuda.device_count() > 1: + model = DataParallel(model) + + return model, tacotron def warm_start_model(checkpoint_path, model): assert os.path.isfile(checkpoint_path) @@ -114,7 +117,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): def validate(model, criterion, valset, iteration, batch_size, n_gpus, - collate_fn, logger, distributed_run, rank): + collate_fn, logger, distributed_run, rank, batch_parser): """Handles all the validation scoring and printing""" model.eval() with torch.no_grad(): @@ -125,7 +128,7 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, val_loss = 0.0 for i, batch in enumerate(val_loader): - x, y = model.module.parse_batch(batch) + x, y = batch_parser(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \ @@ -193,11 +196,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, param_group['lr'] = learning_rate model.zero_grad() - x, y = model.module.parse_batch(batch) + x, y = tacotron_model.parse_batch(batch) y_pred = model(x) loss = criterion(y_pred, y) - reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \ - if hparams.distributed_run else loss.data[0] + reduced_loss = reduce_tensor(loss.data, n_gpus).item() \ + if hparams.distributed_run else loss.item() if hparams.fp16_run: optimizer.backward(loss) @@ -205,7 +208,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, else: loss.backward() grad_norm = torch.nn.utils.clip_grad_norm( - model.module.parameters(), hparams.grad_clip_thresh) + tacotron_model.parameters(), hparams.grad_clip_thresh) optimizer.step() @@ -222,7 +225,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, if not overflow and (iteration % hparams.iters_per_checkpoint == 0): reduced_val_loss = validate( model, criterion, valset, iteration, hparams.batch_size, - n_gpus, collate_fn, logger, hparams.distributed_run, rank) + n_gpus, collate_fn, logger, hparams.distributed_run, rank, tacotron_model.parse_batch) if rank == 0: print("Validation loss {}: {:9f} ".format(