diff --git a/train.py b/train.py index 61f0888..88b0949 100644 --- a/train.py +++ b/train.py @@ -142,7 +142,7 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, model.train() if rank == 0: - print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss)) + print("Validation loss {}: {:9f} ".format(iteration, val_loss)) logger.log_validation(val_loss, model, y, y_pred, iteration)