import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
import transformer.Constants as Constants
|
|
from transformer.Layers import FFTBlock, PreNet, PostNet, Linear
|
|
from text.symbols import symbols
|
|
import hparams as hp
|
|
|
|
|
|
def get_non_pad_mask(seq):
|
|
assert seq.dim() == 2
|
|
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
|
|
|
|
|
|
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
|
''' Sinusoid position encoding table '''
|
|
|
|
def cal_angle(position, hid_idx):
|
|
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
|
|
|
|
def get_posi_angle_vec(position):
|
|
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
|
|
|
|
sinusoid_table = np.array([get_posi_angle_vec(pos_i)
|
|
for pos_i in range(n_position)])
|
|
|
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
|
|
if padding_idx is not None:
|
|
# zero vector for padding dimension
|
|
sinusoid_table[padding_idx] = 0.
|
|
|
|
return torch.FloatTensor(sinusoid_table)
|
|
|
|
|
|
def get_attn_key_pad_mask(seq_k, seq_q):
|
|
''' For masking out the padding part of key sequence. '''
|
|
|
|
# Expand to fit the shape of key query attention matrix.
|
|
len_q = seq_q.size(1)
|
|
padding_mask = seq_k.eq(Constants.PAD)
|
|
padding_mask = padding_mask.unsqueeze(
|
|
1).expand(-1, len_q, -1) # b x lq x lk
|
|
|
|
return padding_mask
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
''' Encoder '''
|
|
|
|
def __init__(self,
|
|
n_src_vocab=len(symbols)+1,
|
|
len_max_seq=hp.max_sep_len,
|
|
d_word_vec=hp.word_vec_dim,
|
|
n_layers=hp.encoder_n_layer,
|
|
n_head=hp.encoder_head,
|
|
d_k=64,
|
|
d_v=64,
|
|
d_model=hp.word_vec_dim,
|
|
d_inner=hp.encoder_conv1d_filter_size,
|
|
dropout=hp.dropout):
|
|
|
|
super(Encoder, self).__init__()
|
|
|
|
n_position = len_max_seq + 1
|
|
|
|
self.src_word_emb = nn.Embedding(
|
|
n_src_vocab, d_word_vec, padding_idx=Constants.PAD)
|
|
|
|
self.position_enc = nn.Embedding.from_pretrained(
|
|
get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
|
|
freeze=True)
|
|
|
|
self.layer_stack = nn.ModuleList([FFTBlock(
|
|
d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
|
|
|
|
def forward(self, src_seq, src_pos, return_attns=False):
|
|
|
|
enc_slf_attn_list = []
|
|
|
|
# -- Prepare masks
|
|
slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
|
|
non_pad_mask = get_non_pad_mask(src_seq)
|
|
|
|
# -- Forward
|
|
enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
|
|
|
|
for enc_layer in self.layer_stack:
|
|
enc_output, enc_slf_attn = enc_layer(
|
|
enc_output,
|
|
non_pad_mask=non_pad_mask,
|
|
slf_attn_mask=slf_attn_mask)
|
|
if return_attns:
|
|
enc_slf_attn_list += [enc_slf_attn]
|
|
|
|
return enc_output, non_pad_mask
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
""" Decoder """
|
|
|
|
def __init__(self,
|
|
len_max_seq=hp.max_sep_len,
|
|
d_word_vec=hp.word_vec_dim,
|
|
n_layers=hp.decoder_n_layer,
|
|
n_head=hp.decoder_head,
|
|
d_k=64,
|
|
d_v=64,
|
|
d_model=hp.word_vec_dim,
|
|
d_inner=hp.decoder_conv1d_filter_size,
|
|
dropout=hp.dropout):
|
|
|
|
super(Decoder, self).__init__()
|
|
|
|
n_position = len_max_seq + 1
|
|
|
|
self.position_enc = nn.Embedding.from_pretrained(
|
|
get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
|
|
freeze=True)
|
|
|
|
self.layer_stack = nn.ModuleList([FFTBlock(
|
|
d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
|
|
|
|
def forward(self, enc_seq, enc_pos, return_attns=False):
|
|
|
|
dec_slf_attn_list = []
|
|
|
|
# -- Prepare masks
|
|
slf_attn_mask = get_attn_key_pad_mask(seq_k=enc_pos, seq_q=enc_pos)
|
|
non_pad_mask = get_non_pad_mask(enc_pos)
|
|
|
|
# -- Forward
|
|
dec_output = enc_seq + self.position_enc(enc_pos)
|
|
|
|
for dec_layer in self.layer_stack:
|
|
dec_output, dec_slf_attn = dec_layer(
|
|
dec_output,
|
|
non_pad_mask=non_pad_mask,
|
|
slf_attn_mask=slf_attn_mask)
|
|
if return_attns:
|
|
dec_slf_attn_list += [dec_slf_attn]
|
|
|
|
return dec_output
|