|
|
@ -84,9 +84,7 @@ def load_model(hparams): |
|
|
|
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min) |
|
|
|
|
|
|
|
if hparams.distributed_run: |
|
|
|
model = DistributedDataParallel(model) |
|
|
|
elif torch.cuda.device_count() > 1: |
|
|
|
model = DataParallel(model) |
|
|
|
model = apply_gradient_allreduce(model) |
|
|
|
|
|
|
|
return model |
|
|
|
|
|
|
|