You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

133 lines
5.3 KiB

import collections
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
PAD = "<pad>"
BOS = "<bos>"
EOS = "<eos>"
def build_glove(word2idx, idx2word, dim=100):
word2vecs = {}
with open(f'glove/glove.6B.{dim}d.txt') as glove_file:
for i, line in enumerate(glove_file):
splat = line.split()
word = str(splat.pop(0))
if word in word2idx:
word2vecs[word] = np.array(splat).astype(float)
vectors = []
for word in [idx2word[i] for i in range(len(idx2word))]:
if word in word2vecs:
vectors.append(torch.from_numpy(word2vecs[word]).float())
else:
vectors.append(torch.from_numpy(np.random.normal(0, 0.5, size=(dim,))).float())
return torch.stack(vectors)
class Corpus(Dataset):
def __init__(self, dataset, split_name, seq_len: int):
self.seq_len = seq_len
self.queries = []
self.intents = []
self.slots = []
self.word2idx = {}
self.idx2word = {}
self.intent2idx = {}
self.slot2idx = {}
self._register(PAD)
self._register(BOS)
self._register(EOS)
for split in ['train', 'valid', 'test']:
with open(f'datasets/{dataset}/{split}/label') as intent_file:
for line in intent_file:
intent = line.rstrip()
if intent not in self.intent2idx:
self.intent2idx[intent] = len(self.intent2idx)
with open(f'datasets/{dataset}/{split}/seq.in') as queries_file:
for line in queries_file:
query = line.rstrip().split()
for word in query:
if word not in self.word2idx:
idx = len(self.word2idx)
self.word2idx[word] = idx
self.idx2word[idx] = word
with open(f'datasets/{dataset}/{split}/seq.out') as slotses_file:
for line in slotses_file:
slots = line.rstrip().split()
for slot in slots:
if slot not in self.slot2idx:
self.slot2idx[slot] = len(self.slot2idx)
with open(f'datasets/{dataset}/{split_name}/label') as intent_file:
for line in intent_file:
intent = line.rstrip()
self.intents.append(intent)
with open(f'datasets/{dataset}/{split_name}/seq.in') as queries_file:
for line in queries_file:
query = line.rstrip().split()
self.queries.append(query)
with open(f'datasets/{dataset}/{split_name}/seq.out') as slotses_file:
for line in slotses_file:
slots = line.rstrip().split()
self.slots.append(slots)
self.idx2intent = {v: k for k, v in self.intent2idx.items()}
self.idx2slot = {v : k for k, v in self.slot2idx.items()}
def _register(self, word):
if word in self.word2idx:
return
assert(len(self.idx2word) == len(self.word2idx))
idx = len(self.idx2word)
self.idx2word[idx] = word
self.word2idx[word] = idx
def pad_query(self, sequence):
sequence = [self.word2idx[BOS]] + sequence + [self.word2idx[EOS]]
sequence = sequence[:self.seq_len]
sequence = np.pad(sequence, (0, self.seq_len - len(sequence)), mode='constant', constant_values=(self.word2idx[PAD],))
return sequence
def pad_slots(self, sequence):
sequence = [-1] + sequence + [-1]
sequence = sequence[:self.seq_len]
sequence = np.pad(sequence, (0, self.seq_len - len(sequence)), mode='constant', constant_values=(-1,))
return sequence
def __getitem__(self, i):
query = torch.from_numpy(self.pad_query([self.word2idx[word] for word in self.queries[i]]))
intent = torch.tensor(self.intent2idx[self.intents[i]])
slots = torch.from_numpy(self.pad_slots([self.slot2idx[slot] for slot in self.slots[i]]))
true_length = torch.tensor(min(len(self.queries[i]), self.seq_len))
return query, intent, slots, true_length, (self.queries[i], self.intents[i], self.slots[i]), (self.idx2word, self.idx2intent, self.idx2slot)
def __len__(self):
assert(len(self.queries) == len(self.intents))
return len(self.queries)
def load(dataset, batch_size, seq_len):
train_corpus, valid_corpus, test_corpus = Corpus(dataset, 'train', seq_len), Corpus(dataset, 'valid', seq_len), Corpus(dataset, 'test', seq_len)
# sanity checks
assert(len(train_corpus.word2idx) == len(valid_corpus.word2idx) == len(test_corpus.word2idx))
assert(len(train_corpus.intent2idx) == len(valid_corpus.intent2idx) == len(test_corpus.intent2idx))
assert(len(train_corpus.slot2idx) == len(valid_corpus.slot2idx) == len(test_corpus.slot2idx))
num_words, num_intents, num_slots = len(train_corpus.word2idx), len(train_corpus.intent2idx), len(train_corpus.slot2idx)
wordvecs = build_glove(train_corpus.word2idx, train_corpus.idx2word)
return (DataLoader(train_corpus, batch_size, shuffle=True),
DataLoader(valid_corpus, batch_size, shuffle=False),
DataLoader(test_corpus, batch_size),
num_words, num_intents, num_slots, wordvecs)