|
|
@ -143,7 +143,7 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, |
|
|
|
model.train() |
|
|
|
if rank == 0: |
|
|
|
print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss)) |
|
|
|
logger.log_validation(reduced_val_loss, model, y, y_pred, iteration) |
|
|
|
logger.log_validation(val_loss, model, y, y_pred, iteration) |
|
|
|
|
|
|
|
|
|
|
|
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, |
|
|
|