This repository has been archived by the owner on May 11, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
classifier.py
98 lines (83 loc) · 3.2 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
from sklearn import svm
import helpers
import heuristics
from config import Credentials
import psycopg2
# Connect to Postgres
connection = psycopg2.connect(
dbname=Credentials.PG_DATABASE,
user=Credentials.PG_USERNAME,
password=Credentials.PG_PASSWORD,
host=Credentials.PG_HOST,
port=Credentials.PG_PORT
)
cursor = connection.cursor()
def create():
cursor.execute("""
SELECT
a.area_id,
page_no,
labels.name AS label,
has_words::int,
line_intersect::int,
small_text::int,
small_leading::int,
is_line::int,
is_top_or_bottom::int,
mostly_blank::int,
very_separated_words::int,
little_word_coverage::int,
normal_word_separation::int,
normal_word_coverage::int,
best_caption::int,
good_caption::int,
ok_caption::int,
overlap::int,
offset_words::int,
proportion_alpha::float,
area::float / (select max(area) from areas) as area,
n_gaps::float / (select max(n_gaps) from areas) as n_gaps,
n_lines::float / (select max(n_lines) from areas) as n_lines
FROM areas a
JOIN area_labels al ON al.area_id = a.area_id
JOIN labels ON labels.label_id = al.label_id
""")
data = cursor.fetchall()
# Omit area_id, doc_id, page_no, and label_name
train = [ list(d[3:]) for d in data ]
label = np.array([ d[2] for d in data ])
#index = [ d[0:3] for d in data ]
# gamma - influence of a single training example. low = far, high = close
# C - low = less freedom, high = more freedom
#clf = svm.SVC(gamma=0.001, C=100., probability=True, cache_size=500)
clf = svm.SVC(gamma=1, C=100, probability=True, cache_size=500, kernel='rbf')
clf.fit(train, label)
return clf
def classify(pages, doc_stats):
clf = create()
for idx, page in enumerate(pages):
for area in page['areas']:
classification = heuristics.classify_list(area, doc_stats, page['areas'])
estimated_label = clf.predict([classification])[0]
p = zip(clf.classes_, clf.predict_proba([classification])[0])
best_p = max([ d[1] for d in p if d[0] != 'other' ])
if best_p < 0.6:
estimated_label = 'unknown'
area['label'] = estimated_label
# Go through again, validating body areas
# if a caption can't be expanded without running into body, it's not a caption
for area in page['areas']:
if area['label'] == 'graphic caption':
valid = False
for each in [ d for d in page['areas'] if d['label'] == 'graphic']:
if valid:
break
expanded = helpers.enlarge_extract(area, each)
for body in [ q for q in page['areas'] if q['label'] == 'body']:
if not helpers.rectangles_intersect(expanded, body):
valid = True
break
if not valid:
area['label'] = 'unknown'
return pages