Skip to content

Commit

Permalink
fix: Refactor exp.RegexpExtract (follow up 4326) (#4341)
Browse files Browse the repository at this point in the history
  • Loading branch information
VaggelisD authored Nov 4, 2024
1 parent c09b6a2 commit def4f1e
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 44 deletions.
17 changes: 7 additions & 10 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,13 @@ class Generator(generator.Generator):
exp.MD5Digest: rename_func("MD5"),
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.RegexpExtract: lambda self, e: self.func(
"REGEXP_EXTRACT",
e.this,
e.expression,
e.args.get("position"),
e.args.get("occurrence"),
),
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.ReturnsProperty: _returnsproperty_sql,
Expand Down Expand Up @@ -1042,13 +1049,3 @@ def version_sql(self, expression: exp.Version) -> str:
if expression.name == "TIMESTAMP":
expression.set("this", "SYSTEM_TIME")
return super().version_sql(expression)

@generator.unsupported_args("group", "parameters")
def regexpextract_sql(self, e: exp.RegexpExtract) -> str:
return self.func(
"REGEXP_EXTRACT",
e.this,
e.expression,
e.args.get("position"),
e.args.get("occurrence"),
)
21 changes: 9 additions & 12 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
build_regexp_extract,
explode_to_unnest_sql,
)
from sqlglot.generator import unsupported_args
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
from sqlglot.parser import binary_range_parser
Expand Down Expand Up @@ -103,7 +104,7 @@ def _timediff_sql(self: DuckDB.Generator, expression: exp.TimeDiff) -> str:
return self.func("DATE_DIFF", unit_to_str(expression), expr, this)


@generator.unsupported_args(("expression", "DuckDB's ARRAY_SORT does not support a comparator."))
@unsupported_args(("expression", "DuckDB's ARRAY_SORT does not support a comparator."))
def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str:
return self.func("ARRAY_SORT", expression.this)

Expand Down Expand Up @@ -953,22 +954,18 @@ def arraytostring_sql(self, expression: exp.ArrayToString) -> str:

return self.func("ARRAY_TO_STRING", this, expression.expression)

@generator.unsupported_args("position", "occurrence")
@unsupported_args("position", "occurrence")
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
group = expression.args.get("group")
params = expression.args.get("parameters")

if params and params.name == "":
params = None

# Do not render group if it's the default value for this dialect
if (
not params
and group
and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP)
):
if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP):
group = None

return self.func(
"REGEXP_EXTRACT", expression.this, expression.expression, group, params
"REGEXP_EXTRACT",
expression.this,
expression.expression,
group,
expression.args.get("parameters"),
)
3 changes: 2 additions & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
no_timestamp_sql,
timestampdiff_sql,
)
from sqlglot.generator import unsupported_args
from sqlglot.helper import flatten, is_float, is_int, seq_get
from sqlglot.tokens import TokenType

Expand Down Expand Up @@ -1067,7 +1068,7 @@ def struct_sql(self, expression: exp.Struct) -> str:

return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values)))

@generator.unsupported_args("weight", "accuracy")
@unsupported_args("weight", "accuracy")
def approxquantile_sql(self, expression: exp.ApproxQuantile) -> str:
return self.func("APPROX_PERCENTILE", expression.this, expression.args.get("quantile"))

Expand Down
21 changes: 1 addition & 20 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,14 +761,7 @@ def test_duckdb(self):
write="duckdb",
unsupported_level=ErrorLevel.IMMEDIATE,
)
with self.assertRaises(UnsupportedError):
# duckdb has the group arg, but bq doesn't
transpile(
"SELECT REGEXP_EXTRACT(a, 'pattern', 2) from table",
read="duckdb",
write="bigquery",
unsupported_level=ErrorLevel.IMMEDIATE,
)

self.validate_all(
"SELECT REGEXP_EXTRACT(a, 'pattern') FROM t",
read={
Expand All @@ -782,21 +775,9 @@ def test_duckdb(self):
"snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern') FROM t",
},
)
self.validate_all(
"SELECT REGEXP_EXTRACT(a, 'pattern', 2) FROM t",
read={
"duckdb": "SELECT REGEXP_EXTRACT(a, 'pattern', 2) FROM t",
"snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern', 1, 1, '', 2) FROM t",
},
write={
"duckdb": "SELECT REGEXP_EXTRACT(a, 'pattern', 2) FROM t",
"snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern', 1, 1, 'c', 2) FROM t",
},
)
self.validate_all(
"SELECT REGEXP_EXTRACT(a, 'pattern', 2, 'i') FROM t",
read={
"duckdb": "SELECT REGEXP_EXTRACT(a, 'pattern', 2, 'i') FROM t",
"snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern', 1, 1, 'i', 2) FROM t",
},
write={
Expand Down
1 change: 0 additions & 1 deletion tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,7 +1763,6 @@ def test_regexp_substr(self, logger):
"REGEXP_SUBSTR(subject, pattern)",
read={
"bigquery": "REGEXP_EXTRACT(subject, pattern)",
"snowflake": "REGEXP_EXTRACT(subject, pattern)",
},
write={
"bigquery": "REGEXP_EXTRACT(subject, pattern)",
Expand Down

0 comments on commit def4f1e

Please sign in to comment.