diff --git a/README.md b/README.md index 107e713..3a904d0 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Distributed and FP16 support relies on work by Christian Sarofeen and NVIDIA's 2. (OPTIONAL) `tensorboard --logdir=outdir/logdir` ## Multi-GPU (distributed) and FP16 Training -1. `python -m multiproc train.py --output_directory=/outdir --log_directory=/logdir --hparams=distributed_run=True --fp16_run=True` +1. `python -m multiproc train.py --output_directory=outdir --log_directory=logdir --hparams=distributed_run=True,fp16_run=True` ## Inference 1. `jupyter notebook --ip=127.0.0.1 --port=31337` diff --git a/train.py b/train.py index dd5de2e..ee01b07 100644 --- a/train.py +++ b/train.py @@ -78,15 +78,19 @@ def prepare_directories_and_logger(output_directory, log_directory, rank): def load_model(hparams): model = Tacotron2(hparams).cuda() model = batchnorm_to_float(model.half()) if hparams.fp16_run else model - model = DistributedDataParallel(model) \ - if hparams.distributed_run else DataParallel(model) + + if hparams.distributed_run: + model = DistributedDataParallel(model) + elif torch.cuda.device_count() > 1: + model = DataParallel(model) + return model def warm_start_model(checkpoint_path, model): assert os.path.isfile(checkpoint_path) print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) - checkpoint_dict = torch.load(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(checkpoint_dict['state_dict']) return model @@ -124,8 +128,13 @@ def validate(model, criterion, valset, iteration, batch_size, n_gpus, pin_memory=False, collate_fn=collate_fn) val_loss = 0.0 + if distributed_run or torch.cuda.device_count() > 1: + batch_parser = model.module.parse_batch + else: + batch_parser = model.parse_batch + for i, batch in enumerate(val_loader): - x, y = model.module.parse_batch(batch) + x, y = batch_parser(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_val_loss = reduce_tensor(loss.data, n_gpus)[0] \ @@ -184,6 +193,10 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, epoch_offset = max(0, int(iteration / len(train_loader))) model.train() + if hparams.distributed_run or torch.cuda.device_count() > 1: + batch_parser = model.module.parse_batch + else: + batch_parser = model.parse_batch # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, hparams.epochs): print("Epoch: {}".format(epoch)) @@ -193,7 +206,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, param_group['lr'] = learning_rate model.zero_grad() - x, y = model.module.parse_batch(batch) + x, y = batch_parser(batch) y_pred = model(x) loss = criterion(y_pred, y) reduced_loss = reduce_tensor(loss.data, n_gpus)[0] \ @@ -205,7 +218,7 @@ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, else: loss.backward() grad_norm = torch.nn.utils.clip_grad_norm( - model.module.parameters(), hparams.grad_clip_thresh) + model.parameters(), hparams.grad_clip_thresh) optimizer.step()