|
@ -221,7 +221,7 @@ class Decoder(nn.Module): |
|
|
[hparams.prenet_dim, hparams.prenet_dim]) |
|
|
[hparams.prenet_dim, hparams.prenet_dim]) |
|
|
|
|
|
|
|
|
self.attention_rnn = nn.LSTMCell( |
|
|
self.attention_rnn = nn.LSTMCell( |
|
|
hparams.prenet_dim + hparams.encoder_embedding_dim, |
|
|
|
|
|
|
|
|
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, |
|
|
hparams.attention_rnn_dim) |
|
|
hparams.attention_rnn_dim) |
|
|
|
|
|
|
|
|
self.attention_layer = Attention( |
|
|
self.attention_layer = Attention( |
|
@ -230,7 +230,7 @@ class Decoder(nn.Module): |
|
|
hparams.attention_location_kernel_size) |
|
|
hparams.attention_location_kernel_size) |
|
|
|
|
|
|
|
|
self.decoder_rnn = nn.LSTMCell( |
|
|
self.decoder_rnn = nn.LSTMCell( |
|
|
hparams.attention_rnn_dim + hparams.encoder_embedding_dim, |
|
|
|
|
|
|
|
|
hparams.prenet_dim + hparams.encoder_embedding_dim, |
|
|
hparams.decoder_rnn_dim, 1) |
|
|
hparams.decoder_rnn_dim, 1) |
|
|
|
|
|
|
|
|
self.linear_projection = LinearNorm( |
|
|
self.linear_projection = LinearNorm( |
|
@ -351,8 +351,8 @@ class Decoder(nn.Module): |
|
|
attention_weights: |
|
|
attention_weights: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
decoder_input = self.prenet(decoder_input) |
|
|
|
|
|
cell_input = torch.cat((decoder_input, self.attention_context), -1) |
|
|
|
|
|
|
|
|
prenet_output = self.prenet(decoder_input) |
|
|
|
|
|
cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1) |
|
|
self.attention_hidden, self.attention_cell = self.attention_rnn( |
|
|
self.attention_hidden, self.attention_cell = self.attention_rnn( |
|
|
cell_input, (self.attention_hidden, self.attention_cell)) |
|
|
cell_input, (self.attention_hidden, self.attention_cell)) |
|
|
|
|
|
|
|
@ -364,8 +364,7 @@ class Decoder(nn.Module): |
|
|
attention_weights_cat, self.mask) |
|
|
attention_weights_cat, self.mask) |
|
|
|
|
|
|
|
|
self.attention_weights_cum += self.attention_weights |
|
|
self.attention_weights_cum += self.attention_weights |
|
|
decoder_input = torch.cat( |
|
|
|
|
|
(self.attention_hidden, self.attention_context), -1) |
|
|
|
|
|
|
|
|
decoder_input = torch.cat((prenet_output, self.attention_context), -1) |
|
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( |
|
|
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( |
|
|
decoder_input, (self.decoder_hidden, self.decoder_cell)) |
|
|
decoder_input, (self.decoder_hidden, self.decoder_cell)) |
|
|
|
|
|
|
|
|