From 1ec0e5e8cd1152f3a551027201a26de0389e1235 Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Sun, 25 Nov 2018 22:33:32 -0800 Subject: [PATCH] layers.py: rewrite --- layers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/layers.py b/layers.py index f4935d5..615a64a 100644 --- a/layers.py +++ b/layers.py @@ -10,7 +10,7 @@ class LinearNorm(torch.nn.Module): super(LinearNorm, self).__init__() self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) - torch.nn.init.xavier_uniform( + torch.nn.init.xavier_uniform_( self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) @@ -31,7 +31,7 @@ class ConvNorm(torch.nn.Module): padding=padding, dilation=dilation, bias=bias) - torch.nn.init.xavier_uniform( + torch.nn.init.xavier_uniform_( self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) def forward(self, signal): @@ -42,7 +42,7 @@ class ConvNorm(torch.nn.Module): class TacotronSTFT(torch.nn.Module): def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, - mel_fmax=None): + mel_fmax=8000.0): super(TacotronSTFT, self).__init__() self.n_mel_channels = n_mel_channels self.sampling_rate = sampling_rate