Skip to content

Commit

Permalink
Fix: dont expand bq pseudocolumns in optimizer star expansion (#1826)
Browse files Browse the repository at this point in the history
* fix: dont expand bq pseudocolumns in star if present in schema

* test: add test for correct expansion with pseudocols in schema

* ci: more typing for mypy and explicit guards
  • Loading branch information
z3z1ma authored Jun 26, 2023
1 parent 451dad2 commit 8aef4c3
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 19 deletions.
53 changes: 34 additions & 19 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def qualify_columns(
return expression


def validate_qualify_columns(expression):
def validate_qualify_columns(expression: E) -> E:
"""Raise an `OptimizeError` if any columns aren't qualified"""
unqualified_columns = []
for scope in traverse_scope(expression):
Expand All @@ -79,7 +79,7 @@ def validate_qualify_columns(expression):
return expression


def _pop_table_column_aliases(derived_tables):
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE]) -> None:
"""
Remove table column aliases.
Expand All @@ -91,13 +91,13 @@ def _pop_table_column_aliases(derived_tables):
table_alias.args.pop("columns", None)


def _expand_using(scope, resolver):
def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
joins = list(scope.find_all(exp.Join))
names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]

# Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables = {}
column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}

for join in joins:
using = join.args.get("using")
Expand Down Expand Up @@ -174,7 +174,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:

def replace_columns(
node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
):
) -> None:
if not node:
return

Expand All @@ -201,7 +201,7 @@ def replace_columns(
scope.clear_cache()


def _expand_group_by(scope, resolver):
def _expand_group_by(scope: Scope, resolver: Resolver):
group = scope.expression.args.get("group")
if not group:
return
Expand All @@ -210,7 +210,7 @@ def _expand_group_by(scope, resolver):
scope.expression.set("group", group)


def _expand_order_by(scope):
def _expand_order_by(scope: Scope):
order = scope.expression.args.get("order")
if not order:
return
Expand All @@ -229,7 +229,7 @@ def _expand_order_by(scope):
ordered.set("this", selects.get(ordered.this, ordered.this))


def _expand_positional_references(scope, expressions):
def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
new_nodes = []
for node in expressions:
if node.is_int:
Expand All @@ -247,7 +247,7 @@ def _expand_positional_references(scope, expressions):
return new_nodes


def _qualify_columns(scope, resolver):
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns:
column_table = column.table
Expand Down Expand Up @@ -296,21 +296,23 @@ def _qualify_columns(scope, resolver):
column.set("table", column_table)


def _expand_stars(scope, resolver, using_column_tables):
def _expand_stars(
scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
) -> None:
"""Expand stars to lists of column selections"""

new_selections = []
except_columns = {}
replace_columns = {}
except_columns: t.Dict[int, t.Set[str]] = {}
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
coalesced_columns = set()

# TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
pivot_columns = None
pivot_output_columns = None
pivot = seq_get(scope.pivots, 0)
pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))

has_pivoted_source = pivot and not pivot.args.get("unpivot")
if has_pivoted_source:
if pivot and has_pivoted_source:
pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))

pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
Expand All @@ -336,8 +338,17 @@ def _expand_stars(scope, resolver, using_column_tables):

columns = resolver.get_source_columns(table, only_visible=True)

# The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
if resolver.schema.dialect == "bigquery":
columns = [
name
for name in columns
if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
]

if columns and "*" not in columns:
if has_pivoted_source:
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
implicit_columns = [col for col in columns if col not in pivot_columns]
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
Expand Down Expand Up @@ -374,7 +385,9 @@ def _expand_stars(scope, resolver, using_column_tables):
scope.expression.set("expressions", new_selections)


def _add_except_columns(expression, tables, except_columns):
def _add_except_columns(
expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
) -> None:
except_ = expression.args.get("except")

if not except_:
Expand All @@ -386,7 +399,9 @@ def _add_except_columns(expression, tables, except_columns):
except_columns[id(table)] = columns


def _add_replace_columns(expression, tables, replace_columns):
def _add_replace_columns(
expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
) -> None:
replace = expression.args.get("replace")

if not replace:
Expand All @@ -398,7 +413,7 @@ def _add_replace_columns(expression, tables, replace_columns):
replace_columns[id(table)] = columns


def _qualify_outputs(scope):
def _qualify_outputs(scope: Scope):
"""Ensure all output columns are aliased"""
new_selections = []

Expand Down Expand Up @@ -435,7 +450,7 @@ class Resolver:
This is a class so we can lazily load some things and easily share them across functions.
"""

def __init__(self, scope, schema, infer_schema: bool = True):
def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
self.scope = scope
self.schema = schema
self._source_columns = None
Expand Down
17 changes: 17 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,3 +727,20 @@ def test_quotes(self):
source_query = parse_one('SELECT * FROM example."source"', read="snowflake")
transformed = func(source_query, dialect="snowflake", schema=schema)
self.assertEqual(transformed.sql(pretty=True, dialect="snowflake"), expected)

def test_no_pseudocolumn_expansion(self):
schema = {
"a": {
"a": "text",
"b": "text",
"_PARTITIONDATE": "date",
"_PARTITIONTIME": "timestamp",
}
}

self.assertEqual(
optimizer.optimize(
parse_one("SELECT * FROM a"), schema=MappingSchema(schema, dialect="bigquery")
),
parse_one('SELECT "a"."a" AS "a", "a"."b" AS "b" FROM "a" AS "a"'),
)

0 comments on commit 8aef4c3

Please sign in to comment.