diff --git a/synthesize.py b/synthesize.py index d5bec89..4c213f0 100644 --- a/synthesize.py +++ b/synthesize.py @@ -59,14 +59,12 @@ _ = model.eval() print("This Tacotron model has been trained for ",torch.load(tacotron2_pretrained_model, map_location=torch.device('cpu'))['iteration']," Iterations.") # Load WaveGlow model into GPU -waveglow_pretrained_model = 'squeezewave.pt' -# squeezewave = torch.load(waveglow_pretrained_model, map_location=torch.device('cpu'))['model'] +waveglow_pretrained_model = 'squeezewave_dict.pt' with open(join(project_name2, 'SqueezeWave/configs/config_a128_c256.json')) as f: data = f.read() config = json.loads(data) waveglow = SqueezeWave(**config['squeezewave_config']) -waveglow.load_state_dict(torch.load('squeezewave_dict.pt'), strict=False) -# waveglow.load_state_dict(squeezewave.state_dict(), strict=False) +waveglow.load_state_dict(torch.load(waveglow_pretrained_model), strict=False) waveglow = waveglow.remove_weightnorm(waveglow) waveglow.eval() for k in waveglow.convinv: