Browse Source

train.py: val logger on gpu 0 only

master
rafaelvalle 6 years ago
parent
commit
6e430556bd
1 changed files with 6 additions and 5 deletions
  1. +6
    -5
      train.py

+ 6
- 5
train.py View File

@ -142,8 +142,9 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus,
val_loss = val_loss / (i + 1) val_loss = val_loss / (i + 1)
model.train() model.train()
print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss))
logger.log_validation(reduced_val_loss, model, y, y_pred, iteration)
if rank == 0:
print("Validation loss {}: {:9f} ".format(iteration, reduced_val_loss))
logger.log_validation(reduced_val_loss, model, y, y_pred, iteration)
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
@ -236,9 +237,9 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
reduced_loss, grad_norm, learning_rate, duration, iteration) reduced_loss, grad_norm, learning_rate, duration, iteration)
if not overflow and (iteration % hparams.iters_per_checkpoint == 0): if not overflow and (iteration % hparams.iters_per_checkpoint == 0):
validate(model, criterion, valset, iteration, hparams.batch_size,
n_gpus, collate_fn, logger, hparams.distributed_run, rank)
validate(model, criterion, valset, iteration,
hparams.batch_size, n_gpus, collate_fn, logger,
hparams.distributed_run, rank)
if rank == 0: if rank == 0:
checkpoint_path = os.path.join( checkpoint_path = os.path.join(
output_directory, "checkpoint_{}".format(iteration)) output_directory, "checkpoint_{}".format(iteration))

Loading…
Cancel
Save