From ecfe19659827afbe47806e67ab51efe3cd7b70a4 Mon Sep 17 00:00:00 2001 From: alokprasad Date: Tue, 10 Mar 2020 17:17:17 +0530 Subject: [PATCH] changes for squeezewave and non cuda --- FastSpeech/synthesis.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/FastSpeech/synthesis.py b/FastSpeech/synthesis.py index 4cc91db..a9ff0fe 100644 --- a/FastSpeech/synthesis.py +++ b/FastSpeech/synthesis.py @@ -13,6 +13,9 @@ import utils import audio as Audio import glow import waveglow +import time + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -21,7 +24,7 @@ def get_FastSpeech(num): checkpoint_path = "checkpoint_" + str(num) + ".pth.tar" model = nn.DataParallel(FastSpeech()).to(device) model.load_state_dict(torch.load(os.path.join( - hp.checkpoint_path, checkpoint_path))['model']) + hp.checkpoint_path, checkpoint_path),map_location=device)['model']) model.eval() return model @@ -35,11 +38,15 @@ def synthesis(model, text, alpha=1.0): src_pos = np.stack([src_pos]) with torch.no_grad(): sequence = torch.autograd.Variable( - torch.from_numpy(text)).cuda().long() + torch.from_numpy(text)).long() src_pos = torch.autograd.Variable( - torch.from_numpy(src_pos)).cuda().long() + torch.from_numpy(src_pos)).long() mel, mel_postnet = model.module.forward(sequence, src_pos, alpha=alpha) + + #script for generating torch script + #traced_script_module = torch.jit.trace(model,(sequence,src_pos)) + #traced_script_module.save("traced_fastspeech_model.pt") return mel[0].cpu().transpose(0, 1), \ mel_postnet[0].cpu().transpose(0, 1), \ @@ -54,11 +61,15 @@ if __name__ == "__main__": model = get_FastSpeech(num) words = "Let’s go out to the airport. The plane landed ten minutes ago." + start = time.time() mel, mel_postnet, mel_torch, mel_postnet_torch = synthesis( model, words, alpha=alpha) if not os.path.exists("results"): os.mkdir("results") + + #do not use any vocoder , mel file generated will be passed to squeezewave vocoder. + """ Audio.tools.inv_mel_spec(mel_postnet, os.path.join( "results", words + "_" + str(num) + "_griffin_lim.wav")) @@ -70,5 +81,13 @@ if __name__ == "__main__": mel_tac2, _, _ = utils.load_data_from_tacotron2(words, tacotron2) waveglow.inference.inference(torch.stack([torch.from_numpy( mel_tac2).cuda()]), wave_glow, os.path.join("results", "tacotron2.wav")) - utils.plot_data([mel.numpy(), mel_postnet.numpy(), mel_tac2]) + + """ + #melspec = torch.squeeze(mel_postnet_torch, 0) + torch.save(mel_postnet_torch, "../SqueezeWave/mel_spectrograms/test.pt") + end = time.time() + print("MEL Calculation:") + print(end-start) + +