From 977cb37cea06f2ca0d0c747d4a899b830ed83887 Mon Sep 17 00:00:00 2001 From: Rafael Valle Date: Fri, 18 May 2018 06:59:09 -0700 Subject: [PATCH 1/2] model.py: attending to full mel instead of prenet and dropout mel --- model.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/model.py b/model.py index 73594b1..7bd8170 100644 --- a/model.py +++ b/model.py @@ -221,7 +221,7 @@ class Decoder(nn.Module): [hparams.prenet_dim, hparams.prenet_dim]) self.attention_rnn = nn.LSTMCell( - hparams.prenet_dim + hparams.encoder_embedding_dim, + hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, hparams.attention_rnn_dim) self.attention_layer = Attention( @@ -230,7 +230,7 @@ class Decoder(nn.Module): hparams.attention_location_kernel_size) self.decoder_rnn = nn.LSTMCell( - hparams.attention_rnn_dim + hparams.encoder_embedding_dim, + hparams.prenet_dim + hparams.encoder_embedding_dim, hparams.decoder_rnn_dim, 1) self.linear_projection = LinearNorm( @@ -351,8 +351,8 @@ class Decoder(nn.Module): attention_weights: """ - decoder_input = self.prenet(decoder_input) - cell_input = torch.cat((decoder_input, self.attention_context), -1) + prenet_output = self.prenet(decoder_input) + cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1) self.attention_hidden, self.attention_cell = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell)) @@ -364,8 +364,7 @@ class Decoder(nn.Module): attention_weights_cat, self.mask) self.attention_weights_cum += self.attention_weights - decoder_input = torch.cat( - (self.attention_hidden, self.attention_context), -1) + decoder_input = torch.cat((prenet_output, self.attention_context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell)) From d5b64729d14b8063634a856fd1abc53b83085de5 Mon Sep 17 00:00:00 2001 From: Rafael Valle Date: Sun, 20 May 2018 12:22:06 -0700 Subject: [PATCH 2/2] model.py: moving for better readibility --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 7bd8170..8ea9a2c 100644 --- a/model.py +++ b/model.py @@ -351,7 +351,6 @@ class Decoder(nn.Module): attention_weights: """ - prenet_output = self.prenet(decoder_input) cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1) self.attention_hidden, self.attention_cell = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell)) @@ -364,6 +363,7 @@ class Decoder(nn.Module): attention_weights_cat, self.mask) self.attention_weights_cum += self.attention_weights + prenet_output = self.prenet(decoder_input) decoder_input = torch.cat((prenet_output, self.attention_context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell))