diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index d30bbc5e29a63..fc063e8fbaf93 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -104,14 +104,21 @@ pub struct PythonUdfExpression { python_function: PyObject, output_type: Option, is_elementwise: bool, + returns_scalar: bool, } impl PythonUdfExpression { - pub fn new(lambda: PyObject, output_type: Option, is_elementwise: bool) -> Self { + pub fn new( + lambda: PyObject, + output_type: Option, + is_elementwise: bool, + returns_scalar: bool, + ) -> Self { Self { python_function: lambda, output_type, is_elementwise, + returns_scalar, } } @@ -121,7 +128,7 @@ impl PythonUdfExpression { // skip header let buf = &buf[MAGIC_BYTE_MARK.len()..]; let mut reader = Cursor::new(buf); - let (output_type, is_elementwise): (Option, bool) = + let (output_type, is_elementwise, returns_scalar): (Option, bool, bool) = ciborium::de::from_reader(&mut reader).map_err(map_err)?; let remainder = &buf[reader.position() as usize..]; @@ -138,6 +145,7 @@ impl PythonUdfExpression { python_function.into(), output_type, is_elementwise, + returns_scalar, )) as Arc) }) } @@ -181,8 +189,15 @@ impl SeriesUdf for PythonUdfExpression { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { buf.extend_from_slice(MAGIC_BYTE_MARK); - ciborium::ser::into_writer(&(self.output_type.clone(), self.is_elementwise), &mut *buf) - .unwrap(); + ciborium::ser::into_writer( + &( + self.output_type.clone(), + self.is_elementwise, + self.returns_scalar, + ), + &mut *buf, + ) + .unwrap(); Python::with_gil(|py| { let pickle = PyModule::import_bound(py, "cloudpickle") @@ -222,6 +237,7 @@ impl Expr { (ApplyOptions::GroupWise, "python_udf") }; + let returns_scalar = func.returns_scalar; let return_dtype = func.output_type.clone(); let output_type = GetOutput::map_field(move |fld| match return_dtype { Some(ref dt) => Field::new(fld.name(), dt.clone()), @@ -239,6 +255,7 @@ impl Expr { options: FunctionOptions { collect_groups, fmt_str: name, + returns_scalar, ..Default::default() }, } diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 86f8c92c5a7da..d0058821fd7ac 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -4524,21 +4524,22 @@ def map_batches( *, agg_list: bool = False, is_elementwise: bool = False, + returns_scalar: bool = False, ) -> Self: """ Apply a custom python function to a whole Series or sequence of Series. - The output of this custom function must be a Series (or a NumPy array, in which - case it will be automatically converted into a Series). If you want to apply a + The output of this custom function is presumed to be either a Series, + or a NumPy array (in which case it will be automatically converted into + a Series), or a scalar that will be converted into a Series. If the + result is a scalar and you want it to stay as a scalar, pass in + ``returns_scalar=True``. If you want to apply a custom function elementwise over single values, see :func:`map_elements`. A reasonable use case for `map` functions is transforming the values represented by an expression using a third-party library. - .. warning:: - If you are looking to map a function over a window function or group_by - context, refer to :func:`map_elements` instead. - Read more in `the book - `_. + If your function returns a scalar, for example a float, use + :func:`map_to_scalar` instead. Parameters ---------- @@ -4556,6 +4557,11 @@ def map_batches( function. This parameter only works in a group-by context. The function will be invoked only once on a list of groups, rather than once per group. + returns_scalar + If the function returns a scalar, by default it will be wrapped in + a list in the output, since the assumption is that the function + always returns something Series-like. If you want to keep the + result as a scalar, set this argument to True. Warnings -------- @@ -4597,7 +4603,7 @@ def map_batches( ... } ... ) >>> df.group_by("a").agg( - ... pl.col("b").map_batches(lambda x: x.max(), agg_list=False) + ... pl.col("b").map_batches(lambda x: x + 2, agg_list=False) ... ) # doctest: +IGNORE_RESULT shape: (2, 2) ┌─────┬───────────┐ @@ -4605,15 +4611,39 @@ def map_batches( │ --- ┆ --- │ │ i64 ┆ list[i64] │ ╞═════╪═══════════╡ - │ 1 ┆ [4] │ - │ 0 ┆ [3] │ + │ 1 ┆ [4, 6] │ + │ 0 ┆ [3, 5] │ └─────┴───────────┘ Using `agg_list=True` would be more efficient. In this example, the input of the function is a Series of type `List(Int64)`. >>> df.group_by("a").agg( - ... pl.col("b").map_batches(lambda x: x.list.max(), agg_list=True) + ... pl.col("b").map_batches( + ... lambda x: x.list.eval(pl.element() + 2), agg_list=True + ... ) + ... ) # doctest: +IGNORE_RESULT + shape: (2, 2) + ┌─────┬───────────┐ + │ a ┆ b │ + │ --- ┆ --- │ + │ i64 ┆ list[i64] │ + ╞═════╪═══════════╡ + │ 0 ┆ [3, 5] │ + │ 1 ┆ [4, 6] │ + └─────┴───────────┘ + + Here's an example of a function that returns a scalar, where we want it + to stay as a scalar: + + >>> df = pl.DataFrame( + ... { + ... "a": [0, 1, 0, 1], + ... "b": [1, 2, 3, 4], + ... } + ... ) + >>> df.group_by("a").agg( + ... pl.col("b").map_batches(lambda x: x.max(), returns_scalar=True) ... ) # doctest: +IGNORE_RESULT shape: (2, 2) ┌─────┬─────┐ @@ -4621,9 +4651,10 @@ def map_batches( │ --- ┆ --- │ │ i64 ┆ i64 │ ╞═════╪═════╡ - │ 0 ┆ 3 │ │ 1 ┆ 4 │ + │ 0 ┆ 3 │ └─────┴─────┘ + """ if return_dtype is not None: return_dtype = py_type_to_dtype(return_dtype) @@ -4634,6 +4665,7 @@ def map_batches( return_dtype, agg_list, is_elementwise, + returns_scalar, ) ) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index e61213175d6f5..7c62d1765045d 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1380,6 +1380,7 @@ def __array_ufunc__( # Only generalized ufuncs have a signature set: is_generalized_ufunc = bool(ufunc.signature) + if is_generalized_ufunc: # Generalized ufuncs will operate on the whole array, so # missing data can corrupt the results. @@ -1392,7 +1393,13 @@ def __array_ufunc__( # output size. assert ufunc.signature is not None # pacify MyPy ufunc_input, ufunc_output = ufunc.signature.split("->") - allocate_output = ufunc_input == ufunc_output + if ufunc_output == "()": + # If the result a scalar, just let the function do its + # thing, no need for any song and dance involving + # allocation: + return ufunc(*args, dtype=dtype_char, **kwargs) + else: + allocate_output = ufunc_input == ufunc_output else: allocate_output = True @@ -1409,6 +1416,7 @@ def __array_ufunc__( lambda out: ufunc(*args, out=out, dtype=dtype_char, **kwargs), allocate_output, ) + result = self._from_pyseries(series) if is_generalized_ufunc: # In this case we've disallowed passing in missing data, so no @@ -1426,7 +1434,6 @@ def __array_ufunc__( .select(F.when(validity_mask).then(F.col(self.name))) .to_series(0) ) - else: msg = ( "only `__call__` is implemented for numpy ufuncs on a Series, got " diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index c3aa832d94292..9e2475d996581 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -755,15 +755,23 @@ impl PyExpr { self.inner.clone().shrink_dtype().into() } - #[pyo3(signature = (lambda, output_type, agg_list, is_elementwise))] + #[pyo3(signature = (lambda, output_type, agg_list, is_elementwise, returns_scalar))] fn map_batches( &self, lambda: PyObject, output_type: Option>, agg_list: bool, is_elementwise: bool, + returns_scalar: bool, ) -> Self { - map_single(self, lambda, output_type, agg_list, is_elementwise) + map_single( + self, + lambda, + output_type, + agg_list, + is_elementwise, + returns_scalar, + ) } fn dot(&self, other: Self) -> Self { diff --git a/py-polars/src/map/lazy.rs b/py-polars/src/map/lazy.rs index 88b05beb16cc2..5779ca0c021b0 100644 --- a/py-polars/src/map/lazy.rs +++ b/py-polars/src/map/lazy.rs @@ -129,10 +129,12 @@ pub fn map_single( output_type: Option>, agg_list: bool, is_elementwise: bool, + returns_scalar: bool, ) -> PyExpr { let output_type = output_type.map(|wrap| wrap.0); - let func = python_udf::PythonUdfExpression::new(lambda, output_type, is_elementwise); + let func = + python_udf::PythonUdfExpression::new(lambda, output_type, is_elementwise, returns_scalar); pyexpr.inner.clone().map_python(func, agg_list).into() } diff --git a/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py b/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py index de50592da38ef..ad288c03593d5 100644 --- a/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py +++ b/py-polars/tests/unit/interop/numpy/test_ufunc_expr.py @@ -133,6 +133,44 @@ def test_grouped_ufunc() -> None: df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1)) +def test_generalized_ufunc_scalar() -> None: + numba = pytest.importorskip("numba") + + @numba.guvectorize([(numba.int64[:], numba.int64[:])], "(n)->()") # type: ignore[misc] + def my_custom_sum(arr, result) -> None: # type: ignore[no-untyped-def] + total = 0 + for value in arr: + total += value + result[0] = total + + # Make type checkers happy: + custom_sum = cast(Callable[[object], object], my_custom_sum) + + # Demonstrate NumPy as the canonical expected behavior: + assert custom_sum(np.array([10, 2, 3], dtype=np.int64)) == 15 + + # Direct call of the gufunc: + df = pl.DataFrame({"values": [10, 2, 3]}) + assert custom_sum(df.get_column("values")) == 15 + + # Indirect call of the gufunc: + indirect = df.select(pl.col("values").map_batches(custom_sum, returns_scalar=True)) + assert_frame_equal(indirect, pl.DataFrame({"values": 15})) + indirect = df.select(pl.col("values").map_batches(custom_sum, returns_scalar=False)) + assert_frame_equal(indirect, pl.DataFrame({"values": [15]})) + + # group_by() + df = pl.DataFrame({"labels": ["a", "b", "a", "b"], "values": [10, 2, 3, 30]}) + indirect = ( + df.group_by("labels") + .agg(pl.col("values").map_batches(custom_sum, returns_scalar=True)) + .sort("labels") + ) + assert_frame_equal( + indirect, pl.DataFrame({"labels": ["a", "b"], "values": [13, 32]}) + ) + + def make_gufunc_mean() -> Callable[[pl.Series], pl.Series]: numba = pytest.importorskip("numba") diff --git a/py-polars/tests/unit/operations/map/test_map_batches.py b/py-polars/tests/unit/operations/map/test_map_batches.py index 457df189fa000..e4929a6cb13ea 100644 --- a/py-polars/tests/unit/operations/map/test_map_batches.py +++ b/py-polars/tests/unit/operations/map/test_map_batches.py @@ -68,6 +68,10 @@ def test_map_batches_group() -> None: assert df.group_by("id").agg(pl.col("t").map_batches(lambda s: s.sum())).sort( "id" ).to_dict(as_series=False) == {"id": [0, 1], "t": [[11], [35]]} + # If returns_scalar is True, the result won't be wrapped in a list: + assert df.group_by("id").agg( + pl.col("t").map_batches(lambda s: s.sum(), returns_scalar=True) + ).sort("id").to_dict(as_series=False) == {"id": [0, 1], "t": [11, 35]} def test_map_deprecated() -> None: @@ -82,16 +86,10 @@ def test_map_deprecated() -> None: def test_ufunc_args() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]}) result = df.select( - z=np.add( # type: ignore[call-overload] - pl.col("a"), pl.col("b") - ) + z=np.add(pl.col("a"), pl.col("b")) # type: ignore[call-overload] ) expected = pl.DataFrame({"z": [3, 6, 9]}) assert_frame_equal(result, expected) - result = df.select( - z=np.add( # type: ignore[call-overload] - 2, pl.col("a") - ) - ) + result = df.select(z=np.add(2, pl.col("a"))) # type: ignore[call-overload] expected = pl.DataFrame({"z": [3, 4, 5]}) assert_frame_equal(result, expected)