|
|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from collections import OrderedDict
- import numpy as np
- import copy
- import math
-
- import hparams as hp
- import utils
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
- def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
- ''' Sinusoid position encoding table '''
-
- def cal_angle(position, hid_idx):
- return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
-
- def get_posi_angle_vec(position):
- return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
-
- sinusoid_table = np.array([get_posi_angle_vec(pos_i)
- for pos_i in range(n_position)])
-
- sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
- sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
-
- if padding_idx is not None:
- # zero vector for padding dimension
- sinusoid_table[padding_idx] = 0.
-
- return torch.FloatTensor(sinusoid_table)
-
-
- def clones(module, N):
- return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
-
-
- class LengthRegulator(nn.Module):
- """ Length Regulator """
-
- def __init__(self):
- super(LengthRegulator, self).__init__()
- self.duration_predictor = DurationPredictor()
-
- def LR(self, x, duration_predictor_output, alpha=1.0, mel_max_length=None):
- output = list()
-
- for batch, expand_target in zip(x, duration_predictor_output):
- output.append(self.expand(batch, expand_target, alpha))
-
- if mel_max_length:
- output = utils.pad(output, mel_max_length)
- else:
- output = utils.pad(output)
-
- return output
-
- def expand(self, batch, predicted, alpha):
- out = list()
-
- for i, vec in enumerate(batch):
- expand_size = predicted[i].item()
- out.append(vec.expand(int(expand_size*alpha), -1))
- out = torch.cat(out, 0)
-
- return out
-
- def rounding(self, num):
- if num - int(num) >= 0.5:
- return int(num) + 1
- else:
- return int(num)
-
- def forward(self, x, alpha=1.0, target=None, mel_max_length=None):
- duration_predictor_output = self.duration_predictor(x)
-
- if self.training:
- output = self.LR(x, target, mel_max_length=mel_max_length)
- return output, duration_predictor_output
- else:
- for idx, ele in enumerate(duration_predictor_output[0]):
- duration_predictor_output[0][idx] = self.rounding(ele)
- output = self.LR(x, duration_predictor_output, alpha)
- mel_pos = torch.stack(
- [torch.Tensor([i+1 for i in range(output.size(1))])]).long().to(device)
-
- return output, mel_pos
-
-
- class DurationPredictor(nn.Module):
- """ Duration Predictor """
-
- def __init__(self):
- super(DurationPredictor, self).__init__()
-
- self.input_size = hp.d_model
- self.filter_size = hp.duration_predictor_filter_size
- self.kernel = hp.duration_predictor_kernel_size
- self.conv_output_size = hp.duration_predictor_filter_size
- self.dropout = hp.dropout
-
- self.conv_layer = nn.Sequential(OrderedDict([
- ("conv1d_1", Conv(self.input_size,
- self.filter_size,
- kernel_size=self.kernel,
- padding=1)),
- ("layer_norm_1", nn.LayerNorm(self.filter_size)),
- ("relu_1", nn.ReLU()),
- ("dropout_1", nn.Dropout(self.dropout)),
- ("conv1d_2", Conv(self.filter_size,
- self.filter_size,
- kernel_size=self.kernel,
- padding=1)),
- ("layer_norm_2", nn.LayerNorm(self.filter_size)),
- ("relu_2", nn.ReLU()),
- ("dropout_2", nn.Dropout(self.dropout))
- ]))
-
- self.linear_layer = Linear(self.conv_output_size, 1)
- self.relu = nn.ReLU()
-
- def forward(self, encoder_output):
- out = self.conv_layer(encoder_output)
- out = self.linear_layer(out)
-
- out = self.relu(out)
-
- out = out.squeeze()
-
- if not self.training:
- out = out.unsqueeze(0)
-
- return out
-
-
- class Conv(nn.Module):
- """
- Convolution Module
- """
-
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- dilation=1,
- bias=True,
- w_init='linear'):
- """
- :param in_channels: dimension of input
- :param out_channels: dimension of output
- :param kernel_size: size of kernel
- :param stride: size of stride
- :param padding: size of padding
- :param dilation: dilation rate
- :param bias: boolean. if True, bias is included.
- :param w_init: str. weight inits with xavier initialization.
- """
- super(Conv, self).__init__()
-
- self.conv = nn.Conv1d(in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=bias)
-
- nn.init.xavier_uniform_(
- self.conv.weight, gain=nn.init.calculate_gain(w_init))
-
- def forward(self, x):
- x = x.contiguous().transpose(1, 2)
- x = self.conv(x)
- x = x.contiguous().transpose(1, 2)
-
- return x
-
-
- class Linear(nn.Module):
- """
- Linear Module
- """
-
- def __init__(self, in_dim, out_dim, bias=True, w_init='linear'):
- """
- :param in_dim: dimension of input
- :param out_dim: dimension of output
- :param bias: boolean. if True, bias is included.
- :param w_init: str. weight inits with xavier initialization.
- """
- super(Linear, self).__init__()
- self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
-
- nn.init.xavier_uniform_(
- self.linear_layer.weight,
- gain=nn.init.calculate_gain(w_init))
-
- def forward(self, x):
- return self.linear_layer(x)
-
-
- class FFN(nn.Module):
- """
- Positionwise Feed-Forward Network
- """
-
- def __init__(self, num_hidden):
- """
- :param num_hidden: dimension of hidden
- """
- super(FFN, self).__init__()
- self.w_1 = Conv(num_hidden, num_hidden * 4,
- kernel_size=3, padding=1, w_init='relu')
- self.w_2 = Conv(num_hidden * 4, num_hidden, kernel_size=3, padding=1)
- self.dropout = nn.Dropout(p=0.1)
- self.layer_norm = nn.LayerNorm(num_hidden)
-
- def forward(self, input_):
- # FFN Network
- x = input_
- x = self.w_2(torch.relu(self.w_1(x)))
-
- # residual connection
- x = x + input_
-
- # dropout
- x = self.dropout(x)
-
- # layer normalization
- x = self.layer_norm(x)
-
- return x
-
-
- class MultiheadAttention(nn.Module):
- """
- Multihead attention mechanism (dot attention)
- """
-
- def __init__(self, num_hidden_k):
- """
- :param num_hidden_k: dimension of hidden
- """
- super(MultiheadAttention, self).__init__()
-
- self.num_hidden_k = num_hidden_k
- self.attn_dropout = nn.Dropout(p=0.1)
-
- def forward(self, key, value, query, mask=None, query_mask=None):
- # Get attention score
- attn = torch.bmm(query, key.transpose(1, 2))
- attn = attn / math.sqrt(self.num_hidden_k)
-
- # Masking to ignore padding (key side)
- if mask is not None:
- attn = attn.masked_fill(mask, -2 ** 32 + 1)
- attn = torch.softmax(attn, dim=-1)
- else:
- attn = torch.softmax(attn, dim=-1)
-
- # Masking to ignore padding (query side)
- if query_mask is not None:
- attn = attn * query_mask
-
- # Dropout
- attn = self.attn_dropout(attn)
-
- # Get Context Vector
- result = torch.bmm(attn, value)
-
- return result, attn
-
-
- class Attention(nn.Module):
- """
- Attention Network
- """
-
- def __init__(self, num_hidden, h=2):
- """
- :param num_hidden: dimension of hidden
- :param h: num of heads
- """
- super(Attention, self).__init__()
-
- self.num_hidden = num_hidden
- self.num_hidden_per_attn = num_hidden // h
- self.h = h
-
- self.key = Linear(num_hidden, num_hidden, bias=False)
- self.value = Linear(num_hidden, num_hidden, bias=False)
- self.query = Linear(num_hidden, num_hidden, bias=False)
-
- self.multihead = MultiheadAttention(self.num_hidden_per_attn)
-
- self.residual_dropout = nn.Dropout(p=0.1)
-
- self.final_linear = Linear(num_hidden * 2, num_hidden)
-
- self.layer_norm_1 = nn.LayerNorm(num_hidden)
-
- def forward(self, memory, decoder_input, mask=None, query_mask=None):
-
- batch_size = memory.size(0)
- seq_k = memory.size(1)
- seq_q = decoder_input.size(1)
-
- # Repeat masks h times
- if query_mask is not None:
- query_mask = query_mask.unsqueeze(-1).repeat(1, 1, seq_k)
- query_mask = query_mask.repeat(self.h, 1, 1)
- if mask is not None:
- mask = mask.repeat(self.h, 1, 1)
-
- # Make multihead
- key = self.key(memory).view(batch_size,
- seq_k,
- self.h,
- self.num_hidden_per_attn)
- value = self.value(memory).view(batch_size,
- seq_k,
- self.h,
- self.num_hidden_per_attn)
- query = self.query(decoder_input).view(batch_size,
- seq_q,
- self.h,
- self.num_hidden_per_attn)
-
- key = key.permute(2, 0, 1, 3).contiguous().view(-1,
- seq_k,
- self.num_hidden_per_attn)
- value = value.permute(2, 0, 1, 3).contiguous().view(-1,
- seq_k,
- self.num_hidden_per_attn)
- query = query.permute(2, 0, 1, 3).contiguous().view(-1,
- seq_q,
- self.num_hidden_per_attn)
-
- # Get context vector
- result, attns = self.multihead(
- key, value, query, mask=mask, query_mask=query_mask)
-
- # Concatenate all multihead context vector
- result = result.view(self.h, batch_size, seq_q,
- self.num_hidden_per_attn)
- result = result.permute(1, 2, 0, 3).contiguous().view(
- batch_size, seq_q, -1)
-
- # Concatenate context vector with input (most important)
- result = torch.cat([decoder_input, result], dim=-1)
-
- # Final linear
- result = self.final_linear(result)
-
- # Residual dropout & connection
- result = self.residual_dropout(result)
- result = result + decoder_input
-
- # Layer normalization
- result = self.layer_norm_1(result)
-
- return result, attns
-
-
- class FFTBlock(torch.nn.Module):
- """FFT Block"""
-
- def __init__(self,
- d_model,
- n_head=hp.Head):
- super(FFTBlock, self).__init__()
- self.slf_attn = clones(Attention(d_model), hp.N)
- self.pos_ffn = clones(FFN(d_model), hp.N)
-
- self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024,
- d_model,
- padding_idx=0), freeze=True)
-
- def forward(self, x, pos, return_attns=False):
- # Get character mask
- if self.training:
- c_mask = pos.ne(0).type(torch.float)
- mask = pos.eq(0).unsqueeze(1).repeat(1, x.size(1), 1)
- else:
- c_mask, mask = None, None
-
- # Get positional embedding, apply alpha and add
- pos = self.pos_emb(pos)
- x = x + pos
-
- # Attention encoder-encoder
- attns = list()
- for slf_attn, ffn in zip(self.slf_attn, self.pos_ffn):
- x, attn = slf_attn(x, x, mask=mask, query_mask=c_mask)
- x = ffn(x)
- attns.append(attn)
-
- return x, attns
|