diff --git a/.gitignore b/.gitignore index 9574f00..2591f2d 100644 --- a/.gitignore +++ b/.gitignore @@ -121,4 +121,6 @@ dmypy.json # End of https://www.gitignore.io/api/python,visualstudiocode glove/ -venv/ \ No newline at end of file +venv/ +env/ +glove.6B.zip diff --git a/intents.pkl b/intents.pkl new file mode 100644 index 0000000..9f58d84 Binary files /dev/null and b/intents.pkl differ diff --git a/num_words.pkl b/num_words.pkl new file mode 100644 index 0000000..7b29d74 --- /dev/null +++ b/num_words.pkl @@ -0,0 +1 @@ +€Mi/. \ No newline at end of file diff --git a/slots.pkl b/slots.pkl new file mode 100644 index 0000000..f882f5b Binary files /dev/null and b/slots.pkl differ diff --git a/snips_joint b/snips_joint new file mode 100644 index 0000000..e6da05b Binary files /dev/null and b/snips_joint differ diff --git a/test_query.py b/test_query.py new file mode 100644 index 0000000..da8c839 --- /dev/null +++ b/test_query.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn + +import numpy as np + +import models + +import pickle + +import time +start_time = time.time() + +PAD = "" +BOS = "" +EOS = "" + +word2idx = pickle.load(open("word2idx.pkl", "rb")) +wordvecs = pickle.load(open("wordvecs.pkl", "rb")) +slots = pickle.load(open("slots.pkl", "rb")) +intents = pickle.load(open("intents.pkl", "rb")) +num_words = len(word2idx) +num_intent = 7 +num_slot = 72 +filter_count = 300 +dropout = 0 +embedding_dim = 100 + + +def pad_query(sequence): + sequence = [word2idx[BOS]] + sequence + [word2idx[EOS]] + sequence = sequence[:50] + sequence = np.pad(sequence, (0, 50 - len(sequence)), mode='constant', constant_values=(word2idx[PAD],)) + return sequence + + +query = "What's the weather like in Great Mills right now?" +q = query.lower().replace("'", " ").replace("?", " ").strip() +true_length = [len(q.split()), 0, 0, 0, 0, 0, 0 ,0] +qq = torch.from_numpy(pad_query([word2idx[word] for word in q.split()])) + +model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs) +model.eval() + +model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu'))) +criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) + +pad_tensor = torch.from_numpy(pad_query([word2idx[w] for w in []])) +batch = torch.stack([qq, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor]) + +pred_intent, pred_slots = model(batch) + +slt = [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num] + 1]] +out_slots = [slots[int(c)] for c in slt] + +itnt = pred_intent.max(1)[1].tolist()[0] +out_intent = intents[itnt] + +print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots)) + +print("--- %s seconds ---" % (time.time() - start_time)) \ No newline at end of file diff --git a/word2idx.pkl b/word2idx.pkl new file mode 100644 index 0000000..5b84799 Binary files /dev/null and b/word2idx.pkl differ diff --git a/wordvecs.pkl b/wordvecs.pkl new file mode 100644 index 0000000..d3e0758 Binary files /dev/null and b/wordvecs.pkl differ