diff --git a/model.py b/model.py index 8ea9a2c..263faa6 100644 --- a/model.py +++ b/model.py @@ -459,6 +459,8 @@ class Tacotron2(nn.Module): self.n_frames_per_step = hparams.n_frames_per_step self.embedding = nn.Embedding( hparams.n_symbols, hparams.symbols_embedding_dim) + torch.nn.init.xavier_uniform_(self.embedding.weight.data) + self.encoder = Encoder(hparams) self.decoder = Decoder(hparams) self.postnet = Postnet(hparams)