From 32b9a135d06b0a2bd00624d6cd014b7b392271bf Mon Sep 17 00:00:00 2001 From: rafaelvalle Date: Sun, 25 Nov 2018 22:34:38 -0800 Subject: [PATCH] utils.py: updating --- utils.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/utils.py b/utils.py index 633ecff..c843d95 100644 --- a/utils.py +++ b/utils.py @@ -4,29 +4,26 @@ import torch def get_mask_from_lengths(lengths): - max_len = torch.max(lengths) - ids = torch.arange(0, max_len).long().cuda() + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) mask = (ids < lengths.unsqueeze(1)).byte() return mask -def load_wav_to_torch(full_path, sr): +def load_wav_to_torch(full_path): sampling_rate, data = read(full_path) - assert sr == sampling_rate, "{} SR doesn't match {} on path {}".format( - sr, sampling_rate, full_path) - return torch.FloatTensor(data.astype(np.float32)) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate -def load_filepaths_and_text(filename, sort_by_length, split="|"): +def load_filepaths_and_text(filename, split="|"): with open(filename, encoding='utf-8') as f: filepaths_and_text = [line.strip().split(split) for line in f] - - if sort_by_length: - filepaths_and_text.sort(key=lambda x: len(x[1])) - return filepaths_and_text def to_gpu(x): - x = x.contiguous().cuda(async=True) + x = x.contiguous() + + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) return torch.autograd.Variable(x)