diff --git a/SqueezeWave/denoiser.py b/SqueezeWave/denoiser.py index 8f9ff57..2da8f3a 100644 --- a/SqueezeWave/denoiser.py +++ b/SqueezeWave/denoiser.py @@ -11,7 +11,7 @@ class Denoiser(torch.nn.Module): super(Denoiser, self).__init__() self.stft = STFT(filter_length=filter_length, hop_length=int(filter_length/n_overlap), - win_length=win_length).cuda() + win_length=win_length) if mode == 'zeros': mel_input = torch.zeros( (1, 80, 88), @@ -32,7 +32,7 @@ class Denoiser(torch.nn.Module): self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) def forward(self, audio, strength=0.1): - audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) + audio_spec, audio_angles = self.stft.transform(audio.float()) audio_spec_denoised = audio_spec - self.bias_spec * strength audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) diff --git a/SqueezeWave/glow.py b/SqueezeWave/glow.py index f692103..36199dc 100644 --- a/SqueezeWave/glow.py +++ b/SqueezeWave/glow.py @@ -103,9 +103,8 @@ class Invertible1x1Conv(torch.nn.Module): # Reverse computation W_inverse = W.float().inverse() W_inverse = Variable(W_inverse[..., None]) - if z.type() == 'torch.cuda.HalfTensor': - W_inverse = W_inverse.half() - self.W_inverse = W_inverse + self.W_inverse = W_inverse.half() + self.W_inverse = self.W_inverse.to(torch.float32) z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) return z else: @@ -148,8 +147,8 @@ class WN(torch.nn.Module): # depthwise separable convolution depthwise = torch.nn.Conv1d(n_channels, n_channels, 3, dilation=dilation, padding=padding, - groups=n_channels).cuda() - pointwise = torch.nn.Conv1d(n_channels, 2*n_channels, 1).cuda() + groups=n_channels) + pointwise = torch.nn.Conv1d(n_channels, 2*n_channels, 1) bn = torch.nn.BatchNorm1d(n_channels) self.in_layers.append(torch.nn.Sequential(bn, depthwise, pointwise)) # res_skip_layer @@ -245,12 +244,14 @@ class SqueezeWave(torch.nn.Module): def infer(self, spect, sigma=1.0): spect_size = spect.size() l = spect.size(2)*(256 // self.n_audio_channel) - if spect.type() == 'torch.cuda.HalfTensor': - audio = torch.cuda.HalfTensor(spect.size(0), + + spect = spect.to(torch.float32) + if spect.type() == 'torch.HalfTensor': + audio = torch.HalfTensor(spect.size(0), self.n_remaining_channels, l).normal_() else: - audio = torch.cuda.FloatTensor(spect.size(0), + audio = torch.FloatTensor(spect.size(0), self.n_remaining_channels, l).normal_() @@ -268,10 +269,10 @@ class SqueezeWave(torch.nn.Module): audio = self.convinv[k](audio, reverse=True) if k % self.n_early_every == 0 and k > 0: - if spect.type() == 'torch.cuda.HalfTensor': - z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, l).normal_() + if spect.type() == 'torch.HalfTensor': + z = torch.HalfTensor(spect.size(0), self.n_early_size, l).normal_() else: - z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, l).normal_() + z = torch.FloatTensor(spect.size(0), self.n_early_size, l).normal_() audio = torch.cat((sigma*z, audio),1) audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data diff --git a/SqueezeWave/inference.py b/SqueezeWave/inference.py index 568e6ce..bd0ff9b 100644 --- a/SqueezeWave/inference.py +++ b/SqueezeWave/inference.py @@ -32,27 +32,29 @@ from scipy.io.wavfile import write import torch from mel2samp import files_to_list, MAX_WAV_VALUE from denoiser import Denoiser - +import time def main(mel_files, squeezewave_path, sigma, output_dir, sampling_rate, is_fp16, denoiser_strength): mel_files = files_to_list(mel_files) - squeezewave = torch.load(squeezewave_path)['model'] + + #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device('cpu') + squeezewave = torch.load(squeezewave_path,map_location=device) ['model'] squeezewave = squeezewave.remove_weightnorm(squeezewave) - squeezewave.cuda().eval() + squeezewave.eval() if is_fp16: from apex import amp - squeezewave, _ = amp.initialize(squeezewave, [], opt_level="O3") + squeezewave, _ = amp.initialize(squeezewave,[],opt_level="O3") if denoiser_strength > 0: - denoiser = Denoiser(squeezewave).cuda() - + denoiser = Denoiser(squeezewave) + start = time.time() for i, file_path in enumerate(mel_files): file_name = os.path.splitext(os.path.basename(file_path))[0] - mel = torch.load(file_path) - mel = torch.autograd.Variable(mel.cuda()) - mel = torch.unsqueeze(mel, 0) - mel = mel.half() if is_fp16 else mel + mel = torch.load(file_path,map_location=device) + mel = torch.autograd.Variable(mel) + mel = mel.half() with torch.no_grad(): audio = squeezewave.infer(mel, sigma=sigma).float() if denoiser_strength > 0: @@ -65,6 +67,9 @@ def main(mel_files, squeezewave_path, sigma, output_dir, sampling_rate, is_fp16, output_dir, "{}_synthesis.wav".format(file_name)) write(audio_path, sampling_rate, audio) print(audio_path) + end = time.time() + print("Squeezewave vocoder time") + print(end-start) if __name__ == "__main__":