-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
24 lines (18 loc) · 721 Bytes
/
dataset.py
File metadata and controls
24 lines (18 loc) · 721 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from transformers import GPT2TokenizerFast
import numpy as np, os
if not os.path.exists("tokens.npy"):
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
text = open("shakespeare.txt", "r").read()
tokens = tokenizer.encode(text)
tokens = np.array(tokens)
np.save("tokens.npy", tokens)
class TokenDataset:
def __init__(self, max_tokens=1024):
self.max_tokens = max_tokens
self.tokens = np.load("tokens.npy")
def __len__(self):
return len(self.tokens) - self.max_tokens
def __getitem__(self, index):
x = self.tokens[index:index+self.max_tokens]
y = self.tokens[self.max_tokens+index]
return x, y