You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

529 lines
20 KiB

  1. from math import sqrt
  2. import torch
  3. from torch.autograd import Variable
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from layers import ConvNorm, LinearNorm
  7. from utils import to_gpu, get_mask_from_lengths
  8. class LocationLayer(nn.Module):
  9. def __init__(self, attention_n_filters, attention_kernel_size,
  10. attention_dim):
  11. super(LocationLayer, self).__init__()
  12. padding = int((attention_kernel_size - 1) / 2)
  13. self.location_conv = ConvNorm(2, attention_n_filters,
  14. kernel_size=attention_kernel_size,
  15. padding=padding, bias=False, stride=1,
  16. dilation=1)
  17. self.location_dense = LinearNorm(attention_n_filters, attention_dim,
  18. bias=False, w_init_gain='tanh')
  19. def forward(self, attention_weights_cat):
  20. processed_attention = self.location_conv(attention_weights_cat)
  21. processed_attention = processed_attention.transpose(1, 2)
  22. processed_attention = self.location_dense(processed_attention)
  23. return processed_attention
  24. class Attention(nn.Module):
  25. def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
  26. attention_location_n_filters, attention_location_kernel_size):
  27. super(Attention, self).__init__()
  28. self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
  29. bias=False, w_init_gain='tanh')
  30. self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
  31. w_init_gain='tanh')
  32. self.v = LinearNorm(attention_dim, 1, bias=False)
  33. self.location_layer = LocationLayer(attention_location_n_filters,
  34. attention_location_kernel_size,
  35. attention_dim)
  36. self.score_mask_value = -float("inf")
  37. def get_alignment_energies(self, query, processed_memory,
  38. attention_weights_cat):
  39. """
  40. PARAMS
  41. ------
  42. query: decoder output (batch, n_mel_channels * n_frames_per_step)
  43. processed_memory: processed encoder outputs (B, T_in, attention_dim)
  44. attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
  45. RETURNS
  46. -------
  47. alignment (batch, max_time)
  48. """
  49. processed_query = self.query_layer(query.unsqueeze(1))
  50. processed_attention_weights = self.location_layer(attention_weights_cat)
  51. energies = self.v(torch.tanh(
  52. processed_query + processed_attention_weights + processed_memory))
  53. energies = energies.squeeze(-1)
  54. return energies
  55. def forward(self, attention_hidden_state, memory, processed_memory,
  56. attention_weights_cat, mask):
  57. """
  58. PARAMS
  59. ------
  60. attention_hidden_state: attention rnn last output
  61. memory: encoder outputs
  62. processed_memory: processed encoder outputs
  63. attention_weights_cat: previous and cummulative attention weights
  64. mask: binary mask for padded data
  65. """
  66. alignment = self.get_alignment_energies(
  67. attention_hidden_state, processed_memory, attention_weights_cat)
  68. if mask is not None:
  69. alignment.data.masked_fill_(mask, self.score_mask_value)
  70. attention_weights = F.softmax(alignment, dim=1)
  71. attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
  72. attention_context = attention_context.squeeze(1)
  73. return attention_context, attention_weights
  74. class Prenet(nn.Module):
  75. def __init__(self, in_dim, sizes):
  76. super(Prenet, self).__init__()
  77. in_sizes = [in_dim] + sizes[:-1]
  78. self.layers = nn.ModuleList(
  79. [LinearNorm(in_size, out_size, bias=False)
  80. for (in_size, out_size) in zip(in_sizes, sizes)])
  81. def forward(self, x):
  82. for linear in self.layers:
  83. x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
  84. return x
  85. class Postnet(nn.Module):
  86. """Postnet
  87. - Five 1-d convolution with 512 channels and kernel size 5
  88. """
  89. def __init__(self, hparams):
  90. super(Postnet, self).__init__()
  91. self.convolutions = nn.ModuleList()
  92. self.convolutions.append(
  93. nn.Sequential(
  94. ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
  95. kernel_size=hparams.postnet_kernel_size, stride=1,
  96. padding=int((hparams.postnet_kernel_size - 1) / 2),
  97. dilation=1, w_init_gain='tanh'),
  98. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  99. )
  100. for i in range(1, hparams.postnet_n_convolutions - 1):
  101. self.convolutions.append(
  102. nn.Sequential(
  103. ConvNorm(hparams.postnet_embedding_dim,
  104. hparams.postnet_embedding_dim,
  105. kernel_size=hparams.postnet_kernel_size, stride=1,
  106. padding=int((hparams.postnet_kernel_size - 1) / 2),
  107. dilation=1, w_init_gain='tanh'),
  108. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  109. )
  110. self.convolutions.append(
  111. nn.Sequential(
  112. ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
  113. kernel_size=hparams.postnet_kernel_size, stride=1,
  114. padding=int((hparams.postnet_kernel_size - 1) / 2),
  115. dilation=1, w_init_gain='linear'),
  116. nn.BatchNorm1d(hparams.n_mel_channels))
  117. )
  118. def forward(self, x):
  119. for i in range(len(self.convolutions) - 1):
  120. x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
  121. x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
  122. return x
  123. class Encoder(nn.Module):
  124. """Encoder module:
  125. - Three 1-d convolution banks
  126. - Bidirectional LSTM
  127. """
  128. def __init__(self, hparams):
  129. super(Encoder, self).__init__()
  130. convolutions = []
  131. for _ in range(hparams.encoder_n_convolutions):
  132. conv_layer = nn.Sequential(
  133. ConvNorm(hparams.encoder_embedding_dim,
  134. hparams.encoder_embedding_dim,
  135. kernel_size=hparams.encoder_kernel_size, stride=1,
  136. padding=int((hparams.encoder_kernel_size - 1) / 2),
  137. dilation=1, w_init_gain='relu'),
  138. nn.BatchNorm1d(hparams.encoder_embedding_dim))
  139. convolutions.append(conv_layer)
  140. self.convolutions = nn.ModuleList(convolutions)
  141. self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
  142. int(hparams.encoder_embedding_dim / 2), 1,
  143. batch_first=True, bidirectional=True)
  144. def forward(self, x, input_lengths):
  145. for conv in self.convolutions:
  146. x = F.dropout(F.relu(conv(x)), 0.5, self.training)
  147. x = x.transpose(1, 2)
  148. # pytorch tensor are not reversible, hence the conversion
  149. input_lengths = input_lengths.cpu().numpy()
  150. x = nn.utils.rnn.pack_padded_sequence(
  151. x, input_lengths, batch_first=True)
  152. self.lstm.flatten_parameters()
  153. outputs, _ = self.lstm(x)
  154. outputs, _ = nn.utils.rnn.pad_packed_sequence(
  155. outputs, batch_first=True)
  156. return outputs
  157. def inference(self, x):
  158. for conv in self.convolutions:
  159. x = F.dropout(F.relu(conv(x)), 0.5, self.training)
  160. x = x.transpose(1, 2)
  161. self.lstm.flatten_parameters()
  162. outputs, _ = self.lstm(x)
  163. return outputs
  164. class Decoder(nn.Module):
  165. def __init__(self, hparams):
  166. super(Decoder, self).__init__()
  167. self.n_mel_channels = hparams.n_mel_channels
  168. self.n_frames_per_step = hparams.n_frames_per_step
  169. self.encoder_embedding_dim = hparams.encoder_embedding_dim
  170. self.attention_rnn_dim = hparams.attention_rnn_dim
  171. self.decoder_rnn_dim = hparams.decoder_rnn_dim
  172. self.prenet_dim = hparams.prenet_dim
  173. self.max_decoder_steps = hparams.max_decoder_steps
  174. self.gate_threshold = hparams.gate_threshold
  175. self.p_attention_dropout = hparams.p_attention_dropout
  176. self.p_decoder_dropout = hparams.p_decoder_dropout
  177. self.prenet = Prenet(
  178. hparams.n_mel_channels * hparams.n_frames_per_step,
  179. [hparams.prenet_dim, hparams.prenet_dim])
  180. self.attention_rnn = nn.LSTMCell(
  181. hparams.prenet_dim + hparams.encoder_embedding_dim,
  182. hparams.attention_rnn_dim)
  183. self.attention_layer = Attention(
  184. hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
  185. hparams.attention_dim, hparams.attention_location_n_filters,
  186. hparams.attention_location_kernel_size)
  187. self.decoder_rnn = nn.LSTMCell(
  188. hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
  189. hparams.decoder_rnn_dim, 1)
  190. self.linear_projection = LinearNorm(
  191. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
  192. hparams.n_mel_channels * hparams.n_frames_per_step)
  193. self.gate_layer = LinearNorm(
  194. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
  195. bias=True, w_init_gain='sigmoid')
  196. def get_go_frame(self, memory):
  197. """ Gets all zeros frames to use as first decoder input
  198. PARAMS
  199. ------
  200. memory: decoder outputs
  201. RETURNS
  202. -------
  203. decoder_input: all zeros frames
  204. """
  205. B = memory.size(0)
  206. decoder_input = Variable(memory.data.new(
  207. B, self.n_mel_channels * self.n_frames_per_step).zero_())
  208. return decoder_input
  209. def initialize_decoder_states(self, memory, mask):
  210. """ Initializes attention rnn states, decoder rnn states, attention
  211. weights, attention cumulative weights, attention context, stores memory
  212. and stores processed memory
  213. PARAMS
  214. ------
  215. memory: Encoder outputs
  216. mask: Mask for padded data if training, expects None for inference
  217. """
  218. B = memory.size(0)
  219. MAX_TIME = memory.size(1)
  220. self.attention_hidden = Variable(memory.data.new(
  221. B, self.attention_rnn_dim).zero_())
  222. self.attention_cell = Variable(memory.data.new(
  223. B, self.attention_rnn_dim).zero_())
  224. self.decoder_hidden = Variable(memory.data.new(
  225. B, self.decoder_rnn_dim).zero_())
  226. self.decoder_cell = Variable(memory.data.new(
  227. B, self.decoder_rnn_dim).zero_())
  228. self.attention_weights = Variable(memory.data.new(
  229. B, MAX_TIME).zero_())
  230. self.attention_weights_cum = Variable(memory.data.new(
  231. B, MAX_TIME).zero_())
  232. self.attention_context = Variable(memory.data.new(
  233. B, self.encoder_embedding_dim).zero_())
  234. self.memory = memory
  235. self.processed_memory = self.attention_layer.memory_layer(memory)
  236. self.mask = mask
  237. def parse_decoder_inputs(self, decoder_inputs):
  238. """ Prepares decoder inputs, i.e. mel outputs
  239. PARAMS
  240. ------
  241. decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
  242. RETURNS
  243. -------
  244. inputs: processed decoder inputs
  245. """
  246. # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
  247. decoder_inputs = decoder_inputs.transpose(1, 2)
  248. decoder_inputs = decoder_inputs.view(
  249. decoder_inputs.size(0),
  250. int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
  251. # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
  252. decoder_inputs = decoder_inputs.transpose(0, 1)
  253. return decoder_inputs
  254. def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
  255. """ Prepares decoder outputs for output
  256. PARAMS
  257. ------
  258. mel_outputs:
  259. gate_outputs: gate output energies
  260. alignments:
  261. RETURNS
  262. -------
  263. mel_outputs:
  264. gate_outpust: gate output energies
  265. alignments:
  266. """
  267. # (T_out, B) -> (B, T_out)
  268. alignments = torch.stack(alignments).transpose(0, 1)
  269. # (T_out, B) -> (B, T_out)
  270. gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
  271. gate_outputs = gate_outputs.contiguous()
  272. # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
  273. mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
  274. # decouple frames per step
  275. mel_outputs = mel_outputs.view(
  276. mel_outputs.size(0), -1, self.n_mel_channels)
  277. # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
  278. mel_outputs = mel_outputs.transpose(1, 2)
  279. return mel_outputs, gate_outputs, alignments
  280. def decode(self, decoder_input):
  281. """ Decoder step using stored states, attention and memory
  282. PARAMS
  283. ------
  284. decoder_input: previous mel output
  285. RETURNS
  286. -------
  287. mel_output:
  288. gate_output: gate output energies
  289. attention_weights:
  290. """
  291. cell_input = torch.cat((decoder_input, self.attention_context), -1)
  292. self.attention_hidden, self.attention_cell = self.attention_rnn(
  293. cell_input, (self.attention_hidden, self.attention_cell))
  294. self.attention_hidden = F.dropout(
  295. self.attention_hidden, self.p_attention_dropout, self.training)
  296. attention_weights_cat = torch.cat(
  297. (self.attention_weights.unsqueeze(1),
  298. self.attention_weights_cum.unsqueeze(1)), dim=1)
  299. self.attention_context, self.attention_weights = self.attention_layer(
  300. self.attention_hidden, self.memory, self.processed_memory,
  301. attention_weights_cat, self.mask)
  302. self.attention_weights_cum += self.attention_weights
  303. decoder_input = torch.cat(
  304. (self.attention_hidden, self.attention_context), -1)
  305. self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
  306. decoder_input, (self.decoder_hidden, self.decoder_cell))
  307. self.decoder_hidden = F.dropout(
  308. self.decoder_hidden, self.p_decoder_dropout, self.training)
  309. decoder_hidden_attention_context = torch.cat(
  310. (self.decoder_hidden, self.attention_context), dim=1)
  311. decoder_output = self.linear_projection(
  312. decoder_hidden_attention_context)
  313. gate_prediction = self.gate_layer(decoder_hidden_attention_context)
  314. return decoder_output, gate_prediction, self.attention_weights
  315. def forward(self, memory, decoder_inputs, memory_lengths):
  316. """ Decoder forward pass for training
  317. PARAMS
  318. ------
  319. memory: Encoder outputs
  320. decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
  321. memory_lengths: Encoder output lengths for attention masking.
  322. RETURNS
  323. -------
  324. mel_outputs: mel outputs from the decoder
  325. gate_outputs: gate outputs from the decoder
  326. alignments: sequence of attention weights from the decoder
  327. """
  328. decoder_input = self.get_go_frame(memory).unsqueeze(0)
  329. decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
  330. decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
  331. decoder_inputs = self.prenet(decoder_inputs)
  332. self.initialize_decoder_states(
  333. memory, mask=~get_mask_from_lengths(memory_lengths))
  334. mel_outputs, gate_outputs, alignments = [], [], []
  335. while len(mel_outputs) < decoder_inputs.size(0) - 1:
  336. decoder_input = decoder_inputs[len(mel_outputs)]
  337. mel_output, gate_output, attention_weights = self.decode(
  338. decoder_input)
  339. mel_outputs += [mel_output.squeeze(1)]
  340. gate_outputs += [gate_output.squeeze(1)]
  341. alignments += [attention_weights]
  342. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  343. mel_outputs, gate_outputs, alignments)
  344. return mel_outputs, gate_outputs, alignments
  345. def inference(self, memory):
  346. """ Decoder inference
  347. PARAMS
  348. ------
  349. memory: Encoder outputs
  350. RETURNS
  351. -------
  352. mel_outputs: mel outputs from the decoder
  353. gate_outputs: gate outputs from the decoder
  354. alignments: sequence of attention weights from the decoder
  355. """
  356. decoder_input = self.get_go_frame(memory)
  357. self.initialize_decoder_states(memory, mask=None)
  358. mel_outputs, gate_outputs, alignments = [], [], []
  359. while True:
  360. decoder_input = self.prenet(decoder_input)
  361. mel_output, gate_output, alignment = self.decode(decoder_input)
  362. mel_outputs += [mel_output.squeeze(1)]
  363. gate_outputs += [gate_output]
  364. alignments += [alignment]
  365. if torch.sigmoid(gate_output.data) > self.gate_threshold:
  366. break
  367. elif len(mel_outputs) == self.max_decoder_steps:
  368. print("Warning! Reached max decoder steps")
  369. break
  370. decoder_input = mel_output
  371. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  372. mel_outputs, gate_outputs, alignments)
  373. return mel_outputs, gate_outputs, alignments
  374. class Tacotron2(nn.Module):
  375. def __init__(self, hparams):
  376. super(Tacotron2, self).__init__()
  377. self.mask_padding = hparams.mask_padding
  378. self.fp16_run = hparams.fp16_run
  379. self.n_mel_channels = hparams.n_mel_channels
  380. self.n_frames_per_step = hparams.n_frames_per_step
  381. self.embedding = nn.Embedding(
  382. hparams.n_symbols, hparams.symbols_embedding_dim)
  383. std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
  384. val = sqrt(3.0) * std # uniform bounds for std
  385. self.embedding.weight.data.uniform_(-val, val)
  386. self.encoder = Encoder(hparams)
  387. self.decoder = Decoder(hparams)
  388. self.postnet = Postnet(hparams)
  389. def parse_batch(self, batch):
  390. text_padded, input_lengths, mel_padded, gate_padded, \
  391. output_lengths = batch
  392. text_padded = to_gpu(text_padded).long()
  393. input_lengths = to_gpu(input_lengths).long()
  394. max_len = torch.max(input_lengths.data).item()
  395. mel_padded = to_gpu(mel_padded).float()
  396. gate_padded = to_gpu(gate_padded).float()
  397. output_lengths = to_gpu(output_lengths).long()
  398. return (
  399. (text_padded, input_lengths, mel_padded, max_len, output_lengths),
  400. (mel_padded, gate_padded))
  401. def parse_output(self, outputs, output_lengths=None):
  402. if self.mask_padding and output_lengths is not None:
  403. mask = ~get_mask_from_lengths(output_lengths)
  404. mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
  405. mask = mask.permute(1, 0, 2)
  406. outputs[0].data.masked_fill_(mask, 0.0)
  407. outputs[1].data.masked_fill_(mask, 0.0)
  408. outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
  409. return outputs
  410. def forward(self, inputs):
  411. text_inputs, text_lengths, mels, max_len, output_lengths = inputs
  412. text_lengths, output_lengths = text_lengths.data, output_lengths.data
  413. embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
  414. encoder_outputs = self.encoder(embedded_inputs, text_lengths)
  415. mel_outputs, gate_outputs, alignments = self.decoder(
  416. encoder_outputs, mels, memory_lengths=text_lengths)
  417. mel_outputs_postnet = self.postnet(mel_outputs)
  418. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  419. return self.parse_output(
  420. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
  421. output_lengths)
  422. def inference(self, inputs):
  423. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  424. encoder_outputs = self.encoder.inference(embedded_inputs)
  425. mel_outputs, gate_outputs, alignments = self.decoder.inference(
  426. encoder_outputs)
  427. mel_outputs_postnet = self.postnet(mel_outputs)
  428. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  429. outputs = self.parse_output(
  430. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
  431. return outputs