Skip to content

Commit

Permalink
Fix: simplify from to a single expression closes #1620 (#1620)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao authored May 15, 2023
1 parent 966dfbb commit 4833953
Show file tree
Hide file tree
Showing 33 changed files with 462 additions and 492 deletions.
14 changes: 6 additions & 8 deletions sqlglot/dataframe/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,13 @@ def createDataFrame(
select_kwargs = {
"expressions": sel_columns,
"from": exp.From(
expressions=[
exp.Values(
expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
this=exp.Values(
expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
],
),
),
}

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dataframe/sql/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T
if not expression.args.get("joins"):
return []

left_table = expression.args["from"].args["expressions"][0]
left_table = expression.args["from"].this
other_tables = [join.this for join in expression.args["joins"]]
return [left_table] + other_tables
19 changes: 6 additions & 13 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,12 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
These are added by the optimizer's qualify_column step.
"""
if isinstance(expression, exp.Select):
unnests = {
unnest.alias
for unnest in expression.args.get("from", exp.From(expressions=[])).expressions
if isinstance(unnest, exp.Unnest) and unnest.alias
}

if unnests:
expression = expression.copy()

for select in expression.expressions:
for column in select.find_all(exp.Column):
if column.table in unnests:
column.set("table", None)
for unnest in expression.find_all(exp.Unnest):
if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
for select in expression.selects:
for column in select.find_all(exp.Column):
if column.table == unnest.alias:
column.set("table", None)

return expression

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _parse_update(self) -> exp.Expression:
exp.Update,
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"from": self._parse_from(),
"from": self._parse_from(modifiers=True),
"expressions": self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
"where": self._parse_where(),
Expand Down
39 changes: 22 additions & 17 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,13 @@ class Into(Expression):


class From(Expression):
arg_types = {"expressions": True}
@property
def name(self) -> str:
return self.this.name

@property
def alias_or_name(self) -> str:
return self.this.alias_or_name


class Having(Expression):
Expand Down Expand Up @@ -2318,7 +2324,9 @@ class Select(Subqueryable):
**QUERY_MODIFIERS,
}

def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
def from_(
self, expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts
) -> Select:
"""
Set the FROM expression.
Expand All @@ -2327,27 +2335,24 @@ def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> S
'SELECT x FROM tbl'
Args:
*expressions (str | Expression): the SQL code strings to parse.
expression : the SQL code strings to parse.
If a `From` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `From`.
append (bool): if `True`, add to any existing expressions.
Otherwise, this flattens all the `From` expression into a single expression.
dialect (str): the dialect used to parse the input expression.
copy (bool): if `False`, modify this expression instance in-place.
opts (kwargs): other options to use to parse the input expressions.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
Select: the modified expression.
"""
return _apply_child_list_builder(
*expressions,
return _apply_builder(
expression=expression,
instance=self,
arg="from",
append=append,
copy=copy,
prefix="FROM",
into=From,
prefix="FROM",
dialect=dialect,
copy=copy,
**opts,
)

Expand Down Expand Up @@ -4624,7 +4629,7 @@ def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Selec
return Select().select(*expressions, dialect=dialect, **opts)


def from_(*expressions, dialect=None, **opts) -> Select:
def from_(expression: ExpOrStr, dialect: DialectType = None, **opts) -> Select:
"""
Initializes a syntax tree from a FROM expression.
Expand All @@ -4633,17 +4638,17 @@ def from_(*expressions, dialect=None, **opts) -> Select:
'SELECT col1, col2 FROM tbl'
Args:
*expressions (str | Expression): the SQL code string to parse as the FROM expressions of a
*expression: the SQL code string to parse as the FROM expressions of a
SELECT statement. If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression (in the case that the
dialect: the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
**opts: other options to use to parse the input expressions (again, in the case
that the input expression is a SQL string).
Returns:
Select: the syntax tree for the SELECT statement.
"""
return Select().from_(*expressions, dialect=dialect, **opts)
return Select().from_(expression, dialect=dialect, **opts)


def update(
Expand Down
37 changes: 18 additions & 19 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,8 +1238,7 @@ def into_sql(self, expression: exp.Into) -> str:
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"

def from_sql(self, expression: exp.From) -> str:
expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}"
return f"{self.seg('FROM')} {self.sql(expression, 'this')}"

def group_sql(self, expression: exp.Group) -> str:
group_by = self.op_expressions("GROUP BY", expression)
Expand Down Expand Up @@ -1280,37 +1279,37 @@ def having_sql(self, expression: exp.Having) -> str:
return f"{self.seg('HAVING')}{self.sep()}{this}"

def join_sql(self, expression: exp.Join) -> str:
op_sql = self.seg(
" ".join(
op
for op in (
"NATURAL" if expression.args.get("natural") else None,
"GLOBAL" if expression.args.get("global") else None,
expression.side,
expression.kind,
expression.hint if self.JOIN_HINTS else None,
"JOIN",
)
if op
op_sql = " ".join(
op
for op in (
"NATURAL" if expression.args.get("natural") else None,
"GLOBAL" if expression.args.get("global") else None,
expression.side,
expression.kind,
expression.hint if self.JOIN_HINTS else None,
)
if op
)
on_sql = self.sql(expression, "on")
using = expression.args.get("using")

if not on_sql and using:
on_sql = csv(*(self.sql(column) for column in using))

this_sql = self.sql(expression, "this")

if on_sql:
on_sql = self.indent(on_sql, skip_first=True)
space = self.seg(" " * self.pad) if self.pretty else " "
if using:
on_sql = f"{space}USING ({on_sql})"
else:
on_sql = f"{space}ON {on_sql}"
elif not op_sql:
return f", {this_sql}"

expression_sql = self.sql(expression, "expression")
this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
return f"{self.seg(op_sql)} {this_sql}{on_sql}"

def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
Expand Down Expand Up @@ -1487,9 +1486,9 @@ def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:

return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
*[self.sql(join) for join in expression.args.get("joins") or []],
self.sql(expression, "match"),
*[self.sql(sql) for sql in expression.args.get("laterals") or []],
*[self.sql(lateral) for lateral in expression.args.get("laterals") or []],
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/optimizer/eliminate_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def eliminate_subqueries(expression):
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y'
This also deduplicates common subqueries:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z")
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z")
>>> eliminate_subqueries(expression).sql()
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z'
'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z'
Args:
expression (sqlglot.Expression): expression
Expand Down
24 changes: 0 additions & 24 deletions sqlglot/optimizer/expand_multi_table_selects.py

This file was deleted.

14 changes: 7 additions & 7 deletions sqlglot/optimizer/merge_subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def merge_subqueries(expression, leave_tables_isolated=False):
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression).sql()
'SELECT x.a FROM x JOIN y'
'SELECT x.a FROM x CROSS JOIN y'
If `leave_tables_isolated` is True, this will not merge inner queries into outer
queries if it would result in multiple table selects in a single query:
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y")
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y")
>>> merge_subqueries(expression, leave_tables_isolated=True).sql()
'SELECT a FROM (SELECT x.a FROM x) JOIN y'
'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y'
Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html
Expand Down Expand Up @@ -154,7 +154,7 @@ def _outer_select_joins_on_inner_select_join():
inner_from = inner_scope.expression.args.get("from")
if not inner_from:
return False
inner_from_table = inner_from.expressions[0].alias_or_name
inner_from_table = inner_from.alias_or_name
inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
return any(
col.table != inner_from_table
Expand Down Expand Up @@ -228,7 +228,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
node_to_replace (exp.Subquery|exp.Table)
alias (str)
"""
new_subquery = inner_scope.expression.args.get("from").expressions[0]
new_subquery = inner_scope.expression.args["from"].this
node_to_replace.replace(new_subquery)
for join_hint in outer_scope.join_hints:
tables = join_hint.find_all(exp.Table)
Expand Down Expand Up @@ -319,7 +319,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
# Merge predicates from an outer join to the ON clause
# if it only has columns that are already joined
from_ = expression.args.get("from")
sources = {table.alias_or_name for table in from_.expressions} if from_ else {}
sources = {from_.alias_or_name} if from_ else {}

for join in expression.args["joins"]:
source = join.alias_or_name
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/optimizer/optimize_joins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from sqlglot import exp
from sqlglot.helper import tsort

JOIN_ATTRS = ("on", "side", "kind", "using", "natural")


def optimize_joins(expression):
"""
Expand Down Expand Up @@ -45,7 +47,7 @@ def reorder_joins(expression):
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
head = from_.expressions[0]
head = from_.this
parent = from_.parent
joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
dag = {head.alias_or_name: []}
Expand All @@ -65,6 +67,9 @@ def normalize(expression):
Remove INNER and OUTER from joins as they are optional.
"""
for join in expression.find_all(exp.Join):
if not any(join.args.get(k) for k in JOIN_ATTRS):
join.set("kind", "CROSS")

if join.kind != "CROSS":
join.set("kind", None)
return expression
Expand Down
2 changes: 0 additions & 2 deletions sqlglot/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
Expand All @@ -33,7 +32,6 @@
validate_qualify_columns,
normalize,
unnest_subqueries,
expand_multi_table_selects,
pushdown_predicates,
optimize_joins,
eliminate_subqueries,
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def validate_qualify_columns(expression):
if scope.external_columns and not scope.is_correlated_subquery:
column = scope.external_columns[0]
raise OptimizeError(
f"""Column '{column}' could not be resolved{" for table: '{column.table}'" if column.table else ''}"""
f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
)

if unqualified_columns:
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def _traverse_tables(scope):
expressions = []
from_ = scope.expression.args.get("from")
if from_:
expressions.extend(from_.expressions)
expressions.append(from_.this)

for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
Expand Down
Loading

0 comments on commit 4833953

Please sign in to comment.