Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

27 lines
698 B

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. class ScaledDotProductAttention(nn.Module):
  5. ''' Scaled Dot-Product Attention '''
  6. def __init__(self, temperature, attn_dropout=0.1):
  7. super().__init__()
  8. self.temperature = temperature
  9. self.dropout = nn.Dropout(attn_dropout)
  10. self.softmax = nn.Softmax(dim=2)
  11. def forward(self, q, k, v, mask=None):
  12. attn = torch.bmm(q, k.transpose(1, 2))
  13. attn = attn / self.temperature
  14. if mask is not None:
  15. attn = attn.masked_fill(mask, -np.inf)
  16. attn = self.softmax(attn)
  17. attn = self.dropout(attn)
  18. output = torch.bmm(attn, v)
  19. return output, attn