Browse Source

train.py: renaming function, removing dataparallel

master
rafaelvalle 6 years ago
parent
commit
f06063f746
1 changed files with 1 additions and 3 deletions
  1. +1
    -3
      train.py

+ 1
- 3
train.py View File

@ -84,9 +84,7 @@ def load_model(hparams):
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min) model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
if hparams.distributed_run: if hparams.distributed_run:
model = DistributedDataParallel(model)
elif torch.cuda.device_count() > 1:
model = DataParallel(model)
model = apply_gradient_allreduce(model)
return model return model

Loading…
Cancel
Save