Browse Source

Merge pull request #23 from NVIDIA/attention_full_mel

model.py: attending to full mel instead of prenet and dropout mel
master
Rafael Valle 6 years ago
committed by GitHub
parent
commit
064629c9bc
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 6 deletions
  1. +5
    -6
      model.py

+ 5
- 6
model.py View File

@ -221,7 +221,7 @@ class Decoder(nn.Module):
[hparams.prenet_dim, hparams.prenet_dim]) [hparams.prenet_dim, hparams.prenet_dim])
self.attention_rnn = nn.LSTMCell( self.attention_rnn = nn.LSTMCell(
hparams.prenet_dim + hparams.encoder_embedding_dim,
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
hparams.attention_rnn_dim) hparams.attention_rnn_dim)
self.attention_layer = Attention( self.attention_layer = Attention(
@ -230,7 +230,7 @@ class Decoder(nn.Module):
hparams.attention_location_kernel_size) hparams.attention_location_kernel_size)
self.decoder_rnn = nn.LSTMCell( self.decoder_rnn = nn.LSTMCell(
hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
hparams.prenet_dim + hparams.encoder_embedding_dim,
hparams.decoder_rnn_dim, 1) hparams.decoder_rnn_dim, 1)
self.linear_projection = LinearNorm( self.linear_projection = LinearNorm(
@ -351,8 +351,7 @@ class Decoder(nn.Module):
attention_weights: attention_weights:
""" """
decoder_input = self.prenet(decoder_input)
cell_input = torch.cat((decoder_input, self.attention_context), -1)
cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn( self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell)) cell_input, (self.attention_hidden, self.attention_cell))
@ -364,8 +363,8 @@ class Decoder(nn.Module):
attention_weights_cat, self.mask) attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
prenet_output = self.prenet(decoder_input)
decoder_input = torch.cat((prenet_output, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn( self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell)) decoder_input, (self.decoder_hidden, self.decoder_cell))

Loading…
Cancel
Save