Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

145 lines
4.5 KiB

import torch
import torch.nn as nn
import numpy as np
import transformer.Constants as Constants
from transformer.Layers import FFTBlock, PreNet, PostNet, Linear
from text.symbols import symbols
import hparams as hp
def get_non_pad_mask(seq):
assert seq.dim() == 2
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
''' Sinusoid position encoding table '''
def cal_angle(position, hid_idx):
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_posi_angle_vec(pos_i)
for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if padding_idx is not None:
# zero vector for padding dimension
sinusoid_table[padding_idx] = 0.
return torch.FloatTensor(sinusoid_table)
def get_attn_key_pad_mask(seq_k, seq_q):
''' For masking out the padding part of key sequence. '''
# Expand to fit the shape of key query attention matrix.
len_q = seq_q.size(1)
padding_mask = seq_k.eq(Constants.PAD)
padding_mask = padding_mask.unsqueeze(
1).expand(-1, len_q, -1) # b x lq x lk
return padding_mask
class Encoder(nn.Module):
''' Encoder '''
def __init__(self,
n_src_vocab=len(symbols)+1,
len_max_seq=hp.max_sep_len,
d_word_vec=hp.word_vec_dim,
n_layers=hp.encoder_n_layer,
n_head=hp.encoder_head,
d_k=64,
d_v=64,
d_model=hp.word_vec_dim,
d_inner=hp.encoder_conv1d_filter_size,
dropout=hp.dropout):
super(Encoder, self).__init__()
n_position = len_max_seq + 1
self.src_word_emb = nn.Embedding(
n_src_vocab, d_word_vec, padding_idx=Constants.PAD)
self.position_enc = nn.Embedding.from_pretrained(
get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
freeze=True)
self.layer_stack = nn.ModuleList([FFTBlock(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
def forward(self, src_seq, src_pos, return_attns=False):
enc_slf_attn_list = []
# -- Prepare masks
slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
non_pad_mask = get_non_pad_mask(src_seq)
# -- Forward
enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(
enc_output,
non_pad_mask=non_pad_mask,
slf_attn_mask=slf_attn_mask)
if return_attns:
enc_slf_attn_list += [enc_slf_attn]
return enc_output, non_pad_mask
class Decoder(nn.Module):
""" Decoder """
def __init__(self,
len_max_seq=hp.max_sep_len,
d_word_vec=hp.word_vec_dim,
n_layers=hp.decoder_n_layer,
n_head=hp.decoder_head,
d_k=64,
d_v=64,
d_model=hp.word_vec_dim,
d_inner=hp.decoder_conv1d_filter_size,
dropout=hp.dropout):
super(Decoder, self).__init__()
n_position = len_max_seq + 1
self.position_enc = nn.Embedding.from_pretrained(
get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
freeze=True)
self.layer_stack = nn.ModuleList([FFTBlock(
d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
def forward(self, enc_seq, enc_pos, return_attns=False):
dec_slf_attn_list = []
# -- Prepare masks
slf_attn_mask = get_attn_key_pad_mask(seq_k=enc_pos, seq_q=enc_pos)
non_pad_mask = get_non_pad_mask(enc_pos)
# -- Forward
dec_output = enc_seq + self.position_enc(enc_pos)
for dec_layer in self.layer_stack:
dec_output, dec_slf_attn = dec_layer(
dec_output,
non_pad_mask=non_pad_mask,
slf_attn_mask=slf_attn_mask)
if return_attns:
dec_slf_attn_list += [dec_slf_attn]
return dec_output