import torch from torch.nn import functional as F from torch.utils.data import Dataset, DataLoader import numpy as np import math import os import hparams import audio as Audio from text import text_to_sequence from utils import process_text, pad_1D, pad_2D device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class FastSpeechDataset(Dataset): """ LJSpeech """ def __init__(self): self.text = process_text(os.path.join("data", "train.txt")) def __len__(self): return len(self.text) def __getitem__(self, idx): mel_gt_name = os.path.join( hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (idx+1)) mel_gt_target = np.load(mel_gt_name) D = np.load(os.path.join(hparams.alignment_path, str(idx)+".npy")) character = self.text[idx][0:len(self.text[idx])-1] character = np.array(text_to_sequence( character, hparams.text_cleaners)) sample = {"text": character, "mel_target": mel_gt_target, "D": D} return sample def reprocess(batch, cut_list): texts = [batch[ind]["text"] for ind in cut_list] mel_targets = [batch[ind]["mel_target"] for ind in cut_list] Ds = [batch[ind]["D"] for ind in cut_list] length_text = np.array([]) for text in texts: length_text = np.append(length_text, text.shape[0]) src_pos = list() max_len = int(max(length_text)) for length_src_row in length_text: src_pos.append(np.pad([i+1 for i in range(int(length_src_row))], (0, max_len-int(length_src_row)), 'constant')) src_pos = np.array(src_pos) length_mel = np.array(list()) for mel in mel_targets: length_mel = np.append(length_mel, mel.shape[0]) mel_pos = list() max_mel_len = int(max(length_mel)) for length_mel_row in length_mel: mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))], (0, max_mel_len-int(length_mel_row)), 'constant')) mel_pos = np.array(mel_pos) texts = pad_1D(texts) Ds = pad_1D(Ds) mel_targets = pad_2D(mel_targets) out = {"text": texts, "mel_target": mel_targets, "D": Ds, "mel_pos": mel_pos, "src_pos": src_pos, "mel_max_len": max_mel_len} return out def collate_fn(batch): len_arr = np.array([d["text"].shape[0] for d in batch]) index_arr = np.argsort(-len_arr) batchsize = len(batch) real_batchsize = int(math.sqrt(batchsize)) cut_list = list() for i in range(real_batchsize): cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize]) output = list() for i in range(real_batchsize): output.append(reprocess(batch, cut_list[i])) return output if __name__ == "__main__": # Test dataset = FastSpeechDataset() training_loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, drop_last=True, num_workers=0) total_step = hparams.epochs * len(training_loader) * hparams.batch_size cnt = 0 for i, batchs in enumerate(training_loader): for j, data_of_batch in enumerate(batchs): mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) # print(mel_target.size()) # print(D.sum()) print(cnt) if mel_target.size(1) == D.sum().item(): cnt += 1 print(cnt)