Skip to content

Commit

Permalink
Propertly deserialize three-levels nested structures
Browse files Browse the repository at this point in the history
In some cases, the "functions" attribute of a RangeFunction to mention
one, the value is a tuple-of-tuples-of-dicts: replace existing
two-levels hardcoded logic with a recursive solution.

This fixes issue #153.
  • Loading branch information
lelit committed Aug 23, 2024
1 parent 7b472c5 commit effc266
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 11 deletions.
25 changes: 15 additions & 10 deletions pglast/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@
SlotTypeInfo = namedtuple('SlotTypeInfo', ['c_type', 'py_type', 'adaptor'])


def _deserialize_value(value):
if isinstance(value, dict) and '@' in value:
G = globals()
if len(value) > 1:
result = G[value['@']](value)
else:
result = G[value['@']]()
elif isinstance(value, (tuple, list)):
result = tuple(_deserialize_value(item) for item in value)
else:
result = value
return result


def _serialize_node(n, depth, ellipsis, skip_none):
d = {'@': n.__class__.__name__}
for a in n:
Expand Down Expand Up @@ -69,19 +83,10 @@ def __init__(self, data):
raise ValueError(f'Bad argument, wrong "@" value, expected'
f' {self.__class__.__name__!r}, got {data["@"]!r}')

G = globals()
for a in self:
v = data.get(a)
if v is not None:
if isinstance(v, dict) and '@' in v:
if len(v) > 1:
v = G[v['@']](v)
else:
v = G[v['@']]()
elif isinstance(v, (tuple, list)):
v = tuple((G[i['@']](i) if len(i) > 1 else G[i['@']]())
if isinstance(i, dict) and '@' in i else i
for i in v)
v = _deserialize_value(v)
setattr(self, a, v)

def __iter__(self):
Expand Down
44 changes: 43 additions & 1 deletion tests/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# :Created: sab 29 mag 2021, 21:25:46
# :Author: Lele Gaifax <[email protected]>
# :License: GNU General Public License version 3 or later
# :Copyright: © 2021, 2022, 2023 Lele Gaifax
# :Copyright: © 2021, 2022, 2023, 2024 Lele Gaifax
#

import pytest
Expand Down Expand Up @@ -104,3 +104,45 @@ def test_issue_97():
def test_issue_138():
raw = parse_sql('select * from foo')[0]
ast.RawStmt(raw())


def test_issue_153():
selstmt = parse_sql('select t.y from f(5) as t')[0].stmt
serialized = selstmt()
assert serialized['@'] == 'SelectStmt'
clone = ast.SelectStmt(serialized)
orig_fromc = selstmt.fromClause[0]
orig_fc_funcs = orig_fromc.functions
clone_fromc = clone.fromClause[0]
clone_fc_funcs = clone_fromc.functions
assert orig_fc_funcs == clone_fc_funcs
assert selstmt == clone


def test_issue_153b():
serialized = {
'@': 'RangeFunction',
'alias': {'@': 'Alias', 'aliasname': 'tmp', 'colnames': None},
'coldeflist': None,
'functions': (({'@': 'FuncCall',
'agg_distinct': False,
'agg_filter': None,
'agg_order': None,
'agg_star': False,
'agg_within_group': False,
'args': ({'@': 'A_Const',
'isnull': False,
'val': {'@': 'Integer', 'ival': 5}},),
'func_variadic': False,
'funcformat': {'#': 'CoercionForm',
'name': 'COERCE_EXPLICIT_CALL',
'value': 0},
'funcname': ({'@': 'String', 'sval': 'f'},),
'location': 21,
'over': None},
None),),
'is_rowsfrom': False,
'lateral': False,
'ordinality': False}
rf = ast.RangeFunction(serialized)
assert isinstance(rf.functions[0][0], ast.FuncCall)
17 changes: 17 additions & 0 deletions tests/test_printers_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,20 @@ def test_pg_regress_corpus(filename):

assert orig_ast == serialized_ast, "Statement “%s” from %s at line %d != %r" % (
trimmed_stmt, rel_src, lineno, serialized)


@pytest.mark.parametrize('src,lineno,statement',
((src, lineno, statement)
for src in sorted(tests_dir.glob('**/*.sql'))
for (lineno, statement) in statements(src)),
ids=make_id)
def test_ast_serialization_roundtrip(src, lineno, statement):
try:
orig_ast = parse_sql(statement)
except: # noqa
raise RuntimeError("%s:%d:Could not parse %r" % (src, lineno, statement))

stmt = orig_ast[0].stmt
serialized = stmt()
clone = stmt.__class__(serialized)
assert stmt == clone

0 comments on commit effc266

Please sign in to comment.