diff --git a/train.py b/train.py index 7600ceb..27061d0 100644 --- a/train.py +++ b/train.py @@ -84,9 +84,7 @@ def load_model(hparams): model.decoder.attention_layer.score_mask_value = float(finfo('float16').min) if hparams.distributed_run: - model = DistributedDataParallel(model) - elif torch.cuda.device_count() > 1: - model = DataParallel(model) + model = apply_gradient_allreduce(model) return model