diff --git a/stft.py b/stft.py index 03cd82f..edfc44a 100644 --- a/stft.py +++ b/stft.py @@ -124,6 +124,7 @@ class STFT(torch.nn.Module): np.where(window_sum > tiny(window_sum))[0]) window_sum = torch.autograd.Variable( torch.from_numpy(window_sum), requires_grad=False) + window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] # scale by hop ratio