import torch import torch.nn as nn from transformer.Models import Encoder, Decoder from transformer.Layers import Linear, PostNet from modules import LengthRegulator import hparams as hp device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class FastSpeech(nn.Module): """ FastSpeech """ def __init__(self): super(FastSpeech, self).__init__() self.encoder = Encoder() self.length_regulator = LengthRegulator() self.decoder = Decoder() self.mel_linear = Linear(hp.decoder_output_size, hp.num_mels) self.postnet = PostNet() def forward(self, src_seq, src_pos, mel_pos=None, mel_max_length=None, length_target=None, alpha=1.0): encoder_output, _ = self.encoder(src_seq, src_pos) if self.training: length_regulator_output, duration_predictor_output = self.length_regulator(encoder_output, target=length_target, alpha=alpha, mel_max_length=mel_max_length) decoder_output = self.decoder(length_regulator_output, mel_pos) mel_output = self.mel_linear(decoder_output) mel_output_postnet = self.postnet(mel_output) + mel_output return mel_output, mel_output_postnet, duration_predictor_output else: length_regulator_output, decoder_pos = self.length_regulator(encoder_output, alpha=alpha) decoder_output = self.decoder(length_regulator_output, decoder_pos) mel_output = self.mel_linear(decoder_output) mel_output_postnet = self.postnet(mel_output) + mel_output return mel_output, mel_output_postnet if __name__ == "__main__": # Test model = FastSpeech() print(sum(param.numel() for param in model.parameters()))