import torch
|
|
import torch.nn as nn
|
|
|
|
from transformer.Models import Encoder, Decoder
|
|
from transformer.Layers import Linear, PostNet
|
|
from modules import LengthRegulator
|
|
import hparams as hp
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
class FastSpeech(nn.Module):
|
|
""" FastSpeech """
|
|
|
|
def __init__(self):
|
|
super(FastSpeech, self).__init__()
|
|
|
|
self.encoder = Encoder()
|
|
self.length_regulator = LengthRegulator()
|
|
self.decoder = Decoder()
|
|
|
|
self.mel_linear = Linear(hp.decoder_output_size, hp.num_mels)
|
|
self.postnet = PostNet()
|
|
|
|
def forward(self, src_seq, src_pos, mel_pos=None, mel_max_length=None, length_target=None, alpha=1.0):
|
|
encoder_output, _ = self.encoder(src_seq, src_pos)
|
|
|
|
if self.training:
|
|
length_regulator_output, duration_predictor_output = self.length_regulator(encoder_output,
|
|
target=length_target,
|
|
alpha=alpha,
|
|
mel_max_length=mel_max_length)
|
|
decoder_output = self.decoder(length_regulator_output, mel_pos)
|
|
|
|
mel_output = self.mel_linear(decoder_output)
|
|
mel_output_postnet = self.postnet(mel_output) + mel_output
|
|
|
|
return mel_output, mel_output_postnet, duration_predictor_output
|
|
else:
|
|
length_regulator_output, decoder_pos = self.length_regulator(encoder_output,
|
|
alpha=alpha)
|
|
|
|
decoder_output = self.decoder(length_regulator_output, decoder_pos)
|
|
|
|
mel_output = self.mel_linear(decoder_output)
|
|
mel_output_postnet = self.postnet(mel_output) + mel_output
|
|
|
|
return mel_output, mel_output_postnet
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test
|
|
model = FastSpeech()
|
|
print(sum(param.numel() for param in model.parameters()))
|