diff --git a/test_query.py b/test_query.py index da8c839..e2febb7 100644 --- a/test_query.py +++ b/test_query.py @@ -33,9 +33,10 @@ def pad_query(sequence): return sequence -query = "What's the weather like in Great Mills right now?" +query = "What's the weather like in York PA right now?" q = query.lower().replace("'", " ").replace("?", " ").strip() -true_length = [len(q.split()), 0, 0, 0, 0, 0, 0 ,0] +# true_length = [len(q.split()), 0, 0, 0, 0, 0, 0 ,0] +true_length = [len(q.split())] qq = torch.from_numpy(pad_query([word2idx[word] for word in q.split()])) model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs) @@ -45,7 +46,8 @@ model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu') criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) pad_tensor = torch.from_numpy(pad_query([word2idx[w] for w in []])) -batch = torch.stack([qq, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor]) +# batch = torch.stack([qq, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor]) +batch = torch.stack([qq]) pred_intent, pred_slots = model(batch) @@ -57,4 +59,37 @@ out_intent = intents[itnt] print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots)) +print("--- %s seconds ---" % (time.time() - start_time)) + +# Write to output file +out = "" +collected_slots = {} +active_slot_words = [] +active_slot_name = None +for words, slot_preds, intent_pred in zip([q.split()], [out_slots], [out_intent]): + line = "" + for word, pred in zip(words, slot_preds): + line = line + word + " " + if pred == 'O': + if active_slot_name: + collected_slots[active_slot_name] = " ".join(active_slot_words) + active_slot_words = [] + active_slot_name = None + else: + # Naive BIO handling: treat B- and I- the same... + new_slot_name = pred[2:] + if active_slot_name is None: + active_slot_words.append(word) + active_slot_name = new_slot_name + elif new_slot_name == active_slot_name: + active_slot_words.append(word) + else: + collected_slots[active_slot_name] = " ".join(active_slot_words) + active_slot_words = [word] + active_slot_name = new_slot_name + out = line.strip() +if active_slot_name: + collected_slots[active_slot_name] = " ".join(active_slot_words) + +print(collected_slots) print("--- %s seconds ---" % (time.time() - start_time)) \ No newline at end of file