Skip to content

Commit

Permalink
perf: optimize count(*) (#3845)
Browse files Browse the repository at this point in the history
* perf: optimize count(*)

Signed-off-by: Ruihang Xia <[email protected]>

* fallback to count(1) for temporary table

Signed-off-by: Ruihang Xia <[email protected]>

* handle alias expr in range plan

Signed-off-by: Ruihang Xia <[email protected]>

* handle subquery alias

Signed-off-by: Ruihang Xia <[email protected]>

* rename file

Signed-off-by: Ruihang Xia <[email protected]>

---------

Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia committed Apr 30, 2024
1 parent 777bc3b commit e84b1ee
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/query/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

pub mod count_wildcard;
pub mod order_hint;
pub mod remove_duplicate;
pub mod string_normalization;
Expand All @@ -27,7 +28,7 @@ use crate::QueryEngineContext;

/// [`ExtensionAnalyzerRule`]s transform [`LogicalPlan`]s in some way to make
/// the plan valid prior to the rest of the DataFusion optimization process.
/// It's an extension of datafusion [`AnalyzerRule`]s but accepts [`QueryEngineContext` as the second parameter.
/// It's an extension of datafusion [`AnalyzerRule`]s but accepts [`QueryEngineContext`] as the second parameter.
pub trait ExtensionAnalyzerRule {
/// Rewrite `plan`
fn analyze(
Expand Down
156 changes: 156 additions & 0 deletions src/query/src/optimizer/count_wildcard.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use datafusion::datasource::DefaultTableSource;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
};
use datafusion_common::Result as DataFusionResult;
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, WindowFunction};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition};
use datafusion_optimizer::utils::NamePreserver;
use datafusion_optimizer::AnalyzerRule;
use table::table::adapter::DfTableProviderAdapter;

/// A replacement to DataFusion's [`CountWildcardRule`]. This rule
/// would prefer to use TIME INDEX for counting wildcard as it's
/// faster to read comparing to PRIMARY KEYs.
///
/// [`CountWildcardRule`]: datafusion::optimizer::analyzer::CountWildcardRule
pub struct CountWildcardToTimeIndexRule;

impl AnalyzerRule for CountWildcardToTimeIndexRule {
fn name(&self) -> &str {
"count_wildcard_to_time_index_rule"
}

fn analyze(
&self,
plan: LogicalPlan,
_config: &datafusion::config::ConfigOptions,
) -> DataFusionResult<LogicalPlan> {
plan.transform_down_with_subqueries(&Self::analyze_internal)
.data()
}
}

impl CountWildcardToTimeIndexRule {
fn analyze_internal(plan: LogicalPlan) -> DataFusionResult<Transformed<LogicalPlan>> {
let name_preserver = NamePreserver::new(&plan);
let new_arg = if let Some(time_index) = Self::try_find_time_index_col(&plan) {
vec![col(time_index)]
} else {
vec![lit(COUNT_STAR_EXPANSION)]
};
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr)?;
let transformed_expr = expr.transform_up_mut(&mut |expr| match expr {
Expr::WindowFunction(mut window_function)
if Self::is_count_star_window_aggregate(&window_function) =>
{
window_function.args.clone_from(&new_arg);
Ok(Transformed::yes(Expr::WindowFunction(window_function)))
}
Expr::AggregateFunction(mut aggregate_function)
if Self::is_count_star_aggregate(&aggregate_function) =>
{
aggregate_function.args.clone_from(&new_arg);
Ok(Transformed::yes(Expr::AggregateFunction(
aggregate_function,
)))
}
_ => Ok(Transformed::no(expr)),
})?;
transformed_expr.map_data(|data| original_name.restore(data))
})
}

fn try_find_time_index_col(plan: &LogicalPlan) -> Option<String> {
let mut finder = TimeIndexFinder::default();
// Safety: `TimeIndexFinder` won't throw error.
plan.visit(&mut finder).unwrap();
finder.time_index
}
}

/// Utility functions from the original rule.
impl CountWildcardToTimeIndexRule {
fn is_wildcard(expr: &Expr) -> bool {
matches!(expr, Expr::Wildcard { qualifier: None })
}

fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
matches!(
&aggregate_function.func_def,
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && aggregate_function.args.len() == 1
&& Self::is_wildcard(&aggregate_function.args[0])
}

fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
matches!(
&window_function.fun,
WindowFunctionDefinition::AggregateFunction(
datafusion_expr::aggregate_function::AggregateFunction::Count,
)
) && window_function.args.len() == 1
&& Self::is_wildcard(&window_function.args[0])
}
}

#[derive(Default)]
struct TimeIndexFinder {
time_index: Option<String>,
table_alias: Option<String>,
}

impl TreeNodeVisitor for TimeIndexFinder {
type Node = LogicalPlan;

fn f_down(&mut self, node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
if let LogicalPlan::SubqueryAlias(subquery_alias) = node {
self.table_alias = Some(subquery_alias.alias.to_string());
}

if let LogicalPlan::TableScan(table_scan) = &node {
if let Some(source) = table_scan
.source
.as_any()
.downcast_ref::<DefaultTableSource>()
{
if let Some(adapter) = source
.table_provider
.as_any()
.downcast_ref::<DfTableProviderAdapter>()
{
let table_info = adapter.table().table_info();
let col_name = table_info.meta.schema.timestamp_column().map(|c| &c.name);
let table_name = self.table_alias.as_ref().unwrap_or(&table_info.name);
self.time_index = col_name.map(|s| format!("{}.{}", table_name, s));

return Ok(TreeNodeRecursion::Stop);
}
}
}

Ok(TreeNodeRecursion::Continue)
}

fn f_up(&mut self, _node: &Self::Node) -> DataFusionResult<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Stop)
}
}
12 changes: 11 additions & 1 deletion src/query/src/query_engine/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use table::table::adapter::DfTableProviderAdapter;
use table::TableRef;

use crate::dist_plan::{DistExtensionPlanner, DistPlannerAnalyzer};
use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
use crate::optimizer::order_hint::OrderHintRule;
use crate::optimizer::remove_duplicate::RemoveDuplicate;
use crate::optimizer::string_normalization::StringNormalizationRule;
Expand Down Expand Up @@ -89,18 +90,27 @@ impl QueryEngineState {
let session_config = SessionConfig::new().with_create_default_catalog_and_schema(false);
// Apply extension rules
let mut extension_rules = Vec::new();

// The [`TypeConversionRule`] must be at first
extension_rules.insert(0, Arc::new(TypeConversionRule) as _);

// Apply the datafusion rules
let mut analyzer = Analyzer::new();
analyzer.rules.insert(0, Arc::new(StringNormalizationRule));

// Use our custom rule instead to optimize the count(*) query
Self::remove_analyzer_rule(&mut analyzer.rules, CountWildcardRule {}.name());
analyzer.rules.insert(0, Arc::new(CountWildcardRule {}));
analyzer
.rules
.insert(0, Arc::new(CountWildcardToTimeIndexRule));

if with_dist_planner {
analyzer.rules.push(Arc::new(DistPlannerAnalyzer));
}

let mut optimizer = Optimizer::new();
optimizer.rules.push(Arc::new(OrderHintRule));

// add physical optimizer
let mut physical_optimizer = PhysicalOptimizer::new();
physical_optimizer.rules.push(Arc::new(RemoveDuplicate));
Expand Down
16 changes: 11 additions & 5 deletions src/query/src/range_select/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,13 @@ impl RangeSelect {
.range_expr
.iter()
.map(|range_fn| {
let expr = match &range_fn.expr {
let name = range_fn.expr.display_name()?;
let range_expr = match &range_fn.expr {
Expr::Alias(expr) => expr.expr.as_ref(),
others => others,
};

let expr = match &range_expr {
Expr::AggregateFunction(
aggr @ datafusion_expr::expr::AggregateFunction {
func_def:
Expand Down Expand Up @@ -778,7 +784,7 @@ impl RangeSelect {
&input_phy_exprs,
&order_by,
&input_schema,
range_fn.expr.display_name()?,
name,
false,
),
AggregateFunctionDefinition::UDF(fun) => create_aggr_udf_expr(
Expand All @@ -787,7 +793,7 @@ impl RangeSelect {
&[],
&[],
&input_schema,
range_fn.expr.display_name()?,
name,
false,
),
f => Err(DataFusionError::NotImplemented(format!(
Expand All @@ -796,8 +802,8 @@ impl RangeSelect {
}
}
_ => Err(DataFusionError::Plan(format!(
"Unexpected Expr:{} in RangeSelect",
range_fn.expr.display_name()?
"Unexpected Expr: {} in RangeSelect",
range_fn.expr.canonical_name()
))),
}?;
let args = expr.expressions();
Expand Down
25 changes: 21 additions & 4 deletions tests/cases/standalone/common/range/special_aggr.result
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ SELECT ts, host, first_value(addon ORDER BY val ASC, ts ASC) RANGE '5s', last_va
| 1970-01-01T00:00:20 | host2 | 28 | 30 |
+---------------------+-------+---------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------+

SELECT ts, host, count(val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
SELECT ts, host, count(val) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts;

+---------------------+-------+--------------------------+
| ts | host | COUNT(host.val) RANGE 5s |
Expand All @@ -160,7 +160,7 @@ SELECT ts, host, count(val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
| 1970-01-01T00:00:20 | host2 | 2 |
+---------------------+-------+--------------------------+

SELECT ts, host, count(distinct val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
SELECT ts, host, count(distinct val) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts;

+---------------------+-------+-----------------------------------+
| ts | host | COUNT(DISTINCT host.val) RANGE 5s |
Expand All @@ -177,7 +177,7 @@ SELECT ts, host, count(distinct val) RANGE '5s'FROM host ALIGN '5s' ORDER BY hos
| 1970-01-01T00:00:20 | host2 | 2 |
+---------------------+-------+-----------------------------------+

SELECT ts, host, count(*) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
SELECT ts, host, count(*) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts;

+---------------------+-------+-------------------+
| ts | host | COUNT(*) RANGE 5s |
Expand All @@ -194,7 +194,24 @@ SELECT ts, host, count(*) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
| 1970-01-01T00:00:20 | host2 | 3 |
+---------------------+-------+-------------------+

SELECT ts, host, count(distinct *) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
SELECT ts, host, count(1) RANGE '5s' as abc FROM host ALIGN '5s' ORDER BY host, ts;

+---------------------+-------+-----+
| ts | host | abc |
+---------------------+-------+-----+
| 1970-01-01T00:00:00 | host1 | 3 |
| 1970-01-01T00:00:05 | host1 | 3 |
| 1970-01-01T00:00:10 | host1 | 3 |
| 1970-01-01T00:00:15 | host1 | 3 |
| 1970-01-01T00:00:20 | host1 | 3 |
| 1970-01-01T00:00:00 | host2 | 3 |
| 1970-01-01T00:00:05 | host2 | 3 |
| 1970-01-01T00:00:10 | host2 | 3 |
| 1970-01-01T00:00:15 | host2 | 3 |
| 1970-01-01T00:00:20 | host2 | 3 |
+---------------------+-------+-----+

SELECT ts, host, count(distinct *) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts;

+---------------------+-------+----------------------------+
| ts | host | COUNT(DISTINCT *) RANGE 5s |
Expand Down
10 changes: 6 additions & 4 deletions tests/cases/standalone/common/range/special_aggr.sql
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ SELECT ts, host, first_value(addon ORDER BY val ASC NULLS FIRST) RANGE '5s', las

SELECT ts, host, first_value(addon ORDER BY val ASC, ts ASC) RANGE '5s', last_value(addon ORDER BY val ASC, ts ASC) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts;

SELECT ts, host, count(val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
SELECT ts, host, count(val) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts;

SELECT ts, host, count(distinct val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
SELECT ts, host, count(distinct val) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts;

SELECT ts, host, count(*) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
SELECT ts, host, count(*) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts;

SELECT ts, host, count(distinct *) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts;
SELECT ts, host, count(1) RANGE '5s' as abc FROM host ALIGN '5s' ORDER BY host, ts;

SELECT ts, host, count(distinct *) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts;

-- Test error first_value/last_value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test;
|_|_Filter: test.j >= TimestampMillisecond(-300000, None) AND test.j <= TimestampMillisecond(300000, None)_|
|_|_TableScan: test_|
| logical_plan after apply_function_rewrites_| SAME TEXT AS ABOVE_|
| logical_plan after count_wildcard_rule_| SAME TEXT AS ABOVE_|
| logical_plan after count_wildcard_to_time_index_rule_| SAME TEXT AS ABOVE_|
| logical_plan after StringNormalizationRule_| SAME TEXT AS ABOVE_|
| logical_plan after inline_table_scan_| SAME TEXT AS ABOVE_|
| logical_plan after type_coercion_| SAME TEXT AS ABOVE_|
Expand Down

0 comments on commit e84b1ee

Please sign in to comment.