You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.7 KiB

4 years ago
  1. import time
  2. import copy
  3. import argparse
  4. import torch
  5. import torch.nn as nn
  6. import models
  7. import dataset
  8. import util
  9. if __name__ == "__main__":
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--name', required=True)
  12. args = parser.parse_args()
  13. if 'atis' in args.name:
  14. args.dataset = 'atis'
  15. elif 'snips' in args.name:
  16. args.dataset = 'snips'
  17. if 'intent' in args.name:
  18. args.model = 'intent'
  19. elif 'slot' in args.name:
  20. args.model = 'slot'
  21. elif 'joint' in args.name:
  22. args.model = 'joint'
  23. args.dropout = 0
  24. cuda = torch.cuda.is_available()
  25. train, valid, test, num_words, num_intent, num_slot, wordvecs = dataset.load(args.dataset, batch_size=8, seq_len=50)
  26. model = util.load_model(args.model, num_words, num_intent, num_slot, args.dropout, wordvecs)
  27. model.eval()
  28. model.load_state_dict(torch.load(args.name))
  29. criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
  30. if cuda:
  31. model = model.cuda()
  32. best_valid_loss = float('inf')
  33. last_epoch_to_improve = 0
  34. best_model = model
  35. count = sum(p.numel() for p in model.parameters() if p.requires_grad)
  36. print(count)
  37. if args.model == 'intent':
  38. _, test_acc = util.valid_intent(model, test, criterion, cuda)
  39. print(f"Test Acc: {test_acc:.5f}")
  40. elif args.model == 'slot':
  41. _, test_f1 = util.valid_slot(model, test, criterion, cuda)
  42. print(f"Test F1: {test_f1:.5f}")
  43. elif args.model == 'joint':
  44. _, (_, intent_test_acc), (_, slot_test_f1) = util.valid_joint(model, test, criterion, cuda, 0)
  45. print(f"Test Intent Acc: {intent_test_acc:.5f}, Slot F1: {slot_test_f1:.5f}")