Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

533 lines
20 KiB

  1. from math import sqrt
  2. import torch
  3. from torch.autograd import Variable
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from tacotron2.layers import ConvNorm, LinearNorm
  7. from tacotron2.utils import to_gpu, get_mask_from_lengths
  8. class LocationLayer(nn.Module):
  9. def __init__(self, attention_n_filters, attention_kernel_size,
  10. attention_dim):
  11. super(LocationLayer, self).__init__()
  12. padding = int((attention_kernel_size - 1) / 2)
  13. self.location_conv = ConvNorm(2, attention_n_filters,
  14. kernel_size=attention_kernel_size,
  15. padding=padding, bias=False, stride=1,
  16. dilation=1)
  17. self.location_dense = LinearNorm(attention_n_filters, attention_dim,
  18. bias=False, w_init_gain='tanh')
  19. def forward(self, attention_weights_cat):
  20. processed_attention = self.location_conv(attention_weights_cat)
  21. processed_attention = processed_attention.transpose(1, 2)
  22. processed_attention = self.location_dense(processed_attention)
  23. return processed_attention
  24. class Attention(nn.Module):
  25. def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
  26. attention_location_n_filters, attention_location_kernel_size):
  27. super(Attention, self).__init__()
  28. self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
  29. bias=False, w_init_gain='tanh')
  30. self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
  31. w_init_gain='tanh')
  32. self.v = LinearNorm(attention_dim, 1, bias=False)
  33. self.location_layer = LocationLayer(attention_location_n_filters,
  34. attention_location_kernel_size,
  35. attention_dim)
  36. self.score_mask_value = -float("inf")
  37. def get_alignment_energies(self, query, processed_memory,
  38. attention_weights_cat):
  39. """
  40. PARAMS
  41. ------
  42. query: decoder output (batch, n_mel_channels * n_frames_per_step)
  43. processed_memory: processed encoder outputs (B, T_in, attention_dim)
  44. attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
  45. RETURNS
  46. -------
  47. alignment (batch, max_time)
  48. """
  49. processed_query = self.query_layer(query.unsqueeze(1))
  50. processed_attention_weights = self.location_layer(
  51. attention_weights_cat)
  52. energies = self.v(torch.tanh(
  53. processed_query + processed_attention_weights + processed_memory))
  54. energies = energies.squeeze(-1)
  55. return energies
  56. def forward(self, attention_hidden_state, memory, processed_memory,
  57. attention_weights_cat, mask):
  58. """
  59. PARAMS
  60. ------
  61. attention_hidden_state: attention rnn last output
  62. memory: encoder outputs
  63. processed_memory: processed encoder outputs
  64. attention_weights_cat: previous and cummulative attention weights
  65. mask: binary mask for padded data
  66. """
  67. alignment = self.get_alignment_energies(
  68. attention_hidden_state, processed_memory, attention_weights_cat)
  69. if mask is not None:
  70. alignment.data.masked_fill_(mask, self.score_mask_value)
  71. attention_weights = F.softmax(alignment, dim=1)
  72. attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
  73. attention_context = attention_context.squeeze(1)
  74. return attention_context, attention_weights
  75. class Prenet(nn.Module):
  76. def __init__(self, in_dim, sizes):
  77. super(Prenet, self).__init__()
  78. in_sizes = [in_dim] + sizes[:-1]
  79. self.layers = nn.ModuleList(
  80. [LinearNorm(in_size, out_size, bias=False)
  81. for (in_size, out_size) in zip(in_sizes, sizes)])
  82. def forward(self, x):
  83. for linear in self.layers:
  84. x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
  85. return x
  86. class Postnet(nn.Module):
  87. """Postnet
  88. - Five 1-d convolution with 512 channels and kernel size 5
  89. """
  90. def __init__(self, hparams):
  91. super(Postnet, self).__init__()
  92. self.convolutions = nn.ModuleList()
  93. self.convolutions.append(
  94. nn.Sequential(
  95. ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
  96. kernel_size=hparams.postnet_kernel_size, stride=1,
  97. padding=int((hparams.postnet_kernel_size - 1) / 2),
  98. dilation=1, w_init_gain='tanh'),
  99. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  100. )
  101. for i in range(1, hparams.postnet_n_convolutions - 1):
  102. self.convolutions.append(
  103. nn.Sequential(
  104. ConvNorm(hparams.postnet_embedding_dim,
  105. hparams.postnet_embedding_dim,
  106. kernel_size=hparams.postnet_kernel_size, stride=1,
  107. padding=int(
  108. (hparams.postnet_kernel_size - 1) / 2),
  109. dilation=1, w_init_gain='tanh'),
  110. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  111. )
  112. self.convolutions.append(
  113. nn.Sequential(
  114. ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
  115. kernel_size=hparams.postnet_kernel_size, stride=1,
  116. padding=int((hparams.postnet_kernel_size - 1) / 2),
  117. dilation=1, w_init_gain='linear'),
  118. nn.BatchNorm1d(hparams.n_mel_channels))
  119. )
  120. def forward(self, x):
  121. for i in range(len(self.convolutions) - 1):
  122. x = F.dropout(torch.tanh(
  123. self.convolutions[i](x)), 0.5, self.training)
  124. x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
  125. return x
  126. class Encoder(nn.Module):
  127. """Encoder module:
  128. - Three 1-d convolution banks
  129. - Bidirectional LSTM
  130. """
  131. def __init__(self, hparams):
  132. super(Encoder, self).__init__()
  133. convolutions = []
  134. for _ in range(hparams.encoder_n_convolutions):
  135. conv_layer = nn.Sequential(
  136. ConvNorm(hparams.encoder_embedding_dim,
  137. hparams.encoder_embedding_dim,
  138. kernel_size=hparams.encoder_kernel_size, stride=1,
  139. padding=int((hparams.encoder_kernel_size - 1) / 2),
  140. dilation=1, w_init_gain='relu'),
  141. nn.BatchNorm1d(hparams.encoder_embedding_dim))
  142. convolutions.append(conv_layer)
  143. self.convolutions = nn.ModuleList(convolutions)
  144. self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
  145. int(hparams.encoder_embedding_dim / 2), 1,
  146. batch_first=True, bidirectional=True)
  147. def forward(self, x, input_lengths):
  148. for conv in self.convolutions:
  149. x = F.dropout(F.relu(conv(x)), 0.5, self.training)
  150. x = x.transpose(1, 2)
  151. # pytorch tensor are not reversible, hence the conversion
  152. input_lengths = input_lengths.cpu().numpy()
  153. x = nn.utils.rnn.pack_padded_sequence(
  154. x, input_lengths, batch_first=True)
  155. self.lstm.flatten_parameters()
  156. outputs, _ = self.lstm(x)
  157. outputs, _ = nn.utils.rnn.pad_packed_sequence(
  158. outputs, batch_first=True)
  159. return outputs
  160. def inference(self, x):
  161. for conv in self.convolutions:
  162. x = F.dropout(F.relu(conv(x)), 0.5, self.training)
  163. x = x.transpose(1, 2)
  164. self.lstm.flatten_parameters()
  165. outputs, _ = self.lstm(x)
  166. return outputs
  167. class Decoder(nn.Module):
  168. def __init__(self, hparams):
  169. super(Decoder, self).__init__()
  170. self.n_mel_channels = hparams.n_mel_channels
  171. self.n_frames_per_step = hparams.n_frames_per_step
  172. self.encoder_embedding_dim = hparams.encoder_embedding_dim
  173. self.attention_rnn_dim = hparams.attention_rnn_dim
  174. self.decoder_rnn_dim = hparams.decoder_rnn_dim
  175. self.prenet_dim = hparams.prenet_dim
  176. self.max_decoder_steps = hparams.max_decoder_steps
  177. self.gate_threshold = hparams.gate_threshold
  178. self.p_attention_dropout = hparams.p_attention_dropout
  179. self.p_decoder_dropout = hparams.p_decoder_dropout
  180. self.prenet = Prenet(
  181. hparams.n_mel_channels * hparams.n_frames_per_step,
  182. [hparams.prenet_dim, hparams.prenet_dim])
  183. self.attention_rnn = nn.LSTMCell(
  184. hparams.prenet_dim + hparams.encoder_embedding_dim,
  185. hparams.attention_rnn_dim)
  186. self.attention_layer = Attention(
  187. hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
  188. hparams.attention_dim, hparams.attention_location_n_filters,
  189. hparams.attention_location_kernel_size)
  190. self.decoder_rnn = nn.LSTMCell(
  191. hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
  192. hparams.decoder_rnn_dim, 1)
  193. self.linear_projection = LinearNorm(
  194. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
  195. hparams.n_mel_channels * hparams.n_frames_per_step)
  196. self.gate_layer = LinearNorm(
  197. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
  198. bias=True, w_init_gain='sigmoid')
  199. def get_go_frame(self, memory):
  200. """ Gets all zeros frames to use as first decoder input
  201. PARAMS
  202. ------
  203. memory: decoder outputs
  204. RETURNS
  205. -------
  206. decoder_input: all zeros frames
  207. """
  208. B = memory.size(0)
  209. decoder_input = Variable(memory.data.new(
  210. B, self.n_mel_channels * self.n_frames_per_step).zero_())
  211. return decoder_input
  212. def initialize_decoder_states(self, memory, mask):
  213. """ Initializes attention rnn states, decoder rnn states, attention
  214. weights, attention cumulative weights, attention context, stores memory
  215. and stores processed memory
  216. PARAMS
  217. ------
  218. memory: Encoder outputs
  219. mask: Mask for padded data if training, expects None for inference
  220. """
  221. B = memory.size(0)
  222. MAX_TIME = memory.size(1)
  223. self.attention_hidden = Variable(memory.data.new(
  224. B, self.attention_rnn_dim).zero_())
  225. self.attention_cell = Variable(memory.data.new(
  226. B, self.attention_rnn_dim).zero_())
  227. self.decoder_hidden = Variable(memory.data.new(
  228. B, self.decoder_rnn_dim).zero_())
  229. self.decoder_cell = Variable(memory.data.new(
  230. B, self.decoder_rnn_dim).zero_())
  231. self.attention_weights = Variable(memory.data.new(
  232. B, MAX_TIME).zero_())
  233. self.attention_weights_cum = Variable(memory.data.new(
  234. B, MAX_TIME).zero_())
  235. self.attention_context = Variable(memory.data.new(
  236. B, self.encoder_embedding_dim).zero_())
  237. self.memory = memory
  238. self.processed_memory = self.attention_layer.memory_layer(memory)
  239. self.mask = mask
  240. def parse_decoder_inputs(self, decoder_inputs):
  241. """ Prepares decoder inputs, i.e. mel outputs
  242. PARAMS
  243. ------
  244. decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
  245. RETURNS
  246. -------
  247. inputs: processed decoder inputs
  248. """
  249. # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
  250. decoder_inputs = decoder_inputs.transpose(1, 2)
  251. decoder_inputs = decoder_inputs.view(
  252. decoder_inputs.size(0),
  253. int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
  254. # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
  255. decoder_inputs = decoder_inputs.transpose(0, 1)
  256. return decoder_inputs
  257. def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
  258. """ Prepares decoder outputs for output
  259. PARAMS
  260. ------
  261. mel_outputs:
  262. gate_outputs: gate output energies
  263. alignments:
  264. RETURNS
  265. -------
  266. mel_outputs:
  267. gate_outpust: gate output energies
  268. alignments:
  269. """
  270. # (T_out, B) -> (B, T_out)
  271. alignments = torch.stack(alignments).transpose(0, 1)
  272. # (T_out, B) -> (B, T_out)
  273. gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
  274. gate_outputs = gate_outputs.contiguous()
  275. # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
  276. mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
  277. # decouple frames per step
  278. mel_outputs = mel_outputs.view(
  279. mel_outputs.size(0), -1, self.n_mel_channels)
  280. # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
  281. mel_outputs = mel_outputs.transpose(1, 2)
  282. return mel_outputs, gate_outputs, alignments
  283. def decode(self, decoder_input):
  284. """ Decoder step using stored states, attention and memory
  285. PARAMS
  286. ------
  287. decoder_input: previous mel output
  288. RETURNS
  289. -------
  290. mel_output:
  291. gate_output: gate output energies
  292. attention_weights:
  293. """
  294. cell_input = torch.cat((decoder_input, self.attention_context), -1)
  295. self.attention_hidden, self.attention_cell = self.attention_rnn(
  296. cell_input, (self.attention_hidden, self.attention_cell))
  297. self.attention_hidden = F.dropout(
  298. self.attention_hidden, self.p_attention_dropout, self.training)
  299. attention_weights_cat = torch.cat(
  300. (self.attention_weights.unsqueeze(1),
  301. self.attention_weights_cum.unsqueeze(1)), dim=1)
  302. self.attention_context, self.attention_weights = self.attention_layer(
  303. self.attention_hidden, self.memory, self.processed_memory,
  304. attention_weights_cat, self.mask)
  305. self.attention_weights_cum += self.attention_weights
  306. decoder_input = torch.cat(
  307. (self.attention_hidden, self.attention_context), -1)
  308. self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
  309. decoder_input, (self.decoder_hidden, self.decoder_cell))
  310. self.decoder_hidden = F.dropout(
  311. self.decoder_hidden, self.p_decoder_dropout, self.training)
  312. decoder_hidden_attention_context = torch.cat(
  313. (self.decoder_hidden, self.attention_context), dim=1)
  314. decoder_output = self.linear_projection(
  315. decoder_hidden_attention_context)
  316. gate_prediction = self.gate_layer(decoder_hidden_attention_context)
  317. return decoder_output, gate_prediction, self.attention_weights
  318. def forward(self, memory, decoder_inputs, memory_lengths):
  319. """ Decoder forward pass for training
  320. PARAMS
  321. ------
  322. memory: Encoder outputs
  323. decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
  324. memory_lengths: Encoder output lengths for attention masking.
  325. RETURNS
  326. -------
  327. mel_outputs: mel outputs from the decoder
  328. gate_outputs: gate outputs from the decoder
  329. alignments: sequence of attention weights from the decoder
  330. """
  331. decoder_input = self.get_go_frame(memory).unsqueeze(0)
  332. decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
  333. decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
  334. decoder_inputs = self.prenet(decoder_inputs)
  335. self.initialize_decoder_states(
  336. memory, mask=~get_mask_from_lengths(memory_lengths))
  337. mel_outputs, gate_outputs, alignments = [], [], []
  338. while len(mel_outputs) < decoder_inputs.size(0) - 1:
  339. decoder_input = decoder_inputs[len(mel_outputs)]
  340. mel_output, gate_output, attention_weights = self.decode(
  341. decoder_input)
  342. mel_outputs += [mel_output.squeeze(1)]
  343. gate_outputs += [gate_output.squeeze().unsqueeze(0)]
  344. alignments += [attention_weights]
  345. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  346. mel_outputs, gate_outputs, alignments)
  347. return mel_outputs, gate_outputs, alignments
  348. def inference(self, memory):
  349. """ Decoder inference
  350. PARAMS
  351. ------
  352. memory: Encoder outputs
  353. RETURNS
  354. -------
  355. mel_outputs: mel outputs from the decoder
  356. gate_outputs: gate outputs from the decoder
  357. alignments: sequence of attention weights from the decoder
  358. """
  359. decoder_input = self.get_go_frame(memory)
  360. self.initialize_decoder_states(memory, mask=None)
  361. mel_outputs, gate_outputs, alignments = [], [], []
  362. while True:
  363. decoder_input = self.prenet(decoder_input)
  364. mel_output, gate_output, alignment = self.decode(decoder_input)
  365. mel_outputs += [mel_output.squeeze(1)]
  366. gate_outputs += [gate_output]
  367. alignments += [alignment]
  368. if torch.sigmoid(gate_output.data) > self.gate_threshold:
  369. break
  370. elif len(mel_outputs) == self.max_decoder_steps:
  371. # print("Warning! Reached max decoder steps")
  372. break
  373. decoder_input = mel_output
  374. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  375. mel_outputs, gate_outputs, alignments)
  376. return mel_outputs, gate_outputs, alignments
  377. class Tacotron2(nn.Module):
  378. def __init__(self, hparams):
  379. super(Tacotron2, self).__init__()
  380. self.mask_padding = hparams.mask_padding
  381. self.fp16_run = hparams.fp16_run
  382. self.n_mel_channels = hparams.n_mel_channels
  383. self.n_frames_per_step = hparams.n_frames_per_step
  384. self.embedding = nn.Embedding(
  385. hparams.n_symbols, hparams.symbols_embedding_dim)
  386. std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
  387. val = sqrt(3.0) * std # uniform bounds for std
  388. self.embedding.weight.data.uniform_(-val, val)
  389. self.encoder = Encoder(hparams)
  390. self.decoder = Decoder(hparams)
  391. self.postnet = Postnet(hparams)
  392. def parse_batch(self, batch):
  393. text_padded, input_lengths, mel_padded, gate_padded, \
  394. output_lengths = batch
  395. text_padded = to_gpu(text_padded).long()
  396. input_lengths = to_gpu(input_lengths).long()
  397. max_len = torch.max(input_lengths.data).item()
  398. mel_padded = to_gpu(mel_padded).float()
  399. gate_padded = to_gpu(gate_padded).float()
  400. output_lengths = to_gpu(output_lengths).long()
  401. return (
  402. (text_padded, input_lengths, mel_padded, max_len, output_lengths),
  403. (mel_padded, gate_padded))
  404. def parse_output(self, outputs, output_lengths=None):
  405. if self.mask_padding and output_lengths is not None:
  406. mask = ~get_mask_from_lengths(output_lengths)
  407. mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
  408. mask = mask.permute(1, 0, 2)
  409. outputs[0].data.masked_fill_(mask, 0.0)
  410. outputs[1].data.masked_fill_(mask, 0.0)
  411. outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
  412. return outputs
  413. def forward(self, inputs):
  414. text_inputs, text_lengths, mels, max_len, output_lengths = inputs
  415. text_lengths, output_lengths = text_lengths.data, output_lengths.data
  416. embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
  417. encoder_outputs = self.encoder(embedded_inputs, text_lengths)
  418. mel_outputs, gate_outputs, alignments = self.decoder(
  419. encoder_outputs, mels, memory_lengths=text_lengths)
  420. mel_outputs_postnet = self.postnet(mel_outputs)
  421. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  422. return self.parse_output(
  423. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
  424. output_lengths), encoder_outputs
  425. def inference(self, inputs):
  426. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  427. encoder_outputs = self.encoder.inference(embedded_inputs)
  428. mel_outputs, gate_outputs, alignments = self.decoder.inference(
  429. encoder_outputs)
  430. mel_outputs_postnet = self.postnet(mel_outputs)
  431. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  432. outputs = self.parse_output(
  433. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
  434. return outputs, encoder_outputs