-
Notifications
You must be signed in to change notification settings - Fork 38
/
data_loading.py
350 lines (290 loc) · 11.4 KB
/
data_loading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import json
import random
from collections import Counter
from pathlib import Path
from typing import List, Optional, Dict
from datasets import load_dataset
from fire import Fire
from pydantic import BaseModel, Field
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer, BatchEncoding, AutoTokenizer
class TokensLengthAnalyzer(BaseModel, arbitrary_types_allowed=True):
name: str
tokenizer: Optional[PreTrainedTokenizer]
def load(self):
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.name, model_max_length=99999
)
def run(self, texts: List[str], limit: int = 0) -> Dict[str, float]:
if limit:
texts = texts[:limit]
self.load()
tokens = self.tokenizer(texts).input_ids
lengths = sorted(len(lst) for lst in tokens)
info = dict(min=lengths[0], max=lengths[-1], median=lengths[len(lengths) // 2])
info.update({"95_percentile": lengths[round(len(lengths) * 0.95)]})
return info
class TextToTextSample(BaseModel):
source: str
target: str
class TextToTextData(BaseModel):
samples: List[TextToTextSample]
@classmethod
def load(cls, path: str):
with open(path) as f:
all_lines = tqdm(f.readlines(), desc=path)
samples = [TextToTextSample(**json.loads(line)) for line in all_lines]
return cls(samples=samples)
def save(self, path: str):
Path(path).parent.mkdir(exist_ok=True, parents=True)
with open(path, "w") as f:
for sample in self.samples:
print(sample.json(), file=f)
def analyze(self, num: int = 10, tokenizer_name: str = "t5-base"):
random.seed(num)
for sample in random.sample(self.samples, k=num):
print(sample.json(indent=2))
token_checker = TokensLengthAnalyzer(name=tokenizer_name)
info = dict(
total_samples=len(self.samples),
source=str(token_checker.run([sample.source for sample in self.samples])),
target=str(token_checker.run([sample.target for sample in self.samples])),
)
print(json.dumps(info, indent=2))
class AlpacaSample(BaseModel):
instruction: str
input: str
output: str
class AlpacaData(BaseModel):
samples: List[AlpacaSample]
@classmethod
def load(cls, path: str):
with open(path) as f:
raw = json.load(f)
return cls(samples=[AlpacaSample(**r) for r in raw])
def save(self, path: str):
raw = [sample.dict() for sample in self.samples]
Path(path).parent.mkdir(exist_ok=True, parents=True)
with open(path, "w") as f:
json.dump(raw, f)
def as_data(self) -> TextToTextData:
self.analyze()
samples = []
for raw in self.samples:
source = raw.instruction.strip()
if raw.input.strip():
source = source + "\n" + raw.input
samples.append(TextToTextSample(source=source, target=raw.output))
return TextToTextData(samples=samples)
def analyze(self):
info = dict(
alpaca_samples=len(self.samples),
with_context=sum(sample.input.strip() != "" for sample in self.samples),
)
print(json.dumps(info, indent=2))
class TextToTextDataset(Dataset):
def __init__(
self,
path: str,
tokenizer: PreTrainedTokenizer,
max_source_length: int,
max_target_length: int,
):
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.tokenizer = tokenizer
self.data = TextToTextData.load(path)
def __len__(self) -> int:
return len(self.data.samples)
def tokenize(self, text: str, is_source: bool) -> BatchEncoding:
x = self.tokenizer(
text,
max_length=self.max_source_length if is_source else self.max_target_length,
padding="max_length",
truncation=not is_source,
return_tensors="pt",
)
"""
T5 truncates on right by default, but we can easily truncate on left
for the encoder input as there is no special token on the left side
"""
if is_source:
assert x.input_ids.ndim == 2
assert x.input_ids.shape == x.attention_mask.shape
length = x.input_ids.shape[1]
start = max(length - self.max_source_length, 0)
x.input_ids = x.input_ids[:, start:]
x.attention_mask = x.attention_mask[:, start:]
assert x.input_ids.shape[1] == self.max_source_length
return x
def __getitem__(self, i: int) -> dict:
x = self.tokenize(self.data.samples[i].source, is_source=True)
y = self.tokenize(self.data.samples[i].target, is_source=False)
return {
"source_ids": x.input_ids.squeeze(),
"source_mask": x.attention_mask.squeeze(),
"target_ids": y.input_ids.squeeze(),
"target_mask": y.attention_mask.squeeze(),
}
def to_human_readable(self, raw: dict) -> dict:
source = self.tokenizer.decode(raw["source_ids"])
target = self.tokenizer.decode(raw["target_ids"])
return dict(source=source, target=target)
def preprocess_alpaca(
path_in: str = "data/alpaca.json", path_out: str = "data/train.json"
):
data = AlpacaData.load(path_in).as_data()
data.analyze()
data.save(path_out)
def clean_gpt4all_text(text: str) -> str:
text = text.replace("<p>", "")
text = text.replace("</p>", "")
text = text.replace("<pre><code>", "")
text = text.replace("</code></pre>", "")
return text
def preprocess_gpt4all(
path_in: str = "nomic-ai/gpt4all_prompt_generations",
path_out="data/train_gpt4all.json",
):
data = []
for raw in tqdm(load_dataset(path_in, split="train"), desc=path_in):
prompt = clean_gpt4all_text(raw["prompt"])
response = clean_gpt4all_text(raw["response"])
data.append(dict(source=prompt, target=response))
random.seed(0)
TextToTextData(
samples=[TextToTextSample(**raw) for raw in random.sample(data, 1000)]
).analyze()
with open(path_out, "w") as f:
for raw in tqdm(data, desc=path_out):
print(json.dumps(raw), file=f)
class ShareGPTConversation(BaseModel):
speaker: str = Field(alias="from")
value: str
class ShareGPTSample(BaseModel):
id: str
conversations: List[ShareGPTConversation]
def contains_texts(self, texts: List[str], do_lower: bool = True) -> bool:
for c in self.conversations:
for t in texts:
if do_lower and t.lower() in c.value.lower():
return True
elif not do_lower and t in c.value:
return True
return False
def has_empty_text(self) -> bool:
for c in self.conversations:
lst = [char for char in c.value if c.value.isalnum()]
if len(lst) == 0:
return True
return False
class ShareGPTData(BaseModel):
samples: List[ShareGPTSample]
@classmethod
def load(cls, path: str):
with open(path) as f:
samples = [ShareGPTSample(**raw) for raw in json.load(f)]
return cls(samples=samples)
def analyze(self):
speakers = [conv.speaker for s in self.samples for conv in s.conversations]
info = dict(samples=len(self.samples), speakers=Counter(speakers))
print(json.dumps(info, indent=2))
def clean(self, phrases: List[str] = None):
"""
~100k ShareGPT conversations narrowed down to 48k by:
Removing non-english conversations
Removing excessive unicode (indicative of Chinese or Korean text, usually)
Removing excessive repeated characters
Removing various instances "AI Moralizing". Conversations with these phrases were removed:
"""
if phrases is None:
phrases = [
"prioritize human safety",
"ethical principles",
"harmful to human beings",
"September 2021",
"as a language model",
"ethical guidelines",
"as an AI language model",
"my guidelines",
"As an AI",
"prioritize user safety",
"adhere to ethical guidelines",
"harmful consequences",
"potentially harmful",
"dangerous activities",
"promote safety",
"well-being of all users",
"responsible information sharing",
"jeopardize the safety",
"illegal actions or intentions",
"undermine the stability",
"promote the well-being",
"illegal activities or actions",
"adherence to the law",
"potentially be harmful",
"illegal substances or activities",
"committed to promoting",
"safe information",
"lawful information",
"cannot provide guidance",
"cannot provide information",
"unable to offer assistance",
"cannot engage in discussions",
"programming prohibits",
"follow ethical guidelines",
"ensure the safety",
"involves an illegal subject",
"prioritize safety",
"illegal subject",
"prioritize user well-being",
"cannot support or promote",
"activities that could harm",
"pose a risk to others",
"against my programming",
"activities that could undermine",
"potentially dangerous",
"not within the scope",
"designed to prioritize safety",
"not able to provide",
"maintain user safety",
"adhere to safety guidelines",
"dangerous or harmful",
"cannot provide any information",
"focus on promoting safety",
]
self.samples = [
s for s in tqdm(self.samples, desc="clean empty") if s.has_empty_text()
]
self.analyze()
self.samples = [
s
for s in tqdm(self.samples, desc="clean phrases")
if not s.contains_texts(phrases, do_lower=True)
]
self.analyze()
def as_data(self) -> TextToTextData:
samples = []
for s in tqdm(self.samples, desc="as_data"):
for i, conv in enumerate(s.conversations):
prev = s.conversations[max(i - 1, 0)]
if conv.speaker == "gpt" and prev.speaker == "human":
target = conv.value
source = "\n\n".join([conv.value for conv in s.conversations[:i]])
samples.append(TextToTextSample(source=source, target=target))
return TextToTextData(samples=samples)
def preprocess_sharegpt(
path_in: str = "data/sharegpt.json",
path_out="data/train_sharegpt.json",
):
# See: https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered
raw = ShareGPTData.load(path_in)
raw.analyze()
raw.clean()
data = raw.as_data()
data.analyze()
data.save(path_out)
if __name__ == "__main__":
Fire()