Browse Source

model.py: mixed squeeze target. fixing

master
Rafael Valle 6 years ago
parent
commit
dcd925f6c8
1 changed files with 4 additions and 6 deletions
  1. +4
    -6
      model.py

+ 4
- 6
model.py View File

@ -402,9 +402,8 @@ class Decoder(nn.Module):
while len(mel_outputs) < decoder_inputs.size(0): while len(mel_outputs) < decoder_inputs.size(0):
mel_output, gate_output, attention_weights = self.decode( mel_output, gate_output, attention_weights = self.decode(
decoder_input) decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()]
mel_outputs += [mel_output]
gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights] alignments += [attention_weights]
decoder_input = decoder_inputs[len(mel_outputs) - 1] decoder_input = decoder_inputs[len(mel_outputs) - 1]
@ -431,12 +430,11 @@ class Decoder(nn.Module):
self.initialize_decoder_states(memory, mask=None) self.initialize_decoder_states(memory, mask=None)
mel_outputs, gate_outputs, alignments = [], [], [] mel_outputs, gate_outputs, alignments = [], [], []
while True: while True:
mel_output, gate_output, alignment = self.decode(decoder_input) mel_output, gate_output, alignment = self.decode(decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()]
mel_outputs += [mel_output]
gate_outputs += [gate_output.squeeze(1)]
alignments += [alignment] alignments += [alignment]
if F.sigmoid(gate_output.data) > self.gate_threshold: if F.sigmoid(gate_output.data) > self.gate_threshold:

Loading…
Cancel
Save