diff --git a/datafusion/core/src/datasource/cte_worktable.rs b/datafusion/core/src/datasource/cte_worktable.rs index de13e73e003b..71075839b9a0 100644 --- a/datafusion/core/src/datasource/cte_worktable.rs +++ b/datafusion/core/src/datasource/cte_worktable.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::not_impl_err; +use datafusion_physical_plan::work_table::WorkTableExec; use crate::{ error::Result, @@ -30,8 +30,6 @@ use crate::{ physical_plan::ExecutionPlan, }; -use datafusion_common::DataFusionError; - use crate::datasource::{TableProvider, TableType}; use crate::execution::context::SessionState; @@ -84,7 +82,11 @@ impl TableProvider for CteWorkTable { _filters: &[Expr], _limit: Option, ) -> Result> { - not_impl_err!("scan not implemented for CteWorkTable yet") + // TODO: pushdown filters and limits + Ok(Arc::new(WorkTableExec::new( + self.name.clone(), + self.table_schema.clone(), + ))) } fn supports_filter_pushdown( diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 1d1bee61805e..2d20c487e473 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -248,7 +248,7 @@ fn try_swapping_with_streaming_table( StreamingTableExec::try_new( streaming_table.partition_schema().clone(), streaming_table.partitions().clone(), - Some(&new_projections), + Some(new_projections.as_ref()), lex_orderings, streaming_table.is_infinite(), ) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ac7827fafc2c..ac3b7ebaeac1 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -58,6 +58,7 @@ use crate::physical_plan::joins::{ use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::recursive_query::RecursiveQueryExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::union::UnionExec; @@ -894,7 +895,7 @@ impl DefaultPhysicalPlanner { let filter = FilterExec::try_new(runtime_expr, physical_input)?; Ok(Arc::new(filter.with_default_selectivity(selectivity)?)) } - LogicalPlan::Union(Union { inputs, .. }) => { + LogicalPlan::Union(Union { inputs, schema: _ }) => { let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; Ok(Arc::new(UnionExec::new(physical_plans))) @@ -1288,8 +1289,10 @@ impl DefaultPhysicalPlanner { Ok(plan) } } - LogicalPlan::RecursiveQuery(RecursiveQuery { name: _, static_term: _, recursive_term: _, is_distinct: _,.. }) => { - not_impl_err!("Physical counterpart of RecursiveQuery is not implemented yet") + LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term, recursive_term, is_distinct,.. }) => { + let static_term = self.create_initial_plan(static_term, session_state).await?; + let recursive_term = self.create_initial_plan(recursive_term, session_state).await?; + Ok(Arc::new(RecursiveQueryExec::try_new(name.clone(), static_term, recursive_term, *is_distinct)?)) } }; exec_plan diff --git a/datafusion/core/tests/data/recursive_cte/balance.csv b/datafusion/core/tests/data/recursive_cte/balance.csv new file mode 100644 index 000000000000..a77c742dd2e5 --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/balance.csv @@ -0,0 +1,5 @@ +time,name,account_balance +1,John,100 +1,Tim,200 +2,John,300 +2,Tim,400 \ No newline at end of file diff --git a/datafusion/core/tests/data/recursive_cte/growth.csv b/datafusion/core/tests/data/recursive_cte/growth.csv new file mode 100644 index 000000000000..912208bad2eb --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/growth.csv @@ -0,0 +1,4 @@ +name,account_growth +John,3 +Tim,20 +Eliza,150 \ No newline at end of file diff --git a/datafusion/core/tests/data/recursive_cte/prices.csv b/datafusion/core/tests/data/recursive_cte/prices.csv new file mode 100644 index 000000000000..b294ecfad774 --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/prices.csv @@ -0,0 +1,101 @@ +Index,product,price,prices_row_num +1,Holden,334.8,1 +2,Mercedes-Benz,623.22,2 +3,Aston Martin,363.48,3 +4,GMC,615.67,4 +5,Lincoln,521.13,5 +6,Mitsubishi,143.05,6 +7,Infiniti,861.82,7 +8,Ford,330.57,8 +9,GMC,136.87,9 +10,Toyota,106.29,10 +11,Pontiac,686.95,11 +12,Ford,197.48,12 +13,Honda,774.42,13 +14,Dodge,854.26,14 +15,Bentley,628.82,15 +16,Chevrolet,756.82,16 +17,Volkswagen,438.51,17 +18,Mazda,156.15,18 +19,Hyundai,322.43,19 +20,Oldsmobile,979.95,20 +21,Geo,359.59,21 +22,Ford,960.75,22 +23,Subaru,106.75,23 +24,Pontiac,13.4,24 +25,Mercedes-Benz,858.46,25 +26,Subaru,55.72,26 +27,BMW,316.69,27 +28,Chevrolet,290.32,28 +29,Mercury,296.8,29 +30,Dodge,410.78,30 +31,Oldsmobile,18.07,31 +32,Subaru,442.22,32 +33,Dodge,93.29,33 +34,Honda,282.9,34 +35,Chevrolet,750.87,35 +36,Lexus,249.82,36 +37,Ford,732.33,37 +38,Toyota,680.78,38 +39,Nissan,657.01,39 +40,Mazda,200.76,40 +41,Nissan,251.44,41 +42,Buick,714.44,42 +43,Ford,436.2,43 +44,Volvo,865.53,44 +45,Saab,471.52,45 +46,Mercedes-Benz,51.13,46 +47,Chrysler,943.52,47 +48,Lamborghini,181.6,48 +49,Hyundai,634.89,49 +50,Ford,757.58,50 +51,Porsche,294.64,51 +52,Ford,261.34,52 +53,Chrysler,822.01,53 +54,Audi,430.68,54 +55,Mitsubishi,69.12,55 +56,Mazda,723.16,56 +57,Mazda,711.46,57 +58,Land Rover,435.15,58 +59,Buick,189.58,59 +60,GMC,651.92,60 +61,Mazda,491.37,61 +62,BMW,346.18,62 +63,Ford,456.25,63 +64,Ford,10.65,64 +65,Mazda,985.39,65 +66,Mercedes-Benz,955.79,66 +67,Honda,550.95,67 +68,Mitsubishi,127.6,68 +69,Mercedes-Benz,840.65,69 +70,Infiniti,647.45,70 +71,Bentley,827.26,71 +72,Lincoln,822.22,72 +73,Plymouth,970.55,73 +74,Ford,595.05,74 +75,Maybach,808.46,75 +76,Chevrolet,341.48,76 +77,Jaguar,759.03,77 +78,Land Rover,625.01,78 +79,Lincoln,289.13,79 +80,Suzuki,285.24,80 +81,GMC,253.4,81 +82,Oldsmobile,174.76,82 +83,Lincoln,434.17,83 +84,Dodge,887.38,84 +85,Mercedes-Benz,308.65,85 +86,GMC,182.71,86 +87,Ford,619.62,87 +88,Lexus,228.63,88 +89,Hyundai,901.06,89 +90,Chevrolet,615.65,90 +91,GMC,460.19,91 +92,Mercedes-Benz,729.28,92 +93,Dodge,414.69,93 +94,Maserati,300.83,94 +95,Suzuki,503.64,95 +96,Audi,275.05,96 +97,Ford,303.25,97 +98,Lotus,101.01,98 +99,Lincoln,721.05,99 +100,Kia,833.31,100 \ No newline at end of file diff --git a/datafusion/core/tests/data/recursive_cte/sales.csv b/datafusion/core/tests/data/recursive_cte/sales.csv new file mode 100644 index 000000000000..12299c39d635 --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/sales.csv @@ -0,0 +1,10 @@ +region_id,salesperson_id,sale_amount +101,1,1000 +102,2,500 +101,2,700 +103,3,800 +102,4,300 +101,4,400 +102,5,600 +103,6,500 +101,7,900 \ No newline at end of file diff --git a/datafusion/core/tests/data/recursive_cte/salespersons.csv b/datafusion/core/tests/data/recursive_cte/salespersons.csv new file mode 100644 index 000000000000..dc941c450246 --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/salespersons.csv @@ -0,0 +1,8 @@ +salesperson_id,manager_id +1, +2,1 +3,1 +4,2 +5,2 +6,3 +7,3 \ No newline at end of file diff --git a/datafusion/core/tests/data/recursive_cte/time.csv b/datafusion/core/tests/data/recursive_cte/time.csv new file mode 100644 index 000000000000..21026bd41a4a --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/time.csv @@ -0,0 +1,5 @@ +time,other +1,foo +2,bar +4,baz +5,qux diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 01d4f8941802..0a9eab5c8633 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -61,6 +61,7 @@ pub mod metrics; mod ordering; pub mod placeholder_row; pub mod projection; +pub mod recursive_query; pub mod repartition; pub mod sorts; pub mod stream; @@ -71,6 +72,7 @@ pub mod union; pub mod unnest; pub mod values; pub mod windows; +pub mod work_table; pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; pub use crate::metrics::Metric; diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs new file mode 100644 index 000000000000..614ab990ac49 --- /dev/null +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -0,0 +1,377 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the recursive query plan + +use std::any::Any; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use super::expressions::PhysicalSortExpr; +use super::metrics::BaselineMetrics; +use super::RecordBatchStream; +use super::{ + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + work_table::{WorkTable, WorkTableExec}, + SendableRecordBatchStream, Statistics, +}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::Partitioning; +use futures::{ready, Stream, StreamExt}; + +use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +/// Recursive query execution plan. +/// +/// This plan has two components: a base part (the static term) and +/// a dynamic part (the recursive term). The execution will start from +/// the base, and as long as the previous iteration produced at least +/// a single new row (taking care of the distinction) the recursive +/// part will be continuously executed. +/// +/// Before each execution of the dynamic part, the rows from the previous +/// iteration will be available in a "working table" (not a real table, +/// can be only accessed using a continuance operation). +/// +/// Note that there won't be any limit or checks applied to detect +/// an infinite recursion, so it is up to the planner to ensure that +/// it won't happen. +#[derive(Debug)] +pub struct RecursiveQueryExec { + /// Name of the query handler + name: String, + /// The working table of cte + work_table: Arc, + /// The base part (static term) + static_term: Arc, + /// The dynamic part (recursive term) + recursive_term: Arc, + /// Distinction + is_distinct: bool, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl RecursiveQueryExec { + /// Create a new RecursiveQueryExec + pub fn try_new( + name: String, + static_term: Arc, + recursive_term: Arc, + is_distinct: bool, + ) -> Result { + // Each recursive query needs its own work table + let work_table = Arc::new(WorkTable::new()); + // Use the same work table for both the WorkTableExec and the recursive term + let recursive_term = assign_work_table(recursive_term, work_table.clone())?; + Ok(RecursiveQueryExec { + name, + static_term, + recursive_term, + is_distinct, + work_table, + metrics: ExecutionPlanMetricsSet::new(), + }) + } +} + +impl ExecutionPlan for RecursiveQueryExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.static_term.schema() + } + + fn children(&self) -> Vec> { + vec![self.static_term.clone(), self.recursive_term.clone()] + } + + // Distribution on a recursive query is really tricky to handle. + // For now, we are going to use a single partition but in the + // future we might find a better way to handle this. + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + // TODO: control these hints and see whether we can + // infer some from the child plans (static/recurisve terms). + fn maintains_input_order(&self) -> Vec { + vec![false, false] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false, false] + } + + fn required_input_distribution(&self) -> Vec { + vec![ + datafusion_physical_expr::Distribution::SinglePartition, + datafusion_physical_expr::Distribution::SinglePartition, + ] + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(RecursiveQueryExec { + name: self.name.clone(), + static_term: children[0].clone(), + recursive_term: children[1].clone(), + is_distinct: self.is_distinct, + work_table: self.work_table.clone(), + metrics: self.metrics.clone(), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + // TODO: we might be able to handle multiple partitions in the future. + if partition != 0 { + return Err(DataFusionError::Internal(format!( + "RecursiveQueryExec got an invalid partition {} (expected 0)", + partition + ))); + } + + let static_stream = self.static_term.execute(partition, context.clone())?; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + Ok(Box::pin(RecursiveQueryStream::new( + context, + self.work_table.clone(), + self.recursive_term.clone(), + static_stream, + baseline_metrics, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } +} + +impl DisplayAs for RecursiveQueryExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "RecursiveQueryExec: name={}, is_distinct={}", + self.name, self.is_distinct + ) + } + } + } +} + +/// The actual logic of the recursive queries happens during the streaming +/// process. A simplified version of the algorithm is the following: +/// +/// buffer = [] +/// +/// while batch := static_stream.next(): +/// buffer.push(batch) +/// yield buffer +/// +/// while buffer.len() > 0: +/// sender, receiver = Channel() +/// register_continuation(handle_name, receiver) +/// sender.send(buffer.drain()) +/// recursive_stream = recursive_term.execute() +/// while batch := recursive_stream.next(): +/// buffer.append(batch) +/// yield buffer +/// +struct RecursiveQueryStream { + /// The context to be used for managing handlers & executing new tasks + task_context: Arc, + /// The working table state, representing the self referencing cte table + work_table: Arc, + /// The dynamic part (recursive term) as is (without being executed) + recursive_term: Arc, + /// The static part (static term) as a stream. If the processing of this + /// part is completed, then it will be None. + static_stream: Option, + /// The dynamic part (recursive term) as a stream. If the processing of this + /// part has not started yet, or has been completed, then it will be None. + recursive_stream: Option, + /// The schema of the output. + schema: SchemaRef, + /// In-memory buffer for storing a copy of the current results. Will be + /// cleared after each iteration. + buffer: Vec, + // /// Metrics. + _baseline_metrics: BaselineMetrics, +} + +impl RecursiveQueryStream { + /// Create a new recursive query stream + fn new( + task_context: Arc, + work_table: Arc, + recursive_term: Arc, + static_stream: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, + ) -> Self { + let schema = static_stream.schema(); + Self { + task_context, + work_table, + recursive_term, + static_stream: Some(static_stream), + recursive_stream: None, + schema, + buffer: vec![], + _baseline_metrics: baseline_metrics, + } + } + + /// Push a clone of the given batch to the in memory buffer, and then return + /// a poll with it. + fn push_batch( + mut self: std::pin::Pin<&mut Self>, + batch: RecordBatch, + ) -> Poll>> { + self.buffer.push(batch.clone()); + Poll::Ready(Some(Ok(batch))) + } + + /// Start polling for the next iteration, will be called either after the static term + /// is completed or another term is completed. It will follow the algorithm above on + /// to check whether the recursion has ended. + fn poll_next_iteration( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let total_length = self + .buffer + .iter() + .fold(0, |acc, batch| acc + batch.num_rows()); + + if total_length == 0 { + return Poll::Ready(None); + } + + // Update the work table with the current buffer + let batches = self.buffer.drain(..).collect(); + self.work_table.write(batches); + + // We always execute (and re-execute iteratively) the first partition. + // Downstream plans should not expect any partitioning. + let partition = 0; + + self.recursive_stream = Some( + self.recursive_term + .execute(partition, self.task_context.clone())?, + ); + self.poll_next(cx) + } +} + +fn assign_work_table( + plan: Arc, + work_table: Arc, +) -> Result> { + let mut work_table_refs = 0; + plan.transform_down_mut(&mut |plan| { + if let Some(exec) = plan.as_any().downcast_ref::() { + if work_table_refs > 0 { + not_impl_err!( + "Multiple recursive references to the same CTE are not supported" + ) + } else { + work_table_refs += 1; + Ok(Transformed::Yes(Arc::new( + exec.with_work_table(work_table.clone()), + ))) + } + } else if plan.as_any().is::() { + not_impl_err!("Recursive queries cannot be nested") + } else { + Ok(Transformed::No(plan)) + } + }) +} + +impl Stream for RecursiveQueryStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + // TODO: we should use this poll to record some metrics! + if let Some(static_stream) = &mut self.static_stream { + // While the static term's stream is available, we'll be forwarding the batches from it (also + // saving them for the initial iteration of the recursive term). + let batch_result = ready!(static_stream.poll_next_unpin(cx)); + match &batch_result { + None => { + // Once this is done, we can start running the setup for the recursive term. + self.static_stream = None; + self.poll_next_iteration(cx) + } + Some(Ok(batch)) => self.push_batch(batch.clone()), + _ => Poll::Ready(batch_result), + } + } else if let Some(recursive_stream) = &mut self.recursive_stream { + let batch_result = ready!(recursive_stream.poll_next_unpin(cx)); + match batch_result { + None => { + self.recursive_stream = None; + self.poll_next_iteration(cx) + } + Some(Ok(batch)) => self.push_batch(batch.clone()), + _ => Poll::Ready(batch_result), + } + } else { + Poll::Ready(None) + } + } +} + +impl RecordBatchStream for RecursiveQueryStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs new file mode 100644 index 000000000000..c74a596f3dae --- /dev/null +++ b/datafusion/physical-plan/src/work_table.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the work table query plan + +use std::any::Any; +use std::sync::{Arc, Mutex}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::Partitioning; + +use crate::memory::MemoryStream; +use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; + +use super::expressions::PhysicalSortExpr; + +use super::{ + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + SendableRecordBatchStream, Statistics, +}; +use datafusion_common::{internal_err, DataFusionError, Result}; + +/// The name is from PostgreSQL's terminology. +/// See +/// This table serves as a mirror or buffer between each iteration of a recursive query. +#[derive(Debug)] +pub(super) struct WorkTable { + batches: Mutex>>, +} + +impl WorkTable { + /// Create a new work table. + pub(super) fn new() -> Self { + Self { + batches: Mutex::new(None), + } + } + + /// Take the previously written batches from the work table. + /// This will be called by the [`WorkTableExec`] when it is executed. + fn take(&self) -> Vec { + let batches = self.batches.lock().unwrap().take().unwrap_or_default(); + batches + } + + /// Write the results of a recursive query iteration to the work table. + pub(super) fn write(&self, input: Vec) { + self.batches.lock().unwrap().replace(input); + } +} + +/// A temporary "working table" operation where the input data will be +/// taken from the named handle during the execution and will be re-published +/// as is (kind of like a mirror). +/// +/// Most notably used in the implementation of recursive queries where the +/// underlying relation does not exist yet but the data will come as the previous +/// term is evaluated. This table will be used such that the recursive plan +/// will register a receiver in the task context and this plan will use that +/// receiver to get the data and stream it back up so that the batches are available +/// in the next iteration. +#[derive(Clone, Debug)] +pub struct WorkTableExec { + /// Name of the relation handler + name: String, + /// The schema of the stream + schema: SchemaRef, + /// The work table + work_table: Arc, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +impl WorkTableExec { + /// Create a new execution plan for a worktable exec. + pub fn new(name: String, schema: SchemaRef) -> Self { + Self { + name, + schema, + metrics: ExecutionPlanMetricsSet::new(), + work_table: Arc::new(WorkTable::new()), + } + } + + pub(super) fn with_work_table(&self, work_table: Arc) -> Self { + Self { + name: self.name.clone(), + schema: self.schema.clone(), + metrics: ExecutionPlanMetricsSet::new(), + work_table, + } + } +} + +impl DisplayAs for WorkTableExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "WorkTableExec: name={}", self.name) + } + } + } +} + +impl ExecutionPlan for WorkTableExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![] + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(1) + } + + fn maintains_input_order(&self) -> Vec { + vec![false] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self.clone()) + } + + /// Stream the batches that were written to the work table. + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + // WorkTable streams must be the plan base. + if partition != 0 { + return internal_err!( + "WorkTableExec got an invalid partition {partition} (expected 0)" + ); + } + + let batches = self.work_table.take(); + Ok(Box::pin(MemoryStream::try_new( + batches, + self.schema.clone(), + None, + )?)) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } +} + +#[cfg(test)] +mod tests {} diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index af0b91ae6c7e..ea8edd0771c8 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -54,7 +54,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let set_expr = query.body; if let Some(with) = query.with { // Process CTEs from top to bottom - let is_recursive = with.recursive; for cte in with.cte_tables { diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index d341833ba1b6..6b9db5589391 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -38,3 +38,616 @@ Projection: NUMBERS.a, NUMBERS.b, NUMBERS.c physical_plan ProjectionExec: expr=[1 as a, 2 as b, 3 as c] --PlaceholderRowExec + + + +# enable recursive CTEs +statement ok +set datafusion.execution.enable_recursive_ctes = true; + +# trivial recursive CTE works +query I rowsort +WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION ALL + SELECT id + 1 as id + FROM nodes + WHERE id < 10 +) +SELECT * FROM nodes +---- +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 + +# explain trivial recursive CTE +query TT +EXPLAIN WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION ALL + SELECT id + 1 as id + FROM nodes + WHERE id < 10 +) +SELECT * FROM nodes +---- +logical_plan +Projection: nodes.id +--SubqueryAlias: nodes +----RecursiveQuery: is_distinct=false +------Projection: Int64(1) AS id +--------EmptyRelation +------Projection: nodes.id + Int64(1) AS id +--------Filter: nodes.id < Int64(10) +----------TableScan: nodes +physical_plan +RecursiveQueryExec: name=nodes, is_distinct=false +--ProjectionExec: expr=[1 as id] +----PlaceholderRowExec +--CoalescePartitionsExec +----ProjectionExec: expr=[id@0 + 1 as id] +------CoalesceBatchesExec: target_batch_size=8192 +--------FilterExec: id@0 < 10 +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------WorkTableExec: name=nodes + +# setup +statement ok +CREATE EXTERNAL TABLE balance STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/balance.csv' + +# setup +statement ok +CREATE EXTERNAL TABLE growth STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/growth.csv' + +# setup +statement ok +set datafusion.execution.batch_size = 2; + +# recursive CTE with static term derived from table works. +# use explain to ensure that batch size is set to 2. This should produce multiple batches per iteration since the input +# table 'balances' has 4 rows +query TT +EXPLAIN WITH RECURSIVE balances AS ( + SELECT * from balance + UNION ALL + SELECT time + 1 as time, name, account_balance + 10 as account_balance + FROM balances + WHERE time < 10 +) +SELECT * FROM balances +ORDER BY time, name, account_balance +---- +logical_plan +Sort: balances.time ASC NULLS LAST, balances.name ASC NULLS LAST, balances.account_balance ASC NULLS LAST +--Projection: balances.time, balances.name, balances.account_balance +----SubqueryAlias: balances +------RecursiveQuery: is_distinct=false +--------Projection: balance.time, balance.name, balance.account_balance +----------TableScan: balance +--------Projection: balances.time + Int64(1) AS time, balances.name, balances.account_balance + Int64(10) AS account_balance +----------Filter: balances.time < Int64(10) +------------TableScan: balances +physical_plan +SortExec: expr=[time@0 ASC NULLS LAST,name@1 ASC NULLS LAST,account_balance@2 ASC NULLS LAST] +--RecursiveQueryExec: name=balances, is_distinct=false +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/balance.csv]]}, projection=[time, name, account_balance], has_header=true +----CoalescePartitionsExec +------ProjectionExec: expr=[time@0 + 1 as time, name@1 as name, account_balance@2 + 10 as account_balance] +--------CoalesceBatchesExec: target_batch_size=2 +----------FilterExec: time@0 < 10 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------WorkTableExec: name=balances + +# recursive CTE with static term derived from table works +# note that this is run with batch size set to 2. This should produce multiple batches per iteration since the input +# table 'balances' has 4 rows +query ITI +WITH RECURSIVE balances AS ( + SELECT * from balance + UNION ALL + SELECT time + 1 as time, name, account_balance + 10 as account_balance + FROM balances + WHERE time < 10 +) +SELECT * FROM balances +ORDER BY time, name, account_balance +---- +1 John 100 +1 Tim 200 +2 John 110 +2 John 300 +2 Tim 210 +2 Tim 400 +3 John 120 +3 John 310 +3 Tim 220 +3 Tim 410 +4 John 130 +4 John 320 +4 Tim 230 +4 Tim 420 +5 John 140 +5 John 330 +5 Tim 240 +5 Tim 430 +6 John 150 +6 John 340 +6 Tim 250 +6 Tim 440 +7 John 160 +7 John 350 +7 Tim 260 +7 Tim 450 +8 John 170 +8 John 360 +8 Tim 270 +8 Tim 460 +9 John 180 +9 John 370 +9 Tim 280 +9 Tim 470 +10 John 190 +10 John 380 +10 Tim 290 +10 Tim 480 + +# reset batch size to default +statement ok +set datafusion.execution.batch_size = 8182; + +# recursive CTE with recursive join works +query ITI +WITH RECURSIVE balances AS ( + SELECT time as time, name as name, account_balance as account_balance + FROM balance + UNION ALL + SELECT time + 1 as time, balances.name, account_balance + growth.account_growth as account_balance + FROM balances + JOIN growth + ON balances.name = growth.name + WHERE time < 10 +) +SELECT * FROM balances +ORDER BY time, name, account_balance +---- +1 John 100 +1 Tim 200 +2 John 103 +2 John 300 +2 Tim 220 +2 Tim 400 +3 John 106 +3 John 303 +3 Tim 240 +3 Tim 420 +4 John 109 +4 John 306 +4 Tim 260 +4 Tim 440 +5 John 112 +5 John 309 +5 Tim 280 +5 Tim 460 +6 John 115 +6 John 312 +6 Tim 300 +6 Tim 480 +7 John 118 +7 John 315 +7 Tim 320 +7 Tim 500 +8 John 121 +8 John 318 +8 Tim 340 +8 Tim 520 +9 John 124 +9 John 321 +9 Tim 360 +9 Tim 540 +10 John 127 +10 John 324 +10 Tim 380 +10 Tim 560 + +# recursive CTE with aggregations works +query I rowsort +WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION ALL + SELECT id + 1 as id + FROM nodes + WHERE id < 10 +) +SELECT sum(id) FROM nodes +---- +55 + +# setup +statement ok +CREATE TABLE t(a BIGINT) AS VALUES(1),(2),(3); + +# referencing CTE multiple times does not error +query II rowsort +WITH RECURSIVE my_cte AS ( + SELECT a from t + UNION ALL + SELECT a+2 as a + FROM my_cte + WHERE a<5 +) +SELECT * FROM my_cte t1, my_cte +---- +1 1 +1 2 +1 3 +1 3 +1 4 +1 5 +1 5 +1 6 +2 1 +2 2 +2 3 +2 3 +2 4 +2 5 +2 5 +2 6 +3 1 +3 1 +3 2 +3 2 +3 3 +3 3 +3 3 +3 3 +3 4 +3 4 +3 5 +3 5 +3 5 +3 5 +3 6 +3 6 +4 1 +4 2 +4 3 +4 3 +4 4 +4 5 +4 5 +4 6 +5 1 +5 1 +5 2 +5 2 +5 3 +5 3 +5 3 +5 3 +5 4 +5 4 +5 5 +5 5 +5 5 +5 5 +5 6 +5 6 +6 1 +6 2 +6 3 +6 3 +6 4 +6 5 +6 5 +6 6 + +# CTE within recursive CTE works and does not result in 'index out of bounds: the len is 0 but the index is 0' +query I +WITH RECURSIVE "recursive_cte" AS ( + SELECT 1 as "val" + UNION ALL ( + WITH "sub_cte" AS ( + SELECT + time, + 1 as "val" + FROM + (SELECT DISTINCT "time" FROM "balance") + ) + SELECT + 2 as "val" + FROM + "recursive_cte" + FULL JOIN "sub_cte" ON 1 = 1 + WHERE + "recursive_cte"."val" < 2 + ) +) +SELECT + * +FROM + "recursive_cte"; +---- +1 +2 +2 + +# setup +statement ok +CREATE EXTERNAL TABLE prices STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/prices.csv' + +# CTE within window function inside nested CTE works. This test demonstrates using a nested window function to recursively iterate over a column. +query RRII +WITH RECURSIVE "recursive_cte" AS ( + ( + WITH "min_prices_row_num_cte" AS ( + SELECT + MIN("prices"."prices_row_num") AS "prices_row_num" + FROM + "prices" + ), + "min_prices_row_num_cte_second" AS ( + SELECT + MIN("prices"."prices_row_num") AS "prices_row_num_advancement" + FROM + "prices" + WHERE + "prices"."prices_row_num" > ( + SELECT + "prices_row_num" + FROM + "min_prices_row_num_cte" + ) + ) + SELECT + 0.0 AS "beg", + (0.0 + 50) AS "end", + ( + SELECT + "prices_row_num" + FROM + "min_prices_row_num_cte" + ) AS "prices_row_num", + ( + SELECT + "prices_row_num_advancement" + FROM + "min_prices_row_num_cte_second" + ) AS "prices_row_num_advancement" + FROM + "prices" + WHERE + "prices"."prices_row_num" = ( + SELECT + DISTINCT "prices_row_num" + FROM + "min_prices_row_num_cte" + ) + ) + UNION ALL ( + WITH "min_prices_row_num_cte" AS ( + SELECT + "prices"."prices_row_num" AS "prices_row_num", + LEAD("prices"."prices_row_num", 1) OVER ( + ORDER BY "prices_row_num" + ) AS "prices_row_num_advancement" + FROM + ( + SELECT + DISTINCT "prices_row_num" + FROM + "prices" + ) AS "prices" + ) + SELECT + "recursive_cte"."end" AS "beg", + ("recursive_cte"."end" + 50) AS "end", + "min_prices_row_num_cte"."prices_row_num" AS "prices_row_num", + "min_prices_row_num_cte"."prices_row_num_advancement" AS "prices_row_num_advancement" + FROM + "recursive_cte" + FULL JOIN "prices" ON "prices"."prices_row_num" = "recursive_cte"."prices_row_num_advancement" + FULL JOIN "min_prices_row_num_cte" ON "min_prices_row_num_cte"."prices_row_num" = COALESCE( + "prices"."prices_row_num", + "recursive_cte"."prices_row_num_advancement" + ) + WHERE + "recursive_cte"."prices_row_num_advancement" IS NOT NULL + ) +) +SELECT + DISTINCT * +FROM + "recursive_cte" +ORDER BY + "prices_row_num" ASC; +---- +0 50 1 2 +50 100 2 3 +100 150 3 4 +150 200 4 5 +200 250 5 6 +250 300 6 7 +300 350 7 8 +350 400 8 9 +400 450 9 10 +450 500 10 11 +500 550 11 12 +550 600 12 13 +600 650 13 14 +650 700 14 15 +700 750 15 16 +750 800 16 17 +800 850 17 18 +850 900 18 19 +900 950 19 20 +950 1000 20 21 +1000 1050 21 22 +1050 1100 22 23 +1100 1150 23 24 +1150 1200 24 25 +1200 1250 25 26 +1250 1300 26 27 +1300 1350 27 28 +1350 1400 28 29 +1400 1450 29 30 +1450 1500 30 31 +1500 1550 31 32 +1550 1600 32 33 +1600 1650 33 34 +1650 1700 34 35 +1700 1750 35 36 +1750 1800 36 37 +1800 1850 37 38 +1850 1900 38 39 +1900 1950 39 40 +1950 2000 40 41 +2000 2050 41 42 +2050 2100 42 43 +2100 2150 43 44 +2150 2200 44 45 +2200 2250 45 46 +2250 2300 46 47 +2300 2350 47 48 +2350 2400 48 49 +2400 2450 49 50 +2450 2500 50 51 +2500 2550 51 52 +2550 2600 52 53 +2600 2650 53 54 +2650 2700 54 55 +2700 2750 55 56 +2750 2800 56 57 +2800 2850 57 58 +2850 2900 58 59 +2900 2950 59 60 +2950 3000 60 61 +3000 3050 61 62 +3050 3100 62 63 +3100 3150 63 64 +3150 3200 64 65 +3200 3250 65 66 +3250 3300 66 67 +3300 3350 67 68 +3350 3400 68 69 +3400 3450 69 70 +3450 3500 70 71 +3500 3550 71 72 +3550 3600 72 73 +3600 3650 73 74 +3650 3700 74 75 +3700 3750 75 76 +3750 3800 76 77 +3800 3850 77 78 +3850 3900 78 79 +3900 3950 79 80 +3950 4000 80 81 +4000 4050 81 82 +4050 4100 82 83 +4100 4150 83 84 +4150 4200 84 85 +4200 4250 85 86 +4250 4300 86 87 +4300 4350 87 88 +4350 4400 88 89 +4400 4450 89 90 +4450 4500 90 91 +4500 4550 91 92 +4550 4600 92 93 +4600 4650 93 94 +4650 4700 94 95 +4700 4750 95 96 +4750 4800 96 97 +4800 4850 97 98 +4850 4900 98 99 +4900 4950 99 100 +4950 5000 100 NULL + +# setup +statement ok +CREATE EXTERNAL TABLE sales STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/sales.csv' + +# setup +statement ok +CREATE EXTERNAL TABLE salespersons STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/salespersons.csv' + + +# group by works within recursive cte. This test case demonstrates rolling up a hierarchy of salespeople to their managers. +query III +WITH RECURSIVE region_sales AS ( + -- Anchor member + SELECT + s.salesperson_id AS salesperson_id, + SUM(s.sale_amount) AS amount, + 0 as level + FROM + sales s + GROUP BY + s.salesperson_id + UNION ALL + -- Recursive member + SELECT + sp.manager_id AS salesperson_id, + SUM(rs.amount) AS amount, + MIN(rs.level) + 1 as level + FROM + region_sales rs + INNER JOIN salespersons sp ON rs.salesperson_id = sp.salesperson_id + WHERE sp.manager_id IS NOT NULL + GROUP BY + sp.manager_id +) +SELECT + salesperson_id, + MAX(amount) as amount, + MAX(level) as hierarchy_level +FROM + region_sales +GROUP BY + salesperson_id +ORDER BY + hierarchy_level ASC, salesperson_id ASC; +---- +4 700 0 +5 600 0 +6 500 0 +7 900 0 +2 1300 1 +3 1400 1 +1 2700 2 + +#expect error from recursive CTE with nested recursive terms +query error DataFusion error: This feature is not implemented: Recursive queries cannot be nested +WITH RECURSIVE outer_cte AS ( + SELECT 1 as a + UNION ALL ( + WITH RECURSIVE nested_cte AS ( + SELECT 1 as a + UNION ALL + SELECT a+2 as a + FROM nested_cte where a < 3 + ) + SELECT outer_cte.a +2 + FROM outer_cte JOIN nested_cte USING(a) + WHERE nested_cte.a < 4 + ) +) +SELECT a FROM outer_cte; + +# expect error when recursive CTE is referenced multiple times in the recursive term +query error DataFusion error: This feature is not implemented: Multiple recursive references to the same CTE are not supported +WITH RECURSIVE my_cte AS ( + SELECT 1 as a + UNION ALL + SELECT my_cte.a+2 as a + FROM my_cte join my_cte c2 using(a) + WHERE my_cte.a<5 +) +SELECT a FROM my_cte;