From 4af4ccb135811fe20c38fa4fa2cd51bb48e33597 Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Sun, 25 Nov 2018 22:33:38 -0800 Subject: [PATCH] model.py: rewrite --- model.py | 77 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/model.py b/model.py index 263faa6..6673b7c 100644 --- a/model.py +++ b/model.py @@ -1,3 +1,4 @@ +from math import sqrt import torch from torch.autograd import Variable from torch import nn @@ -56,7 +57,7 @@ class Attention(nn.Module): processed_query = self.query_layer(query.unsqueeze(1)) processed_attention_weights = self.location_layer(attention_weights_cat) - energies = self.v(F.tanh( + energies = self.v(torch.tanh( processed_query + processed_attention_weights + processed_memory)) energies = energies.squeeze(-1) @@ -107,7 +108,6 @@ class Postnet(nn.Module): def __init__(self, hparams): super(Postnet, self).__init__() - self.dropout = nn.Dropout(0.5) self.convolutions = nn.ModuleList() self.convolutions.append( @@ -141,9 +141,8 @@ class Postnet(nn.Module): def forward(self, x): for i in range(len(self.convolutions) - 1): - x = self.dropout(F.tanh(self.convolutions[i](x))) - - x = self.dropout(self.convolutions[-1](x)) + x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) + x = F.dropout(self.convolutions[-1](x), 0.5, self.training) return x @@ -155,7 +154,6 @@ class Encoder(nn.Module): """ def __init__(self, hparams): super(Encoder, self).__init__() - self.dropout = nn.Dropout(0.5) convolutions = [] for _ in range(hparams.encoder_n_convolutions): @@ -175,7 +173,7 @@ class Encoder(nn.Module): def forward(self, x, input_lengths): for conv in self.convolutions: - x = self.dropout(F.relu(conv(x))) + x = F.dropout(F.relu(conv(x)), 0.5, self.training) x = x.transpose(1, 2) @@ -194,7 +192,7 @@ class Encoder(nn.Module): def inference(self, x): for conv in self.convolutions: - x = self.dropout(F.relu(conv(x))) + x = F.dropout(F.relu(conv(x)), 0.5, self.training) x = x.transpose(1, 2) @@ -215,13 +213,15 @@ class Decoder(nn.Module): self.prenet_dim = hparams.prenet_dim self.max_decoder_steps = hparams.max_decoder_steps self.gate_threshold = hparams.gate_threshold + self.p_attention_dropout = hparams.p_attention_dropout + self.p_decoder_dropout = hparams.p_decoder_dropout self.prenet = Prenet( hparams.n_mel_channels * hparams.n_frames_per_step, [hparams.prenet_dim, hparams.prenet_dim]) self.attention_rnn = nn.LSTMCell( - hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, + hparams.prenet_dim + hparams.encoder_embedding_dim, hparams.attention_rnn_dim) self.attention_layer = Attention( @@ -230,12 +230,12 @@ class Decoder(nn.Module): hparams.attention_location_kernel_size) self.decoder_rnn = nn.LSTMCell( - hparams.prenet_dim + hparams.encoder_embedding_dim, + hparams.attention_rnn_dim + hparams.encoder_embedding_dim, hparams.decoder_rnn_dim, 1) self.linear_projection = LinearNorm( hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, - hparams.n_mel_channels*hparams.n_frames_per_step) + hparams.n_mel_channels * hparams.n_frames_per_step) self.gate_layer = LinearNorm( hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1, @@ -350,10 +350,13 @@ class Decoder(nn.Module): gate_output: gate output energies attention_weights: """ - - cell_input = torch.cat((self.decoder_hidden, self.attention_context), -1) + cell_input = torch.cat((decoder_input, self.attention_context), -1) self.attention_hidden, self.attention_cell = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_hidden = F.dropout( + self.attention_hidden, self.p_attention_dropout, self.training) + self.attention_cell = F.dropout( + self.attention_cell, self.p_attention_dropout, self.training) attention_weights_cat = torch.cat( (self.attention_weights.unsqueeze(1), @@ -363,10 +366,14 @@ class Decoder(nn.Module): attention_weights_cat, self.mask) self.attention_weights_cum += self.attention_weights - prenet_output = self.prenet(decoder_input) - decoder_input = torch.cat((prenet_output, self.attention_context), -1) + decoder_input = torch.cat( + (self.attention_hidden, self.attention_context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell)) + self.decoder_hidden = F.dropout( + self.decoder_hidden, self.p_decoder_dropout, self.training) + self.decoder_cell = F.dropout( + self.decoder_cell, self.p_decoder_dropout, self.training) decoder_hidden_attention_context = torch.cat( (self.decoder_hidden, self.attention_context), dim=1) @@ -391,22 +398,23 @@ class Decoder(nn.Module): alignments: sequence of attention weights from the decoder """ - decoder_input = self.get_go_frame(memory) + decoder_input = self.get_go_frame(memory).unsqueeze(0) decoder_inputs = self.parse_decoder_inputs(decoder_inputs) + decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) + decoder_inputs = self.prenet(decoder_inputs) + self.initialize_decoder_states( memory, mask=~get_mask_from_lengths(memory_lengths)) mel_outputs, gate_outputs, alignments = [], [], [] - - while len(mel_outputs) < decoder_inputs.size(0): + while len(mel_outputs) < decoder_inputs.size(0) - 1: + decoder_input = decoder_inputs[len(mel_outputs)] mel_output, gate_output, attention_weights = self.decode( decoder_input) - mel_outputs += [mel_output] - gate_outputs += [gate_output.squeeze(1)] + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze()] alignments += [attention_weights] - decoder_input = decoder_inputs[len(mel_outputs) - 1] - mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( mel_outputs, gate_outputs, alignments) @@ -430,13 +438,14 @@ class Decoder(nn.Module): mel_outputs, gate_outputs, alignments = [], [], [] while True: + decoder_input = self.prenet(decoder_input) mel_output, gate_output, alignment = self.decode(decoder_input) - mel_outputs += [mel_output] - gate_outputs += [gate_output.squeeze(1)] + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output] alignments += [alignment] - if F.sigmoid(gate_output.data) > self.gate_threshold: + if torch.sigmoid(gate_output.data) > self.gate_threshold: break elif len(mel_outputs) == self.max_decoder_steps: print("Warning! Reached max decoder steps") @@ -459,8 +468,9 @@ class Tacotron2(nn.Module): self.n_frames_per_step = hparams.n_frames_per_step self.embedding = nn.Embedding( hparams.n_symbols, hparams.symbols_embedding_dim) - torch.nn.init.xavier_uniform_(self.embedding.weight.data) - + std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) + val = sqrt(3.0) * std # uniform bounds for std + self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(hparams) self.decoder = Decoder(hparams) self.postnet = Postnet(hparams) @@ -469,8 +479,8 @@ class Tacotron2(nn.Module): text_padded, input_lengths, mel_padded, gate_padded, \ output_lengths = batch text_padded = to_gpu(text_padded).long() - max_len = int(torch.max(input_lengths.data).numpy()) input_lengths = to_gpu(input_lengths).long() + max_len = torch.max(input_lengths.data).item() mel_padded = to_gpu(mel_padded).float() gate_padded = to_gpu(gate_padded).float() output_lengths = to_gpu(output_lengths).long() @@ -485,7 +495,7 @@ class Tacotron2(nn.Module): def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: - mask = ~get_mask_from_lengths(output_lengths+1) # +1 token + mask = ~get_mask_from_lengths(output_lengths) mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) mask = mask.permute(1, 0, 2) @@ -494,7 +504,6 @@ class Tacotron2(nn.Module): outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs - return outputs def forward(self, inputs): @@ -512,14 +521,6 @@ class Tacotron2(nn.Module): mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet - # DataParallel expects equal sized inputs/outputs, hence padding - if input_lengths is not None: - alignments = alignments.unsqueeze(0) - alignments = nn.functional.pad( - alignments, - (0, max_len - alignments.size(3), 0, 0), - "constant", 0) - alignments = alignments.squeeze() return self.parse_output( [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], output_lengths)