|
|
@ -1,3 +1,4 @@ |
|
|
|
from math import sqrt |
|
|
|
import torch |
|
|
|
from torch.autograd import Variable |
|
|
|
from torch import nn |
|
|
@ -56,7 +57,7 @@ class Attention(nn.Module): |
|
|
|
|
|
|
|
processed_query = self.query_layer(query.unsqueeze(1)) |
|
|
|
processed_attention_weights = self.location_layer(attention_weights_cat) |
|
|
|
energies = self.v(F.tanh( |
|
|
|
energies = self.v(torch.tanh( |
|
|
|
processed_query + processed_attention_weights + processed_memory)) |
|
|
|
|
|
|
|
energies = energies.squeeze(-1) |
|
|
@ -107,7 +108,6 @@ class Postnet(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, hparams): |
|
|
|
super(Postnet, self).__init__() |
|
|
|
self.dropout = nn.Dropout(0.5) |
|
|
|
self.convolutions = nn.ModuleList() |
|
|
|
|
|
|
|
self.convolutions.append( |
|
|
@ -141,9 +141,8 @@ class Postnet(nn.Module): |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
for i in range(len(self.convolutions) - 1): |
|
|
|
x = self.dropout(F.tanh(self.convolutions[i](x))) |
|
|
|
|
|
|
|
x = self.dropout(self.convolutions[-1](x)) |
|
|
|
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) |
|
|
|
x = F.dropout(self.convolutions[-1](x), 0.5, self.training) |
|
|
|
|
|
|
|
return x |
|
|
|
|
|
|
@ -155,7 +154,6 @@ class Encoder(nn.Module): |
|
|
|
""" |
|
|
|
def __init__(self, hparams): |
|
|
|
super(Encoder, self).__init__() |
|
|
|
self.dropout = nn.Dropout(0.5) |
|
|
|
|
|
|
|
convolutions = [] |
|
|
|
for _ in range(hparams.encoder_n_convolutions): |
|
|
@ -175,7 +173,7 @@ class Encoder(nn.Module): |
|
|
|
|
|
|
|
def forward(self, x, input_lengths): |
|
|
|
for conv in self.convolutions: |
|
|
|
x = self.dropout(F.relu(conv(x))) |
|
|
|
x = F.dropout(F.relu(conv(x)), 0.5, self.training) |
|
|
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
@ -194,7 +192,7 @@ class Encoder(nn.Module): |
|
|
|
|
|
|
|
def inference(self, x): |
|
|
|
for conv in self.convolutions: |
|
|
|
x = self.dropout(F.relu(conv(x))) |
|
|
|
x = F.dropout(F.relu(conv(x)), 0.5, self.training) |
|
|
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
@ -215,13 +213,15 @@ class Decoder(nn.Module): |
|
|
|
self.prenet_dim = hparams.prenet_dim |
|
|
|
self.max_decoder_steps = hparams.max_decoder_steps |
|
|
|
self.gate_threshold = hparams.gate_threshold |
|
|
|
self.p_attention_dropout = hparams.p_attention_dropout |
|
|
|
self.p_decoder_dropout = hparams.p_decoder_dropout |
|
|
|
|
|
|
|
self.prenet = Prenet( |
|
|
|
hparams.n_mel_channels * hparams.n_frames_per_step, |
|
|
|
[hparams.prenet_dim, hparams.prenet_dim]) |
|
|
|
|
|
|
|
self.attention_rnn = nn.LSTMCell( |
|
|
|
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, |
|
|
|
hparams.prenet_dim + hparams.encoder_embedding_dim, |
|
|
|
hparams.attention_rnn_dim) |
|
|
|
|
|
|
|
self.attention_layer = Attention( |
|
|
@ -230,12 +230,12 @@ class Decoder(nn.Module): |
|
|
|
hparams.attention_location_kernel_size) |
|
|
|
|
|
|
|
self.decoder_rnn = nn.LSTMCell( |
|
|
|
hparams.prenet_dim + hparams.encoder_embedding_dim, |
|
|
|
hparams.attention_rnn_dim + hparams.encoder_embedding_dim, |
|
|
|
hparams.decoder_rnn_dim, 1) |
|
|
|
|
|
|
|
self.linear_projection = LinearNorm( |
|
|
|
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, |
|
|
|
hparams.n_mel_channels*hparams.n_frames_per_step) |
|
|
|
hparams.n_mel_channels * hparams.n_frames_per_step) |
|
|
|
|
|
|
|
self.gate_layer = LinearNorm( |
|
|
|
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1, |
|
|
@ -350,10 +350,13 @@ class Decoder(nn.Module): |
|
|
|
gate_output: gate output energies |
|
|
|
attention_weights: |
|
|
|
""" |
|
|
|
|
|
|
|
cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1) |
|
|
|
cell_input = torch.cat((decoder_input, self.attention_context), -1) |
|
|
|
self.attention_hidden, self.attention_cell = self.attention_rnn( |
|
|
|
cell_input, (self.attention_hidden, self.attention_cell)) |
|
|
|
self.attention_hidden = F.dropout( |
|
|
|
self.attention_hidden, self.p_attention_dropout, self.training) |
|
|
|
self.attention_cell = F.dropout( |
|
|
|
self.attention_cell, self.p_attention_dropout, self.training) |
|
|
|
|
|
|
|
attention_weights_cat = torch.cat( |
|
|
|
(self.attention_weights.unsqueeze(1), |
|
|
@ -363,10 +366,14 @@ class Decoder(nn.Module): |
|
|
|
attention_weights_cat, self.mask) |
|
|
|
|
|
|
|
self.attention_weights_cum += self.attention_weights |
|
|
|
prenet_output = self.prenet(decoder_input) |
|
|
|
decoder_input = torch.cat((prenet_output, self.attention_context), -1) |
|
|
|
decoder_input = torch.cat( |
|
|
|
(self.attention_hidden, self.attention_context), -1) |
|
|
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( |
|
|
|
decoder_input, (self.decoder_hidden, self.decoder_cell)) |
|
|
|
self.decoder_hidden = F.dropout( |
|
|
|
self.decoder_hidden, self.p_decoder_dropout, self.training) |
|
|
|
self.decoder_cell = F.dropout( |
|
|
|
self.decoder_cell, self.p_decoder_dropout, self.training) |
|
|
|
|
|
|
|
decoder_hidden_attention_context = torch.cat( |
|
|
|
(self.decoder_hidden, self.attention_context), dim=1) |
|
|
@ -391,22 +398,23 @@ class Decoder(nn.Module): |
|
|
|
alignments: sequence of attention weights from the decoder |
|
|
|
""" |
|
|
|
|
|
|
|
decoder_input = self.get_go_frame(memory) |
|
|
|
decoder_input = self.get_go_frame(memory).unsqueeze(0) |
|
|
|
decoder_inputs = self.parse_decoder_inputs(decoder_inputs) |
|
|
|
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) |
|
|
|
decoder_inputs = self.prenet(decoder_inputs) |
|
|
|
|
|
|
|
self.initialize_decoder_states( |
|
|
|
memory, mask=~get_mask_from_lengths(memory_lengths)) |
|
|
|
|
|
|
|
mel_outputs, gate_outputs, alignments = [], [], [] |
|
|
|
|
|
|
|
while len(mel_outputs) < decoder_inputs.size(0): |
|
|
|
while len(mel_outputs) < decoder_inputs.size(0) - 1: |
|
|
|
decoder_input = decoder_inputs[len(mel_outputs)] |
|
|
|
mel_output, gate_output, attention_weights = self.decode( |
|
|
|
decoder_input) |
|
|
|
mel_outputs += [mel_output] |
|
|
|
gate_outputs += [gate_output.squeeze(1)] |
|
|
|
mel_outputs += [mel_output.squeeze(1)] |
|
|
|
gate_outputs += [gate_output.squeeze()] |
|
|
|
alignments += [attention_weights] |
|
|
|
|
|
|
|
decoder_input = decoder_inputs[len(mel_outputs) - 1] |
|
|
|
|
|
|
|
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( |
|
|
|
mel_outputs, gate_outputs, alignments) |
|
|
|
|
|
|
@ -430,13 +438,14 @@ class Decoder(nn.Module): |
|
|
|
|
|
|
|
mel_outputs, gate_outputs, alignments = [], [], [] |
|
|
|
while True: |
|
|
|
decoder_input = self.prenet(decoder_input) |
|
|
|
mel_output, gate_output, alignment = self.decode(decoder_input) |
|
|
|
|
|
|
|
mel_outputs += [mel_output] |
|
|
|
gate_outputs += [gate_output.squeeze(1)] |
|
|
|
mel_outputs += [mel_output.squeeze(1)] |
|
|
|
gate_outputs += [gate_output] |
|
|
|
alignments += [alignment] |
|
|
|
|
|
|
|
if F.sigmoid(gate_output.data) > self.gate_threshold: |
|
|
|
if torch.sigmoid(gate_output.data) > self.gate_threshold: |
|
|
|
break |
|
|
|
elif len(mel_outputs) == self.max_decoder_steps: |
|
|
|
print("Warning! Reached max decoder steps") |
|
|
@ -459,8 +468,9 @@ class Tacotron2(nn.Module): |
|
|
|
self.n_frames_per_step = hparams.n_frames_per_step |
|
|
|
self.embedding = nn.Embedding( |
|
|
|
hparams.n_symbols, hparams.symbols_embedding_dim) |
|
|
|
torch.nn.init.xavier_uniform_(self.embedding.weight.data) |
|
|
|
|
|
|
|
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) |
|
|
|
val = sqrt(3.0) * std # uniform bounds for std |
|
|
|
self.embedding.weight.data.uniform_(-val, val) |
|
|
|
self.encoder = Encoder(hparams) |
|
|
|
self.decoder = Decoder(hparams) |
|
|
|
self.postnet = Postnet(hparams) |
|
|
@ -469,8 +479,8 @@ class Tacotron2(nn.Module): |
|
|
|
text_padded, input_lengths, mel_padded, gate_padded, \ |
|
|
|
output_lengths = batch |
|
|
|
text_padded = to_gpu(text_padded).long() |
|
|
|
max_len = int(torch.max(input_lengths.data).numpy()) |
|
|
|
input_lengths = to_gpu(input_lengths).long() |
|
|
|
max_len = torch.max(input_lengths.data).item() |
|
|
|
mel_padded = to_gpu(mel_padded).float() |
|
|
|
gate_padded = to_gpu(gate_padded).float() |
|
|
|
output_lengths = to_gpu(output_lengths).long() |
|
|
@ -485,7 +495,7 @@ class Tacotron2(nn.Module): |
|
|
|
|
|
|
|
def parse_output(self, outputs, output_lengths=None): |
|
|
|
if self.mask_padding and output_lengths is not None: |
|
|
|
mask = ~get_mask_from_lengths(output_lengths+1) # +1 <stop> token |
|
|
|
mask = ~get_mask_from_lengths(output_lengths) |
|
|
|
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) |
|
|
|
mask = mask.permute(1, 0, 2) |
|
|
|
|
|
|
@ -494,7 +504,6 @@ class Tacotron2(nn.Module): |
|
|
|
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies |
|
|
|
|
|
|
|
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs |
|
|
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
@ -512,14 +521,6 @@ class Tacotron2(nn.Module): |
|
|
|
mel_outputs_postnet = self.postnet(mel_outputs) |
|
|
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet |
|
|
|
|
|
|
|
# DataParallel expects equal sized inputs/outputs, hence padding |
|
|
|
if input_lengths is not None: |
|
|
|
alignments = alignments.unsqueeze(0) |
|
|
|
alignments = nn.functional.pad( |
|
|
|
alignments, |
|
|
|
(0, max_len - alignments.size(3), 0, 0), |
|
|
|
"constant", 0) |
|
|
|
alignments = alignments.squeeze() |
|
|
|
return self.parse_output( |
|
|
|
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments], |
|
|
|
output_lengths) |
|
|
|