-
Notifications
You must be signed in to change notification settings - Fork 0
/
csv_to_arff.py
80 lines (69 loc) · 2.55 KB
/
csv_to_arff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Convert csv file to arff file.
"""
import pandas as pd
import sys
import os
def print_help():
"""
Print the help message.
"""
print("""
Usage
======
python pandas_to_arff.py input-file output-file
""")
def main(argv):
"""
The main entrance of the program.
Parameters
----------
argv: list of str
The program arguments. The length of argv must be 3.
The first is the programme, the second and the third arguments are
input file path and output path respectively.
Returns
-------
None
"""
if len(argv) != 3:
print_help()
return
input_file = argv[1]
input_filename = os.path.basename(input_file)
output_file = argv[2]
output_filename = os.path.basename(output_file)
if str(input_filename.split('.')[1]) != 'csv' or str(output_filename.split('.')[1]) != 'arff':
raise NameError('The extension name of file is wrong!')
input_csv_file = pd.read_csv(input_file)
with open(output_file, 'w+') as f:
relation_name = str(input_filename.split('.')[0])
f.write('@relation ' + '\'' + relation_name + '\'\n')
for column in input_csv_file.columns:
type_name = str(input_csv_file[column].dtype)
if type_name == 'int64':
unique_value_list = pd.unique(input_csv_file[column]).tolist()
if len(unique_value_list) == 2 and unique_value_list[0] == 0 and unique_value_list[1] == 1:
f.write('@attribute ' + '\'' + column + '\' ' + '{\'0\', \'1\'}\n')
else:
f.write('@attribute ' + '\'' + column + '\' ' + 'real\n')
elif type_name == 'float64':
f.write('@attribute ' + '\'' + column + '\' ' + 'real\n')
elif type_name == 'object':
unique_value_list = pd.unique(input_csv_file[column]).tolist()
str_list = '{' + ','.join(['\'' + i + '\'' for i in unique_value_list]) + '}'
f.write('@attribute ' + '\'' + column + '\' ' + str_list + '\n')
else:
raise ValueError(f'Unknown data type: {type_name} at column {column}')
f.write('@data\n')
for index, row in input_csv_file.iterrows():
row = row.tolist()
row = list(map(str, row))
f.write(','.join(row) + '\n')
if __name__ == '__main__':
try:
main(sys.argv)
except ValueError as e:
print(e)
except NameError as e:
print(e)