Skip to content

Commit

Permalink
Fix: make map gen more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 26, 2023
1 parent 425af88 commit 8f0fbad
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion sqlglot/executor/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def interval(this, unit):
"LOWER": null_if_any(lambda arg: arg.lower()),
"LT": null_if_any(lambda this, e: this < e),
"LTE": null_if_any(lambda this, e: this <= e),
"MAP": null_if_any(lambda k, v: dict(zip(k, v))),
"MAP": null_if_any(lambda *args: dict(zip(*args))), # type: ignore
"MOD": null_if_any(lambda e, this: e % this),
"MUL": null_if_any(lambda e, this: e * this),
"NEQ": null_if_any(lambda this, e: this != e),
Expand Down
18 changes: 13 additions & 5 deletions sqlglot/executor/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,19 @@ def _ordered_py(self, expression):

def _rename(self, e):
try:
if "expressions" in e.args:
this = self.sql(e, "this")
this = f"{this}, " if this else ""
return f"{e.key.upper()}({this}{self.expressions(e)})"
return self.func(e.key, *e.args.values())
values = list(e.args.values())

if len(values) == 1:
values = values[0]
if not isinstance(values, list):
return self.func(e.key, values)
return self.func(e.key, *values)

if isinstance(e, exp.Func) and e.is_var_len_args:
*head, tail = values
return self.func(e.key, *head, *tail)

return self.func(e.key, *values)
except Exception as ex:
raise Exception(f"Could not rename {repr(e)}") from ex

Expand Down
2 changes: 2 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def rename_anonymous(self, source, target):

def test_py_dialect(self):
self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''")
self.assertEqual(Python().generate(parse_one("MAP([1], [2])")), "MAP([1], [2])")

def test_optimized_tpch(self):
for i, (sql, optimized) in enumerate(self.sqls[:20], start=1):
Expand Down Expand Up @@ -596,6 +597,7 @@ def test_scalar_functions(self):
("1::bool", True),
("0::bool", False),
("MAP(['a'], [1]).a", 1),
("MAP()", {}),
("STRFTIME('%j', '2023-03-23 15:00:00')", "082"),
("STRFTIME('%j', NULL)", None),
("DATESTRTODATE('2022-01-01')", date(2022, 1, 1)),
Expand Down

0 comments on commit 8f0fbad

Please sign in to comment.