diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 846c2bc9441b..f832cbfc57c8 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -57,6 +57,8 @@ unicode_expressions = ["datafusion-physical-expr/unicode_expressions", "datafusi ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } apache-avro = { version = "0.14", optional = true } arrow = { workspace = true } +arrow-schema = { workspace = true } +arrow-array = { workspace = true } async-compression = { version = "0.3.14", features = ["bzip2", "gzip", "xz", "zstd", "futures-io", "tokio"], optional = true } async-trait = "0.1.41" bytes = "1.4" diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs index 3507a5b224d2..7e8d600542fd 100644 --- a/datafusion/core/src/physical_plan/sorts/cursor.rs +++ b/datafusion/core/src/physical_plan/sorts/cursor.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use crate::physical_plan::sorts::sort::SortOptions; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::ArrowNativeTypeOp; use arrow::row::{Row, Rows}; use std::cmp::Ordering; @@ -93,3 +96,232 @@ impl Cursor for RowCursor { t } } + +/// A cursor over sorted, nullable [`ArrowNativeTypeOp`] +/// +/// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering +#[derive(Debug)] +pub struct PrimitiveCursor { + values: ScalarBuffer, + offset: usize, + // If nulls first, the first non-null index + // Otherwise, the first null index + null_threshold: usize, + options: SortOptions, +} + +impl PrimitiveCursor { + /// Create a new [`PrimitiveCursor`] from the provided `values` sorted according to `options` + pub fn new(options: SortOptions, values: ScalarBuffer, null_count: usize) -> Self { + assert!(null_count <= values.len()); + + let null_threshold = match options.nulls_first { + true => null_count, + false => values.len() - null_count, + }; + + Self { + values, + offset: 0, + null_threshold, + options, + } + } + + fn is_null(&self) -> bool { + (self.offset < self.null_threshold) == self.options.nulls_first + } + + fn value(&self) -> T { + self.values[self.offset] + } +} + +impl PartialEq for PrimitiveCursor { + fn eq(&self, other: &Self) -> bool { + self.cmp(other).is_eq() + } +} + +impl Eq for PrimitiveCursor {} +impl PartialOrd for PrimitiveCursor { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PrimitiveCursor { + fn cmp(&self, other: &Self) -> Ordering { + match (self.is_null(), other.is_null()) { + (true, true) => Ordering::Equal, + (true, false) => match self.options.nulls_first { + true => Ordering::Less, + false => Ordering::Greater, + }, + (false, true) => match self.options.nulls_first { + true => Ordering::Greater, + false => Ordering::Less, + }, + (false, false) => { + let s_v = self.value(); + let o_v = other.value(); + + match self.options.descending { + true => o_v.compare(s_v), + false => s_v.compare(o_v), + } + } + } + } +} + +impl Cursor for PrimitiveCursor { + fn is_finished(&self) -> bool { + self.offset == self.values.len() + } + + fn advance(&mut self) -> usize { + let t = self.offset; + self.offset += 1; + t + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_primitive_nulls_first() { + let options = SortOptions { + descending: false, + nulls_first: true, + }; + + let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]); + let mut a = PrimitiveCursor::new(options, buffer, 1); + let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]); + let mut b = PrimitiveCursor::new(options, buffer, 2); + + // NULL == NULL + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // NULL == NULL + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // NULL < -2 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 1 > -2 + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 1 > -1 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 1 == 1 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // 9 > 1 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 9 > 2 + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + let options = SortOptions { + descending: false, + nulls_first: false, + }; + + let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]); + let mut a = PrimitiveCursor::new(options, buffer, 2); + let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]); + let mut b = PrimitiveCursor::new(options, buffer, 2); + + // 0 > -1 + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 0 < NULL + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 1 < NULL + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // NULL = NULL + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + let options = SortOptions { + descending: true, + nulls_first: false, + }; + + let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]); + let mut a = PrimitiveCursor::new(options, buffer, 3); + let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]); + let mut b = PrimitiveCursor::new(options, buffer, 2); + + // 6 > 67 + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 6 < -3 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 6 < NULL + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 6 < NULL + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // NULL == NULL + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + let options = SortOptions { + descending: true, + nulls_first: true, + }; + + let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]); + let mut a = PrimitiveCursor::new(options, buffer, 2); + let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]); + let mut b = PrimitiveCursor::new(options, buffer, 1); + + // NULL == NULL + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // NULL == NULL + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Equal); + assert_eq!(a, b); + + // NULL < 4546 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + + // 6 > 4546 + a.advance(); + assert_eq!(a.cmp(&b), Ordering::Greater); + + // 6 < -3 + b.advance(); + assert_eq!(a.cmp(&b), Ordering::Less); + } +} diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs b/datafusion/core/src/physical_plan/sorts/merge.rs index 69c3fb82c2d2..d7be0c972d72 100644 --- a/datafusion/core/src/physical_plan/sorts/merge.rs +++ b/datafusion/core/src/physical_plan/sorts/merge.rs @@ -19,16 +19,31 @@ use crate::common::Result; use crate::physical_plan::metrics::MemTrackingMetrics; use crate::physical_plan::sorts::builder::BatchBuilder; use crate::physical_plan::sorts::cursor::Cursor; -use crate::physical_plan::sorts::stream::{PartitionedStream, RowCursorStream}; +use crate::physical_plan::sorts::stream::{ + PartitionedStream, PrimitiveCursorStream, RowCursorStream, +}; use crate::physical_plan::{ PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream, }; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use arrow_array::downcast_primitive; use futures::Stream; use std::pin::Pin; use std::task::{ready, Context, Poll}; +macro_rules! primitive_merge_helper { + ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{ + let streams = PrimitiveCursorStream::<$t>::new($sort, $streams); + return Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + $schema, + $tracking_metrics, + $batch_size, + ))); + }}; +} + /// Perform a streaming merge of [`SendableRecordBatchStream`] pub(crate) fn streaming_merge( streams: Vec, @@ -37,8 +52,16 @@ pub(crate) fn streaming_merge( tracking_metrics: MemTrackingMetrics, batch_size: usize, ) -> Result { - let streams = RowCursorStream::try_new(schema.as_ref(), expressions, streams)?; + if expressions.len() == 1 { + let sort = expressions[0].clone(); + let data_type = sort.expr.data_type(schema.as_ref())?; + downcast_primitive! { + data_type => (primitive_merge_helper, sort, streams, schema, tracking_metrics, batch_size), + _ => {} + } + } + let streams = RowCursorStream::try_new(schema.as_ref(), expressions, streams)?; Ok(Box::pin(SortPreservingMergeStream::new( Box::new(streams), schema, diff --git a/datafusion/core/src/physical_plan/sorts/stream.rs b/datafusion/core/src/physical_plan/sorts/stream.rs index 3fe68624f7c1..509c53203188 100644 --- a/datafusion/core/src/physical_plan/sorts/stream.rs +++ b/datafusion/core/src/physical_plan/sorts/stream.rs @@ -16,13 +16,16 @@ // under the License. use crate::common::Result; -use crate::physical_plan::sorts::cursor::RowCursor; +use crate::physical_plan::sorts::cursor::{PrimitiveCursor, RowCursor}; use crate::physical_plan::SendableRecordBatchStream; use crate::physical_plan::{PhysicalExpr, PhysicalSortExpr}; +use arrow::array::{Array, ArrowPrimitiveType}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; +use datafusion_common::cast::as_primitive_array; use futures::stream::{Fuse, StreamExt}; +use std::marker::PhantomData; use std::sync::Arc; use std::task::{ready, Context, Poll}; @@ -75,7 +78,7 @@ impl FusedStreams { /// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`] /// and computes [`RowCursor`] based on the provided [`PhysicalSortExpr`] #[derive(Debug)] -pub(crate) struct RowCursorStream { +pub struct RowCursorStream { /// Converter to convert output of physical expressions converter: RowConverter, /// The physical expressions to sort by @@ -85,7 +88,7 @@ pub(crate) struct RowCursorStream { } impl RowCursorStream { - pub(crate) fn try_new( + pub fn try_new( schema: &Schema, expressions: &[PhysicalSortExpr], streams: Vec, @@ -139,3 +142,67 @@ impl PartitionedStream for RowCursorStream { })) } } + +pub struct PrimitiveCursorStream { + /// The physical expressions to sort by + sort: PhysicalSortExpr, + /// Input streams + streams: FusedStreams, + phantom: PhantomData T>, +} + +impl std::fmt::Debug for PrimitiveCursorStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrimitiveCursorStream") + .field("data_type", &T::DATA_TYPE) + .field("num_streams", &self.streams) + .finish() + } +} + +impl PrimitiveCursorStream { + pub fn new(sort: PhysicalSortExpr, streams: Vec) -> Self { + let streams = streams.into_iter().map(|s| s.fuse()).collect(); + Self { + sort, + streams: FusedStreams(streams), + phantom: Default::default(), + } + } + + fn convert_batch( + &mut self, + batch: &RecordBatch, + ) -> Result> { + let value = self.sort.expr.evaluate(batch)?; + let array = value.into_array(batch.num_rows()); + let array = as_primitive_array::(array.as_ref())?; + + Ok(PrimitiveCursor::new( + self.sort.options, + array.values().clone(), + array.null_count(), + )) + } +} + +impl PartitionedStream for PrimitiveCursorStream { + type Output = Result<(PrimitiveCursor, RecordBatch)>; + + fn partitions(&self) -> usize { + self.streams.0.len() + } + + fn poll_next( + &mut self, + cx: &mut Context<'_>, + stream_idx: usize, + ) -> Poll> { + Poll::Ready(ready!(self.streams.poll_next(cx, stream_idx)).map(|r| { + r.and_then(|batch| { + let cursor = self.convert_batch(&batch)?; + Ok((cursor, batch)) + }) + })) + } +}