From dcd925f6c8047c5f128bfd5b28b3a7ce4b2d6de8 Mon Sep 17 00:00:00 2001 From: Rafael Valle Date: Sun, 6 May 2018 08:58:01 -0700 Subject: [PATCH] model.py: mixed squeeze target. fixing --- model.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/model.py b/model.py index d416945..73594b1 100644 --- a/model.py +++ b/model.py @@ -402,9 +402,8 @@ class Decoder(nn.Module): while len(mel_outputs) < decoder_inputs.size(0): mel_output, gate_output, attention_weights = self.decode( decoder_input) - - mel_outputs += [mel_output.squeeze(1)] - gate_outputs += [gate_output.squeeze()] + mel_outputs += [mel_output] + gate_outputs += [gate_output.squeeze(1)] alignments += [attention_weights] decoder_input = decoder_inputs[len(mel_outputs) - 1] @@ -431,12 +430,11 @@ class Decoder(nn.Module): self.initialize_decoder_states(memory, mask=None) mel_outputs, gate_outputs, alignments = [], [], [] - while True: mel_output, gate_output, alignment = self.decode(decoder_input) - mel_outputs += [mel_output.squeeze(1)] - gate_outputs += [gate_output.squeeze()] + mel_outputs += [mel_output] + gate_outputs += [gate_output.squeeze(1)] alignments += [alignment] if F.sigmoid(gate_output.data) > self.gate_threshold: