Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): improve support for user-defined functions that return scalars #16556

Merged
merged 13 commits into from
May 30, 2024

Conversation

itamarst
Copy link
Contributor

@itamarst itamarst commented May 28, 2024

Fixes #14748

First, using generalized ufuncs that returned scalars didn't work with the latest version of Polars, at least. Consider the following script:

import polars as pl
import numpy as np
import numba


@numba.guvectorize([(numba.int64[:], numba.int64[:])], "(n)->()")
def custom_sum(arr, result) -> None:
    total = 0
    for value in arr:
        total += value
    # Use guvectorize's weird API for returning scalars:
    result[0] = total


# Demonstrate NumPy as the canonical expected behavior:
assert custom_sum(np.array([10, 2, 3], dtype=np.int64)) == 15

# Direct call of the gufunc:
df = pl.DataFrame({"values": [10, 2, 3]})
assert custom_sum(df.get_column("values")) == 15

On the released version of Polars it fails with:

TypeError: Series constructor called with unsupported type 'int64' for the `values` parameter

This is fixed in this branch (see the new test in test_ufunc_expr.py).

Second, if your user-defined function returns a scalar, chances are you want to keep it as a scalar. Unfortunately map_batches() had no way to express this previously, so the scalar would be wrapped in a list, making it annoying to use.

This is analogous to the equivalent flag in pl.map_groups(): https://docs.pola.rs/py-polars/html/reference/expressions/api/polars.map_groups.html (unfortunately with the opposite default, for backwards compatibility reasons).

Here's what it look like in the new branch:

>>> import polars as pl
>>> df = pl.DataFrame({"a": [0, 1, 0, 1], "b": [1, 2, 3,4]})
>>> df.group_by("a").agg(pl.col("b").map_batches(max))
shape: (2, 2)
┌─────┬───────────┐
│ ab         │
│ ------       │
│ i64list[i64] │
╞═════╪═══════════╡
│ 1   ┆ [4]       │
│ 0   ┆ [3]       │
└─────┴───────────┘
>>> df.group_by("a").agg(pl.col("b").map_batches(max, returns_scalar=True))
shape: (2, 2)
┌─────┬─────┐
│ ab   │
│ ------ │
│ i64i64 │
╞═════╪═════╡
│ 14   │
│ 03   │
└─────┴─────┘

@github-actions github-actions bot added enhancement New feature or an improvement of an existing feature python Related to Python Polars labels May 28, 2024
@itamarst itamarst changed the title feat(python): in map_batches() with a user-defined functions that return scalars, allow keeping results as scalars fix(python): improve support for user-defined functions that return scalars May 28, 2024
@github-actions github-actions bot added the fix Bug fix label May 28, 2024
@itamarst itamarst marked this pull request as ready for review May 28, 2024 17:16
Copy link
Member

@ritchie46 ritchie46 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @itamarst. Thanks!

@ritchie46 ritchie46 merged commit 84ba2d0 into pola-rs:main May 30, 2024
33 checks passed
Wouittone pushed a commit to Wouittone/polars that referenced this pull request Jun 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or an improvement of an existing feature fix Bug fix python Related to Python Polars
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Generalized ufunc functions have inconsistent behavior between NumPy and Polars when returning scalars
3 participants