You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
4.9 KiB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
  1. import os
  2. from os.path import exists, join, basename, splitext
  3. git_repo_url = 'https://github.com/NVIDIA/tacotron2.git'
  4. project_name = splitext(basename(git_repo_url))[0]
  5. git_repo_url2 = 'https://github.com/alokprasad/fastspeech_squeezewave.git'
  6. project_name2 = splitext(basename(git_repo_url2))[0]
  7. import sys
  8. sys.path.append(join(project_name2, "SqueezeWave/"))
  9. sys.path.append(project_name)
  10. import numpy as np
  11. import torch
  12. from hparams import create_hparams
  13. from model import Tacotron2
  14. from text import text_to_sequence
  15. from denoiser import Denoiser
  16. from glow import SqueezeWave
  17. import librosa
  18. import json
  19. thisdict = {}
  20. for line in reversed((open('merged.dict_1.1.txt', "r").read()).splitlines()):
  21. thisdict[(line.split(" ",1))[0]] = (line.split(" ",1))[1].strip()
  22. def ARPA(text):
  23. out = ''
  24. for word_ in text.split(" "):
  25. word=word_; end_chars = ''
  26. while any(elem in word for elem in r"!?,.;") and len(word) > 1:
  27. if word[-1] == '!': end_chars = '!' + end_chars; word = word[:-1]
  28. if word[-1] == '?': end_chars = '?' + end_chars; word = word[:-1]
  29. if word[-1] == ',': end_chars = ',' + end_chars; word = word[:-1]
  30. if word[-1] == '.': end_chars = '.' + end_chars; word = word[:-1]
  31. if word[-1] == ';': end_chars = ';' + end_chars; word = word[:-1]
  32. else: break
  33. try: word_arpa = thisdict[word.upper()]
  34. except: word_arpa = ''
  35. if len(word_arpa)!=0: word = "{" + str(word_arpa) + "}"
  36. out = (out + " " + word + end_chars).strip()
  37. if out[-1] != ";": out = out + ";"
  38. return out
  39. #torch.set_grad_enabled(False)
  40. # initialize Tacotron2 with the pretrained model
  41. hparams = create_hparams()
  42. tacotron2_pretrained_model = 'tacotron.pt'
  43. # Setup Parameters
  44. hparams = create_hparams()
  45. hparams.sampling_rate = 22050
  46. hparams.max_decoder_steps = 3000 # how many steps before cutting off generation, too many and you may get CUDA errors.
  47. hparams.gate_threshold = 0.30 # Model must be 30% sure the clip is over before ending generation
  48. # Load Tacotron2 model into GPU
  49. model = Tacotron2(hparams)
  50. model.load_state_dict(torch.load(tacotron2_pretrained_model, map_location=torch.device('cpu'))['state_dict'])
  51. _ = model.eval()
  52. print("This Tacotron model has been trained for ",torch.load(tacotron2_pretrained_model, map_location=torch.device('cpu'))['iteration']," Iterations.")
  53. # Load WaveGlow model into GPU
  54. waveglow_pretrained_model = 'squeezewave_dict.pt'
  55. with open(join(project_name2, 'SqueezeWave/configs/config_a128_c256.json')) as f:
  56. data = f.read()
  57. config = json.loads(data)
  58. waveglow = SqueezeWave(**config['squeezewave_config'])
  59. waveglow.load_state_dict(torch.load(waveglow_pretrained_model), strict=False)
  60. waveglow = waveglow.remove_weightnorm(waveglow)
  61. waveglow.eval()
  62. for k in waveglow.convinv:
  63. k.float()
  64. denoiser = Denoiser(waveglow)
  65. print("SqueezeWave model loaded")
  66. import time
  67. # All right, I've been thinking. , When life gives you lemons? , Don't make lemonade. , Make life take the lemons back! , Get mad! , 'I don't want your damn lemons! What am I supposed to do with these?' , Demand to see life's manager! , Make life rue the day it thought it could give Cave Johnson lemons! , Do you know who I am? , I'm the man who's going to burn your house down! , With the lemons! , I'm going to get my engineers to invent a combustible lemon that burns your house down!
  68. text = """
  69. Peter Piper picked a peck of pickled peppers, A peck of pickled peppers Peter Piper picked; If Peter Piper picked a peck of pickled peppers, wheres the peck of pickled peppers Peter Piper picked?
  70. She sells sea shells by the seashore, The shells she sells are sea shells, Im sure. So if she sells sea shells on the seashore, Then Im sure she sells seashore shells.
  71. """
  72. sigma = 0.75
  73. denoise_strength = 0.01
  74. raw_input = False # disables automatic ARPAbet conversion, useful for inputting your own ARPAbet pronounciations or just for testing
  75. counter = 0
  76. for i in text.split("\n"):
  77. start_time = time.time()
  78. if len(i) < 1: continue;
  79. print(i)
  80. if raw_input:
  81. if i[-1] != ";": i=i+";"
  82. else: i = ARPA(i)
  83. print(i)
  84. with torch.no_grad(): # save VRAM by not including gradients
  85. sequence = np.array(text_to_sequence(i, ['english_cleaners']))[None, :]
  86. sequence = torch.autograd.Variable(torch.from_numpy(sequence)).long()
  87. mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
  88. audio = waveglow.infer(mel_outputs_postnet, sigma=sigma); print("");
  89. audio_denoised = denoiser(audio, strength=denoise_strength)[:, 0]; print("Denoised");
  90. # librosa.output.write_wav('Inf_' + str(counter) + '.wav', np.swapaxes(audio.cpu().numpy(),0,1), hparams.sampling_rate)
  91. librosa.output.write_wav('Inf_' + str(counter) + '_denoised.wav', np.swapaxes(audio_denoised.cpu().numpy(),0,1), hparams.sampling_rate)
  92. counter += 1
  93. print("--- %s seconds ---" % (time.time() - start_time))