-
Notifications
You must be signed in to change notification settings - Fork 2
/
sample_multilingual_corpus.py
99 lines (84 loc) · 2.69 KB
/
sample_multilingual_corpus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import json
import glob
import re
import os
import math
import numpy as np
import random
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--lang_prob_path",
default=None,
type=str,
required=True,
help="path to lang prob file",
)
parser.add_argument(
"--input_dir",
default=None,
type=str,
required=True,
help="input_dir",
)
parser.add_argument(
"--output_path",
default=None,
type=str,
required=True,
help="output_path.",
)
parser.add_argument(
"--n_sample",
default=20000000,
type=int,
required=True,
help="lines to sample.",
)
parser.add_argument(
"--beta",
default=0.7,
type=float,
required=False,
help="beta",
)
parser.add_argument(
"--rescale",
action="store_true"
)
args = parser.parse_args()
lang_prob_dict = json.loads(open(args.lang_prob_path, "r").readlines()[0])
z = sum([lang_prob_dict[lang] for lang in lang_prob_dict])
for lang in sorted(lang_prob_dict.keys()):
lang_prob_dict[lang] = lang_prob_dict[lang] / z
print(sum([lang_prob_dict[lang] for lang in lang_prob_dict]))
print(lang_prob_dict)
if args.rescale and args.beta != 1:
print("renorm lang_prob with beta = {}.".format(args.beta))
z = sum([math.pow(lang_prob_dict[lang], args.beta) for lang in lang_prob_dict])
for lang in sorted(lang_prob_dict.keys()):
lang_prob_dict[lang] = math.pow(lang_prob_dict[lang], args.beta) / z
n_sample = {}
for lang in lang_prob_dict:
n_sample[lang] = int(lang_prob_dict[lang] * args.n_sample)
random.seed(1)
with open(args.output_path, "w") as fout:
for file in os.listdir(args.input_dir):
input_path = os.path.join(args.input_dir, file)
if not os.path.exists(input_path):
print("{} does not exist.".format(input_path))
continue
else:
print("processing {}.".format(file))
lang = file.split(".")[0]
if lang not in n_sample:
print("skipping language {}.".format(lang))
continue
print(lang, n_sample[lang])
lines = open(input_path, "r").readlines()
random.shuffle(lines)
assert len(lines) >= n_sample[lang]
for line in lines[:n_sample[lang]]:
fout.write(line.strip() + '\n')