diff --git a/Cargo.lock b/Cargo.lock index 214c1790bf..324b861e8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1835,6 +1835,7 @@ version = "0.3.0-dev0" dependencies = [ "async-stream", "async-trait", + "common-daft-config", "common-display", "common-error", "common-tracing", diff --git a/daft/context.py b/daft/context.py index 872985dc05..f286c77c7b 100644 --- a/daft/context.py +++ b/daft/context.py @@ -288,6 +288,7 @@ def set_execution_config( read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, + default_morsel_size: int | None = None, ) -> DaftContext: """Globally sets various configuration parameters which control various aspects of Daft execution. These configuration values are used when a Dataframe is executed (e.g. calls to `.write_*`, `.collect()` or `.show()`) @@ -323,6 +324,7 @@ def set_execution_config( read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB enable_aqe: Enables Adaptive Query Execution, Defaults to False enable_native_executor: Enables new local executor. Defaults to False + default_morsel_size: Default size of morsels used for the new local executor. Defaults to 131072 rows. """ # Replace values in the DaftExecutionConfig with user-specified overrides ctx = get_context() @@ -346,6 +348,7 @@ def set_execution_config( read_sql_partition_size_bytes=read_sql_partition_size_bytes, enable_aqe=enable_aqe, enable_native_executor=enable_native_executor, + default_morsel_size=default_morsel_size, ) ctx._daft_execution_config = new_daft_execution_config diff --git a/daft/daft.pyi b/daft/daft.pyi index 8a023e61cf..bb701bd3fb 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1715,7 +1715,9 @@ class NativeExecutor: def from_logical_plan_builder( logical_plan_builder: LogicalPlanBuilder, ) -> NativeExecutor: ... - def run(self, psets: dict[str, list[PartitionT]]) -> Iterator[PyMicroPartition]: ... + def run( + self, psets: dict[str, list[PartitionT]], cfg: PyDaftExecutionConfig, results_buffer_size: int | None + ) -> Iterator[PyMicroPartition]: ... class PyDaftExecutionConfig: @staticmethod @@ -1739,6 +1741,7 @@ class PyDaftExecutionConfig: read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, + default_morsel_size: int | None = None, ) -> PyDaftExecutionConfig: ... @property def scan_tasks_min_size_bytes(self) -> int: ... @@ -1772,6 +1775,8 @@ class PyDaftExecutionConfig: def enable_aqe(self) -> bool: ... @property def enable_native_executor(self) -> bool: ... + @property + def default_morsel_size(self) -> int: ... class PyDaftPlanningConfig: @staticmethod diff --git a/daft/execution/native_executor.py b/daft/execution/native_executor.py index 3898c31406..3c790fbfb5 100644 --- a/daft/execution/native_executor.py +++ b/daft/execution/native_executor.py @@ -5,6 +5,7 @@ from daft.daft import ( NativeExecutor as _NativeExecutor, ) +from daft.daft import PyDaftExecutionConfig from daft.logical.builder import LogicalPlanBuilder from daft.runners.partitioning import ( MaterializedResult, @@ -25,10 +26,16 @@ def from_logical_plan_builder(cls, builder: LogicalPlanBuilder) -> NativeExecuto executor = _NativeExecutor.from_logical_plan_builder(builder._builder) return cls(executor) - def run(self, psets: dict[str, list[MaterializedResult[PartitionT]]]) -> Iterator[PyMaterializedResult]: + def run( + self, + psets: dict[str, list[MaterializedResult[PartitionT]]], + daft_execution_config: PyDaftExecutionConfig, + results_buffer_size: int | None, + ) -> Iterator[PyMaterializedResult]: from daft.runners.pyrunner import PyMaterializedResult psets_mp = {part_id: [part.vpartition()._micropartition for part in parts] for part_id, parts in psets.items()} return ( - PyMaterializedResult(MicroPartition._from_pymicropartition(part)) for part in self._executor.run(psets_mp) + PyMaterializedResult(MicroPartition._from_pymicropartition(part)) + for part in self._executor.run(psets_mp, daft_execution_config, results_buffer_size) ) diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 0820af6987..04f67b9eb2 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -195,7 +195,9 @@ def run_iter( logger.info("Using native executor") executor = NativeExecutor.from_logical_plan_builder(builder) results_gen = executor.run( - {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()} + {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()}, + daft_execution_config, + results_buffer_size, ) yield from results_gen else: diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 0462f9118f..10d2b75cff 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -58,6 +58,7 @@ pub struct DaftExecutionConfig { pub read_sql_partition_size_bytes: usize, pub enable_aqe: bool, pub enable_native_executor: bool, + pub default_morsel_size: usize, } impl Default for DaftExecutionConfig { @@ -80,6 +81,7 @@ impl Default for DaftExecutionConfig { read_sql_partition_size_bytes: 512 * 1024 * 1024, // 512MB enable_aqe: false, enable_native_executor: false, + default_morsel_size: 128 * 1024, } } } diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 0a55b82aa7..5ce219d1aa 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -114,6 +114,7 @@ impl PyDaftExecutionConfig { read_sql_partition_size_bytes: Option, enable_aqe: Option, enable_native_executor: Option, + default_morsel_size: Option, ) -> PyResult { let mut config = self.config.as_ref().clone(); @@ -173,6 +174,9 @@ impl PyDaftExecutionConfig { if let Some(enable_native_executor) = enable_native_executor { config.enable_native_executor = enable_native_executor; } + if let Some(default_morsel_size) = default_morsel_size { + config.default_morsel_size = default_morsel_size; + } Ok(PyDaftExecutionConfig { config: Arc::new(config), @@ -256,6 +260,10 @@ impl PyDaftExecutionConfig { fn enable_native_executor(&self) -> PyResult { Ok(self.config.enable_native_executor) } + #[getter] + fn default_morsel_size(&self) -> PyResult { + Ok(self.config.default_morsel_size) + } fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec,))> { let bin_data = bincode::serialize(self.config.as_ref()) diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index cc0c696805..07b9bc4682 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,6 +1,7 @@ [dependencies] async-stream = {workspace = true} async-trait = {workspace = true} +common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} common-tracing = {path = "../common/tracing", default-features = false} @@ -26,7 +27,7 @@ tokio = {workspace = true} tracing = {workspace = true} [features] -python = ["dep:pyo3", "common-error/python", "daft-dsl/python", "daft-io/python", "daft-micropartition/python", "daft-plan/python", "daft-scan/python", "common-display/python"] +python = ["dep:pyo3", "common-daft-config/python", "common-error/python", "daft-dsl/python", "daft-io/python", "daft-micropartition/python", "daft-plan/python", "daft-scan/python", "common-display/python"] [package] edition = {workspace = true} diff --git a/src/daft-local-execution/src/intermediate_ops/buffer.rs b/src/daft-local-execution/src/intermediate_ops/buffer.rs new file mode 100644 index 0000000000..e0301d90ec --- /dev/null +++ b/src/daft-local-execution/src/intermediate_ops/buffer.rs @@ -0,0 +1,78 @@ +use std::{collections::VecDeque, sync::Arc}; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; +use std::cmp::Ordering::*; + +pub struct OperatorBuffer { + pub buffer: VecDeque>, + pub curr_len: usize, + pub threshold: usize, +} + +impl OperatorBuffer { + pub fn new(threshold: usize) -> Self { + assert!(threshold > 0); + Self { + buffer: VecDeque::new(), + curr_len: 0, + threshold, + } + } + + pub fn push(&mut self, part: Arc) { + self.curr_len += part.len(); + self.buffer.push_back(part); + } + + pub fn try_clear(&mut self) -> Option>> { + match self.curr_len.cmp(&self.threshold) { + Less => None, + Equal => self.clear_all(), + Greater => Some(self.clear_enough()), + } + } + + fn clear_enough(&mut self) -> DaftResult> { + assert!(self.curr_len > self.threshold); + + let mut to_concat = Vec::with_capacity(self.buffer.len()); + let mut remaining = self.threshold; + + while remaining > 0 { + let part = self.buffer.pop_front().expect("Buffer should not be empty"); + let part_len = part.len(); + if part_len <= remaining { + remaining -= part_len; + to_concat.push(part); + } else { + let (head, tail) = part.split_at(remaining)?; + remaining = 0; + to_concat.push(Arc::new(head)); + self.buffer.push_front(Arc::new(tail)); + break; + } + } + assert_eq!(remaining, 0); + + self.curr_len -= self.threshold; + match to_concat.len() { + 1 => Ok(to_concat.pop().unwrap()), + _ => MicroPartition::concat(&to_concat.iter().map(|x| x.as_ref()).collect::>()) + .map(Arc::new), + } + } + + pub fn clear_all(&mut self) -> Option>> { + if self.buffer.is_empty() { + return None; + } + + let concated = + MicroPartition::concat(&self.buffer.iter().map(|x| x.as_ref()).collect::>()) + .map(Arc::new); + self.buffer.clear(); + self.curr_len = 0; + Some(concated) + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 0dbe69f6af..ecb7182fab 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -17,7 +17,8 @@ use crate::{ ExecutionRuntimeHandle, NUM_CPUS, }; -use super::state::OperatorTaskState; +use super::buffer::OperatorBuffer; + pub trait IntermediateOperator: Send + Sync { fn execute(&self, input: &Arc) -> DaftResult>; fn name(&self) -> &'static str; @@ -61,19 +62,12 @@ impl IntermediateNode { sender: SingleSender, rt_context: Arc, ) -> DaftResult<()> { - let mut state = OperatorTaskState::new(); let span = info_span!("IntermediateOp::execute"); let sender = CountingSender::new(sender, rt_context.clone()); while let Some(morsel) = receiver.recv().await { rt_context.mark_rows_received(morsel.len() as u64); let result = rt_context.in_span(&span, || op.execute(&morsel))?; - state.add(result); - if let Some(part) = state.try_clear() { - let _ = sender.send(part?).await; - } - } - if let Some(part) = state.clear() { - let _ = sender.send(part?).await; + let _ = sender.send(result).await; } Ok(()) } @@ -105,16 +99,31 @@ impl IntermediateNode { pub async fn send_to_workers( mut receiver: MultiReceiver, worker_senders: Vec, + morsel_size: usize, ) -> DaftResult<()> { let mut next_worker_idx = 0; + let mut send_to_next_worker = |morsel: Arc| { + let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); + next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); + next_worker_sender.send(morsel) + }; + let mut buffer = OperatorBuffer::new(morsel_size); + while let Some(morsel) = receiver.recv().await { - if morsel.is_empty() { - continue; + buffer.push(morsel); + if let Some(ready) = buffer.try_clear() { + let _ = send_to_next_worker(ready?).await; } + } - let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); - let _ = next_worker_sender.send(morsel).await; - next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); + // Buffer may still have some morsels left above the threshold + while let Some(ready) = buffer.try_clear() { + let _ = send_to_next_worker(ready?).await; + } + + // Clear all remaining morsels + if let Some(last_morsel) = buffer.clear_all() { + let _ = send_to_next_worker(last_morsel?).await; } Ok(()) } @@ -169,7 +178,14 @@ impl PipelineNode for IntermediateNode { child.start(sender, runtime_handle).await?; let worker_senders = self.spawn_workers(&mut destination, runtime_handle).await; - runtime_handle.spawn(Self::send_to_workers(receiver, worker_senders), self.name()); + runtime_handle.spawn( + Self::send_to_workers( + receiver, + worker_senders, + runtime_handle.default_morsel_size(), + ), + self.intermediate_op.name(), + ); Ok(()) } fn as_tree_display(&self) -> &dyn TreeDisplay { diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 31e9f65ab1..290f4bf184 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,5 +1,5 @@ pub mod aggregate; +pub mod buffer; pub mod filter; pub mod intermediate_op; pub mod project; -pub mod state; diff --git a/src/daft-local-execution/src/intermediate_ops/state.rs b/src/daft-local-execution/src/intermediate_ops/state.rs deleted file mode 100644 index 2e145cb29e..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/state.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::sync::Arc; - -use common_error::DaftResult; -use daft_micropartition::MicroPartition; - -use crate::DEFAULT_MORSEL_SIZE; - -/// State of an operator task, used to buffer data and output it when a threshold is reached. -pub struct OperatorTaskState { - pub buffer: Vec>, - pub curr_len: usize, - pub threshold: usize, -} - -impl OperatorTaskState { - pub fn new() -> Self { - Self { - buffer: vec![], - curr_len: 0, - threshold: DEFAULT_MORSEL_SIZE, - } - } - - // Add a micro partition to the buffer. - pub fn add(&mut self, part: Arc) { - self.curr_len += part.len(); - self.buffer.push(part); - } - - // Try to clear the buffer if the threshold is reached. - pub fn try_clear(&mut self) -> Option>> { - if self.curr_len >= self.threshold { - self.clear() - } else { - None - } - } - - // Clear the buffer and return the concatenated MicroPartition. - pub fn clear(&mut self) -> Option>> { - if self.buffer.is_empty() { - return None; - } - - let concated = - MicroPartition::concat(&self.buffer.iter().map(|x| x.as_ref()).collect::>()) - .map(Arc::new); - self.buffer.clear(); - self.curr_len = 0; - Some(concated) - } -} diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index a8c7545de6..92e42e6fe8 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -16,19 +16,15 @@ lazy_static! { } pub struct ExecutionRuntimeHandle { - pub worker_set: tokio::task::JoinSet>, -} - -impl Default for ExecutionRuntimeHandle { - fn default() -> Self { - Self::new() - } + worker_set: tokio::task::JoinSet>, + default_morsel_size: usize, } impl ExecutionRuntimeHandle { - pub fn new() -> Self { + pub fn new(default_morsel_size: usize) -> Self { Self { worker_set: tokio::task::JoinSet::new(), + default_morsel_size, } } pub fn spawn( @@ -48,9 +44,11 @@ impl ExecutionRuntimeHandle { pub async fn shutdown(&mut self) { self.worker_set.shutdown().await; } -} -const DEFAULT_MORSEL_SIZE: usize = 1000; + pub fn default_morsel_size(&self) -> usize { + self.default_morsel_size + } +} #[cfg(feature = "python")] use pyo3::prelude::*; diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index e4c854f9de..1e8a03ba24 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -1,3 +1,8 @@ +use common_daft_config::DaftExecutionConfig; +use common_error::DaftResult; +use common_tracing::refresh_chrome_trace; +use daft_micropartition::MicroPartition; +use daft_physical_plan::{translate, LocalPhysicalPlan}; use std::{ collections::HashMap, fs::File, @@ -9,22 +14,18 @@ use std::{ time::{SystemTime, UNIX_EPOCH}, }; -use common_error::DaftResult; -use common_tracing::refresh_chrome_trace; -use daft_micropartition::MicroPartition; -use daft_physical_plan::{translate, LocalPhysicalPlan}; - #[cfg(feature = "python")] use { + common_daft_config::PyDaftExecutionConfig, daft_micropartition::python::PyMicroPartition, daft_plan::PyLogicalPlanBuilder, pyo3::{pyclass, pymethods, IntoPy, PyObject, PyRef, PyRefMut, PyResult, Python}, }; use crate::{ - channel::create_channel, + channel::{create_channel, create_single_channel, SingleReceiver}, pipeline::{physical_plan_to_pipeline, viz_pipeline}, - Error, ExecutionRuntimeHandle, + Error, ExecutionRuntimeHandle, NUM_CPUS, }; #[cfg(feature = "python")] @@ -71,6 +72,8 @@ impl NativeExecutor { &self, py: Python, psets: HashMap>, + cfg: PyDaftExecutionConfig, + results_buffer_size: Option, ) -> PyResult { let native_psets: HashMap>> = psets .into_iter() @@ -84,7 +87,14 @@ impl NativeExecutor { ) }) .collect(); - let out = py.allow_threads(|| run_local(&self.local_physical_plan, native_psets))?; + let out = py.allow_threads(|| { + run_local( + &self.local_physical_plan, + native_psets, + cfg.config, + results_buffer_size, + ) + })?; let iter = Box::new(out.map(|part| { part.map(|p| pyo3::Python::with_gil(|py| PyMicroPartition::from(p).into_py(py))) })); @@ -107,53 +117,90 @@ fn should_enable_explain_analyze() -> bool { pub fn run_local( physical_plan: &LocalPhysicalPlan, psets: HashMap>>, + cfg: Arc, + results_buffer_size: Option, ) -> DaftResult>> + Send>> { refresh_chrome_trace(); - let runtime = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .max_blocking_threads(10) - .thread_name_fn(|| { - static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); - let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); - format!("Executor-Worker-{}", id) + let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets)?; + let (tx, rx) = create_single_channel(results_buffer_size.unwrap_or(1)); + let handle = std::thread::spawn(move || { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .max_blocking_threads(10) + .thread_name_fn(|| { + static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("Executor-Worker-{}", id) + }) + .build() + .expect("Failed to create tokio runtime"); + runtime.block_on(async { + let (sender, mut receiver) = create_channel(*NUM_CPUS, true); + + let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); + pipeline.start(sender, &mut runtime_handle).await?; + while let Some(val) = receiver.recv().await { + let _ = tx.send(val).await; + } + + while let Some(result) = runtime_handle.join_next().await { + match result { + Ok(Err(e)) => { + runtime_handle.shutdown().await; + return DaftResult::Err(e.into()); + } + Err(e) => { + runtime_handle.shutdown().await; + return DaftResult::Err(Error::JoinError { source: e }.into()); + } + _ => {} + } + } + if should_enable_explain_analyze() { + let curr_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(); + let file_name = format!("explain-analyze-{}-mermaid.md", curr_ms); + let mut file = File::create(file_name)?; + writeln!(file, "```mermaid\n{}\n```", viz_pipeline(pipeline.as_ref()))?; + } + Ok(()) }) - .build() - .expect("Failed to create tokio runtime"); + }); - let res = runtime.block_on(async { - let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets)?; - let (sender, mut receiver) = create_channel(1, true); + struct ReceiverIterator { + receiver: SingleReceiver, + handle: Option>>, + } - let mut runtime_handle = ExecutionRuntimeHandle::default(); - pipeline.start(sender, &mut runtime_handle).await?; - let mut result = vec![]; - while let Some(val) = receiver.recv().await { - result.push(Ok(val)); - } + impl Iterator for ReceiverIterator { + type Item = DaftResult>; - while let Some(result) = runtime_handle.join_next().await { - match result { - Ok(Err(e)) => { - runtime_handle.shutdown().await; - return DaftResult::Err(e.into()); - } - Err(e) => { - runtime_handle.shutdown().await; - return DaftResult::Err(Error::JoinError { source: e }.into()); + fn next(&mut self) -> Option { + match self.receiver.blocking_recv() { + Some(part) => Some(Ok(part)), + None => { + if self.handle.is_some() { + let join_result = self + .handle + .take() + .unwrap() + .join() + .expect("Execution engine thread panicked"); + match join_result { + Ok(_) => None, + Err(e) => Some(Err(e)), + } + } else { + None + } } - _ => {} } } - if should_enable_explain_analyze() { - let curr_ms = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis(); - let file_name = format!("explain-analyze-{}-mermaid.md", curr_ms); - let mut file = File::create(file_name)?; - writeln!(file, "```mermaid\n{}\n```", viz_pipeline(pipeline.as_ref()))?; - } - Ok(result.into_iter()) - }); - Ok(Box::new(res?)) + } + Ok(Box::new(ReceiverIterator { + receiver: rx, + handle: Some(handle), + })) } diff --git a/src/daft-local-execution/src/sinks/hash_join.rs b/src/daft-local-execution/src/sinks/hash_join.rs index c0cb2c36fc..eabe56024e 100644 --- a/src/daft-local-execution/src/sinks/hash_join.rs +++ b/src/daft-local-execution/src/sinks/hash_join.rs @@ -367,7 +367,11 @@ impl PipelineNode for HashJoinNode { .spawn_workers(&mut destination, runtime_handle) .await; runtime_handle.spawn( - IntermediateNode::send_to_workers(streaming_receiver, worker_senders), + IntermediateNode::send_to_workers( + streaming_receiver, + worker_senders, + runtime_handle.default_morsel_size(), + ), self.name(), ); Ok(()) diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 6f62e3567c..b82649713d 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use crate::{ channel::{MultiSender, SingleSender}, runtime_stats::{CountingSender, RuntimeStatsContext}, - ExecutionRuntimeHandle, DEFAULT_MORSEL_SIZE, + ExecutionRuntimeHandle, }; use super::source::{Source, SourceStream}; @@ -69,7 +69,7 @@ impl Source for ScanTaskSource { runtime_stats: Arc, io_stats: IOStatsRef, ) -> crate::Result<()> { - let morsel_size = DEFAULT_MORSEL_SIZE; + let morsel_size = runtime_handle.default_morsel_size(); let maintain_order = destination.in_order(); for scan_task in self.scan_tasks.clone() { let sender = destination.get_next_sender(); diff --git a/src/daft-micropartition/src/ops/slice.rs b/src/daft-micropartition/src/ops/slice.rs index e9e6fa2161..fa6cb858f1 100644 --- a/src/daft-micropartition/src/ops/slice.rs +++ b/src/daft-micropartition/src/ops/slice.rs @@ -34,7 +34,7 @@ impl MicroPartition { if offset_so_far == 0 && rows_needed >= tab_rows { slices_tables.push(tab.clone()); - rows_needed += tab_rows; + rows_needed -= tab_rows; } else { let new_end = (rows_needed + offset_so_far).min(tab_rows); let sliced = tab.slice(offset_so_far, new_end)?; @@ -54,4 +54,8 @@ impl MicroPartition { pub fn head(&self, num: usize) -> DaftResult { self.slice(0, num) } + + pub fn split_at(&self, idx: usize) -> DaftResult<(Self, Self)> { + Ok((self.head(idx)?, self.slice(idx, self.len())?)) + } }