diff --git a/inference.ipynb b/inference.ipynb index 7785cd2..2e94ff3 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -98,8 +98,11 @@ "source": [ "checkpoint_path = \"/home/scratch.adlr-gcf/audio_denoising/runs/TTS-Tacotron2-LJS-MSE-DRC-NoMaskPadding-Unsorted-Distributed-22khz/checkpoint_15500\"\n", "model = load_model(hparams)\n", + "try:\n", + " model = model.module\n", + "except:\n", + " pass\n" "model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(checkpoint_path)['state_dict'].items()})\n", - "model = model.module\n", "_ = model.eval()" ] },