|
|
- import torch
- import torch.nn as nn
-
-
- class FastSpeechLoss(nn.Module):
- """ FastSPeech Loss """
-
- def __init__(self):
- super(FastSpeechLoss, self).__init__()
- self.mse_loss = nn.MSELoss()
- self.l1_loss = nn.L1Loss()
-
- def forward(self, mel, mel_postnet, duration_predicted, mel_target, duration_predictor_target):
- mel_target.requires_grad = False
- mel_loss = self.mse_loss(mel, mel_target)
- mel_postnet_loss = self.mse_loss(mel_postnet, mel_target)
-
- duration_predictor_target.requires_grad = False
- # duration_predictor_target = duration_predictor_target + 1
- # duration_predictor_target = torch.log(
- # duration_predictor_target.float())
-
- # print(duration_predictor_target)
- # print(duration_predicted)
-
- duration_predictor_loss = self.l1_loss(
- duration_predicted, duration_predictor_target.float())
-
- return mel_loss, mel_postnet_loss, duration_predictor_loss
|