Skip to content

Commit

Permalink
fix: subquery selects (#1569)
Browse files Browse the repository at this point in the history
  • Loading branch information
barakalon authored May 9, 2023
1 parent c9103fe commit bcfae2c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
11 changes: 5 additions & 6 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,11 +846,7 @@ def alias_column_names(self):

@property
def selects(self):
alias = self.args.get("alias")

if alias:
return alias.columns
return []
return self.this.selects if isinstance(self.this, Subqueryable) else []

@property
def named_selects(self):
Expand Down Expand Up @@ -920,7 +916,10 @@ def except_(self, expression, distinct=True, dialect=None, **opts):


class UDTF(DerivedTable, Unionable):
pass
@property
def selects(self):
alias = self.args.get("alias")
return alias.columns if alias else []


class Cache(Expression):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ def test_selects(self):
expression = parse_one("SELECT a, b FROM x")
self.assertEqual([s.sql() for s in expression.selects], ["a", "b"])

expression = parse_one("(SELECT a, b FROM x)")
self.assertEqual([s.sql() for s in expression.selects], ["a", "b"])

def test_alias_column_names(self):
expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y")
subquery = expression.find(exp.Subquery)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,12 @@ def test_concat_annotation(self):
expression = annotate_types(parse_one("CONCAT('A', 'B')"))
self.assertEqual(expression.type.this, exp.DataType.Type.VARCHAR)

def test_root_subquery_annotation(self):
expression = annotate_types(parse_one("(SELECT 1, 2 FROM x) LIMIT 0"))
self.assertIsInstance(expression, exp.Subquery)
self.assertEqual(exp.DataType.Type.INT, expression.selects[0].type.this)
self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this)

def test_recursive_cte(self):
query = parse_one(
"""
Expand Down

0 comments on commit bcfae2c

Please sign in to comment.