Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(optimizer): optimize pivots #1617

Merged
merged 12 commits into from
May 16, 2023
2 changes: 1 addition & 1 deletion sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:

def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str:
self.unsupported("PIVOT unsupported")
return self.sql(expression)
return ""
georgesittas marked this conversation as resolved.
Show resolved Hide resolved


def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class Snowflake(Dialect):
}

class Parser(parser.Parser):
QUOTED_PIVOT_COLUMNS = True
IDENTIFY_PIVOT_STRINGS = True

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
Expand Down
1 change: 0 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2953,7 +2953,6 @@ class Tag(Expression):

class Pivot(Expression):
arg_types = {
"this": False,
"alias": False,
"expressions": True,
"field": True,
Expand Down
19 changes: 8 additions & 11 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,9 +1158,10 @@ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:

alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
hints = self.expressions(expression, key="hints", flat=True)
hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
pivots = self.expressions(expression, key="pivots", sep="")
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="")
laterals = self.expressions(expression, key="laterals", sep="")
system_time = expression.args.get("system_time")
Expand Down Expand Up @@ -1197,14 +1198,13 @@ def tablesample_sql(
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"

def pivot_sql(self, expression: exp.Pivot) -> str:
this = self.sql(expression, "this")
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
expressions = self.expressions(expression, key="expressions")
expressions = self.expressions(expression, flat=True)
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field}){alias}"
return f"{direction}({expressions} FOR {field}){alias}"

def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
Expand Down Expand Up @@ -1562,13 +1562,10 @@ def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""

sql = self.query_modifiers(
expression,
self.wrap(expression),
alias,
self.expressions(expression, key="pivots", sep=" "),
)
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""

sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots)
return self.prepend_ctes(expression, sql)

def qualify_sql(self, expression: exp.Qualify) -> str:
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/optimizer/eliminate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def _eliminate_union(scope, existing_ctes, taken):


def _eliminate_derived_table(scope, existing_ctes, taken):
# This ensures we don't drop the "pivot" arg from a pivoted subquery
if scope.parent.pivots:
return None
georgesittas marked this conversation as resolved.
Show resolved Hide resolved

parent = scope.expression.parent
name, cte = _new_cte(scope, existing_ctes, taken)

Expand Down
1 change: 1 addition & 0 deletions sqlglot/optimizer/merge_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def _outer_select_joins_on_inner_select_join():
and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not outer_scope.pivots
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
and not (
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ def optimize(
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)

for rule in rules:
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
rule_kwargs = {
param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
}
expression = rule(expression, **rule_kwargs)

return expression
5 changes: 3 additions & 2 deletions sqlglot/optimizer/pushdown_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
for scope in reversed(traverse_scope(expression)):
parent_selections = referenced_columns.get(scope, {SELECT_ALL})

if scope.expression.args.get("distinct"):
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT
if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots:
# We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
# we select from a pivoted source in the parent scope.
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
parent_selections = {SELECT_ALL}

if isinstance(scope.expression, exp.Union):
Expand Down
42 changes: 40 additions & 2 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
from sqlglot.schema import Schema, ensure_schema

Expand Down Expand Up @@ -65,7 +66,7 @@ def validate_qualify_columns(expression):
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
unqualified_columns.extend(scope.unqualified_columns)
if scope.external_columns and not scope.is_correlated_subquery:
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
column = scope.external_columns[0]
raise OptimizeError(
f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
Expand Down Expand Up @@ -249,6 +250,12 @@ def _qualify_columns(scope, resolver):
raise OptimizeError(f"Unknown column: {column_name}")

if not column_table:
if scope.pivots and not column.find_ancestor(exp.Pivot):
# If the column is under the Pivot expression, we need to qualify it
# using the name of the pivoted source instead of the pivot's alias
column.set("table", exp.to_identifier(scope.pivots[0].alias))
continue
georgesittas marked this conversation as resolved.
Show resolved Hide resolved

column_table = resolver.get_table(column_name)

# column_table can be a '' because bigquery unnest has no table alias
Expand All @@ -272,6 +279,13 @@ def _qualify_columns(scope, resolver):
if column_table:
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))

for pivot in scope.pivots:
for column in pivot.find_all(exp.Column):
if not column.table and column.name in resolver.all_columns:
column_table = resolver.get_table(column.name)
if column_table:
column.set("table", column_table)


def _expand_stars(scope, resolver, using_column_tables):
"""Expand stars to lists of column selections"""
Expand All @@ -281,6 +295,20 @@ def _expand_stars(scope, resolver, using_column_tables):
replace_columns = {}
coalesced_columns = set()

# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
pivot_columns = None
pivot_output_columns = None
pivot = seq_get(scope.pivots, 0)

has_pivoted_source = pivot and not pivot.args.get("unpivot")
if has_pivoted_source:
# We're using a dictionary here in order to preserve order
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))

pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
if not pivot_output_columns:
pivot_output_columns = [col.alias_or_name for col in pivot.expressions]

for expression in scope.selects:
if isinstance(expression, exp.Star):
tables = list(scope.selected_sources)
Expand All @@ -297,9 +325,18 @@ def _expand_stars(scope, resolver, using_column_tables):
for table in tables:
if table not in scope.sources:
raise OptimizeError(f"Unknown table: {table}")

columns = resolver.get_source_columns(table, only_visible=True)

if columns and "*" not in columns:
if has_pivoted_source:
implicit_columns = [col for col in columns if col not in pivot_columns]
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
for name in implicit_columns + pivot_output_columns
)
continue

table_id = id(table)
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
Expand All @@ -319,12 +356,13 @@ def _expand_stars(scope, resolver, using_column_tables):
)
elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
column = exp.column(name, table=table)
new_selections.append(
alias(column, alias_, copy=False) if alias_ != name else column
)
else:
return

scope.expression.set("expressions", new_selections)


Expand Down
12 changes: 10 additions & 2 deletions sqlglot/optimizer/qualify_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))

if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
alias_ = next_name()
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
scope.rename_source(None, alias_)

pivots = derived_table.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))

for name, source in scope.sources.items():
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
Expand All @@ -60,12 +64,16 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
source = source.replace(
alias(
source,
name if name else next_name(),
name or source.name or next_name(),
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
copy=True,
table=True,
)
)

pivots = source.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_name())))

if schema and isinstance(source.this, exp.ReadCSV):
with csv_reader(source.this) as reader:
header = next(reader)
Expand Down
21 changes: 19 additions & 2 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def clear_cache(self):
self._columns = None
self._external_columns = None
self._join_hints = None
self._pivots = None

def branch(self, expression, scope_type, chain_sources=None, **kwargs):
"""Branch from the current scope to a new, inner scope"""
Expand Down Expand Up @@ -372,6 +373,17 @@ def join_hints(self):
return []
return self._join_hints

@property
def pivots(self):
if not self._pivots:
self._pivots = [
pivot
for node in self.tables + self.derived_tables
for pivot in node.args.get("pivots") or []
]

return self._pivots

def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
Expand Down Expand Up @@ -603,8 +615,13 @@ def _traverse_tables(scope):
source_name = expression.alias_or_name

if table_name in scope.sources:
# This is a reference to a parent source (e.g. a CTE), not an actual table.
sources[source_name] = scope.sources[table_name]
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
# it is pivoted, because then we get back a new table and hence a new source.
pivots = expression.args.get("pivots")
if pivots:
sources[pivots[0].alias] = expression
else:
sources[source_name] = scope.sources[table_name]
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
elif source_name in sources:
sources[find_new_name(sources, table_name)] = expression
else:
Expand Down
11 changes: 6 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,8 @@ class Parser(metaclass=_Parser):

CONVERT_TYPE_FIRST = False

QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None
PREFIXED_PIVOT_COLUMNS = False
IDENTIFY_PIVOT_STRINGS = False

LOG_BASE_FIRST = True
LOG_DEFAULTS_TO_LN = False
Expand Down Expand Up @@ -2465,14 +2465,15 @@ def _parse_pivot(self) -> t.Optional[exp.Expression]:
names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions))

columns: t.List[exp.Expression] = []
for col in pivot.args["field"].expressions:
for fld in pivot.args["field"].expressions:
field_name = fld.sql() if self.IDENTIFY_PIVOT_STRINGS else fld.alias_or_name
for name in names:
if self.PREFIXED_PIVOT_COLUMNS:
name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name
name = f"{name}_{field_name}" if name else field_name
else:
name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name
name = f"{field_name}_{name}" if name else field_name

columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS))
columns.append(exp.to_identifier(name))

pivot.set("columns", columns)

Expand Down
Loading