forked from RikVN/AMR
-
Notifications
You must be signed in to change notification settings - Fork 1
/
amr_utils.py
498 lines (442 loc) · 20 KB
/
amr_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
#!/usr/bin/env python
# -*- coding: utf8 -*-
'''General utils and AMR specific utils'''
from collections import defaultdict
import json
import os
import codecs
def get_default_amr():
default='(w / want-01 :ARG0 (b / boy) :ARG1 (g / go-01 :ARG0 b))'
return default
def write_to_file(lst, file_new, split=True):
with codecs.open(file_new, 'w', 'utf-8') as out_f:
if split:
for line in lst:
out_f.write(line.strip() + '\n')
else:
out_f.write(lst)
out_f.close()
def get_files_by_ext(direc, ext):
'''Function that traverses a directory and returns all files that match a certain extension'''
return_files = []
for root, dirs, files in os.walk(direc):
for f in files:
if f.endswith(ext):
return_files.append(os.path.join(root, f))
return return_files
def tokenize_line(line):
new_l = line.replace('(', ' ( ').replace(')',' ) ')
return " ".join(new_l.split())
def reverse_tokenize(new_line):
while ' )' in new_line or '( ' in new_line: #restore tokenizing
new_line = new_line.replace(' )',')').replace('( ','(')
return new_line
def load_dict(d):
'''Funcion that loads json dictionaries'''
with open(d, 'r') as in_f: #load reference dict (based on training data) to settle disputes based on frequency
dic = json.load(in_f)
in_f.close()
return dic
def add_to_dict(d, key, base):
'''Function to add key to dictionary, either add base or start with base'''
if key in d:
d[key] += base
else:
d[key] = base
return d
def countparens(text):
''' proper nested parens counting '''
currcount=0
for i in text:
if i == "(":
currcount+=1
elif i == ")":
currcount-=1
if currcount < 0:
return False
return currcount == 0
def valid_amr(amrtext):
import amr
if not countparens(amrtext): ## wrong parentheses, return false
return False
try:
theamr = amr.AMR.parse_AMR_line(amrtext)
if theamr is None:
return False
print "MAJOR WARNING: couldn't build amr out of "+amrtext+" using smatch code"
else:
return True
except (AttributeError,Exception),e:
print 'Error:',e
return False
return True
class AMR(object):
"""
AMR is a rooted, labeled graph to represent semantics.
This class has the following members:
nodes: list of node in the graph. Its ith element is the name of the ith node. For example, a node name
could be "a1", "b", "g2", .etc
node_values: list of node labels (values) of the graph. Its ith element is the value associated with node i in
nodes list. In AMR, such value is usually a semantic concept (e.g. "boy", "want-01")
root: root node name
relations: list of edges connecting two nodes in the graph. Each entry is a link between two nodes, i.e. a triple
<relation name, node1 name, node 2 name>. In AMR, such link denotes the relation between two semantic
concepts. For example, "arg0" means that one of the concepts is the 0th argument of the other.
attributes: list of edges connecting a node to an attribute name and its value. For example, if the polarity of
some node is negative, there should be an edge connecting this node and "-". A triple < attribute name,
node name, attribute value> is used to represent such attribute. It can also be viewed as a relation.
"""
from collections import defaultdict
import sys
# change this if needed
ERROR_LOG = sys.stderr
# change this if needed
DEBUG_LOG = sys.stderr
def __init__(self, node_list=None, node_value_list=None, relation_list=None, attribute_list=None):
"""
node_list: names of nodes in AMR graph, e.g. "a11", "n"
node_value_list: values of nodes in AMR graph, e.g. "group" for a node named "g"
relation_list: list of relations between two nodes
attribute_list: list of attributes (links between one node and one constant value)
"""
# initialize AMR graph nodes using list of nodes name
# root, by default, is the first in var_list
if node_list is None:
self.nodes = []
self.root = None
else:
self.nodes = node_list[:]
if len(node_list) != 0:
self.root = node_list[0]
else:
self.root = None
if node_value_list is None:
self.node_values = []
else:
self.node_values = node_value_list[:]
if relation_list is None:
self.relations = []
else:
self.relations = relation_list[:]
if attribute_list is None:
self.attributes = []
else:
self.attributes = attribute_list[:]
def rename_node(self, prefix):
"""
Rename AMR graph nodes to prefix + node_index to avoid nodes with the same name in two different AMRs.
"""
node_map_dict = {}
# map each node to its new name (e.g. "a1")
for i in range(0, len(self.nodes)):
node_map_dict[self.nodes[i]] = prefix + str(i)
# update node name
for i, v in enumerate(self.nodes):
self.nodes[i] = node_map_dict[v]
# update node name in relations
for i, d in enumerate(self.relations):
new_dict = {}
for k, v in d.items():
new_dict[node_map_dict[k]] = v
self.relations[i] = new_dict
def get_triples(self):
"""
Get the triples in three lists.
instance_triple: a triple representing an instance. E.g. instance(w, want-01)
attribute triple: relation of attributes, e.g. polarity(w, - )
and relation triple, e.g. arg0 (w, b)
"""
instance_triple = []
relation_triple = []
attribute_triple = []
for i in range(len(self.nodes)):
instance_triple.append(("instance", self.nodes[i], self.node_values[i]))
# k is the other node this node has relation with
# v is relation name
for k, v in self.relations[i].items():
relation_triple.append((v, self.nodes[i], k))
# k2 is the attribute name
# v2 is the attribute value
for k2, v2 in self.attributes[i].items():
attribute_triple.append((k2, self.nodes[i], v2))
return instance_triple, attribute_triple, relation_triple
def get_triples2(self):
"""
Get the triples in two lists:
instance_triple: a triple representing an instance. E.g. instance(w, want-01)
relation_triple: a triple representing all relations. E.g arg0 (w, b) or E.g. polarity(w, - )
Note that we do not differentiate between attribute triple and relation triple. Both are considered as relation
triples.
All triples are represented by (triple_type, argument 1 of the triple, argument 2 of the triple)
"""
instance_triple = []
relation_triple = []
for i in range(len(self.nodes)):
# an instance triple is instance(node name, node value).
# For example, instance(b, boy).
instance_triple.append(("instance", self.nodes[i], self.node_values[i]))
# k is the other node this node has relation with
# v is relation name
for k, v in self.relations[i].items():
relation_triple.append((v, self.nodes[i], k))
# k2 is the attribute name
# v2 is the attribute value
for k2, v2 in self.attributes[i].items():
relation_triple.append((k2, self.nodes[i], v2))
return instance_triple, relation_triple
def __str__(self):
"""
Generate AMR string for better readability
"""
lines = []
for i in range(len(self.nodes)):
lines.append("Node "+ str(i) + " " + self.nodes[i])
lines.append("Value: " + self.node_values[i])
lines.append("Relations:")
for k, v in self.relations[i].items():
lines.append("Node " + k + " via " + v)
for k2, v2 in self.attributes[i].items():
lines.append("Attribute: " + k2 + " value " + v2)
return "\n".join(lines)
def __repr__(self):
return self.__str__()
def output_amr(self):
"""
Output AMR string
"""
print >> DEBUG_LOG, self.__str__()
@staticmethod
def parse_AMR_line(line):
"""
Parse a AMR from line representation to an AMR object.
This parsing algorithm scans the line once and process each character, in a shift-reduce style.
"""
# Current state. It denotes the last significant symbol encountered. 1 for (, 2 for :, 3 for /,
# and 0 for start state or ')'
# Last significant symbol is ( --- start processing node name
# Last significant symbol is : --- start processing relation name
# Last significant symbol is / --- start processing node value (concept name)
# Last significant symbol is ) --- current node processing is complete
# Note that if these symbols are inside parenthesis, they are not significant symbols.
state = 0
# node stack for parsing
stack = []
# current not-yet-reduced character sequence
cur_charseq = []
# key: node name value: node value
node_dict = {}
# node name list (order: occurrence of the node)
node_name_list = []
# key: node name: value: list of (relation name, the other node name)
node_relation_dict1 = defaultdict(list)
# key: node name, value: list of (attribute name, const value) or (relation name, unseen node name)
node_relation_dict2 = defaultdict(list)
# current relation name
cur_relation_name = ""
# having unmatched quote string
in_quote = False
for i, c in enumerate(line.strip()):
if c == " ":
# allow space in relation name
if state == 2:
cur_charseq.append(c)
continue
if c == "\"":
# flip in_quote value when a quote symbol is encountered
# insert placeholder if in_quote from last symbol
if in_quote:
cur_charseq.append('_')
in_quote = not in_quote
elif c == "(":
# not significant symbol if inside quote
if in_quote:
cur_charseq.append(c)
continue
# get the attribute name
# e.g :arg0 (x ...
# at this point we get "arg0"
if state == 2:
# in this state, current relation name should be empty
if cur_relation_name != "":
print >> ERROR_LOG, "Format error when processing ", line[0:i+1]
return None
# update current relation name for future use
cur_relation_name = "".join(cur_charseq).strip()
cur_charseq[:] = []
state = 1
elif c == ":":
# not significant symbol if inside quote
if in_quote:
cur_charseq.append(c)
continue
# Last significant symbol is "/". Now we encounter ":"
# Example:
# :OR (o2 / *OR*
# :mod (o3 / official)
# gets node value "*OR*" at this point
if state == 3:
node_value = "".join(cur_charseq)
# clear current char sequence
cur_charseq[:] = []
# pop node name ("o2" in the above example)
cur_node_name = stack[-1]
# update node name/value map
node_dict[cur_node_name] = node_value
# Last significant symbol is ":". Now we encounter ":"
# Example:
# :op1 w :quant 30
# or :day 14 :month 3
# the problem is that we cannot decide if node value is attribute value (constant)
# or node value (variable) at this moment
elif state == 2:
temp_attr_value = "".join(cur_charseq)
cur_charseq[:] = []
parts = temp_attr_value.split()
if len(parts) < 2:
print >> ERROR_LOG, "Error in processing; part len < 2", line[0:i+1]
return None
# For the above example, node name is "op1", and node value is "w"
# Note that this node name might not be encountered before
relation_name = parts[0].strip()
relation_value = parts[1].strip()
# We need to link upper level node to the current
# top of stack is upper level node
if len(stack) == 0:
print >> ERROR_LOG, "Error in processing", line[:i], relation_name, relation_value
return None
# if we have not seen this node name before
if relation_value not in node_dict:
node_relation_dict2[stack[-1]].append((relation_name, relation_value))
else:
node_relation_dict1[stack[-1]].append((relation_name, relation_value))
state = 2
elif c == "/":
if in_quote:
cur_charseq.append(c)
continue
# Last significant symbol is "(". Now we encounter "/"
# Example:
# (d / default-01
# get "d" here
if state == 1:
node_name = "".join(cur_charseq)
cur_charseq[:] = []
# if this node name is already in node_dict, it is duplicate
if node_name in node_dict:
print >> ERROR_LOG, "Duplicate node name ", node_name, " in parsing AMR"
return None
# push the node name to stack
stack.append(node_name)
# add it to node name list
node_name_list.append(node_name)
# if this node is part of the relation
# Example:
# :arg1 (n / nation)
# cur_relation_name is arg1
# node name is n
# we have a relation arg1(upper level node, n)
if cur_relation_name != "":
# if relation name ends with "-of", e.g."arg0-of",
# it is reverse of some relation. For example, if a is "arg0-of" b,
# we can also say b is "arg0" a.
# If the relation name ends with "-of", we store the reverse relation.
if not cur_relation_name.endswith("-of"):
# stack[-2] is upper_level node we encountered, as we just add node_name to stack
node_relation_dict1[stack[-2]].append((cur_relation_name, node_name))
else:
# cur_relation_name[:-3] is to delete "-of"
node_relation_dict1[node_name].append((cur_relation_name[:-3], stack[-2]))
# clear current_relation_name
cur_relation_name = ""
else:
# error if in other state
print >> ERROR_LOG, "Error in parsing AMR", line[0:i+1]
return None
state = 3
elif c == ")":
if in_quote:
cur_charseq.append(c)
continue
# stack should be non-empty to find upper level node
if len(stack) == 0:
print >> ERROR_LOG, "Unmatched parenthesis at position", i, "in processing", line[0:i+1]
return None
# Last significant symbol is ":". Now we encounter ")"
# Example:
# :op2 "Brown") or :op2 w)
# get \"Brown\" or w here
if state == 2:
temp_attr_value = "".join(cur_charseq)
cur_charseq[:] = []
parts = temp_attr_value.split()
if len(parts) < 2:
print >> ERROR_LOG, "Error processing", line[:i+1], temp_attr_value
return None
relation_name = parts[0].strip()
relation_value = parts[1].strip()
# store reverse of the relation
# we are sure relation_value is a node here, as "-of" relation is only between two nodes
if relation_name.endswith("-of"):
node_relation_dict1[relation_value].append((relation_name[:-3], stack[-1]))
# attribute value not seen before
# Note that it might be a constant attribute value, or an unseen node
# process this after we have seen all the node names
elif relation_value not in node_dict:
node_relation_dict2[stack[-1]].append((relation_name, relation_value))
else:
node_relation_dict1[stack[-1]].append((relation_name, relation_value))
# Last significant symbol is "/". Now we encounter ")"
# Example:
# :arg1 (n / nation)
# we get "nation" here
elif state == 3:
node_value = "".join(cur_charseq)
cur_charseq[:] = []
cur_node_name = stack[-1]
# map node name to its value
node_dict[cur_node_name] = node_value
# pop from stack, as the current node has been processed
stack.pop()
cur_relation_name = ""
state = 0
else:
# not significant symbols, so we just shift.
cur_charseq.append(c)
#create data structures to initialize an AMR
node_value_list = []
relation_list = []
attribute_list = []
for v in node_name_list:
if v not in node_dict:
print >> ERROR_LOG, "Error: Node name not found", v
return None
else:
node_value_list.append(node_dict[v])
# build relation map and attribute map for this node
relation_dict = {}
attribute_dict = {}
if v in node_relation_dict1:
for v1 in node_relation_dict1[v]:
relation_dict[v1[1]] = v1[0]
if v in node_relation_dict2:
for v2 in node_relation_dict2[v]:
# if value is in quote, it is a constant value
# strip the quote and put it in attribute map
if v2[1][0] == "\"" and v2[1][-1] == "\"":
attribute_dict[v2[0]] = v2[1][1:-1]
# if value is a node name
elif v2[1] in node_dict:
relation_dict[v2[1]] = v2[0]
else:
attribute_dict[v2[0]] = v2[1]
# each node has a relation map and attribute map
relation_list.append(relation_dict)
attribute_list.append(attribute_dict)
# add TOP as an attribute. The attribute value is the top node value
attribute_list[0]["TOP"] = node_value_list[0]
#print node_name_list
#print node_value_list
#print relation_list
#print attribute_list,'\n\n'
result_amr = AMR(node_name_list, node_value_list, relation_list, attribute_list)
return result_amr