Browse Source

model.py: renaming variables, removing dropout from lstm cell state, removing conversions now handled by amp

master
rafaelvalle 5 years ago
parent
commit
1480f82908
1 changed files with 5 additions and 17 deletions
  1. +5
    -17
      model.py

+ 5
- 17
model.py View File

@ -5,7 +5,6 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from layers import ConvNorm, LinearNorm from layers import ConvNorm, LinearNorm
from utils import to_gpu, get_mask_from_lengths from utils import to_gpu, get_mask_from_lengths
from fp16_optimizer import fp32_to_fp16, fp16_to_fp32
class LocationLayer(nn.Module): class LocationLayer(nn.Module):
@ -355,8 +354,6 @@ class Decoder(nn.Module):
cell_input, (self.attention_hidden, self.attention_cell)) cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout( self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training) self.attention_hidden, self.p_attention_dropout, self.training)
self.attention_cell = F.dropout(
self.attention_cell, self.p_attention_dropout, self.training)
attention_weights_cat = torch.cat( attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1), (self.attention_weights.unsqueeze(1),
@ -372,8 +369,6 @@ class Decoder(nn.Module):
decoder_input, (self.decoder_hidden, self.decoder_cell)) decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout( self.decoder_hidden = F.dropout(
self.decoder_hidden, self.p_decoder_dropout, self.training) self.decoder_hidden, self.p_decoder_dropout, self.training)
self.decoder_cell = F.dropout(
self.decoder_cell, self.p_decoder_dropout, self.training)
decoder_hidden_attention_context = torch.cat( decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1) (self.decoder_hidden, self.attention_context), dim=1)
@ -489,10 +484,6 @@ class Tacotron2(nn.Module):
(text_padded, input_lengths, mel_padded, max_len, output_lengths), (text_padded, input_lengths, mel_padded, max_len, output_lengths),
(mel_padded, gate_padded)) (mel_padded, gate_padded))
def parse_input(self, inputs):
inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
return inputs
def parse_output(self, outputs, output_lengths=None): def parse_output(self, outputs, output_lengths=None):
if self.mask_padding and output_lengths is not None: if self.mask_padding and output_lengths is not None:
mask = ~get_mask_from_lengths(output_lengths) mask = ~get_mask_from_lengths(output_lengths)
@ -503,20 +494,18 @@ class Tacotron2(nn.Module):
outputs[1].data.masked_fill_(mask, 0.0) outputs[1].data.masked_fill_(mask, 0.0)
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
return outputs return outputs
def forward(self, inputs): def forward(self, inputs):
inputs, input_lengths, targets, max_len, \
output_lengths = self.parse_input(inputs)
input_lengths, output_lengths = input_lengths.data, output_lengths.data
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
text_lengths, output_lengths = text_lengths.data, output_lengths.data
embedded_inputs = self.embedding(inputs).transpose(1, 2)
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, input_lengths)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
mel_outputs, gate_outputs, alignments = self.decoder( mel_outputs, gate_outputs, alignments = self.decoder(
encoder_outputs, targets, memory_lengths=input_lengths)
encoder_outputs, mels, memory_lengths=text_lengths)
mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs_postnet = mel_outputs + mel_outputs_postnet
@ -526,7 +515,6 @@ class Tacotron2(nn.Module):
output_lengths) output_lengths)
def inference(self, inputs): def inference(self, inputs):
inputs = self.parse_input(inputs)
embedded_inputs = self.embedding(inputs).transpose(1, 2) embedded_inputs = self.embedding(inputs).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs) encoder_outputs = self.encoder.inference(embedded_inputs)
mel_outputs, gate_outputs, alignments = self.decoder.inference( mel_outputs, gate_outputs, alignments = self.decoder.inference(

Loading…
Cancel
Save