-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_text_categorizer.py
87 lines (72 loc) · 2.71 KB
/
train_text_categorizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#! /usr/bin/env python
#
# Trains text categorizer using Scikit-Learn. See
# http://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html
#
# Notes:
# - This is a simple wrapper around the code in text_categorizer.py, created
# to clarify how training is done.
#
# TODO:
# - Rename to something like apply_text_categorizer.py, as now supports
# usage with a pre-trained model.
#
#
"""Trains text categorization"""
import sys
import debug
import system
from text_categorizer import TextCategorizer
SHOW_REPORT = system.getenv_bool("SHOW_REPORT", False)
def usage():
"""Show command-line usage"""
# TODO: remove path from script filename
script = (__file__ or "n/a")
system.print_stderr("Usage: {scr} training-file model-file [testing]".format(scr=script))
system.print_stderr("")
system.print_stderr("Notes:")
system.print_stderr("- Use - to indicate the file is not needed (e.g., existing training model).")
system.print_stderr("- You need to supply either training file or model file.")
system.print_stderr("- The testing file is optional when training.")
return
def main(args=None):
"""Entry point for script"""
debug.trace_fmtd(4, "main(): args={a}", a=args)
# Check command line arguments
if args is None:
args = sys.argv
debug.trace_fmtd(4, "len(args)={l}", l=len(args))
if (len(args) == 1) and system.getenv_text("ARGS"):
args += system.getenv_text("ARGS").split()
debug.trace_fmtd(4, "len(args)={l}; args={a}", l=len(args), a=args)
if len(args) <= 2:
usage()
return
training_filename = args[1]
model_filename = args[2]
testing_filename = None
if (len(args) > 3):
testing_filename = args[3]
# Train text categorizer and save model to specified file
text_cat = TextCategorizer()
new_model = False
accuracy = None
if training_filename and (training_filename != "-"):
text_cat.train(training_filename)
new_model = True
if model_filename and (model_filename != "-"):
if new_model:
text_cat.save(model_filename)
else:
text_cat.load(model_filename)
if testing_filename and (testing_filename != "-"):
accuracy = text_cat.test(testing_filename, report=SHOW_REPORT)
print("Accuracy over {f}: {acc}".format(acc=accuracy, f=testing_filename))
# Show usage if nothing actually done (e.g., due to too many -'s for filenames)
## OLD: if (not (new_model or accuracy)):
if (not (new_model or (accuracy is not None))):
usage()
return
#------------------------------------------------------------------------
if __name__ == '__main__':
main()