Skip to content

Commit

Permalink
Feat(duckdb): add support for simplified pivot syntax (#1714)
Browse files Browse the repository at this point in the history
* Feat(duckdb): add support for simplified pivot syntax

* Fixups

* Fixups

* Cleanup

* Use explicit kwarg name

* Fixup
  • Loading branch information
georgesittas authored Jun 1, 2023
1 parent 1b1d9f2 commit 17dc0e1
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 24 deletions.
4 changes: 1 addition & 3 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def _parse_placeholder(self) -> t.Optional[exp.Expression]:

return self.expression(exp.Placeholder, this=this, kind=kind)

def _parse_in(
self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:
def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In:
this = super()._parse_in(this)
this.set("is_global", is_global)
return this
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class Tokenizer(tokens.Tokenizer):
"INT1": TokenType.TINYINT,
"LOGICAL": TokenType.BOOLEAN,
"NUMERIC": TokenType.DOUBLE,
"PIVOT_WIDER": TokenType.PIVOT,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
"UBIGINT": TokenType.UBIGINT,
Expand Down
9 changes: 7 additions & 2 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3120,12 +3120,17 @@ class Tag(Expression):
}


# Represents both the standard SQL PIVOT operator and DuckDB's "simplified" PIVOT syntax
# https://duckdb.org/docs/sql/statements/pivot
class Pivot(Expression):
arg_types = {
"this": False,
"alias": False,
"expressions": True,
"field": True,
"unpivot": True,
"field": False,
"unpivot": False,
"using": False,
"group": False,
"columns": False,
}

Expand Down
11 changes: 10 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,11 +1220,20 @@ def tablesample_sql(
return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"

def pivot_sql(self, expression: exp.Pivot) -> str:
expressions = self.expressions(expression, flat=True)

if expression.this:
this = self.sql(expression, "this")
on = f"{self.seg('ON')} {expressions}"
using = self.expressions(expression, key="using", flat=True)
using = f"{self.seg('USING')} {using}" if using else ""
group = self.sql(expression, "group")
return f"PIVOT {this}{on}{using}{group}"

alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
expressions = self.expressions(expression, flat=True)
field = self.sql(expression, "field")
return f"{direction}({expressions} FOR {field}){alias}"

Expand Down
58 changes: 42 additions & 16 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,31 +440,31 @@ class Parser(metaclass=_Parser):
}

EXPRESSION_PARSERS = {
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
exp.Column: lambda self: self._parse_column(),
exp.Condition: lambda self: self._parse_conjunction(),
exp.DataType: lambda self: self._parse_types(),
exp.Expression: lambda self: self._parse_statement(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
exp.Having: lambda self: self._parse_having(),
exp.Identifier: lambda self: self._parse_id_var(),
exp.Lateral: lambda self: self._parse_lateral(),
exp.Join: lambda self: self._parse_join(),
exp.Order: lambda self: self._parse_order(),
exp.Cluster: lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
exp.Sort: lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
exp.Lambda: lambda self: self._parse_lambda(),
exp.Lateral: lambda self: self._parse_lateral(),
exp.Limit: lambda self: self._parse_limit(),
exp.Offset: lambda self: self._parse_offset(),
exp.TableAlias: lambda self: self._parse_table_alias(),
exp.Table: lambda self: self._parse_table_parts(),
exp.Condition: lambda self: self._parse_conjunction(),
exp.Expression: lambda self: self._parse_statement(),
exp.Properties: lambda self: self._parse_properties(),
exp.Where: lambda self: self._parse_where(),
exp.Order: lambda self: self._parse_order(),
exp.Ordered: lambda self: self._parse_ordered(),
exp.Having: lambda self: self._parse_having(),
exp.With: lambda self: self._parse_with(),
exp.Window: lambda self: self._parse_named_window(),
exp.Properties: lambda self: self._parse_properties(),
exp.Qualify: lambda self: self._parse_qualify(),
exp.Returning: lambda self: self._parse_returning(),
exp.Sort: lambda self: self._parse_sort(exp.Sort, "SORT", "BY"),
exp.Table: lambda self: self._parse_table_parts(),
exp.TableAlias: lambda self: self._parse_table_alias(),
exp.Where: lambda self: self._parse_where(),
exp.Window: lambda self: self._parse_named_window(),
exp.With: lambda self: self._parse_with(),
"JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
}

Expand All @@ -480,9 +480,13 @@ class Parser(metaclass=_Parser):
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.FROM: lambda self: exp.select("*").from_(
t.cast(exp.From, self._parse_from(skip_from_token=True))
),
TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD: lambda self: self._parse_load(),
TokenType.MERGE: lambda self: self._parse_merge(),
TokenType.PIVOT: lambda self: self._parse_simplified_pivot(),
TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.SET: lambda self: self._parse_set(),
Expand Down Expand Up @@ -1897,6 +1901,10 @@ def _parse_select(
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
elif self._match(TokenType.PIVOT):
this = self._parse_simplified_pivot()
elif self._match(TokenType.FROM):
this = exp.select("*").from_(t.cast(exp.From, self._parse_from(skip_from_token=True)))
else:
this = None

Expand Down Expand Up @@ -2000,8 +2008,10 @@ def _parse_into(self) -> t.Optional[exp.Expression]:
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
)

def _parse_from(self, modifiers: bool = False) -> t.Optional[exp.Expression]:
if not self._match(TokenType.FROM):
def _parse_from(
self, modifiers: bool = False, skip_from_token: bool = False
) -> t.Optional[exp.From]:
if not skip_from_token and not self._match(TokenType.FROM):
return None

comments = self._prev_comments
Expand Down Expand Up @@ -2416,6 +2426,22 @@ def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Expre
def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]:
return list(iter(self._parse_pivot, None))

# https://duckdb.org/docs/sql/statements/pivot
def _parse_simplified_pivot(self) -> exp.Pivot:
def _parse_on() -> t.Optional[exp.Expression]:
this = self._parse_bitwise()
return self._parse_in(this) if self._match(TokenType.IN) else this

this = self._parse_table()
expressions = self._match(TokenType.ON) and self._parse_csv(_parse_on)
using = self._match(TokenType.USING) and self._parse_csv(
lambda: self._parse_alias(self._parse_function())
)
group = self._parse_group()
return self.expression(
exp.Pivot, this=this, expressions=expressions, using=using, group=group
)

def _parse_pivot(self) -> t.Optional[exp.Expression]:
index = self._index

Expand Down Expand Up @@ -2740,7 +2766,7 @@ def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expressi
this = self.expression(exp.Is, this=this, expression=expression)
return self.expression(exp.Not, this=this) if negate else this

def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.Expression:
def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In:
unnest = self._parse_unnest()
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
Expand Down
35 changes: 33 additions & 2 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ def test_duckdb(self):
parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b"
)

self.validate_identity("PIVOT Cities ON Year USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Year USING FIRST(Population)")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country")
self.validate_identity("PIVOT Cities ON Country, Name USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Country || '_' || Name USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country, Name")
self.validate_identity("SELECT {'a': 1} AS x")
self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x")
self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}")
Expand All @@ -146,9 +152,36 @@ def test_duckdb(self):
self.validate_identity(
"SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)"
)
self.validate_identity(
"PIVOT Cities ON Year IN (2000, 2010) USING SUM(Population) GROUP BY Country"
)
self.validate_identity(
"PIVOT Cities ON Year USING SUM(Population) AS total, MAX(Population) AS max GROUP BY Country"
)
self.validate_identity(
"WITH pivot_alias AS (PIVOT Cities ON Year USING SUM(Population) GROUP BY Country) SELECT * FROM pivot_alias"
)
self.validate_identity(
"SELECT * FROM (PIVOT Cities ON Year USING SUM(Population) GROUP BY Country) AS pivot_alias"
)

self.validate_all("FROM (FROM tbl)", write={"duckdb": "SELECT * FROM (SELECT * FROM tbl)"})
self.validate_all("FROM tbl", write={"duckdb": "SELECT * FROM tbl"})
self.validate_all("0b1010", write={"": "0 AS b1010"})
self.validate_all("0x1010", write={"": "0 AS x1010"})
self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"})
self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'})
self.validate_all(
"PIVOT_WIDER Cities ON Year USING SUM(Population)",
write={"duckdb": "PIVOT Cities ON Year USING SUM(Population)"},
)
self.validate_all(
"WITH t AS (SELECT 1) FROM t", write={"duckdb": "WITH t AS (SELECT 1) SELECT * FROM t"}
)
self.validate_all(
"WITH t AS (SELECT 1) SELECT * FROM (FROM t)",
write={"duckdb": "WITH t AS (SELECT 1) SELECT * FROM (SELECT * FROM t)"},
)
self.validate_all(
"""SELECT DATEDIFF('day', t1."A", t1."B") FROM "table" AS t1""",
write={
Expand All @@ -163,8 +196,6 @@ def test_duckdb(self):
"trino": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))",
},
)
self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"})
self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'})
self.validate_all(
"WITH 'x' AS (SELECT 1) SELECT * FROM x",
write={"duckdb": 'WITH "x" AS (SELECT 1) SELECT * FROM x'},
Expand Down

0 comments on commit 17dc0e1

Please sign in to comment.