diff --git a/model.py b/model.py index 7bd8170..8ea9a2c 100644 --- a/model.py +++ b/model.py @@ -351,7 +351,6 @@ class Decoder(nn.Module): attention_weights: """ - prenet_output = self.prenet(decoder_input) cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1) self.attention_hidden, self.attention_cell = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell)) @@ -364,6 +363,7 @@ class Decoder(nn.Module): attention_weights_cat, self.mask) self.attention_weights_cum += self.attention_weights + prenet_output = self.prenet(decoder_input) decoder_input = torch.cat((prenet_output, self.attention_context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell))