import torch import torch.nn as nn import numpy as np import transformer.Constants as Constants from transformer.Layers import FFTBlock, PreNet, PostNet, Linear from text.symbols import symbols import hparams as hp def get_non_pad_mask(seq): assert seq.dim() == 2 return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): ''' Sinusoid position encoding table ''' def cal_angle(position, hid_idx): return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) def get_posi_angle_vec(position): return [cal_angle(position, hid_j) for hid_j in range(d_hid)] sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 if padding_idx is not None: # zero vector for padding dimension sinusoid_table[padding_idx] = 0. return torch.FloatTensor(sinusoid_table) def get_attn_key_pad_mask(seq_k, seq_q): ''' For masking out the padding part of key sequence. ''' # Expand to fit the shape of key query attention matrix. len_q = seq_q.size(1) padding_mask = seq_k.eq(Constants.PAD) padding_mask = padding_mask.unsqueeze( 1).expand(-1, len_q, -1) # b x lq x lk return padding_mask class Encoder(nn.Module): ''' Encoder ''' def __init__(self, n_src_vocab=len(symbols)+1, len_max_seq=hp.max_sep_len, d_word_vec=hp.word_vec_dim, n_layers=hp.encoder_n_layer, n_head=hp.encoder_head, d_k=64, d_v=64, d_model=hp.word_vec_dim, d_inner=hp.encoder_conv1d_filter_size, dropout=hp.dropout): super(Encoder, self).__init__() n_position = len_max_seq + 1 self.src_word_emb = nn.Embedding( n_src_vocab, d_word_vec, padding_idx=Constants.PAD) self.position_enc = nn.Embedding.from_pretrained( get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), freeze=True) self.layer_stack = nn.ModuleList([FFTBlock( d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)]) def forward(self, src_seq, src_pos, return_attns=False): enc_slf_attn_list = [] # -- Prepare masks slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) non_pad_mask = get_non_pad_mask(src_seq) # -- Forward enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) for enc_layer in self.layer_stack: enc_output, enc_slf_attn = enc_layer( enc_output, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask) if return_attns: enc_slf_attn_list += [enc_slf_attn] return enc_output, non_pad_mask class Decoder(nn.Module): """ Decoder """ def __init__(self, len_max_seq=hp.max_sep_len, d_word_vec=hp.word_vec_dim, n_layers=hp.decoder_n_layer, n_head=hp.decoder_head, d_k=64, d_v=64, d_model=hp.word_vec_dim, d_inner=hp.decoder_conv1d_filter_size, dropout=hp.dropout): super(Decoder, self).__init__() n_position = len_max_seq + 1 self.position_enc = nn.Embedding.from_pretrained( get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), freeze=True) self.layer_stack = nn.ModuleList([FFTBlock( d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)]) def forward(self, enc_seq, enc_pos, return_attns=False): dec_slf_attn_list = [] # -- Prepare masks slf_attn_mask = get_attn_key_pad_mask(seq_k=enc_pos, seq_q=enc_pos) non_pad_mask = get_non_pad_mask(enc_pos) # -- Forward dec_output = enc_seq + self.position_enc(enc_pos) for dec_layer in self.layer_stack: dec_output, dec_slf_attn = dec_layer( dec_output, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask) if return_attns: dec_slf_attn_list += [dec_slf_attn] return dec_output