Skip to content

Commit

Permalink
Feat(clickhouse): add support for clickhouse's placeholders (#1628)
Browse files Browse the repository at this point in the history
* Feat(clickhouse): add support for clickhouse's placeholders

* Formatting

* Fixup
  • Loading branch information
georgesittas authored May 15, 2023
1 parent 8610298 commit 50025ea
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 78 deletions.
67 changes: 46 additions & 21 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,25 @@ class Tokenizer(tokens.Tokenizer):
**tokens.Tokenizer.KEYWORDS,
"ASOF": TokenType.ASOF,
"ATTACH": TokenType.COMMAND,
"GLOBAL": TokenType.GLOBAL,
"DATETIME64": TokenType.DATETIME64,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
"INT8": TokenType.TINYINT,
"UINT8": TokenType.UTINYINT,
"GLOBAL": TokenType.GLOBAL,
"INT128": TokenType.INT128,
"INT16": TokenType.SMALLINT,
"UINT16": TokenType.USMALLINT,
"INT256": TokenType.INT256,
"INT32": TokenType.INT,
"UINT32": TokenType.UINT,
"INT64": TokenType.BIGINT,
"UINT64": TokenType.UBIGINT,
"INT128": TokenType.INT128,
"INT8": TokenType.TINYINT,
"MAP": TokenType.MAP,
"TUPLE": TokenType.STRUCT,
"UINT128": TokenType.UINT128,
"INT256": TokenType.INT256,
"UINT16": TokenType.USMALLINT,
"UINT256": TokenType.UINT256,
"TUPLE": TokenType.STRUCT,
"UINT32": TokenType.UINT,
"UINT64": TokenType.UBIGINT,
"UINT8": TokenType.UTINYINT,
}

class Parser(parser.Parser):
Expand Down Expand Up @@ -116,6 +117,27 @@ def _parse_ternary(self) -> t.Optional[exp.Expression]:

return this

def _parse_placeholder(self) -> t.Optional[exp.Expression]:
"""
Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier}
https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters
"""
if not self._match(TokenType.L_BRACE):
return None

this = self._parse_id_var()
self._match(TokenType.COLON)
kind = self._parse_types(check_func=False) or (
self._match_text_seq("IDENTIFIER") and "Identifier"
)

if not kind:
self.raise_error("Expecting a placeholder type or 'Identifier' for tables")
elif not self._match(TokenType.R_BRACE):
self.raise_error("Expecting }")

return self.expression(exp.Placeholder, this=this, kind=kind)

def _parse_in(
self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:
Expand Down Expand Up @@ -220,25 +242,25 @@ class Generator(generator.Generator):

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATETIME64: "DateTime64",
exp.DataType.Type.DOUBLE: "Float64",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.INT128: "Int128",
exp.DataType.Type.INT256: "Int256",
exp.DataType.Type.MAP: "Map",
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.SMALLINT: "Int16",
exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.INT: "Int32",
exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.UBIGINT: "UInt64",
exp.DataType.Type.INT128: "Int128",
exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.UINT128: "UInt128",
exp.DataType.Type.INT256: "Int256",
exp.DataType.Type.UINT256: "UInt256",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.DOUBLE: "Float64",
exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.UTINYINT: "UInt8",
}

TRANSFORMS = {
Expand Down Expand Up @@ -285,3 +307,6 @@ def after_limit_modifiers(self, expression):
def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
params = self.expressions(expression, "params", flat=True)
return self.func(expression.name, *expression.expressions) + f"({params})"

def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
100 changes: 50 additions & 50 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3020,7 +3020,7 @@ class SessionParameter(Expression):


class Placeholder(Expression):
arg_types = {"this": False}
arg_types = {"this": False, "kind": False}


class Null(Condition):
Expand Down Expand Up @@ -3049,69 +3049,69 @@ class DataType(Expression):
}

class Type(AutoName):
CHAR = auto()
NCHAR = auto()
VARCHAR = auto()
NVARCHAR = auto()
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
MEDIUMBLOB = auto()
LONGBLOB = auto()
BINARY = auto()
VARBINARY = auto()
INT = auto()
UINT = auto()
TINYINT = auto()
UTINYINT = auto()
SMALLINT = auto()
USMALLINT = auto()
BIGINT = auto()
UBIGINT = auto()
INT128 = auto()
UINT128 = auto()
INT256 = auto()
UINT256 = auto()
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
ARRAY = auto()
BIGDECIMAL = auto()
BIGINT = auto()
BIGSERIAL = auto()
BINARY = auto()
BIT = auto()
BOOLEAN = auto()
JSON = auto()
JSONB = auto()
INTERVAL = auto()
TIME = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
CHAR = auto()
DATE = auto()
DATETIME = auto()
DATETIME64 = auto()
ARRAY = auto()
MAP = auto()
UUID = auto()
DECIMAL = auto()
DOUBLE = auto()
FLOAT = auto()
GEOGRAPHY = auto()
GEOMETRY = auto()
STRUCT = auto()
NULLABLE = auto()
HLLSKETCH = auto()
HSTORE = auto()
SUPER = auto()
SERIAL = auto()
SMALLSERIAL = auto()
BIGSERIAL = auto()
XML = auto()
UNIQUEIDENTIFIER = auto()
MONEY = auto()
SMALLMONEY = auto()
ROWVERSION = auto()
IMAGE = auto()
VARIANT = auto()
OBJECT = auto()
INET = auto()
INT = auto()
INT128 = auto()
INT256 = auto()
INTERVAL = auto()
JSON = auto()
JSONB = auto()
LONGBLOB = auto()
LONGTEXT = auto()
MAP = auto()
MEDIUMBLOB = auto()
MEDIUMTEXT = auto()
MONEY = auto()
NCHAR = auto()
NULL = auto()
NULLABLE = auto()
NVARCHAR = auto()
OBJECT = auto()
ROWVERSION = auto()
SERIAL = auto()
SMALLINT = auto()
SMALLMONEY = auto()
SMALLSERIAL = auto()
STRUCT = auto()
SUPER = auto()
TEXT = auto()
TIME = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
TINYINT = auto()
UBIGINT = auto()
UINT = auto()
USMALLINT = auto()
UTINYINT = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
UINT128 = auto()
UINT256 = auto()
UNIQUEIDENTIFIER = auto()
UUID = auto()
VARBINARY = auto()
VARCHAR = auto()
VARIANT = auto()
XML = auto()

TEXT_TYPES = {
Type.CHAR,
Expand Down
10 changes: 3 additions & 7 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class Parser(metaclass=_Parser):
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.MAP,
TokenType.STRUCT,
TokenType.NULLABLE,
TokenType.STRUCT,
}

TYPE_TOKENS = {
Expand Down Expand Up @@ -2255,6 +2255,7 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
(not schema and self._parse_function())
or self._parse_id_var(any_token=False)
or self._parse_string_as_identifier()
or self._parse_placeholder()
)

def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
Expand Down Expand Up @@ -2284,22 +2285,18 @@ def _parse_table(
self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()

if lateral:
return lateral

unnest = self._parse_unnest()

if unnest:
return unnest

values = self._parse_derived_table_values()

if values:
return values

subquery = self._parse_select(table=True)

if subquery:
if not subquery.args.get("pivots"):
subquery.set("pivots", self._parse_pivots())
Expand All @@ -2314,7 +2311,6 @@ def _parse_table(
table_sample = self._parse_table_sample()

alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)

if alias:
this.set("alias", alias)

Expand Down Expand Up @@ -2835,7 +2831,7 @@ def _parse_type(self) -> t.Optional[exp.Expression]:
if parser:
return parser(self, this, data_type)
return self.expression(exp.Cast, this=this, to=data_type)
if not data_type.args.get("expressions"):
if not data_type.expressions:
self._retreat(index)
return self._parse_column()
return data_type
Expand Down
13 changes: 13 additions & 0 deletions tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,19 @@ def test_ternary(self):
self.assertIsInstance(nested_ternary.args["true"], exp.Literal)
self.assertIsInstance(nested_ternary.args["false"], exp.Literal)

def test_parameterization(self):
self.validate_all(
"SELECT {abc: UInt32}, {b: String}, {c: DateTime},{d: Map(String, Array(UInt8))}, {e: Tuple(UInt8, String)}",
write={
"clickhouse": "SELECT {abc: UInt32}, {b: TEXT}, {c: DATETIME}, {d: Map(TEXT, Array(UInt8))}, {e: Tuple(UInt8, String)}",
"": "SELECT :abc, :b, :c, :d, :e",
},
)
self.validate_all(
"SELECT * FROM {table: Identifier}",
write={"clickhouse": "SELECT * FROM {table: Identifier}"},
)

def test_signed_and_unsigned_types(self):
data_types = [
"UInt8",
Expand Down

0 comments on commit 50025ea

Please sign in to comment.