From 646ab0d8c868c094d46a9feeb0549dccad0f9499 Mon Sep 17 00:00:00 2001 From: Rafael Valle Date: Thu, 3 May 2018 19:42:37 -0700 Subject: [PATCH] model.py removing top of three code, cleanup --- train.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/train.py b/train.py index f914cfc..7103321 100644 --- a/train.py +++ b/train.py @@ -74,22 +74,23 @@ def prepare_directories_and_logger(output_directory, log_directory, rank): logger = None return logger + def load_model(hparams): model = Tacotron2(hparams).cuda() model = batchnorm_to_float(model.half()) if hparams.fp16_run else model - tacotron_model = model if hparams.distributed_run: model = DistributedDataParallel(model) elif torch.cuda.device_count() > 1: model = DataParallel(model) - return model, tacotron + return model + def warm_start_model(checkpoint_path, model): assert os.path.isfile(checkpoint_path) print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) - checkpoint_dict = torch.load(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint_dict['state_dict']) return model @@ -117,7 +118,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): def validate(model, criterion, valset, iteration, batch_size, n_gpus, - collate_fn, logger, distributed_run, rank, batch_parser): + collate_fn, logger, distributed_run, rank): """Handles all the validation scoring and printing""" model.eval() with torch.no_grad(): @@ -128,7 +129,7 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, val_loss = 0.0 for i, batch in enumerate(val_loader): - x, y = batch_parser(batch) + x, y = model.parse_batch(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \ @@ -196,11 +197,11 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, param_group['lr'] = learning_rate model.zero_grad() - x, y = tacotron_model.parse_batch(batch) + x, y = model.parse_batch(batch) y_pred = model(x) loss = criterion(y_pred, y) - reduced_loss = reduce_tensor(loss.data, n_gpus).item() \ - if hparams.distributed_run else loss.item() + reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \ + if hparams.distributed_run else loss.data[0] if hparams.fp16_run: optimizer.backward(loss) @@ -208,7 +209,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, else: loss.backward() grad_norm = torch.nn.utils.clip_grad_norm( - tacotron_model.parameters(), hparams.grad_clip_thresh) + model.parameters(), hparams.grad_clip_thresh) optimizer.step() @@ -225,7 +226,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, if not overflow and (iteration % hparams.iters_per_checkpoint == 0): reduced_val_loss = validate( model, criterion, valset, iteration, hparams.batch_size, - n_gpus, collate_fn, logger, hparams.distributed_run, rank, tacotron_model.parse_batch) + n_gpus, collate_fn, logger, hparams.distributed_run, rank) if rank == 0: print("Validation loss {}: {:9f} ".format(