You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

251 lines
12 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class CNNIntent(nn.Module):
def __init__(self, input_dim, embedding_dim, output_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
super().__init__()
if wordvecs is not None:
self.embedding = nn.Embedding.from_pretrained(wordvecs)
else:
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.convs = nn.ModuleList(
[nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(filter_sizes[-1], output_dim)
self.embedding_dim = embedding_dim
self.filter_sizes = filter_sizes
self.kernel_size = kernel_size
self.unpruned_count = sum(filter_sizes)
def forward(self, query): # query shape: [batch, seq len]
x = self.embedding(query).permute(0, 2, 1) # [batch, embedding dim, seq len]
for conv in self.convs:
x = conv(x)
x = torch.rrelu(x)
x = x.permute(0, 2, 1)
x, _ = torch.max(x, dim=1)
return self.fc(self.dropout(x))
def prune(self, count, norm=2):
if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
exit(0)
rankings = [] # list of (conv #, filter #, norm)
for i, conv in enumerate(self.convs):
for k, filter in enumerate(conv.weight):
rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
rankings.sort(key = lambda x: x[2])
for ranking in rankings[:count]:
conv_num, filter_num, _ = ranking
# remove filter
new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
self.convs[conv_num].weight[filter_num + 1:]))
new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
self.convs[conv_num].bias[filter_num + 1:]))
self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
self.filter_sizes[conv_num] - 1,
self.kernel_size)
self.convs[conv_num].weight = nn.Parameter(new_weight)
self.convs[conv_num].bias = nn.Parameter(new_bias)
# update channel in succeeding layer
if conv_num == len(self.filter_sizes) - 1: # prune linear
new_weight = torch.cat((self.fc.weight[:,:filter_num], self.fc.weight[:,filter_num + 1:]), dim=1)
new_bias = self.fc.bias
self.fc = nn.Linear(self.fc.in_features - 1, self.fc.out_features)
self.fc.weight = nn.Parameter(new_weight)
self.fc.bias = nn.Parameter(new_bias)
else: # prune conv
new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
new_bias = self.convs[conv_num + 1].bias
self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
self.filter_sizes[conv_num + 1],
self.kernel_size)
self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
class CNNSlot(nn.Module):
def __init__(self, input_dim, embedding_dim, output_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
super().__init__()
if wordvecs is not None:
self.embedding = nn.Embedding.from_pretrained(wordvecs)
else:
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.convs = nn.ModuleList(
[nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(filter_sizes[-1], output_dim)
self.padding = int((kernel_size - 1) / 2)
self.embedding_dim = embedding_dim
self.unpruned_count = sum(filter_sizes)
self.filter_sizes = filter_sizes
self.kernel_size = kernel_size
def forward(self, query): # query shape: [batch, seq len]
x = self.embedding(query) # embedded shape: [batch, seq len, embedding dim]
x = x.permute(0, 2, 1) # x shape: [batch, embedding dim, seq len]
for conv in self.convs:
x = F.pad(x, (self.padding, self.padding)) # x shape: [batch, filter count, seq len]
x = conv(x)
x = torch.rrelu(x)
x = x.permute(0, 2, 1) # x shape: [batch, seq len, filter count]
x = self.fc(self.dropout(x))
return x
def prune(self, count, norm=2):
if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
exit(0)
rankings = [] # list of (conv #, filter #, norm)
for i, conv in enumerate(self.convs):
for k, filter in enumerate(conv.weight):
rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
rankings.sort(key = lambda x: x[2])
for ranking in rankings[:count]:
conv_num, filter_num, _ = ranking
# remove filter
new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
self.convs[conv_num].weight[filter_num + 1:]))
new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
self.convs[conv_num].bias[filter_num + 1:]))
self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
self.filter_sizes[conv_num] - 1,
self.kernel_size)
self.convs[conv_num].weight = nn.Parameter(new_weight)
self.convs[conv_num].bias = nn.Parameter(new_bias)
# update channel in succeeding layer
if conv_num == len(self.filter_sizes) - 1: # prune linear
new_weight = torch.cat((self.fc.weight[:,:filter_num], self.fc.weight[:,filter_num + 1:]), dim=1)
new_bias = self.fc.bias
self.fc = nn.Linear(self.fc.in_features - 1, self.fc.out_features)
self.fc.weight = nn.Parameter(new_weight)
self.fc.bias = nn.Parameter(new_bias)
else: # prune conv
new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
new_bias = self.convs[conv_num + 1].bias
self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
self.filter_sizes[conv_num + 1],
self.kernel_size)
self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
class CNNJoint(nn.Module):
def __init__(self, input_dim, embedding_dim, intent_dim, slot_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
super().__init__()
if wordvecs is not None:
self.embedding = nn.Embedding.from_pretrained(wordvecs)
else:
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.convs = nn.ModuleList(
[nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
)
self.intent_dropout = nn.Dropout(dropout)
self.intent_fc = nn.Linear(filter_sizes[-1], intent_dim)
self.slot_dropout = nn.Dropout(dropout)
self.slot_fc = nn.Linear(filter_sizes[-1], slot_dim)
self.padding = int((kernel_size - 1) / 2)
self.unpruned_count = sum(filter_sizes)
self.embedding_dim = embedding_dim
self.filter_sizes = filter_sizes
self.kernel_size = kernel_size
def forward(self, query):
x = self.embedding(query).permute(0, 2, 1)
for conv in self.convs:
x = F.pad(x, (self.padding, self.padding))
x = conv(x)
x = torch.rrelu(x)
x = x.permute(0, 2, 1)
intent_pred = self.intent_fc(self.intent_dropout(torch.max(x, dim=1)[0]))
slot_pred = self.slot_fc(self.slot_dropout(x))
return intent_pred, slot_pred.permute(0, 2, 1)
def prune(self, count, norm=2):
if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
exit(0)
rankings = [] # list of (conv #, filter #, norm)
for i, conv in enumerate(self.convs):
for k, filter in enumerate(conv.weight):
rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
rankings.sort(key = lambda x: x[2])
for ranking in rankings[:count]:
conv_num, filter_num, _ = ranking
# remove filter
new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
self.convs[conv_num].weight[filter_num + 1:]))
new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
self.convs[conv_num].bias[filter_num + 1:]))
self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
self.filter_sizes[conv_num] - 1,
self.kernel_size)
self.convs[conv_num].weight = nn.Parameter(new_weight)
self.convs[conv_num].bias = nn.Parameter(new_bias)
# update channel in succeeding layer
if conv_num == len(self.filter_sizes) - 1: # prune linear
new_intent_weight = torch.cat((self.intent_fc.weight[:,:filter_num], self.intent_fc.weight[:,filter_num + 1:]), dim=1)
new_intent_bias = self.intent_fc.bias
self.intent_fc = nn.Linear(self.intent_fc.in_features - 1, self.intent_fc.out_features)
self.intent_fc.weight = nn.Parameter(new_intent_weight)
self.intent_fc.bias = nn.Parameter(new_intent_bias)
new_slot_weight = torch.cat((self.slot_fc.weight[:,:filter_num], self.slot_fc.weight[:,filter_num + 1:]), dim=1)
new_slot_bias = self.slot_fc.bias
self.slot_fc = nn.Linear(self.slot_fc.in_features - 1, self.slot_fc.out_features)
self.slot_fc.weight = nn.Parameter(new_slot_weight)
self.slot_fc.bias = nn.Parameter(new_slot_bias)
else: # prune conv
new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
new_bias = self.convs[conv_num + 1].bias
self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
self.filter_sizes[conv_num + 1],
self.kernel_size)
self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])