diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 978d446b6b..39f87547e2 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -474,12 +474,12 @@ def _parse_types( dtype = super()._parse_types( check_func=check_func, schema=schema, allow_identifiers=allow_identifiers ) - if isinstance(dtype, exp.DataType): - # Mark every type as non-nullable which is ClickHouse's default. This marker - # helps us transpile types from other dialects to ClickHouse, so that we can - # e.g. produce `CAST(x AS Nullable(String))` from `CAST(x AS TEXT)`. If there - # is a `NULL` value in `x`, the former would fail in ClickHouse without the - # `Nullable` type constructor + if isinstance(dtype, exp.DataType) and dtype.args.get("nullable") is not True: + # Mark every type as non-nullable which is ClickHouse's default, unless it's + # already marked as nullable. This marker helps us transpile types from other + # dialects to ClickHouse, so that we can e.g. produce `CAST(x AS Nullable(String))` + # from `CAST(x AS TEXT)`. If there is a `NULL` value in `x`, the former would + # fail in ClickHouse without the `Nullable` type constructor. dtype.set("nullable", False) return dtype @@ -815,7 +815,6 @@ class Generator(generator.Generator): exp.DataType.Type.LOWCARDINALITY: "LowCardinality", exp.DataType.Type.MAP: "Map", exp.DataType.Type.NESTED: "Nested", - exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.SMALLINT: "Int16", exp.DataType.Type.STRUCT: "Tuple", exp.DataType.Type.TINYINT: "Int8", @@ -921,7 +920,6 @@ class Generator(generator.Generator): NON_NULLABLE_TYPES = { exp.DataType.Type.ARRAY, exp.DataType.Type.MAP, - exp.DataType.Type.NULLABLE, exp.DataType.Type.STRUCT, } @@ -1004,8 +1002,9 @@ def datatype_sql(self, expression: exp.DataType) -> str: # String or FixedString (possibly LowCardinality) or UUID or IPv6" # - It's not a composite type, e.g. `Nullable(Array(...))` is not a valid type parent = expression.parent - if ( - expression.args.get("nullable") is not False + nullable = expression.args.get("nullable") + if nullable is True or ( + nullable is None and not ( isinstance(parent, exp.DataType) and parent.is_type(exp.DataType.Type.MAP, check_nullable=True) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 894fc0820c..78a935939a 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -187,9 +187,6 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[ if enum not in ("", "bigquery"): klass.generator_class.SELECT_KINDS = () - if enum not in ("", "clickhouse"): - klass.generator_class.SUPPORTS_NULLABLE_TYPES = False - if enum not in ("", "athena", "presto", "trino"): klass.generator_class.TRY_SUPPORTED = False klass.generator_class.SUPPORTS_UESCAPE = False diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1119c45fea..a14c32d376 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4118,7 +4118,6 @@ class Type(AutoName): NCHAR = auto() NESTED = auto() NULL = auto() - NULLABLE = auto() NUMMULTIRANGE = auto() NUMRANGE = auto() NVARCHAR = auto() @@ -4312,32 +4311,19 @@ def is_type(self, *dtypes: DATA_TYPE, check_nullable: bool = False) -> bool: Returns: True, if and only if there is a type in `dtypes` which is equal to this DataType. """ - if ( - not check_nullable - and self.this == DataType.Type.NULLABLE - and len(self.expressions) == 1 - ): - this_type = self.expressions[0] - else: - this_type = self - + self_is_nullable = self.args.get("nullable") for dtype in dtypes: other_type = DataType.build(dtype, copy=False, udt=True) - if ( - not check_nullable - and other_type.this == DataType.Type.NULLABLE - and len(other_type.expressions) == 1 - ): - other_type = other_type.expressions[0] - + other_is_nullable = other_type.args.get("nullable") if ( other_type.expressions - or this_type.this == DataType.Type.USERDEFINED + or (check_nullable and (self_is_nullable or other_is_nullable)) + or self.this == DataType.Type.USERDEFINED or other_type.this == DataType.Type.USERDEFINED ): - matches = this_type == other_type + matches = self == other_type else: - matches = this_type.this == other_type.this + matches = self.this == other_type.this if matches: return True diff --git a/sqlglot/generator.py b/sqlglot/generator.py index f0470d2880..0e8888a6f0 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -390,9 +390,6 @@ class Generator(metaclass=_Generator): # Whether CONVERT_TIMEZONE() is supported; if not, it will be generated as exp.AtTimeZone SUPPORTS_CONVERT_TIMEZONE = False - # Whether nullable types can be constructed, e.g. `Nullable(Int64)` - SUPPORTS_NULLABLE_TYPES = True - # The name to generate for the JSONPath expression. If `None`, only `this` will be generated PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON" @@ -1239,14 +1236,12 @@ def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"): type_sql = self.sql(expression, "kind") - elif type_value != exp.DataType.Type.NULLABLE or self.SUPPORTS_NULLABLE_TYPES: + else: type_sql = ( self.TYPE_MAPPING.get(type_value, type_value.value) if isinstance(type_value, exp.DataType.Type) else type_value ) - else: - return interior if interior: if expression.args.get("nested"): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 8fb3c14d51..2ae75e1780 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -4753,6 +4753,11 @@ def _parse_types( check_func=check_func, schema=schema, allow_identifiers=allow_identifiers ) ) + if type_token == TokenType.NULLABLE and len(expressions) == 1: + this = expressions[0] + this.set("nullable", True) + self._match_r_paren() + return this elif type_token in self.ENUM_TYPE_TOKENS: expressions = self._parse_csv(self._parse_equality) elif is_aggregate: diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 9bb00de50e..8e2ac45bd7 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1035,7 +1035,6 @@ def test_data_type_builder(self): self.assertEqual(exp.DataType.build("GEOGRAPHY").sql(), "GEOGRAPHY") self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY") self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT") - self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE") self.assertEqual(exp.DataType.build("HLLSKETCH", dialect="redshift").sql(), "HLLSKETCH") self.assertEqual(exp.DataType.build("HSTORE", dialect="postgres").sql(), "HSTORE") self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")