import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from collections import OrderedDict
|
|
import numpy as np
|
|
import copy
|
|
import math
|
|
|
|
import hparams as hp
|
|
import utils
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
|
|
''' Sinusoid position encoding table '''
|
|
|
|
def cal_angle(position, hid_idx):
|
|
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
|
|
|
|
def get_posi_angle_vec(position):
|
|
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
|
|
|
|
sinusoid_table = np.array([get_posi_angle_vec(pos_i)
|
|
for pos_i in range(n_position)])
|
|
|
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
|
|
if padding_idx is not None:
|
|
# zero vector for padding dimension
|
|
sinusoid_table[padding_idx] = 0.
|
|
|
|
return torch.FloatTensor(sinusoid_table)
|
|
|
|
|
|
def clones(module, N):
|
|
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
|
|
|
|
|
class LengthRegulator(nn.Module):
|
|
""" Length Regulator """
|
|
|
|
def __init__(self):
|
|
super(LengthRegulator, self).__init__()
|
|
self.duration_predictor = DurationPredictor()
|
|
|
|
def LR(self, x, duration_predictor_output, alpha=1.0, mel_max_length=None):
|
|
output = list()
|
|
|
|
for batch, expand_target in zip(x, duration_predictor_output):
|
|
output.append(self.expand(batch, expand_target, alpha))
|
|
|
|
if mel_max_length:
|
|
output = utils.pad(output, mel_max_length)
|
|
else:
|
|
output = utils.pad(output)
|
|
|
|
return output
|
|
|
|
def expand(self, batch, predicted, alpha):
|
|
out = list()
|
|
|
|
for i, vec in enumerate(batch):
|
|
expand_size = predicted[i].item()
|
|
out.append(vec.expand(int(expand_size*alpha), -1))
|
|
out = torch.cat(out, 0)
|
|
|
|
return out
|
|
|
|
def rounding(self, num):
|
|
if num - int(num) >= 0.5:
|
|
return int(num) + 1
|
|
else:
|
|
return int(num)
|
|
|
|
def forward(self, x, alpha=1.0, target=None, mel_max_length=None):
|
|
duration_predictor_output = self.duration_predictor(x)
|
|
|
|
if self.training:
|
|
output = self.LR(x, target, mel_max_length=mel_max_length)
|
|
return output, duration_predictor_output
|
|
else:
|
|
for idx, ele in enumerate(duration_predictor_output[0]):
|
|
duration_predictor_output[0][idx] = self.rounding(ele)
|
|
output = self.LR(x, duration_predictor_output, alpha)
|
|
mel_pos = torch.stack(
|
|
[torch.Tensor([i+1 for i in range(output.size(1))])]).long().to(device)
|
|
|
|
return output, mel_pos
|
|
|
|
|
|
class DurationPredictor(nn.Module):
|
|
""" Duration Predictor """
|
|
|
|
def __init__(self):
|
|
super(DurationPredictor, self).__init__()
|
|
|
|
self.input_size = hp.d_model
|
|
self.filter_size = hp.duration_predictor_filter_size
|
|
self.kernel = hp.duration_predictor_kernel_size
|
|
self.conv_output_size = hp.duration_predictor_filter_size
|
|
self.dropout = hp.dropout
|
|
|
|
self.conv_layer = nn.Sequential(OrderedDict([
|
|
("conv1d_1", Conv(self.input_size,
|
|
self.filter_size,
|
|
kernel_size=self.kernel,
|
|
padding=1)),
|
|
("layer_norm_1", nn.LayerNorm(self.filter_size)),
|
|
("relu_1", nn.ReLU()),
|
|
("dropout_1", nn.Dropout(self.dropout)),
|
|
("conv1d_2", Conv(self.filter_size,
|
|
self.filter_size,
|
|
kernel_size=self.kernel,
|
|
padding=1)),
|
|
("layer_norm_2", nn.LayerNorm(self.filter_size)),
|
|
("relu_2", nn.ReLU()),
|
|
("dropout_2", nn.Dropout(self.dropout))
|
|
]))
|
|
|
|
self.linear_layer = Linear(self.conv_output_size, 1)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, encoder_output):
|
|
out = self.conv_layer(encoder_output)
|
|
out = self.linear_layer(out)
|
|
|
|
out = self.relu(out)
|
|
|
|
out = out.squeeze()
|
|
|
|
if not self.training:
|
|
out = out.unsqueeze(0)
|
|
|
|
return out
|
|
|
|
|
|
class Conv(nn.Module):
|
|
"""
|
|
Convolution Module
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
bias=True,
|
|
w_init='linear'):
|
|
"""
|
|
:param in_channels: dimension of input
|
|
:param out_channels: dimension of output
|
|
:param kernel_size: size of kernel
|
|
:param stride: size of stride
|
|
:param padding: size of padding
|
|
:param dilation: dilation rate
|
|
:param bias: boolean. if True, bias is included.
|
|
:param w_init: str. weight inits with xavier initialization.
|
|
"""
|
|
super(Conv, self).__init__()
|
|
|
|
self.conv = nn.Conv1d(in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
bias=bias)
|
|
|
|
nn.init.xavier_uniform_(
|
|
self.conv.weight, gain=nn.init.calculate_gain(w_init))
|
|
|
|
def forward(self, x):
|
|
x = x.contiguous().transpose(1, 2)
|
|
x = self.conv(x)
|
|
x = x.contiguous().transpose(1, 2)
|
|
|
|
return x
|
|
|
|
|
|
class Linear(nn.Module):
|
|
"""
|
|
Linear Module
|
|
"""
|
|
|
|
def __init__(self, in_dim, out_dim, bias=True, w_init='linear'):
|
|
"""
|
|
:param in_dim: dimension of input
|
|
:param out_dim: dimension of output
|
|
:param bias: boolean. if True, bias is included.
|
|
:param w_init: str. weight inits with xavier initialization.
|
|
"""
|
|
super(Linear, self).__init__()
|
|
self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
|
|
|
|
nn.init.xavier_uniform_(
|
|
self.linear_layer.weight,
|
|
gain=nn.init.calculate_gain(w_init))
|
|
|
|
def forward(self, x):
|
|
return self.linear_layer(x)
|
|
|
|
|
|
class FFN(nn.Module):
|
|
"""
|
|
Positionwise Feed-Forward Network
|
|
"""
|
|
|
|
def __init__(self, num_hidden):
|
|
"""
|
|
:param num_hidden: dimension of hidden
|
|
"""
|
|
super(FFN, self).__init__()
|
|
self.w_1 = Conv(num_hidden, num_hidden * 4,
|
|
kernel_size=3, padding=1, w_init='relu')
|
|
self.w_2 = Conv(num_hidden * 4, num_hidden, kernel_size=3, padding=1)
|
|
self.dropout = nn.Dropout(p=0.1)
|
|
self.layer_norm = nn.LayerNorm(num_hidden)
|
|
|
|
def forward(self, input_):
|
|
# FFN Network
|
|
x = input_
|
|
x = self.w_2(torch.relu(self.w_1(x)))
|
|
|
|
# residual connection
|
|
x = x + input_
|
|
|
|
# dropout
|
|
x = self.dropout(x)
|
|
|
|
# layer normalization
|
|
x = self.layer_norm(x)
|
|
|
|
return x
|
|
|
|
|
|
class MultiheadAttention(nn.Module):
|
|
"""
|
|
Multihead attention mechanism (dot attention)
|
|
"""
|
|
|
|
def __init__(self, num_hidden_k):
|
|
"""
|
|
:param num_hidden_k: dimension of hidden
|
|
"""
|
|
super(MultiheadAttention, self).__init__()
|
|
|
|
self.num_hidden_k = num_hidden_k
|
|
self.attn_dropout = nn.Dropout(p=0.1)
|
|
|
|
def forward(self, key, value, query, mask=None, query_mask=None):
|
|
# Get attention score
|
|
attn = torch.bmm(query, key.transpose(1, 2))
|
|
attn = attn / math.sqrt(self.num_hidden_k)
|
|
|
|
# Masking to ignore padding (key side)
|
|
if mask is not None:
|
|
attn = attn.masked_fill(mask, -2 ** 32 + 1)
|
|
attn = torch.softmax(attn, dim=-1)
|
|
else:
|
|
attn = torch.softmax(attn, dim=-1)
|
|
|
|
# Masking to ignore padding (query side)
|
|
if query_mask is not None:
|
|
attn = attn * query_mask
|
|
|
|
# Dropout
|
|
attn = self.attn_dropout(attn)
|
|
|
|
# Get Context Vector
|
|
result = torch.bmm(attn, value)
|
|
|
|
return result, attn
|
|
|
|
|
|
class Attention(nn.Module):
|
|
"""
|
|
Attention Network
|
|
"""
|
|
|
|
def __init__(self, num_hidden, h=2):
|
|
"""
|
|
:param num_hidden: dimension of hidden
|
|
:param h: num of heads
|
|
"""
|
|
super(Attention, self).__init__()
|
|
|
|
self.num_hidden = num_hidden
|
|
self.num_hidden_per_attn = num_hidden // h
|
|
self.h = h
|
|
|
|
self.key = Linear(num_hidden, num_hidden, bias=False)
|
|
self.value = Linear(num_hidden, num_hidden, bias=False)
|
|
self.query = Linear(num_hidden, num_hidden, bias=False)
|
|
|
|
self.multihead = MultiheadAttention(self.num_hidden_per_attn)
|
|
|
|
self.residual_dropout = nn.Dropout(p=0.1)
|
|
|
|
self.final_linear = Linear(num_hidden * 2, num_hidden)
|
|
|
|
self.layer_norm_1 = nn.LayerNorm(num_hidden)
|
|
|
|
def forward(self, memory, decoder_input, mask=None, query_mask=None):
|
|
|
|
batch_size = memory.size(0)
|
|
seq_k = memory.size(1)
|
|
seq_q = decoder_input.size(1)
|
|
|
|
# Repeat masks h times
|
|
if query_mask is not None:
|
|
query_mask = query_mask.unsqueeze(-1).repeat(1, 1, seq_k)
|
|
query_mask = query_mask.repeat(self.h, 1, 1)
|
|
if mask is not None:
|
|
mask = mask.repeat(self.h, 1, 1)
|
|
|
|
# Make multihead
|
|
key = self.key(memory).view(batch_size,
|
|
seq_k,
|
|
self.h,
|
|
self.num_hidden_per_attn)
|
|
value = self.value(memory).view(batch_size,
|
|
seq_k,
|
|
self.h,
|
|
self.num_hidden_per_attn)
|
|
query = self.query(decoder_input).view(batch_size,
|
|
seq_q,
|
|
self.h,
|
|
self.num_hidden_per_attn)
|
|
|
|
key = key.permute(2, 0, 1, 3).contiguous().view(-1,
|
|
seq_k,
|
|
self.num_hidden_per_attn)
|
|
value = value.permute(2, 0, 1, 3).contiguous().view(-1,
|
|
seq_k,
|
|
self.num_hidden_per_attn)
|
|
query = query.permute(2, 0, 1, 3).contiguous().view(-1,
|
|
seq_q,
|
|
self.num_hidden_per_attn)
|
|
|
|
# Get context vector
|
|
result, attns = self.multihead(
|
|
key, value, query, mask=mask, query_mask=query_mask)
|
|
|
|
# Concatenate all multihead context vector
|
|
result = result.view(self.h, batch_size, seq_q,
|
|
self.num_hidden_per_attn)
|
|
result = result.permute(1, 2, 0, 3).contiguous().view(
|
|
batch_size, seq_q, -1)
|
|
|
|
# Concatenate context vector with input (most important)
|
|
result = torch.cat([decoder_input, result], dim=-1)
|
|
|
|
# Final linear
|
|
result = self.final_linear(result)
|
|
|
|
# Residual dropout & connection
|
|
result = self.residual_dropout(result)
|
|
result = result + decoder_input
|
|
|
|
# Layer normalization
|
|
result = self.layer_norm_1(result)
|
|
|
|
return result, attns
|
|
|
|
|
|
class FFTBlock(torch.nn.Module):
|
|
"""FFT Block"""
|
|
|
|
def __init__(self,
|
|
d_model,
|
|
n_head=hp.Head):
|
|
super(FFTBlock, self).__init__()
|
|
self.slf_attn = clones(Attention(d_model), hp.N)
|
|
self.pos_ffn = clones(FFN(d_model), hp.N)
|
|
|
|
self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024,
|
|
d_model,
|
|
padding_idx=0), freeze=True)
|
|
|
|
def forward(self, x, pos, return_attns=False):
|
|
# Get character mask
|
|
if self.training:
|
|
c_mask = pos.ne(0).type(torch.float)
|
|
mask = pos.eq(0).unsqueeze(1).repeat(1, x.size(1), 1)
|
|
else:
|
|
c_mask, mask = None, None
|
|
|
|
# Get positional embedding, apply alpha and add
|
|
pos = self.pos_emb(pos)
|
|
x = x + pos
|
|
|
|
# Attention encoder-encoder
|
|
attns = list()
|
|
for slf_attn, ffn in zip(self.slf_attn, self.pos_ffn):
|
|
x, attn = slf_attn(x, x, mask=mask, query_mask=c_mask)
|
|
x = ffn(x)
|
|
attns.append(attn)
|
|
|
|
return x, attns
|