Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
7.0 KiB

  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import functional as F
  4. import numpy as np
  5. from collections import OrderedDict
  6. from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward
  7. from text.symbols import symbols
  8. class Linear(nn.Module):
  9. """
  10. Linear Module
  11. """
  12. def __init__(self, in_dim, out_dim, bias=True, w_init='linear'):
  13. """
  14. :param in_dim: dimension of input
  15. :param out_dim: dimension of output
  16. :param bias: boolean. if True, bias is included.
  17. :param w_init: str. weight inits with xavier initialization.
  18. """
  19. super(Linear, self).__init__()
  20. self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
  21. nn.init.xavier_uniform_(
  22. self.linear_layer.weight,
  23. gain=nn.init.calculate_gain(w_init))
  24. def forward(self, x):
  25. return self.linear_layer(x)
  26. class PreNet(nn.Module):
  27. """
  28. Pre Net before passing through the network
  29. """
  30. def __init__(self, input_size, hidden_size, output_size, p=0.5):
  31. """
  32. :param input_size: dimension of input
  33. :param hidden_size: dimension of hidden unit
  34. :param output_size: dimension of output
  35. """
  36. super(PreNet, self).__init__()
  37. self.input_size = input_size
  38. self.output_size = output_size
  39. self.hidden_size = hidden_size
  40. self.layer = nn.Sequential(OrderedDict([
  41. ('fc1', Linear(self.input_size, self.hidden_size)),
  42. ('relu1', nn.ReLU()),
  43. ('dropout1', nn.Dropout(p)),
  44. ('fc2', Linear(self.hidden_size, self.output_size)),
  45. ('relu2', nn.ReLU()),
  46. ('dropout2', nn.Dropout(p)),
  47. ]))
  48. def forward(self, input_):
  49. out = self.layer(input_)
  50. return out
  51. class Conv(nn.Module):
  52. """
  53. Convolution Module
  54. """
  55. def __init__(self,
  56. in_channels,
  57. out_channels,
  58. kernel_size=1,
  59. stride=1,
  60. padding=0,
  61. dilation=1,
  62. bias=True,
  63. w_init='linear'):
  64. """
  65. :param in_channels: dimension of input
  66. :param out_channels: dimension of output
  67. :param kernel_size: size of kernel
  68. :param stride: size of stride
  69. :param padding: size of padding
  70. :param dilation: dilation rate
  71. :param bias: boolean. if True, bias is included.
  72. :param w_init: str. weight inits with xavier initialization.
  73. """
  74. super(Conv, self).__init__()
  75. self.conv = nn.Conv1d(in_channels,
  76. out_channels,
  77. kernel_size=kernel_size,
  78. stride=stride,
  79. padding=padding,
  80. dilation=dilation,
  81. bias=bias)
  82. nn.init.xavier_uniform_(
  83. self.conv.weight, gain=nn.init.calculate_gain(w_init))
  84. def forward(self, x):
  85. x = self.conv(x)
  86. return x
  87. class FFTBlock(torch.nn.Module):
  88. """FFT Block"""
  89. def __init__(self,
  90. d_model,
  91. d_inner,
  92. n_head,
  93. d_k,
  94. d_v,
  95. dropout=0.1):
  96. super(FFTBlock, self).__init__()
  97. self.slf_attn = MultiHeadAttention(
  98. n_head, d_model, d_k, d_v, dropout=dropout)
  99. self.pos_ffn = PositionwiseFeedForward(
  100. d_model, d_inner, dropout=dropout)
  101. def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
  102. enc_output, enc_slf_attn = self.slf_attn(
  103. enc_input, enc_input, enc_input, mask=slf_attn_mask)
  104. enc_output *= non_pad_mask
  105. enc_output = self.pos_ffn(enc_output)
  106. enc_output *= non_pad_mask
  107. return enc_output, enc_slf_attn
  108. class ConvNorm(torch.nn.Module):
  109. def __init__(self,
  110. in_channels,
  111. out_channels,
  112. kernel_size=1,
  113. stride=1,
  114. padding=None,
  115. dilation=1,
  116. bias=True,
  117. w_init_gain='linear'):
  118. super(ConvNorm, self).__init__()
  119. if padding is None:
  120. assert(kernel_size % 2 == 1)
  121. padding = int(dilation * (kernel_size - 1) / 2)
  122. self.conv = torch.nn.Conv1d(in_channels,
  123. out_channels,
  124. kernel_size=kernel_size,
  125. stride=stride,
  126. padding=padding,
  127. dilation=dilation,
  128. bias=bias)
  129. torch.nn.init.xavier_uniform_(
  130. self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
  131. def forward(self, signal):
  132. conv_signal = self.conv(signal)
  133. return conv_signal
  134. class PostNet(nn.Module):
  135. """
  136. PostNet: Five 1-d convolution with 512 channels and kernel size 5
  137. """
  138. def __init__(self,
  139. n_mel_channels=80,
  140. postnet_embedding_dim=512,
  141. postnet_kernel_size=5,
  142. postnet_n_convolutions=5):
  143. super(PostNet, self).__init__()
  144. self.convolutions = nn.ModuleList()
  145. self.convolutions.append(
  146. nn.Sequential(
  147. ConvNorm(n_mel_channels,
  148. postnet_embedding_dim,
  149. kernel_size=postnet_kernel_size,
  150. stride=1,
  151. padding=int((postnet_kernel_size - 1) / 2),
  152. dilation=1,
  153. w_init_gain='tanh'),
  154. nn.BatchNorm1d(postnet_embedding_dim))
  155. )
  156. for i in range(1, postnet_n_convolutions - 1):
  157. self.convolutions.append(
  158. nn.Sequential(
  159. ConvNorm(postnet_embedding_dim,
  160. postnet_embedding_dim,
  161. kernel_size=postnet_kernel_size,
  162. stride=1,
  163. padding=int((postnet_kernel_size - 1) / 2),
  164. dilation=1,
  165. w_init_gain='tanh'),
  166. nn.BatchNorm1d(postnet_embedding_dim))
  167. )
  168. self.convolutions.append(
  169. nn.Sequential(
  170. ConvNorm(postnet_embedding_dim,
  171. n_mel_channels,
  172. kernel_size=postnet_kernel_size,
  173. stride=1,
  174. padding=int((postnet_kernel_size - 1) / 2),
  175. dilation=1,
  176. w_init_gain='linear'),
  177. nn.BatchNorm1d(n_mel_channels))
  178. )
  179. def forward(self, x):
  180. x = x.contiguous().transpose(1, 2)
  181. for i in range(len(self.convolutions) - 1):
  182. x = F.dropout(torch.tanh(
  183. self.convolutions[i](x)), 0.5, self.training)
  184. x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
  185. x = x.contiguous().transpose(1, 2)
  186. return x