You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

145 lines
6.3 KiB

import time
import copy
import argparse
import torch
import torch.nn as nn
import dataset
import util
import models
from itertools import chain
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--name', required=True)
parser.add_argument('--filename', required=True)
parser.add_argument('--epochs', default=50)
parser.add_argument('--seed', type=int, default=None)
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--alpha', type=float, default=0.2)
args = parser.parse_args()
if 'atis' in args.name:
args.dataset = 'atis'
elif 'snips' in args.name:
args.dataset = 'snips'
if 'intent' in args.name:
args.model = 'intent'
elif 'slot' in args.name:
args.model = 'slot'
elif 'joint' in args.name:
args.model = 'joint'
print(f"seed {util.rep(args.seed)}")
cuda = torch.cuda.is_available()
train, valid, test, num_words, num_intent, num_slot, wordvecs = dataset.load(args.dataset, batch_size=8, seq_len=50)
open(args.filename, 'w').close() # clear the file
f = open(args.filename, "a")
for filter_count in chain(range(300, 10, -5), range(10, 0, -1)):
if args.model == 'intent':
model = models.CNNIntent(num_words, 100, num_intent, (filter_count,), 5, args.dropout, wordvecs)
elif args.model == 'slot':
model = models.CNNSlot(num_words, 100, num_slot, (filter_count,), 5, args.dropout, wordvecs)
elif args.model == 'joint':
model = models.CNNJoint(num_words, 100, num_intent, num_slot, (filter_count,), 5, args.dropout, wordvecs)
teacher = util.load_model(args.model, num_words, num_intent, num_slot, args.dropout, wordvecs)
teacher.load_state_dict(torch.load(args.name))
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
distill_criterion = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(model.parameters())
if cuda:
model = model.cuda()
teacher = teacher.cuda()
best_valid_loss = float('inf')
last_epoch_to_improve = 0
best_model = model
model_filename = f"models/{args.dataset}_{args.model}"
if args.model == 'intent':
for epoch in range(args.epochs):
start_time = time.time()
train_loss, train_acc = util.distill_intent(teacher, model, 1.0, train, distill_criterion, optimizer, cuda)
valid_loss, valid_acc = util.valid_intent(model, valid, criterion, cuda)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Epoch {epoch + 1:03} took {elapsed_time:.3f} seconds")
print(f"\tTrain Loss: {train_loss:.5f}, Acc: {train_acc:.5f}")
print(f"\tValid Loss: {valid_loss:.5f}, Acc: {valid_acc:.5f}")
if valid_loss < best_valid_loss:
last_epoch_to_improve = epoch
best_valid_loss = valid_loss
best_model = copy.deepcopy(model)
print("\tNew best valid loss!")
if last_epoch_to_improve + args.patience < epoch:
break
_, test_acc = util.valid_intent(best_model, test, criterion, cuda)
print(f"Test Acc: {test_acc:.5f}")
print(f"{sum(best_model.filter_sizes)}, {test_acc:.5f}", file=f, flush=True)
elif args.model == 'slot':
for epoch in range(args.epochs):
start_time = time.time()
train_loss, train_f1 = util.distill_slot(teacher, model, 1.0, train, distill_criterion, optimizer, cuda)
valid_loss, valid_f1 = util.valid_slot(model, valid, criterion, cuda)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Epoch {epoch + 1:03} took {elapsed_time:.3f} seconds")
print(f"\tTrain Loss: {train_loss:.5f}, F1: {train_f1:.5f}")
print(f"\tValid Loss: {valid_loss:.5f}, F1: {valid_f1:.5f}")
if valid_loss < best_valid_loss:
last_epoch_to_improve = epoch
best_valid_loss = valid_loss
best_model = copy.deepcopy(model)
print("\tNew best valid loss!")
if last_epoch_to_improve + args.patience < epoch:
break
_, test_f1 = util.valid_slot(best_model, test, criterion, cuda)
print(f"Test F1: {test_f1:.5f}")
print(f"{sum(best_model.filter_sizes)}, {test_f1:.5f}", file=f, flush=True)
elif args.model == 'joint':
for epoch in range(args.epochs):
start_time = time.time()
train_loss, (intent_train_loss, intent_train_acc), (slot_train_loss, slot_train_f1) = util.distill_joint(teacher, model, 1.0, train, distill_criterion, optimizer, cuda, args.alpha)
valid_loss, (intent_valid_loss, intent_valid_acc), (slot_valid_loss, slot_valid_f1) = util.valid_joint(model, valid, criterion, cuda, args.alpha)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Epoch {epoch + 1:03} took {elapsed_time:.3f} seconds")
print(f"\tTrain Loss: {train_loss:.5f}, (Intent Loss: {intent_train_loss:.5f}, Acc: {intent_train_acc:.5f}), (Slot Loss: {slot_train_loss:.5f}, F1: {slot_train_f1:.5f})")
print(f"\tValid Loss: {valid_loss:.5f}, (Intent Loss: {intent_valid_loss:.5f}, Acc: {intent_valid_acc:.5f}), (Slot Loss: {slot_valid_loss:.5f}, F1: {slot_valid_f1:.5f})")
if valid_loss < best_valid_loss:
last_epoch_to_improve = epoch
best_valid_loss = valid_loss
best_model = copy.deepcopy(model)
print("\tNew best valid loss!")
if last_epoch_to_improve + args.patience < epoch:
break
_, (_, intent_test_acc), (_, slot_test_f1) = util.valid_joint(best_model, test, criterion, cuda, args.alpha)
print(f"Test Intent Acc: {intent_test_acc:.5f}, Slot F1: {slot_test_f1:.5f}")
print(f"{sum(best_model.filter_sizes)}, {intent_test_acc:.5f}, {slot_test_f1:.5f}", file=f, flush=True)