Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

404 lines
12 KiB

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from collections import OrderedDict
  5. import numpy as np
  6. import copy
  7. import math
  8. import hparams as hp
  9. import utils
  10. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  11. def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
  12. ''' Sinusoid position encoding table '''
  13. def cal_angle(position, hid_idx):
  14. return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
  15. def get_posi_angle_vec(position):
  16. return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
  17. sinusoid_table = np.array([get_posi_angle_vec(pos_i)
  18. for pos_i in range(n_position)])
  19. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  20. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  21. if padding_idx is not None:
  22. # zero vector for padding dimension
  23. sinusoid_table[padding_idx] = 0.
  24. return torch.FloatTensor(sinusoid_table)
  25. def clones(module, N):
  26. return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
  27. class LengthRegulator(nn.Module):
  28. """ Length Regulator """
  29. def __init__(self):
  30. super(LengthRegulator, self).__init__()
  31. self.duration_predictor = DurationPredictor()
  32. def LR(self, x, duration_predictor_output, alpha=1.0, mel_max_length=None):
  33. output = list()
  34. for batch, expand_target in zip(x, duration_predictor_output):
  35. output.append(self.expand(batch, expand_target, alpha))
  36. if mel_max_length:
  37. output = utils.pad(output, mel_max_length)
  38. else:
  39. output = utils.pad(output)
  40. return output
  41. def expand(self, batch, predicted, alpha):
  42. out = list()
  43. for i, vec in enumerate(batch):
  44. expand_size = predicted[i].item()
  45. out.append(vec.expand(int(expand_size*alpha), -1))
  46. out = torch.cat(out, 0)
  47. return out
  48. def rounding(self, num):
  49. if num - int(num) >= 0.5:
  50. return int(num) + 1
  51. else:
  52. return int(num)
  53. def forward(self, x, alpha=1.0, target=None, mel_max_length=None):
  54. duration_predictor_output = self.duration_predictor(x)
  55. if self.training:
  56. output = self.LR(x, target, mel_max_length=mel_max_length)
  57. return output, duration_predictor_output
  58. else:
  59. for idx, ele in enumerate(duration_predictor_output[0]):
  60. duration_predictor_output[0][idx] = self.rounding(ele)
  61. output = self.LR(x, duration_predictor_output, alpha)
  62. mel_pos = torch.stack(
  63. [torch.Tensor([i+1 for i in range(output.size(1))])]).long().to(device)
  64. return output, mel_pos
  65. class DurationPredictor(nn.Module):
  66. """ Duration Predictor """
  67. def __init__(self):
  68. super(DurationPredictor, self).__init__()
  69. self.input_size = hp.d_model
  70. self.filter_size = hp.duration_predictor_filter_size
  71. self.kernel = hp.duration_predictor_kernel_size
  72. self.conv_output_size = hp.duration_predictor_filter_size
  73. self.dropout = hp.dropout
  74. self.conv_layer = nn.Sequential(OrderedDict([
  75. ("conv1d_1", Conv(self.input_size,
  76. self.filter_size,
  77. kernel_size=self.kernel,
  78. padding=1)),
  79. ("layer_norm_1", nn.LayerNorm(self.filter_size)),
  80. ("relu_1", nn.ReLU()),
  81. ("dropout_1", nn.Dropout(self.dropout)),
  82. ("conv1d_2", Conv(self.filter_size,
  83. self.filter_size,
  84. kernel_size=self.kernel,
  85. padding=1)),
  86. ("layer_norm_2", nn.LayerNorm(self.filter_size)),
  87. ("relu_2", nn.ReLU()),
  88. ("dropout_2", nn.Dropout(self.dropout))
  89. ]))
  90. self.linear_layer = Linear(self.conv_output_size, 1)
  91. self.relu = nn.ReLU()
  92. def forward(self, encoder_output):
  93. out = self.conv_layer(encoder_output)
  94. out = self.linear_layer(out)
  95. out = self.relu(out)
  96. out = out.squeeze()
  97. if not self.training:
  98. out = out.unsqueeze(0)
  99. return out
  100. class Conv(nn.Module):
  101. """
  102. Convolution Module
  103. """
  104. def __init__(self,
  105. in_channels,
  106. out_channels,
  107. kernel_size=1,
  108. stride=1,
  109. padding=0,
  110. dilation=1,
  111. bias=True,
  112. w_init='linear'):
  113. """
  114. :param in_channels: dimension of input
  115. :param out_channels: dimension of output
  116. :param kernel_size: size of kernel
  117. :param stride: size of stride
  118. :param padding: size of padding
  119. :param dilation: dilation rate
  120. :param bias: boolean. if True, bias is included.
  121. :param w_init: str. weight inits with xavier initialization.
  122. """
  123. super(Conv, self).__init__()
  124. self.conv = nn.Conv1d(in_channels,
  125. out_channels,
  126. kernel_size=kernel_size,
  127. stride=stride,
  128. padding=padding,
  129. dilation=dilation,
  130. bias=bias)
  131. nn.init.xavier_uniform_(
  132. self.conv.weight, gain=nn.init.calculate_gain(w_init))
  133. def forward(self, x):
  134. x = x.contiguous().transpose(1, 2)
  135. x = self.conv(x)
  136. x = x.contiguous().transpose(1, 2)
  137. return x
  138. class Linear(nn.Module):
  139. """
  140. Linear Module
  141. """
  142. def __init__(self, in_dim, out_dim, bias=True, w_init='linear'):
  143. """
  144. :param in_dim: dimension of input
  145. :param out_dim: dimension of output
  146. :param bias: boolean. if True, bias is included.
  147. :param w_init: str. weight inits with xavier initialization.
  148. """
  149. super(Linear, self).__init__()
  150. self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
  151. nn.init.xavier_uniform_(
  152. self.linear_layer.weight,
  153. gain=nn.init.calculate_gain(w_init))
  154. def forward(self, x):
  155. return self.linear_layer(x)
  156. class FFN(nn.Module):
  157. """
  158. Positionwise Feed-Forward Network
  159. """
  160. def __init__(self, num_hidden):
  161. """
  162. :param num_hidden: dimension of hidden
  163. """
  164. super(FFN, self).__init__()
  165. self.w_1 = Conv(num_hidden, num_hidden * 4,
  166. kernel_size=3, padding=1, w_init='relu')
  167. self.w_2 = Conv(num_hidden * 4, num_hidden, kernel_size=3, padding=1)
  168. self.dropout = nn.Dropout(p=0.1)
  169. self.layer_norm = nn.LayerNorm(num_hidden)
  170. def forward(self, input_):
  171. # FFN Network
  172. x = input_
  173. x = self.w_2(torch.relu(self.w_1(x)))
  174. # residual connection
  175. x = x + input_
  176. # dropout
  177. x = self.dropout(x)
  178. # layer normalization
  179. x = self.layer_norm(x)
  180. return x
  181. class MultiheadAttention(nn.Module):
  182. """
  183. Multihead attention mechanism (dot attention)
  184. """
  185. def __init__(self, num_hidden_k):
  186. """
  187. :param num_hidden_k: dimension of hidden
  188. """
  189. super(MultiheadAttention, self).__init__()
  190. self.num_hidden_k = num_hidden_k
  191. self.attn_dropout = nn.Dropout(p=0.1)
  192. def forward(self, key, value, query, mask=None, query_mask=None):
  193. # Get attention score
  194. attn = torch.bmm(query, key.transpose(1, 2))
  195. attn = attn / math.sqrt(self.num_hidden_k)
  196. # Masking to ignore padding (key side)
  197. if mask is not None:
  198. attn = attn.masked_fill(mask, -2 ** 32 + 1)
  199. attn = torch.softmax(attn, dim=-1)
  200. else:
  201. attn = torch.softmax(attn, dim=-1)
  202. # Masking to ignore padding (query side)
  203. if query_mask is not None:
  204. attn = attn * query_mask
  205. # Dropout
  206. attn = self.attn_dropout(attn)
  207. # Get Context Vector
  208. result = torch.bmm(attn, value)
  209. return result, attn
  210. class Attention(nn.Module):
  211. """
  212. Attention Network
  213. """
  214. def __init__(self, num_hidden, h=2):
  215. """
  216. :param num_hidden: dimension of hidden
  217. :param h: num of heads
  218. """
  219. super(Attention, self).__init__()
  220. self.num_hidden = num_hidden
  221. self.num_hidden_per_attn = num_hidden // h
  222. self.h = h
  223. self.key = Linear(num_hidden, num_hidden, bias=False)
  224. self.value = Linear(num_hidden, num_hidden, bias=False)
  225. self.query = Linear(num_hidden, num_hidden, bias=False)
  226. self.multihead = MultiheadAttention(self.num_hidden_per_attn)
  227. self.residual_dropout = nn.Dropout(p=0.1)
  228. self.final_linear = Linear(num_hidden * 2, num_hidden)
  229. self.layer_norm_1 = nn.LayerNorm(num_hidden)
  230. def forward(self, memory, decoder_input, mask=None, query_mask=None):
  231. batch_size = memory.size(0)
  232. seq_k = memory.size(1)
  233. seq_q = decoder_input.size(1)
  234. # Repeat masks h times
  235. if query_mask is not None:
  236. query_mask = query_mask.unsqueeze(-1).repeat(1, 1, seq_k)
  237. query_mask = query_mask.repeat(self.h, 1, 1)
  238. if mask is not None:
  239. mask = mask.repeat(self.h, 1, 1)
  240. # Make multihead
  241. key = self.key(memory).view(batch_size,
  242. seq_k,
  243. self.h,
  244. self.num_hidden_per_attn)
  245. value = self.value(memory).view(batch_size,
  246. seq_k,
  247. self.h,
  248. self.num_hidden_per_attn)
  249. query = self.query(decoder_input).view(batch_size,
  250. seq_q,
  251. self.h,
  252. self.num_hidden_per_attn)
  253. key = key.permute(2, 0, 1, 3).contiguous().view(-1,
  254. seq_k,
  255. self.num_hidden_per_attn)
  256. value = value.permute(2, 0, 1, 3).contiguous().view(-1,
  257. seq_k,
  258. self.num_hidden_per_attn)
  259. query = query.permute(2, 0, 1, 3).contiguous().view(-1,
  260. seq_q,
  261. self.num_hidden_per_attn)
  262. # Get context vector
  263. result, attns = self.multihead(
  264. key, value, query, mask=mask, query_mask=query_mask)
  265. # Concatenate all multihead context vector
  266. result = result.view(self.h, batch_size, seq_q,
  267. self.num_hidden_per_attn)
  268. result = result.permute(1, 2, 0, 3).contiguous().view(
  269. batch_size, seq_q, -1)
  270. # Concatenate context vector with input (most important)
  271. result = torch.cat([decoder_input, result], dim=-1)
  272. # Final linear
  273. result = self.final_linear(result)
  274. # Residual dropout & connection
  275. result = self.residual_dropout(result)
  276. result = result + decoder_input
  277. # Layer normalization
  278. result = self.layer_norm_1(result)
  279. return result, attns
  280. class FFTBlock(torch.nn.Module):
  281. """FFT Block"""
  282. def __init__(self,
  283. d_model,
  284. n_head=hp.Head):
  285. super(FFTBlock, self).__init__()
  286. self.slf_attn = clones(Attention(d_model), hp.N)
  287. self.pos_ffn = clones(FFN(d_model), hp.N)
  288. self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024,
  289. d_model,
  290. padding_idx=0), freeze=True)
  291. def forward(self, x, pos, return_attns=False):
  292. # Get character mask
  293. if self.training:
  294. c_mask = pos.ne(0).type(torch.float)
  295. mask = pos.eq(0).unsqueeze(1).repeat(1, x.size(1), 1)
  296. else:
  297. c_mask, mask = None, None
  298. # Get positional embedding, apply alpha and add
  299. pos = self.pos_emb(pos)
  300. x = x + pos
  301. # Attention encoder-encoder
  302. attns = list()
  303. for slf_attn, ffn in zip(self.slf_attn, self.pos_ffn):
  304. x, attn = slf_attn(x, x, mask=mask, query_mask=c_mask)
  305. x = ffn(x)
  306. attns.append(attn)
  307. return x, attns