forked from codertimo/BERT-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_dataset.py
46 lines (32 loc) · 1.54 KB
/
build_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from dataset.dataset import BERTDatasetCreator
from dataset import WordVocab
from multiprocessing import Pool
import argparse
import tqdm
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--vocab_path", required=True, type=str)
parser.add_argument("-c", "--corpus_path", required=True, type=str)
parser.add_argument("-e", "--encoding", default="utf-8", type=str)
parser.add_argument("-o", "--output_path", required=True, type=str)
parser.add_argument("-w", "--workers", default=4, type=int)
args = parser.parse_args()
print("Loading Word Vocab", args.vocab_path)
word_vocab = WordVocab.load_vocab(args.vocab_path)
print("VOCAB SIZE=", len(word_vocab))
builder = BERTDatasetCreator(corpus_path=args.corpus_path,
vocab=word_vocab, seq_len=None,
encoding=args.encoding)
def work(i):
data = builder[i]
data["t1_random"], data["t2_random"] = [",".join([str(i) for i in t])
for t in [data["t1_random"], data["t2_random"]]]
data["t1_label"], data["t2_label"] = [",".join([str(i) for i in label])
for label in [data["t1_label"], data["t2_label"]]]
return data
output_form = "%s\t%s\t%s\t%s\t%d\n"
f = open(args.output_path, 'w', encoding=args.encoding, buffering=1)
for i in tqdm.tqdm(range(len(builder)), total=len(builder), desc="Building Dataset"):
d = work(i)
output = output_form % (d["t1_random"], d["t2_random"], d["t1_label"], d["t2_label"], d["is_next"])
f.write(output)
f.close()