diff --git a/train.py b/train.py index 27061d0..1fca517 100644 --- a/train.py +++ b/train.py @@ -28,10 +28,10 @@ def batchnorm_to_float(module): return module -def reduce_tensor(tensor, num_gpus): +def reduce_tensor(tensor, n_gpus): rt = tensor.clone() dist.all_reduce(rt, op=dist.reduce_op.SUM) - rt /= num_gpus + rt /= n_gpus return rt @@ -135,7 +135,7 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, y_pred = model(x) loss = criterion(y_pred, y) if distributed_run: - reduced_val_loss = reduce_tensor(loss.data, num_gpus).item() + reduced_val_loss = reduce_tensor(loss.data, n_gpus).item() else: reduced_val_loss = loss.item() val_loss += reduced_val_loss @@ -212,7 +212,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, loss = criterion(y_pred, y) if hparams.distributed_run: - reduced_loss = reduce_tensor(loss.data, num_gpus).item() + reduced_loss = reduce_tensor(loss.data, n_gpus).item() else: reduced_loss = loss.item()