@ -1,3 +1,4 @@
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
@ -56,7 +57,7 @@ class Attention(nn.Module):
processed_query = self . query_layer ( query . unsqueeze ( 1 ) )
processed_attention_weights = self . location_layer ( attention_weights_cat )
energies = self . v ( F . tanh (
energies = self . v ( torch . tanh (
processed_query + processed_attention_weights + processed_memory ) )
energies = energies . squeeze ( - 1 )
@ -107,7 +108,6 @@ class Postnet(nn.Module):
def __init__ ( self , hparams ) :
super ( Postnet , self ) . __init__ ( )
self . dropout = nn . Dropout ( 0.5 )
self . convolutions = nn . ModuleList ( )
self . convolutions . append (
@ -141,9 +141,8 @@ class Postnet(nn.Module):
def forward ( self , x ) :
for i in range ( len ( self . convolutions ) - 1 ) :
x = self . dropout ( F . tanh ( self . convolutions [ i ] ( x ) ) )
x = self . dropout ( self . convolutions [ - 1 ] ( x ) )
x = F . dropout ( torch . tanh ( self . convolutions [ i ] ( x ) ) , 0.5 , self . training )
x = F . dropout ( self . convolutions [ - 1 ] ( x ) , 0.5 , self . training )
return x
@ -155,7 +154,6 @@ class Encoder(nn.Module):
"""
def __init__ ( self , hparams ) :
super ( Encoder , self ) . __init__ ( )
self . dropout = nn . Dropout ( 0.5 )
convolutions = [ ]
for _ in range ( hparams . encoder_n_convolutions ) :
@ -175,7 +173,7 @@ class Encoder(nn.Module):
def forward ( self , x , input_lengths ) :
for conv in self . convolutions :
x = self . dropout ( F . relu ( conv ( x ) ) )
x = F . dropout ( F . relu ( conv ( x ) ) , 0.5 , self . training )
x = x . transpose ( 1 , 2 )
@ -194,7 +192,7 @@ class Encoder(nn.Module):
def inference ( self , x ) :
for conv in self . convolutions :
x = self . dropout ( F . relu ( conv ( x ) ) )
x = F . dropout ( F . relu ( conv ( x ) ) , 0.5 , self . training )
x = x . transpose ( 1 , 2 )
@ -215,13 +213,15 @@ class Decoder(nn.Module):
self . prenet_dim = hparams . prenet_dim
self . max_decoder_steps = hparams . max_decoder_steps
self . gate_threshold = hparams . gate_threshold
self . p_attention_dropout = hparams . p_attention_dropout
self . p_decoder_dropout = hparams . p_decoder_dropout
self . prenet = Prenet (
hparams . n_mel_channels * hparams . n_frames_per_step ,
[ hparams . prenet_dim , hparams . prenet_dim ] )
self . attention_rnn = nn . LSTMCell (
hparams . decoder_rnn _dim + hparams . encoder_embedding_dim ,
hparams . prenet _dim + hparams . encoder_embedding_dim ,
hparams . attention_rnn_dim )
self . attention_layer = Attention (
@ -230,12 +230,12 @@ class Decoder(nn.Module):
hparams . attention_location_kernel_size )
self . decoder_rnn = nn . LSTMCell (
hparams . prenet _dim + hparams . encoder_embedding_dim ,
hparams . attention_rnn _dim + hparams . encoder_embedding_dim ,
hparams . decoder_rnn_dim , 1 )
self . linear_projection = LinearNorm (
hparams . decoder_rnn_dim + hparams . encoder_embedding_dim ,
hparams . n_mel_channels * hparams . n_frames_per_step )
hparams . n_mel_channels * hparams . n_frames_per_step )
self . gate_layer = LinearNorm (
hparams . decoder_rnn_dim + hparams . encoder_embedding_dim , 1 ,
@ -350,10 +350,13 @@ class Decoder(nn.Module):
gate_output : gate output energies
attention_weights :
"""
cell_input = torch . cat ( ( self . decoder_hidden , self . attention_context ) , - 1 )
cell_input = torch . cat ( ( decoder_input , self . attention_context ) , - 1 )
self . attention_hidden , self . attention_cell = self . attention_rnn (
cell_input , ( self . attention_hidden , self . attention_cell ) )
self . attention_hidden = F . dropout (
self . attention_hidden , self . p_attention_dropout , self . training )
self . attention_cell = F . dropout (
self . attention_cell , self . p_attention_dropout , self . training )
attention_weights_cat = torch . cat (
( self . attention_weights . unsqueeze ( 1 ) ,
@ -363,10 +366,14 @@ class Decoder(nn.Module):
attention_weights_cat , self . mask )
self . attention_weights_cum + = self . attention_weights
prenet_output = self . prenet ( decoder_input )
decoder_input = torch . cat ( ( prenet_output , self . attention_context ) , - 1 )
decoder_input = torch . cat (
( self . attention_hidden , self . attention_context ) , - 1 )
self . decoder_hidden , self . decoder_cell = self . decoder_rnn (
decoder_input , ( self . decoder_hidden , self . decoder_cell ) )
self . decoder_hidden = F . dropout (
self . decoder_hidden , self . p_decoder_dropout , self . training )
self . decoder_cell = F . dropout (
self . decoder_cell , self . p_decoder_dropout , self . training )
decoder_hidden_attention_context = torch . cat (
( self . decoder_hidden , self . attention_context ) , dim = 1 )
@ -391,22 +398,23 @@ class Decoder(nn.Module):
alignments : sequence of attention weights from the decoder
"""
decoder_input = self . get_go_frame ( memory )
decoder_input = self . get_go_frame ( memory ) . unsqueeze ( 0 )
decoder_inputs = self . parse_decoder_inputs ( decoder_inputs )
decoder_inputs = torch . cat ( ( decoder_input , decoder_inputs ) , dim = 0 )
decoder_inputs = self . prenet ( decoder_inputs )
self . initialize_decoder_states (
memory , mask = ~ get_mask_from_lengths ( memory_lengths ) )
mel_outputs , gate_outputs , alignments = [ ] , [ ] , [ ]
while len ( mel_outputs ) < decoder_inputs . size ( 0 ) :
while len ( mel_outputs ) < decoder_inputs . size ( 0 ) - 1 :
decoder_input = decoder_inputs [ len ( mel_outputs ) ]
mel_output , gate_output , attention_weights = self . decode (
decoder_input )
mel_outputs + = [ mel_output ]
gate_outputs + = [ gate_output . squeeze ( 1 ) ]
mel_outputs + = [ mel_output . squeeze ( 1 ) ]
gate_outputs + = [ gate_output . squeeze ( ) ]
alignments + = [ attention_weights ]
decoder_input = decoder_inputs [ len ( mel_outputs ) - 1 ]
mel_outputs , gate_outputs , alignments = self . parse_decoder_outputs (
mel_outputs , gate_outputs , alignments )
@ -430,13 +438,14 @@ class Decoder(nn.Module):
mel_outputs , gate_outputs , alignments = [ ] , [ ] , [ ]
while True :
decoder_input = self . prenet ( decoder_input )
mel_output , gate_output , alignment = self . decode ( decoder_input )
mel_outputs + = [ mel_output ]
gate_outputs + = [ gate_output . squeeze ( 1 ) ]
mel_outputs + = [ mel_output . squeeze ( 1 ) ]
gate_outputs + = [ gate_output ]
alignments + = [ alignment ]
if F . sigmoid ( gate_output . data ) > self . gate_threshold :
if torch . sigmoid ( gate_output . data ) > self . gate_threshold :
break
elif len ( mel_outputs ) == self . max_decoder_steps :
print ( " Warning! Reached max decoder steps " )
@ -459,8 +468,9 @@ class Tacotron2(nn.Module):
self . n_frames_per_step = hparams . n_frames_per_step
self . embedding = nn . Embedding (
hparams . n_symbols , hparams . symbols_embedding_dim )
torch . nn . init . xavier_uniform_ ( self . embedding . weight . data )
std = sqrt ( 2.0 / ( hparams . n_symbols + hparams . symbols_embedding_dim ) )
val = sqrt ( 3.0 ) * std # uniform bounds for std
self . embedding . weight . data . uniform_ ( - val , val )
self . encoder = Encoder ( hparams )
self . decoder = Decoder ( hparams )
self . postnet = Postnet ( hparams )
@ -469,8 +479,8 @@ class Tacotron2(nn.Module):
text_padded , input_lengths , mel_padded , gate_padded , \
output_lengths = batch
text_padded = to_gpu ( text_padded ) . long ( )
max_len = int ( torch . max ( input_lengths . data ) . numpy ( ) )
input_lengths = to_gpu ( input_lengths ) . long ( )
max_len = torch . max ( input_lengths . data ) . item ( )
mel_padded = to_gpu ( mel_padded ) . float ( )
gate_padded = to_gpu ( gate_padded ) . float ( )
output_lengths = to_gpu ( output_lengths ) . long ( )
@ -485,7 +495,7 @@ class Tacotron2(nn.Module):
def parse_output ( self , outputs , output_lengths = None ) :
if self . mask_padding and output_lengths is not None :
mask = ~ get_mask_from_lengths ( output_lengths + 1 ) # +1 <stop> token
mask = ~ get_mask_from_lengths ( output_lengths )
mask = mask . expand ( self . n_mel_channels , mask . size ( 0 ) , mask . size ( 1 ) )
mask = mask . permute ( 1 , 0 , 2 )
@ -494,7 +504,6 @@ class Tacotron2(nn.Module):
outputs [ 2 ] . data . masked_fill_ ( mask [ : , 0 , : ] , 1e3 ) # gate energies
outputs = fp16_to_fp32 ( outputs ) if self . fp16_run else outputs
return outputs
def forward ( self , inputs ) :
@ -512,14 +521,6 @@ class Tacotron2(nn.Module):
mel_outputs_postnet = self . postnet ( mel_outputs )
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
# DataParallel expects equal sized inputs/outputs, hence padding
if input_lengths is not None :
alignments = alignments . unsqueeze ( 0 )
alignments = nn . functional . pad (
alignments ,
( 0 , max_len - alignments . size ( 3 ) , 0 , 0 ) ,
" constant " , 0 )
alignments = alignments . squeeze ( )
return self . parse_output (
[ mel_outputs , mel_outputs_postnet , gate_outputs , alignments ] ,
output_lengths )