diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 4fc93bfb40..5376dffdb5 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -620,7 +620,16 @@ def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat return self.sql(this) -# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator +def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: + bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) + if bad_args: + self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") + + return self.func( + "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") + ) + + def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: names = [] for agg in aggregations: diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index d7e5a436fc..1d8a7fbb7b 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -15,6 +15,7 @@ no_properties_sql, no_safe_divide_sql, pivot_column_names, + regexp_extract_sql, rename_func, str_position_sql, str_to_time_sql, @@ -88,19 +89,6 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) -def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str: - bad_args = list(filter(expression.args.get, ("position", "occurrence"))) - if bad_args: - self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}") - - return self.func( - "REGEXP_EXTRACT", - expression.args.get("this"), - expression.args.get("expression"), - expression.args.get("group"), - ) - - def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: sql = self.func("TO_JSON", expression.this, expression.args.get("options")) return f"CAST({sql} AS TEXT)" @@ -156,6 +144,9 @@ class Parser(parser.Parser): "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) + ), "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"), "STRING_SPLIT": exp.Split.from_arg_list, @@ -227,7 +218,7 @@ class Generator(generator.Generator): exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.Properties: no_properties_sql, - exp.RegexpExtract: _regexp_extract_sql, + exp.RegexpExtract: regexp_extract_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index d119eeb980..f968f6aceb 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -17,6 +17,7 @@ no_recursive_cte_sql, no_safe_divide_sql, no_trycast_sql, + regexp_extract_sql, rename_func, right_to_substring_sql, strposition_to_locate_sql, @@ -230,23 +231,24 @@ class Parser(parser.Parser): **parser.Parser.FUNCTIONS, "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, + "COLLECT_SET": exp.SetAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), - "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - expression=exp.TsOrDsToDate(this=seq_get(args, 1)), + "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( + [ + exp.TimeStrToTime(this=seq_get(args, 0)), + seq_get(args, 1), + ] ), "DATE_SUB": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)), unit=exp.Literal.string("DAY"), ), - "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( - [ - exp.TimeStrToTime(this=seq_get(args, 0)), - seq_get(args, 1), - ] + "DATEDIFF": lambda args: exp.DateDiff( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), ), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), @@ -256,7 +258,9 @@ class Parser(parser.Parser): "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, - "COLLECT_SET": exp.SetAgg.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) + ), "SIZE": exp.ArraySize.from_arg_list, "SPLIT": exp.RegexpSplit.from_arg_list, "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), @@ -363,6 +367,7 @@ class Generator(generator.Generator): exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), + exp.RegexpExtract: regexp_extract_sql, exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpSplit: rename_func("SPLIT"), exp.Right: right_to_substring_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 172158819c..7d35c67143 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -12,6 +12,7 @@ no_ilike_sql, no_pivot_sql, no_safe_divide_sql, + regexp_extract_sql, rename_func, right_to_substring_sql, struct_extract_sql, @@ -215,6 +216,9 @@ class Parser(parser.Parser): this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") ), "NOW": exp.CurrentTimestamp.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) + ), "SEQUENCE": exp.GenerateSeries.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) @@ -293,6 +297,7 @@ class Generator(generator.Generator): exp.LogicalOr: rename_func("BOOL_OR"), exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, + exp.RegexpExtract: regexp_extract_sql, exp.Right: right_to_substring_sql, exp.SafeBracket: lambda self, e: self.func( "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 19924cd743..715a84cfb0 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -223,13 +223,14 @@ class Parser(parser.Parser): "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, + "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, "TO_ARRAY": exp.Array.from_arg_list, - "TO_VARCHAR": exp.ToChar.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, + "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, } @@ -361,12 +362,12 @@ class Generator(generator.Generator): "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), ), + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, - exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToStr: lambda self, e: self.func( "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e) ), - exp.TimestampTrunc: timestamptrunc_sql, + exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), @@ -390,6 +391,24 @@ class Generator(generator.Generator): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: + # Other dialects don't support all of the following parameters, so we need to + # generate default values as necessary to ensure the transpilation is correct + group = expression.args.get("group") + parameters = expression.args.get("parameters") or (group and exp.Literal.string("c")) + occurrence = expression.args.get("occurrence") or (parameters and exp.Literal.number(1)) + position = expression.args.get("position") or (occurrence and exp.Literal.number(1)) + + return self.func( + "REGEXP_SUBSTR", + expression.this, + expression.expression, + position, + occurrence, + parameters, + group, + ) + def except_op(self, expression: exp.Except) -> str: if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 34394ce70a..1efedc7b78 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4420,6 +4420,7 @@ class RegexpExtract(Func): "expression": True, "position": False, "occurrence": False, + "parameters": False, "group": False, } diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index f7bab4d23e..e20045be55 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -1,3 +1,5 @@ +from unittest import mock + from sqlglot import UnsupportedError, exp, parse_one from tests.dialects.test_dialect import Validator @@ -309,6 +311,7 @@ def test_snowflake(self): "SELECT IFF(TRUE, 'true', 'false')", write={ "snowflake": "SELECT IFF(TRUE, 'true', 'false')", + "spark": "SELECT IF(TRUE, 'true', 'false')", }, ) self.validate_all( @@ -870,6 +873,46 @@ def test_parse_like_any(self): self.assertIsInstance(ilike, exp.ILikeAny) like.sql() # check that this doesn't raise + @mock.patch("sqlglot.generator.logger") + def test_regexp_substr(self, logger): + self.validate_all( + "REGEXP_SUBSTR(subject, pattern, pos, occ, params, group)", + write={ + "bigquery": "REGEXP_EXTRACT(subject, pattern, pos, occ)", + "hive": "REGEXP_EXTRACT(subject, pattern, group)", + "presto": "REGEXP_EXTRACT(subject, pattern, group)", + "snowflake": "REGEXP_SUBSTR(subject, pattern, pos, occ, params, group)", + "spark": "REGEXP_EXTRACT(subject, pattern, group)", + }, + ) + self.validate_all( + "REGEXP_SUBSTR(subject, pattern)", + read={ + "bigquery": "REGEXP_EXTRACT(subject, pattern)", + "hive": "REGEXP_EXTRACT(subject, pattern)", + "presto": "REGEXP_EXTRACT(subject, pattern)", + "spark": "REGEXP_EXTRACT(subject, pattern)", + }, + write={ + "bigquery": "REGEXP_EXTRACT(subject, pattern)", + "hive": "REGEXP_EXTRACT(subject, pattern)", + "presto": "REGEXP_EXTRACT(subject, pattern)", + "snowflake": "REGEXP_SUBSTR(subject, pattern)", + "spark": "REGEXP_EXTRACT(subject, pattern)", + }, + ) + self.validate_all( + "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', group)", + read={ + "bigquery": "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', group)", + "duckdb": "REGEXP_EXTRACT(subject, pattern, group)", + "hive": "REGEXP_EXTRACT(subject, pattern, group)", + "presto": "REGEXP_EXTRACT(subject, pattern, group)", + "snowflake": "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', group)", + "spark": "REGEXP_EXTRACT(subject, pattern, group)", + }, + ) + def test_match_recognize(self): for row in ( "ONE ROW PER MATCH",