You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

538 lines
20 KiB

  1. import torch
  2. from torch.autograd import Variable
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from layers import ConvNorm, LinearNorm
  6. from utils import to_gpu, get_mask_from_lengths
  7. from fp16_optimizer import fp32_to_fp16, fp16_to_fp32
  8. class LocationLayer(nn.Module):
  9. def __init__(self, attention_n_filters, attention_kernel_size,
  10. attention_dim):
  11. super(LocationLayer, self).__init__()
  12. padding = int((attention_kernel_size - 1) / 2)
  13. self.location_conv = ConvNorm(2, attention_n_filters,
  14. kernel_size=attention_kernel_size,
  15. padding=padding, bias=False, stride=1,
  16. dilation=1)
  17. self.location_dense = LinearNorm(attention_n_filters, attention_dim,
  18. bias=False, w_init_gain='tanh')
  19. def forward(self, attention_weights_cat):
  20. processed_attention = self.location_conv(attention_weights_cat)
  21. processed_attention = processed_attention.transpose(1, 2)
  22. processed_attention = self.location_dense(processed_attention)
  23. return processed_attention
  24. class Attention(nn.Module):
  25. def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
  26. attention_location_n_filters, attention_location_kernel_size):
  27. super(Attention, self).__init__()
  28. self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
  29. bias=False, w_init_gain='tanh')
  30. self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
  31. w_init_gain='tanh')
  32. self.v = LinearNorm(attention_dim, 1, bias=False)
  33. self.location_layer = LocationLayer(attention_location_n_filters,
  34. attention_location_kernel_size,
  35. attention_dim)
  36. self.score_mask_value = -float("inf")
  37. def get_alignment_energies(self, query, processed_memory,
  38. attention_weights_cat):
  39. """
  40. PARAMS
  41. ------
  42. query: decoder output (batch, n_mel_channels * n_frames_per_step)
  43. processed_memory: processed encoder outputs (B, T_in, attention_dim)
  44. attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
  45. RETURNS
  46. -------
  47. alignment (batch, max_time)
  48. """
  49. processed_query = self.query_layer(query.unsqueeze(1))
  50. processed_attention_weights = self.location_layer(attention_weights_cat)
  51. energies = self.v(F.tanh(
  52. processed_query + processed_attention_weights + processed_memory))
  53. energies = energies.squeeze(-1)
  54. return energies
  55. def forward(self, attention_hidden_state, memory, processed_memory,
  56. attention_weights_cat, mask):
  57. """
  58. PARAMS
  59. ------
  60. attention_hidden_state: attention rnn last output
  61. memory: encoder outputs
  62. processed_memory: processed encoder outputs
  63. attention_weights_cat: previous and cummulative attention weights
  64. mask: binary mask for padded data
  65. """
  66. alignment = self.get_alignment_energies(
  67. attention_hidden_state, processed_memory, attention_weights_cat)
  68. if mask is not None:
  69. alignment.data.masked_fill_(mask, self.score_mask_value)
  70. attention_weights = F.softmax(alignment, dim=1)
  71. attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
  72. attention_context = attention_context.squeeze(1)
  73. return attention_context, attention_weights
  74. class Prenet(nn.Module):
  75. def __init__(self, in_dim, sizes):
  76. super(Prenet, self).__init__()
  77. in_sizes = [in_dim] + sizes[:-1]
  78. self.layers = nn.ModuleList(
  79. [LinearNorm(in_size, out_size, bias=False)
  80. for (in_size, out_size) in zip(in_sizes, sizes)])
  81. def forward(self, x):
  82. for linear in self.layers:
  83. x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
  84. return x
  85. class Postnet(nn.Module):
  86. """Postnet
  87. - Five 1-d convolution with 512 channels and kernel size 5
  88. """
  89. def __init__(self, hparams):
  90. super(Postnet, self).__init__()
  91. self.dropout = nn.Dropout(0.5)
  92. self.convolutions = nn.ModuleList()
  93. self.convolutions.append(
  94. nn.Sequential(
  95. ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
  96. kernel_size=hparams.postnet_kernel_size, stride=1,
  97. padding=int((hparams.postnet_kernel_size - 1) / 2),
  98. dilation=1, w_init_gain='tanh'),
  99. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  100. )
  101. for i in range(1, hparams.postnet_n_convolutions - 1):
  102. self.convolutions.append(
  103. nn.Sequential(
  104. ConvNorm(hparams.postnet_embedding_dim,
  105. hparams.postnet_embedding_dim,
  106. kernel_size=hparams.postnet_kernel_size, stride=1,
  107. padding=int((hparams.postnet_kernel_size - 1) / 2),
  108. dilation=1, w_init_gain='tanh'),
  109. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  110. )
  111. self.convolutions.append(
  112. nn.Sequential(
  113. ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
  114. kernel_size=hparams.postnet_kernel_size, stride=1,
  115. padding=int((hparams.postnet_kernel_size - 1) / 2),
  116. dilation=1, w_init_gain='linear'),
  117. nn.BatchNorm1d(hparams.n_mel_channels))
  118. )
  119. def forward(self, x):
  120. for i in range(len(self.convolutions) - 1):
  121. x = self.dropout(F.tanh(self.convolutions[i](x)))
  122. x = self.dropout(self.convolutions[-1](x))
  123. return x
  124. class Encoder(nn.Module):
  125. """Encoder module:
  126. - Three 1-d convolution banks
  127. - Bidirectional LSTM
  128. """
  129. def __init__(self, hparams):
  130. super(Encoder, self).__init__()
  131. self.dropout = nn.Dropout(0.5)
  132. convolutions = []
  133. for _ in range(hparams.encoder_n_convolutions):
  134. conv_layer = nn.Sequential(
  135. ConvNorm(hparams.encoder_embedding_dim,
  136. hparams.encoder_embedding_dim,
  137. kernel_size=hparams.encoder_kernel_size, stride=1,
  138. padding=int((hparams.encoder_kernel_size - 1) / 2),
  139. dilation=1, w_init_gain='relu'),
  140. nn.BatchNorm1d(hparams.encoder_embedding_dim))
  141. convolutions.append(conv_layer)
  142. self.convolutions = nn.ModuleList(convolutions)
  143. self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
  144. int(hparams.encoder_embedding_dim / 2), 1,
  145. batch_first=True, bidirectional=True)
  146. def forward(self, x, input_lengths):
  147. for conv in self.convolutions:
  148. x = self.dropout(F.relu(conv(x)))
  149. x = x.transpose(1, 2)
  150. # pytorch tensor are not reversible, hence the conversion
  151. input_lengths = input_lengths.cpu().numpy()
  152. x = nn.utils.rnn.pack_padded_sequence(
  153. x, input_lengths, batch_first=True)
  154. self.lstm.flatten_parameters()
  155. outputs, _ = self.lstm(x)
  156. outputs, _ = nn.utils.rnn.pad_packed_sequence(
  157. outputs, batch_first=True)
  158. return outputs
  159. def inference(self, x):
  160. for conv in self.convolutions:
  161. x = self.dropout(F.relu(conv(x)))
  162. x = x.transpose(1, 2)
  163. self.lstm.flatten_parameters()
  164. outputs, _ = self.lstm(x)
  165. return outputs
  166. class Decoder(nn.Module):
  167. def __init__(self, hparams):
  168. super(Decoder, self).__init__()
  169. self.n_mel_channels = hparams.n_mel_channels
  170. self.n_frames_per_step = hparams.n_frames_per_step
  171. self.encoder_embedding_dim = hparams.encoder_embedding_dim
  172. self.attention_rnn_dim = hparams.attention_rnn_dim
  173. self.decoder_rnn_dim = hparams.decoder_rnn_dim
  174. self.prenet_dim = hparams.prenet_dim
  175. self.max_decoder_steps = hparams.max_decoder_steps
  176. self.gate_threshold = hparams.gate_threshold
  177. self.prenet = Prenet(
  178. hparams.n_mel_channels * hparams.n_frames_per_step,
  179. [hparams.prenet_dim, hparams.prenet_dim])
  180. self.attention_rnn = nn.LSTMCell(
  181. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
  182. hparams.attention_rnn_dim)
  183. self.attention_layer = Attention(
  184. hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
  185. hparams.attention_dim, hparams.attention_location_n_filters,
  186. hparams.attention_location_kernel_size)
  187. self.decoder_rnn = nn.LSTMCell(
  188. hparams.prenet_dim + hparams.encoder_embedding_dim,
  189. hparams.decoder_rnn_dim, 1)
  190. self.linear_projection = LinearNorm(
  191. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
  192. hparams.n_mel_channels*hparams.n_frames_per_step)
  193. self.gate_layer = LinearNorm(
  194. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
  195. bias=True, w_init_gain='sigmoid')
  196. def get_go_frame(self, memory):
  197. """ Gets all zeros frames to use as first decoder input
  198. PARAMS
  199. ------
  200. memory: decoder outputs
  201. RETURNS
  202. -------
  203. decoder_input: all zeros frames
  204. """
  205. B = memory.size(0)
  206. decoder_input = Variable(memory.data.new(
  207. B, self.n_mel_channels * self.n_frames_per_step).zero_())
  208. return decoder_input
  209. def initialize_decoder_states(self, memory, mask):
  210. """ Initializes attention rnn states, decoder rnn states, attention
  211. weights, attention cumulative weights, attention context, stores memory
  212. and stores processed memory
  213. PARAMS
  214. ------
  215. memory: Encoder outputs
  216. mask: Mask for padded data if training, expects None for inference
  217. """
  218. B = memory.size(0)
  219. MAX_TIME = memory.size(1)
  220. self.attention_hidden = Variable(memory.data.new(
  221. B, self.attention_rnn_dim).zero_())
  222. self.attention_cell = Variable(memory.data.new(
  223. B, self.attention_rnn_dim).zero_())
  224. self.decoder_hidden = Variable(memory.data.new(
  225. B, self.decoder_rnn_dim).zero_())
  226. self.decoder_cell = Variable(memory.data.new(
  227. B, self.decoder_rnn_dim).zero_())
  228. self.attention_weights = Variable(memory.data.new(
  229. B, MAX_TIME).zero_())
  230. self.attention_weights_cum = Variable(memory.data.new(
  231. B, MAX_TIME).zero_())
  232. self.attention_context = Variable(memory.data.new(
  233. B, self.encoder_embedding_dim).zero_())
  234. self.memory = memory
  235. self.processed_memory = self.attention_layer.memory_layer(memory)
  236. self.mask = mask
  237. def parse_decoder_inputs(self, decoder_inputs):
  238. """ Prepares decoder inputs, i.e. mel outputs
  239. PARAMS
  240. ------
  241. decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
  242. RETURNS
  243. -------
  244. inputs: processed decoder inputs
  245. """
  246. # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
  247. decoder_inputs = decoder_inputs.transpose(1, 2)
  248. decoder_inputs = decoder_inputs.view(
  249. decoder_inputs.size(0),
  250. int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
  251. # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
  252. decoder_inputs = decoder_inputs.transpose(0, 1)
  253. return decoder_inputs
  254. def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
  255. """ Prepares decoder outputs for output
  256. PARAMS
  257. ------
  258. mel_outputs:
  259. gate_outputs: gate output energies
  260. alignments:
  261. RETURNS
  262. -------
  263. mel_outputs:
  264. gate_outpust: gate output energies
  265. alignments:
  266. """
  267. # (T_out, B) -> (B, T_out)
  268. alignments = torch.stack(alignments).transpose(0, 1)
  269. # (T_out, B) -> (B, T_out)
  270. gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
  271. gate_outputs = gate_outputs.contiguous()
  272. # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
  273. mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
  274. # decouple frames per step
  275. mel_outputs = mel_outputs.view(
  276. mel_outputs.size(0), -1, self.n_mel_channels)
  277. # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
  278. mel_outputs = mel_outputs.transpose(1, 2)
  279. return mel_outputs, gate_outputs, alignments
  280. def decode(self, decoder_input):
  281. """ Decoder step using stored states, attention and memory
  282. PARAMS
  283. ------
  284. decoder_input: previous mel output
  285. RETURNS
  286. -------
  287. mel_output:
  288. gate_output: gate output energies
  289. attention_weights:
  290. """
  291. cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1)
  292. self.attention_hidden, self.attention_cell = self.attention_rnn(
  293. cell_input, (self.attention_hidden, self.attention_cell))
  294. attention_weights_cat = torch.cat(
  295. (self.attention_weights.unsqueeze(1),
  296. self.attention_weights_cum.unsqueeze(1)), dim=1)
  297. self.attention_context, self.attention_weights = self.attention_layer(
  298. self.attention_hidden, self.memory, self.processed_memory,
  299. attention_weights_cat, self.mask)
  300. self.attention_weights_cum += self.attention_weights
  301. prenet_output = self.prenet(decoder_input)
  302. decoder_input = torch.cat((prenet_output, self.attention_context), -1)
  303. self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
  304. decoder_input, (self.decoder_hidden, self.decoder_cell))
  305. decoder_hidden_attention_context = torch.cat(
  306. (self.decoder_hidden, self.attention_context), dim=1)
  307. decoder_output = self.linear_projection(
  308. decoder_hidden_attention_context)
  309. gate_prediction = self.gate_layer(decoder_hidden_attention_context)
  310. return decoder_output, gate_prediction, self.attention_weights
  311. def forward(self, memory, decoder_inputs, memory_lengths):
  312. """ Decoder forward pass for training
  313. PARAMS
  314. ------
  315. memory: Encoder outputs
  316. decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
  317. memory_lengths: Encoder output lengths for attention masking.
  318. RETURNS
  319. -------
  320. mel_outputs: mel outputs from the decoder
  321. gate_outputs: gate outputs from the decoder
  322. alignments: sequence of attention weights from the decoder
  323. """
  324. decoder_input = self.get_go_frame(memory)
  325. decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
  326. self.initialize_decoder_states(
  327. memory, mask=~get_mask_from_lengths(memory_lengths))
  328. mel_outputs, gate_outputs, alignments = [], [], []
  329. while len(mel_outputs) < decoder_inputs.size(0):
  330. mel_output, gate_output, attention_weights = self.decode(
  331. decoder_input)
  332. mel_outputs += [mel_output]
  333. gate_outputs += [gate_output.squeeze(1)]
  334. alignments += [attention_weights]
  335. decoder_input = decoder_inputs[len(mel_outputs) - 1]
  336. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  337. mel_outputs, gate_outputs, alignments)
  338. return mel_outputs, gate_outputs, alignments
  339. def inference(self, memory):
  340. """ Decoder inference
  341. PARAMS
  342. ------
  343. memory: Encoder outputs
  344. RETURNS
  345. -------
  346. mel_outputs: mel outputs from the decoder
  347. gate_outputs: gate outputs from the decoder
  348. alignments: sequence of attention weights from the decoder
  349. """
  350. decoder_input = self.get_go_frame(memory)
  351. self.initialize_decoder_states(memory, mask=None)
  352. mel_outputs, gate_outputs, alignments = [], [], []
  353. while True:
  354. mel_output, gate_output, alignment = self.decode(decoder_input)
  355. mel_outputs += [mel_output]
  356. gate_outputs += [gate_output.squeeze(1)]
  357. alignments += [alignment]
  358. if F.sigmoid(gate_output.data) > self.gate_threshold:
  359. break
  360. elif len(mel_outputs) == self.max_decoder_steps:
  361. print("Warning! Reached max decoder steps")
  362. break
  363. decoder_input = mel_output
  364. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  365. mel_outputs, gate_outputs, alignments)
  366. return mel_outputs, gate_outputs, alignments
  367. class Tacotron2(nn.Module):
  368. def __init__(self, hparams):
  369. super(Tacotron2, self).__init__()
  370. self.mask_padding = hparams.mask_padding
  371. self.fp16_run = hparams.fp16_run
  372. self.n_mel_channels = hparams.n_mel_channels
  373. self.n_frames_per_step = hparams.n_frames_per_step
  374. self.embedding = nn.Embedding(
  375. hparams.n_symbols, hparams.symbols_embedding_dim)
  376. self.encoder = Encoder(hparams)
  377. self.decoder = Decoder(hparams)
  378. self.postnet = Postnet(hparams)
  379. def parse_batch(self, batch):
  380. text_padded, input_lengths, mel_padded, gate_padded, \
  381. output_lengths = batch
  382. text_padded = to_gpu(text_padded).long()
  383. max_len = int(torch.max(input_lengths.data).numpy())
  384. input_lengths = to_gpu(input_lengths).long()
  385. mel_padded = to_gpu(mel_padded).float()
  386. gate_padded = to_gpu(gate_padded).float()
  387. output_lengths = to_gpu(output_lengths).long()
  388. return (
  389. (text_padded, input_lengths, mel_padded, max_len, output_lengths),
  390. (mel_padded, gate_padded))
  391. def parse_input(self, inputs):
  392. inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
  393. return inputs
  394. def parse_output(self, outputs, output_lengths=None):
  395. if self.mask_padding and output_lengths is not None:
  396. mask = ~get_mask_from_lengths(output_lengths+1) # +1 <stop> token
  397. mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
  398. mask = mask.permute(1, 0, 2)
  399. outputs[0].data.masked_fill_(mask, 0.0)
  400. outputs[1].data.masked_fill_(mask, 0.0)
  401. outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
  402. outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
  403. return outputs
  404. def forward(self, inputs):
  405. inputs, input_lengths, targets, max_len, \
  406. output_lengths = self.parse_input(inputs)
  407. input_lengths, output_lengths = input_lengths.data, output_lengths.data
  408. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  409. encoder_outputs = self.encoder(embedded_inputs, input_lengths)
  410. mel_outputs, gate_outputs, alignments = self.decoder(
  411. encoder_outputs, targets, memory_lengths=input_lengths)
  412. mel_outputs_postnet = self.postnet(mel_outputs)
  413. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  414. # DataParallel expects equal sized inputs/outputs, hence padding
  415. if input_lengths is not None:
  416. alignments = alignments.unsqueeze(0)
  417. alignments = nn.functional.pad(
  418. alignments,
  419. (0, max_len - alignments.size(3), 0, 0),
  420. "constant", 0)
  421. alignments = alignments.squeeze()
  422. return self.parse_output(
  423. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
  424. output_lengths)
  425. def inference(self, inputs):
  426. inputs = self.parse_input(inputs)
  427. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  428. encoder_outputs = self.encoder.inference(embedded_inputs)
  429. mel_outputs, gate_outputs, alignments = self.decoder.inference(
  430. encoder_outputs)
  431. mel_outputs_postnet = self.postnet(mel_outputs)
  432. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  433. outputs = self.parse_output(
  434. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
  435. return outputs