Skip to content

Commit

Permalink
Specialize PrimitiveCursor (apache#5882)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Apr 6, 2023
1 parent 7a8b225 commit 5f7a3d6
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 5 deletions.
2 changes: 2 additions & 0 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ unicode_expressions = ["datafusion-physical-expr/unicode_expressions", "datafusi
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
apache-avro = { version = "0.14", optional = true }
arrow = { workspace = true }
arrow-schema = { workspace = true }
arrow-array = { workspace = true }
async-compression = { version = "0.3.14", features = ["bzip2", "gzip", "xz", "zstd", "futures-io", "tokio"], optional = true }
async-trait = "0.1.41"
bytes = "1.4"
Expand Down
232 changes: 232 additions & 0 deletions datafusion/core/src/physical_plan/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use crate::physical_plan::sorts::sort::SortOptions;
use arrow::buffer::ScalarBuffer;
use arrow::datatypes::ArrowNativeTypeOp;
use arrow::row::{Row, Rows};
use std::cmp::Ordering;

Expand Down Expand Up @@ -93,3 +96,232 @@ impl Cursor for RowCursor {
t
}
}

/// A cursor over sorted, nullable [`ArrowNativeTypeOp`]
///
/// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering
#[derive(Debug)]
pub struct PrimitiveCursor<T: ArrowNativeTypeOp> {
values: ScalarBuffer<T>,
offset: usize,
// If nulls first, the first non-null index
// Otherwise, the first null index
null_threshold: usize,
options: SortOptions,
}

impl<T: ArrowNativeTypeOp> PrimitiveCursor<T> {
/// Create a new [`PrimitiveCursor`] from the provided `values` sorted according to `options`
pub fn new(options: SortOptions, values: ScalarBuffer<T>, null_count: usize) -> Self {
assert!(null_count <= values.len());

let null_threshold = match options.nulls_first {
true => null_count,
false => values.len() - null_count,
};

Self {
values,
offset: 0,
null_threshold,
options,
}
}

fn is_null(&self) -> bool {
(self.offset < self.null_threshold) == self.options.nulls_first
}

fn value(&self) -> T {
self.values[self.offset]
}
}

impl<T: ArrowNativeTypeOp> PartialEq for PrimitiveCursor<T> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}

impl<T: ArrowNativeTypeOp> Eq for PrimitiveCursor<T> {}
impl<T: ArrowNativeTypeOp> PartialOrd for PrimitiveCursor<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl<T: ArrowNativeTypeOp> Ord for PrimitiveCursor<T> {
fn cmp(&self, other: &Self) -> Ordering {
match (self.is_null(), other.is_null()) {
(true, true) => Ordering::Equal,
(true, false) => match self.options.nulls_first {
true => Ordering::Less,
false => Ordering::Greater,
},
(false, true) => match self.options.nulls_first {
true => Ordering::Greater,
false => Ordering::Less,
},
(false, false) => {
let s_v = self.value();
let o_v = other.value();

match self.options.descending {
true => o_v.compare(s_v),
false => s_v.compare(o_v),
}
}
}
}
}

impl<T: ArrowNativeTypeOp> Cursor for PrimitiveCursor<T> {
fn is_finished(&self) -> bool {
self.offset == self.values.len()
}

fn advance(&mut self) -> usize {
let t = self.offset;
self.offset += 1;
t
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_primitive_nulls_first() {
let options = SortOptions {
descending: false,
nulls_first: true,
};

let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]);
let mut a = PrimitiveCursor::new(options, buffer, 1);
let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]);
let mut b = PrimitiveCursor::new(options, buffer, 2);

// NULL == NULL
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);

// NULL == NULL
b.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);

// NULL < -2
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);

// 1 > -2
a.advance();
assert_eq!(a.cmp(&b), Ordering::Greater);

// 1 > -1
b.advance();
assert_eq!(a.cmp(&b), Ordering::Greater);

// 1 == 1
b.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);

// 9 > 1
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);

// 9 > 2
a.advance();
assert_eq!(a.cmp(&b), Ordering::Less);

let options = SortOptions {
descending: false,
nulls_first: false,
};

let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]);
let mut a = PrimitiveCursor::new(options, buffer, 2);
let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]);
let mut b = PrimitiveCursor::new(options, buffer, 2);

// 0 > -1
assert_eq!(a.cmp(&b), Ordering::Greater);

// 0 < NULL
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);

// 1 < NULL
a.advance();
assert_eq!(a.cmp(&b), Ordering::Less);

// NULL = NULL
a.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);

let options = SortOptions {
descending: true,
nulls_first: false,
};

let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]);
let mut a = PrimitiveCursor::new(options, buffer, 3);
let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]);
let mut b = PrimitiveCursor::new(options, buffer, 2);

// 6 > 67
assert_eq!(a.cmp(&b), Ordering::Greater);

// 6 < -3
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);

// 6 < NULL
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);

// 6 < NULL
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);

// NULL == NULL
a.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);

let options = SortOptions {
descending: true,
nulls_first: true,
};

let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]);
let mut a = PrimitiveCursor::new(options, buffer, 2);
let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]);
let mut b = PrimitiveCursor::new(options, buffer, 1);

// NULL == NULL
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);

// NULL == NULL
a.advance();
assert_eq!(a.cmp(&b), Ordering::Equal);
assert_eq!(a, b);

// NULL < 4546
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);

// 6 > 4546
a.advance();
assert_eq!(a.cmp(&b), Ordering::Greater);

// 6 < -3
b.advance();
assert_eq!(a.cmp(&b), Ordering::Less);
}
}
27 changes: 25 additions & 2 deletions datafusion/core/src/physical_plan/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,31 @@ use crate::common::Result;
use crate::physical_plan::metrics::MemTrackingMetrics;
use crate::physical_plan::sorts::builder::BatchBuilder;
use crate::physical_plan::sorts::cursor::Cursor;
use crate::physical_plan::sorts::stream::{PartitionedStream, RowCursorStream};
use crate::physical_plan::sorts::stream::{
PartitionedStream, PrimitiveCursorStream, RowCursorStream,
};
use crate::physical_plan::{
PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream,
};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow_array::downcast_primitive;
use futures::Stream;
use std::pin::Pin;
use std::task::{ready, Context, Poll};

macro_rules! primitive_merge_helper {
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
let streams = PrimitiveCursorStream::<$t>::new($sort, $streams);
return Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
$schema,
$tracking_metrics,
$batch_size,
)));
}};
}

/// Perform a streaming merge of [`SendableRecordBatchStream`]
pub(crate) fn streaming_merge(
streams: Vec<SendableRecordBatchStream>,
Expand All @@ -37,8 +52,16 @@ pub(crate) fn streaming_merge(
tracking_metrics: MemTrackingMetrics,
batch_size: usize,
) -> Result<SendableRecordBatchStream> {
let streams = RowCursorStream::try_new(schema.as_ref(), expressions, streams)?;
if expressions.len() == 1 {
let sort = expressions[0].clone();
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, tracking_metrics, batch_size),
_ => {}
}
}

let streams = RowCursorStream::try_new(schema.as_ref(), expressions, streams)?;
Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
schema,
Expand Down
Loading

0 comments on commit 5f7a3d6

Please sign in to comment.