From 78d5150d83e74432d97f06ac1030e6c272f5a78b Mon Sep 17 00:00:00 2001 From: Raul Puri Date: Sat, 5 May 2018 17:23:11 -0700 Subject: [PATCH 1/3] inference (distributed) dataparallel patch removing the '.module' that comes from (distibuted)dataparallel state dict --- inference.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference.ipynb b/inference.ipynb index 26e38d1..7785cd2 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -98,7 +98,7 @@ "source": [ "checkpoint_path = \"/home/scratch.adlr-gcf/audio_denoising/runs/TTS-Tacotron2-LJS-MSE-DRC-NoMaskPadding-Unsorted-Distributed-22khz/checkpoint_15500\"\n", "model = load_model(hparams)\n", - "model.load_state_dict(torch.load(checkpoint_path)['state_dict'])\n", + "model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(checkpoint_path)['state_dict'].items()})\n", "model = model.module\n", "_ = model.eval()" ] From c67ca6531edfedb2df16e188db16a3ec3be035f1 Mon Sep 17 00:00:00 2001 From: Raul Puri Date: Sat, 5 May 2018 17:29:09 -0700 Subject: [PATCH 2/3] force single gpu in inference.ipynb --- inference.ipynb | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/inference.ipynb b/inference.ipynb index 7785cd2..2e94ff3 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -98,8 +98,11 @@ "source": [ "checkpoint_path = \"/home/scratch.adlr-gcf/audio_denoising/runs/TTS-Tacotron2-LJS-MSE-DRC-NoMaskPadding-Unsorted-Distributed-22khz/checkpoint_15500\"\n", "model = load_model(hparams)\n", + "try:\n", + " model = model.module\n", + "except:\n", + " pass\n" "model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(checkpoint_path)['state_dict'].items()})\n", - "model = model.module\n", "_ = model.eval()" ] }, From 4ac6ce9ab51f43bbb21993e995b87e4db13f02e4 Mon Sep 17 00:00:00 2001 From: Raul Puri Date: Sat, 5 May 2018 17:30:08 -0700 Subject: [PATCH 3/3] ipynb typo --- inference.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference.ipynb b/inference.ipynb index 2e94ff3..2aa9b50 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -101,7 +101,7 @@ "try:\n", " model = model.module\n", "except:\n", - " pass\n" + " pass\n", "model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(checkpoint_path)['state_dict'].items()})\n", "_ = model.eval()" ]