diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 9206c9da4d..08ec0860ba 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -9,7 +9,7 @@ from daft.io.scan import ScanOperator from daft.plan_scheduler.physical_plan_scheduler import PartitionT from daft.runners.partitioning import PartitionCacheEntry from daft.sql.sql_connection import SQLConnection -from daft.udf import PartialStatefulUDF, PartialStatelessUDF +from daft.udf import InitArgsType, PartialStatefulUDF, PartialStatelessUDF if TYPE_CHECKING: import pyarrow as pa @@ -1150,12 +1150,14 @@ def stateful_udf( expressions: list[PyExpr], return_dtype: PyDataType, resource_request: ResourceRequest | None, - init_args: tuple[tuple[Any, ...], dict[str, Any]] | None, + init_args: InitArgsType, batch_size: int | None, concurrency: int | None, ) -> PyExpr: ... def check_column_name_validity(name: str, schema: PySchema): ... -def extract_partial_stateful_udf_py(expression: PyExpr) -> dict[str, PartialStatefulUDF]: ... +def extract_partial_stateful_udf_py( + expression: PyExpr, +) -> dict[str, tuple[PartialStatefulUDF, InitArgsType]]: ... def bind_stateful_udfs(expression: PyExpr, initialized_funcs: dict[str, Callable]) -> PyExpr: ... def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ... def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ... diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index e80acb03cb..31c56c3ad4 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -138,10 +138,15 @@ def initialize_actor_global_state(uninitialized_projection: ExpressionsProjectio logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) - # TODO: Account for Stateful Actor initialization arguments as well as user-provided batch_size - PyActorPool.initialized_stateful_udfs_process_singleton = { - name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items() - } + PyActorPool.initialized_stateful_udfs_process_singleton = {} + for name, (partial_udf, init_args) in partial_stateful_udfs.items(): + if init_args is None: + PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls() + else: + args, kwargs = init_args + PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls( + *args, **kwargs + ) @staticmethod def build_partitions_with_stateful_project( diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 7c84f7b9dc..d29a15c9f2 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -931,9 +931,14 @@ def __init__(self, daft_execution_config: PyDaftExecutionConfig, uninitialized_p for name, psu in extract_partial_stateful_udf_py(expr._expr).items() } logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys())) - self.initialized_stateful_udfs = { - name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items() - } + + self.initialized_stateful_udfs = {} + for name, (partial_udf, init_args) in partial_stateful_udfs.items(): + if init_args is None: + self.initialized_stateful_udfs[name] = partial_udf.func_cls() + else: + args, kwargs = init_args + self.initialized_stateful_udfs[name] = partial_udf.func_cls(*args, **kwargs) @ray.method(num_returns=2) def run( diff --git a/daft/udf.py b/daft/udf.py index e2afb495ff..c662dc6ced 100644 --- a/daft/udf.py +++ b/daft/udf.py @@ -4,7 +4,7 @@ import functools import inspect from abc import abstractmethod -from typing import Any, Callable, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union from daft.context import get_context from daft.daft import PyDataType, ResourceRequest @@ -13,6 +13,7 @@ from daft.expressions import Expression from daft.series import PySeries, Series +InitArgsType = Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] UserProvidedPythonFunction = Callable[..., Union[Series, "np.ndarray", list]] @@ -294,7 +295,7 @@ class StatefulUDF(UDF): name: str cls: type return_dtype: DataType - init_args: tuple[tuple[Any, ...], dict[str, Any]] | None = None + init_args: InitArgsType = None concurrency: int | None = None def __post_init__(self): diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 44bb95c1b0..4da0140e01 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -27,13 +27,21 @@ impl PyDaftPlanningConfig { } } - fn with_config_values(&mut self, default_io_config: Option) -> PyResult { + fn with_config_values( + &mut self, + default_io_config: Option, + enable_actor_pool_projections: Option, + ) -> PyResult { let mut config = self.config.as_ref().clone(); if let Some(default_io_config) = default_io_config { config.default_io_config = default_io_config.config; } + if let Some(enable_actor_pool_projections) = enable_actor_pool_projections { + config.enable_actor_pool_projections = enable_actor_pool_projections; + } + Ok(Self { config: Arc::new(config), }) diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index 378611851a..adbb2830e7 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -9,6 +9,8 @@ use common_resource_request::ResourceRequest; use common_treenode::{TreeNode, TreeNodeRecursion}; use daft_core::datatypes::DataType; use itertools::Itertools; +#[cfg(feature = "python")] +use pyo3::{Py, PyAny}; pub use runtime_py_object::RuntimePyObject; use serde::{Deserialize, Serialize}; pub use udf_runtime_binding::UDFRuntimeBinding; @@ -180,7 +182,7 @@ pub fn get_concurrency(exprs: &[ExprRef]) -> usize { #[cfg(feature = "python")] pub fn bind_stateful_udfs( expr: ExprRef, - initialized_funcs: &HashMap>, + initialized_funcs: &HashMap>, ) -> DaftResult { expr.transform(|e| match e.as_ref() { Expr::Function { @@ -213,7 +215,9 @@ pub fn bind_stateful_udfs( /// Helper function that extracts all PartialStatefulUDF python objects from a given expression tree #[cfg(feature = "python")] -pub fn extract_partial_stateful_udf_py(expr: ExprRef) -> HashMap> { +pub fn extract_partial_stateful_udf_py( + expr: ExprRef, +) -> HashMap, Option>)> { let mut py_partial_udfs = HashMap::new(); expr.apply(|child| { if let Expr::Function { @@ -221,12 +225,19 @@ pub fn extract_partial_stateful_udf_py(expr: ExprRef) -> HashMap HashMap> { +pub fn extract_partial_stateful_udf_py( + expr: PyExpr, +) -> HashMap, Option>)> { use crate::functions::python::extract_partial_stateful_udf_py; extract_partial_stateful_udf_py(expr.expr) } diff --git a/tests/expressions/test_udf.py b/tests/expressions/test_udf.py index 2572eb1adc..5ac09387e0 100644 --- a/tests/expressions/test_udf.py +++ b/tests/expressions/test_udf.py @@ -4,7 +4,9 @@ import pyarrow as pa import pytest +import daft from daft import col +from daft.context import get_context, set_planning_config from daft.datatype import DataType from daft.expressions import Expression from daft.expressions.testing import expr_structurally_equal @@ -13,6 +15,21 @@ from daft.udf import udf +@pytest.fixture(scope="function", params=[False, True]) +def actor_pool_enabled(request): + if request.param and get_context().daft_execution_config.enable_native_executor: + pytest.skip("Native executor does not support stateful UDFs") + + original_config = get_context().daft_planning_config + try: + set_planning_config( + config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=request.param) + ) + yield request.param + finally: + set_planning_config(config=original_config) + + def test_udf(): table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) @@ -30,8 +47,8 @@ def repeat_n(data, n): @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf(batch_size): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf(batch_size, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) class RepeatN: @@ -41,18 +58,21 @@ def __init__(self): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + expr = RepeatN(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf_init_args(batch_size): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf_init_args(batch_size, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) class RepeatN: @@ -62,24 +82,27 @@ def __init__(self, initial_n: int = 2): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + expr = RepeatN(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} expr = RepeatN.with_init_args(initial_n=3)(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoofoo", "barbarbar", "bazbazbaz"]} @pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10]) -def test_class_udf_init_args_no_default(batch_size): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf_init_args_no_default(batch_size, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string(), batch_size=batch_size) class RepeatN: @@ -89,18 +112,21 @@ def __init__(self, initial_n): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + with pytest.raises(ValueError, match="Cannot call StatefulUDF without initialization arguments."): RepeatN(col("a")) expr = RepeatN.with_init_args(initial_n=2)(col("a")) - field = expr._to_field(table.schema()) + field = expr._to_field(df.schema()) assert field.name == "a" assert field.dtype == DataType.string() - result = table.eval_expression_list([expr]) + result = df.select(expr) assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} -def test_class_udf_init_args_bad_args(): +def test_class_udf_init_args_bad_args(actor_pool_enabled): @udf(return_dtype=DataType.string()) class RepeatN: def __init__(self, initial_n): @@ -109,10 +135,37 @@ def __init__(self, initial_n): def __call__(self, data): return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + if actor_pool_enabled: + RepeatN = RepeatN.with_concurrency(1) + with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"): RepeatN.with_init_args(wrong=5) +@pytest.mark.parametrize("concurrency", [1, 2, 4]) +@pytest.mark.parametrize("actor_pool_enabled", [True], indirect=True) +def test_stateful_udf_concurrency(concurrency, actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) + + @udf(return_dtype=DataType.string(), batch_size=1) + class RepeatN: + def __init__(self): + self.n = 2 + + def __call__(self, data): + return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()]) + + RepeatN = RepeatN.with_concurrency(concurrency) + + expr = RepeatN(col("a")) + field = expr._to_field(df.schema()) + assert field.name == "a" + assert field.dtype == DataType.string() + + result = df.select(expr) + assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]} + + def test_udf_kwargs(): table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) @@ -208,8 +261,8 @@ def full_udf(e_arg, val, kwarg_val=None, kwarg_ex=None): full_udf() -def test_class_udf_initialization_error(): - table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]}) +def test_class_udf_initialization_error(actor_pool_enabled): + df = daft.from_pydict({"a": ["foo", "bar", "baz"]}) @udf(return_dtype=DataType.string()) class IdentityWithInitError: @@ -219,9 +272,16 @@ def __init__(self): def __call__(self, data): return data + if actor_pool_enabled: + IdentityWithInitError = IdentityWithInitError.with_concurrency(1) + expr = IdentityWithInitError(col("a")) - with pytest.raises(RuntimeError, match="UDF INIT ERROR"): - table.eval_expression_list([expr]) + if actor_pool_enabled: + with pytest.raises(Exception): + df.select(expr).collect() + else: + with pytest.raises(RuntimeError, match="UDF INIT ERROR"): + df.select(expr).collect() def test_udf_equality():