You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
3.7 KiB

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import models
  5. import pickle
  6. import time
  7. start_time = time.time()
  8. PAD = "<pad>"
  9. BOS = "<bos>"
  10. EOS = "<eos>"
  11. word2idx = pickle.load(open("word2idx.pkl", "rb"))
  12. wordvecs = pickle.load(open("wordvecs.pkl", "rb"))
  13. slots = pickle.load(open("slots.pkl", "rb"))
  14. slot_filters = pickle.load(open("slot_filters.pkl", "rb"))
  15. intents = pickle.load(open("intents.pkl", "rb"))
  16. num_words = len(word2idx)
  17. num_intent = len(intents)
  18. num_slot = len(slots)
  19. filter_count = 300
  20. dropout = 0
  21. embedding_dim = 100
  22. def pad_query(sequence):
  23. sequence = [word2idx[BOS]] + sequence + [word2idx[EOS]]
  24. sequence = sequence[:50]
  25. sequence = np.pad(sequence, (0, 50 - len(sequence)), mode='constant', constant_values=(word2idx[PAD],))
  26. return sequence
  27. def predict(query):
  28. q = query.lower().replace("'", " ").replace("?", " ").strip()
  29. true_length = [len(q.split())]
  30. qq = torch.from_numpy(pad_query([word2idx[word] if word in word2idx else word2idx["<pad>"] for word in q.split()]))
  31. model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs)
  32. model.eval()
  33. model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu')))
  34. batch = torch.stack([qq])
  35. pred_intent, pred_slots = model(batch)
  36. itnt = pred_intent.max(1)[1].tolist()[0]
  37. out_intent = intents[itnt]
  38. if out_intent in slot_filters:
  39. b = [1 if x in slot_filters[out_intent] else 0 for x in slots]
  40. zz = torch.stack([torch.FloatTensor([b]).repeat(50,1).transpose(0,1)])
  41. pred_slots = torch.mul(pred_slots, zz)
  42. slt = [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num] + 1]]
  43. out_slots = [slots[int(c)] for c in slt]
  44. print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots))
  45. print("--- %s seconds ---" % (time.time() - start_time))
  46. # Write to output file
  47. out = ""
  48. collected_slots = {}
  49. active_slot_words = []
  50. active_slot_name = None
  51. for words, slot_preds, intent_pred in zip([q.split()], [out_slots], [out_intent]):
  52. line = ""
  53. for word, pred in zip(words, slot_preds):
  54. line = line + word + " "
  55. if pred == 'O':
  56. if active_slot_name:
  57. collected_slots[active_slot_name] = " ".join(active_slot_words)
  58. active_slot_words = []
  59. active_slot_name = None
  60. else:
  61. # Naive BIO handling: treat B- and I- the same...
  62. new_slot_name = pred[2:]
  63. if active_slot_name is None:
  64. active_slot_words.append(word)
  65. active_slot_name = new_slot_name
  66. elif new_slot_name == active_slot_name:
  67. active_slot_words.append(word)
  68. else:
  69. collected_slots[active_slot_name] = " ".join(active_slot_words)
  70. active_slot_words = [word]
  71. active_slot_name = new_slot_name
  72. out = line.strip()
  73. if active_slot_name:
  74. collected_slots[active_slot_name] = " ".join(active_slot_words)
  75. print(collected_slots)
  76. print("--- %s seconds ---" % (time.time() - start_time))
  77. predict("What's the weather like in York PA right now?")
  78. predict("How's the weather in York PA right now?")
  79. predict("What's the weather like in Great Mills right now?")
  80. predict("What will the weather be like in Frederick Maryland tomorrow?")
  81. predict("Play some jazz")
  82. predict("Play some daft punk")
  83. predict("Play some hatsune miku")