From 1480f82908fea60a69a656850db8b6f7e32ed4b8 Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Wed, 3 Apr 2019 13:36:35 -0800 Subject: [PATCH] model.py: renaming variables, removing dropout from lstm cell state, removing conversions now handled by amp --- model.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/model.py b/model.py index 6673b7c..4c7d7d2 100644 --- a/model.py +++ b/model.py @@ -5,7 +5,6 @@ from torch import nn from torch.nn import functional as F from layers import ConvNorm, LinearNorm from utils import to_gpu, get_mask_from_lengths -from fp16_optimizer import fp32_to_fp16, fp16_to_fp32 class LocationLayer(nn.Module): @@ -355,8 +354,6 @@ class Decoder(nn.Module): cell_input, (self.attention_hidden, self.attention_cell)) self.attention_hidden = F.dropout( self.attention_hidden, self.p_attention_dropout, self.training) - self.attention_cell = F.dropout( - self.attention_cell, self.p_attention_dropout, self.training) attention_weights_cat = torch.cat( (self.attention_weights.unsqueeze(1), @@ -372,8 +369,6 @@ class Decoder(nn.Module): decoder_input, (self.decoder_hidden, self.decoder_cell)) self.decoder_hidden = F.dropout( self.decoder_hidden, self.p_decoder_dropout, self.training) - self.decoder_cell = F.dropout( - self.decoder_cell, self.p_decoder_dropout, self.training) decoder_hidden_attention_context = torch.cat( (self.decoder_hidden, self.attention_context), dim=1) @@ -489,10 +484,6 @@ class Tacotron2(nn.Module): (text_padded, input_lengths, mel_padded, max_len, output_lengths), (mel_padded, gate_padded)) - def parse_input(self, inputs): - inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs - return inputs - def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask = ~get_mask_from_lengths(output_lengths) @@ -503,20 +494,18 @@ class Tacotron2(nn.Module): outputs[1].data.masked_fill_(mask, 0.0) outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies - outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs return outputs def forward(self, inputs): - inputs, input_lengths, targets, max_len, \ - output_lengths = self.parse_input(inputs) - input_lengths, output_lengths = input_lengths.data, output_lengths.data + text_inputs, text_lengths, mels, max_len, output_lengths = inputs + text_lengths, output_lengths = text_lengths.data, output_lengths.data - embedded_inputs = self.embedding(inputs).transpose(1, 2) + embedded_inputs = self.embedding(text_inputs).transpose(1, 2) - encoder_outputs = self.encoder(embedded_inputs, input_lengths) + encoder_outputs = self.encoder(embedded_inputs, text_lengths) mel_outputs, gate_outputs, alignments = self.decoder( - encoder_outputs, targets, memory_lengths=input_lengths) + encoder_outputs, mels, memory_lengths=text_lengths) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet @@ -526,7 +515,6 @@ class Tacotron2(nn.Module): output_lengths) def inference(self, inputs): - inputs = self.parse_input(inputs) embedded_inputs = self.embedding(inputs).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) mel_outputs, gate_outputs, alignments = self.decoder.inference(