Skip to content

Commit

Permalink
fix(python): improve support for user-defined functions that return s…
Browse files Browse the repository at this point in the history
…calars (pola-rs#16556)

Co-authored-by: Itamar Turner-Trauring <[email protected]>
  • Loading branch information
2 people authored and Wouittone committed Jun 22, 2024
1 parent 5ebbc39 commit fa46225
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 29 deletions.
25 changes: 21 additions & 4 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,21 @@ pub struct PythonUdfExpression {
python_function: PyObject,
output_type: Option<DataType>,
is_elementwise: bool,
returns_scalar: bool,
}

impl PythonUdfExpression {
pub fn new(lambda: PyObject, output_type: Option<DataType>, is_elementwise: bool) -> Self {
pub fn new(
lambda: PyObject,
output_type: Option<DataType>,
is_elementwise: bool,
returns_scalar: bool,
) -> Self {
Self {
python_function: lambda,
output_type,
is_elementwise,
returns_scalar,
}
}

Expand All @@ -121,7 +128,7 @@ impl PythonUdfExpression {
// skip header
let buf = &buf[MAGIC_BYTE_MARK.len()..];
let mut reader = Cursor::new(buf);
let (output_type, is_elementwise): (Option<DataType>, bool) =
let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) =
ciborium::de::from_reader(&mut reader).map_err(map_err)?;

let remainder = &buf[reader.position() as usize..];
Expand All @@ -138,6 +145,7 @@ impl PythonUdfExpression {
python_function.into(),
output_type,
is_elementwise,
returns_scalar,
)) as Arc<dyn SeriesUdf>)
})
}
Expand Down Expand Up @@ -181,8 +189,15 @@ impl SeriesUdf for PythonUdfExpression {
#[cfg(feature = "serde")]
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
buf.extend_from_slice(MAGIC_BYTE_MARK);
ciborium::ser::into_writer(&(self.output_type.clone(), self.is_elementwise), &mut *buf)
.unwrap();
ciborium::ser::into_writer(
&(
self.output_type.clone(),
self.is_elementwise,
self.returns_scalar,
),
&mut *buf,
)
.unwrap();

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
Expand Down Expand Up @@ -222,6 +237,7 @@ impl Expr {
(ApplyOptions::GroupWise, "python_udf")
};

let returns_scalar = func.returns_scalar;
let return_dtype = func.output_type.clone();
let output_type = GetOutput::map_field(move |fld| match return_dtype {
Some(ref dt) => Field::new(fld.name(), dt.clone()),
Expand All @@ -239,6 +255,7 @@ impl Expr {
options: FunctionOptions {
collect_groups,
fmt_str: name,
returns_scalar,
..Default::default()
},
}
Expand Down
56 changes: 44 additions & 12 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4524,21 +4524,22 @@ def map_batches(
*,
agg_list: bool = False,
is_elementwise: bool = False,
returns_scalar: bool = False,
) -> Self:
"""
Apply a custom python function to a whole Series or sequence of Series.
The output of this custom function must be a Series (or a NumPy array, in which
case it will be automatically converted into a Series). If you want to apply a
The output of this custom function is presumed to be either a Series,
or a NumPy array (in which case it will be automatically converted into
a Series), or a scalar that will be converted into a Series. If the
result is a scalar and you want it to stay as a scalar, pass in
``returns_scalar=True``. If you want to apply a
custom function elementwise over single values, see :func:`map_elements`.
A reasonable use case for `map` functions is transforming the values
represented by an expression using a third-party library.
.. warning::
If you are looking to map a function over a window function or group_by
context, refer to :func:`map_elements` instead.
Read more in `the book
<https://docs.pola.rs/user-guide/expressions/user-defined-functions>`_.
If your function returns a scalar, for example a float, use
:func:`map_to_scalar` instead.
Parameters
----------
Expand All @@ -4556,6 +4557,11 @@ def map_batches(
function. This parameter only works in a group-by context.
The function will be invoked only once on a list of groups, rather than
once per group.
returns_scalar
If the function returns a scalar, by default it will be wrapped in
a list in the output, since the assumption is that the function
always returns something Series-like. If you want to keep the
result as a scalar, set this argument to True.
Warnings
--------
Expand Down Expand Up @@ -4597,33 +4603,58 @@ def map_batches(
... }
... )
>>> df.group_by("a").agg(
... pl.col("b").map_batches(lambda x: x.max(), agg_list=False)
... pl.col("b").map_batches(lambda x: x + 2, agg_list=False)
... ) # doctest: +IGNORE_RESULT
shape: (2, 2)
┌─────┬───────────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ list[i64] │
╞═════╪═══════════╡
│ 1 ┆ [4]
│ 0 ┆ [3]
│ 1 ┆ [4, 6]
│ 0 ┆ [3, 5]
└─────┴───────────┘
Using `agg_list=True` would be more efficient. In this example, the input of
the function is a Series of type `List(Int64)`.
>>> df.group_by("a").agg(
... pl.col("b").map_batches(lambda x: x.list.max(), agg_list=True)
... pl.col("b").map_batches(
... lambda x: x.list.eval(pl.element() + 2), agg_list=True
... )
... ) # doctest: +IGNORE_RESULT
shape: (2, 2)
┌─────┬───────────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ list[i64] │
╞═════╪═══════════╡
│ 0 ┆ [3, 5] │
│ 1 ┆ [4, 6] │
└─────┴───────────┘
Here's an example of a function that returns a scalar, where we want it
to stay as a scalar:
>>> df = pl.DataFrame(
... {
... "a": [0, 1, 0, 1],
... "b": [1, 2, 3, 4],
... }
... )
>>> df.group_by("a").agg(
... pl.col("b").map_batches(lambda x: x.max(), returns_scalar=True)
... ) # doctest: +IGNORE_RESULT
shape: (2, 2)
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 0 ┆ 3 │
│ 1 ┆ 4 │
│ 0 ┆ 3 │
└─────┴─────┘
"""
if return_dtype is not None:
return_dtype = py_type_to_dtype(return_dtype)
Expand All @@ -4634,6 +4665,7 @@ def map_batches(
return_dtype,
agg_list,
is_elementwise,
returns_scalar,
)
)

Expand Down
11 changes: 9 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,7 @@ def __array_ufunc__(

# Only generalized ufuncs have a signature set:
is_generalized_ufunc = bool(ufunc.signature)

if is_generalized_ufunc:
# Generalized ufuncs will operate on the whole array, so
# missing data can corrupt the results.
Expand All @@ -1392,7 +1393,13 @@ def __array_ufunc__(
# output size.
assert ufunc.signature is not None # pacify MyPy
ufunc_input, ufunc_output = ufunc.signature.split("->")
allocate_output = ufunc_input == ufunc_output
if ufunc_output == "()":
# If the result a scalar, just let the function do its
# thing, no need for any song and dance involving
# allocation:
return ufunc(*args, dtype=dtype_char, **kwargs)
else:
allocate_output = ufunc_input == ufunc_output
else:
allocate_output = True

Expand All @@ -1409,6 +1416,7 @@ def __array_ufunc__(
lambda out: ufunc(*args, out=out, dtype=dtype_char, **kwargs),
allocate_output,
)

result = self._from_pyseries(series)
if is_generalized_ufunc:
# In this case we've disallowed passing in missing data, so no
Expand All @@ -1426,7 +1434,6 @@ def __array_ufunc__(
.select(F.when(validity_mask).then(F.col(self.name)))
.to_series(0)
)

else:
msg = (
"only `__call__` is implemented for numpy ufuncs on a Series, got "
Expand Down
12 changes: 10 additions & 2 deletions py-polars/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,15 +755,23 @@ impl PyExpr {
self.inner.clone().shrink_dtype().into()
}

#[pyo3(signature = (lambda, output_type, agg_list, is_elementwise))]
#[pyo3(signature = (lambda, output_type, agg_list, is_elementwise, returns_scalar))]
fn map_batches(
&self,
lambda: PyObject,
output_type: Option<Wrap<DataType>>,
agg_list: bool,
is_elementwise: bool,
returns_scalar: bool,
) -> Self {
map_single(self, lambda, output_type, agg_list, is_elementwise)
map_single(
self,
lambda,
output_type,
agg_list,
is_elementwise,
returns_scalar,
)
}

fn dot(&self, other: Self) -> Self {
Expand Down
4 changes: 3 additions & 1 deletion py-polars/src/map/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ pub fn map_single(
output_type: Option<Wrap<DataType>>,
agg_list: bool,
is_elementwise: bool,
returns_scalar: bool,
) -> PyExpr {
let output_type = output_type.map(|wrap| wrap.0);

let func = python_udf::PythonUdfExpression::new(lambda, output_type, is_elementwise);
let func =
python_udf::PythonUdfExpression::new(lambda, output_type, is_elementwise, returns_scalar);
pyexpr.inner.clone().map_python(func, agg_list).into()
}

Expand Down
38 changes: 38 additions & 0 deletions py-polars/tests/unit/interop/numpy/test_ufunc_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,44 @@ def test_grouped_ufunc() -> None:
df.group_by("id").agg(pl.col("values").log1p().sum().pipe(np.expm1))


def test_generalized_ufunc_scalar() -> None:
numba = pytest.importorskip("numba")

@numba.guvectorize([(numba.int64[:], numba.int64[:])], "(n)->()") # type: ignore[misc]
def my_custom_sum(arr, result) -> None: # type: ignore[no-untyped-def]
total = 0
for value in arr:
total += value
result[0] = total

# Make type checkers happy:
custom_sum = cast(Callable[[object], object], my_custom_sum)

# Demonstrate NumPy as the canonical expected behavior:
assert custom_sum(np.array([10, 2, 3], dtype=np.int64)) == 15

# Direct call of the gufunc:
df = pl.DataFrame({"values": [10, 2, 3]})
assert custom_sum(df.get_column("values")) == 15

# Indirect call of the gufunc:
indirect = df.select(pl.col("values").map_batches(custom_sum, returns_scalar=True))
assert_frame_equal(indirect, pl.DataFrame({"values": 15}))
indirect = df.select(pl.col("values").map_batches(custom_sum, returns_scalar=False))
assert_frame_equal(indirect, pl.DataFrame({"values": [15]}))

# group_by()
df = pl.DataFrame({"labels": ["a", "b", "a", "b"], "values": [10, 2, 3, 30]})
indirect = (
df.group_by("labels")
.agg(pl.col("values").map_batches(custom_sum, returns_scalar=True))
.sort("labels")
)
assert_frame_equal(
indirect, pl.DataFrame({"labels": ["a", "b"], "values": [13, 32]})
)


def make_gufunc_mean() -> Callable[[pl.Series], pl.Series]:
numba = pytest.importorskip("numba")

Expand Down
14 changes: 6 additions & 8 deletions py-polars/tests/unit/operations/map/test_map_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def test_map_batches_group() -> None:
assert df.group_by("id").agg(pl.col("t").map_batches(lambda s: s.sum())).sort(
"id"
).to_dict(as_series=False) == {"id": [0, 1], "t": [[11], [35]]}
# If returns_scalar is True, the result won't be wrapped in a list:
assert df.group_by("id").agg(
pl.col("t").map_batches(lambda s: s.sum(), returns_scalar=True)
).sort("id").to_dict(as_series=False) == {"id": [0, 1], "t": [11, 35]}


def test_map_deprecated() -> None:
Expand All @@ -82,16 +86,10 @@ def test_map_deprecated() -> None:
def test_ufunc_args() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [2, 4, 6]})
result = df.select(
z=np.add( # type: ignore[call-overload]
pl.col("a"), pl.col("b")
)
z=np.add(pl.col("a"), pl.col("b")) # type: ignore[call-overload]
)
expected = pl.DataFrame({"z": [3, 6, 9]})
assert_frame_equal(result, expected)
result = df.select(
z=np.add( # type: ignore[call-overload]
2, pl.col("a")
)
)
result = df.select(z=np.add(2, pl.col("a"))) # type: ignore[call-overload]
expected = pl.DataFrame({"z": [3, 4, 5]})
assert_frame_equal(result, expected)

0 comments on commit fa46225

Please sign in to comment.