From f06063f74672ec982dc1b49602bb0da53e2c9fb4 Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Tue, 27 Nov 2018 18:04:12 -0800 Subject: [PATCH] train.py: renaming function, removing dataparallel --- train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train.py b/train.py index 7600ceb..27061d0 100644 --- a/train.py +++ b/train.py @@ -84,9 +84,7 @@ def load_model(hparams): model.decoder.attention_layer.score_mask_value = float(finfo('float16').min) if hparams.distributed_run: - model = DistributedDataParallel(model) - elif torch.cuda.device_count() > 1: - model = DataParallel(model) + model = apply_gradient_allreduce(model) return model