-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
67 lines (44 loc) · 2.06 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
import config
import tiktoken
import pickle
import os
class WMT16(Dataset):
def __init__(self,
from_disk=True,
dataset = config.dataset,
subset = config.subset,
split = 'train',
tokenizer = config.tokenizer,
cache_file='tokenized_dataset.pkl'):
super().__init__()
#print((self.tokenizer.n_vocab)) #vocab size
if from_disk: ## Not to tokenize dataset every time
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
self.dataset = pickle.load(f)
else:
raise ValueError(f"Tokenized dataset not found at {cache_file}. Tokenize the dataset first by setting 'from_disk' = False")
else:
self.tokenizer = tiktoken.get_encoding(tokenizer)
self.dataset = load_dataset(dataset, subset, split=split)
self.dataset = [[self.tokenizer.encode(sentence['translation']['en']), self.tokenizer.encode(sentence['translation']['tr'])] for sentence in self.dataset]
with open(cache_file, 'wb') as f:
pickle.dump(self.dataset, f)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
#implement embedding + padding for each batch (+ attention masks)
en = torch.LongTensor(self.dataset[index][0]).to(config.device)
tr = torch.LongTensor(self.dataset[index][1]).to(config.device)
#print(self.tokenizer.decode(self.dataset[index][0]))
#print(self.tokenizer.decode(self.dataset[index][1]))
return en, tr
def get_indices_sorted_by_length(self):
# Return indices sorted by the length of sequences
indices = list(range(len(self)))
indices.sort(key=lambda x: len(self[x][0]))
return indices
#[7553, 8143, 25438, 129268, 132928, 144660, 149875, 165566, 165591, 23236]