diff --git a/train.py b/train.py index c64ea1c..4a83921 100644 --- a/train.py +++ b/train.py @@ -194,7 +194,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, epoch_offset = max(0, int(iteration / len(train_loader))) model.train() - if distributed_run or torch.cuda.device_count() > 1: + if hparams.distributed_run or torch.cuda.device_count() > 1: batch_parser = model.module.parse_batch else: batch_parser = model.parse_batch