-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocess.py
More file actions
79 lines (67 loc) · 2.84 KB
/
preprocess.py
File metadata and controls
79 lines (67 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from torchtext.vocab import GloVe
from torchtext import data
from tokenizer import post_ptbtokenizer
class PreprocessData:
def __init__(self,
data_path,
glove_size,
batch_size,
train_file='train.csv',
dev_file='dev.csv'):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Defining the Fields
self.RAW = data.RawField(is_target=False)
self.WORDS = data.Field(batch_first=True,
tokenize=post_ptbtokenizer,
lower=True,
include_lengths=True)
self.CHAR = data.NestedField(data.Field(batch_first=True,
tokenize=list,
lower=True),
tokenize=post_ptbtokenizer)
self.INDEX = data.Field(sequential=False,
unk_token=None,
use_vocab=False)
fields = {
'id': ('id', self.RAW),
'context_ptb_tok': [('context_words', self.WORDS), ('context_char', self.CHAR)],
'question_ptb_tok': [('question_words', self.WORDS), ('question_char', self.CHAR)],
'answer_ptb_tok': [('answer_words', self.WORDS), ('answer_char', self.CHAR)],
'start_idx': ('start_idx', self.INDEX),
'end_idx': ('end_idx', self.INDEX)
}
print('Loading CSV Data Into Torch Tabular Dataset')
self.train, self.dev = data.TabularDataset.splits(
path=data_path,
train=train_file,
validation=dev_file,
format='csv',
fields=fields)
print('Building Vocabulary')
self.CHAR.build_vocab(self.train, self.dev)
self.WORDS.build_vocab(self.train, self.dev, vectors=GloVe(name='6B', dim=glove_size))
print('Creating Iterators')
self.train_iter = PreprocessData.create_train_iterator(self.train, device, batch_size)
self.dev_iter = PreprocessData.create_dev_iterator(self.dev, device, batch_size)
@staticmethod
def create_train_iterator(train, device, batch_size):
train_iter = data.BucketIterator(
train,
batch_size=batch_size,
device=device,
repeat=False,
shuffle=True,
sort_within_batch=True,
sort_key=lambda x: len(x.context_words))
return train_iter
@staticmethod
def create_dev_iterator(dev, device, batch_size):
dev_iter = data.BucketIterator(
dev,
batch_size=batch_size,
device=device,
repeat=False,
sort_within_batch=True,
sort_key=lambda x: len(x.context_words))
return dev_iter