Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Enable init args for stateful UDFs #2956

Merged
merged 6 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from daft.io.scan import ScanOperator
from daft.plan_scheduler.physical_plan_scheduler import PartitionT
from daft.runners.partitioning import PartitionCacheEntry
from daft.sql.sql_connection import SQLConnection
from daft.udf import PartialStatefulUDF, PartialStatelessUDF
from daft.udf import InitArgsType, PartialStatefulUDF, PartialStatelessUDF

if TYPE_CHECKING:
import pyarrow as pa
Expand Down Expand Up @@ -1150,12 +1150,14 @@ def stateful_udf(
expressions: list[PyExpr],
return_dtype: PyDataType,
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None,
init_args: InitArgsType,
batch_size: int | None,
concurrency: int | None,
) -> PyExpr: ...
def check_column_name_validity(name: str, schema: PySchema): ...
def extract_partial_stateful_udf_py(expression: PyExpr) -> dict[str, PartialStatefulUDF]: ...
def extract_partial_stateful_udf_py(
expression: PyExpr,
) -> dict[str, tuple[PartialStatefulUDF, InitArgsType]]: ...
def bind_stateful_udfs(expression: PyExpr, initialized_funcs: dict[str, Callable]) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
Expand Down
13 changes: 9 additions & 4 deletions daft/runners/pyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,15 @@

logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys()))

# TODO: Account for Stateful Actor initialization arguments as well as user-provided batch_size
PyActorPool.initialized_stateful_udfs_process_singleton = {
name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items()
}
PyActorPool.initialized_stateful_udfs_process_singleton = {}
for name, (partial_udf, init_args) in partial_stateful_udfs.items():
if init_args is None:
PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls()

Check warning on line 144 in daft/runners/pyrunner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/pyrunner.py#L141-L144

Added lines #L141 - L144 were not covered by tests
else:
args, kwargs = init_args
PyActorPool.initialized_stateful_udfs_process_singleton[name] = partial_udf.func_cls(

Check warning on line 147 in daft/runners/pyrunner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/pyrunner.py#L146-L147

Added lines #L146 - L147 were not covered by tests
*args, **kwargs
)

@staticmethod
def build_partitions_with_stateful_project(
Expand Down
11 changes: 8 additions & 3 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,9 +931,14 @@
for name, psu in extract_partial_stateful_udf_py(expr._expr).items()
}
logger.info("Initializing stateful UDFs: %s", ", ".join(partial_stateful_udfs.keys()))
self.initialized_stateful_udfs = {
name: partial_udf.func_cls() for name, partial_udf in partial_stateful_udfs.items()
}

self.initialized_stateful_udfs = {}
for name, (partial_udf, init_args) in partial_stateful_udfs.items():
if init_args is None:
self.initialized_stateful_udfs[name] = partial_udf.func_cls()
else:
args, kwargs = init_args
self.initialized_stateful_udfs[name] = partial_udf.func_cls(*args, **kwargs)

Check warning on line 941 in daft/runners/ray_runner.py

View check run for this annotation

Codecov / codecov/patch

daft/runners/ray_runner.py#L940-L941

Added lines #L940 - L941 were not covered by tests

@ray.method(num_returns=2)
def run(
Expand Down
5 changes: 3 additions & 2 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import inspect
from abc import abstractmethod
from typing import Any, Callable, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

from daft.context import get_context
from daft.daft import PyDataType, ResourceRequest
Expand All @@ -13,6 +13,7 @@
from daft.expressions import Expression
from daft.series import PySeries, Series

InitArgsType = Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]]
UserProvidedPythonFunction = Callable[..., Union[Series, "np.ndarray", list]]


Expand Down Expand Up @@ -294,7 +295,7 @@ class StatefulUDF(UDF):
name: str
cls: type
return_dtype: DataType
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None = None
init_args: InitArgsType = None
concurrency: int | None = None

def __post_init__(self):
Expand Down
10 changes: 9 additions & 1 deletion src/common/daft-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,21 @@ impl PyDaftPlanningConfig {
}
}

fn with_config_values(&mut self, default_io_config: Option<PyIOConfig>) -> PyResult<Self> {
fn with_config_values(
&mut self,
default_io_config: Option<PyIOConfig>,
enable_actor_pool_projections: Option<bool>,
) -> PyResult<Self> {
let mut config = self.config.as_ref().clone();

if let Some(default_io_config) = default_io_config {
config.default_io_config = default_io_config.config;
}

if let Some(enable_actor_pool_projections) = enable_actor_pool_projections {
config.enable_actor_pool_projections = enable_actor_pool_projections;
}

Ok(Self {
config: Arc::new(config),
})
Expand Down
17 changes: 14 additions & 3 deletions src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
use common_treenode::{TreeNode, TreeNodeRecursion};
use daft_core::datatypes::DataType;
use itertools::Itertools;
#[cfg(feature = "python")]
use pyo3::{Py, PyAny};
pub use runtime_py_object::RuntimePyObject;
use serde::{Deserialize, Serialize};
pub use udf_runtime_binding::UDFRuntimeBinding;
Expand Down Expand Up @@ -180,7 +182,7 @@
#[cfg(feature = "python")]
pub fn bind_stateful_udfs(
expr: ExprRef,
initialized_funcs: &HashMap<String, pyo3::Py<pyo3::PyAny>>,
initialized_funcs: &HashMap<String, Py<PyAny>>,

Check warning on line 185 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L185

Added line #L185 was not covered by tests
) -> DaftResult<ExprRef> {
expr.transform(|e| match e.as_ref() {
Expr::Function {
Expand Down Expand Up @@ -213,20 +215,29 @@

/// Helper function that extracts all PartialStatefulUDF python objects from a given expression tree
#[cfg(feature = "python")]
pub fn extract_partial_stateful_udf_py(expr: ExprRef) -> HashMap<String, pyo3::Py<pyo3::PyAny>> {
pub fn extract_partial_stateful_udf_py(
expr: ExprRef,
) -> HashMap<String, (Py<PyAny>, Option<Py<PyAny>>)> {
let mut py_partial_udfs = HashMap::new();
expr.apply(|child| {
if let Expr::Function {
func:
FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
name,
stateful_partial_func: py_partial_udf,
init_args,
..
})),
..
} = child.as_ref()
{
py_partial_udfs.insert(name.as_ref().to_string(), py_partial_udf.as_ref().clone());
py_partial_udfs.insert(
name.as_ref().to_string(),
(
py_partial_udf.as_ref().clone(),
init_args.clone().map(|x| x.as_ref().clone()),
),
);
}
Ok(TreeNodeRecursion::Continue)
})
Expand Down
4 changes: 3 additions & 1 deletion src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ pub fn stateful_udf(

/// Extracts the `class PartialStatefulUDF` Python objects that are in the specified expression tree
#[pyfunction]
pub fn extract_partial_stateful_udf_py(expr: PyExpr) -> HashMap<String, Py<PyAny>> {
pub fn extract_partial_stateful_udf_py(
expr: PyExpr,
) -> HashMap<String, (Py<PyAny>, Option<Py<PyAny>>)> {
use crate::functions::python::extract_partial_stateful_udf_py;
extract_partial_stateful_udf_py(expr.expr)
}
Expand Down
98 changes: 79 additions & 19 deletions tests/expressions/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import pyarrow as pa
import pytest

import daft
from daft import col
from daft.context import get_context, set_planning_config
from daft.datatype import DataType
from daft.expressions import Expression
from daft.expressions.testing import expr_structurally_equal
Expand All @@ -13,6 +15,21 @@
from daft.udf import udf


@pytest.fixture(scope="function", params=[False, True])
def actor_pool_enabled(request):
if request.param and get_context().daft_execution_config.enable_native_executor:
pytest.skip("Native executor does not support stateful UDFs")

original_config = get_context().daft_planning_config
try:
set_planning_config(
config=get_context().daft_planning_config.with_config_values(enable_actor_pool_projections=request.param)
)
yield request.param
finally:
set_planning_config(config=original_config)


def test_udf():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})

Expand All @@ -30,8 +47,8 @@ def repeat_n(data, n):


@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10])
def test_class_udf(batch_size):
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})
def test_class_udf(batch_size, actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string(), batch_size=batch_size)
class RepeatN:
Expand All @@ -41,18 +58,21 @@ def __init__(self):
def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

if actor_pool_enabled:
RepeatN = RepeatN.with_concurrency(1)

expr = RepeatN(col("a"))
field = expr._to_field(table.schema())
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()

result = table.eval_expression_list([expr])
result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}


@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10])
def test_class_udf_init_args(batch_size):
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})
def test_class_udf_init_args(batch_size, actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string(), batch_size=batch_size)
class RepeatN:
Expand All @@ -62,24 +82,27 @@ def __init__(self, initial_n: int = 2):
def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

if actor_pool_enabled:
RepeatN = RepeatN.with_concurrency(1)

expr = RepeatN(col("a"))
field = expr._to_field(table.schema())
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()
result = table.eval_expression_list([expr])
result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}

expr = RepeatN.with_init_args(initial_n=3)(col("a"))
field = expr._to_field(table.schema())
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()
result = table.eval_expression_list([expr])
result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoofoo", "barbarbar", "bazbazbaz"]}


@pytest.mark.parametrize("batch_size", [None, 1, 2, 3, 10])
def test_class_udf_init_args_no_default(batch_size):
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})
def test_class_udf_init_args_no_default(batch_size, actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string(), batch_size=batch_size)
class RepeatN:
Expand All @@ -89,18 +112,21 @@ def __init__(self, initial_n):
def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

if actor_pool_enabled:
RepeatN = RepeatN.with_concurrency(1)

with pytest.raises(ValueError, match="Cannot call StatefulUDF without initialization arguments."):
RepeatN(col("a"))

expr = RepeatN.with_init_args(initial_n=2)(col("a"))
field = expr._to_field(table.schema())
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()
result = table.eval_expression_list([expr])
result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}


def test_class_udf_init_args_bad_args():
def test_class_udf_init_args_bad_args(actor_pool_enabled):
@udf(return_dtype=DataType.string())
class RepeatN:
def __init__(self, initial_n):
Expand All @@ -109,10 +135,37 @@ def __init__(self, initial_n):
def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

if actor_pool_enabled:
RepeatN = RepeatN.with_concurrency(1)

with pytest.raises(TypeError, match="missing a required argument: 'initial_n'"):
RepeatN.with_init_args(wrong=5)


@pytest.mark.parametrize("concurrency", [1, 2, 4])
@pytest.mark.parametrize("actor_pool_enabled", [True], indirect=True)
def test_stateful_udf_concurrency(concurrency, actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string(), batch_size=1)
class RepeatN:
def __init__(self):
self.n = 2

def __call__(self, data):
return Series.from_pylist([d.as_py() * self.n for d in data.to_arrow()])

RepeatN = RepeatN.with_concurrency(concurrency)

expr = RepeatN(col("a"))
field = expr._to_field(df.schema())
assert field.name == "a"
assert field.dtype == DataType.string()

result = df.select(expr)
assert result.to_pydict() == {"a": ["foofoo", "barbar", "bazbaz"]}


def test_udf_kwargs():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})

Expand Down Expand Up @@ -208,8 +261,8 @@ def full_udf(e_arg, val, kwarg_val=None, kwarg_ex=None):
full_udf()


def test_class_udf_initialization_error():
table = MicroPartition.from_pydict({"a": ["foo", "bar", "baz"]})
def test_class_udf_initialization_error(actor_pool_enabled):
df = daft.from_pydict({"a": ["foo", "bar", "baz"]})

@udf(return_dtype=DataType.string())
class IdentityWithInitError:
Expand All @@ -219,9 +272,16 @@ def __init__(self):
def __call__(self, data):
return data

if actor_pool_enabled:
IdentityWithInitError = IdentityWithInitError.with_concurrency(1)

expr = IdentityWithInitError(col("a"))
with pytest.raises(RuntimeError, match="UDF INIT ERROR"):
table.eval_expression_list([expr])
if actor_pool_enabled:
with pytest.raises(Exception):
df.select(expr).collect()
else:
with pytest.raises(RuntimeError, match="UDF INIT ERROR"):
df.select(expr).collect()


def test_udf_equality():
Expand Down
Loading