|
|
@ -13,6 +13,9 @@ import utils |
|
|
|
import audio as Audio |
|
|
|
import glow |
|
|
|
import waveglow |
|
|
|
import time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
@ -21,7 +24,7 @@ def get_FastSpeech(num): |
|
|
|
checkpoint_path = "checkpoint_" + str(num) + ".pth.tar" |
|
|
|
model = nn.DataParallel(FastSpeech()).to(device) |
|
|
|
model.load_state_dict(torch.load(os.path.join( |
|
|
|
hp.checkpoint_path, checkpoint_path))['model']) |
|
|
|
hp.checkpoint_path, checkpoint_path),map_location=device)['model']) |
|
|
|
model.eval() |
|
|
|
|
|
|
|
return model |
|
|
@ -35,11 +38,15 @@ def synthesis(model, text, alpha=1.0): |
|
|
|
src_pos = np.stack([src_pos]) |
|
|
|
with torch.no_grad(): |
|
|
|
sequence = torch.autograd.Variable( |
|
|
|
torch.from_numpy(text)).cuda().long() |
|
|
|
torch.from_numpy(text)).long() |
|
|
|
src_pos = torch.autograd.Variable( |
|
|
|
torch.from_numpy(src_pos)).cuda().long() |
|
|
|
torch.from_numpy(src_pos)).long() |
|
|
|
|
|
|
|
mel, mel_postnet = model.module.forward(sequence, src_pos, alpha=alpha) |
|
|
|
|
|
|
|
#script for generating torch script |
|
|
|
#traced_script_module = torch.jit.trace(model,(sequence,src_pos)) |
|
|
|
#traced_script_module.save("traced_fastspeech_model.pt") |
|
|
|
|
|
|
|
return mel[0].cpu().transpose(0, 1), \ |
|
|
|
mel_postnet[0].cpu().transpose(0, 1), \ |
|
|
@ -54,11 +61,15 @@ if __name__ == "__main__": |
|
|
|
model = get_FastSpeech(num) |
|
|
|
words = "Let’s go out to the airport. The plane landed ten minutes ago." |
|
|
|
|
|
|
|
start = time.time() |
|
|
|
mel, mel_postnet, mel_torch, mel_postnet_torch = synthesis( |
|
|
|
model, words, alpha=alpha) |
|
|
|
|
|
|
|
if not os.path.exists("results"): |
|
|
|
os.mkdir("results") |
|
|
|
|
|
|
|
#do not use any vocoder , mel file generated will be passed to squeezewave vocoder. |
|
|
|
""" |
|
|
|
Audio.tools.inv_mel_spec(mel_postnet, os.path.join( |
|
|
|
"results", words + "_" + str(num) + "_griffin_lim.wav")) |
|
|
|
|
|
|
@ -70,5 +81,13 @@ if __name__ == "__main__": |
|
|
|
mel_tac2, _, _ = utils.load_data_from_tacotron2(words, tacotron2) |
|
|
|
waveglow.inference.inference(torch.stack([torch.from_numpy( |
|
|
|
mel_tac2).cuda()]), wave_glow, os.path.join("results", "tacotron2.wav")) |
|
|
|
|
|
|
|
utils.plot_data([mel.numpy(), mel_postnet.numpy(), mel_tac2]) |
|
|
|
|
|
|
|
""" |
|
|
|
#melspec = torch.squeeze(mel_postnet_torch, 0) |
|
|
|
torch.save(mel_postnet_torch, "../SqueezeWave/mel_spectrograms/test.pt") |
|
|
|
end = time.time() |
|
|
|
print("MEL Calculation:") |
|
|
|
print(end-start) |
|
|
|
|
|
|
|
|