From 6703b06129d5bfceef89da2901c6dc8c809afabb Mon Sep 17 00:00:00 2001 From: George Sittas Date: Mon, 15 May 2023 19:46:08 +0300 Subject: [PATCH] Factor out some computations --- sqlglot/optimizer/qualify_columns.py | 36 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index de7e77b64e..263570d7c6 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -296,8 +296,18 @@ def _expand_stars(scope, resolver, using_column_tables): coalesced_columns = set() # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future + pivot_columns = None + pivot_output_columns = None pivot = seq_get(scope.pivots, 0) + has_pivoted_source = pivot and not pivot.args.get("unpivot") + if has_pivoted_source: + pivot_columns = set(column.output_name for column in pivot.find_all(exp.Column)) + + pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] + if not pivot_output_columns: + pivot_output_columns = [col.alias_or_name for col in pivot.expressions] + for expression in scope.selects: if isinstance(expression, exp.Star): tables = list(scope.selected_sources) @@ -318,8 +328,14 @@ def _expand_stars(scope, resolver, using_column_tables): columns = resolver.get_source_columns(table, only_visible=True) if columns and "*" not in columns: - if pivot and not pivot.args.get("unpivot"): - _add_pivot_columns(pivot, columns, new_selections) + if has_pivoted_source: + implicit_columns = list(set(columns) - pivot_columns) + new_selections.extend( + [ + exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) + for name in implicit_columns + pivot_output_columns + ] + ) continue table_id = id(table) @@ -375,22 +391,6 @@ def _add_replace_columns(expression, tables, replace_columns): replace_columns[id(table)] = columns -def _add_pivot_columns(pivot, source_columns, columns): - pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] - if not pivot_output_columns: - pivot_output_columns = [col.alias_or_name for col in pivot.expressions] - - pivot_columns = set(column.output_name for column in pivot.find_all(exp.Column)) - implicit_columns = list(set(source_columns) - pivot_columns) - - columns.extend( - [ - exp.alias_(exp.column(name, table=pivot.alias), name) - for name in implicit_columns + pivot_output_columns - ] - ) - - def _qualify_outputs(scope): """Ensure all output columns are aliased""" new_selections = []