diff --git a/train.py b/train.py index e93917b..61f0888 100644 --- a/train.py +++ b/train.py @@ -143,7 +143,7 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, model.train() if rank == 0: print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss)) - logger.log_validation(reduced_val_loss, model, y, y_pred, iteration) + logger.log_validation(val_loss, model, y, y_pred, iteration) def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,