diff --git a/model.py b/model.py index 7356b19..d416945 100644 --- a/model.py +++ b/model.py @@ -470,8 +470,8 @@ class Tacotron2(nn.Module): text_padded, input_lengths, mel_padded, gate_padded, \ output_lengths = batch text_padded = to_gpu(text_padded).long() + max_len = int(torch.max(input_lengths.data).numpy()) input_lengths = to_gpu(input_lengths).long() - max_len = torch.max(input_lengths.data).cpu().numpy()[0] mel_padded = to_gpu(mel_padded).float() gate_padded = to_gpu(gate_padded).float() output_lengths = to_gpu(output_lengths).long()