Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

124 lines
3.7 KiB

import torch
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import os
import hparams
import audio as Audio
from text import text_to_sequence
from utils import process_text, pad_1D, pad_2D
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class FastSpeechDataset(Dataset):
""" LJSpeech """
def __init__(self):
self.text = process_text(os.path.join("data", "train.txt"))
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
mel_gt_name = os.path.join(
hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (idx+1))
mel_gt_target = np.load(mel_gt_name)
D = np.load(os.path.join(hparams.alignment_path, str(idx)+".npy"))
character = self.text[idx][0:len(self.text[idx])-1]
character = np.array(text_to_sequence(
character, hparams.text_cleaners))
sample = {"text": character,
"mel_target": mel_gt_target,
"D": D}
return sample
def reprocess(batch, cut_list):
texts = [batch[ind]["text"] for ind in cut_list]
mel_targets = [batch[ind]["mel_target"] for ind in cut_list]
Ds = [batch[ind]["D"] for ind in cut_list]
length_text = np.array([])
for text in texts:
length_text = np.append(length_text, text.shape[0])
src_pos = list()
max_len = int(max(length_text))
for length_src_row in length_text:
src_pos.append(np.pad([i+1 for i in range(int(length_src_row))],
(0, max_len-int(length_src_row)), 'constant'))
src_pos = np.array(src_pos)
length_mel = np.array(list())
for mel in mel_targets:
length_mel = np.append(length_mel, mel.shape[0])
mel_pos = list()
max_mel_len = int(max(length_mel))
for length_mel_row in length_mel:
mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))],
(0, max_mel_len-int(length_mel_row)), 'constant'))
mel_pos = np.array(mel_pos)
texts = pad_1D(texts)
Ds = pad_1D(Ds)
mel_targets = pad_2D(mel_targets)
out = {"text": texts,
"mel_target": mel_targets,
"D": Ds,
"mel_pos": mel_pos,
"src_pos": src_pos,
"mel_max_len": max_mel_len}
return out
def collate_fn(batch):
len_arr = np.array([d["text"].shape[0] for d in batch])
index_arr = np.argsort(-len_arr)
batchsize = len(batch)
real_batchsize = int(math.sqrt(batchsize))
cut_list = list()
for i in range(real_batchsize):
cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize])
output = list()
for i in range(real_batchsize):
output.append(reprocess(batch, cut_list[i]))
return output
if __name__ == "__main__":
# Test
dataset = FastSpeechDataset()
training_loader = DataLoader(dataset,
batch_size=1,
shuffle=False,
collate_fn=collate_fn,
drop_last=True,
num_workers=0)
total_step = hparams.epochs * len(training_loader) * hparams.batch_size
cnt = 0
for i, batchs in enumerate(training_loader):
for j, data_of_batch in enumerate(batchs):
mel_target = torch.from_numpy(
data_of_batch["mel_target"]).float().to(device)
D = torch.from_numpy(data_of_batch["D"]).int().to(device)
# print(mel_target.size())
# print(D.sum())
print(cnt)
if mel_target.size(1) == D.sum().item():
cnt += 1
print(cnt)