You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

356 lines
12 KiB

4 years ago
  1. import random
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import numpy as np
  6. from sklearn.metrics import accuracy_score
  7. from seqeval.metrics import f1_score
  8. import models
  9. kernel_size = 5
  10. def load_model(model_name, num_words, num_intent, num_slot, dropout, wordvecs=None, embedding_dim=100, filter_count=300):
  11. if model_name == 'intent':
  12. model = models.CNNIntent(num_words, embedding_dim, num_intent, (filter_count,), kernel_size, dropout, wordvecs)
  13. elif model_name == 'slot':
  14. model = models.CNNSlot(num_words, embedding_dim, num_slot, (filter_count,), kernel_size, dropout, wordvecs)
  15. elif model_name == 'joint':
  16. model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), kernel_size, dropout, wordvecs)
  17. return model
  18. def rep(seed=None):
  19. if not seed:
  20. seed = random.randint(0, 10000)
  21. torch.manual_seed(seed)
  22. np.random.seed(seed)
  23. # CUDA
  24. torch.backends.cudnn.deterministic = True
  25. torch.backends.cudnn.benchmark = False
  26. return seed
  27. def train_intent(model, iter, criterion, optimizer, cuda):
  28. model.train()
  29. epoch_loss = 0
  30. true_intents = []
  31. pred_intents = []
  32. for i, batch in enumerate(iter):
  33. optimizer.zero_grad()
  34. query = batch[0]
  35. true_intent = batch[1]
  36. if cuda:
  37. query = query.cuda()
  38. true_intent = true_intent.cuda()
  39. pred_intent = model(query)
  40. true_intents += true_intent.tolist()
  41. pred_intents += pred_intent.max(1)[1].tolist()
  42. loss = criterion(pred_intent, true_intent)
  43. loss.backward()
  44. optimizer.step()
  45. epoch_loss += loss.item()
  46. return epoch_loss / len(iter), accuracy_score(true_intents, pred_intents)
  47. def distill_intent(teacher, student, temperature, iter, criterion, optimizer, cuda):
  48. teacher.eval()
  49. student.train()
  50. true_intents = []
  51. pred_intents = []
  52. epoch_loss = 0
  53. for i, batch in enumerate(iter):
  54. optimizer.zero_grad()
  55. query = batch[0]
  56. true_intent = batch[1]
  57. if cuda:
  58. query = query.cuda()
  59. true_intent = true_intent.cuda()
  60. with torch.no_grad():
  61. teacher_pred_intent = teacher(query)
  62. student_pred_intent = student(query)
  63. true_intents += true_intent.tolist()
  64. pred_intents += student_pred_intent.max(1)[1].tolist()
  65. loss = criterion(F.log_softmax(student_pred_intent / temperature, dim=-1), F.softmax(teacher_pred_intent / temperature, dim=-1))
  66. loss.backward()
  67. optimizer.step()
  68. epoch_loss += loss.item()
  69. return epoch_loss / len(iter), accuracy_score(true_intents, pred_intents)
  70. def valid_intent(model, iter, criterion, cuda):
  71. model.eval()
  72. epoch_loss = 0
  73. true_intents = []
  74. pred_intents = []
  75. for i, batch in enumerate(iter):
  76. query = batch[0]
  77. true_intent = batch[1]
  78. if cuda:
  79. query = query.cuda()
  80. true_intent = true_intent.cuda()
  81. pred_intent = model(query)
  82. true_intents += true_intent.tolist()
  83. pred_intents += pred_intent.max(1)[1].tolist()
  84. loss = criterion(pred_intent, true_intent)
  85. epoch_loss += loss.item()
  86. return epoch_loss / len(iter), accuracy_score(true_intents, pred_intents)
  87. def train_slot(model, iter, criterion, optimizer, cuda):
  88. model.train()
  89. epoch_loss = 0
  90. true_history = []
  91. pred_history = []
  92. for i, batch in enumerate(iter):
  93. optimizer.zero_grad()
  94. query = batch[0]
  95. true_slots = batch[2]
  96. true_length = batch[3]
  97. if cuda:
  98. query = query.cuda()
  99. true_slots = true_slots.cuda()
  100. pred_slots = model(query).permute(0, 2, 1) # batch * slots * seq len
  101. true_history += [str(item) for batch_num, sublist in enumerate(true_slots.tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  102. pred_history += [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  103. loss = criterion(pred_slots, true_slots)
  104. loss.backward()
  105. optimizer.step()
  106. epoch_loss += loss.item()
  107. return epoch_loss / len(iter), f1_score(true_history, pred_history)
  108. def distill_slot(teacher, student, temperature, iter, criterion, optimizer, cuda):
  109. teacher.eval()
  110. student.train()
  111. true_history = []
  112. pred_history = []
  113. epoch_loss = 0
  114. for i, batch in enumerate(iter):
  115. optimizer.zero_grad()
  116. query = batch[0]
  117. true_slots = batch[2]
  118. true_length = batch[3]
  119. if cuda:
  120. query = query.cuda()
  121. true_slots = true_slots.cuda()
  122. true_length = true_length.cuda()
  123. with torch.no_grad():
  124. teacher_pred_slot = teacher(query).permute(0, 2, 1) # batch * slot * seq len
  125. student_pred_slot = student(query).permute(0, 2, 1)
  126. true_history += [str(item) for batch_num, sublist in enumerate(true_slots.tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  127. pred_history += [str(item) for batch_num, sublist in enumerate(student_pred_slot.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  128. loss = criterion(F.log_softmax(student_pred_slot / temperature, dim=1), F.softmax(teacher_pred_slot / temperature, dim=1))
  129. loss.backward()
  130. optimizer.step()
  131. epoch_loss += loss.item()
  132. return epoch_loss / len(iter), f1_score(true_history, pred_history)
  133. def valid_slot(model, iter, criterion, cuda):
  134. model.eval()
  135. epoch_loss = 0
  136. true_history = []
  137. pred_history = []
  138. for i, batch in enumerate(iter):
  139. query = batch[0]
  140. true_slots = batch[2]
  141. true_length = batch[3]
  142. if cuda:
  143. query = query.cuda()
  144. true_slots = true_slots.cuda()
  145. pred_slots = model(query).permute(0, 2, 1) # batch * slots * seq len
  146. true_history += [str(item) for batch_num, sublist in enumerate(true_slots.tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  147. pred_history += [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  148. loss = criterion(pred_slots, true_slots)
  149. epoch_loss += loss.item()
  150. return epoch_loss / len(iter), f1_score(true_history, pred_history)
  151. def train_joint(model, iter, criterion, optimizer, cuda, alpha):
  152. model.train()
  153. epoch_loss = 0
  154. epoch_intent_loss = 0
  155. true_intents = []
  156. pred_intents = []
  157. epoch_slot_loss = 0
  158. true_history = []
  159. pred_history = []
  160. for i, batch in enumerate(iter):
  161. optimizer.zero_grad()
  162. query = batch[0]
  163. true_intent = batch[1]
  164. true_slots = batch[2]
  165. true_length = batch[3]
  166. if cuda:
  167. query = query.cuda()
  168. true_intent = true_intent.cuda()
  169. true_slots = true_slots.cuda()
  170. true_length = true_length.cuda()
  171. pred_intent, pred_slots = model(query)
  172. true_intents += true_intent.tolist()
  173. pred_intents += pred_intent.max(1)[1].tolist()
  174. intent_loss = criterion(pred_intent, true_intent)
  175. epoch_intent_loss += intent_loss
  176. #pred_slots.permute(0, 2, 1)
  177. true_history += [str(item) for batch_num, sublist in enumerate(true_slots.tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  178. pred_history += [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  179. slot_loss = criterion(pred_slots, true_slots)
  180. epoch_slot_loss += slot_loss
  181. loss = alpha * intent_loss + (1 - alpha) * slot_loss
  182. loss.backward()
  183. optimizer.step()
  184. epoch_loss += loss.item()
  185. return (epoch_loss / len(iter),
  186. (epoch_intent_loss / len(iter), accuracy_score(true_intents, pred_intents)),
  187. (epoch_slot_loss / len(iter), f1_score(true_history, pred_history)))
  188. def distill_joint(teacher, student, temperature, iter, criterion, optimizer, cuda, alpha):
  189. teacher.eval()
  190. student.train()
  191. epoch_loss = 0
  192. epoch_intent_loss = 0
  193. true_intents = []
  194. pred_intents = []
  195. epoch_slot_loss = 0
  196. true_history = []
  197. pred_history = []
  198. for i, batch in enumerate(iter):
  199. optimizer.zero_grad()
  200. query = batch[0]
  201. true_intent = batch[1]
  202. true_slots = batch[2]
  203. true_length = batch[3]
  204. if cuda:
  205. query = query.cuda()
  206. true_intent = true_intent.cuda()
  207. true_slots = true_slots.cuda()
  208. true_length = true_length.cuda()
  209. with torch.no_grad():
  210. teacher_pred_intent, teacher_pred_slot = teacher(query)
  211. student_pred_intent, student_pred_slot = student(query)
  212. true_intents += true_intent.tolist()
  213. pred_intents += student_pred_intent.max(1)[1].tolist()
  214. intent_loss = criterion(F.log_softmax(student_pred_intent / temperature, dim=-1), F.softmax(teacher_pred_intent / temperature, dim=-1))
  215. epoch_intent_loss += intent_loss
  216. true_history += [str(item) for batch_num, sublist in enumerate(true_slots.tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  217. pred_history += [str(item) for batch_num, sublist in enumerate(student_pred_slot.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  218. slot_loss = criterion(F.log_softmax(student_pred_slot / temperature, dim=1), F.softmax(teacher_pred_slot / temperature, dim=1))
  219. epoch_slot_loss += slot_loss
  220. loss = alpha * intent_loss + (1 - alpha) * slot_loss
  221. loss.backward()
  222. optimizer.step()
  223. epoch_loss += loss.item()
  224. return (epoch_loss / len(iter),
  225. (epoch_intent_loss / len(iter), accuracy_score(true_intents, pred_intents)),
  226. (epoch_slot_loss / len(iter), f1_score(true_history, pred_history)))
  227. def valid_joint(model, iter, criterion, cuda, alpha):
  228. model.eval()
  229. epoch_loss = 0
  230. epoch_intent_loss = 0
  231. true_intents = []
  232. pred_intents = []
  233. epoch_slot_loss = 0
  234. true_history = []
  235. pred_history = []
  236. for i, batch in enumerate(iter):
  237. query = batch[0]
  238. true_intent = batch[1]
  239. true_slots = batch[2]
  240. true_length = batch[3]
  241. if cuda:
  242. query = query.cuda()
  243. true_intent = true_intent.cuda()
  244. true_slots = true_slots.cuda()
  245. true_length = true_length.cuda()
  246. pred_intent, pred_slots = model(query)
  247. true_intents += true_intent.tolist()
  248. pred_intents += pred_intent.max(1)[1].tolist()
  249. intent_loss = criterion(pred_intent, true_intent)
  250. epoch_intent_loss += intent_loss
  251. #pred_slots.permute(0, 2, 1)
  252. true_history += [str(item) for batch_num, sublist in enumerate(true_slots.tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  253. pred_history += [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num].item() + 1]]
  254. slot_loss = criterion(pred_slots, true_slots)
  255. epoch_slot_loss += slot_loss
  256. loss = alpha * intent_loss + (1 - alpha) * slot_loss
  257. epoch_loss += loss.item()
  258. return (epoch_loss / len(iter),
  259. (epoch_intent_loss / len(iter), accuracy_score(true_intents, pred_intents)),
  260. (epoch_slot_loss / len(iter), f1_score(true_history, pred_history)))