You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.9 KiB

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import models
  5. import pickle
  6. import time
  7. start_time = time.time()
  8. PAD = "<pad>"
  9. BOS = "<bos>"
  10. EOS = "<eos>"
  11. word2idx = pickle.load(open("word2idx.pkl", "rb"))
  12. wordvecs = pickle.load(open("wordvecs.pkl", "rb"))
  13. slots = pickle.load(open("slots.pkl", "rb"))
  14. intents = pickle.load(open("intents.pkl", "rb"))
  15. num_words = len(word2idx)
  16. num_intent = 7
  17. num_slot = 72
  18. filter_count = 300
  19. dropout = 0
  20. embedding_dim = 100
  21. def pad_query(sequence):
  22. sequence = [word2idx[BOS]] + sequence + [word2idx[EOS]]
  23. sequence = sequence[:50]
  24. sequence = np.pad(sequence, (0, 50 - len(sequence)), mode='constant', constant_values=(word2idx[PAD],))
  25. return sequence
  26. query = "What's the weather like in Great Mills right now?"
  27. q = query.lower().replace("'", " ").replace("?", " ").strip()
  28. true_length = [len(q.split()), 0, 0, 0, 0, 0, 0 ,0]
  29. qq = torch.from_numpy(pad_query([word2idx[word] for word in q.split()]))
  30. model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs)
  31. model.eval()
  32. model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu')))
  33. criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
  34. pad_tensor = torch.from_numpy(pad_query([word2idx[w] for w in []]))
  35. batch = torch.stack([qq, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor])
  36. pred_intent, pred_slots = model(batch)
  37. slt = [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num] + 1]]
  38. out_slots = [slots[int(c)] for c in slt]
  39. itnt = pred_intent.max(1)[1].tolist()[0]
  40. out_intent = intents[itnt]
  41. print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots))
  42. print("--- %s seconds ---" % (time.time() - start_time))