You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

104 lines
3.7 KiB

import torch
import torch.nn as nn
import numpy as np
import models
import pickle
import time
start_time = time.time()
PAD = "<pad>"
BOS = "<bos>"
EOS = "<eos>"
word2idx = pickle.load(open("word2idx.pkl", "rb"))
wordvecs = pickle.load(open("wordvecs.pkl", "rb"))
slots = pickle.load(open("slots.pkl", "rb"))
slot_filters = pickle.load(open("slot_filters.pkl", "rb"))
intents = pickle.load(open("intents.pkl", "rb"))
num_words = len(word2idx)
num_intent = len(intents)
num_slot = len(slots)
filter_count = 300
dropout = 0
embedding_dim = 100
def pad_query(sequence):
sequence = [word2idx[BOS]] + sequence + [word2idx[EOS]]
sequence = sequence[:50]
sequence = np.pad(sequence, (0, 50 - len(sequence)), mode='constant', constant_values=(word2idx[PAD],))
return sequence
def predict(query):
q = query.lower().replace("'", " ").replace("?", " ").strip()
true_length = [len(q.split())]
qq = torch.from_numpy(pad_query([word2idx[word] if word in word2idx else word2idx["<pad>"] for word in q.split()]))
model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs)
model.eval()
model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu')))
batch = torch.stack([qq])
pred_intent, pred_slots = model(batch)
itnt = pred_intent.max(1)[1].tolist()[0]
out_intent = intents[itnt]
if out_intent in slot_filters:
b = [1 if x in slot_filters[out_intent] else 0 for x in slots]
zz = torch.stack([torch.FloatTensor([b]).repeat(50,1).transpose(0,1)])
pred_slots = torch.mul(pred_slots, zz)
slt = [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num] + 1]]
out_slots = [slots[int(c)] for c in slt]
print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots))
print("--- %s seconds ---" % (time.time() - start_time))
# Write to output file
out = ""
collected_slots = {}
active_slot_words = []
active_slot_name = None
for words, slot_preds, intent_pred in zip([q.split()], [out_slots], [out_intent]):
line = ""
for word, pred in zip(words, slot_preds):
line = line + word + " "
if pred == 'O':
if active_slot_name:
collected_slots[active_slot_name] = " ".join(active_slot_words)
active_slot_words = []
active_slot_name = None
else:
# Naive BIO handling: treat B- and I- the same...
new_slot_name = pred[2:]
if active_slot_name is None:
active_slot_words.append(word)
active_slot_name = new_slot_name
elif new_slot_name == active_slot_name:
active_slot_words.append(word)
else:
collected_slots[active_slot_name] = " ".join(active_slot_words)
active_slot_words = [word]
active_slot_name = new_slot_name
out = line.strip()
if active_slot_name:
collected_slots[active_slot_name] = " ".join(active_slot_words)
print(collected_slots)
print("--- %s seconds ---" % (time.time() - start_time))
predict("What's the weather like in York PA right now?")
predict("How's the weather in York PA right now?")
predict("What's the weather like in Great Mills right now?")
predict("What will the weather be like in Frederick Maryland tomorrow?")
predict("Play some jazz")
predict("Play some daft punk")
predict("Play some hatsune miku")