Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
4.5 KiB

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import transformer.Constants as Constants
  5. from transformer.Layers import FFTBlock, PreNet, PostNet, Linear
  6. from text.symbols import symbols
  7. import hparams as hp
  8. def get_non_pad_mask(seq):
  9. assert seq.dim() == 2
  10. return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
  11. def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
  12. ''' Sinusoid position encoding table '''
  13. def cal_angle(position, hid_idx):
  14. return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
  15. def get_posi_angle_vec(position):
  16. return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
  17. sinusoid_table = np.array([get_posi_angle_vec(pos_i)
  18. for pos_i in range(n_position)])
  19. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  20. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  21. if padding_idx is not None:
  22. # zero vector for padding dimension
  23. sinusoid_table[padding_idx] = 0.
  24. return torch.FloatTensor(sinusoid_table)
  25. def get_attn_key_pad_mask(seq_k, seq_q):
  26. ''' For masking out the padding part of key sequence. '''
  27. # Expand to fit the shape of key query attention matrix.
  28. len_q = seq_q.size(1)
  29. padding_mask = seq_k.eq(Constants.PAD)
  30. padding_mask = padding_mask.unsqueeze(
  31. 1).expand(-1, len_q, -1) # b x lq x lk
  32. return padding_mask
  33. class Encoder(nn.Module):
  34. ''' Encoder '''
  35. def __init__(self,
  36. n_src_vocab=len(symbols)+1,
  37. len_max_seq=hp.max_sep_len,
  38. d_word_vec=hp.word_vec_dim,
  39. n_layers=hp.encoder_n_layer,
  40. n_head=hp.encoder_head,
  41. d_k=64,
  42. d_v=64,
  43. d_model=hp.word_vec_dim,
  44. d_inner=hp.encoder_conv1d_filter_size,
  45. dropout=hp.dropout):
  46. super(Encoder, self).__init__()
  47. n_position = len_max_seq + 1
  48. self.src_word_emb = nn.Embedding(
  49. n_src_vocab, d_word_vec, padding_idx=Constants.PAD)
  50. self.position_enc = nn.Embedding.from_pretrained(
  51. get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
  52. freeze=True)
  53. self.layer_stack = nn.ModuleList([FFTBlock(
  54. d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
  55. def forward(self, src_seq, src_pos, return_attns=False):
  56. enc_slf_attn_list = []
  57. # -- Prepare masks
  58. slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
  59. non_pad_mask = get_non_pad_mask(src_seq)
  60. # -- Forward
  61. enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
  62. for enc_layer in self.layer_stack:
  63. enc_output, enc_slf_attn = enc_layer(
  64. enc_output,
  65. non_pad_mask=non_pad_mask,
  66. slf_attn_mask=slf_attn_mask)
  67. if return_attns:
  68. enc_slf_attn_list += [enc_slf_attn]
  69. return enc_output, non_pad_mask
  70. class Decoder(nn.Module):
  71. """ Decoder """
  72. def __init__(self,
  73. len_max_seq=hp.max_sep_len,
  74. d_word_vec=hp.word_vec_dim,
  75. n_layers=hp.decoder_n_layer,
  76. n_head=hp.decoder_head,
  77. d_k=64,
  78. d_v=64,
  79. d_model=hp.word_vec_dim,
  80. d_inner=hp.decoder_conv1d_filter_size,
  81. dropout=hp.dropout):
  82. super(Decoder, self).__init__()
  83. n_position = len_max_seq + 1
  84. self.position_enc = nn.Embedding.from_pretrained(
  85. get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
  86. freeze=True)
  87. self.layer_stack = nn.ModuleList([FFTBlock(
  88. d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])
  89. def forward(self, enc_seq, enc_pos, return_attns=False):
  90. dec_slf_attn_list = []
  91. # -- Prepare masks
  92. slf_attn_mask = get_attn_key_pad_mask(seq_k=enc_pos, seq_q=enc_pos)
  93. non_pad_mask = get_non_pad_mask(enc_pos)
  94. # -- Forward
  95. dec_output = enc_seq + self.position_enc(enc_pos)
  96. for dec_layer in self.layer_stack:
  97. dec_output, dec_slf_attn = dec_layer(
  98. dec_output,
  99. non_pad_mask=non_pad_mask,
  100. slf_attn_mask=slf_attn_mask)
  101. if return_attns:
  102. dec_slf_attn_list += [dec_slf_attn]
  103. return dec_output