diff --git a/Cargo.lock b/Cargo.lock index d1c64e69b2..c2844d6da0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -904,8 +904,10 @@ dependencies = [ name = "common-daft-config" version = "0.1.10" dependencies = [ + "bincode", "lazy_static", "pyo3", + "serde", ] [[package]] diff --git a/daft/context.py b/daft/context.py index 2377521db3..9954b29260 100644 --- a/daft/context.py +++ b/daft/context.py @@ -53,27 +53,25 @@ def _get_runner_config_from_env() -> _RunnerConfig: raise ValueError(f"Unsupported DAFT_RUNNER variable: {runner}") -# Global Runner singleton, initialized when accessed through the DaftContext -_RUNNER: Runner | None = None - - -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass class DaftContext: """Global context for the current Daft execution environment""" daft_config: PyDaftConfig = PyDaftConfig() runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env) disallow_set_runner: bool = False + _runner: Runner | None = None def runner(self) -> Runner: - global _RUNNER - if _RUNNER is not None: - return _RUNNER + if self._runner is not None: + return self._runner + if self.runner_config.name == "ray": from daft.runners.ray_runner import RayRunner assert isinstance(self.runner_config, _RayRunnerConfig) - _RUNNER = RayRunner( + self._runner = RayRunner( + daft_config=self.daft_config, address=self.runner_config.address, max_task_backlog=self.runner_config.max_task_backlog, ) @@ -94,20 +92,16 @@ def runner(self) -> Runner: pass assert isinstance(self.runner_config, _PyRunnerConfig) - _RUNNER = PyRunner(use_thread_pool=self.runner_config.use_thread_pool) + self._runner = PyRunner(daft_config=self.daft_config, use_thread_pool=self.runner_config.use_thread_pool) else: raise NotImplementedError(f"Runner config implemented: {self.runner_config.name}") # Mark DaftContext as having the runner set, which prevents any subsequent setting of the config # after the runner has been initialized once - global _DaftContext - _DaftContext = dataclasses.replace( - _DaftContext, - disallow_set_runner=True, - ) + self.disallow_set_runner = True - return _RUNNER + return self._runner @property def is_ray_runner(self) -> bool: @@ -121,11 +115,24 @@ def get_context() -> DaftContext: return _DaftContext -def _set_context(ctx: DaftContext): +def set_context(ctx: DaftContext) -> DaftContext: global _DaftContext + pop_context() _DaftContext = ctx + return _DaftContext + + +def pop_context() -> DaftContext: + """Helper used in tests and test fixtures to clear the global runner and allow for re-setting of configs.""" + global _DaftContext + + old_daft_context = _DaftContext + _DaftContext = DaftContext() + + return old_daft_context + def set_runner_ray( address: str | None = None, @@ -150,24 +157,21 @@ def set_runner_ray( Returns: DaftContext: Daft context after setting the Ray runner """ - old_ctx = get_context() - if old_ctx.disallow_set_runner: + ctx = get_context() + if ctx.disallow_set_runner: if noop_if_initialized: warnings.warn( "Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call." ) - return old_ctx + return ctx raise RuntimeError("Cannot set runner more than once") - new_ctx = dataclasses.replace( - old_ctx, - runner_config=_RayRunnerConfig( - address=address, - max_task_backlog=max_task_backlog, - ), - disallow_set_runner=True, + + ctx.runner_config = _RayRunnerConfig( + address=address, + max_task_backlog=max_task_backlog, ) - _set_context(new_ctx) - return new_ctx + ctx.disallow_set_runner = True + return ctx def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext: @@ -178,25 +182,25 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext: Returns: DaftContext: Daft context after setting the Py runner """ - old_ctx = get_context() - if old_ctx.disallow_set_runner: + ctx = get_context() + if ctx.disallow_set_runner: raise RuntimeError("Cannot set runner more than once") - new_ctx = dataclasses.replace( - old_ctx, - runner_config=_PyRunnerConfig(use_thread_pool=use_thread_pool), - disallow_set_runner=True, - ) - _set_context(new_ctx) - return new_ctx + + ctx.runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool) + ctx.disallow_set_runner = True + return ctx def set_config( + config: PyDaftConfig | None = None, merge_scan_tasks_min_size_bytes: int | None = None, merge_scan_tasks_max_size_bytes: int | None = None, ) -> DaftContext: """Globally sets various configuration parameters which control various aspects of Daft execution Args: + config: A PyDaftConfig object to set the config to, before applying other kwargs. Defaults to None which indicates + that the old (current) config should be used. merge_scan_tasks_min_size_bytes: Minimum size in bytes when merging ScanTasks when reading files from storage. Increasing this value will make Daft perform more merging of files into a single partition before yielding, which leads to bigger but fewer partitions. (Defaults to 64MB) @@ -204,18 +208,19 @@ def set_config( Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but fewer partitions. (Defaults to 512MB) """ - old_ctx = get_context() + ctx = get_context() + if ctx.disallow_set_runner: + raise RuntimeError( + "Cannot call `set_config` after the runner has already been created. " + "Please call `set_config` before any calls to set the runner and before any dataframe creation or execution." + ) # Replace values in the DaftConfig with user-specified overrides - old_daft_config = old_ctx.daft_config + old_daft_config = ctx.daft_config if config is None else config new_daft_config = old_daft_config.with_config_values( merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes, merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes, ) - new_ctx = dataclasses.replace( - old_ctx, - daft_config=new_daft_config, - ) - _set_context(new_ctx) - return new_ctx + ctx.daft_config = new_daft_config + return ctx diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index be4dc16feb..55076c2156 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -8,11 +8,11 @@ import psutil -from daft.context import get_context from daft.daft import ( FileFormatConfig, FileInfos, IOConfig, + PyDaftConfig, ResourceRequest, StorageConfig, ) @@ -105,8 +105,9 @@ def get_schema_from_first_filepath( class PyRunner(Runner[MicroPartition]): - def __init__(self, use_thread_pool: bool | None) -> None: + def __init__(self, daft_config: PyDaftConfig, use_thread_pool: bool | None) -> None: super().__init__() + self.daft_config = daft_config self._use_thread_pool: bool = use_thread_pool if use_thread_pool is not None else True self.num_cpus = multiprocessing.cpu_count() @@ -132,13 +133,11 @@ def run_iter( # NOTE: PyRunner does not run any async execution, so it ignores `results_buffer_size` which is essentially 0 results_buffer_size: int | None = None, ) -> Iterator[PyMaterializedResult]: - daft_config = get_context().daft_config - # Optimize the logical plan. builder = builder.optimize() # Finalize the logical plan and get a physical plan scheduler for translating the # physical plan to executable tasks. - plan_scheduler = builder.to_physical_plan_scheduler(daft_config) + plan_scheduler = builder.to_physical_plan_scheduler(self.daft_config) psets = { key: entry.value.values() for key, entry in self._part_set_cache._uuid_to_partition_set.items() diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 14fceae5fc..b845296c7a 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -11,7 +11,7 @@ import pyarrow as pa -from daft.context import get_context +from daft.context import set_config from daft.logical.builder import LogicalPlanBuilder from daft.plan_scheduler import PhysicalPlanScheduler from daft.runners.progress_bar import ProgressBar @@ -30,6 +30,7 @@ FileFormatConfig, FileInfos, IOConfig, + PyDaftConfig, ResourceRequest, StorageConfig, ) @@ -75,10 +76,13 @@ @ray.remote def _glob_path_into_file_infos( + daft_config: PyDaftConfig, paths: list[str], file_format_config: FileFormatConfig | None, io_config: IOConfig | None, ) -> MicroPartition: + set_config(daft_config) + file_infos = FileInfos() file_format = file_format_config.file_format() if file_format_config is not None else None for path in paths: @@ -91,7 +95,9 @@ def _glob_path_into_file_infos( @ray.remote -def _make_ray_block_from_vpartition(partition: MicroPartition) -> RayDatasetBlock: +def _make_ray_block_from_vpartition(daft_config: PyDaftConfig, partition: MicroPartition) -> RayDatasetBlock: + set_config(daft_config) + try: return partition.to_arrow(cast_tensors_to_ray_tensor_dtype=True) except pa.ArrowInvalid: @@ -100,15 +106,20 @@ def _make_ray_block_from_vpartition(partition: MicroPartition) -> RayDatasetBloc @ray.remote def _make_daft_partition_from_ray_dataset_blocks( - ray_dataset_block: pa.MicroPartition, daft_schema: Schema + daft_config: PyDaftConfig, ray_dataset_block: pa.MicroPartition, daft_schema: Schema ) -> MicroPartition: + set_config(daft_config) + return MicroPartition.from_arrow(ray_dataset_block) @ray.remote(num_returns=2) def _make_daft_partition_from_dask_dataframe_partitions( + daft_config: PyDaftConfig, dask_df_partition: pd.DataFrame, ) -> tuple[MicroPartition, pa.Schema]: + set_config(daft_config) + vpart = MicroPartition.from_pandas(dask_df_partition) return vpart, vpart.schema() @@ -127,17 +138,21 @@ def _to_pandas_ref(df: pd.DataFrame | ray.ObjectRef[pd.DataFrame]) -> ray.Object @ray.remote def sample_schema_from_filepath( + daft_config: PyDaftConfig, first_file_path: str, file_format_config: FileFormatConfig, storage_config: StorageConfig, ) -> Schema: """Ray remote function to run schema sampling on top of a MicroPartition containing a single filepath""" + set_config(daft_config) + # Currently just samples the Schema from the first file return runner_io.sample_schema(first_file_path, file_format_config, storage_config) @dataclass class RayPartitionSet(PartitionSet[ray.ObjectRef]): + _daft_config_objref: ray.ObjectRef _results: dict[PartID, RayMaterializedResult] def items(self) -> list[tuple[PartID, ray.ObjectRef]]: @@ -156,7 +171,10 @@ def to_ray_dataset(self) -> RayDataset: "Unable to import `ray.data.from_arrow_refs`. Please ensure that you have a compatible version of Ray >= 1.10 installed." ) - blocks = [_make_ray_block_from_vpartition.remote(self._results[k].partition()) for k in self._results.keys()] + blocks = [ + _make_ray_block_from_vpartition.remote(self._daft_config_objref, self._results[k].partition()) + for k in self._results.keys() + ] # NOTE: although the Ray method is called `from_arrow_refs`, this method works also when the blocks are List[T] types # instead of Arrow tables as the codepath for Dataset creation is the same. return from_arrow_refs(blocks) @@ -205,6 +223,10 @@ def wait(self) -> None: class RayRunnerIO(runner_io.RunnerIO): + def __init__(self, daft_config_objref: ray.ObjectRef, *args, **kwargs): + self.daft_config_objref = daft_config_objref + super().__init__(*args, **kwargs) + def glob_paths_details( self, source_paths: list[str], @@ -213,7 +235,11 @@ def glob_paths_details( ) -> FileInfos: # Synchronously fetch the file infos, for now. return FileInfos.from_table( - ray.get(_glob_path_into_file_infos.remote(source_paths, file_format_config, io_config=io_config)) + ray.get( + _glob_path_into_file_infos.remote( + self.daft_config_objref, source_paths, file_format_config, io_config=io_config + ) + ) .to_table() ._table ) @@ -230,6 +256,7 @@ def get_schema_from_first_filepath( first_path = file_infos[0].file_path return ray.get( sample_schema_from_filepath.remote( + self.daft_config_objref, first_path, file_format_config, storage_config, @@ -267,9 +294,20 @@ def partition_set_from_ray_dataset( # NOTE: This materializes the entire Ray Dataset - we could make this more intelligent by creating a new RayDatasetScan node # which can iterate on Ray Dataset blocks and materialize as-needed daft_vpartitions = [ - _make_daft_partition_from_ray_dataset_blocks.remote(block, daft_schema) for block in block_refs + _make_daft_partition_from_ray_dataset_blocks.remote(self.daft_config_objref, block, daft_schema) + for block in block_refs ] - return RayPartitionSet({i: RayMaterializedResult(obj) for i, obj in enumerate(daft_vpartitions)}), daft_schema + + return ( + RayPartitionSet( + _daft_config_objref=self.daft_config_objref, + _results={ + i: RayMaterializedResult(obj, _daft_config_objref=self.daft_config_objref) + for i, obj in enumerate(daft_vpartitions) + }, + ), + daft_schema, + ) def partition_set_from_dask_dataframe( self, @@ -283,7 +321,9 @@ def partition_set_from_dask_dataframe( raise ValueError("Can't convert an empty Dask DataFrame (with no partitions) to a Daft DataFrame.") persisted_partitions = dask.persist(*partitions, scheduler=ray_dask_get) parts = [_to_pandas_ref(next(iter(part.dask.values()))) for part in persisted_partitions] - daft_vpartitions, schemas = zip(*map(_make_daft_partition_from_dask_dataframe_partitions.remote, parts)) + daft_vpartitions, schemas = zip( + *(_make_daft_partition_from_dask_dataframe_partitions.remote(self.daft_config_objref, p) for p in parts) + ) schemas = ray.get(list(schemas)) # Dask shouldn't allow inconsistent schemas across partitions, but we double-check here. if not all(schemas[0] == schema for schema in schemas[1:]): @@ -291,7 +331,16 @@ def partition_set_from_dask_dataframe( "Can't convert a Dask DataFrame with inconsistent schemas across partitions to a Daft DataFrame:", schemas, ) - return RayPartitionSet({i: RayMaterializedResult(obj) for i, obj in enumerate(daft_vpartitions)}), schemas[0] + return ( + RayPartitionSet( + _daft_config_objref=self.daft_config_objref, + _results={ + i: RayMaterializedResult(obj, _daft_config_objref=self.daft_config_objref) + for i, obj in enumerate(daft_vpartitions) + }, + ), + schemas[0], + ) def _get_ray_task_options(resource_request: ResourceRequest) -> dict[str, Any]: @@ -328,38 +377,43 @@ def build_partitions( @ray.remote def single_partition_pipeline( - instruction_stack: list[Instruction], *inputs: MicroPartition + daft_config: PyDaftConfig, instruction_stack: list[Instruction], *inputs: MicroPartition ) -> list[list[PartitionMetadata] | MicroPartition]: + set_config(daft_config) return build_partitions(instruction_stack, *inputs) @ray.remote def fanout_pipeline( - instruction_stack: list[Instruction], *inputs: MicroPartition + daft_config: PyDaftConfig, instruction_stack: list[Instruction], *inputs: MicroPartition ) -> list[list[PartitionMetadata] | MicroPartition]: + set_config(daft_config) return build_partitions(instruction_stack, *inputs) @ray.remote(scheduling_strategy="SPREAD") def reduce_pipeline( - instruction_stack: list[Instruction], inputs: list + daft_config: PyDaftConfig, instruction_stack: list[Instruction], inputs: list ) -> list[list[PartitionMetadata] | MicroPartition]: import ray + set_config(daft_config) return build_partitions(instruction_stack, *ray.get(inputs)) @ray.remote(scheduling_strategy="SPREAD") def reduce_and_fanout( - instruction_stack: list[Instruction], inputs: list + daft_config: PyDaftConfig, instruction_stack: list[Instruction], inputs: list ) -> list[list[PartitionMetadata] | MicroPartition]: import ray + set_config(daft_config) return build_partitions(instruction_stack, *ray.get(inputs)) @ray.remote -def get_meta(partition: MicroPartition) -> PartitionMetadata: +def get_meta(daft_config: PyDaftConfig, partition: MicroPartition) -> PartitionMetadata: + set_config(daft_config) return PartitionMetadata.from_table(partition) @@ -386,7 +440,7 @@ def _ray_num_cpus_provider(ttl_seconds: int = 1) -> Generator[int, None, None]: class Scheduler: - def __init__(self, max_task_backlog: int | None, use_ray_tqdm: bool) -> None: + def __init__(self, daft_config_objref: ray.ObjectRef, max_task_backlog: int | None, use_ray_tqdm: bool) -> None: """ max_task_backlog: Max number of inflight tasks waiting for cores. """ @@ -403,6 +457,7 @@ def __init__(self, max_task_backlog: int | None, use_ray_tqdm: bool) -> None: self.reserved_cores = 0 + self.daft_config_objref = daft_config_objref self.threads_by_df: dict[str, threading.Thread] = dict() self.results_by_df: dict[str, Queue] = {} self.active_by_df: dict[str, bool] = dict() @@ -522,7 +577,10 @@ def place_in_queue(item): logger.debug("Running task synchronously in main thread: %s", next_step) assert isinstance(next_step, SingleOutputPartitionTask) next_step.set_result( - [RayMaterializedResult(partition) for partition in next_step.inputs] + [ + RayMaterializedResult(partition, _daft_config_objref=self.daft_config_objref) + for partition in next_step.inputs + ] ) next_step = next(tasks) @@ -542,7 +600,7 @@ def place_in_queue(item): break for task in tasks_to_dispatch: - results = _build_partitions(task) + results = _build_partitions(self.daft_config_objref, task) logger.debug("%s -> %s", task, results) inflight_tasks[task.id()] = task for result in results: @@ -621,7 +679,7 @@ def __init__(self, *n, **kw) -> None: self.reserved_cores = 1 -def _build_partitions(task: PartitionTask[ray.ObjectRef]) -> list[ray.ObjectRef]: +def _build_partitions(daft_config_objref: ray.ObjectRef, task: PartitionTask[ray.ObjectRef]) -> list[ray.ObjectRef]: """Run a PartitionTask and return the resulting list of partitions.""" ray_options: dict[str, Any] = { "num_returns": task.num_results + 1, @@ -633,17 +691,27 @@ def _build_partitions(task: PartitionTask[ray.ObjectRef]) -> list[ray.ObjectRef] if isinstance(task.instructions[0], ReduceInstruction): build_remote = reduce_and_fanout if isinstance(task.instructions[-1], FanoutInstruction) else reduce_pipeline build_remote = build_remote.options(**ray_options) - [metadatas_ref, *partitions] = build_remote.remote(task.instructions, task.inputs) + [metadatas_ref, *partitions] = build_remote.remote(daft_config_objref, task.instructions, task.inputs) else: build_remote = ( fanout_pipeline if isinstance(task.instructions[-1], FanoutInstruction) else single_partition_pipeline ) build_remote = build_remote.options(**ray_options) - [metadatas_ref, *partitions] = build_remote.remote(task.instructions, *task.inputs) + [metadatas_ref, *partitions] = build_remote.remote(daft_config_objref, task.instructions, *task.inputs) metadatas_accessor = PartitionMetadataAccessor(metadatas_ref) - task.set_result([RayMaterializedResult(partition, metadatas_accessor, i) for i, partition in enumerate(partitions)]) + task.set_result( + [ + RayMaterializedResult( + _partition=partition, + _daft_config_objref=daft_config_objref, + _metadatas=metadatas_accessor, + _metadata_index=i, + ) + for i, partition in enumerate(partitions) + ] + ) return partitions @@ -651,6 +719,7 @@ def _build_partitions(task: PartitionTask[ray.ObjectRef]) -> list[ray.ObjectRef] class RayRunner(Runner[ray.ObjectRef]): def __init__( self, + daft_config: PyDaftConfig, address: str | None, max_task_backlog: int | None, ) -> None: @@ -659,13 +728,19 @@ def __init__( logger.warning(f"Ray has already been initialized, Daft will reuse the existing Ray context.") self.ray_context = ray.init(address=address, ignore_reinit_error=True) + # We put a frozen copy of the Daft config into the cluster to be used across all subsequent Daft function calls + self.daft_config_objref = ray.put(daft_config) + self.daft_config = daft_config + if isinstance(self.ray_context, ray.client_builder.ClientContext): # Run scheduler remotely if the cluster is connected remotely. self.scheduler_actor = SchedulerActor.remote( # type: ignore - max_task_backlog=max_task_backlog, use_ray_tqdm=True + daft_config_objref=self.daft_config_objref, max_task_backlog=max_task_backlog, use_ray_tqdm=True ) else: - self.scheduler = Scheduler(max_task_backlog=max_task_backlog, use_ray_tqdm=False) + self.scheduler = Scheduler( + daft_config_objref=self.daft_config_objref, max_task_backlog=max_task_backlog, use_ray_tqdm=False + ) def active_plans(self) -> list[str]: if isinstance(self.ray_context, ray.client_builder.ClientContext): @@ -676,13 +751,12 @@ def active_plans(self) -> list[str]: def run_iter( self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None ) -> Iterator[RayMaterializedResult]: - daft_config = get_context().daft_config - # Optimize the logical plan. builder = builder.optimize() + # Finalize the logical plan and get a physical plan scheduler for translating the # physical plan to executable tasks. - plan_scheduler = builder.to_physical_plan_scheduler(daft_config) + plan_scheduler = builder.to_physical_plan_scheduler(self.daft_config) psets = { key: entry.value.values() @@ -699,7 +773,6 @@ def run_iter( results_buffer_size=results_buffer_size, ) ) - else: self.scheduler.start_plan( plan_scheduler=plan_scheduler, @@ -707,6 +780,7 @@ def run_iter( result_uuid=result_uuid, results_buffer_size=results_buffer_size, ) + try: while True: if isinstance(self.ray_context, ray.client_builder.ClientContext): @@ -734,7 +808,7 @@ def run_iter_tables( yield ray.get(result.partition()) def run(self, builder: LogicalPlanBuilder) -> PartitionCacheEntry: - result_pset = RayPartitionSet({}) + result_pset = RayPartitionSet(_daft_config_objref=self.daft_config_objref, _results={}) results_iter = self.run_iter(builder) @@ -747,17 +821,24 @@ def run(self, builder: LogicalPlanBuilder) -> PartitionCacheEntry: def put_partition_set_into_cache(self, pset: PartitionSet) -> PartitionCacheEntry: if isinstance(pset, LocalPartitionSet): - pset = RayPartitionSet({pid: RayMaterializedResult(ray.put(val)) for pid, val in pset._partitions.items()}) + pset = RayPartitionSet( + _daft_config_objref=self.daft_config_objref, + _results={ + pid: RayMaterializedResult(ray.put(val), _daft_config_objref=self.daft_config_objref) + for pid, val in pset._partitions.items() + }, + ) return self._part_set_cache.put_partition_set(pset=pset) def runner_io(self) -> RayRunnerIO: - return RayRunnerIO() + return RayRunnerIO(daft_config_objref=self.daft_config_objref) @dataclass(frozen=True) class RayMaterializedResult(MaterializedResult[ray.ObjectRef]): _partition: ray.ObjectRef + _daft_config_objref: ray.ObjectRef _metadatas: PartitionMetadataAccessor | None = None _metadata_index: int | None = None @@ -771,7 +852,7 @@ def metadata(self) -> PartitionMetadata: if self._metadatas is not None and self._metadata_index is not None: return self._metadatas.get_index(self._metadata_index) else: - return ray.get(get_meta.remote(self._partition)) + return ray.get(get_meta.remote(self._daft_config_objref, self._partition)) def cancel(self) -> None: return ray.cancel(self._partition) diff --git a/src/common/daft-config/Cargo.toml b/src/common/daft-config/Cargo.toml index bbd9878d6c..6d759e8d8e 100644 --- a/src/common/daft-config/Cargo.toml +++ b/src/common/daft-config/Cargo.toml @@ -1,6 +1,8 @@ [dependencies] +bincode = {workspace = true} lazy_static = {workspace = true} pyo3 = {workspace = true, optional = true} +serde = {workspace = true} [features] default = ["python"] diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 77a199d867..b71280ee36 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -1,4 +1,6 @@ -#[derive(Clone)] +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Serialize, Deserialize)] pub struct DaftConfig { pub merge_scan_tasks_min_size_bytes: usize, pub merge_scan_tasks_max_size_bytes: usize, diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 3dd76aad98..f327e5a55d 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -1,11 +1,12 @@ use std::sync::Arc; -use pyo3::prelude::*; +use pyo3::{prelude::*, PyTypeInfo}; +use serde::{Deserialize, Serialize}; use crate::DaftConfig; -#[derive(Clone, Default)] -#[pyclass] +#[derive(Clone, Default, Serialize, Deserialize)] +#[pyclass(module = "daft.daft")] pub struct PyDaftConfig { pub config: Arc, } @@ -45,4 +46,24 @@ impl PyDaftConfig { fn get_merge_scan_tasks_max_size_bytes(&self) -> PyResult { Ok(self.config.merge_scan_tasks_max_size_bytes) } + + fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec,))> { + let bin_data = bincode::serialize(self.config.as_ref()) + .expect("DaftConfig should be serializable to bytes"); + Ok(( + Self::type_object(py) + .getattr("_from_serialized")? + .to_object(py), + (bin_data,), + )) + } + + #[staticmethod] + fn _from_serialized(bin_data: Vec) -> PyResult { + let daft_config: DaftConfig = bincode::deserialize(bin_data.as_slice()) + .expect("DaftConfig should be deserializable from bytes"); + Ok(PyDaftConfig { + config: daft_config.into(), + }) + } } diff --git a/tests/conftest.py b/tests/conftest.py index 805b33e288..6516864390 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,10 @@ @pytest.fixture(scope="session", autouse=True) def set_configs(): """Sets global Daft config for testing""" + + # Pop the old context, which gets rid of the old Runner as well + daft.context.pop_context() + daft.context.set_config( # Disables merging of ScanTasks merge_scan_tasks_min_size_bytes=0, diff --git a/tests/io/test_merge_scan_tasks.py b/tests/io/test_merge_scan_tasks.py index 6c6e9c1ba8..3ed35eb467 100644 --- a/tests/io/test_merge_scan_tasks.py +++ b/tests/io/test_merge_scan_tasks.py @@ -10,9 +10,7 @@ @contextlib.contextmanager def override_merge_scan_tasks_configs(merge_scan_tasks_min_size_bytes: int, merge_scan_tasks_max_size_bytes: int): - config = daft.context.get_context().daft_config - original_merge_scan_tasks_min_size_bytes = config.merge_scan_tasks_min_size_bytes - original_merge_scan_tasks_max_size_bytes = config.merge_scan_tasks_max_size_bytes + old_context = daft.context.pop_context() try: daft.context.set_config( @@ -21,10 +19,7 @@ def override_merge_scan_tasks_configs(merge_scan_tasks_min_size_bytes: int, merg ) yield finally: - daft.context.set_config( - merge_scan_tasks_min_size_bytes=original_merge_scan_tasks_min_size_bytes, - merge_scan_tasks_max_size_bytes=original_merge_scan_tasks_max_size_bytes, - ) + daft.context.set_context(old_context) @pytest.fixture(scope="function")