Skip to content

Commit

Permalink
gh-121210: handle nodes with missing attributes/fields in `ast.compar…
Browse files Browse the repository at this point in the history
…e` (#121211)
  • Loading branch information
picnixz authored Jul 2, 2024
1 parent 7a807c3 commit 15232a0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
19 changes: 15 additions & 4 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ def compare(
might differ in whitespace or similar details.
"""

sentinel = object() # handle the possibility of a missing attribute/field

def _compare(a, b):
# Compare two fields on an AST object, which may themselves be
# AST objects, lists of AST objects, or primitive ASDL types
Expand Down Expand Up @@ -449,8 +451,14 @@ def _compare_fields(a, b):
if a._fields != b._fields:
return False
for field in a._fields:
a_field = getattr(a, field)
b_field = getattr(b, field)
a_field = getattr(a, field, sentinel)
b_field = getattr(b, field, sentinel)
if a_field is sentinel and b_field is sentinel:
# both nodes are missing a field at runtime
continue
if a_field is sentinel or b_field is sentinel:
# one of the node is missing a field
return False
if not _compare(a_field, b_field):
return False
else:
Expand All @@ -461,8 +469,11 @@ def _compare_attributes(a, b):
return False
# Attributes are always ints.
for attr in a._attributes:
a_attr = getattr(a, attr)
b_attr = getattr(b, attr)
a_attr = getattr(a, attr, sentinel)
b_attr = getattr(b, attr, sentinel)
if a_attr is sentinel and b_attr is sentinel:
# both nodes are missing an attribute at runtime
continue
if a_attr != b_attr:
return False
else:
Expand Down
19 changes: 19 additions & 0 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,15 @@ def test_compare_fieldless(self):
self.assertTrue(ast.compare(ast.Add(), ast.Add()))
self.assertFalse(ast.compare(ast.Sub(), ast.Add()))

# test that missing runtime fields is handled in ast.compare()
a1, a2 = ast.Name('a'), ast.Name('a')
self.assertTrue(ast.compare(a1, a2))
self.assertTrue(ast.compare(a1, a2))
del a1.id
self.assertFalse(ast.compare(a1, a2))
del a2.id
self.assertTrue(ast.compare(a1, a2))

def test_compare_modes(self):
for mode, sources in (
("exec", exec_tests),
Expand All @@ -970,6 +979,16 @@ def parse(a, b):
self.assertTrue(ast.compare(a, b, compare_attributes=False))
self.assertFalse(ast.compare(a, b, compare_attributes=True))

def test_compare_attributes_option_missing_attribute(self):
# test that missing runtime attributes is handled in ast.compare()
a1, a2 = ast.Name('a', lineno=1), ast.Name('a', lineno=1)
self.assertTrue(ast.compare(a1, a2))
self.assertTrue(ast.compare(a1, a2, compare_attributes=True))
del a1.lineno
self.assertFalse(ast.compare(a1, a2, compare_attributes=True))
del a2.lineno
self.assertTrue(ast.compare(a1, a2, compare_attributes=True))

def test_positional_only_feature_version(self):
ast.parse('def foo(x, /): ...', feature_version=(3, 8))
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Handle AST nodes with missing runtime fields or attributes in
:func:`ast.compare`. Patch by Bénédikt Tran.

0 comments on commit 15232a0

Please sign in to comment.