|
@ -98,8 +98,11 @@ |
|
|
"source": [ |
|
|
"source": [ |
|
|
"checkpoint_path = \"/home/scratch.adlr-gcf/audio_denoising/runs/TTS-Tacotron2-LJS-MSE-DRC-NoMaskPadding-Unsorted-Distributed-22khz/checkpoint_15500\"\n", |
|
|
"checkpoint_path = \"/home/scratch.adlr-gcf/audio_denoising/runs/TTS-Tacotron2-LJS-MSE-DRC-NoMaskPadding-Unsorted-Distributed-22khz/checkpoint_15500\"\n", |
|
|
"model = load_model(hparams)\n", |
|
|
"model = load_model(hparams)\n", |
|
|
|
|
|
"try:\n", |
|
|
|
|
|
" model = model.module\n", |
|
|
|
|
|
"except:\n", |
|
|
|
|
|
" pass\n" |
|
|
"model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(checkpoint_path)['state_dict'].items()})\n", |
|
|
"model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(checkpoint_path)['state_dict'].items()})\n", |
|
|
"model = model.module\n", |
|
|
|
|
|
"_ = model.eval()" |
|
|
"_ = model.eval()" |
|
|
] |
|
|
] |
|
|
}, |
|
|
}, |
|
|