import torch import torch.nn as nn class FastSpeechLoss(nn.Module): """ FastSPeech Loss """ def __init__(self): super(FastSpeechLoss, self).__init__() self.mse_loss = nn.MSELoss() self.l1_loss = nn.L1Loss() def forward(self, mel, mel_postnet, duration_predicted, mel_target, duration_predictor_target): mel_target.requires_grad = False mel_loss = self.mse_loss(mel, mel_target) mel_postnet_loss = self.mse_loss(mel_postnet, mel_target) duration_predictor_target.requires_grad = False # duration_predictor_target = duration_predictor_target + 1 # duration_predictor_target = torch.log( # duration_predictor_target.float()) # print(duration_predictor_target) # print(duration_predicted) duration_predictor_loss = self.l1_loss( duration_predicted, duration_predictor_target.float()) return mel_loss, mel_postnet_loss, duration_predictor_loss