Fork of https://github.com/alokprasad/fastspeech_squeezewave to also fix denoising in squeezewave
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

404 lines
12 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np
import copy
import math
import hparams as hp
import utils
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
''' Sinusoid position encoding table '''
def cal_angle(position, hid_idx):
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_posi_angle_vec(pos_i)
for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if padding_idx is not None:
# zero vector for padding dimension
sinusoid_table[padding_idx] = 0.
return torch.FloatTensor(sinusoid_table)
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class LengthRegulator(nn.Module):
""" Length Regulator """
def __init__(self):
super(LengthRegulator, self).__init__()
self.duration_predictor = DurationPredictor()
def LR(self, x, duration_predictor_output, alpha=1.0, mel_max_length=None):
output = list()
for batch, expand_target in zip(x, duration_predictor_output):
output.append(self.expand(batch, expand_target, alpha))
if mel_max_length:
output = utils.pad(output, mel_max_length)
else:
output = utils.pad(output)
return output
def expand(self, batch, predicted, alpha):
out = list()
for i, vec in enumerate(batch):
expand_size = predicted[i].item()
out.append(vec.expand(int(expand_size*alpha), -1))
out = torch.cat(out, 0)
return out
def rounding(self, num):
if num - int(num) >= 0.5:
return int(num) + 1
else:
return int(num)
def forward(self, x, alpha=1.0, target=None, mel_max_length=None):
duration_predictor_output = self.duration_predictor(x)
if self.training:
output = self.LR(x, target, mel_max_length=mel_max_length)
return output, duration_predictor_output
else:
for idx, ele in enumerate(duration_predictor_output[0]):
duration_predictor_output[0][idx] = self.rounding(ele)
output = self.LR(x, duration_predictor_output, alpha)
mel_pos = torch.stack(
[torch.Tensor([i+1 for i in range(output.size(1))])]).long().to(device)
return output, mel_pos
class DurationPredictor(nn.Module):
""" Duration Predictor """
def __init__(self):
super(DurationPredictor, self).__init__()
self.input_size = hp.d_model
self.filter_size = hp.duration_predictor_filter_size
self.kernel = hp.duration_predictor_kernel_size
self.conv_output_size = hp.duration_predictor_filter_size
self.dropout = hp.dropout
self.conv_layer = nn.Sequential(OrderedDict([
("conv1d_1", Conv(self.input_size,
self.filter_size,
kernel_size=self.kernel,
padding=1)),
("layer_norm_1", nn.LayerNorm(self.filter_size)),
("relu_1", nn.ReLU()),
("dropout_1", nn.Dropout(self.dropout)),
("conv1d_2", Conv(self.filter_size,
self.filter_size,
kernel_size=self.kernel,
padding=1)),
("layer_norm_2", nn.LayerNorm(self.filter_size)),
("relu_2", nn.ReLU()),
("dropout_2", nn.Dropout(self.dropout))
]))
self.linear_layer = Linear(self.conv_output_size, 1)
self.relu = nn.ReLU()
def forward(self, encoder_output):
out = self.conv_layer(encoder_output)
out = self.linear_layer(out)
out = self.relu(out)
out = out.squeeze()
if not self.training:
out = out.unsqueeze(0)
return out
class Conv(nn.Module):
"""
Convolution Module
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
w_init='linear'):
"""
:param in_channels: dimension of input
:param out_channels: dimension of output
:param kernel_size: size of kernel
:param stride: size of stride
:param padding: size of padding
:param dilation: dilation rate
:param bias: boolean. if True, bias is included.
:param w_init: str. weight inits with xavier initialization.
"""
super(Conv, self).__init__()
self.conv = nn.Conv1d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
nn.init.xavier_uniform_(
self.conv.weight, gain=nn.init.calculate_gain(w_init))
def forward(self, x):
x = x.contiguous().transpose(1, 2)
x = self.conv(x)
x = x.contiguous().transpose(1, 2)
return x
class Linear(nn.Module):
"""
Linear Module
"""
def __init__(self, in_dim, out_dim, bias=True, w_init='linear'):
"""
:param in_dim: dimension of input
:param out_dim: dimension of output
:param bias: boolean. if True, bias is included.
:param w_init: str. weight inits with xavier initialization.
"""
super(Linear, self).__init__()
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=nn.init.calculate_gain(w_init))
def forward(self, x):
return self.linear_layer(x)
class FFN(nn.Module):
"""
Positionwise Feed-Forward Network
"""
def __init__(self, num_hidden):
"""
:param num_hidden: dimension of hidden
"""
super(FFN, self).__init__()
self.w_1 = Conv(num_hidden, num_hidden * 4,
kernel_size=3, padding=1, w_init='relu')
self.w_2 = Conv(num_hidden * 4, num_hidden, kernel_size=3, padding=1)
self.dropout = nn.Dropout(p=0.1)
self.layer_norm = nn.LayerNorm(num_hidden)
def forward(self, input_):
# FFN Network
x = input_
x = self.w_2(torch.relu(self.w_1(x)))
# residual connection
x = x + input_
# dropout
x = self.dropout(x)
# layer normalization
x = self.layer_norm(x)
return x
class MultiheadAttention(nn.Module):
"""
Multihead attention mechanism (dot attention)
"""
def __init__(self, num_hidden_k):
"""
:param num_hidden_k: dimension of hidden
"""
super(MultiheadAttention, self).__init__()
self.num_hidden_k = num_hidden_k
self.attn_dropout = nn.Dropout(p=0.1)
def forward(self, key, value, query, mask=None, query_mask=None):
# Get attention score
attn = torch.bmm(query, key.transpose(1, 2))
attn = attn / math.sqrt(self.num_hidden_k)
# Masking to ignore padding (key side)
if mask is not None:
attn = attn.masked_fill(mask, -2 ** 32 + 1)
attn = torch.softmax(attn, dim=-1)
else:
attn = torch.softmax(attn, dim=-1)
# Masking to ignore padding (query side)
if query_mask is not None:
attn = attn * query_mask
# Dropout
attn = self.attn_dropout(attn)
# Get Context Vector
result = torch.bmm(attn, value)
return result, attn
class Attention(nn.Module):
"""
Attention Network
"""
def __init__(self, num_hidden, h=2):
"""
:param num_hidden: dimension of hidden
:param h: num of heads
"""
super(Attention, self).__init__()
self.num_hidden = num_hidden
self.num_hidden_per_attn = num_hidden // h
self.h = h
self.key = Linear(num_hidden, num_hidden, bias=False)
self.value = Linear(num_hidden, num_hidden, bias=False)
self.query = Linear(num_hidden, num_hidden, bias=False)
self.multihead = MultiheadAttention(self.num_hidden_per_attn)
self.residual_dropout = nn.Dropout(p=0.1)
self.final_linear = Linear(num_hidden * 2, num_hidden)
self.layer_norm_1 = nn.LayerNorm(num_hidden)
def forward(self, memory, decoder_input, mask=None, query_mask=None):
batch_size = memory.size(0)
seq_k = memory.size(1)
seq_q = decoder_input.size(1)
# Repeat masks h times
if query_mask is not None:
query_mask = query_mask.unsqueeze(-1).repeat(1, 1, seq_k)
query_mask = query_mask.repeat(self.h, 1, 1)
if mask is not None:
mask = mask.repeat(self.h, 1, 1)
# Make multihead
key = self.key(memory).view(batch_size,
seq_k,
self.h,
self.num_hidden_per_attn)
value = self.value(memory).view(batch_size,
seq_k,
self.h,
self.num_hidden_per_attn)
query = self.query(decoder_input).view(batch_size,
seq_q,
self.h,
self.num_hidden_per_attn)
key = key.permute(2, 0, 1, 3).contiguous().view(-1,
seq_k,
self.num_hidden_per_attn)
value = value.permute(2, 0, 1, 3).contiguous().view(-1,
seq_k,
self.num_hidden_per_attn)
query = query.permute(2, 0, 1, 3).contiguous().view(-1,
seq_q,
self.num_hidden_per_attn)
# Get context vector
result, attns = self.multihead(
key, value, query, mask=mask, query_mask=query_mask)
# Concatenate all multihead context vector
result = result.view(self.h, batch_size, seq_q,
self.num_hidden_per_attn)
result = result.permute(1, 2, 0, 3).contiguous().view(
batch_size, seq_q, -1)
# Concatenate context vector with input (most important)
result = torch.cat([decoder_input, result], dim=-1)
# Final linear
result = self.final_linear(result)
# Residual dropout & connection
result = self.residual_dropout(result)
result = result + decoder_input
# Layer normalization
result = self.layer_norm_1(result)
return result, attns
class FFTBlock(torch.nn.Module):
"""FFT Block"""
def __init__(self,
d_model,
n_head=hp.Head):
super(FFTBlock, self).__init__()
self.slf_attn = clones(Attention(d_model), hp.N)
self.pos_ffn = clones(FFN(d_model), hp.N)
self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024,
d_model,
padding_idx=0), freeze=True)
def forward(self, x, pos, return_attns=False):
# Get character mask
if self.training:
c_mask = pos.ne(0).type(torch.float)
mask = pos.eq(0).unsqueeze(1).repeat(1, x.size(1), 1)
else:
c_mask, mask = None, None
# Get positional embedding, apply alpha and add
pos = self.pos_emb(pos)
x = x + pos
# Attention encoder-encoder
attns = list()
for slf_attn, ffn in zip(self.slf_attn, self.pos_ffn):
x, attn = slf_attn(x, x, mask=mask, query_mask=c_mask)
x = ffn(x)
attns.append(attn)
return x, attns