|
|
- import torch
- import torch.nn as nn
- import numpy as np
-
- import transformer.Constants as Constants
- from transformer.Layers import FFTBlock, PreNet, PostNet, Linear
- from text.symbols import symbols
- import hparams as hp
-
-
- def get_non_pad_mask(seq):
- assert seq.dim() == 2
- return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
-
-
- def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
- ''' Sinusoid position encoding table '''
-
- def cal_angle(position, hid_idx):
- return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
-
- def get_posi_angle_vec(position):
- return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
-
- sinusoid_table = np.array([get_posi_angle_vec(pos_i)
- for pos_i in range(n_position)])
-
- sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
- sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
-
- if padding_idx is not None:
- # zero vector for padding dimension
- sinusoid_table[padding_idx] = 0.
-
- return torch.FloatTensor(sinusoid_table)
-
-
- def get_attn_key_pad_mask(seq_k, seq_q):
- ''' For masking out the padding part of key sequence. '''
-
- # Expand to fit the shape of key query attention matrix.
- len_q = seq_q.size(1)
- padding_mask = seq_k.eq(Constants.PAD)
- padding_mask = padding_mask.unsqueeze(
- 1).expand(-1, len_q, -1) # b x lq x lk
-
- return padding_mask
-
-
- class Encoder(nn.Module):
- ''' Encoder '''
-
- def __init__(self,
- n_src_vocab=len(symbols)+1,
- len_max_seq=hp.max_sep_len,
- d_word_vec=hp.word_vec_dim,
- n_layers=hp.encoder_n_layer,
- n_head=hp.encoder_head,
- d_k=64,
- d_v=64,
- d_model=hp.word_vec_dim,
- d_inner=hp.encoder_conv1d_filter_size,
- dropout=hp.dropout):
-
- super(Encoder, self).__init__()
-
- n_position = len_max_seq + 1
-
- self.src_word_emb = nn.Embedding(
- n_src_vocab, d_word_vec, padding_idx=Constants.PAD)
-
- self.position_enc = nn.Embedding.from_pretrained(
- get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
- freeze=True)
-
- self.layer_stack = nn.ModuleList([FFTBlock(
- d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
-
- def forward(self, src_seq, src_pos, return_attns=False):
-
- enc_slf_attn_list = []
-
- # -- Prepare masks
- slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
- non_pad_mask = get_non_pad_mask(src_seq)
-
- # -- Forward
- enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
-
- for enc_layer in self.layer_stack:
- enc_output, enc_slf_attn = enc_layer(
- enc_output,
- non_pad_mask=non_pad_mask,
- slf_attn_mask=slf_attn_mask)
- if return_attns:
- enc_slf_attn_list += [enc_slf_attn]
-
- return enc_output, non_pad_mask
-
-
- class Decoder(nn.Module):
- """ Decoder """
-
- def __init__(self,
- len_max_seq=hp.max_sep_len,
- d_word_vec=hp.word_vec_dim,
- n_layers=hp.decoder_n_layer,
- n_head=hp.decoder_head,
- d_k=64,
- d_v=64,
- d_model=hp.word_vec_dim,
- d_inner=hp.decoder_conv1d_filter_size,
- dropout=hp.dropout):
-
- super(Decoder, self).__init__()
-
- n_position = len_max_seq + 1
-
- self.position_enc = nn.Embedding.from_pretrained(
- get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
- freeze=True)
-
- self.layer_stack = nn.ModuleList([FFTBlock(
- d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
-
- def forward(self, enc_seq, enc_pos, return_attns=False):
-
- dec_slf_attn_list = []
-
- # -- Prepare masks
- slf_attn_mask = get_attn_key_pad_mask(seq_k=enc_pos, seq_q=enc_pos)
- non_pad_mask = get_non_pad_mask(enc_pos)
-
- # -- Forward
- dec_output = enc_seq + self.position_enc(enc_pos)
-
- for dec_layer in self.layer_stack:
- dec_output, dec_slf_attn = dec_layer(
- dec_output,
- non_pad_mask=non_pad_mask,
- slf_attn_mask=slf_attn_mask)
- if return_attns:
- dec_slf_attn_list += [dec_slf_attn]
-
- return dec_output
|