Browse Source

Fix model names

mistress
Daniel Muckerman 4 years ago
parent
commit
e5ec86c9fd
1 changed files with 2 additions and 4 deletions
  1. +2
    -4
      synthesize.py

+ 2
- 4
synthesize.py View File

@ -59,14 +59,12 @@ _ = model.eval()
print("This Tacotron model has been trained for ",torch.load(tacotron2_pretrained_model, map_location=torch.device('cpu'))['iteration']," Iterations.") print("This Tacotron model has been trained for ",torch.load(tacotron2_pretrained_model, map_location=torch.device('cpu'))['iteration']," Iterations.")
# Load WaveGlow model into GPU # Load WaveGlow model into GPU
waveglow_pretrained_model = 'squeezewave.pt'
# squeezewave = torch.load(waveglow_pretrained_model, map_location=torch.device('cpu'))['model']
waveglow_pretrained_model = 'squeezewave_dict.pt'
with open(join(project_name2, 'SqueezeWave/configs/config_a128_c256.json')) as f: with open(join(project_name2, 'SqueezeWave/configs/config_a128_c256.json')) as f:
data = f.read() data = f.read()
config = json.loads(data) config = json.loads(data)
waveglow = SqueezeWave(**config['squeezewave_config']) waveglow = SqueezeWave(**config['squeezewave_config'])
waveglow.load_state_dict(torch.load('squeezewave_dict.pt'), strict=False)
# waveglow.load_state_dict(squeezewave.state_dict(), strict=False)
waveglow.load_state_dict(torch.load(waveglow_pretrained_model), strict=False)
waveglow = waveglow.remove_weightnorm(waveglow) waveglow = waveglow.remove_weightnorm(waveglow)
waveglow.eval() waveglow.eval()
for k in waveglow.convinv: for k in waveglow.convinv:

Loading…
Cancel
Save