import torch
|
|
import torch.nn as nn
|
|
|
|
import numpy as np
|
|
|
|
import models
|
|
|
|
import pickle
|
|
|
|
import time
|
|
start_time = time.time()
|
|
|
|
PAD = "<pad>"
|
|
BOS = "<bos>"
|
|
EOS = "<eos>"
|
|
|
|
word2idx = pickle.load(open("word2idx.pkl", "rb"))
|
|
wordvecs = pickle.load(open("wordvecs.pkl", "rb"))
|
|
slots = pickle.load(open("slots.pkl", "rb"))
|
|
slot_filters = pickle.load(open("slot_filters.pkl", "rb"))
|
|
intents = pickle.load(open("intents.pkl", "rb"))
|
|
num_words = len(word2idx)
|
|
num_intent = len(intents)
|
|
num_slot = len(slots)
|
|
filter_count = 300
|
|
dropout = 0
|
|
embedding_dim = 100
|
|
|
|
|
|
def pad_query(sequence):
|
|
sequence = [word2idx[BOS]] + sequence + [word2idx[EOS]]
|
|
sequence = sequence[:50]
|
|
sequence = np.pad(sequence, (0, 50 - len(sequence)), mode='constant', constant_values=(word2idx[PAD],))
|
|
return sequence
|
|
|
|
|
|
def predict(query):
|
|
q = query.lower().replace("'", " ").replace("?", " ").strip()
|
|
true_length = [len(q.split())]
|
|
qq = torch.from_numpy(pad_query([word2idx[word] if word in word2idx else word2idx["<pad>"] for word in q.split()]))
|
|
|
|
model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs)
|
|
model.eval()
|
|
|
|
model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu')))
|
|
|
|
batch = torch.stack([qq])
|
|
pred_intent, pred_slots = model(batch)
|
|
|
|
itnt = pred_intent.max(1)[1].tolist()[0]
|
|
out_intent = intents[itnt]
|
|
|
|
if out_intent in slot_filters:
|
|
b = [1 if x in slot_filters[out_intent] else 0 for x in slots]
|
|
zz = torch.stack([torch.FloatTensor([b]).repeat(50,1).transpose(0,1)])
|
|
pred_slots = torch.mul(pred_slots, zz)
|
|
|
|
slt = [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num] + 1]]
|
|
out_slots = [slots[int(c)] for c in slt]
|
|
|
|
print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots))
|
|
|
|
print("--- %s seconds ---" % (time.time() - start_time))
|
|
|
|
# Write to output file
|
|
out = ""
|
|
collected_slots = {}
|
|
active_slot_words = []
|
|
active_slot_name = None
|
|
for words, slot_preds, intent_pred in zip([q.split()], [out_slots], [out_intent]):
|
|
line = ""
|
|
for word, pred in zip(words, slot_preds):
|
|
line = line + word + " "
|
|
if pred == 'O':
|
|
if active_slot_name:
|
|
collected_slots[active_slot_name] = " ".join(active_slot_words)
|
|
active_slot_words = []
|
|
active_slot_name = None
|
|
else:
|
|
# Naive BIO handling: treat B- and I- the same...
|
|
new_slot_name = pred[2:]
|
|
if active_slot_name is None:
|
|
active_slot_words.append(word)
|
|
active_slot_name = new_slot_name
|
|
elif new_slot_name == active_slot_name:
|
|
active_slot_words.append(word)
|
|
else:
|
|
collected_slots[active_slot_name] = " ".join(active_slot_words)
|
|
active_slot_words = [word]
|
|
active_slot_name = new_slot_name
|
|
out = line.strip()
|
|
if active_slot_name:
|
|
collected_slots[active_slot_name] = " ".join(active_slot_words)
|
|
|
|
print(collected_slots)
|
|
print("--- %s seconds ---" % (time.time() - start_time))
|
|
|
|
predict("What's the weather like in York PA right now?")
|
|
predict("How's the weather in York PA right now?")
|
|
predict("What's the weather like in Great Mills right now?")
|
|
predict("What will the weather be like in Frederick Maryland tomorrow?")
|
|
predict("Play some jazz")
|
|
predict("Play some daft punk")
|
|
predict("Play some hatsune miku")
|