You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

541 lines
21 KiB

  1. from math import sqrt
  2. import torch
  3. from torch.autograd import Variable
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from layers import ConvNorm, LinearNorm
  7. from utils import to_gpu, get_mask_from_lengths
  8. from fp16_optimizer import fp32_to_fp16, fp16_to_fp32
  9. class LocationLayer(nn.Module):
  10. def __init__(self, attention_n_filters, attention_kernel_size,
  11. attention_dim):
  12. super(LocationLayer, self).__init__()
  13. padding = int((attention_kernel_size - 1) / 2)
  14. self.location_conv = ConvNorm(2, attention_n_filters,
  15. kernel_size=attention_kernel_size,
  16. padding=padding, bias=False, stride=1,
  17. dilation=1)
  18. self.location_dense = LinearNorm(attention_n_filters, attention_dim,
  19. bias=False, w_init_gain='tanh')
  20. def forward(self, attention_weights_cat):
  21. processed_attention = self.location_conv(attention_weights_cat)
  22. processed_attention = processed_attention.transpose(1, 2)
  23. processed_attention = self.location_dense(processed_attention)
  24. return processed_attention
  25. class Attention(nn.Module):
  26. def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
  27. attention_location_n_filters, attention_location_kernel_size):
  28. super(Attention, self).__init__()
  29. self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
  30. bias=False, w_init_gain='tanh')
  31. self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
  32. w_init_gain='tanh')
  33. self.v = LinearNorm(attention_dim, 1, bias=False)
  34. self.location_layer = LocationLayer(attention_location_n_filters,
  35. attention_location_kernel_size,
  36. attention_dim)
  37. self.score_mask_value = -float("inf")
  38. def get_alignment_energies(self, query, processed_memory,
  39. attention_weights_cat):
  40. """
  41. PARAMS
  42. ------
  43. query: decoder output (batch, n_mel_channels * n_frames_per_step)
  44. processed_memory: processed encoder outputs (B, T_in, attention_dim)
  45. attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
  46. RETURNS
  47. -------
  48. alignment (batch, max_time)
  49. """
  50. processed_query = self.query_layer(query.unsqueeze(1))
  51. processed_attention_weights = self.location_layer(attention_weights_cat)
  52. energies = self.v(torch.tanh(
  53. processed_query + processed_attention_weights + processed_memory))
  54. energies = energies.squeeze(-1)
  55. return energies
  56. def forward(self, attention_hidden_state, memory, processed_memory,
  57. attention_weights_cat, mask):
  58. """
  59. PARAMS
  60. ------
  61. attention_hidden_state: attention rnn last output
  62. memory: encoder outputs
  63. processed_memory: processed encoder outputs
  64. attention_weights_cat: previous and cummulative attention weights
  65. mask: binary mask for padded data
  66. """
  67. alignment = self.get_alignment_energies(
  68. attention_hidden_state, processed_memory, attention_weights_cat)
  69. if mask is not None:
  70. alignment.data.masked_fill_(mask, self.score_mask_value)
  71. attention_weights = F.softmax(alignment, dim=1)
  72. attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
  73. attention_context = attention_context.squeeze(1)
  74. return attention_context, attention_weights
  75. class Prenet(nn.Module):
  76. def __init__(self, in_dim, sizes):
  77. super(Prenet, self).__init__()
  78. in_sizes = [in_dim] + sizes[:-1]
  79. self.layers = nn.ModuleList(
  80. [LinearNorm(in_size, out_size, bias=False)
  81. for (in_size, out_size) in zip(in_sizes, sizes)])
  82. def forward(self, x):
  83. for linear in self.layers:
  84. x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
  85. return x
  86. class Postnet(nn.Module):
  87. """Postnet
  88. - Five 1-d convolution with 512 channels and kernel size 5
  89. """
  90. def __init__(self, hparams):
  91. super(Postnet, self).__init__()
  92. self.convolutions = nn.ModuleList()
  93. self.convolutions.append(
  94. nn.Sequential(
  95. ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
  96. kernel_size=hparams.postnet_kernel_size, stride=1,
  97. padding=int((hparams.postnet_kernel_size - 1) / 2),
  98. dilation=1, w_init_gain='tanh'),
  99. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  100. )
  101. for i in range(1, hparams.postnet_n_convolutions - 1):
  102. self.convolutions.append(
  103. nn.Sequential(
  104. ConvNorm(hparams.postnet_embedding_dim,
  105. hparams.postnet_embedding_dim,
  106. kernel_size=hparams.postnet_kernel_size, stride=1,
  107. padding=int((hparams.postnet_kernel_size - 1) / 2),
  108. dilation=1, w_init_gain='tanh'),
  109. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  110. )
  111. self.convolutions.append(
  112. nn.Sequential(
  113. ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
  114. kernel_size=hparams.postnet_kernel_size, stride=1,
  115. padding=int((hparams.postnet_kernel_size - 1) / 2),
  116. dilation=1, w_init_gain='linear'),
  117. nn.BatchNorm1d(hparams.n_mel_channels))
  118. )
  119. def forward(self, x):
  120. for i in range(len(self.convolutions) - 1):
  121. x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
  122. x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
  123. return x
  124. class Encoder(nn.Module):
  125. """Encoder module:
  126. - Three 1-d convolution banks
  127. - Bidirectional LSTM
  128. """
  129. def __init__(self, hparams):
  130. super(Encoder, self).__init__()
  131. convolutions = []
  132. for _ in range(hparams.encoder_n_convolutions):
  133. conv_layer = nn.Sequential(
  134. ConvNorm(hparams.encoder_embedding_dim,
  135. hparams.encoder_embedding_dim,
  136. kernel_size=hparams.encoder_kernel_size, stride=1,
  137. padding=int((hparams.encoder_kernel_size - 1) / 2),
  138. dilation=1, w_init_gain='relu'),
  139. nn.BatchNorm1d(hparams.encoder_embedding_dim))
  140. convolutions.append(conv_layer)
  141. self.convolutions = nn.ModuleList(convolutions)
  142. self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
  143. int(hparams.encoder_embedding_dim / 2), 1,
  144. batch_first=True, bidirectional=True)
  145. def forward(self, x, input_lengths):
  146. for conv in self.convolutions:
  147. x = F.dropout(F.relu(conv(x)), 0.5, self.training)
  148. x = x.transpose(1, 2)
  149. # pytorch tensor are not reversible, hence the conversion
  150. input_lengths = input_lengths.cpu().numpy()
  151. x = nn.utils.rnn.pack_padded_sequence(
  152. x, input_lengths, batch_first=True)
  153. self.lstm.flatten_parameters()
  154. outputs, _ = self.lstm(x)
  155. outputs, _ = nn.utils.rnn.pad_packed_sequence(
  156. outputs, batch_first=True)
  157. return outputs
  158. def inference(self, x):
  159. for conv in self.convolutions:
  160. x = F.dropout(F.relu(conv(x)), 0.5, self.training)
  161. x = x.transpose(1, 2)
  162. self.lstm.flatten_parameters()
  163. outputs, _ = self.lstm(x)
  164. return outputs
  165. class Decoder(nn.Module):
  166. def __init__(self, hparams):
  167. super(Decoder, self).__init__()
  168. self.n_mel_channels = hparams.n_mel_channels
  169. self.n_frames_per_step = hparams.n_frames_per_step
  170. self.encoder_embedding_dim = hparams.encoder_embedding_dim
  171. self.attention_rnn_dim = hparams.attention_rnn_dim
  172. self.decoder_rnn_dim = hparams.decoder_rnn_dim
  173. self.prenet_dim = hparams.prenet_dim
  174. self.max_decoder_steps = hparams.max_decoder_steps
  175. self.gate_threshold = hparams.gate_threshold
  176. self.p_attention_dropout = hparams.p_attention_dropout
  177. self.p_decoder_dropout = hparams.p_decoder_dropout
  178. self.prenet = Prenet(
  179. hparams.n_mel_channels * hparams.n_frames_per_step,
  180. [hparams.prenet_dim, hparams.prenet_dim])
  181. self.attention_rnn = nn.LSTMCell(
  182. hparams.prenet_dim + hparams.encoder_embedding_dim,
  183. hparams.attention_rnn_dim)
  184. self.attention_layer = Attention(
  185. hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
  186. hparams.attention_dim, hparams.attention_location_n_filters,
  187. hparams.attention_location_kernel_size)
  188. self.decoder_rnn = nn.LSTMCell(
  189. hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
  190. hparams.decoder_rnn_dim, 1)
  191. self.linear_projection = LinearNorm(
  192. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
  193. hparams.n_mel_channels * hparams.n_frames_per_step)
  194. self.gate_layer = LinearNorm(
  195. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
  196. bias=True, w_init_gain='sigmoid')
  197. def get_go_frame(self, memory):
  198. """ Gets all zeros frames to use as first decoder input
  199. PARAMS
  200. ------
  201. memory: decoder outputs
  202. RETURNS
  203. -------
  204. decoder_input: all zeros frames
  205. """
  206. B = memory.size(0)
  207. decoder_input = Variable(memory.data.new(
  208. B, self.n_mel_channels * self.n_frames_per_step).zero_())
  209. return decoder_input
  210. def initialize_decoder_states(self, memory, mask):
  211. """ Initializes attention rnn states, decoder rnn states, attention
  212. weights, attention cumulative weights, attention context, stores memory
  213. and stores processed memory
  214. PARAMS
  215. ------
  216. memory: Encoder outputs
  217. mask: Mask for padded data if training, expects None for inference
  218. """
  219. B = memory.size(0)
  220. MAX_TIME = memory.size(1)
  221. self.attention_hidden = Variable(memory.data.new(
  222. B, self.attention_rnn_dim).zero_())
  223. self.attention_cell = Variable(memory.data.new(
  224. B, self.attention_rnn_dim).zero_())
  225. self.decoder_hidden = Variable(memory.data.new(
  226. B, self.decoder_rnn_dim).zero_())
  227. self.decoder_cell = Variable(memory.data.new(
  228. B, self.decoder_rnn_dim).zero_())
  229. self.attention_weights = Variable(memory.data.new(
  230. B, MAX_TIME).zero_())
  231. self.attention_weights_cum = Variable(memory.data.new(
  232. B, MAX_TIME).zero_())
  233. self.attention_context = Variable(memory.data.new(
  234. B, self.encoder_embedding_dim).zero_())
  235. self.memory = memory
  236. self.processed_memory = self.attention_layer.memory_layer(memory)
  237. self.mask = mask
  238. def parse_decoder_inputs(self, decoder_inputs):
  239. """ Prepares decoder inputs, i.e. mel outputs
  240. PARAMS
  241. ------
  242. decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
  243. RETURNS
  244. -------
  245. inputs: processed decoder inputs
  246. """
  247. # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
  248. decoder_inputs = decoder_inputs.transpose(1, 2)
  249. decoder_inputs = decoder_inputs.view(
  250. decoder_inputs.size(0),
  251. int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
  252. # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
  253. decoder_inputs = decoder_inputs.transpose(0, 1)
  254. return decoder_inputs
  255. def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
  256. """ Prepares decoder outputs for output
  257. PARAMS
  258. ------
  259. mel_outputs:
  260. gate_outputs: gate output energies
  261. alignments:
  262. RETURNS
  263. -------
  264. mel_outputs:
  265. gate_outpust: gate output energies
  266. alignments:
  267. """
  268. # (T_out, B) -> (B, T_out)
  269. alignments = torch.stack(alignments).transpose(0, 1)
  270. # (T_out, B) -> (B, T_out)
  271. gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
  272. gate_outputs = gate_outputs.contiguous()
  273. # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
  274. mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
  275. # decouple frames per step
  276. mel_outputs = mel_outputs.view(
  277. mel_outputs.size(0), -1, self.n_mel_channels)
  278. # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
  279. mel_outputs = mel_outputs.transpose(1, 2)
  280. return mel_outputs, gate_outputs, alignments
  281. def decode(self, decoder_input):
  282. """ Decoder step using stored states, attention and memory
  283. PARAMS
  284. ------
  285. decoder_input: previous mel output
  286. RETURNS
  287. -------
  288. mel_output:
  289. gate_output: gate output energies
  290. attention_weights:
  291. """
  292. cell_input = torch.cat((decoder_input, self.attention_context), -1)
  293. self.attention_hidden, self.attention_cell = self.attention_rnn(
  294. cell_input, (self.attention_hidden, self.attention_cell))
  295. self.attention_hidden = F.dropout(
  296. self.attention_hidden, self.p_attention_dropout, self.training)
  297. self.attention_cell = F.dropout(
  298. self.attention_cell, self.p_attention_dropout, self.training)
  299. attention_weights_cat = torch.cat(
  300. (self.attention_weights.unsqueeze(1),
  301. self.attention_weights_cum.unsqueeze(1)), dim=1)
  302. self.attention_context, self.attention_weights = self.attention_layer(
  303. self.attention_hidden, self.memory, self.processed_memory,
  304. attention_weights_cat, self.mask)
  305. self.attention_weights_cum += self.attention_weights
  306. decoder_input = torch.cat(
  307. (self.attention_hidden, self.attention_context), -1)
  308. self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
  309. decoder_input, (self.decoder_hidden, self.decoder_cell))
  310. self.decoder_hidden = F.dropout(
  311. self.decoder_hidden, self.p_decoder_dropout, self.training)
  312. self.decoder_cell = F.dropout(
  313. self.decoder_cell, self.p_decoder_dropout, self.training)
  314. decoder_hidden_attention_context = torch.cat(
  315. (self.decoder_hidden, self.attention_context), dim=1)
  316. decoder_output = self.linear_projection(
  317. decoder_hidden_attention_context)
  318. gate_prediction = self.gate_layer(decoder_hidden_attention_context)
  319. return decoder_output, gate_prediction, self.attention_weights
  320. def forward(self, memory, decoder_inputs, memory_lengths):
  321. """ Decoder forward pass for training
  322. PARAMS
  323. ------
  324. memory: Encoder outputs
  325. decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
  326. memory_lengths: Encoder output lengths for attention masking.
  327. RETURNS
  328. -------
  329. mel_outputs: mel outputs from the decoder
  330. gate_outputs: gate outputs from the decoder
  331. alignments: sequence of attention weights from the decoder
  332. """
  333. decoder_input = self.get_go_frame(memory).unsqueeze(0)
  334. decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
  335. decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
  336. decoder_inputs = self.prenet(decoder_inputs)
  337. self.initialize_decoder_states(
  338. memory, mask=~get_mask_from_lengths(memory_lengths))
  339. mel_outputs, gate_outputs, alignments = [], [], []
  340. while len(mel_outputs) < decoder_inputs.size(0) - 1:
  341. decoder_input = decoder_inputs[len(mel_outputs)]
  342. mel_output, gate_output, attention_weights = self.decode(
  343. decoder_input)
  344. mel_outputs += [mel_output.squeeze(1)]
  345. gate_outputs += [gate_output.squeeze()]
  346. alignments += [attention_weights]
  347. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  348. mel_outputs, gate_outputs, alignments)
  349. return mel_outputs, gate_outputs, alignments
  350. def inference(self, memory):
  351. """ Decoder inference
  352. PARAMS
  353. ------
  354. memory: Encoder outputs
  355. RETURNS
  356. -------
  357. mel_outputs: mel outputs from the decoder
  358. gate_outputs: gate outputs from the decoder
  359. alignments: sequence of attention weights from the decoder
  360. """
  361. decoder_input = self.get_go_frame(memory)
  362. self.initialize_decoder_states(memory, mask=None)
  363. mel_outputs, gate_outputs, alignments = [], [], []
  364. while True:
  365. decoder_input = self.prenet(decoder_input)
  366. mel_output, gate_output, alignment = self.decode(decoder_input)
  367. mel_outputs += [mel_output.squeeze(1)]
  368. gate_outputs += [gate_output]
  369. alignments += [alignment]
  370. if torch.sigmoid(gate_output.data) > self.gate_threshold:
  371. break
  372. elif len(mel_outputs) == self.max_decoder_steps:
  373. print("Warning! Reached max decoder steps")
  374. break
  375. decoder_input = mel_output
  376. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  377. mel_outputs, gate_outputs, alignments)
  378. return mel_outputs, gate_outputs, alignments
  379. class Tacotron2(nn.Module):
  380. def __init__(self, hparams):
  381. super(Tacotron2, self).__init__()
  382. self.mask_padding = hparams.mask_padding
  383. self.fp16_run = hparams.fp16_run
  384. self.n_mel_channels = hparams.n_mel_channels
  385. self.n_frames_per_step = hparams.n_frames_per_step
  386. self.embedding = nn.Embedding(
  387. hparams.n_symbols, hparams.symbols_embedding_dim)
  388. std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
  389. val = sqrt(3.0) * std # uniform bounds for std
  390. self.embedding.weight.data.uniform_(-val, val)
  391. self.encoder = Encoder(hparams)
  392. self.decoder = Decoder(hparams)
  393. self.postnet = Postnet(hparams)
  394. def parse_batch(self, batch):
  395. text_padded, input_lengths, mel_padded, gate_padded, \
  396. output_lengths = batch
  397. text_padded = to_gpu(text_padded).long()
  398. input_lengths = to_gpu(input_lengths).long()
  399. max_len = torch.max(input_lengths.data).item()
  400. mel_padded = to_gpu(mel_padded).float()
  401. gate_padded = to_gpu(gate_padded).float()
  402. output_lengths = to_gpu(output_lengths).long()
  403. return (
  404. (text_padded, input_lengths, mel_padded, max_len, output_lengths),
  405. (mel_padded, gate_padded))
  406. def parse_input(self, inputs):
  407. inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
  408. return inputs
  409. def parse_output(self, outputs, output_lengths=None):
  410. if self.mask_padding and output_lengths is not None:
  411. mask = ~get_mask_from_lengths(output_lengths)
  412. mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
  413. mask = mask.permute(1, 0, 2)
  414. outputs[0].data.masked_fill_(mask, 0.0)
  415. outputs[1].data.masked_fill_(mask, 0.0)
  416. outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
  417. outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
  418. return outputs
  419. def forward(self, inputs):
  420. inputs, input_lengths, targets, max_len, \
  421. output_lengths = self.parse_input(inputs)
  422. input_lengths, output_lengths = input_lengths.data, output_lengths.data
  423. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  424. encoder_outputs = self.encoder(embedded_inputs, input_lengths)
  425. mel_outputs, gate_outputs, alignments = self.decoder(
  426. encoder_outputs, targets, memory_lengths=input_lengths)
  427. mel_outputs_postnet = self.postnet(mel_outputs)
  428. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  429. return self.parse_output(
  430. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
  431. output_lengths)
  432. def inference(self, inputs):
  433. inputs = self.parse_input(inputs)
  434. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  435. encoder_outputs = self.encoder.inference(embedded_inputs)
  436. mel_outputs, gate_outputs, alignments = self.decoder.inference(
  437. encoder_outputs)
  438. mel_outputs_postnet = self.postnet(mel_outputs)
  439. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  440. outputs = self.parse_output(
  441. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
  442. return outputs