Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

54 lines
2.1 KiB

import torch
import torch.nn as nn
from transformer.Models import Encoder, Decoder
from transformer.Layers import Linear, PostNet
from modules import LengthRegulator
import hparams as hp
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class FastSpeech(nn.Module):
""" FastSpeech """
def __init__(self):
super(FastSpeech, self).__init__()
self.encoder = Encoder()
self.length_regulator = LengthRegulator()
self.decoder = Decoder()
self.mel_linear = Linear(hp.decoder_output_size, hp.num_mels)
self.postnet = PostNet()
def forward(self, src_seq, src_pos, mel_pos=None, mel_max_length=None, length_target=None, alpha=1.0):
encoder_output, _ = self.encoder(src_seq, src_pos)
if self.training:
length_regulator_output, duration_predictor_output = self.length_regulator(encoder_output,
target=length_target,
alpha=alpha,
mel_max_length=mel_max_length)
decoder_output = self.decoder(length_regulator_output, mel_pos)
mel_output = self.mel_linear(decoder_output)
mel_output_postnet = self.postnet(mel_output) + mel_output
return mel_output, mel_output_postnet, duration_predictor_output
else:
length_regulator_output, decoder_pos = self.length_regulator(encoder_output,
alpha=alpha)
decoder_output = self.decoder(length_regulator_output, decoder_pos)
mel_output = self.mel_linear(decoder_output)
mel_output_postnet = self.postnet(mel_output) + mel_output
return mel_output, mel_output_postnet
if __name__ == "__main__":
# Test
model = FastSpeech()
print(sum(param.numel() for param in model.parameters()))