Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

29 lines
1016 B

  1. import torch
  2. import torch.nn as nn
  3. class FastSpeechLoss(nn.Module):
  4. """ FastSPeech Loss """
  5. def __init__(self):
  6. super(FastSpeechLoss, self).__init__()
  7. self.mse_loss = nn.MSELoss()
  8. self.l1_loss = nn.L1Loss()
  9. def forward(self, mel, mel_postnet, duration_predicted, mel_target, duration_predictor_target):
  10. mel_target.requires_grad = False
  11. mel_loss = self.mse_loss(mel, mel_target)
  12. mel_postnet_loss = self.mse_loss(mel_postnet, mel_target)
  13. duration_predictor_target.requires_grad = False
  14. # duration_predictor_target = duration_predictor_target + 1
  15. # duration_predictor_target = torch.log(
  16. # duration_predictor_target.float())
  17. # print(duration_predictor_target)
  18. # print(duration_predicted)
  19. duration_predictor_loss = self.l1_loss(
  20. duration_predicted, duration_predictor_target.float())
  21. return mel_loss, mel_postnet_loss, duration_predictor_loss