Skip to content

Commit

Permalink
[FEAT] Allow returning of pyarrow arrays from UDFs (#2252)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaychia and Jay Chia authored May 8, 2024
1 parent f1d6570 commit 62f9dd6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
9 changes: 9 additions & 0 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@
except ImportError:
_NUMPY_AVAILABLE = False

_PYARROW_AVAILABLE = True
try:
import pyarrow as pa
except ImportError:
_PYARROW_AVAILABLE = False

if TYPE_CHECKING:
import numpy as np
import pyarrow as pa

UserProvidedPythonFunction = Callable[..., Union[Series, "np.ndarray", list]]

Expand Down Expand Up @@ -114,6 +121,8 @@ def __call__(self, evaluated_expressions: list[Series]) -> PySeries:
return Series.from_pylist(result, name=name, pyobj="allow").cast(self.udf.return_dtype)._series
elif _NUMPY_AVAILABLE and isinstance(result, np.ndarray):
return Series.from_numpy(result, name=name).cast(self.udf.return_dtype)._series
elif _PYARROW_AVAILABLE and isinstance(result, (pa.Array, pa.ChunkedArray)):
return Series.from_arrow(result, name=name).cast(self.udf.return_dtype)._series
else:
raise NotImplementedError(f"Return type not supported for UDF: {type(result)}")

Expand Down
12 changes: 12 additions & 0 deletions tests/expressions/test_udf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import numpy as np
import pyarrow as pa
import pytest

from daft import col
Expand Down Expand Up @@ -235,3 +236,14 @@ def add_cols_elementwise(*args, multiplier: float):

result = table.eval_expression_list([expr])
assert result.to_pydict() == {"a": [6, 12, 18]}


def test_udf_return_pyarrow():
table = MicroPartition.from_pydict({"a": [1, 2, 3]})

@udf(return_dtype=DataType.int64())
def add_1(data):
return pa.compute.add(data.to_arrow(), 1)

result = table.eval_expression_list([add_1(col("a"))])
assert result.to_pydict() == {"a": [2, 3, 4]}

0 comments on commit 62f9dd6

Please sign in to comment.