Browse Source

non cuda and importing fastspeech mel

master
alokprasad 4 years ago
parent
commit
00cd3dc8ec
3 changed files with 29 additions and 23 deletions
  1. +2
    -2
      SqueezeWave/denoiser.py
  2. +12
    -11
      SqueezeWave/glow.py
  3. +15
    -10
      SqueezeWave/inference.py

+ 2
- 2
SqueezeWave/denoiser.py View File

@ -11,7 +11,7 @@ class Denoiser(torch.nn.Module):
super(Denoiser, self).__init__()
self.stft = STFT(filter_length=filter_length,
hop_length=int(filter_length/n_overlap),
win_length=win_length).cuda()
win_length=win_length)
if mode == 'zeros':
mel_input = torch.zeros(
(1, 80, 88),
@ -32,7 +32,7 @@ class Denoiser(torch.nn.Module):
self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
def forward(self, audio, strength=0.1):
audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
audio_spec, audio_angles = self.stft.transform(audio.float())
audio_spec_denoised = audio_spec - self.bias_spec * strength
audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)

+ 12
- 11
SqueezeWave/glow.py View File

@ -103,9 +103,8 @@ class Invertible1x1Conv(torch.nn.Module):
# Reverse computation
W_inverse = W.float().inverse()
W_inverse = Variable(W_inverse[..., None])
if z.type() == 'torch.cuda.HalfTensor':
W_inverse = W_inverse.half()
self.W_inverse = W_inverse
self.W_inverse = W_inverse.half()
self.W_inverse = self.W_inverse.to(torch.float32)
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
return z
else:
@ -148,8 +147,8 @@ class WN(torch.nn.Module):
# depthwise separable convolution
depthwise = torch.nn.Conv1d(n_channels, n_channels, 3,
dilation=dilation, padding=padding,
groups=n_channels).cuda()
pointwise = torch.nn.Conv1d(n_channels, 2*n_channels, 1).cuda()
groups=n_channels)
pointwise = torch.nn.Conv1d(n_channels, 2*n_channels, 1)
bn = torch.nn.BatchNorm1d(n_channels)
self.in_layers.append(torch.nn.Sequential(bn, depthwise, pointwise))
# res_skip_layer
@ -245,12 +244,14 @@ class SqueezeWave(torch.nn.Module):
def infer(self, spect, sigma=1.0):
spect_size = spect.size()
l = spect.size(2)*(256 // self.n_audio_channel)
if spect.type() == 'torch.cuda.HalfTensor':
audio = torch.cuda.HalfTensor(spect.size(0),
spect = spect.to(torch.float32)
if spect.type() == 'torch.HalfTensor':
audio = torch.HalfTensor(spect.size(0),
self.n_remaining_channels,
l).normal_()
else:
audio = torch.cuda.FloatTensor(spect.size(0),
audio = torch.FloatTensor(spect.size(0),
self.n_remaining_channels,
l).normal_()
@ -268,10 +269,10 @@ class SqueezeWave(torch.nn.Module):
audio = self.convinv[k](audio, reverse=True)
if k % self.n_early_every == 0 and k > 0:
if spect.type() == 'torch.cuda.HalfTensor':
z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, l).normal_()
if spect.type() == 'torch.HalfTensor':
z = torch.HalfTensor(spect.size(0), self.n_early_size, l).normal_()
else:
z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, l).normal_()
z = torch.FloatTensor(spect.size(0), self.n_early_size, l).normal_()
audio = torch.cat((sigma*z, audio),1)
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data

+ 15
- 10
SqueezeWave/inference.py View File

@ -32,27 +32,29 @@ from scipy.io.wavfile import write
import torch
from mel2samp import files_to_list, MAX_WAV_VALUE
from denoiser import Denoiser
import time
def main(mel_files, squeezewave_path, sigma, output_dir, sampling_rate, is_fp16,
denoiser_strength):
mel_files = files_to_list(mel_files)
squeezewave = torch.load(squeezewave_path)['model']
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
squeezewave = torch.load(squeezewave_path,map_location=device) ['model']
squeezewave = squeezewave.remove_weightnorm(squeezewave)
squeezewave.cuda().eval()
squeezewave.eval()
if is_fp16:
from apex import amp
squeezewave, _ = amp.initialize(squeezewave, [], opt_level="O3")
squeezewave, _ = amp.initialize(squeezewave,[],opt_level="O3")
if denoiser_strength > 0:
denoiser = Denoiser(squeezewave).cuda()
denoiser = Denoiser(squeezewave)
start = time.time()
for i, file_path in enumerate(mel_files):
file_name = os.path.splitext(os.path.basename(file_path))[0]
mel = torch.load(file_path)
mel = torch.autograd.Variable(mel.cuda())
mel = torch.unsqueeze(mel, 0)
mel = mel.half() if is_fp16 else mel
mel = torch.load(file_path,map_location=device)
mel = torch.autograd.Variable(mel)
mel = mel.half()
with torch.no_grad():
audio = squeezewave.infer(mel, sigma=sigma).float()
if denoiser_strength > 0:
@ -65,6 +67,9 @@ def main(mel_files, squeezewave_path, sigma, output_dir, sampling_rate, is_fp16,
output_dir, "{}_synthesis.wav".format(file_name))
write(audio_path, sampling_rate, audio)
print(audio_path)
end = time.time()
print("Squeezewave vocoder time")
print(end-start)
if __name__ == "__main__":

Loading…
Cancel
Save