Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
2.1 KiB

  1. import torch
  2. import torch.nn as nn
  3. from transformer.Models import Encoder, Decoder
  4. from transformer.Layers import Linear, PostNet
  5. from modules import LengthRegulator
  6. import hparams as hp
  7. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  8. class FastSpeech(nn.Module):
  9. """ FastSpeech """
  10. def __init__(self):
  11. super(FastSpeech, self).__init__()
  12. self.encoder = Encoder()
  13. self.length_regulator = LengthRegulator()
  14. self.decoder = Decoder()
  15. self.mel_linear = Linear(hp.decoder_output_size, hp.num_mels)
  16. self.postnet = PostNet()
  17. def forward(self, src_seq, src_pos, mel_pos=None, mel_max_length=None, length_target=None, alpha=1.0):
  18. encoder_output, _ = self.encoder(src_seq, src_pos)
  19. if self.training:
  20. length_regulator_output, duration_predictor_output = self.length_regulator(encoder_output,
  21. target=length_target,
  22. alpha=alpha,
  23. mel_max_length=mel_max_length)
  24. decoder_output = self.decoder(length_regulator_output, mel_pos)
  25. mel_output = self.mel_linear(decoder_output)
  26. mel_output_postnet = self.postnet(mel_output) + mel_output
  27. return mel_output, mel_output_postnet, duration_predictor_output
  28. else:
  29. length_regulator_output, decoder_pos = self.length_regulator(encoder_output,
  30. alpha=alpha)
  31. decoder_output = self.decoder(length_regulator_output, decoder_pos)
  32. mel_output = self.mel_linear(decoder_output)
  33. mel_output_postnet = self.postnet(mel_output) + mel_output
  34. return mel_output, mel_output_postnet
  35. if __name__ == "__main__":
  36. # Test
  37. model = FastSpeech()
  38. print(sum(param.numel() for param in model.parameters()))