diff --git a/train.py b/train.py index 7103321..c64ea1c 100644 --- a/train.py +++ b/train.py @@ -128,8 +128,13 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, pin_memory=False, collate_fn=collate_fn) val_loss = 0.0 + if distributed_run or torch.cuda.device_count() > 1: + batch_parser = model.module.parse_batch + else: + batch_parser = model.parse_batch + for i, batch in enumerate(val_loader): - x, y = model.parse_batch(batch) + x, y = batch_parser(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \ @@ -157,6 +162,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, if hparams.distributed_run: init_distributed(hparams, n_gpus, rank, group_name) + torch.manual_seed(hparams.seed) torch.cuda.manual_seed(hparams.seed) @@ -188,6 +194,10 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, epoch_offset = max(0, int(iteration / len(train_loader))) model.train() + if distributed_run or torch.cuda.device_count() > 1: + batch_parser = model.module.parse_batch + else: + batch_parser = model.parse_batch # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, hparams.epochs): print("Epoch: {}".format(epoch)) @@ -197,7 +207,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, param_group['lr'] = learning_rate model.zero_grad() - x, y = model.parse_batch(batch) + x, y = batch_parser(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \