Skip to content

Commit

Permalink
Use tournament loser tree for k-way sort-merging (#4301)
Browse files Browse the repository at this point in the history
Co-authored-by: zhangli20 <[email protected]>
  • Loading branch information
richox and zhangli20 authored Nov 28, 2022
1 parent 52e198e commit 0d334cf
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 60 deletions.
12 changes: 9 additions & 3 deletions datafusion/core/src/physical_plan/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,14 @@ impl PartialOrd for SortKeyCursor {

impl Ord for SortKeyCursor {
fn cmp(&self, other: &Self) -> Ordering {
self.current()
.cmp(&other.current())
.then_with(|| self.stream_idx.cmp(&other.stream_idx))
match (self.is_finished(), other.is_finished()) {
(true, true) => Ordering::Equal,
(_, true) => Ordering::Less,
(true, _) => Ordering::Greater,
_ => self
.current()
.cmp(&other.current())
.then_with(|| self.stream_idx.cmp(&other.stream_idx)),
}
}
}
162 changes: 105 additions & 57 deletions datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
//! Defines the sort preserving merge plan

use std::any::Any;
use std::cmp::Reverse;
use std::collections::{BinaryHeap, VecDeque};
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -304,10 +303,6 @@ pub(crate) struct SortPreservingMergeStream {
/// their rows have been yielded to the output
batches: Vec<VecDeque<RecordBatch>>,

/// Maintain a flag for each stream denoting if the current cursor
/// has finished and needs to poll from the stream
cursor_finished: Vec<bool>,

/// The accumulated row indexes for the next record batch
in_progress: Vec<RowIndex>,

Expand All @@ -323,8 +318,17 @@ pub(crate) struct SortPreservingMergeStream {
/// An id to uniquely identify the input stream batch
next_batch_id: usize,

/// Heap that yields [`SortKeyCursor`] in increasing order
heap: BinaryHeap<Reverse<SortKeyCursor>>,
/// Vector that holds all [`SortKeyCursor`]s
cursors: Vec<Option<SortKeyCursor>>,

/// The loser tree that always produces the minimum cursor
///
/// Node 0 stores the top winner, Nodes 1..num_streams store
/// the loser nodes
loser_tree: Vec<usize>,

/// Identify whether the loser tree is adjusted
loser_tree_adjusted: bool,

/// target batch size
batch_size: usize,
Expand Down Expand Up @@ -361,14 +365,15 @@ impl SortPreservingMergeStream {
Ok(Self {
schema,
batches,
cursor_finished: vec![true; stream_count],
streams: MergingStreams::new(wrappers),
column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
tracking_metrics,
aborted: false,
in_progress: vec![],
next_batch_id: 0,
heap: BinaryHeap::with_capacity(stream_count),
cursors: (0..stream_count).into_iter().map(|_| None).collect(),
loser_tree: Vec::with_capacity(stream_count),
loser_tree_adjusted: false,
batch_size,
row_converter,
})
Expand All @@ -382,7 +387,11 @@ impl SortPreservingMergeStream {
cx: &mut Context<'_>,
idx: usize,
) -> Poll<ArrowResult<()>> {
if !self.cursor_finished[idx] {
if self.cursors[idx]
.as_ref()
.map(|cursor| !cursor.is_finished())
.unwrap_or(false)
{
// Cursor is not finished - don't need a new RecordBatch yet
return Poll::Ready(Ok(()));
}
Expand Down Expand Up @@ -418,14 +427,12 @@ impl SortPreservingMergeStream {
}
};

let cursor = SortKeyCursor::new(
self.cursors[idx] = Some(SortKeyCursor::new(
idx,
self.next_batch_id, // assign this batch an ID
rows,
);
));
self.next_batch_id += 1;
self.heap.push(Reverse(cursor));
self.cursor_finished[idx] = false;
self.batches[idx].push_back(batch)
} else {
empty_batch = true;
Expand Down Expand Up @@ -551,17 +558,46 @@ impl SortPreservingMergeStream {
if self.aborted {
return Poll::Ready(None);
}
let num_streams = self.streams.num_streams();

// Init all cursors and the loser tree in the first poll
if self.loser_tree.is_empty() {
// Ensure all non-exhausted streams have a cursor from which
// rows can be pulled
for i in 0..num_streams {
match futures::ready!(self.maybe_poll_stream(cx, i)) {
Ok(_) => {}
Err(e) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
}
}
}

// Ensure all non-exhausted streams have a cursor from which
// rows can be pulled
for i in 0..self.streams.num_streams() {
match futures::ready!(self.maybe_poll_stream(cx, i)) {
Ok(_) => {}
Err(e) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
// Init loser tree
self.loser_tree.resize(num_streams, usize::MAX);
for i in 0..num_streams {
let mut winner = i;
let mut cmp_node = (num_streams + i) / 2;
while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
let challenger = self.loser_tree[cmp_node];
let challenger_win =
match (&self.cursors[winner], &self.cursors[challenger]) {
(None, _) => true,
(_, None) => false,
(Some(winner), Some(challenger)) => challenger < winner,
};
if challenger_win {
self.loser_tree[cmp_node] = winner;
winner = challenger;
} else {
self.loser_tree[cmp_node] = challenger;
}
cmp_node /= 2;
}
self.loser_tree[cmp_node] = winner;
}
self.loser_tree_adjusted = true;
}

// NB timer records time taken on drop, so there are no
Expand All @@ -570,45 +606,57 @@ impl SortPreservingMergeStream {
let _timer = elapsed_compute.timer();

loop {
match self.heap.pop() {
Some(Reverse(mut cursor)) => {
let stream_idx = cursor.stream_idx();
let batch_idx = self.batches[stream_idx].len() - 1;
let row_idx = cursor.advance();

let mut cursor_finished = false;
// insert the cursor back to heap if the record batch is not exhausted
if !cursor.is_finished() {
self.heap.push(Reverse(cursor));
} else {
cursor_finished = true;
self.cursor_finished[stream_idx] = true;
// Adjust the loser tree if necessary
if !self.loser_tree_adjusted {
let mut winner = self.loser_tree[0];
match futures::ready!(self.maybe_poll_stream(cx, winner)) {
Ok(_) => {}
Err(e) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
}
}

self.in_progress.push(RowIndex {
stream_idx,
batch_idx,
row_idx,
});

if self.in_progress.len() == self.batch_size {
return Poll::Ready(Some(self.build_record_batch()));
let mut cmp_node = (num_streams + winner) / 2;
while cmp_node != 0 {
let challenger = self.loser_tree[cmp_node];
let challenger_win =
match (&self.cursors[winner], &self.cursors[challenger]) {
(None, _) => true,
(_, None) => false,
(Some(winner), Some(challenger)) => challenger < winner,
};
if challenger_win {
self.loser_tree[cmp_node] = winner;
winner = challenger;
}
cmp_node /= 2;
}
self.loser_tree[0] = winner;
self.loser_tree_adjusted = true;
}

// If removed the last row from the cursor, need to fetch a new record
// batch if possible, before looping round again
if cursor_finished {
match futures::ready!(self.maybe_poll_stream(cx, stream_idx)) {
Ok(_) => {}
Err(e) => {
self.aborted = true;
return Poll::Ready(Some(Err(e)));
}
}
}
let min_cursor_idx = self.loser_tree[0];
let next = self.cursors[min_cursor_idx]
.as_mut()
.filter(|cursor| !cursor.is_finished())
.map(|cursor| (cursor.stream_idx(), cursor.advance()));

if let Some((stream_idx, row_idx)) = next {
self.loser_tree_adjusted = false;
let batch_idx = self.batches[stream_idx].len() - 1;
self.in_progress.push(RowIndex {
stream_idx,
batch_idx,
row_idx,
});
if self.in_progress.len() == self.batch_size {
return Poll::Ready(Some(self.build_record_batch()));
}
None if self.in_progress.is_empty() => return Poll::Ready(None),
None => return Poll::Ready(Some(self.build_record_batch())),
} else if !self.in_progress.is_empty() {
return Poll::Ready(Some(self.build_record_batch()));
} else {
return Poll::Ready(None);
}
}
}
Expand Down

0 comments on commit 0d334cf

Please sign in to comment.