From 4d7b04120ab3d812a3a304efb62a701bf2a0b318 Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Wed, 5 Dec 2018 22:14:35 -0800 Subject: [PATCH] inference.ipynb: changing waverglow inference fo fp16 --- inference.ipynb | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/inference.ipynb b/inference.ipynb index dc5f8ce..3e25002 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -23,10 +23,7 @@ { "name": "stderr", "output_type": "stream", - "text": [ - "/home/dcg-adlr-rafaelvalle-source.cosmos597/repos/nvidia/tacotron2/plotting_utils.py:2: UserWarning: matplotlib.pyplot as already been imported, this call will have no effect.\n", - " matplotlib.use(\"Agg\")\n" - ] + "text": [] } ], "source": [ @@ -113,17 +110,15 @@ { "name": "stderr", "output_type": "stream", - "text": [ - "/opt/conda/lib/python3.6/site-packages/torch/serialization.py:425: SourceChangeWarning: source code of class 'glow_old.WaveGlow' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n", - "/opt/conda/lib/python3.6/site-packages/torch/serialization.py:425: SourceChangeWarning: source code of class 'glow_old.WN' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", - " warnings.warn(msg, SourceChangeWarning)\n" - ] + "text": [] } ], "source": [ "waveglow_path = 'waveglow_old.pt'\n", - "waveglow = torch.load(waveglow_path)['model']" + "waveglow = torch.load(waveglow_path)['model']\n", + "waveglow.cuda().half()\n", + "for k in waveglow.convinv:\n", + " k.float()" ] }, { @@ -210,7 +205,7 @@ ], "source": [ "with torch.no_grad():\n", - " audio = waveglow.infer(mel_outputs_postnet, sigma=0.666)\n", + " audio = waveglow.infer(mel_outputs_postnet.half(), sigma=0.666)\n", "ipd.Audio(audio[0].data.cpu().numpy(), rate=hparams.sampling_rate)" ] }