diff --git a/inference.ipynb b/inference.ipynb index 26e38d1..2aa9b50 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -98,8 +98,11 @@ "source": [ "checkpoint_path = \"/home/scratch.adlr-gcf/audio_denoising/runs/TTS-Tacotron2-LJS-MSE-DRC-NoMaskPadding-Unsorted-Distributed-22khz/checkpoint_15500\"\n", "model = load_model(hparams)\n", - "model.load_state_dict(torch.load(checkpoint_path)['state_dict'])\n", - "model = model.module\n", + "try:\n", + " model = model.module\n", + "except:\n", + " pass\n", + "model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(checkpoint_path)['state_dict'].items()})\n", "_ = model.eval()" ] },