import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
|
|
class CNNIntent(nn.Module):
|
|
def __init__(self, input_dim, embedding_dim, output_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
|
|
super().__init__()
|
|
|
|
if wordvecs is not None:
|
|
self.embedding = nn.Embedding.from_pretrained(wordvecs)
|
|
else:
|
|
self.embedding = nn.Embedding(input_dim, embedding_dim)
|
|
|
|
self.convs = nn.ModuleList(
|
|
[nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
|
|
)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.fc = nn.Linear(filter_sizes[-1], output_dim)
|
|
|
|
self.embedding_dim = embedding_dim
|
|
self.filter_sizes = filter_sizes
|
|
self.kernel_size = kernel_size
|
|
self.unpruned_count = sum(filter_sizes)
|
|
|
|
def forward(self, query): # query shape: [batch, seq len]
|
|
x = self.embedding(query).permute(0, 2, 1) # [batch, embedding dim, seq len]
|
|
for conv in self.convs:
|
|
x = conv(x)
|
|
x = torch.rrelu(x)
|
|
x = x.permute(0, 2, 1)
|
|
x, _ = torch.max(x, dim=1)
|
|
return self.fc(self.dropout(x))
|
|
|
|
def prune(self, count, norm=2):
|
|
if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
|
|
exit(0)
|
|
|
|
rankings = [] # list of (conv #, filter #, norm)
|
|
for i, conv in enumerate(self.convs):
|
|
for k, filter in enumerate(conv.weight):
|
|
rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
|
|
rankings.sort(key = lambda x: x[2])
|
|
|
|
for ranking in rankings[:count]:
|
|
conv_num, filter_num, _ = ranking
|
|
|
|
# remove filter
|
|
new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
|
|
self.convs[conv_num].weight[filter_num + 1:]))
|
|
new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
|
|
self.convs[conv_num].bias[filter_num + 1:]))
|
|
|
|
self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
|
|
self.filter_sizes[conv_num] - 1,
|
|
self.kernel_size)
|
|
self.convs[conv_num].weight = nn.Parameter(new_weight)
|
|
self.convs[conv_num].bias = nn.Parameter(new_bias)
|
|
|
|
# update channel in succeeding layer
|
|
if conv_num == len(self.filter_sizes) - 1: # prune linear
|
|
new_weight = torch.cat((self.fc.weight[:,:filter_num], self.fc.weight[:,filter_num + 1:]), dim=1)
|
|
new_bias = self.fc.bias
|
|
|
|
self.fc = nn.Linear(self.fc.in_features - 1, self.fc.out_features)
|
|
self.fc.weight = nn.Parameter(new_weight)
|
|
self.fc.bias = nn.Parameter(new_bias)
|
|
else: # prune conv
|
|
new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
|
|
new_bias = self.convs[conv_num + 1].bias
|
|
|
|
self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
|
|
self.filter_sizes[conv_num + 1],
|
|
self.kernel_size)
|
|
|
|
self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
|
|
self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
|
|
|
|
self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
|
|
|
|
|
|
class CNNSlot(nn.Module):
|
|
def __init__(self, input_dim, embedding_dim, output_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
|
|
super().__init__()
|
|
|
|
if wordvecs is not None:
|
|
self.embedding = nn.Embedding.from_pretrained(wordvecs)
|
|
else:
|
|
self.embedding = nn.Embedding(input_dim, embedding_dim)
|
|
|
|
self.convs = nn.ModuleList(
|
|
[nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
|
|
)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.fc = nn.Linear(filter_sizes[-1], output_dim)
|
|
|
|
self.padding = int((kernel_size - 1) / 2)
|
|
|
|
self.embedding_dim = embedding_dim
|
|
self.unpruned_count = sum(filter_sizes)
|
|
self.filter_sizes = filter_sizes
|
|
self.kernel_size = kernel_size
|
|
|
|
def forward(self, query): # query shape: [batch, seq len]
|
|
x = self.embedding(query) # embedded shape: [batch, seq len, embedding dim]
|
|
x = x.permute(0, 2, 1) # x shape: [batch, embedding dim, seq len]
|
|
for conv in self.convs:
|
|
x = F.pad(x, (self.padding, self.padding)) # x shape: [batch, filter count, seq len]
|
|
x = conv(x)
|
|
x = torch.rrelu(x)
|
|
x = x.permute(0, 2, 1) # x shape: [batch, seq len, filter count]
|
|
x = self.fc(self.dropout(x))
|
|
return x
|
|
|
|
def prune(self, count, norm=2):
|
|
if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
|
|
exit(0)
|
|
|
|
rankings = [] # list of (conv #, filter #, norm)
|
|
for i, conv in enumerate(self.convs):
|
|
for k, filter in enumerate(conv.weight):
|
|
rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
|
|
rankings.sort(key = lambda x: x[2])
|
|
|
|
for ranking in rankings[:count]:
|
|
conv_num, filter_num, _ = ranking
|
|
|
|
# remove filter
|
|
new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
|
|
self.convs[conv_num].weight[filter_num + 1:]))
|
|
new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
|
|
self.convs[conv_num].bias[filter_num + 1:]))
|
|
|
|
self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
|
|
self.filter_sizes[conv_num] - 1,
|
|
self.kernel_size)
|
|
self.convs[conv_num].weight = nn.Parameter(new_weight)
|
|
self.convs[conv_num].bias = nn.Parameter(new_bias)
|
|
|
|
# update channel in succeeding layer
|
|
if conv_num == len(self.filter_sizes) - 1: # prune linear
|
|
new_weight = torch.cat((self.fc.weight[:,:filter_num], self.fc.weight[:,filter_num + 1:]), dim=1)
|
|
new_bias = self.fc.bias
|
|
|
|
self.fc = nn.Linear(self.fc.in_features - 1, self.fc.out_features)
|
|
self.fc.weight = nn.Parameter(new_weight)
|
|
self.fc.bias = nn.Parameter(new_bias)
|
|
else: # prune conv
|
|
new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
|
|
new_bias = self.convs[conv_num + 1].bias
|
|
|
|
self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
|
|
self.filter_sizes[conv_num + 1],
|
|
self.kernel_size)
|
|
|
|
self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
|
|
self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
|
|
|
|
self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
|
|
|
|
|
|
class CNNJoint(nn.Module):
|
|
def __init__(self, input_dim, embedding_dim, intent_dim, slot_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
|
|
super().__init__()
|
|
|
|
if wordvecs is not None:
|
|
self.embedding = nn.Embedding.from_pretrained(wordvecs)
|
|
else:
|
|
self.embedding = nn.Embedding(input_dim, embedding_dim)
|
|
|
|
self.convs = nn.ModuleList(
|
|
[nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
|
|
)
|
|
|
|
self.intent_dropout = nn.Dropout(dropout)
|
|
self.intent_fc = nn.Linear(filter_sizes[-1], intent_dim)
|
|
self.slot_dropout = nn.Dropout(dropout)
|
|
self.slot_fc = nn.Linear(filter_sizes[-1], slot_dim)
|
|
|
|
self.padding = int((kernel_size - 1) / 2)
|
|
self.unpruned_count = sum(filter_sizes)
|
|
self.embedding_dim = embedding_dim
|
|
self.filter_sizes = filter_sizes
|
|
self.kernel_size = kernel_size
|
|
|
|
def forward(self, query):
|
|
x = self.embedding(query).permute(0, 2, 1)
|
|
for conv in self.convs:
|
|
x = F.pad(x, (self.padding, self.padding))
|
|
x = conv(x)
|
|
x = torch.rrelu(x)
|
|
x = x.permute(0, 2, 1)
|
|
|
|
intent_pred = self.intent_fc(self.intent_dropout(torch.max(x, dim=1)[0]))
|
|
slot_pred = self.slot_fc(self.slot_dropout(x))
|
|
|
|
return intent_pred, slot_pred.permute(0, 2, 1)
|
|
|
|
def prune(self, count, norm=2):
|
|
if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
|
|
exit(0)
|
|
|
|
rankings = [] # list of (conv #, filter #, norm)
|
|
for i, conv in enumerate(self.convs):
|
|
for k, filter in enumerate(conv.weight):
|
|
rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
|
|
rankings.sort(key = lambda x: x[2])
|
|
|
|
for ranking in rankings[:count]:
|
|
conv_num, filter_num, _ = ranking
|
|
|
|
# remove filter
|
|
new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
|
|
self.convs[conv_num].weight[filter_num + 1:]))
|
|
new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
|
|
self.convs[conv_num].bias[filter_num + 1:]))
|
|
|
|
self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
|
|
self.filter_sizes[conv_num] - 1,
|
|
self.kernel_size)
|
|
self.convs[conv_num].weight = nn.Parameter(new_weight)
|
|
self.convs[conv_num].bias = nn.Parameter(new_bias)
|
|
|
|
# update channel in succeeding layer
|
|
if conv_num == len(self.filter_sizes) - 1: # prune linear
|
|
new_intent_weight = torch.cat((self.intent_fc.weight[:,:filter_num], self.intent_fc.weight[:,filter_num + 1:]), dim=1)
|
|
new_intent_bias = self.intent_fc.bias
|
|
|
|
self.intent_fc = nn.Linear(self.intent_fc.in_features - 1, self.intent_fc.out_features)
|
|
self.intent_fc.weight = nn.Parameter(new_intent_weight)
|
|
self.intent_fc.bias = nn.Parameter(new_intent_bias)
|
|
|
|
new_slot_weight = torch.cat((self.slot_fc.weight[:,:filter_num], self.slot_fc.weight[:,filter_num + 1:]), dim=1)
|
|
new_slot_bias = self.slot_fc.bias
|
|
|
|
self.slot_fc = nn.Linear(self.slot_fc.in_features - 1, self.slot_fc.out_features)
|
|
self.slot_fc.weight = nn.Parameter(new_slot_weight)
|
|
self.slot_fc.bias = nn.Parameter(new_slot_bias)
|
|
else: # prune conv
|
|
new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
|
|
new_bias = self.convs[conv_num + 1].bias
|
|
|
|
self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
|
|
self.filter_sizes[conv_num + 1],
|
|
self.kernel_size)
|
|
|
|
self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
|
|
self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
|
|
|
|
self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
|