From 6e430556bd4e1404c4dbf7cf4c790b4dd53ee93d Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Tue, 27 Nov 2018 22:03:11 -0800 Subject: [PATCH] train.py: val logger on gpu 0 only --- train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 1fca517..cd7635f 100644 --- a/train.py +++ b/train.py @@ -142,8 +142,9 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, val_loss = val_loss / (i + 1) model.train() - print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss)) - logger.log_validation(reduced_val_loss, model, y, y_pred, iteration) + if rank == 0: + print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss)) + logger.log_validation(reduced_val_loss, model, y, y_pred, iteration) def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, @@ -236,9 +237,9 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, reduced_loss, grad_norm, learning_rate, duration, iteration) if not overflow and (iteration % hparams.iters_per_checkpoint == 0): - validate(model, criterion, valset, iteration, hparams.batch_size, - n_gpus, collate_fn, logger, hparams.distributed_run, rank) - + validate(model, criterion, valset, iteration, + hparams.batch_size, n_gpus, collate_fn, logger, + hparams.distributed_run, rank) if rank == 0: checkpoint_path = os.path.join( output_directory, "checkpoint_{}".format(iteration))