Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ir): make impure ibis.random() and ibis.uuid() functions return unique node instances #8967

Merged
merged 12 commits into from
Apr 15, 2024
5 changes: 3 additions & 2 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ class BigQueryCompiler(SQLGlotCompiler):
ops.RPad: "rpad",
ops.Levenshtein: "edit_distance",
ops.Modulus: "mod",
ops.RandomScalar: "rand",
ops.RandomUUID: "generate_uuid",
ops.RegexReplace: "regexp_replace",
ops.RegexSearch: "regexp_contains",
ops.Time: "time",
Expand Down Expand Up @@ -698,3 +696,6 @@ def visit_CountDistinct(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_RandomUUID(self, op, **kwargs):
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
return self.f.generate_uuid()
8 changes: 6 additions & 2 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.NotNull: "isNotNull",
ops.NullIf: "nullIf",
ops.RStrip: "trimRight",
ops.RandomScalar: "randCanonical",
ops.RandomUUID: "generateUUIDv4",
ops.RegexReplace: "replaceRegexpAll",
ops.RowNumber: "row_number",
ops.StartsWith: "startsWith",
Expand Down Expand Up @@ -637,6 +635,12 @@ def visit_TimestampRange(self, op, *, start, stop, step):
def visit_RegexSplit(self, op, *, arg, pattern):
return self.f.splitByRegexp(pattern, self.cast(arg, dt.String(nullable=False)))

def visit_RandomScalar(self, op, **kwargs):
return self.f.randCanonical()

def visit_RandomUUID(self, op, **kwargs):
return self.f.generateUUIDv4()

@staticmethod
def _generate_groups(groups):
return groups
1 change: 0 additions & 1 deletion ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class DataFusionCompiler(SQLGlotCompiler):
ops.Last: "last_value",
ops.Median: "median",
ops.StringLength: "character_length",
ops.RandomUUID: "uuid",
ops.RegexSplit: "regex_split",
ops.EndsWith: "ends_with",
ops.ArrayIntersect: "array_intersect",
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class DruidCompiler(SQLGlotCompiler):
ops.Median,
ops.MultiQuantile,
ops.Quantile,
ops.RandomUUID,
ops.RegexReplace,
ops.RegexSplit,
ops.RowID,
Expand Down
7 changes: 6 additions & 1 deletion ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class DuckDBCompiler(SQLGlotCompiler):
ops.MapMerge: "map_concat",
ops.MapValues: "map_values",
ops.Mode: "mode",
ops.RandomUUID: "uuid",
ops.TimeFromHMS: "make_time",
ops.TypeOf: "typeof",
ops.GeoPoint: "st_point",
Expand Down Expand Up @@ -418,3 +417,9 @@ def visit_StructField(self, op, *, arg, field):
expression=sg.to_identifier(field, quoted=self.quoted),
)
return super().visit_StructField(op, arg=arg, field=field)

def visit_RandomScalar(self, op, **kwargs):
return self.f.random()

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid()
1 change: 1 addition & 0 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class ExasolCompiler(SQLGlotCompiler):
ops.Median,
ops.MultiQuantile,
ops.Quantile,
ops.RandomUUID,
ops.ReductionVectorizedUDF,
ops.RegexExtract,
ops.RegexReplace,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/flink/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.MapKeys: "map_keys",
ops.MapValues: "map_values",
ops.Power: "power",
ops.RandomScalar: "rand",
ops.RandomUUID: "uuid",
ops.RegexSearch: "regexp",
ops.StrRight: "right",
ops.StringLength: "char_length",
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class ImpalaCompiler(SQLGlotCompiler):
ops.Hash: "fnv_hash",
ops.LStrip: "ltrim",
ops.Ln: "ln",
ops.RandomUUID: "uuid",
ops.RStrip: "rtrim",
ops.Strip: "trim",
ops.TypeOf: "typeof",
Expand Down Expand Up @@ -146,7 +145,7 @@ def visit_CountDistinct(self, op, *, arg, where):
def visit_Xor(self, op, *, left, right):
return sg.and_(sg.or_(left, right), sg.not_(sg.and_(left, right)))

def visit_RandomScalar(self, op):
def visit_RandomScalar(self, op, **_):
return self.f.rand(self.f.utc_to_unix_micros(self.f.utc_timestamp()))

def visit_DayOfWeekIndex(self, op, *, arg):
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ class MSSQLCompiler(SQLGlotCompiler):
ops.Ln: "log",
ops.Log10: "log10",
ops.Power: "power",
ops.RandomScalar: "rand",
ops.RandomUUID: "newid",
ops.Repeat: "replicate",
ops.Reverse: "reverse",
ops.StringAscii: "ascii",
Expand Down Expand Up @@ -172,6 +170,9 @@ def _minimize_spec(start, end, spec):
return None
return spec

def visit_RandomUUID(self, op, **kwargs):
return self.f.newid()

def visit_StringLength(self, op, *, arg):
"""The MSSQL LEN function doesn't count trailing spaces.

Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/oracle/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class OracleCompiler(SQLGlotCompiler):
ops.ExtractWeekOfYear,
ops.ExtractDayOfYear,
ops.RowID,
ops.RandomUUID,
)
)

Expand Down Expand Up @@ -221,7 +222,7 @@ def visit_Log(self, op, *, arg, base):
def visit_IsInf(self, op, *, arg):
return arg.isin(self.POS_INF, self.NEG_INF)

def visit_RandomScalar(self, op):
def visit_RandomScalar(self, op, **_):
# Not using FuncGen here because of dotted function call
return sg.func("dbms_random.value")

Expand Down
8 changes: 0 additions & 8 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,6 @@ def visit(cls, op: ops.Cast, arg, to):
else:
return PandasConverter.convert_scalar(arg, to)

@classmethod
def visit(cls, op: ops.TypeOf, arg):
raise OperationNotDefinedError("TypeOf is not implemented")

@classmethod
def visit(cls, op: ops.RandomScalar):
raise OperationNotDefinedError("RandomScalar is not implemented")

@classmethod
def visit(cls, op: ops.Greatest, arg):
return cls.columnwise(lambda df: df.max(axis=1), arg)
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class PostgresCompiler(SQLGlotCompiler):
ops.MapContains: "exist",
ops.MapKeys: "akeys",
ops.MapValues: "avals",
ops.RandomUUID: "gen_random_uuid",
ops.RegexSearch: "regexp_like",
ops.TimeFromHMS: "make_time",
}
Expand All @@ -111,6 +110,9 @@ def _aggregate(self, funcname: str, *args, where):
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_RandomUUID(self, op, **kwargs):
return self.f.gen_random_uuid()

def visit_Mode(self, op, *, arg, where):
expr = self.f.mode()
expr = sge.WithinGroup(
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class PySparkCompiler(SQLGlotCompiler):
(
ops.RowID,
ops.TimestampBucket,
ops.RandomUUID,
)
)

Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
ops.Hash: "hash",
ops.Median: "median",
ops.Mode: "mode",
ops.RandomUUID: "uuid_string",
ops.StringToTimestamp: "to_timestamp_tz",
ops.TimeFromHMS: "time_from_parts",
ops.TimestampFromYMDHMS: "timestamp_from_parts",
Expand Down Expand Up @@ -241,11 +240,14 @@
def visit_Log(self, op, *, arg, base):
return self.f.log(base, arg, dialect=self.dialect)

def visit_RandomScalar(self, op):
def visit_RandomScalar(self, op, **kwargs):
return self.f.uniform(
self.f.to_double(0.0), self.f.to_double(1.0), self.f.random()
)

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid_string()

Check warning on line 250 in ibis/backends/snowflake/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/snowflake/compiler.py#L250

Added line #L250 was not covered by tests
def visit_ApproxMedian(self, op, *, arg, where):
return self.agg.approx_percentile(arg, 0.5, where=where)

Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ class SQLGlotCompiler(abc.ABC):
ops.Power: "pow",
ops.RPad: "rpad",
ops.Radians: "radians",
ops.RandomScalar: "random",
ops.RegexSearch: "regexp_like",
ops.RegexSplit: "regexp_split",
ops.Repeat: "repeat",
Expand Down Expand Up @@ -687,6 +686,14 @@ def visit_Round(self, op, *, arg, digits):
return sge.Round(this=arg, decimals=digits)
return sge.Round(this=arg)

### Random Noise

def visit_RandomScalar(self, op, **kwargs):
return self.f.rand()

def visit_RandomUUID(self, op, **kwargs):
return self.f.uuid()

### Dtype Dysmorphia

def visit_TryCast(self, op, *, arg, to):
Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,13 @@ def merge_select_select(_, **kwargs):
from the inner Select are inlined into the outer Select.
"""
# don't merge if either the outer or the inner select has window functions
blocking = (ops.WindowFunction, ops.ExistsSubquery, ops.InSubquery, ops.Unnest)
blocking = (
ops.WindowFunction,
ops.ExistsSubquery,
ops.InSubquery,
ops.Unnest,
ops.Impure,
)
if _.find_below(blocking, filter=ops.Value):
return _
if _.parent.find_below(blocking, filter=ops.Value):
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.Mode: "_ibis_mode",
ops.Time: "time",
ops.Date: "date",
ops.RandomUUID: "uuid",
}

def _aggregate(self, funcname: str, *args, where):
Expand Down Expand Up @@ -213,7 +212,7 @@ def visit_Clip(self, op, *, arg, lower, upper):

return arg

def visit_RandomScalar(self, op):
def visit_RandomScalar(self, op, **kwargs):
return 0.5 + self.f.random() / sge.Literal.number(float(-1 << 64))

def visit_Cot(self, op, *, arg):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
`t1`.`x`,
`t1`.`y`,
`t1`.`z`,
IF(`t1`.`y` = `t1`.`z`, 'big', 'small') AS `size`
FROM (
SELECT
`t0`.`x`,
RAND() AS `y`,
RAND() AS `z`
FROM `t` AS `t0`
) AS `t1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
`t1`.`x`,
`t1`.`y`,
`t1`.`z`,
IF(`t1`.`y` = `t1`.`z`, 'big', 'small') AS `size`
FROM (
SELECT
`t0`.`x`,
generate_uuid() AS `y`,
generate_uuid() AS `z`
FROM `t` AS `t0`
) AS `t1`
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
randCanonical() AS "y",
randCanonical() AS "z"
FROM "t" AS "t0"
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
generateUUIDv4() AS "y",
generateUUIDv4() AS "z"
FROM "t" AS "t0"
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
RANDOM() AS "y",
RANDOM() AS "z"
FROM "t" AS "t0"
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
UUID() AS "y",
UUID() AS "z"
FROM "t" AS "t0"
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
RANDOM() AS "y",
RANDOM() AS "z"
FROM "t" AS "t0"
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
RANDOM() AS "y",
RANDOM() AS "z"
FROM "t" AS "t0"
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."x",
"t1"."y",
"t1"."z",
CASE WHEN "t1"."y" = "t1"."z" THEN 'big' ELSE 'small' END AS "size"
FROM (
SELECT
"t0"."x",
UUID() AS "y",
UUID() AS "z"
FROM "t" AS "t0"
) AS "t1"
Loading
Loading