Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

332 lines
13 KiB

  1. # We retain the copyright notice by NVIDIA from the original code. However, we
  2. # we reserve our rights on the modifications based on the original code.
  3. #
  4. # *****************************************************************************
  5. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  6. #
  7. # Redistribution and use in source and binary forms, with or without
  8. # modification, are permitted provided that the following conditions are met:
  9. # * Redistributions of source code must retain the above copyright
  10. # notice, this list of conditions and the following disclaimer.
  11. # * Redistributions in binary form must reproduce the above copyright
  12. # notice, this list of conditions and the following disclaimer in the
  13. # documentation and/or other materials provided with the distribution.
  14. # * Neither the name of the NVIDIA CORPORATION nor the
  15. # names of its contributors may be used to endorse or promote products
  16. # derived from this software without specific prior written permission.
  17. #
  18. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  19. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  20. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  21. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  22. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  23. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  24. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  25. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  26. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  27. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  28. #
  29. # *****************************************************************************
  30. import torch
  31. from torch.autograd import Variable
  32. import torch.nn.functional as F
  33. import numpy as np
  34. @torch.jit.script
  35. def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
  36. n_channels_int = n_channels[0]
  37. in_act = input_a+input_b
  38. t_act = torch.tanh(in_act[:, :n_channels_int, :])
  39. s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
  40. acts = t_act * s_act
  41. return acts
  42. class Upsample1d(torch.nn.Module):
  43. def __init__(self, scale=2):
  44. super(Upsample1d, self).__init__()
  45. self.scale = scale
  46. def forward(self, x):
  47. y = F.interpolate(
  48. x, scale_factor=self.scale, mode='nearest')
  49. return y
  50. class SqueezeWaveLoss(torch.nn.Module):
  51. def __init__(self, sigma=1.0):
  52. super(SqueezeWaveLoss, self).__init__()
  53. self.sigma = sigma
  54. def forward(self, model_output):
  55. z, log_s_list, log_det_W_list = model_output
  56. for i, log_s in enumerate(log_s_list):
  57. if i == 0:
  58. log_s_total = torch.sum(log_s)
  59. log_det_W_total = log_det_W_list[i]
  60. else:
  61. log_s_total = log_s_total + torch.sum(log_s)
  62. log_det_W_total += log_det_W_list[i]
  63. loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total
  64. return loss/(z.size(0)*z.size(1)*z.size(2))
  65. class Invertible1x1Conv(torch.nn.Module):
  66. """
  67. The layer outputs both the convolution, and the log determinant
  68. of its weight matrix. If reverse=True it does convolution with
  69. inverse
  70. """
  71. def __init__(self, c):
  72. super(Invertible1x1Conv, self).__init__()
  73. self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0,
  74. bias=False)
  75. # Sample a random orthonormal matrix to initialize weights
  76. W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
  77. # Ensure determinant is 1.0 not -1.0
  78. if torch.det(W) < 0:
  79. W[:,0] = -1*W[:,0]
  80. W = W.view(c, c, 1)
  81. self.conv.weight.data = W
  82. def forward(self, z, reverse=False):
  83. # shape
  84. batch_size, group_size, n_of_groups = z.size()
  85. W = self.conv.weight.squeeze()
  86. if reverse:
  87. if not hasattr(self, 'W_inverse'):
  88. # Reverse computation
  89. W_inverse = W.float().inverse()
  90. W_inverse = Variable(W_inverse[..., None])
  91. self.W_inverse = W_inverse.half()
  92. self.W_inverse = self.W_inverse.to(torch.float32)
  93. z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
  94. return z
  95. else:
  96. # Forward computation
  97. log_det_W = batch_size * n_of_groups * torch.logdet(W)
  98. z = self.conv(z)
  99. return z, log_det_W
  100. class WN(torch.nn.Module):
  101. """
  102. This is the WaveNet like layer for the affine coupling. The primary difference
  103. from WaveNet is the convolutions need not be causal. There is also no dilation
  104. size reset. The dilation only doubles on each layer
  105. """
  106. def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
  107. kernel_size):
  108. super(WN, self).__init__()
  109. assert(kernel_size % 2 == 1)
  110. assert(n_channels % 2 == 0)
  111. self.n_layers = n_layers
  112. self.n_channels = n_channels
  113. self.in_layers = torch.nn.ModuleList()
  114. self.res_skip_layers = torch.nn.ModuleList()
  115. self.upsample = Upsample1d(2)
  116. start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
  117. start = torch.nn.utils.weight_norm(start, name='weight')
  118. self.start = start
  119. end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
  120. end.weight.data.zero_()
  121. end.bias.data.zero_()
  122. self.end = end
  123. # cond_layer
  124. cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1)
  125. self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
  126. for i in range(n_layers):
  127. dilation = 1
  128. padding = int((kernel_size*dilation - dilation)/2)
  129. # depthwise separable convolution
  130. depthwise = torch.nn.Conv1d(n_channels, n_channels, 3,
  131. dilation=dilation, padding=padding,
  132. groups=n_channels)
  133. pointwise = torch.nn.Conv1d(n_channels, 2*n_channels, 1)
  134. bn = torch.nn.BatchNorm1d(n_channels)
  135. self.in_layers.append(torch.nn.Sequential(bn, depthwise, pointwise))
  136. # res_skip_layer
  137. res_skip_layer = torch.nn.Conv1d(n_channels, n_channels, 1)
  138. res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
  139. self.res_skip_layers.append(res_skip_layer)
  140. def forward(self, forward_input):
  141. audio, spect = forward_input
  142. audio = self.start(audio)
  143. n_channels_tensor = torch.IntTensor([self.n_channels])
  144. # pass all the mel_spectrograms to cond_layer
  145. spect = self.cond_layer(spect)
  146. for i in range(self.n_layers):
  147. # split the corresponding mel_spectrogram
  148. spect_offset = i*2*self.n_channels
  149. spec = spect[:,spect_offset:spect_offset+2*self.n_channels,:]
  150. if audio.size(2) > spec.size(2):
  151. cond = self.upsample(spec)
  152. else:
  153. cond = spec
  154. acts = fused_add_tanh_sigmoid_multiply(
  155. self.in_layers[i](audio),
  156. cond,
  157. n_channels_tensor)
  158. # res_skip
  159. res_skip_acts = self.res_skip_layers[i](acts)
  160. audio = audio + res_skip_acts
  161. return self.end(audio)
  162. class SqueezeWave(torch.nn.Module):
  163. def __init__(self, n_mel_channels, n_flows, n_audio_channel, n_early_every,
  164. n_early_size, WN_config):
  165. super(SqueezeWave, self).__init__()
  166. assert(n_audio_channel % 2 == 0)
  167. self.upsample = torch.nn.ConvTranspose1d(n_mel_channels,
  168. n_mel_channels,
  169. 1024, stride=256)
  170. self.n_flows = n_flows
  171. self.n_audio_channel = n_audio_channel
  172. self.n_early_every = n_early_every
  173. self.n_early_size = n_early_size
  174. self.WN = torch.nn.ModuleList()
  175. self.convinv = torch.nn.ModuleList()
  176. n_half = int(n_audio_channel / 2)
  177. # Set up layers with the right sizes based on how many dimensions
  178. # have been output already
  179. n_remaining_channels = n_audio_channel
  180. for k in range(n_flows):
  181. if k % self.n_early_every == 0 and k > 0:
  182. n_half = n_half - int(self.n_early_size/2)
  183. n_remaining_channels = n_remaining_channels - self.n_early_size
  184. self.convinv.append(Invertible1x1Conv(n_remaining_channels))
  185. self.WN.append(WN(n_half, n_mel_channels, **WN_config))
  186. self.n_remaining_channels = n_remaining_channels # Useful during inference
  187. def forward(self, forward_input):
  188. """
  189. forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames
  190. forward_input[1] = audio: batch x time
  191. """
  192. spect, audio = forward_input
  193. audio = audio.unfold(
  194. 1, self.n_audio_channel, self.n_audio_channel).permute(0, 2, 1)
  195. output_audio = []
  196. log_s_list = []
  197. log_det_W_list = []
  198. for k in range(self.n_flows):
  199. if k % self.n_early_every == 0 and k > 0:
  200. output_audio.append(audio[:,:self.n_early_size,:])
  201. audio = audio[:,self.n_early_size:,:]
  202. audio, log_det_W = self.convinv[k](audio)
  203. log_det_W_list.append(log_det_W)
  204. n_half = int(audio.size(1)/2)
  205. audio_0 = audio[:,:n_half,:]
  206. audio_1 = audio[:,n_half:,:]
  207. output = self.WN[k]((audio_0, spect))
  208. log_s = output[:, n_half:, :]
  209. b = output[:, :n_half, :]
  210. audio_1 = (torch.exp(log_s))*audio_1 + b
  211. log_s_list.append(log_s)
  212. audio = torch.cat([audio_0, audio_1], 1)
  213. output_audio.append(audio)
  214. return torch.cat(output_audio, 1), log_s_list, log_det_W_list
  215. def infer(self, spect, sigma=1.0):
  216. spect_size = spect.size()
  217. l = spect.size(2)*(256 // self.n_audio_channel)
  218. spect = spect.to(torch.float32)
  219. if spect.type() == 'torch.HalfTensor':
  220. audio = torch.HalfTensor(spect.size(0),
  221. self.n_remaining_channels,
  222. l).normal_()
  223. else:
  224. audio = torch.FloatTensor(spect.size(0),
  225. self.n_remaining_channels,
  226. l).normal_()
  227. for k in reversed(range(self.n_flows)):
  228. n_half = int(audio.size(1)/2)
  229. audio_0 = audio[:,:n_half,:]
  230. audio_1 = audio[:,n_half:,:]
  231. output = self.WN[k]((audio_0, spect))
  232. s = output[:, n_half:, :]
  233. b = output[:, :n_half, :]
  234. audio_1 = (audio_1 - b)/torch.exp(s)
  235. audio = torch.cat([audio_0, audio_1],1)
  236. audio = self.convinv[k](audio, reverse=True)
  237. if k % self.n_early_every == 0 and k > 0:
  238. if spect.type() == 'torch.HalfTensor':
  239. z = torch.HalfTensor(spect.size(0), self.n_early_size, l).normal_()
  240. else:
  241. z = torch.FloatTensor(spect.size(0), self.n_early_size, l).normal_()
  242. audio = torch.cat((sigma*z, audio),1)
  243. audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
  244. return audio
  245. @staticmethod
  246. def remove_weightnorm(model):
  247. squeezewave = model
  248. for WN in squeezewave.WN:
  249. WN.start = torch.nn.utils.remove_weight_norm(WN.start)
  250. WN.in_layers = remove_batch_norm(WN.in_layers)
  251. WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
  252. WN.res_skip_layers = remove(WN.res_skip_layers)
  253. return squeezewave
  254. def fuse_conv_and_bn(conv, bn):
  255. fusedconv = torch.nn.Conv1d(
  256. conv.in_channels,
  257. conv.out_channels,
  258. kernel_size = conv.kernel_size,
  259. padding=conv.padding,
  260. bias=True,
  261. groups=conv.groups)
  262. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  263. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
  264. w_bn = w_bn.clone()
  265. fusedconv.weight.data = torch.mm(w_bn, w_conv).view(fusedconv.weight.size())
  266. if conv.bias is not None:
  267. b_conv = conv.bias
  268. else:
  269. b_conv = torch.zeros( conv.weight.size(0) )
  270. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  271. b_bn = torch.unsqueeze(b_bn, 1)
  272. bn_3 = b_bn.expand(-1, 3)
  273. b = torch.matmul(w_conv, torch.transpose(bn_3, 0, 1))[range(b_bn.size()[0]), range(b_bn.size()[0])]
  274. fusedconv.bias.data = ( b_conv + b )
  275. return fusedconv
  276. def remove_batch_norm(conv_list):
  277. new_conv_list = torch.nn.ModuleList()
  278. for old_conv in conv_list:
  279. depthwise = fuse_conv_and_bn(old_conv[1], old_conv[0])
  280. pointwise = old_conv[2]
  281. new_conv_list.append(torch.nn.Sequential(depthwise, pointwise))
  282. return new_conv_list
  283. def remove(conv_list):
  284. new_conv_list = torch.nn.ModuleList()
  285. for old_conv in conv_list:
  286. old_conv = torch.nn.utils.remove_weight_norm(old_conv)
  287. new_conv_list.append(old_conv)
  288. return new_conv_list