Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.7 KiB

  1. import torch
  2. from torch.nn import functional as F
  3. from torch.utils.data import Dataset, DataLoader
  4. import numpy as np
  5. import math
  6. import os
  7. import hparams
  8. import audio as Audio
  9. from text import text_to_sequence
  10. from utils import process_text, pad_1D, pad_2D
  11. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  12. class FastSpeechDataset(Dataset):
  13. """ LJSpeech """
  14. def __init__(self):
  15. self.text = process_text(os.path.join("data", "train.txt"))
  16. def __len__(self):
  17. return len(self.text)
  18. def __getitem__(self, idx):
  19. mel_gt_name = os.path.join(
  20. hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (idx+1))
  21. mel_gt_target = np.load(mel_gt_name)
  22. D = np.load(os.path.join(hparams.alignment_path, str(idx)+".npy"))
  23. character = self.text[idx][0:len(self.text[idx])-1]
  24. character = np.array(text_to_sequence(
  25. character, hparams.text_cleaners))
  26. sample = {"text": character,
  27. "mel_target": mel_gt_target,
  28. "D": D}
  29. return sample
  30. def reprocess(batch, cut_list):
  31. texts = [batch[ind]["text"] for ind in cut_list]
  32. mel_targets = [batch[ind]["mel_target"] for ind in cut_list]
  33. Ds = [batch[ind]["D"] for ind in cut_list]
  34. length_text = np.array([])
  35. for text in texts:
  36. length_text = np.append(length_text, text.shape[0])
  37. src_pos = list()
  38. max_len = int(max(length_text))
  39. for length_src_row in length_text:
  40. src_pos.append(np.pad([i+1 for i in range(int(length_src_row))],
  41. (0, max_len-int(length_src_row)), 'constant'))
  42. src_pos = np.array(src_pos)
  43. length_mel = np.array(list())
  44. for mel in mel_targets:
  45. length_mel = np.append(length_mel, mel.shape[0])
  46. mel_pos = list()
  47. max_mel_len = int(max(length_mel))
  48. for length_mel_row in length_mel:
  49. mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))],
  50. (0, max_mel_len-int(length_mel_row)), 'constant'))
  51. mel_pos = np.array(mel_pos)
  52. texts = pad_1D(texts)
  53. Ds = pad_1D(Ds)
  54. mel_targets = pad_2D(mel_targets)
  55. out = {"text": texts,
  56. "mel_target": mel_targets,
  57. "D": Ds,
  58. "mel_pos": mel_pos,
  59. "src_pos": src_pos,
  60. "mel_max_len": max_mel_len}
  61. return out
  62. def collate_fn(batch):
  63. len_arr = np.array([d["text"].shape[0] for d in batch])
  64. index_arr = np.argsort(-len_arr)
  65. batchsize = len(batch)
  66. real_batchsize = int(math.sqrt(batchsize))
  67. cut_list = list()
  68. for i in range(real_batchsize):
  69. cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize])
  70. output = list()
  71. for i in range(real_batchsize):
  72. output.append(reprocess(batch, cut_list[i]))
  73. return output
  74. if __name__ == "__main__":
  75. # Test
  76. dataset = FastSpeechDataset()
  77. training_loader = DataLoader(dataset,
  78. batch_size=1,
  79. shuffle=False,
  80. collate_fn=collate_fn,
  81. drop_last=True,
  82. num_workers=0)
  83. total_step = hparams.epochs * len(training_loader) * hparams.batch_size
  84. cnt = 0
  85. for i, batchs in enumerate(training_loader):
  86. for j, data_of_batch in enumerate(batchs):
  87. mel_target = torch.from_numpy(
  88. data_of_batch["mel_target"]).float().to(device)
  89. D = torch.from_numpy(data_of_batch["D"]).int().to(device)
  90. # print(mel_target.size())
  91. # print(D.sum())
  92. print(cnt)
  93. if mel_target.size(1) == D.sum().item():
  94. cnt += 1
  95. print(cnt)