Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

29 lines
1016 B

import torch
import torch.nn as nn
class FastSpeechLoss(nn.Module):
""" FastSPeech Loss """
def __init__(self):
super(FastSpeechLoss, self).__init__()
self.mse_loss = nn.MSELoss()
self.l1_loss = nn.L1Loss()
def forward(self, mel, mel_postnet, duration_predicted, mel_target, duration_predictor_target):
mel_target.requires_grad = False
mel_loss = self.mse_loss(mel, mel_target)
mel_postnet_loss = self.mse_loss(mel_postnet, mel_target)
duration_predictor_target.requires_grad = False
# duration_predictor_target = duration_predictor_target + 1
# duration_predictor_target = torch.log(
# duration_predictor_target.float())
# print(duration_predictor_target)
# print(duration_predicted)
duration_predictor_loss = self.l1_loss(
duration_predicted, duration_predictor_target.float())
return mel_loss, mel_postnet_loss, duration_predictor_loss