Skip to content

Commit

Permalink
feat: include order by to commutativity rule set (#4753)
Browse files Browse the repository at this point in the history
* feat: include order by to commutativity rule set

Signed-off-by: Ruihang Xia <[email protected]>

* tune sqlness replace interceptor

Signed-off-by: Ruihang Xia <[email protected]>

---------

Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored Sep 23, 2024
1 parent 0f99218 commit 2feddca
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 19 deletions.
101 changes: 83 additions & 18 deletions src/query/src/dist_plan/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;
use std::sync::Arc;

use datafusion::datasource::DefaultTableSource;
use datafusion::error::Result as DfResult;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::{col, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
use datafusion_expr::utils::expr_to_columns;
use datafusion_expr::{col as col_fn, Expr, LogicalPlan, LogicalPlanBuilder, Subquery};
use datafusion_optimizer::analyzer::AnalyzerRule;
use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
Expand Down Expand Up @@ -104,7 +106,7 @@ impl DistPlannerAnalyzer {
let project_exprs = output_schema
.fields()
.iter()
.map(|f| col(f.name()))
.map(|f| col_fn(f.name()))
.collect::<Vec<_>>();
rewrote_subquery = LogicalPlanBuilder::from(rewrote_subquery)
.project(project_exprs)?
Expand Down Expand Up @@ -137,6 +139,7 @@ struct PlanRewriter {
status: RewriterStatus,
/// Partition columns of the table in current pass
partition_cols: Option<Vec<String>>,
column_requirements: HashSet<String>,
}

impl PlanRewriter {
Expand All @@ -162,20 +165,23 @@ impl PlanRewriter {
Commutativity::Commutative => {}
Commutativity::PartialCommutative => {
if let Some(plan) = partial_commutative_transformer(plan) {
self.update_column_requirements(&plan);
self.stage.push(plan)
}
}
Commutativity::ConditionalCommutative(transformer) => {
if let Some(transformer) = transformer
&& let Some(plan) = transformer(plan)
{
self.update_column_requirements(&plan);
self.stage.push(plan)
}
}
Commutativity::TransformedCommutative(transformer) => {
if let Some(transformer) = transformer
&& let Some(plan) = transformer(plan)
{
self.update_column_requirements(&plan);
self.stage.push(plan)
}
}
Expand All @@ -189,6 +195,18 @@ impl PlanRewriter {
false
}

fn update_column_requirements(&mut self, plan: &LogicalPlan) {
let mut container = HashSet::new();
for expr in plan.expressions() {
// this method won't fail
let _ = expr_to_columns(&expr, &mut container);
}

for col in container {
self.column_requirements.insert(col.flat_name());
}
}

fn is_expanded(&self) -> bool {
self.status == RewriterStatus::Expanded
}
Expand Down Expand Up @@ -238,6 +256,67 @@ impl PlanRewriter {
self.level -= 1;
self.stack.pop();
}

fn expand(&mut self, mut on_node: LogicalPlan) -> DfResult<LogicalPlan> {
let mut rewriter = EnforceDistRequirementRewriter {
column_requirements: std::mem::take(&mut self.column_requirements),
};
on_node = on_node.rewrite(&mut rewriter)?.data;

// add merge scan as the new root
let mut node = MergeScanLogicalPlan::new(on_node, false).into_logical_plan();
// expand stages
for new_stage in self.stage.drain(..) {
node = new_stage.with_new_exprs(new_stage.expressions(), vec![node.clone()])?
}
self.set_expanded();

Ok(node)
}
}

/// Implementation of the [`TreeNodeRewriter`] trait which is responsible for rewriting
/// logical plans to enforce various requirement for distributed query.
///
/// Requirements enforced by this rewriter:
/// - Enforce column requirements for `LogicalPlan::Projection` nodes. Makes sure the
/// required columns are available in the sub plan.
struct EnforceDistRequirementRewriter {
column_requirements: HashSet<String>,
}

impl TreeNodeRewriter for EnforceDistRequirementRewriter {
type Node = LogicalPlan;

fn f_down(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
if let LogicalPlan::Projection(ref projection) = node {
let mut column_requirements = std::mem::take(&mut self.column_requirements);
if column_requirements.is_empty() {
return Ok(Transformed::no(node));
}

for expr in &projection.expr {
column_requirements.remove(&expr.name_for_alias()?);
}
if column_requirements.is_empty() {
return Ok(Transformed::no(node));
}

let mut new_exprs = projection.expr.clone();
for col in &column_requirements {
new_exprs.push(col_fn(col));
}
let new_node =
node.with_new_exprs(new_exprs, node.inputs().into_iter().cloned().collect())?;
return Ok(Transformed::yes(new_node));
}

Ok(Transformed::no(node))
}

fn f_up(&mut self, node: Self::Node) -> DfResult<Transformed<Self::Node>> {
Ok(Transformed::no(node))
}
}

impl TreeNodeRewriter for PlanRewriter {
Expand Down Expand Up @@ -274,29 +353,15 @@ impl TreeNodeRewriter for PlanRewriter {
self.maybe_set_partitions(&node);

let Some(parent) = self.get_parent() else {
// add merge scan as the new root
let mut node = MergeScanLogicalPlan::new(node, false).into_logical_plan();
// expand stages
for new_stage in self.stage.drain(..) {
node = new_stage.with_new_exprs(node.expressions(), vec![node.clone()])?
}
self.set_expanded();

let node = self.expand(node)?;
self.pop_stack();
return Ok(Transformed::yes(node));
};

// TODO(ruihang): avoid this clone
if self.should_expand(&parent.clone()) {
// TODO(ruihang): does this work for nodes with multiple children?;
// replace the current node with expanded one
let mut node = MergeScanLogicalPlan::new(node, false).into_logical_plan();
// expand stages
for new_stage in self.stage.drain(..) {
node = new_stage.with_new_exprs(node.expressions(), vec![node.clone()])?
}
self.set_expanded();

let node = self.expand(node)?;
self.pop_stack();
return Ok(Transformed::yes(node));
}
Expand Down
2 changes: 1 addition & 1 deletion src/query/src/dist_plan/commutativity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl Categorizer {

// sort plan needs to consider column priority
// We can implement a merge-sort on partial ordered data
Commutativity::Unimplemented
Commutativity::PartialCommutative
}
LogicalPlan::Join(_) => Commutativity::NonCommutative,
LogicalPlan::CrossJoin(_) => Commutativity::NonCommutative,
Expand Down
99 changes: 99 additions & 0 deletions tests/cases/standalone/common/order/order_by.result
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,102 @@ DROP TABLE test;

Affected Rows: 0

-- ORDER BY for partition table
CREATE TABLE IF NOT EXISTS `t` (
`tag` STRING NULL,
`ts` TIMESTAMP(3) NOT NULL,
`num` BIGINT NULL,
TIME INDEX (`ts`),
PRIMARY KEY (`tag`)
)
PARTITION ON COLUMNS (`tag`) (
tag <= 'z',
tag > 'z'
);

Affected Rows: 0

INSERT INTO t (tag, ts, num) VALUES
('abc', 0, 1),
('abc', 3000, 2),
('abc', 6000, 3),
('abc', 9000, 4),
('abc', 12000, 5),
('zzz', 3000, 6),
('zzz', 6000, 7),
('zzz', 9000, 8),
('zzz', 0, 9),
('zzz', 3000, 10);

Affected Rows: 10

select * from t where num > 3 order by ts desc limit 2;

+-----+---------------------+-----+
| tag | ts | num |
+-----+---------------------+-----+
| abc | 1970-01-01T00:00:12 | 5 |
| abc | 1970-01-01T00:00:09 | 4 |
+-----+---------------------+-----+

select tag from t where num > 6 order by ts desc limit 2;

+-----+---------------------+
| tag | ts |
+-----+---------------------+
| zzz | 1970-01-01T00:00:09 |
| zzz | 1970-01-01T00:00:06 |
+-----+---------------------+

select tag from t where num > 6 order by ts;

+-----+---------------------+
| tag | ts |
+-----+---------------------+
| zzz | 1970-01-01T00:00:00 |
| zzz | 1970-01-01T00:00:03 |
| zzz | 1970-01-01T00:00:06 |
| zzz | 1970-01-01T00:00:09 |
+-----+---------------------+

-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (peers.*) REDACTED
-- SQLNESS REPLACE (metrics.*) REDACTED
-- SQLNESS REPLACE region=\d+\(\d+,\s+\d+\) region=REDACTED
-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED
explain analyze select tag from t where num > 6 order by ts desc limit 2;

+-+-+-+
| stage | node | plan_|
+-+-+-+
| 0_| 0_|_GlobalLimitExec: skip=0, fetch=2 REDACTED
|_|_|_SortPreservingMergeExec: [ts@1 DESC] REDACTED
|_|_|_SortExec: TopK(fetch=2), expr=[ts@1 DESC], preserve_partitioning=[true] REDACTED
|_|_|_MergeScanExec: REDACTED
|_|_|_|
| 1_| 0_|_GlobalLimitExec: skip=0, fetch=2 REDACTED
|_|_|_SortPreservingMergeExec: [ts@1 DESC] REDACTED
|_|_|_SortExec: TopK(fetch=2), expr=[ts@1 DESC], preserve_partitioning=[true] REDACTED
|_|_|_ProjectionExec: expr=[tag@0 as tag, ts@1 as ts] REDACTED
|_|_|_CoalesceBatchesExec: target_batch_size=8192 REDACTED
|_|_|_FilterExec: num@2 > 6 REDACTED
|_|_|_RepartitionExec: partitioning=REDACTED
|_|_|_SeqScan: region=REDACTED, partition_count=1 (1 memtable ranges, 0 file 0 ranges) REDACTED
|_|_|_|
| 1_| 1_|_GlobalLimitExec: skip=0, fetch=2 REDACTED
|_|_|_SortPreservingMergeExec: [ts@1 DESC] REDACTED
|_|_|_SortExec: TopK(fetch=2), expr=[ts@1 DESC], preserve_partitioning=[true] REDACTED
|_|_|_ProjectionExec: expr=[tag@0 as tag, ts@1 as ts] REDACTED
|_|_|_CoalesceBatchesExec: target_batch_size=8192 REDACTED
|_|_|_FilterExec: num@2 > 6 REDACTED
|_|_|_RepartitionExec: partitioning=REDACTED
|_|_|_SeqScan: region=REDACTED, partition_count=1 (1 memtable ranges, 0 file 0 ranges) REDACTED
|_|_|_|
|_|_| Total rows: 2_|
+-+-+-+

drop table t;

Affected Rows: 0

41 changes: 41 additions & 0 deletions tests/cases/standalone/common/order/order_by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,44 @@ SELECT a-10 AS k FROM test UNION SELECT a-10 AS l FROM test ORDER BY a-10;
SELECT a-10 AS k FROM test UNION SELECT a-11 AS l FROM test ORDER BY a-11;

DROP TABLE test;

-- ORDER BY for partition table
CREATE TABLE IF NOT EXISTS `t` (
`tag` STRING NULL,
`ts` TIMESTAMP(3) NOT NULL,
`num` BIGINT NULL,
TIME INDEX (`ts`),
PRIMARY KEY (`tag`)
)
PARTITION ON COLUMNS (`tag`) (
tag <= 'z',
tag > 'z'
);

INSERT INTO t (tag, ts, num) VALUES
('abc', 0, 1),
('abc', 3000, 2),
('abc', 6000, 3),
('abc', 9000, 4),
('abc', 12000, 5),
('zzz', 3000, 6),
('zzz', 6000, 7),
('zzz', 9000, 8),
('zzz', 0, 9),
('zzz', 3000, 10);

select * from t where num > 3 order by ts desc limit 2;

select tag from t where num > 6 order by ts desc limit 2;

select tag from t where num > 6 order by ts;

-- SQLNESS REPLACE (-+) -
-- SQLNESS REPLACE (\s\s+) _
-- SQLNESS REPLACE (peers.*) REDACTED
-- SQLNESS REPLACE (metrics.*) REDACTED
-- SQLNESS REPLACE region=\d+\(\d+,\s+\d+\) region=REDACTED
-- SQLNESS REPLACE (RoundRobinBatch.*) REDACTED
explain analyze select tag from t where num > 6 order by ts desc limit 2;

drop table t;

0 comments on commit 2feddca

Please sign in to comment.