You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

59 lines
1.7 KiB

import time
import copy
import argparse
import torch
import torch.nn as nn
import models
import dataset
import util
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--name', required=True)
args = parser.parse_args()
if 'atis' in args.name:
args.dataset = 'atis'
elif 'snips' in args.name:
args.dataset = 'snips'
if 'intent' in args.name:
args.model = 'intent'
elif 'slot' in args.name:
args.model = 'slot'
elif 'joint' in args.name:
args.model = 'joint'
args.dropout = 0
cuda = torch.cuda.is_available()
train, valid, test, num_words, num_intent, num_slot, wordvecs = dataset.load(args.dataset, batch_size=8, seq_len=50)
model = util.load_model(args.model, num_words, num_intent, num_slot, args.dropout, wordvecs)
model.eval()
model.load_state_dict(torch.load(args.name))
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
if cuda:
model = model.cuda()
best_valid_loss = float('inf')
last_epoch_to_improve = 0
best_model = model
count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count)
if args.model == 'intent':
_, test_acc = util.valid_intent(model, test, criterion, cuda)
print(f"Test Acc: {test_acc:.5f}")
elif args.model == 'slot':
_, test_f1 = util.valid_slot(model, test, criterion, cuda)
print(f"Test F1: {test_f1:.5f}")
elif args.model == 'joint':
_, (_, intent_test_acc), (_, slot_test_f1) = util.valid_joint(model, test, criterion, cuda, 0)
print(f"Test Intent Acc: {intent_test_acc:.5f}, Slot F1: {slot_test_f1:.5f}")