Skip to content

Commit

Permalink
Factor out some computations
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed May 15, 2023
1 parent 42bf5b3 commit 6703b06
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,18 @@ def _expand_stars(scope, resolver, using_column_tables):
coalesced_columns = set()

# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
pivot_columns = None
pivot_output_columns = None
pivot = seq_get(scope.pivots, 0)

has_pivoted_source = pivot and not pivot.args.get("unpivot")
if has_pivoted_source:
pivot_columns = set(column.output_name for column in pivot.find_all(exp.Column))

pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
if not pivot_output_columns:
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]

for expression in scope.selects:
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
Expand All @@ -318,8 +328,14 @@ def _expand_stars(scope, resolver, using_column_tables):
columns = resolver.get_source_columns(table, only_visible=True)

if columns and "*" not in columns:
if pivot and not pivot.args.get("unpivot"):
_add_pivot_columns(pivot, columns, new_selections)
if has_pivoted_source:
implicit_columns = list(set(columns) - pivot_columns)
new_selections.extend(
[
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
for name in implicit_columns + pivot_output_columns
]
)
continue

table_id = id(table)
Expand Down Expand Up @@ -375,22 +391,6 @@ def _add_replace_columns(expression, tables, replace_columns):
replace_columns[id(table)] = columns


def _add_pivot_columns(pivot, source_columns, columns):
pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
if not pivot_output_columns:
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]

pivot_columns = set(column.output_name for column in pivot.find_all(exp.Column))
implicit_columns = list(set(source_columns) - pivot_columns)

columns.extend(
[
exp.alias_(exp.column(name, table=pivot.alias), name)
for name in implicit_columns + pivot_output_columns
]
)


def _qualify_outputs(scope):
"""Ensure all output columns are aliased"""
new_selections = []
Expand Down

0 comments on commit 6703b06

Please sign in to comment.