You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

116 lines
4.9 KiB

4 years ago
  1. import time
  2. import copy
  3. import argparse
  4. import torch
  5. import dataset
  6. import util
  7. if __name__ == "__main__":
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument('--dataset', choices=['atis', 'snips'])
  10. parser.add_argument('--model', choices=['intent', 'slot', 'joint'])
  11. parser.add_argument('--name', default=None)
  12. parser.add_argument('--epochs', default=50)
  13. parser.add_argument('--seed', type=int, default=None)
  14. parser.add_argument('--patience', type=int, default=5)
  15. parser.add_argument('--dropout', type=float, default=0.5)
  16. parser.add_argument('--alpha', type=float, default=0.2)
  17. args = parser.parse_args()
  18. print(f"seed {util.rep(args.seed)}")
  19. cuda = torch.cuda.is_available()
  20. train, valid, test, num_words, num_intent, num_slot, wordvecs = dataset.load(args.dataset, batch_size=8, seq_len=50)
  21. model = util.load_model(args.model, num_words, num_intent, num_slot, args.dropout, wordvecs)
  22. criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
  23. optimizer = torch.optim.Adam(model.parameters())
  24. if cuda:
  25. model = model.cuda()
  26. best_valid_loss = float('inf')
  27. last_epoch_to_improve = 0
  28. best_model = model
  29. model_filename = f"models/{args.dataset}_{args.model}" if not args.name else args.name
  30. if args.model == 'intent':
  31. for epoch in range(0, args.epochs):
  32. start_time = time.time()
  33. train_loss, train_acc = util.train_intent(model, train, criterion, optimizer, cuda)
  34. valid_loss, valid_acc = util.valid_intent(model, valid, criterion, cuda)
  35. end_time = time.time()
  36. elapsed_time = end_time - start_time
  37. print(f"Epoch {epoch + 1:03} took {elapsed_time:.3f} seconds")
  38. print(f"\tTrain Loss: {train_loss:.5f}, Acc: {train_acc:.5f}")
  39. print(f"\tValid Loss: {valid_loss:.5f}, Acc: {valid_acc:.5f}")
  40. if valid_loss < best_valid_loss:
  41. last_epoch_to_improve = epoch
  42. best_valid_loss = valid_loss
  43. best_model = copy.deepcopy(model)
  44. torch.save(model.state_dict(), model_filename)
  45. print("\tNew best valid loss!")
  46. if last_epoch_to_improve + args.patience < epoch:
  47. break
  48. _, test_acc = util.valid_intent(best_model, test, criterion, cuda)
  49. print(f"Test Acc: {test_acc:.5f}")
  50. elif args.model == 'slot':
  51. for epoch in range(0, args.epochs):
  52. start_time = time.time()
  53. train_loss, train_f1 = util.train_slot(model, train, criterion, optimizer, cuda)
  54. valid_loss, valid_f1 = util.valid_slot(model, valid, criterion, cuda)
  55. end_time = time.time()
  56. elapsed_time = end_time - start_time
  57. print(f"Epoch {epoch + 1:03} took {elapsed_time:.3f} seconds")
  58. print(f"\tTrain Loss: {train_loss:.5f}, F1: {train_f1:.5f}")
  59. print(f"\tValid Loss: {valid_loss:.5f}, F1: {valid_f1:.5f}")
  60. if valid_loss < best_valid_loss:
  61. last_epoch_to_improve = epoch
  62. best_valid_loss = valid_loss
  63. best_model = copy.deepcopy(model)
  64. torch.save(model.state_dict(), model_filename)
  65. print("\tNew best valid loss!")
  66. if last_epoch_to_improve + args.patience < epoch:
  67. break
  68. _, test_f1 = util.valid_slot(best_model, test, criterion, cuda)
  69. print(f"Test F1: {test_f1:.5f}")
  70. elif args.model == 'joint':
  71. for epoch in range(0, args.epochs):
  72. start_time = time.time()
  73. train_loss, (intent_train_loss, intent_train_acc), (slot_train_loss, slot_train_f1) = util.train_joint(model, train, criterion, optimizer, cuda, args.alpha)
  74. valid_loss, (intent_valid_loss, intent_valid_acc), (slot_valid_loss, slot_valid_f1) = util.valid_joint(model, valid, criterion, cuda, args.alpha)
  75. end_time = time.time()
  76. elapsed_time = end_time - start_time
  77. print(f"Epoch {epoch + 1:03} took {elapsed_time:.3f} seconds")
  78. print(f"\tTrain Loss: {train_loss:.5f}, (Intent Loss: {intent_train_loss:.5f}, Acc: {intent_train_acc:.5f}), (Slot Loss: {slot_train_loss:.5f}, F1: {slot_train_f1:.5f})")
  79. print(f"\tValid Loss: {valid_loss:.5f}, (Intent Loss: {intent_valid_loss:.5f}, Acc: {intent_valid_acc:.5f}), (Slot Loss: {slot_valid_loss:.5f}, F1: {slot_valid_f1:.5f})")
  80. if valid_loss < best_valid_loss:
  81. last_epoch_to_improve = epoch
  82. best_valid_loss = valid_loss
  83. best_model = copy.deepcopy(model)
  84. torch.save(model.state_dict(), model_filename)
  85. print("\tNew best valid loss!")
  86. if last_epoch_to_improve + args.patience < epoch:
  87. break
  88. _, (_, intent_test_acc), (_, slot_test_f1) = util.valid_joint(best_model, test, criterion, cuda, args.alpha)
  89. print(f"Test Intent Acc: {intent_test_acc:.5f}, Slot F1: {slot_test_f1:.5f}")