Skip to content

Commit

Permalink
Fix(executor): allow non-projected aggregates in ORDER BY (#1863)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Jun 30, 2023
1 parent 3800158 commit d6c1569
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
2 changes: 2 additions & 0 deletions sqlglot/executor/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ def eval_tuple(self, codes):
def table(self) -> Table:
if self._table is None:
self._table = list(self.tables.values())[0]

for other in self.tables.values():
if self._table.columns != other.columns:
raise Exception(f"Columns are different.")
if len(self._table.rows) != len(other.rows):
raise Exception(f"Rows are different.")

return self._table

def add_columns(self, *columns: str) -> None:
Expand Down
17 changes: 16 additions & 1 deletion sqlglot/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def dag(self) -> t.Dict[Step, t.Set[Step]]:
while nodes:
node = nodes.pop()
dag[node] = set()

for dep in node.dependencies:
dag[node].add(dep)
nodes.add(dep)

self._dag = dag

return self._dag
Expand Down Expand Up @@ -128,13 +130,16 @@ def extract_agg_operands(expression):
agg_funcs = tuple(expression.find_all(exp.AggFunc))
if agg_funcs:
aggregations.add(expression)

for agg in agg_funcs:
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = next_operand_name()

operand.replace(exp.column(operands[operand], quoted=True))

return bool(agg_funcs)

for e in expression.expressions:
Expand Down Expand Up @@ -178,13 +183,14 @@ def extract_agg_operands(expression):
for k, v in aggregate.group.items():
intermediate[v] = k
if isinstance(v, exp.Column):
intermediate[v.alias_or_name] = k
intermediate[v.name] = k

for projection in projections:
for node, *_ in projection.walk():
name = intermediate.get(node)
if name:
node.replace(exp.column(name, step.name))

if aggregate.condition:
for node, *_ in aggregate.condition.walk():
name = intermediate.get(node) or intermediate.get(node.name)
Expand All @@ -197,6 +203,15 @@ def extract_agg_operands(expression):
order = expression.args.get("order")

if order:
if isinstance(step, Aggregate):
for ordered in order.expressions:
if ordered.find(exp.AggFunc):
operand_name = next_operand_name()
extract_agg_operands(exp.alias_(ordered.this, operand_name, quoted=True))
ordered.this.replace(exp.column(operand_name, quoted=True))

step.aggregations = list(aggregations)

sort = Sort()
sort.name = step.name
sort.key = order.expressions
Expand Down
5 changes: 5 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,11 @@ def test_group_by(self):
[(2, 25.0)],
("_col_0", "_col_1"),
),
(
"SELECT a FROM x GROUP BY a ORDER BY AVG(b)",
[(2,), (1,), (3,)],
("a",),
),
):
with self.subTest(sql):
result = execute(sql, tables=tables)
Expand Down

0 comments on commit d6c1569

Please sign in to comment.