Skip to content

Commit

Permalink
[FEAT] Enable init args for stateful UDFs (Eventual-Inc#2956)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang authored and sagiahrac committed Oct 7, 2024
1 parent 6cb4fc0 commit a509010
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 36 deletions.
8 changes: 5 additions & 3 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from daft.io.scan import ScanOperator
from daft.plan_scheduler.physical_plan_scheduler import PartitionT
from daft.runners.partitioning import PartitionCacheEntry
from daft.sql.sql_connection import SQLConnection
from daft.udf import PartialStatefulUDF, PartialStatelessUDF
from daft.udf import InitArgsType, PartialStatefulUDF, PartialStatelessUDF

if TYPE_CHECKING:
import pyarrow as pa
Expand Down Expand Up @@ -1150,12 +1150,14 @@ def stateful_udf(
expressions: list[PyExpr],
return_dtype: PyDataType,
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None,
init_args: InitArgsType,
batch_size: int | None,
concurrency: int | None,
) -> PyExpr: ...
def check_column_name_validity(name: str, schema: PySchema): ...
def extract_partial_stateful_udf_py(expression: PyExpr) -> dict[str, PartialStatefulUDF]: ...
def extract_partial_stateful_udf_py(
expression: PyExpr,
) -> dict[str, tuple[PartialStatefulUDF, InitArgsType]]: ...
def bind_stateful_udfs(expression: PyExpr, initialized_funcs: dict[str, Callable]) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
Expand Down
13 changes: 9 additions & 4 deletions daft/runners/pyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,15 @@ def initialize_actor_global_state(uninitialized_projection: ExpressionsProjectio

logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys()))

# TODO: Account for Stateful Actor initialization arguments as well as user-provided batch_size
PyActorPool.initialized_stateful_udfs_process_singleton = {
name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items()
}
PyActorPool.initialized_stateful_udfs_process_singleton = {}
for name, (partial_udf, init_args) in partial_stateful_udfs.items():
if init_args is None:
PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls()
else:
args, kwargs = init_args
PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls(
*args, **kwargs
)

@staticmethod
def build_partitions_with_stateful_project(
Expand Down
11 changes: 8 additions & 3 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,9 +931,14 @@ def __init__(self, daft_execution_config: PyDaftExecutionConfig, uninitialized_p
for name, psu in extract_partial_stateful_udf_py(expr._expr).items()
}
logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys()))
self.initialized_stateful_udfs = {
name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items()
}

self.initialized_stateful_udfs = {}
for name, (partial_udf, init_args) in partial_stateful_udfs.items():
if init_args is None:
self.initialized_stateful_udfs[name] = partial_udf.func_cls()
else:
args, kwargs = init_args
self.initialized_stateful_udfs[name] = partial_udf.func_cls(*args, **kwargs)

@ray.method(num_returns=2)
def run(
Expand Down
5 changes: 3 additions & 2 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import inspect
from abc import abstractmethod
from typing import Any, Callable, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

from daft.context import get_context
from daft.daft import PyDataType, ResourceRequest
Expand All @@ -13,6 +13,7 @@
from daft.expressions import Expression
from daft.series import PySeries, Series

InitArgsType = Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]]
UserProvidedPythonFunction = Callable[..., Union[Series, "np.ndarray", list]]


Expand Down Expand Up @@ -294,7 +295,7 @@ class StatefulUDF(UDF):
name: str
cls: type
return_dtype: DataType
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None = None
init_args: InitArgsType = None
concurrency: int | None = None

def __post_init__(self):
Expand Down
10 changes: 9 additions & 1 deletion src/common/daft-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,21 @@ impl PyDaftPlanningConfig {
}
}

fn with_config_values(&mut self, default_io_config: Option<PyIOConfig>) -> PyResult<Self> {
fn with_config_values(
&mut self,
default_io_config: Option<PyIOConfig>,
enable_actor_pool_projections: Option<bool>,
) -> PyResult<Self> {
let mut config = self.config.as_ref().clone();

if let Some(default_io_config) = default_io_config {
config.default_io_config = default_io_config.config;
}

if let Some(enable_actor_pool_projections) = enable_actor_pool_projections {
config.enable_actor_pool_projections = enable_actor_pool_projections;
}

Ok(Self {
config: Arc::new(config),
})
Expand Down
17 changes: 14 additions & 3 deletions src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use common_resource_request::ResourceRequest;
use common_treenode::{TreeNode, TreeNodeRecursion};
use daft_core::datatypes::DataType;
use itertools::Itertools;
#[cfg(feature = "python")]
use pyo3::{Py, PyAny};
pub use runtime_py_object::RuntimePyObject;
use serde::{Deserialize, Serialize};
pub use udf_runtime_binding::UDFRuntimeBinding;
Expand Down Expand Up @@ -180,7 +182,7 @@ pub fn get_concurrency(exprs: &[ExprRef]) -> usize {
#[cfg(feature = "python")]
pub fn bind_stateful_udfs(
expr: ExprRef,
initialized_funcs: &HashMap<String, pyo3::Py<pyo3::PyAny>>,
initialized_funcs: &HashMap<String, Py<PyAny>>,
) -> DaftResult<ExprRef> {
expr.transform(|e| match e.as_ref() {
Expr::Function {
Expand Down Expand Up @@ -213,20 +215,29 @@ pub fn bind_stateful_udfs(

/// Helper function that extracts all PartialStatefulUDF python objects from a given expression tree
#[cfg(feature = "python")]
pub fn extract_partial_stateful_udf_py(expr: ExprRef) -> HashMap<String, pyo3::Py<pyo3::PyAny>> {
pub fn extract_partial_stateful_udf_py(
expr: ExprRef,
) -> HashMap<String, (Py<PyAny>, Option<Py<PyAny>>)> {
let mut py_partial_udfs = HashMap::new();
expr.apply(|child| {
if let Expr::Function {
func:
FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
name,
stateful_partial_func: py_partial_udf,
init_args,
..
})),
..
} = child.as_ref()
{
py_partial_udfs.insert(name.as_ref().to_string(), py_partial_udf.as_ref().clone());
py_partial_udfs.insert(
name.as_ref().to_string(),
(
py_partial_udf.as_ref().clone(),
init_args.clone().map(|x| x.as_ref().clone()),
),
);
}
Ok(TreeNodeRecursion::Continue)
})
Expand Down
4 changes: 3 additions & 1 deletion src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ pub fn stateful_udf(

/// Extracts the `class PartialStatefulUDF` Python objects that are in the specified expression tree
#[pyfunction]
pub fn extract_partial_stateful_udf_py(expr: PyExpr) -> HashMap<String, Py<PyAny>> {
pub fn extract_partial_stateful_udf_py(
expr: PyExpr,
) -> HashMap<String, (Py<PyAny>, Option<Py<PyAny>>)> {
use crate::functions::python::extract_partial_stateful_udf_py;
extract_partial_stateful_udf_py(expr.expr)
}
Expand Down
98 changes: 79 additions & 19 deletions tests/expressions/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import pyarrow as pa
import pytest

import daft
from daft import col
from daft.context import get_context, set_planning_config
from daft.datatype import DataType
from daft.expressions import Expression
from daft.expressions.testing import expr_structurally_equal
Expand All @@ -13,6 +15,21 @@
from daft.udf import udf


@pytest.fixture(scope="function", params=[False, True])
def actor_pool_enabled(request):
if request.param and get_context().daft_execution_config.enable_native_executor:
pytest.skip("Native executor does not support stateful UDFs")

original_config = get_context().daft_planning_config
try:
set_planning_config(
config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=request.param)
)
yield request.param
finally:
set_planning_config(config=original_config)


def test_udf():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})

Expand All @@ -30,8 +47,8 @@ def repeat_n(data, n):


@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10])
def test_class_udf(batch_size):
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})
def test_class_udf(batch_size, actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string(), batch_size=batch_size)
class RepeatN:
Expand All @@ -41,18 +58,21 @@ def __init__(self):
def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

if actor_pool_enabled:
RepeatN = RepeatN.with_concurrency(1)

expr = RepeatN(col("a"))
field = expr._to_field(table.schema())
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()

result = table.eval_expression_list([expr])
result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}


@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10])
def test_class_udf_init_args(batch_size):
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})
def test_class_udf_init_args(batch_size, actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string(), batch_size=batch_size)
class RepeatN:
Expand All @@ -62,24 +82,27 @@ def __init__(self, initial_n: int = 2):
def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

if actor_pool_enabled:
RepeatN = RepeatN.with_concurrency(1)

expr = RepeatN(col("a"))
field = expr._to_field(table.schema())
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()
result = table.eval_expression_list([expr])
result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}

expr = RepeatN.with_init_args(initial_n=3)(col("a"))
field = expr._to_field(table.schema())
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()
result = table.eval_expression_list([expr])
result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoofoo", "barbarbar", "bazbazbaz"]}


@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10])
def test_class_udf_init_args_no_default(batch_size):
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})
def test_class_udf_init_args_no_default(batch_size, actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string(), batch_size=batch_size)
class RepeatN:
Expand All @@ -89,18 +112,21 @@ def __init__(self, initial_n):
def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

if actor_pool_enabled:
RepeatN = RepeatN.with_concurrency(1)

with pytest.raises(ValueError, match="Cannot call StatefulUDF without initialization arguments."):
RepeatN(col("a"))

expr = RepeatN.with_init_args(initial_n=2)(col("a"))
field = expr._to_field(table.schema())
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()
result = table.eval_expression_list([expr])
result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}


def test_class_udf_init_args_bad_args():
def test_class_udf_init_args_bad_args(actor_pool_enabled):
@udf(return_dtype=DataType.string())
class RepeatN:
def __init__(self, initial_n):
Expand All @@ -109,10 +135,37 @@ def __init__(self, initial_n):
def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

if actor_pool_enabled:
RepeatN = RepeatN.with_concurrency(1)

with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"):
RepeatN.with_init_args(wrong=5)


@pytest.mark.parametrize("concurrency", [1, 2, 4])
@pytest.mark.parametrize("actor_pool_enabled", [True], indirect=True)
def test_stateful_udf_concurrency(concurrency, actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string(), batch_size=1)
class RepeatN:
def __init__(self):
self.n = 2

def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

RepeatN = RepeatN.with_concurrency(concurrency)

expr = RepeatN(col("a"))
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()

result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}


def test_udf_kwargs():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})

Expand Down Expand Up @@ -208,8 +261,8 @@ def full_udf(e_arg, val, kwarg_val=None, kwarg_ex=None):
full_udf()


def test_class_udf_initialization_error():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})
def test_class_udf_initialization_error(actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string())
class IdentityWithInitError:
Expand All @@ -219,9 +272,16 @@ def __init__(self):
def __call__(self, data):
return data

if actor_pool_enabled:
IdentityWithInitError = IdentityWithInitError.with_concurrency(1)

expr = IdentityWithInitError(col("a"))
with pytest.raises(RuntimeError, match="UDF INIT ERROR"):
table.eval_expression_list([expr])
if actor_pool_enabled:
with pytest.raises(Exception):
df.select(expr).collect()
else:
with pytest.raises(RuntimeError, match="UDF INIT ERROR"):
df.select(expr).collect()


def test_udf_equality():
Expand Down

0 comments on commit a509010

Please sign in to comment.