You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
4.9 KiB

6 years ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## Tacotron 2 inference code \n",
  8. "Edit the variables **checkpoint_path** and **text** to match yours and run the entire code to generate plots of mel outputs, alignments and audio synthesis from the generated mel-spectrogram using Griffin-Lim."
  9. ]
  10. },
  11. {
  12. "cell_type": "markdown",
  13. "metadata": {},
  14. "source": [
  15. "#### Import libraries and setup matplotlib"
  16. ]
  17. },
  18. {
  19. "cell_type": "code",
  20. "execution_count": 1,
  21. "metadata": {},
  22. "outputs": [
  23. {
  24. "name": "stderr",
  25. "output_type": "stream",
  26. "text": [
  27. "/home/dcg-adlr-rafaelvalle-source.cosmos597/repos/nvidia/tacotron2/plotting_utils.py:2: UserWarning: matplotlib.pyplot as already been imported, this call will have no effect.\n",
  28. " matplotlib.use(\"Agg\")\n"
  29. ]
  30. }
  31. ],
  32. "source": [
  33. "import matplotlib\n",
  34. "matplotlib.use(\"Agg\")\n",
  35. "import matplotlib.pylab as plt\n",
  36. "%matplotlib inline\n",
  37. "import IPython.display as ipd\n",
  38. "\n",
  39. "import sys\n",
  40. "sys.path.append('waveglow/')\n",
  41. "import numpy as np\n",
  42. "import torch\n",
  43. "\n",
  44. "from hparams import create_hparams\n",
  45. "from model import Tacotron2\n",
  46. "from layers import TacotronSTFT\n",
  47. "from audio_processing import griffin_lim\n",
  48. "from train import load_model\n",
  49. "from text import text_to_sequence\n"
  50. ]
  51. },
  52. {
  53. "cell_type": "code",
  54. "execution_count": 2,
  55. "metadata": {},
  56. "outputs": [],
  57. "source": [
  58. "def plot_data(data, figsize=(16, 4)):\n",
  59. " fig, axes = plt.subplots(1, len(data), figsize=figsize)\n",
  60. " for i in range(len(data)):\n",
  61. " axes[i].imshow(data[i], aspect='auto', origin='bottom', \n",
  62. " interpolation='none')"
  63. ]
  64. },
  65. {
  66. "cell_type": "markdown",
  67. "metadata": {},
  68. "source": [
  69. "#### Setup hparams"
  70. ]
  71. },
  72. {
  73. "cell_type": "code",
  74. "execution_count": null,
  75. "metadata": {},
  76. "outputs": [],
  77. "source": [
  78. "hparams = create_hparams(\"distributed_run=False,mask_padding=False\")\n",
  79. "hparams.sampling_rate = 22050\n",
  80. "hparams.filter_length = 1024\n",
  81. "hparams.hop_length = 256\n",
  82. "hparams.win_length = 1024"
  83. ]
  84. },
  85. {
  86. "cell_type": "markdown",
  87. "metadata": {},
  88. "source": [
  89. "#### Load model from checkpoint"
  90. ]
  91. },
  92. {
  93. "cell_type": "code",
  94. "execution_count": null,
  95. "metadata": {},
  96. "outputs": [],
  97. "source": [
  98. "checkpoint_path = \"tacotron2_statedict\"\n",
  99. "\n",
  100. "model = load_model(hparams)\n",
  101. "try:\n",
  102. " model = model.module\n",
  103. "except:\n",
  104. " pass\n",
  105. "model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(checkpoint_path)['state_dict'].items()})\n",
  106. "_ = model.eval()"
  107. ]
  108. },
  109. {
  110. "cell_type": "markdown",
  111. "metadata": {},
  112. "source": [
  113. "#### Load WaveGlow for mel2audio synthesis"
  114. ]
  115. },
  116. {
  117. "cell_type": "code",
  118. "execution_count": null,
  119. "metadata": {},
  120. "outputs": [],
  121. "source": [
  122. "waveglow_path = 'waveglow_old.pt'\n",
  123. "waveglow = torch.load(waveglow_path)['model']"
  124. ]
  125. },
  126. {
  127. "cell_type": "markdown",
  128. "metadata": {},
  129. "source": [
  130. "#### Prepare text input"
  131. ]
  132. },
  133. {
  134. "cell_type": "code",
  135. "execution_count": null,
  136. "metadata": {},
  137. "outputs": [],
  138. "source": [
  139. "text = \"Waveglow is really awesome!\"\n",
  140. "sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]\n",
  141. "sequence = torch.autograd.Variable(\n",
  142. " torch.from_numpy(sequence)).cuda().long()"
  143. ]
  144. },
  145. {
  146. "cell_type": "markdown",
  147. "metadata": {},
  148. "source": [
  149. "#### Decode text input and plot results"
  150. ]
  151. },
  152. {
  153. "cell_type": "code",
  154. "execution_count": null,
  155. "metadata": {
  156. "scrolled": true
  157. },
  158. "outputs": [],
  159. "source": [
  160. "mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)\n",
  161. "plot_data((mel_outputs.data.cpu().numpy()[0],\n",
  162. " mel_outputs_postnet.data.cpu().numpy()[0],\n",
  163. " alignments.data.cpu().numpy()[0].T))"
  164. ]
  165. },
  166. {
  167. "cell_type": "markdown",
  168. "metadata": {},
  169. "source": [
  170. "#### Synthesize audio from spectrogram using WaveGlow"
  171. ]
  172. },
  173. {
  174. "cell_type": "code",
  175. "execution_count": null,
  176. "metadata": {},
  177. "outputs": [],
  178. "source": [
  179. "with torch.no_grad():\n",
  180. " audio = waveglow.infer(mel_outputs_postnet, sigma=0.666)\n",
  181. "ipd.Audio(audio[0].data.cpu().numpy(), rate=hparams.sampling_rate)"
  182. ]
  183. }
  184. ],
  185. "metadata": {
  186. "kernelspec": {
  187. "display_name": "Python 3",
  188. "language": "python",
  189. "name": "python3"
  190. },
  191. "language_info": {
  192. "codemirror_mode": {
  193. "name": "ipython",
  194. "version": 3
  195. },
  196. "file_extension": ".py",
  197. "mimetype": "text/x-python",
  198. "name": "python",
  199. "nbconvert_exporter": "python",
  200. "pygments_lexer": "ipython3",
  201. "version": "3.6.6"
  202. }
  203. },
  204. "nbformat": 4,
  205. "nbformat_minor": 2
  206. }