diff --git a/Cargo.toml b/Cargo.toml index aa1ba1f214d5..6a6928e25bdd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.75" +rust-version = "1.76" version = "39.0.0" [workspace.dependencies] diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index c5b34df4f1cf..28312fee79a7 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -387,7 +387,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -714,9 +714,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.72" +version = "0.3.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17c6a35df3749d2e8bb1b7b21a976d82b15548788d2735b9d82f329268f71a11" +checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" dependencies = [ "addr2line", "cc", @@ -1099,7 +1099,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -1278,7 +1278,6 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", - "datafusion-physical-expr", "hashbrown 0.14.5", "hex", "itertools", @@ -1359,7 +1358,6 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", "datafusion-physical-expr-common", "half", "hashbrown 0.14.5", @@ -1687,7 +1685,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -1911,12 +1909,12 @@ dependencies = [ [[package]] name = "http-body-util" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", - "futures-core", + "futures-util", "http 1.1.0", "http-body 1.0.0", "pin-project-lite", @@ -1924,9 +1922,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.8.0" +version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9" [[package]] name = "httpdate" @@ -2001,18 +1999,19 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.26.0" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c" +checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", "hyper 1.3.1", "hyper-util", - "rustls 0.22.4", + "rustls 0.23.10", + "rustls-native-certs 0.7.0", "rustls-pki-types", "tokio", - "tokio-rustls 0.25.0", + "tokio-rustls 0.26.0", "tower-service", ] @@ -2148,9 +2147,9 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lexical-core" @@ -2326,9 +2325,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mimalloc" @@ -2347,9 +2346,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87dfd01fe195c66b572b37921ad8803d010623c0aca821bea2302239d155cdae" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ "adler", ] @@ -2483,9 +2482,9 @@ dependencies = [ [[package]] name = "object" -version = "0.35.0" +version = "0.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8ec7ab813848ba4522158d5517a6093db1ded27575b070f4177b8d12b41db5e" +checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" dependencies = [ "memchr", ] @@ -2699,7 +2698,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -2788,9 +2787,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.85" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -2811,6 +2810,53 @@ dependencies = [ "serde", ] +[[package]] +name = "quinn" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4ceeeeabace7857413798eb1ffa1e9c905a9946a57d81fb69b4b71c4d8eb3ad" +dependencies = [ + "bytes", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls 0.23.10", + "thiserror", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf517c03a109db8100448a4be38d498df8a210a99fe0e1b9eaf39e78c640efe" +dependencies = [ + "bytes", + "rand", + "ring 0.17.8", + "rustc-hash", + "rustls 0.23.10", + "slab", + "thiserror", + "tinyvec", + "tracing", +] + +[[package]] +name = "quinn-udp" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9096629c45860fc7fb143e125eb826b5e721e10be3263160c7d60ca832cf8c46" +dependencies = [ + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.36" @@ -2862,9 +2908,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" dependencies = [ "bitflags 2.5.0", ] @@ -2917,9 +2963,9 @@ checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "reqwest" -version = "0.12.4" +version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "566cafdd92868e0939d3fb961bd0dc25fcfaaed179291093b3d43e6b3150ea10" +checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" dependencies = [ "base64 0.22.1", "bytes", @@ -2930,7 +2976,7 @@ dependencies = [ "http-body 1.0.0", "http-body-util", "hyper 1.3.1", - "hyper-rustls 0.26.0", + "hyper-rustls 0.27.2", "hyper-util", "ipnet", "js-sys", @@ -2939,7 +2985,8 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.22.4", + "quinn", + "rustls 0.23.10", "rustls-native-certs 0.7.0", "rustls-pemfile 2.1.2", "rustls-pki-types", @@ -2948,7 +2995,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tokio-rustls 0.25.0", + "tokio-rustls 0.26.0", "tokio-util", "tower-service", "url", @@ -3027,6 +3074,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.0" @@ -3063,11 +3116,11 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.4" +version = "0.23.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402" dependencies = [ - "log", + "once_cell", "ring 0.17.8", "rustls-pki-types", "rustls-webpki", @@ -3257,7 +3310,7 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3392,7 +3445,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3438,7 +3491,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3451,14 +3504,14 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "0d0208408ba0c3df17ed26eb06992cb1a1268d41b2c0e12e65203fbe3972cee5" [[package]] name = "syn" @@ -3473,9 +3526,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.66" +version = "2.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +checksum = "ff8655ed1d86f3af4ee3fd3263786bc14245ad17c4c7e85ba7187fb3ae028c90" dependencies = [ "proc-macro2", "quote", @@ -3484,9 +3537,9 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "0.1.2" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" [[package]] name = "tempfile" @@ -3538,7 +3591,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3633,7 +3686,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3649,11 +3702,11 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.25.0" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.22.4", + "rustls 0.23.10", "rustls-pki-types", "tokio", ] @@ -3730,7 +3783,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3775,7 +3828,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] @@ -3831,9 +3884,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.0" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" dependencies = [ "form_urlencoded", "idna", @@ -3929,7 +3982,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", "wasm-bindgen-shared", ] @@ -3963,7 +4016,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4228,7 +4281,7 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.67", ] [[package]] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 8f4b3cd81f36..8578476ed43d 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.75" +rust-version = "1.76" readme = "README.md" [dependencies] diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index faa657da6511..c11eb3280c20 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -236,13 +236,14 @@ mod tests { fn setup_context() -> (SessionContext, Arc) { let mut ctx = SessionContext::new(); ctx.register_catalog_list(Arc::new(DynamicFileCatalog::new( - ctx.state().catalog_list(), + ctx.state().catalog_list().clone(), ctx.state_weak_ref(), ))); - let provider = - &DynamicFileCatalog::new(ctx.state().catalog_list(), ctx.state_weak_ref()) - as &dyn CatalogProviderList; + let provider = &DynamicFileCatalog::new( + ctx.state().catalog_list().clone(), + ctx.state_weak_ref(), + ) as &dyn CatalogProviderList; let catalog = provider .catalog(provider.catalog_names().first().unwrap()) .unwrap(); diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index c4c92be1525d..b78f32e0ac48 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -21,7 +21,6 @@ use std::collections::HashMap; use std::fs::File; use std::io::prelude::*; use std::io::BufReader; -use std::str::FromStr; use crate::cli_context::CliSessionContext; use crate::helper::split_from_semicolon; @@ -35,6 +34,7 @@ use crate::{ use datafusion::common::instant::Instant; use datafusion::common::plan_datafusion_err; +use datafusion::config::ConfigFileType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; @@ -42,7 +42,6 @@ use datafusion::physical_plan::{collect, execute_stream, ExecutionPlanProperties use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; -use datafusion::common::FileType; use datafusion::sql::sqlparser; use rustyline::error::ReadlineError; use rustyline::Editor; @@ -291,6 +290,15 @@ impl AdjustedPrintOptions { } } +fn config_file_type_from_str(ext: &str) -> Option { + match ext.to_lowercase().as_str() { + "csv" => Some(ConfigFileType::CSV), + "json" => Some(ConfigFileType::JSON), + "parquet" => Some(ConfigFileType::PARQUET), + _ => None, + } +} + async fn create_plan( ctx: &mut dyn CliSessionContext, statement: Statement, @@ -302,7 +310,7 @@ async fn create_plan( // will raise Configuration errors. if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { // To support custom formats, treat error as None - let format = FileType::from_str(&cmd.file_type).ok(); + let format = config_file_type_from_str(&cmd.file_type); register_object_store_and_config_extensions( ctx, &cmd.location, @@ -313,13 +321,13 @@ async fn create_plan( } if let LogicalPlan::Copy(copy_to) = &mut plan { - let format: FileType = (©_to.format_options).into(); + let format = config_file_type_from_str(©_to.file_type.get_ext()); register_object_store_and_config_extensions( ctx, ©_to.output_url, ©_to.options, - Some(format), + format, ) .await?; } @@ -357,7 +365,7 @@ pub(crate) async fn register_object_store_and_config_extensions( ctx: &dyn CliSessionContext, location: &String, options: &HashMap, - format: Option, + format: Option, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -374,7 +382,7 @@ pub(crate) async fn register_object_store_and_config_extensions( // Clone and modify the default table options based on the provided options let mut table_options = ctx.session_state().default_table_options().clone(); if let Some(format) = format { - table_options.set_file_format(format); + table_options.set_config_format(format); } table_options.alter_with_string_hash_map(options)?; @@ -392,7 +400,6 @@ pub(crate) async fn register_object_store_and_config_extensions( mod tests { use super::*; - use datafusion::common::config::FormatOptions; use datafusion::common::plan_err; use datafusion::prelude::SessionContext; @@ -403,7 +410,7 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { - let format = FileType::from_str(&cmd.file_type).ok(); + let format = config_file_type_from_str(&cmd.file_type); register_object_store_and_config_extensions( &ctx, &cmd.location, @@ -429,12 +436,12 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Copy(cmd) = &plan { - let format: FileType = (&cmd.format_options).into(); + let format = config_file_type_from_str(&cmd.file_type.get_ext()); register_object_store_and_config_extensions( &ctx, &cmd.output_url, &cmd.options, - Some(format), + format, ) .await?; } else { @@ -484,7 +491,7 @@ mod tests { let mut plan = create_plan(&mut ctx, statement).await?; if let LogicalPlan::Copy(copy_to) = &mut plan { assert_eq!(copy_to.output_url, location); - assert!(matches!(copy_to.format_options, FormatOptions::PARQUET(_))); + assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); ctx.runtime_env() .object_store_registry .get_store(&Url::parse(©_to.output_url).unwrap())?; diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index f469fda4f960..6266ae6f561a 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -180,7 +180,7 @@ async fn main_inner() -> Result<()> { ctx.refresh_catalogs().await?; // install dynamic catalog provider that knows how to open files ctx.register_catalog_list(Arc::new(DynamicFileCatalog::new( - ctx.state().catalog_list(), + ctx.state().catalog_list().clone(), ctx.state_weak_ref(), ))); // register `parquet_metadata` table function to get metadata from parquet files diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index c96aa7ae3951..52e3a5525717 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -64,6 +64,7 @@ datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-optimizer = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true, default-features = true } +datafusion-proto = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 6150c551c900..52702361e623 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -19,7 +19,8 @@ # DataFusion Examples -This crate includes several examples of how to use various DataFusion APIs and help you on your way. +This crate includes end to end, highly commented examples of how to use +various DataFusion APIs to help you get started. ## Prerequisites: @@ -27,7 +28,7 @@ Run `git submodule update --init` to init test files. ## Running Examples -To run the examples, use the `cargo run` command, such as: +To run an example, use the `cargo run` command, such as: ```bash git clone https://github.com/apache/datafusion @@ -45,8 +46,10 @@ cargo run --example csv_sql - [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) +- [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a detailed secondary index that covers the contents of several parquet files - [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file - [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog +- [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization - [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file - [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) @@ -61,10 +64,12 @@ cargo run --example csv_sql - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es -- ['parquet_index.rs'](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries +- [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to replace certain predicates +- [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files -- ['parquet_exec_visitor.rs'](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution +- [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution +- [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into Datafusion `Expr`. - [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from Datafusion `Expr` and `LogicalPlan` - [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/advanced_parquet_index.rs new file mode 100644 index 000000000000..9bf71e52c3de --- /dev/null +++ b/datafusion-examples/examples/advanced_parquet_index.rs @@ -0,0 +1,664 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use bytes::Bytes; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::parquet::{ + ParquetAccessPlan, ParquetExecBuilder, +}; +use datafusion::datasource::physical_plan::{ + parquet::ParquetFileReaderFactory, FileMeta, FileScanConfig, +}; +use datafusion::datasource::TableProvider; +use datafusion::execution::context::SessionState; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::arrow_reader::{ + ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, RowSelector, +}; +use datafusion::parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::file::metadata::ParquetMetaData; +use datafusion::parquet::file::properties::{EnabledStatistics, WriterProperties}; +use datafusion::parquet::schema::types::ColumnPath; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_optimizer::pruning::PruningPredicate; +use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::*; +use datafusion_common::{ + internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::utils::conjunction; +use datafusion_expr::{TableProviderFilterPushDown, TableType}; +use datafusion_physical_expr::utils::{Guarantee, LiteralGuarantee}; +use futures::future::BoxFuture; +use futures::FutureExt; +use object_store::ObjectStore; +use std::any::Any; +use std::collections::{HashMap, HashSet}; +use std::fs::File; +use std::ops::Range; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use tempfile::TempDir; +use url::Url; + +/// This example demonstrates using low level DataFusion APIs to read only +/// certain row groups and ranges from parquet files, based on external +/// information. +/// +/// Using these APIs, you can instruct DataFusion's parquet reader to skip +/// ("prune") portions of files that do not contain relevant data. These APIs +/// can be useful for doing low latency queries over a large number of Parquet +/// files on remote storage (e.g. S3) where the cost of reading the metadata for +/// each file is high (e.g. because it requires a network round trip to the +/// storage service). +/// +/// Depending on the information from the index, DataFusion can make a request +/// to the storage service (e.g. S3) to read only the necessary data. +/// +/// Note that this example uses a hard coded index implementation. For a more +/// realistic example of creating an index to prune files, see the +/// `parquet_index.rs` example. +/// +/// Specifically, this example illustrates how to: +/// 1. Use [`ParquetFileReaderFactory`] to avoid re-reading parquet metadata on each query +/// 2. Use [`PruningPredicate`] for predicate analysis +/// 3. Pass a row group selection to [`ParuetExec`] +/// 4. Pass a row selection (within a row group) to [`ParquetExec`] +/// +/// Note this is a *VERY* low level example for people who want to build their +/// own custom indexes (e.g. for low latency queries). Most users should use +/// higher level APIs for reading parquet files: +/// [`SessionContext::read_parquet`] or [`ListingTable`], which also do file +/// pruning based on parquet statistics (using the same underlying APIs) +/// +/// # Diagram +/// +/// This diagram shows how the `ParquetExec` is configured to do only a single +/// (range) read from a parquet file, for the data that is needed. It does +/// not read the file footer or any of the row groups that are not needed. +/// +/// ```text +/// ┌───────────────────────┐ The TableProvider configures the +/// │ ┌───────────────────┐ │ ParquetExec: +/// │ │ │ │ +/// │ └───────────────────┘ │ +/// │ ┌───────────────────┐ │ +/// Row │ │ │ │ 1. To read only specific Row +/// Groups │ └───────────────────┘ │ Groups (the ParquetExec tries +/// │ ┌───────────────────┐ │ to reduce this further based +/// │ │ │ │ on metadata) +/// │ └───────────────────┘ │ ┌────────────────────┐ +/// │ ┌───────────────────┐ │ │ │ +/// │ │ │◀┼ ─ ─ ┐ │ ParquetExec │ +/// │ └───────────────────┘ │ │ (Parquet Reader) │ +/// │ ... │ └ ─ ─ ─ ─│ │ +/// │ ┌───────────────────┐ │ │ ╔═══════════════╗ │ +/// │ │ │ │ │ ║ParquetMetadata║ │ +/// │ └───────────────────┘ │ │ ╚═══════════════╝ │ +/// │ ╔═══════════════════╗ │ └────────────────────┘ +/// │ ║ Thrift metadata ║ │ +/// │ ╚═══════════════════╝ │ 1. With cached ParquetMetadata, so +/// └───────────────────────┘ the ParquetExec does not re-read / +/// Parquet File decode the thrift footer +/// +/// ``` +/// +/// Within a Row Group, Column Chunks store data in DataPages. This example also +/// shows how to configure the ParquetExec to read a `RowSelection` (row ranges) +/// which will skip unneeded data pages. This requires that the Parquet file has +/// a [Page Index]. +/// +/// ```text +/// ┌───────────────────────┐ If the RowSelection does not include any +/// │ ... │ rows from a particular Data Page, that +/// │ │ Data Page is not fetched or decoded. +/// │ ┌───────────────────┐ │ Note this requires a PageIndex +/// │ │ ┌──────────┐ │ │ +/// Row │ │ │DataPage 0│ │ │ ┌────────────────────┐ +/// Groups │ │ └──────────┘ │ │ │ │ +/// │ │ ┌──────────┐ │ │ │ ParquetExec │ +/// │ │ ... │DataPage 1│ ◀┼ ┼ ─ ─ ─ │ (Parquet Reader) │ +/// │ │ └──────────┘ │ │ └ ─ ─ ─ ─ ─│ │ +/// │ │ ┌──────────┐ │ │ │ ╔═══════════════╗ │ +/// │ │ │DataPage 2│ │ │ If only rows │ ║ParquetMetadata║ │ +/// │ │ └──────────┘ │ │ from DataPage 1 │ ╚═══════════════╝ │ +/// │ └───────────────────┘ │ are selected, └────────────────────┘ +/// │ │ only DataPage 1 +/// │ ... │ is fetched and +/// │ │ decoded +/// │ ╔═══════════════════╗ │ +/// │ ║ Thrift metadata ║ │ +/// │ ╚═══════════════════╝ │ +/// └───────────────────────┘ +/// Parquet File +/// ``` +/// +/// [`ListingTable`]: datafusion::datasource::listing::ListingTable +/// [Page Index](https://github.com/apache/parquet-format/blob/master/PageIndex.md) +#[tokio::main] +async fn main() -> Result<()> { + // the object store is used to read the parquet files (in this case, it is + // a local file system, but in a real system it could be S3, GCS, etc) + let object_store: Arc = + Arc::new(object_store::local::LocalFileSystem::new()); + + // Create a custom table provider with our special index. + let provider = Arc::new(IndexTableProvider::try_new(Arc::clone(&object_store))?); + + // SessionContext for running queries that has the table provider + // registered as "index_table" + let ctx = SessionContext::new(); + ctx.register_table("index_table", Arc::clone(&provider) as _)?; + + // register object store provider for urls like `file://` work + let url = Url::try_from("file://").unwrap(); + ctx.register_object_store(&url, object_store); + + // Select data from the table without any predicates (and thus no pruning) + println!("** Select data, no predicates:"); + ctx.sql("SELECT avg(id), max(text) FROM index_table") + .await? + .show() + .await?; + // the underlying parquet reader makes 10 IO requests, one for each row group + + // Now, run a query that has a predicate that our index can handle + // + // For this query, the access plan specifies skipping 8 row groups + // and scanning 2 of them. The skipped row groups are not read at all: + // + // [Skip, Skip, Scan, Skip, Skip, Skip, Skip, Scan, Skip, Skip] + // + // Note that the parquet reader makes 2 IO requests - one for the data from + // each row group. + println!("** Select data, predicate `id IN (250, 750)`"); + ctx.sql("SELECT text FROM index_table WHERE id IN (250, 750)") + .await? + .show() + .await?; + + // Finally, demonstrate scanning sub ranges within the row groups. + // Parquet's minimum decode unit is a page, so specifying ranges + // within a row group can be used to skip pages within a row group. + // + // For this query, the access plan specifies skipping all but the last row + // group and within the last row group, reading only the row with id 950 + // + // [Skip, Skip, Skip, Skip, Skip, Skip, Skip, Skip, Skip, Selection(skip 49, select 1, skip 50)] + // + // Note that the parquet reader makes a single IO request - for the data + // pages that must be decoded + // + // Note: in order to prune pages, the Page Index must be loaded and the + // ParquetExec will load it on demand if not present. To avoid a second IO + // during query, this example loaded the Page Index pre-emptively by setting + // `ArrowReader::with_page_index` in `IndexedFile::try_new` + provider.set_use_row_selection(true); + println!("** Select data, predicate `id = 950`"); + ctx.sql("SELECT text FROM index_table WHERE id = 950") + .await? + .show() + .await?; + + Ok(()) +} + +/// DataFusion `TableProvider` that uses knowledge of how data is distributed in +/// a file to prune row groups and rows from the file. +/// +/// `file1.parquet` contains values `0..1000` +#[derive(Debug)] +pub struct IndexTableProvider { + /// Where the file is stored (cleanup on drop) + #[allow(dead_code)] + tmpdir: TempDir, + /// The file that is being read. + indexed_file: IndexedFile, + /// The underlying object store + object_store: Arc, + /// if true, use row selections in addition to row group selections + use_row_selections: AtomicBool, +} +impl IndexTableProvider { + /// Create a new IndexTableProvider + /// * `object_store` - the object store implementation to use for reading files + pub fn try_new(object_store: Arc) -> Result { + let tmpdir = TempDir::new().expect("Can't make temporary directory"); + + let indexed_file = + IndexedFile::try_new(tmpdir.path().join("indexed_file.parquet"), 0..1000)?; + + Ok(Self { + indexed_file, + tmpdir, + object_store, + use_row_selections: AtomicBool::new(false), + }) + } + + /// set the value of use row selections + pub fn set_use_row_selection(&self, use_row_selections: bool) { + self.use_row_selections + .store(use_row_selections, Ordering::SeqCst); + } + + /// return the value of use row selections + pub fn use_row_selections(&self) -> bool { + self.use_row_selections.load(Ordering::SeqCst) + } + + /// convert filters like `a = 1`, `b = 2` + /// to a single predicate like `a = 1 AND b = 2` suitable for execution + fn filters_to_predicate( + &self, + state: &SessionState, + filters: &[Expr], + ) -> Result> { + let df_schema = DFSchema::try_from(self.schema())?; + + let predicate = conjunction(filters.to_vec()); + let predicate = predicate + .map(|predicate| state.create_physical_expr(predicate, &df_schema)) + .transpose()? + // if there are no filters, use a literal true to have a predicate + // that always evaluates to true we can pass to the index + .unwrap_or_else(|| datafusion_physical_expr::expressions::lit(true)); + + Ok(predicate) + } + + /// Returns a [`ParquetAccessPlan`] that specifies how to scan the + /// parquet file. + /// + /// A `ParquetAccessPlan` specifies which row groups and which rows within + /// those row groups to scan. + fn create_plan( + &self, + predicate: &Arc, + ) -> Result { + // In this example, we use the PruningPredicate's literal guarantees to + // analyze the predicate. In a real system, using + // `PruningPredicate::prune` would likely be easier to do. + let pruning_predicate = + PruningPredicate::try_new(Arc::clone(predicate), self.schema().clone())?; + + // The PruningPredicate's guarantees must all be satisfied in order for + // the predicate to possibly evaluate to true. + let guarantees = pruning_predicate.literal_guarantees(); + let Some(constants) = self.value_constants(guarantees) else { + return Ok(self.indexed_file.scan_all_plan()); + }; + + // Begin with a plan that skips all row groups. + let mut plan = self.indexed_file.scan_none_plan(); + + // determine which row groups have the values in the guarantees + for value in constants { + let ScalarValue::Int32(Some(val)) = value else { + // if we have unexpected type of constant, no pruning is possible + return Ok(self.indexed_file.scan_all_plan()); + }; + + // Since we know the values in the files are between 0..1000 and + // evenly distributed between in row groups, calculate in what row + // group this value appears and tell the parquet reader to read it + let val = *val as usize; + let num_rows_in_row_group = 1000 / plan.len(); + let row_group_index = val / num_rows_in_row_group; + plan.scan(row_group_index); + + // If we want to use row selections, which the parquet reader can + // use to skip data pages when the parquet file has a "page index" + // and the reader is configured to read it, add a row selection + if self.use_row_selections() { + let offset_in_row_group = val - row_group_index * num_rows_in_row_group; + let selection = RowSelection::from(vec![ + // skip rows before the desired row + RowSelector::skip(offset_in_row_group.saturating_sub(1)), + // select the actual row + RowSelector::select(1), + // skip any remaining rows in the group + RowSelector::skip(num_rows_in_row_group - offset_in_row_group), + ]); + + plan.scan_selection(row_group_index, selection); + } + } + + Ok(plan) + } + + /// Returns the set of constants that the `"id"` column must take in order + /// for the predicate to be true. + /// + /// If `None` is returned, we can't extract the necessary information from + /// the guarantees. + fn value_constants<'a>( + &self, + guarantees: &'a [LiteralGuarantee], + ) -> Option<&'a HashSet> { + // only handle a single guarantee for column in this example + if guarantees.len() != 1 { + return None; + } + let guarantee = guarantees.first()?; + + // Only handle IN guarantees for the "in" column + if guarantee.guarantee != Guarantee::In || guarantee.column.name() != "id" { + return None; + } + Some(&guarantee.literals) + } +} + +/// Stores information needed to scan a file +#[derive(Debug)] +struct IndexedFile { + /// File name + file_name: String, + /// The path of the file + path: PathBuf, + /// The size of the file + file_size: u64, + /// The pre-parsed parquet metadata for the file + metadata: Arc, + /// The arrow schema of the file + schema: SchemaRef, +} + +impl IndexedFile { + fn try_new(path: impl AsRef, value_range: Range) -> Result { + let path = path.as_ref(); + // write the actual file + make_demo_file(path, value_range)?; + + // Now, open the file and read its size and metadata + let file_name = path + .file_name() + .ok_or_else(|| internal_datafusion_err!("Invalid path"))? + .to_str() + .ok_or_else(|| internal_datafusion_err!("Invalid filename"))? + .to_string(); + let file_size = path.metadata()?.len(); + + let file = File::open(path).map_err(|e| { + DataFusionError::from(e).context(format!("Error opening file {path:?}")) + })?; + + let options = ArrowReaderOptions::new() + // Load the page index when reading metadata to cache + // so it is available to interpret row selections + .with_page_index(true); + let reader = + ParquetRecordBatchReaderBuilder::try_new_with_options(file, options)?; + let metadata = reader.metadata().clone(); + let schema = reader.schema().clone(); + + // canonicalize after writing the file + let path = std::fs::canonicalize(path)?; + + Ok(Self { + file_name, + path, + file_size, + metadata, + schema, + }) + } + + /// Return a `PartitionedFile` to scan the underlying file + /// + /// The returned value does not have any `ParquetAccessPlan` specified in + /// its extensions. + fn partitioned_file(&self) -> PartitionedFile { + PartitionedFile::new(self.path.display().to_string(), self.file_size) + } + + /// Return a `ParquetAccessPlan` that scans all row groups in the file + fn scan_all_plan(&self) -> ParquetAccessPlan { + ParquetAccessPlan::new_all(self.metadata.num_row_groups()) + } + + /// Return a `ParquetAccessPlan` that scans no row groups in the file + fn scan_none_plan(&self) -> ParquetAccessPlan { + ParquetAccessPlan::new_none(self.metadata.num_row_groups()) + } +} + +/// Implement the TableProvider trait for IndexTableProvider +/// so that we can query it as a table. +#[async_trait] +impl TableProvider for IndexTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.indexed_file.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &SessionState, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let indexed_file = &self.indexed_file; + let predicate = self.filters_to_predicate(state, filters)?; + + // Figure out which row groups to scan based on the predicate + let access_plan = self.create_plan(&predicate)?; + println!("{access_plan:?}"); + + let partitioned_file = indexed_file + .partitioned_file() + // provide the starting access plan to the ParquetExec by + // storing it as "extensions" on PartitionedFile + .with_extensions(Arc::new(access_plan) as _); + + // Prepare for scanning + let schema = self.schema(); + let object_store_url = ObjectStoreUrl::parse("file://")?; + let file_scan_config = FileScanConfig::new(object_store_url, schema) + .with_limit(limit) + .with_projection(projection.cloned()) + .with_file(partitioned_file); + + // Configure a factory interface to avoid re-reading the metadata for each file + let reader_factory = + CachedParquetFileReaderFactory::new(Arc::clone(&self.object_store)) + .with_file(indexed_file); + + // Finally, put it all together into a ParquetExec + Ok(ParquetExecBuilder::new(file_scan_config) + // provide the predicate so the ParquetExec can try and prune + // row groups internally + .with_predicate(predicate) + // provide the factory to create parquet reader without re-reading metadata + .with_parquet_file_reader_factory(Arc::new(reader_factory)) + .build_arc()) + } + + /// Tell DataFusion to push filters down to the scan method + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + // Inexact because the pruning can't handle all expressions and pruning + // is not done at the row level -- there may be rows in returned files + // that do not pass the filter + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } +} + +/// A custom [`ParquetFileReaderFactory`] that handles opening parquet files +/// from object storage, and uses pre-loaded metadata. + +#[derive(Debug)] +struct CachedParquetFileReaderFactory { + /// The underlying object store implementation for reading file data + object_store: Arc, + /// The parquet metadata for each file in the index, keyed by the file name + /// (e.g. `file1.parquet`) + metadata: HashMap>, +} + +impl CachedParquetFileReaderFactory { + fn new(object_store: Arc) -> Self { + Self { + object_store, + metadata: HashMap::new(), + } + } + /// Add the pre-parsed information about the file to the factor + fn with_file(mut self, indexed_file: &IndexedFile) -> Self { + self.metadata.insert( + indexed_file.file_name.clone(), + Arc::clone(&indexed_file.metadata), + ); + self + } +} + +impl ParquetFileReaderFactory for CachedParquetFileReaderFactory { + fn create_reader( + &self, + _partition_index: usize, + file_meta: FileMeta, + metadata_size_hint: Option, + _metrics: &ExecutionPlanMetricsSet, + ) -> Result> { + // for this example we ignore the partition index and metrics + // but in a real system you would likely use them to report details on + // the performance of the reader. + let filename = file_meta + .location() + .parts() + .last() + .expect("No path in location") + .as_ref() + .to_string(); + + let object_store = Arc::clone(&self.object_store); + let mut inner = ParquetObjectReader::new(object_store, file_meta.object_meta); + + if let Some(hint) = metadata_size_hint { + inner = inner.with_footer_size_hint(hint) + }; + + let metadata = self + .metadata + .get(&filename) + .expect("metadata for file not found: {filename}"); + Ok(Box::new(ParquetReaderWithCache { + filename, + metadata: Arc::clone(metadata), + inner, + })) + } +} + +/// wrapper around a ParquetObjectReader that caches metadata +struct ParquetReaderWithCache { + filename: String, + metadata: Arc, + inner: ParquetObjectReader, +} + +impl AsyncFileReader for ParquetReaderWithCache { + fn get_bytes( + &mut self, + range: Range, + ) -> BoxFuture<'_, datafusion::parquet::errors::Result> { + println!("get_bytes: {} Reading range {:?}", self.filename, range); + self.inner.get_bytes(range) + } + + fn get_byte_ranges( + &mut self, + ranges: Vec>, + ) -> BoxFuture<'_, datafusion::parquet::errors::Result>> { + println!( + "get_byte_ranges: {} Reading ranges {:?}", + self.filename, ranges + ); + self.inner.get_byte_ranges(ranges) + } + + fn get_metadata( + &mut self, + ) -> BoxFuture<'_, datafusion::parquet::errors::Result>> { + println!("get_metadata: {} returning cached metadata", self.filename); + + // return the cached metadata so the parquet reader does not read it + let metadata = self.metadata.clone(); + async move { Ok(metadata) }.boxed() + } +} + +/// Creates a new parquet file at the specified path. +/// +/// * id: Int32 +/// * text: Utf8 +/// +/// The `id` column increases sequentially from `min_value` to `max_value` +/// The `text` column is a repeating sequence of `TheTextValue{i}` +/// +/// Each row group has 100 rows +fn make_demo_file(path: impl AsRef, value_range: Range) -> Result<()> { + let path = path.as_ref(); + let file = File::create(path)?; + + let id = Int32Array::from_iter_values(value_range.clone()); + let text = + StringArray::from_iter_values(value_range.map(|i| format!("TheTextValue{i}"))); + + let batch = RecordBatch::try_from_iter(vec![ + ("id", Arc::new(id) as ArrayRef), + ("text", Arc::new(text) as ArrayRef), + ])?; + + let schema = batch.schema(); + + // enable page statistics for the tag column, + // for everything else. + let props = WriterProperties::builder() + .set_max_row_group_size(100) + // compute column chunk (per row group) statistics by default + .set_statistics_enabled(EnabledStatistics::Chunk) + // compute column page statistics for the tag column + .set_column_statistics_enabled(ColumnPath::from("tag"), EnabledStatistics::Page) + .build(); + + // write the actual values to the file + let mut writer = ArrowWriter::try_new(file, schema, Some(props))?; + writer.write(&batch)?; + writer.close()?; + + Ok(()) +} diff --git a/datafusion-examples/examples/composed_extension_codec.rs b/datafusion-examples/examples/composed_extension_codec.rs new file mode 100644 index 000000000000..43c6daba211a --- /dev/null +++ b/datafusion-examples/examples/composed_extension_codec.rs @@ -0,0 +1,291 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates how to compose multiple PhysicalExtensionCodecs +//! +//! This can be helpful when an Execution plan tree has different nodes from different crates +//! that need to be serialized. +//! +//! For example if your plan has `ShuffleWriterExec` from `datafusion-ballista` and `DeltaScan` from `deltalake` +//! both crates both provide PhysicalExtensionCodec and this example shows how to combine them together +//! +//! ```text +//! ShuffleWriterExec +//! ProjectionExec +//! ... +//! DeltaScan +//! ``` + +use datafusion::common::Result; +use datafusion::physical_plan::{DisplayAs, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion_common::internal_err; +use datafusion_expr::registry::FunctionRegistry; +use datafusion_expr::ScalarUDF; +use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; +use datafusion_proto::protobuf; +use std::any::Any; +use std::fmt::Debug; +use std::ops::Deref; +use std::sync::Arc; + +#[tokio::main] +async fn main() { + // build execution plan that has both types of nodes + // + // Note each node requires a different `PhysicalExtensionCodec` to decode + let exec_plan = Arc::new(ParentExec { + input: Arc::new(ChildExec {}), + }); + let ctx = SessionContext::new(); + + let composed_codec = ComposedPhysicalExtensionCodec { + codecs: vec![ + Arc::new(ParentPhysicalExtensionCodec {}), + Arc::new(ChildPhysicalExtensionCodec {}), + ], + }; + + // serialize execution plan to proto + let proto: protobuf::PhysicalPlanNode = + protobuf::PhysicalPlanNode::try_from_physical_plan( + exec_plan.clone(), + &composed_codec, + ) + .expect("to proto"); + + // deserialize proto back to execution plan + let runtime = ctx.runtime_env(); + let result_exec_plan: Arc = proto + .try_into_physical_plan(&ctx, runtime.deref(), &composed_codec) + .expect("from proto"); + + // assert that the original and deserialized execution plans are equal + assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); +} + +/// This example has two types of nodes: `ParentExec` and `ChildExec` which can only +/// be serialized with different `PhysicalExtensionCodec`s +#[derive(Debug)] +struct ParentExec { + input: Arc, +} + +impl DisplayAs for ParentExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "ParentExec") + } +} + +impl ExecutionPlan for ParentExec { + fn name(&self) -> &str { + "ParentExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + unreachable!() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + unreachable!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!() + } +} + +/// A PhysicalExtensionCodec that can serialize and deserialize ParentExec +#[derive(Debug)] +struct ParentPhysicalExtensionCodec; + +impl PhysicalExtensionCodec for ParentPhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + if buf == "ParentExec".as_bytes() { + Ok(Arc::new(ParentExec { + input: inputs[0].clone(), + })) + } else { + internal_err!("Not supported") + } + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + if node.as_any().downcast_ref::().is_some() { + buf.extend_from_slice("ParentExec".as_bytes()); + Ok(()) + } else { + internal_err!("Not supported") + } + } +} + +#[derive(Debug)] +struct ChildExec {} + +impl DisplayAs for ChildExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "ChildExec") + } +} + +impl ExecutionPlan for ChildExec { + fn name(&self) -> &str { + "ChildExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + unreachable!() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + unreachable!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!() + } +} + +/// A PhysicalExtensionCodec that can serialize and deserialize ChildExec +#[derive(Debug)] +struct ChildPhysicalExtensionCodec; + +impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + if buf == "ChildExec".as_bytes() { + Ok(Arc::new(ChildExec {})) + } else { + internal_err!("Not supported") + } + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + if node.as_any().downcast_ref::().is_some() { + buf.extend_from_slice("ChildExec".as_bytes()); + Ok(()) + } else { + internal_err!("Not supported") + } + } +} + +/// A PhysicalExtensionCodec that tries one of multiple inner codecs +/// until one works +#[derive(Debug)] +struct ComposedPhysicalExtensionCodec { + codecs: Vec>, +} + +impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + registry: &dyn FunctionRegistry, + ) -> Result> { + let mut last_err = None; + for codec in &self.codecs { + match codec.try_decode(buf, inputs, registry) { + Ok(plan) => return Ok(plan), + Err(e) => last_err = Some(e), + } + } + Err(last_err.unwrap()) + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + let mut last_err = None; + for codec in &self.codecs { + match codec.try_encode(node.clone(), buf) { + Ok(_) => return Ok(()), + Err(e) => last_err = Some(e), + } + } + Err(last_err.unwrap()) + } + + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { + let mut last_err = None; + for codec in &self.codecs { + match codec.try_decode_udf(name, _buf) { + Ok(plan) => return Ok(plan), + Err(e) => last_err = Some(e), + } + } + Err(last_err.unwrap()) + } + + fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { + let mut last_err = None; + for codec in &self.codecs { + match codec.try_encode_udf(_node, _buf) { + Ok(_) => return Ok(()), + Err(e) => last_err = Some(e), + } + } + Err(last_err.unwrap()) + } +} diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index 9fb61008b9f6..e798751b3353 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -19,6 +19,7 @@ use arrow_schema::DataType; use std::sync::Arc; use datafusion::error::Result; +use datafusion::functions_aggregate::average::avg; use datafusion::prelude::*; use datafusion::test_util::arrow_test_data; use datafusion_common::ScalarValue; diff --git a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs b/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs index 4d71ed758912..e75ba5dd5328 100644 --- a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs +++ b/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs @@ -20,10 +20,10 @@ use std::sync::Arc; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; -use datafusion_common::{FileType, GetExt}; use object_store::aws::AmazonS3Builder; use url::Url; @@ -54,7 +54,7 @@ async fn main() -> Result<()> { let path = format!("s3://{bucket_name}/test_data/"); let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::PARQUET.get_ext()); + .with_file_extension(ParquetFormat::default().get_ext()); ctx.register_listing_table("test", &path, listing_options, None, None) .await?; @@ -79,7 +79,7 @@ async fn main() -> Result<()> { let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::PARQUET.get_ext()); + .with_file_extension(ParquetFormat::default().get_ext()); ctx.register_listing_table("test2", &out_path, listing_options, None, None) .await?; diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs new file mode 100644 index 000000000000..057852946341 --- /dev/null +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow_schema::DataType; +use datafusion::prelude::SessionContext; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, +}; +use datafusion_optimizer::optimizer::ApplyOrder; +use datafusion_optimizer::{OptimizerConfig, OptimizerRule}; +use std::any::Any; +use std::sync::Arc; + +/// This example demonstrates how to add your own [`OptimizerRule`] +/// to DataFusion. +/// +/// [`OptimizerRule`]s transform [`LogicalPlan`]s into an equivalent (but +/// hopefully faster) form. +/// +/// See [analyzer_rule.rs] for an example of AnalyzerRules, which are for +/// changing plan semantics. +#[tokio::main] +pub async fn main() -> Result<()> { + // DataFusion includes many built in OptimizerRules for tasks such as outer + // to inner join conversion and constant folding. + // + // Note you can change the order of optimizer rules using the lower level + // `SessionState` API + let ctx = SessionContext::new(); + ctx.add_optimizer_rule(Arc::new(MyOptimizerRule {})); + + // Now, let's plan and run queries with the new rule + ctx.register_batch("person", person_batch())?; + let sql = "SELECT * FROM person WHERE age = 22"; + let plan = ctx.sql(sql).await?.into_optimized_plan()?; + + // We can see the effect of our rewrite on the output plan that the filter + // has been rewritten to `my_eq` + // + // Filter: my_eq(person.age, Int32(22)) + // TableScan: person projection=[name, age] + println!("Logical Plan:\n\n{}\n", plan.display_indent()); + + // The query below doesn't respect a filter `where age = 22` because + // the plan has been rewritten using UDF which returns always true + // + // And the output verifies the predicates have been changed (as the my_eq + // function always returns true) + // + // +--------+-----+ + // | name | age | + // +--------+-----+ + // | Andy | 11 | + // | Andrew | 22 | + // | Oleks | 33 | + // +--------+-----+ + ctx.sql(sql).await?.show().await?; + + // however we can see the rule doesn't trigger for queries with predicates + // other than `=` + // + // +-------+-----+ + // | name | age | + // +-------+-----+ + // | Andy | 11 | + // | Oleks | 33 | + // +-------+-----+ + ctx.sql("SELECT * FROM person WHERE age <> 22") + .await? + .show() + .await?; + + Ok(()) +} + +/// An example OptimizerRule that replaces all `col = ` predicates with a +/// user defined function +struct MyOptimizerRule {} + +impl OptimizerRule for MyOptimizerRule { + fn name(&self) -> &str { + "my_optimizer_rule" + } + + // New OptimizerRules should use the "rewrite" api as it is more efficient + fn supports_rewrite(&self) -> bool { + true + } + + /// Ask the optimizer to handle the plan recursion. `rewrite` will be called + /// on each plan node. + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + plan.map_expressions(|expr| { + // This closure is called for all expressions in the current plan + // + // For example, given a plan like `SELECT a + b, 5 + 10` + // + // The closure would be called twice: + // 1. once for `a + b` + // 2. once for `5 + 10` + self.rewrite_expr(expr) + }) + } +} + +impl MyOptimizerRule { + /// Rewrites an Expr replacing all ` = ` expressions with + /// a call to my_eq udf + fn rewrite_expr(&self, expr: Expr) -> Result> { + // do a bottom up rewrite of the expression tree + expr.transform_up(|expr| { + // Closure called for each sub tree + match expr { + Expr::BinaryExpr(binary_expr) if is_binary_eq(&binary_expr) => { + // destruture the expression + let BinaryExpr { left, op: _, right } = binary_expr; + // rewrite to `my_eq(left, right)` + let udf = ScalarUDF::new_from_impl(MyEq::new()); + let call = udf.call(vec![*left, *right]); + Ok(Transformed::yes(call)) + } + _ => Ok(Transformed::no(expr)), + } + }) + // Note that the TreeNode API handles propagating the transformed flag + // and errors up the call chain + } +} + +/// return true of the expression is an equality expression for a literal or +/// column reference +fn is_binary_eq(binary_expr: &BinaryExpr) -> bool { + binary_expr.op == Operator::Eq + && is_lit_or_col(binary_expr.left.as_ref()) + && is_lit_or_col(binary_expr.right.as_ref()) +} + +/// Return true if the expression is a literal or column reference +fn is_lit_or_col(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +/// A simple user defined filter function +#[derive(Debug, Clone)] +struct MyEq { + signature: Signature, +} + +impl MyEq { + fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Stable), + } + } +} + +impl ScalarUDFImpl for MyEq { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_eq" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + // this example simply returns "true" which is not what a real + // implementation would do. + Ok(ColumnarValue::Scalar(ScalarValue::from(true))) + } +} + +/// Return a RecordBatch with made up data +fn person_batch() -> RecordBatch { + let name: ArrayRef = + Arc::new(StringArray::from_iter_values(["Andy", "Andrew", "Oleks"])); + let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22, 33])); + RecordBatch::try_from_iter(vec![("name", name), ("age", age)]).unwrap() +} diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs new file mode 100644 index 000000000000..6444eb68b6b2 --- /dev/null +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::{ + assert_batches_eq, + error::Result, + prelude::{ParquetReadOptions, SessionContext}, +}; +use datafusion_common::DFSchema; +use datafusion_expr::{col, lit}; +use datafusion_sql::unparser::Unparser; + +/// This example demonstrates the programmatic parsing of SQL expressions using +/// the DataFusion [`SessionContext::parse_sql_expr`] API or the [`DataFrame::parse_sql_expr`] API. +/// +/// +/// The code in this example shows how to: +/// +/// 1. [`simple_session_context_parse_sql_expr_demo`]: Parse a simple SQL text into a logical +/// expression using a schema at [`SessionContext`]. +/// +/// 2. [`simple_dataframe_parse_sql_expr_demo`]: Parse a simple SQL text into a logical expression +/// using a schema at [`DataFrame`]. +/// +/// 3. [`query_parquet_demo`]: Query a parquet file using the parsed_sql_expr from a DataFrame. +/// +/// 4. [`round_trip_parse_sql_expr_demo`]: Parse a SQL text and convert it back to SQL using [`Unparser`]. + +#[tokio::main] +async fn main() -> Result<()> { + // See how to evaluate expressions + simple_session_context_parse_sql_expr_demo()?; + simple_dataframe_parse_sql_expr_demo().await?; + query_parquet_demo().await?; + round_trip_parse_sql_expr_demo().await?; + Ok(()) +} + +/// DataFusion can parse a SQL text to a logical expression against a schema at [`SessionContext`]. +fn simple_session_context_parse_sql_expr_demo() -> Result<()> { + let sql = "a < 5 OR a = 8"; + let expr = col("a").lt(lit(5_i64)).or(col("a").eq(lit(8_i64))); + + // provide type information that `a` is an Int32 + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from(schema).unwrap(); + let ctx = SessionContext::new(); + + let parsed_expr = ctx.parse_sql_expr(sql, &df_schema)?; + + assert_eq!(parsed_expr, expr); + + Ok(()) +} + +/// DataFusion can parse a SQL text to an logical expression using schema at [`DataFrame`]. +async fn simple_dataframe_parse_sql_expr_demo() -> Result<()> { + let sql = "int_col < 5 OR double_col = 8.0"; + let expr = col("int_col") + .lt(lit(5_i64)) + .or(col("double_col").eq(lit(8.0_f64))); + + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let df = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let parsed_expr = df.parse_sql_expr(sql)?; + + assert_eq!(parsed_expr, expr); + + Ok(()) +} + +async fn query_parquet_demo() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let df = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let df = df + .clone() + .select(vec![ + df.parse_sql_expr("int_col")?, + df.parse_sql_expr("double_col")?, + ])? + .filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)? + .aggregate( + vec![df.parse_sql_expr("double_col")?], + vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], + )? + // Directly parsing the SQL text into a sort expression is not supported yet, so + // construct it programatically + .sort(vec![col("double_col").sort(false, false)])? + .limit(0, Some(1))?; + + let result = df.collect().await?; + + assert_batches_eq!( + &[ + "+------------+----------------------+", + "| double_col | sum(?table?.int_col) |", + "+------------+----------------------+", + "| 10.1 | 4 |", + "+------------+----------------------+", + ], + &result + ); + + Ok(()) +} + +/// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`]. +async fn round_trip_parse_sql_expr_demo() -> Result<()> { + let sql = "((int_col < 5) OR (double_col = 8))"; + + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let df = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let parsed_expr = df.parse_sql_expr(sql)?; + + let unparser = Unparser::default(); + let round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); + + assert_eq!(sql, round_trip_sql); + + Ok(()) +} diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/pruning.rs index 3fa35049a8da..c090cd2bcca9 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/pruning.rs @@ -33,6 +33,11 @@ use std::sync::Arc; /// quickly eliminate entire files / partitions / row groups of data from /// consideration using statistical information from a catalog or other /// metadata. +/// +/// This example uses a user defined catalog to supply pruning information, as +/// one might do as part of a higher level storage engine. See +/// `parquet_index.rs` for an example that uses pruning in the context of an +/// individual query. #[tokio::main] async fn main() { // In this example, we'll use the PruningPredicate to determine if diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 556687a46ab4..06286d5d66ed 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -131,14 +131,6 @@ impl OptimizerRule for MyOptimizerRule { "my_optimizer_rule" } - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - unreachable!() - } - fn apply_order(&self) -> Option { Some(ApplyOrder::BottomUp) } diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index d2c8c6a86c7c..aedc511c62fe 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -15,21 +15,21 @@ // specific language governing permissions and limitations // under the License. -use arrow_schema::{Field, Schema}; -use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs}; -use datafusion_expr::simplify::SimplifyInfo; - use std::{any::Any, sync::Arc}; +use arrow_schema::{Field, Schema}; + use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch}; use datafusion::error::Result; +use datafusion::functions_aggregate::average::avg_udaf; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; use datafusion::{assert_batches_eq, prelude::*}; use datafusion_common::cast::as_float64_array; +use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs}; +use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ - expr::{AggregateFunction, AggregateFunctionDefinition}, - function::AccumulatorArgs, - Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, + expr::AggregateFunction, function::AccumulatorArgs, Accumulator, AggregateUDF, + AggregateUDFImpl, GroupsAccumulator, Signature, }; /// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user @@ -92,18 +92,16 @@ impl AggregateUDFImpl for BetterAvgUdaf { // with build-in aggregate function to illustrate the use let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction, _: &dyn SimplifyInfo| { - Ok(Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - // yes it is the same Avg, `BetterAvgUdaf` was just a - // marketing pitch :) - datafusion_expr::aggregate_function::AggregateFunction::Avg, - ), - args: aggregate_function.args, - distinct: aggregate_function.distinct, - filter: aggregate_function.filter, - order_by: aggregate_function.order_by, - null_treatment: aggregate_function.null_treatment, - })) + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + avg_udaf(), + // yes it is the same Avg, `BetterAvgUdaf` was just a + // marketing pitch :) + aggregate_function.args, + aggregate_function.distinct, + aggregate_function.filter, + aggregate_function.order_by, + aggregate_function.null_treatment, + ))) }; Some(Box::new(simplify)) diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs index 4e8d03c38e00..a17e45dba2a3 100644 --- a/datafusion-examples/examples/simplify_udwf_expression.rs +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -18,12 +18,14 @@ use std::any::Any; use arrow_schema::DataType; + use datafusion::execution::context::SessionContext; +use datafusion::functions_aggregate::average::avg_udaf; use datafusion::{error::Result, execution::options::CsvReadOptions}; use datafusion_expr::function::WindowFunctionSimplification; use datafusion_expr::{ - expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr, - PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, + expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature, + Volatility, WindowUDF, WindowUDFImpl, }; /// This UDWF will show how to use the WindowUDFImpl::simplify() API @@ -71,9 +73,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { let simplify = |window_function: datafusion_expr::expr::WindowFunction, _: &dyn SimplifyInfo| { Ok(Expr::WindowFunction(WindowFunction { - fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction( - AggregateFunction::Avg, - ), + fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()), args: window_function.args, partition_by: window_function.partition_by, order_by: window_function.order_by, diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 3e2bc0ad7c3a..e36a4f890644 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -127,6 +127,13 @@ impl Column { }) } + /// return the column's name. + /// + /// Note: This ignores the relation and returns the column name only. + pub fn name(&self) -> &str { + &self.name + } + /// Serialize column into a flat name string pub fn flat_name(&self) -> String { match &self.relation { diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 1c431d04cd35..7e3871e6b795 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -24,7 +24,7 @@ use std::str::FromStr; use crate::error::_config_err; use crate::parsers::CompressionTypeVariant; -use crate::{DataFusionError, FileType, Result}; +use crate::{DataFusionError, Result}; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -204,6 +204,11 @@ config_namespace! { /// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. pub dialect: String, default = "generic".to_string() + /// If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but + /// ignore the length. If false, error if a `VARCHAR` with a length is + /// specified. The Arrow type system does not have a notion of maximum + /// string length and thus DataFusion can not enforce such limits. + pub support_varchar_with_length: bool, default = true } } @@ -303,6 +308,7 @@ config_namespace! { /// statistics into the same file groups. /// Currently experimental pub split_file_groups_by_statistics: bool, default = false + } } @@ -1116,6 +1122,16 @@ macro_rules! extensions_options { } } +/// These file types have special built in behavior for configuration. +/// Use TableOptions::Extensions for configuring other file types. +#[derive(Debug, Clone)] +pub enum ConfigFileType { + CSV, + #[cfg(feature = "parquet")] + PARQUET, + JSON, +} + /// Represents the configuration options available for handling different table formats within a data processing application. /// This struct encompasses options for various file formats including CSV, Parquet, and JSON, allowing for flexible configuration /// of parsing and writing behaviors specific to each format. Additionally, it supports extending functionality through custom extensions. @@ -1134,7 +1150,7 @@ pub struct TableOptions { /// The current file format that the table operations should assume. This option allows /// for dynamic switching between the supported file types (e.g., CSV, Parquet, JSON). - pub current_format: Option, + pub current_format: Option, /// Optional extensions that can be used to extend or customize the behavior of the table /// options. Extensions can be registered using `Extensions::insert` and might include @@ -1152,10 +1168,9 @@ impl ConfigField for TableOptions { if let Some(file_type) = &self.current_format { match file_type { #[cfg(feature = "parquet")] - FileType::PARQUET => self.parquet.visit(v, "format", ""), - FileType::CSV => self.csv.visit(v, "format", ""), - FileType::JSON => self.json.visit(v, "format", ""), - _ => {} + ConfigFileType::PARQUET => self.parquet.visit(v, "format", ""), + ConfigFileType::CSV => self.csv.visit(v, "format", ""), + ConfigFileType::JSON => self.json.visit(v, "format", ""), } } else { self.csv.visit(v, "csv", ""); @@ -1188,12 +1203,9 @@ impl ConfigField for TableOptions { match key { "format" => match format { #[cfg(feature = "parquet")] - FileType::PARQUET => self.parquet.set(rem, value), - FileType::CSV => self.csv.set(rem, value), - FileType::JSON => self.json.set(rem, value), - _ => { - _config_err!("Config value \"{key}\" is not supported on {}", format) - } + ConfigFileType::PARQUET => self.parquet.set(rem, value), + ConfigFileType::CSV => self.csv.set(rem, value), + ConfigFileType::JSON => self.json.set(rem, value), }, _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } @@ -1210,15 +1222,6 @@ impl TableOptions { Self::default() } - /// Sets the file format for the table. - /// - /// # Parameters - /// - /// * `format`: The file format to use (e.g., CSV, Parquet). - pub fn set_file_format(&mut self, format: FileType) { - self.current_format = Some(format); - } - /// Creates a new `TableOptions` instance initialized with settings from a given session config. /// /// # Parameters @@ -1249,6 +1252,15 @@ impl TableOptions { clone } + /// Sets the file format for the table. + /// + /// # Parameters + /// + /// * `format`: The file format to use (e.g., CSV, Parquet). + pub fn set_config_format(&mut self, format: ConfigFileType) { + self.current_format = Some(format); + } + /// Sets the extensions for this `TableOptions` instance. /// /// # Parameters @@ -1393,6 +1405,13 @@ pub struct TableParquetOptions { pub key_value_metadata: HashMap>, } +impl TableParquetOptions { + /// Return new default TableParquetOptions + pub fn new() -> Self { + Self::default() + } +} + impl ConfigField for TableParquetOptions { fn visit(&self, v: &mut V, key_prefix: &str, description: &'static str) { self.global.visit(v, key_prefix, description); @@ -1559,6 +1578,7 @@ config_namespace! { pub delimiter: u8, default = b',' pub quote: u8, default = b'"' pub escape: Option, default = None + pub double_quote: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED pub schema_infer_max_rec: usize, default = 100 pub date_format: Option, default = None @@ -1624,6 +1644,13 @@ impl CsvOptions { self } + /// Set true to indicate that the CSV quotes should be doubled. + /// - default to true + pub fn with_double_quote(mut self, double_quote: bool) -> Self { + self.double_quote = Some(double_quote); + self + } + /// Set a `CompressionTypeVariant` of CSV /// - defaults to `CompressionTypeVariant::UNCOMPRESSED` pub fn with_file_compression_type( @@ -1658,6 +1685,8 @@ config_namespace! { } } +pub trait FormatOptionsExt: Display {} + #[derive(Debug, Clone, PartialEq)] #[allow(clippy::large_enum_variant)] pub enum FormatOptions { @@ -1668,6 +1697,7 @@ pub enum FormatOptions { AVRO, ARROW, } + impl Display for FormatOptions { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let out = match self { @@ -1682,28 +1712,15 @@ impl Display for FormatOptions { } } -impl From for FormatOptions { - fn from(value: FileType) -> Self { - match value { - FileType::ARROW => FormatOptions::ARROW, - FileType::AVRO => FormatOptions::AVRO, - #[cfg(feature = "parquet")] - FileType::PARQUET => FormatOptions::PARQUET(TableParquetOptions::default()), - FileType::CSV => FormatOptions::CSV(CsvOptions::default()), - FileType::JSON => FormatOptions::JSON(JsonOptions::default()), - } - } -} - #[cfg(test)] mod tests { use std::any::Any; use std::collections::HashMap; use crate::config::{ - ConfigEntry, ConfigExtension, ExtensionOptions, Extensions, TableOptions, + ConfigEntry, ConfigExtension, ConfigFileType, ExtensionOptions, Extensions, + TableOptions, }; - use crate::FileType; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -1761,7 +1778,7 @@ mod tests { let mut extension = Extensions::new(); extension.insert(TestExtensionConfig::default()); let mut table_config = TableOptions::new().with_extensions(extension); - table_config.set_file_format(FileType::CSV); + table_config.set_config_format(ConfigFileType::CSV); table_config.set("format.delimiter", ";").unwrap(); assert_eq!(table_config.csv.delimiter, b';'); table_config.set("test.bootstrap.servers", "asd").unwrap(); @@ -1778,7 +1795,7 @@ mod tests { #[test] fn csv_u8_table_options() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::CSV); + table_config.set_config_format(ConfigFileType::CSV); table_config.set("format.delimiter", ";").unwrap(); assert_eq!(table_config.csv.delimiter as char, ';'); table_config.set("format.escape", "\"").unwrap(); @@ -1791,7 +1808,7 @@ mod tests { #[test] fn parquet_table_options() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config .set("format.bloom_filter_enabled::col1", "true") .unwrap(); @@ -1805,7 +1822,7 @@ mod tests { #[test] fn parquet_table_options_config_entry() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config .set("format.bloom_filter_enabled::col1", "true") .unwrap(); @@ -1819,7 +1836,7 @@ mod tests { #[test] fn parquet_table_options_config_metadata_entry() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config.set("format.metadata::key1", "").unwrap(); table_config.set("format.metadata::key2", "value2").unwrap(); table_config diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index 4f948a29adc4..5792cfdba9e0 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -69,6 +69,12 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions { if let Some(v) = &value.null_value { builder = builder.with_null(v.into()) } + if let Some(v) = &value.escape { + builder = builder.with_escape(*v) + } + if let Some(v) = &value.double_quote { + builder = builder.with_double_quote(*v) + } Ok(CsvWriterOptions { writer_options: builder, compression: value.compression, diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index fc0bb7445645..2648f7289798 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -17,11 +17,8 @@ //! File type abstraction -use std::fmt::{self, Display}; -use std::str::FromStr; - -use crate::config::FormatOptions; -use crate::error::{DataFusionError, Result}; +use std::any::Any; +use std::fmt::Display; /// The default file extension of arrow files pub const DEFAULT_ARROW_EXTENSION: &str = ".arrow"; @@ -40,107 +37,10 @@ pub trait GetExt { fn get_ext(&self) -> String; } -/// Readable file type -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum FileType { - /// Apache Arrow file - ARROW, - /// Apache Avro file - AVRO, - /// Apache Parquet file - #[cfg(feature = "parquet")] - PARQUET, - /// CSV file - CSV, - /// JSON file - JSON, -} - -impl From<&FormatOptions> for FileType { - fn from(value: &FormatOptions) -> Self { - match value { - FormatOptions::CSV(_) => FileType::CSV, - FormatOptions::JSON(_) => FileType::JSON, - #[cfg(feature = "parquet")] - FormatOptions::PARQUET(_) => FileType::PARQUET, - FormatOptions::AVRO => FileType::AVRO, - FormatOptions::ARROW => FileType::ARROW, - } - } -} - -impl GetExt for FileType { - fn get_ext(&self) -> String { - match self { - FileType::ARROW => DEFAULT_ARROW_EXTENSION.to_owned(), - FileType::AVRO => DEFAULT_AVRO_EXTENSION.to_owned(), - #[cfg(feature = "parquet")] - FileType::PARQUET => DEFAULT_PARQUET_EXTENSION.to_owned(), - FileType::CSV => DEFAULT_CSV_EXTENSION.to_owned(), - FileType::JSON => DEFAULT_JSON_EXTENSION.to_owned(), - } - } -} - -impl Display for FileType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let out = match self { - FileType::CSV => "csv", - FileType::JSON => "json", - #[cfg(feature = "parquet")] - FileType::PARQUET => "parquet", - FileType::AVRO => "avro", - FileType::ARROW => "arrow", - }; - write!(f, "{}", out) - } -} - -impl FromStr for FileType { - type Err = DataFusionError; - - fn from_str(s: &str) -> Result { - let s = s.to_uppercase(); - match s.as_str() { - "ARROW" => Ok(FileType::ARROW), - "AVRO" => Ok(FileType::AVRO), - #[cfg(feature = "parquet")] - "PARQUET" => Ok(FileType::PARQUET), - "CSV" => Ok(FileType::CSV), - "JSON" | "NDJSON" => Ok(FileType::JSON), - _ => Err(DataFusionError::NotImplemented(format!( - "Unknown FileType: {s}" - ))), - } - } -} - -#[cfg(test)] -#[cfg(feature = "parquet")] -mod tests { - use std::str::FromStr; - - use crate::error::DataFusionError; - use crate::FileType; - - #[test] - fn from_str() { - for (ext, file_type) in [ - ("csv", FileType::CSV), - ("CSV", FileType::CSV), - ("json", FileType::JSON), - ("JSON", FileType::JSON), - ("avro", FileType::AVRO), - ("AVRO", FileType::AVRO), - ("parquet", FileType::PARQUET), - ("PARQUET", FileType::PARQUET), - ] { - assert_eq!(FileType::from_str(ext).unwrap(), file_type); - } - - assert!(matches!( - FileType::from_str("Unknown"), - Err(DataFusionError::NotImplemented(_)) - )); - } +/// Defines the functionality needed for logical planning for +/// a type of file which will be read or written to storage. +pub trait FileType: GetExt + Display + Send + Sync { + /// Returns the table source as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; } diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index 59040b4290b0..77781457d0d2 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -32,10 +32,10 @@ mod tests { use super::parquet_writer::ParquetWriterOptions; use crate::{ - config::TableOptions, + config::{ConfigFileType, TableOptions}, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, parsers::CompressionTypeVariant, - FileType, Result, + Result, }; use parquet::{ @@ -76,7 +76,7 @@ mod tests { option_map.insert("format.bloom_filter_ndv".to_owned(), "123".to_owned()); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -181,7 +181,7 @@ mod tests { ); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -284,7 +284,7 @@ mod tests { option_map.insert("format.delimiter".to_owned(), ";".to_owned()); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::CSV); + table_config.set_config_format(ConfigFileType::CSV); table_config.alter_with_string_hash_map(&option_map)?; let csv_options = CsvWriterOptions::try_from(&table_config.csv)?; @@ -306,7 +306,7 @@ mod tests { option_map.insert("format.compression".to_owned(), "gzip".to_owned()); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::JSON); + table_config.set_config_format(ConfigFileType::JSON); table_config.alter_with_string_hash_map(&option_map)?; let json_options = JsonWriterOptions::try_from(&table_config.json)?; diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 7eecdec8abef..c972536c4d23 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -36,7 +36,7 @@ use crate::error::{Result, _internal_err}; // Combines two hashes into one hash #[inline] -fn combine_hashes(l: u64, r: u64) -> u64 { +pub fn combine_hashes(l: u64, r: u64) -> u64 { let hash = (17 * 37u64).wrapping_add(l); hash.wrapping_mul(37).wrapping_add(r) } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index e64acd0bfefe..c275152642f0 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -54,8 +54,8 @@ pub use error::{ SharedResult, }; pub use file_options::file_type::{ - FileType, GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, - DEFAULT_CSV_EXTENSION, DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, + GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, + DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, }; pub use functional_dependencies::{ aggregate_functional_dependencies, get_required_group_by_exprs_indices, diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 3daf347ae4ff..5b9c4a223de6 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -29,6 +29,7 @@ use std::iter::repeat; use std::str::FromStr; use std::sync::Arc; +use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, as_fixed_size_binary_array, as_fixed_size_list_array, @@ -1077,7 +1078,7 @@ impl ScalarValue { DataType::Float64 => ScalarValue::Float64(Some(10.0)), _ => { return _not_impl_err!( - "Can't create a negative one scalar from data_type \"{datatype:?}\"" + "Can't create a ten scalar from data_type \"{datatype:?}\"" ); } }) @@ -1168,6 +1169,13 @@ impl ScalarValue { /// Calculate arithmetic negation for a scalar value pub fn arithmetic_negate(&self) -> Result { + fn neg_checked_with_ctx( + v: T, + ctx: impl Fn() -> String, + ) -> Result { + v.neg_checked() + .map_err(|e| arrow_datafusion_err!(e).context(ctx())) + } match self { ScalarValue::Int8(None) | ScalarValue::Int16(None) @@ -1177,40 +1185,91 @@ impl ScalarValue { | ScalarValue::Float64(None) => Ok(self.clone()), ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), - ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))), - ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(-v))), - ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(-v))), - ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(-v))), - ScalarValue::IntervalYearMonth(Some(v)) => { - Ok(ScalarValue::IntervalYearMonth(Some(-v))) - } + ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))), + ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(v.neg_checked()?))), + ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))), + ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))), + ScalarValue::IntervalYearMonth(Some(v)) => Ok( + ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx(*v, || { + format!("In negation of IntervalYearMonth({v})") + })?)), + ), ScalarValue::IntervalDayTime(Some(v)) => { let (days, ms) = IntervalDayTimeType::to_parts(*v); - let val = IntervalDayTimeType::make_value(-days, -ms); + let val = IntervalDayTimeType::make_value( + neg_checked_with_ctx(days, || { + format!("In negation of days {days} in IntervalDayTime") + })?, + neg_checked_with_ctx(ms, || { + format!("In negation of milliseconds {ms} in IntervalDayTime") + })?, + ); Ok(ScalarValue::IntervalDayTime(Some(val))) } ScalarValue::IntervalMonthDayNano(Some(v)) => { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); - let val = IntervalMonthDayNanoType::make_value(-months, -days, -nanos); + let val = IntervalMonthDayNanoType::make_value( + neg_checked_with_ctx(months, || { + format!("In negation of months {months} of IntervalMonthDayNano") + })?, + neg_checked_with_ctx(days, || { + format!("In negation of days {days} of IntervalMonthDayNano") + })?, + neg_checked_with_ctx(nanos, || { + format!("In negation of nanos {nanos} of IntervalMonthDayNano") + })?, + ); Ok(ScalarValue::IntervalMonthDayNano(Some(val))) } ScalarValue::Decimal128(Some(v), precision, scale) => { - Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale)) + Ok(ScalarValue::Decimal128( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal128({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal256( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal256({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) } - ScalarValue::Decimal256(Some(v), precision, scale) => Ok( - ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), - ), ScalarValue::TimestampSecond(Some(v), tz) => { - Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampSecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampSecond({v})") + })?), + tz.clone(), + )) } ScalarValue::TimestampNanosecond(Some(v), tz) => { - Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampNanosecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampNanoSecond({v})") + })?), + tz.clone(), + )) } ScalarValue::TimestampMicrosecond(Some(v), tz) => { - Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampMicrosecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampMicroSecond({v})") + })?), + tz.clone(), + )) } ScalarValue::TimestampMillisecond(Some(v), tz) => { - Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampMillisecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampMilliSecond({v})") + })?), + tz.clone(), + )) } value => _internal_err!( "Can not run arithmetic negative on scalar value {value:?}" @@ -1839,7 +1898,7 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let result = ScalarValue::new_list(&scalars, &DataType::Int32); + /// let result = ScalarValue::new_list(&scalars, &DataType::Int32, true); /// /// let expected = ListArray::from_iter_primitive::( /// vec![ @@ -1848,13 +1907,25 @@ impl ScalarValue { /// /// assert_eq!(*result, expected); /// ``` - pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> Arc { + pub fn new_list( + values: &[ScalarValue], + data_type: &DataType, + nullable: bool, + ) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { Self::iter_to_array(values.iter().cloned()).unwrap() }; - Arc::new(array_into_list_array(values)) + Arc::new(array_into_list_array(values, nullable)) + } + + /// Same as [`ScalarValue::new_list`] but with nullable set to true. + pub fn new_list_nullable( + values: &[ScalarValue], + data_type: &DataType, + ) -> Arc { + Self::new_list(values, data_type, true) } /// Converts `IntoIterator` where each element has type corresponding to @@ -1873,7 +1944,7 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32); + /// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32, true); /// /// let expected = ListArray::from_iter_primitive::( /// vec![ @@ -1885,13 +1956,14 @@ impl ScalarValue { pub fn new_list_from_iter( values: impl IntoIterator + ExactSizeIterator, data_type: &DataType, + nullable: bool, ) -> Arc { let values = if values.len() == 0 { new_empty_array(data_type) } else { Self::iter_to_array(values).unwrap() }; - Arc::new(array_into_list_array(values)) + Arc::new(array_into_list_array(values, nullable)) } /// Converts `Vec` where each element has type corresponding to @@ -2305,7 +2377,7 @@ impl ScalarValue { /// use datafusion_common::ScalarValue; /// use arrow::array::ListArray; /// use arrow::datatypes::{DataType, Int32Type}; - /// use datafusion_common::utils::array_into_list_array; + /// use datafusion_common::utils::array_into_list_array_nullable; /// use std::sync::Arc; /// /// let list_arr = ListArray::from_iter_primitive::(vec![ @@ -2314,7 +2386,7 @@ impl ScalarValue { /// ]); /// /// // Wrap into another layer of list, we got nested array as [ [[1,2,3], [4,5]] ] - /// let list_arr = array_into_list_array(Arc::new(list_arr)); + /// let list_arr = array_into_list_array_nullable(Arc::new(list_arr)); /// /// // Convert the array into Scalar Values for each row, we got 1D arrays in this example /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); @@ -2400,11 +2472,12 @@ impl ScalarValue { typed_cast!(array, index, LargeStringArray, LargeUtf8)? } DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?, - DataType::List(_) => { + DataType::List(field) => { let list_array = array.as_list::(); let nested_array = list_array.value(index); // Produces a single element `ListArray` with the value at `index`. - let arr = Arc::new(array_into_list_array(nested_array)); + let arr = + Arc::new(array_into_list_array(nested_array, field.is_nullable())); ScalarValue::List(arr) } @@ -3499,8 +3572,10 @@ mod tests { }; use crate::assert_batches_eq; + use crate::utils::array_into_list_array_nullable; use arrow::buffer::OffsetBuffer; use arrow::compute::{is_null, kernels}; + use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; use arrow_buffer::Buffer; use arrow_schema::Fields; @@ -3646,9 +3721,9 @@ mod tests { ScalarValue::from("data-fusion"), ]; - let result = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); + let result = ScalarValue::new_list_nullable(scalars.as_slice(), &DataType::Utf8); - let expected = array_into_list_array(Arc::new(StringArray::from(vec![ + let expected = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ "rust", "arrow", "data-fusion", @@ -3860,10 +3935,12 @@ mod tests { #[test] fn iter_to_array_string_test() { - let arr1 = - array_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); - let arr2 = - array_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"]))); + let arr1 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ + "foo", "bar", "baz", + ]))); + let arr2 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ + "rust", "world", + ]))); let scalars = vec![ ScalarValue::List(Arc::new(arr1)), @@ -4270,7 +4347,7 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array = ScalarValue::new_list(&[], &DataType::UInt64); + let list_array = ScalarValue::new_list_nullable(&[], &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); @@ -4291,7 +4368,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]; - let list_array = ScalarValue::new_list(&values, &DataType::UInt64); + let list_array = ScalarValue::new_list_nullable(&values, &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -5216,13 +5293,13 @@ mod tests { // Define list-of-structs scalars let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap(); - let nl0 = ScalarValue::List(Arc::new(array_into_list_array(nl0_array))); + let nl0 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl0_array))); let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap(); - let nl1 = ScalarValue::List(Arc::new(array_into_list_array(nl1_array))); + let nl1 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl1_array))); let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap(); - let nl2 = ScalarValue::List(Arc::new(array_into_list_array(nl2_array))); + let nl2 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl2_array))); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); @@ -5494,6 +5571,89 @@ mod tests { Ok(()) } + #[test] + #[allow(arithmetic_overflow)] // we want to test them + fn test_scalar_negative_overflows() -> Result<()> { + macro_rules! test_overflow_on_value { + ($($val:expr),* $(,)?) => {$( + { + let value: ScalarValue = $val; + let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}"); + let root_err = err.find_root(); + match root_err{ + DataFusionError::ArrowError( + ArrowError::ComputeError(_), + _, + ) => {} + _ => return Err(err), + }; + } + )*}; + } + test_overflow_on_value!( + // the integers + i8::MIN.into(), + i16::MIN.into(), + i32::MIN.into(), + i64::MIN.into(), + // for decimals, only value needs to be tested + ScalarValue::try_new_decimal128(i128::MIN, 10, 5)?, + ScalarValue::Decimal256(Some(i256::MIN), 20, 5), + // interval, check all possible values + ScalarValue::IntervalYearMonth(Some(i32::MIN)), + ScalarValue::new_interval_dt(i32::MIN, 999), + ScalarValue::new_interval_dt(1, i32::MIN), + ScalarValue::new_interval_mdn(i32::MIN, 15, 123_456), + ScalarValue::new_interval_mdn(12, i32::MIN, 123_456), + ScalarValue::new_interval_mdn(12, 15, i64::MIN), + // tz doesn't matter when negating + ScalarValue::TimestampSecond(Some(i64::MIN), None), + ScalarValue::TimestampMillisecond(Some(i64::MIN), None), + ScalarValue::TimestampMicrosecond(Some(i64::MIN), None), + ScalarValue::TimestampNanosecond(Some(i64::MIN), None), + ); + + let float_cases = [ + ( + ScalarValue::Float16(Some(f16::MIN)), + ScalarValue::Float16(Some(f16::MAX)), + ), + ( + ScalarValue::Float16(Some(f16::MAX)), + ScalarValue::Float16(Some(f16::MIN)), + ), + (f32::MIN.into(), f32::MAX.into()), + (f32::MAX.into(), f32::MIN.into()), + (f64::MIN.into(), f64::MAX.into()), + (f64::MAX.into(), f64::MIN.into()), + ]; + // skip float 16 because they aren't supported + for (test, expected) in float_cases.into_iter().skip(2) { + assert_eq!(test.arithmetic_negate()?, expected); + } + Ok(()) + } + + #[test] + #[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")] + fn f16_test_overflow() { + // TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case + let cases = [ + ( + ScalarValue::Float16(Some(f16::MIN)), + ScalarValue::Float16(Some(f16::MAX)), + ), + ( + ScalarValue::Float16(Some(f16::MAX)), + ScalarValue::Float16(Some(f16::MIN)), + ), + ]; + + for (test, expected) in cases { + assert_eq!(test.arithmetic_negate().unwrap(), expected); + } + } + macro_rules! expect_operation_error { ($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => { #[test] @@ -6008,7 +6168,7 @@ mod tests { #[test] fn test_build_timestamp_millisecond_list() { let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; - let arr = ScalarValue::new_list( + let arr = ScalarValue::new_list_nullable( &values, &DataType::Timestamp(TimeUnit::Millisecond, None), ); @@ -6019,7 +6179,7 @@ mod tests { fn test_newlist_timestamp_zone() { let s: &'static str = "UTC"; let values = vec![ScalarValue::TimestampMillisecond(Some(1), Some(s.into()))]; - let arr = ScalarValue::new_list( + let arr = ScalarValue::new_list_nullable( &values, &DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), ); diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index a0e4d1a76c03..dd7b80333cf8 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -351,10 +351,19 @@ pub fn longest_consecutive_prefix>( /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -pub fn array_into_list_array(arr: ArrayRef) -> ListArray { +/// The field in the list array is nullable. +pub fn array_into_list_array_nullable(arr: ArrayRef) -> ListArray { + array_into_list_array(arr, true) +} + +/// Array Utils + +/// Wrap an array into a single element `ListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); ListArray::new( - Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), nullable)), offsets, arr, None, diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 45617d88dc0c..532ca8fde9e7 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.75" +rust-version = "1.76" [lints] workspace = true diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 398f59e35d10..8e55da8c3ad0 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -26,6 +26,9 @@ use std::sync::Arc; use crate::arrow::record_batch::RecordBatch; use crate::arrow::util::pretty; +use crate::datasource::file_format::csv::CsvFormatFactory; +use crate::datasource::file_format::format_as_file_type; +use crate::datasource::file_format::json::JsonFormatFactory; use crate::datasource::{provider_as_source, MemTable, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; @@ -44,17 +47,15 @@ use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field}; use arrow_schema::{Schema, SchemaRef}; -use datafusion_common::config::{CsvOptions, FormatOptions, JsonOptions}; +use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, }; -use datafusion_expr::lit; +use datafusion_expr::{case, is_null, lit}; use datafusion_expr::{ - avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, - UNNAMED_TABLE, + max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; -use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum}; +use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum}; use async_trait::async_trait; @@ -177,6 +178,33 @@ impl DataFrame { } } + /// Creates logical expression from a SQL query text. + /// The expression is created and processed againt the current schema. + /// + /// # Example: Parsing SQL queries + /// ``` + /// # use arrow::datatypes::{DataType, Field, Schema}; + /// # use datafusion::prelude::*; + /// # use datafusion_common::{DFSchema, Result}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// // datafusion will parse number as i64 first. + /// let sql = "a > 1 and b in (1, 10)"; + /// let expected = col("a").gt(lit(1 as i64)) + /// .and(col("b").in_list(vec![lit(1 as i64), lit(10 as i64)], false)); + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let expr = df.parse_sql_expr(sql)?; + /// assert_eq!(expected, expr); + /// # Ok(()) + /// # } + /// ``` + pub fn parse_sql_expr(&self, sql: &str) -> Result { + let df_schema = self.schema(); + + self.session_state.create_logical_expr(sql, df_schema) + } + /// Consume the DataFrame and produce a physical plan pub async fn create_physical_plan(self) -> Result> { self.session_state.create_physical_plan(&self.plan).await @@ -534,7 +562,7 @@ impl DataFrame { /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? - /// // Return a single row (a, b) for each distinct value of a + /// // Return a single row (a, b) for each distinct value of a /// .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)?; /// # Ok(()) /// # } @@ -1304,13 +1332,19 @@ impl DataFrame { "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), )); } - let props = writer_options - .unwrap_or_else(|| self.session_state.default_table_options().csv); + + let format = if let Some(csv_opts) = writer_options { + Arc::new(CsvFormatFactory::new_with_options(csv_opts)) + } else { + Arc::new(CsvFormatFactory::new()) + }; + + let file_type = format_as_file_type(format); let plan = LogicalPlanBuilder::copy_to( self.plan, path.into(), - FormatOptions::CSV(props), + file_type, HashMap::new(), options.partition_by, )? @@ -1359,13 +1393,18 @@ impl DataFrame { )); } - let props = writer_options - .unwrap_or_else(|| self.session_state.default_table_options().json); + let format = if let Some(json_opts) = writer_options { + Arc::new(JsonFormatFactory::new_with_options(json_opts)) + } else { + Arc::new(JsonFormatFactory::new()) + }; + + let file_type = format_as_file_type(format); let plan = LogicalPlanBuilder::copy_to( self.plan, path.into(), - FormatOptions::JSON(props), + file_type, Default::default(), options.partition_by, )? @@ -2018,7 +2057,7 @@ mod tests { assert_batches_sorted_eq!( ["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", - "| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |", + "| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |", "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", "| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |", "| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |", @@ -3171,7 +3210,7 @@ mod tests { let sql = r#" SELECT - COUNT(1) + count(1) FROM test GROUP BY @@ -3420,6 +3459,82 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_except_nested_struct() -> Result<()> { + use arrow::array::StructArray; + + let nested_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("lat", DataType::Int32, true), + Field::new("long", DataType::Int32, true), + ])); + let schema = Arc::new(Schema::new(vec![ + Field::new("value", DataType::Int32, true), + Field::new( + "nested", + DataType::Struct(nested_schema.fields.clone()), + true, + ), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("lat", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("long", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ])), + ], + ) + .unwrap(); + + let updated_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)])), + Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("lat", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("long", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ])), + ], + ) + .unwrap(); + + let ctx = SessionContext::new(); + let before = ctx.read_batch(batch).expect("Failed to make DataFrame"); + let after = ctx + .read_batch(updated_batch) + .expect("Failed to make DataFrame"); + + let diff = before + .except(after) + .expect("Failed to except") + .collect() + .await?; + assert_eq!(diff.len(), 1); + Ok(()) + } + #[tokio::test] async fn nested_explain_should_fail() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 0ec46df0ae5d..1abb550f5c98 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -15,11 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use crate::datasource::file_format::{ + format_as_file_type, parquet::ParquetFormatFactory, +}; + use super::{ DataFrame, DataFrameWriteOptions, DataFusionError, LogicalPlanBuilder, RecordBatch, }; -use datafusion_common::config::{FormatOptions, TableParquetOptions}; +use datafusion_common::config::TableParquetOptions; impl DataFrame { /// Execute the `DataFrame` and write the results to Parquet file(s). @@ -57,13 +63,18 @@ impl DataFrame { )); } - let props = writer_options - .unwrap_or_else(|| self.session_state.default_table_options().parquet); + let format = if let Some(parquet_opts) = writer_options { + Arc::new(ParquetFormatFactory::new_with_options(parquet_opts)) + } else { + Arc::new(ParquetFormatFactory::new()) + }; + + let file_type = format_as_file_type(format); let plan = LogicalPlanBuilder::copy_to( self.plan, path.into(), - FormatOptions::PARQUET(props), + file_type, Default::default(), options.partition_by, )? diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 8c6790541597..478a11d7e76e 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -21,12 +21,14 @@ use std::any::Any; use std::borrow::Cow; +use std::collections::HashMap; use std::fmt::{self, Debug}; use std::sync::Arc; use super::file_compression_type::FileCompressionType; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; +use super::FileFormatFactory; use crate::datasource::file_format::FileFormat; use crate::datasource::physical_plan::{ ArrowExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, @@ -40,7 +42,10 @@ use arrow::ipc::reader::FileReader; use arrow::ipc::writer::IpcWriteOptions; use arrow::ipc::{root_as_message, CompressionType}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use datafusion_common::{not_impl_err, DataFusionError, Statistics}; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{ + not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION, +}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; @@ -61,6 +66,38 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// If the buffered Arrow data exceeds this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; +#[derive(Default)] +/// Factory struct used to create [ArrowFormat] +pub struct ArrowFormatFactory; + +impl ArrowFormatFactory { + /// Creates an instance of [ArrowFormatFactory] + pub fn new() -> Self { + Self {} + } +} + +impl FileFormatFactory for ArrowFormatFactory { + fn create( + &self, + _state: &SessionState, + _format_options: &HashMap, + ) -> Result> { + Ok(Arc::new(ArrowFormat)) + } + + fn default(&self) -> Arc { + Arc::new(ArrowFormat) + } +} + +impl GetExt for ArrowFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_ARROW_EXTENSION[1..].to_string() + } +} + /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] pub struct ArrowFormat; @@ -71,6 +108,23 @@ impl FileFormat for ArrowFormat { self } + fn get_ext(&self) -> String { + ArrowFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "Arrow FileFormat does not support compression.".into(), + )), + } + } + async fn infer_schema( &self, _state: &SessionState, diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 7b2c26a2c4f9..f4f9adcba7ed 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -18,15 +18,22 @@ //! [`AvroFormat`] Apache Avro [`FileFormat`] abstractions use std::any::Any; +use std::collections::HashMap; use std::sync::Arc; use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; use async_trait::async_trait; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::DataFusionError; +use datafusion_common::GetExt; +use datafusion_common::DEFAULT_AVRO_EXTENSION; use datafusion_physical_expr::PhysicalExpr; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use super::file_compression_type::FileCompressionType; use super::FileFormat; +use super::FileFormatFactory; use crate::datasource::avro_to_arrow::read_avro_schema_from_reader; use crate::datasource::physical_plan::{AvroExec, FileScanConfig}; use crate::error::Result; @@ -34,6 +41,38 @@ use crate::execution::context::SessionState; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; +#[derive(Default)] +/// Factory struct used to create [AvroFormat] +pub struct AvroFormatFactory; + +impl AvroFormatFactory { + /// Creates an instance of [AvroFormatFactory] + pub fn new() -> Self { + Self {} + } +} + +impl FileFormatFactory for AvroFormatFactory { + fn create( + &self, + _state: &SessionState, + _format_options: &HashMap, + ) -> Result> { + Ok(Arc::new(AvroFormat)) + } + + fn default(&self) -> Arc { + Arc::new(AvroFormat) + } +} + +impl GetExt for AvroFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_AVRO_EXTENSION[1..].to_string() + } +} + /// Avro `FileFormat` implementation. #[derive(Default, Debug)] pub struct AvroFormat; @@ -44,6 +83,23 @@ impl FileFormat for AvroFormat { self } + fn get_ext(&self) -> String { + AvroFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "Avro FileFormat does not support compression.".into(), + )), + } + } + async fn infer_schema( &self, _state: &SessionState, diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 2139b35621f2..92cb11e2b47a 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -18,12 +18,12 @@ //! [`CsvFormat`], Comma Separated Value (CSV) [`FileFormat`] abstractions use std::any::Any; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug}; use std::sync::Arc; use super::write::orchestration::stateless_multipart_put; -use super::FileFormat; +use super::{FileFormat, FileFormatFactory}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::{ @@ -40,9 +40,11 @@ use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Fields, Schema}; -use datafusion_common::config::CsvOptions; +use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; -use datafusion_common::{exec_err, not_impl_err, DataFusionError}; +use datafusion_common::{ + exec_err, not_impl_err, DataFusionError, GetExt, DEFAULT_CSV_EXTENSION, +}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; @@ -53,6 +55,63 @@ use futures::stream::BoxStream; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; +#[derive(Default)] +/// Factory struct used to create [CsvFormatFactory] +pub struct CsvFormatFactory { + options: Option, +} + +impl CsvFormatFactory { + /// Creates an instance of [CsvFormatFactory] + pub fn new() -> Self { + Self { options: None } + } + + /// Creates an instance of [CsvFormatFactory] with customized default options + pub fn new_with_options(options: CsvOptions) -> Self { + Self { + options: Some(options), + } + } +} + +impl FileFormatFactory for CsvFormatFactory { + fn create( + &self, + state: &SessionState, + format_options: &HashMap, + ) -> Result> { + let csv_options = match &self.options { + None => { + let mut table_options = state.default_table_options(); + table_options.set_config_format(ConfigFileType::CSV); + table_options.alter_with_string_hash_map(format_options)?; + table_options.csv + } + Some(csv_options) => { + let mut csv_options = csv_options.clone(); + for (k, v) in format_options { + csv_options.set(k, v)?; + } + csv_options + } + }; + + Ok(Arc::new(CsvFormat::default().with_options(csv_options))) + } + + fn default(&self) -> Arc { + Arc::new(CsvFormat::default()) + } +} + +impl GetExt for CsvFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_CSV_EXTENSION[1..].to_string() + } +} + /// Character Separated Value `FileFormat` implementation. #[derive(Debug, Default)] pub struct CsvFormat { @@ -206,6 +265,18 @@ impl FileFormat for CsvFormat { self } + fn get_ext(&self) -> String { + CsvFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + Ok(format!("{}{}", ext, file_compression_type.get_ext())) + } + async fn infer_schema( &self, state: &SessionState, @@ -558,7 +629,6 @@ mod tests { use datafusion_common::cast::as_string_array; use datafusion_common::internal_err; use datafusion_common::stats::Precision; - use datafusion_common::{FileType, GetExt}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::{col, lit}; @@ -1060,9 +1130,9 @@ mod tests { .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); let ctx = SessionContext::new_with_config(config); - let file_format = CsvFormat::default().with_has_header(false); - let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::CSV.get_ext()); + let file_format = Arc::new(CsvFormat::default().with_has_header(false)); + let listing_options = ListingOptions::new(file_format.clone()) + .with_file_extension(file_format.get_ext()); ctx.register_listing_table( "empty", "tests/data/empty_files/all_empty/", @@ -1113,9 +1183,9 @@ mod tests { .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); let ctx = SessionContext::new_with_config(config); - let file_format = CsvFormat::default().with_has_header(false); - let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::CSV.get_ext()); + let file_format = Arc::new(CsvFormat::default().with_has_header(false)); + let listing_options = ListingOptions::new(file_format.clone()) + .with_file_extension(file_format.get_ext()); ctx.register_listing_table( "empty", "tests/data/empty_files/some_empty", diff --git a/datafusion/core/src/datasource/file_format/file_compression_type.rs b/datafusion/core/src/datasource/file_format/file_compression_type.rs index c1fbe352d37b..a054094822d0 100644 --- a/datafusion/core/src/datasource/file_format/file_compression_type.rs +++ b/datafusion/core/src/datasource/file_format/file_compression_type.rs @@ -22,7 +22,7 @@ use std::str::FromStr; use crate::error::{DataFusionError, Result}; use datafusion_common::parsers::CompressionTypeVariant::{self, *}; -use datafusion_common::{FileType, GetExt}; +use datafusion_common::GetExt; #[cfg(feature = "compression")] use async_compression::tokio::bufread::{ @@ -112,6 +112,11 @@ impl FileCompressionType { variant: UNCOMPRESSED, }; + /// Read only access to self.variant + pub fn get_variant(&self) -> &CompressionTypeVariant { + &self.variant + } + /// The file is compressed or not pub const fn is_compressed(&self) -> bool { self.variant.is_compressed() @@ -245,90 +250,16 @@ pub trait FileTypeExt { fn get_ext_with_compression(&self, c: FileCompressionType) -> Result; } -impl FileTypeExt for FileType { - fn get_ext_with_compression(&self, c: FileCompressionType) -> Result { - let ext = self.get_ext(); - - match self { - FileType::JSON | FileType::CSV => Ok(format!("{}{}", ext, c.get_ext())), - FileType::AVRO | FileType::ARROW => match c.variant { - UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "FileCompressionType can be specified for CSV/JSON FileType.".into(), - )), - }, - #[cfg(feature = "parquet")] - FileType::PARQUET => match c.variant { - UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "FileCompressionType can be specified for CSV/JSON FileType.".into(), - )), - }, - } - } -} - #[cfg(test)] mod tests { use std::str::FromStr; - use crate::datasource::file_format::file_compression_type::{ - FileCompressionType, FileTypeExt, - }; + use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::error::DataFusionError; - use datafusion_common::file_options::file_type::FileType; - use bytes::Bytes; use futures::StreamExt; - #[test] - fn get_ext_with_compression() { - for (file_type, compression, extension) in [ - (FileType::CSV, FileCompressionType::UNCOMPRESSED, ".csv"), - (FileType::CSV, FileCompressionType::GZIP, ".csv.gz"), - (FileType::CSV, FileCompressionType::XZ, ".csv.xz"), - (FileType::CSV, FileCompressionType::BZIP2, ".csv.bz2"), - (FileType::CSV, FileCompressionType::ZSTD, ".csv.zst"), - (FileType::JSON, FileCompressionType::UNCOMPRESSED, ".json"), - (FileType::JSON, FileCompressionType::GZIP, ".json.gz"), - (FileType::JSON, FileCompressionType::XZ, ".json.xz"), - (FileType::JSON, FileCompressionType::BZIP2, ".json.bz2"), - (FileType::JSON, FileCompressionType::ZSTD, ".json.zst"), - ] { - assert_eq!( - file_type.get_ext_with_compression(compression).unwrap(), - extension - ); - } - - let mut ty_ext_tuple = vec![]; - ty_ext_tuple.push((FileType::AVRO, ".avro")); - #[cfg(feature = "parquet")] - ty_ext_tuple.push((FileType::PARQUET, ".parquet")); - - // Cannot specify compression for these file types - for (file_type, extension) in ty_ext_tuple { - assert_eq!( - file_type - .get_ext_with_compression(FileCompressionType::UNCOMPRESSED) - .unwrap(), - extension - ); - for compression in [ - FileCompressionType::GZIP, - FileCompressionType::XZ, - FileCompressionType::BZIP2, - FileCompressionType::ZSTD, - ] { - assert!(matches!( - file_type.get_ext_with_compression(compression), - Err(DataFusionError::Internal(_)) - )); - } - } - } - #[test] fn from_str() { for (ext, compression_type) in [ diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index d5347c498c71..007b084f504d 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -18,13 +18,14 @@ //! [`JsonFormat`]: Line delimited JSON [`FileFormat`] abstractions use std::any::Any; +use std::collections::HashMap; use std::fmt; use std::fmt::Debug; use std::io::BufReader; use std::sync::Arc; use super::write::orchestration::stateless_multipart_put; -use super::{FileFormat, FileScanConfig}; +use super::{FileFormat, FileFormatFactory, FileScanConfig}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::FileGroupDisplay; @@ -41,9 +42,9 @@ use arrow::datatypes::SchemaRef; use arrow::json; use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow_array::RecordBatch; -use datafusion_common::config::JsonOptions; +use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; -use datafusion_common::not_impl_err; +use datafusion_common::{not_impl_err, GetExt, DEFAULT_JSON_EXTENSION}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; @@ -53,6 +54,63 @@ use async_trait::async_trait; use bytes::{Buf, Bytes}; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +#[derive(Default)] +/// Factory struct used to create [JsonFormat] +pub struct JsonFormatFactory { + options: Option, +} + +impl JsonFormatFactory { + /// Creates an instance of [JsonFormatFactory] + pub fn new() -> Self { + Self { options: None } + } + + /// Creates an instance of [JsonFormatFactory] with customized default options + pub fn new_with_options(options: JsonOptions) -> Self { + Self { + options: Some(options), + } + } +} + +impl FileFormatFactory for JsonFormatFactory { + fn create( + &self, + state: &SessionState, + format_options: &HashMap, + ) -> Result> { + let json_options = match &self.options { + None => { + let mut table_options = state.default_table_options(); + table_options.set_config_format(ConfigFileType::JSON); + table_options.alter_with_string_hash_map(format_options)?; + table_options.json + } + Some(json_options) => { + let mut json_options = json_options.clone(); + for (k, v) in format_options { + json_options.set(k, v)?; + } + json_options + } + }; + + Ok(Arc::new(JsonFormat::default().with_options(json_options))) + } + + fn default(&self) -> Arc { + Arc::new(JsonFormat::default()) + } +} + +impl GetExt for JsonFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_JSON_EXTENSION[1..].to_string() + } +} + /// New line delimited JSON `FileFormat` implementation. #[derive(Debug, Default)] pub struct JsonFormat { @@ -95,6 +153,18 @@ impl FileFormat for JsonFormat { self } + fn get_ext(&self) -> String { + JsonFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + Ok(format!("{}{}", ext, file_compression_type.get_ext())) + } + async fn infer_schema( &self, _state: &SessionState, diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 9462cde43610..1aa93a106aff 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -32,7 +32,8 @@ pub mod parquet; pub mod write; use std::any::Any; -use std::fmt; +use std::collections::HashMap; +use std::fmt::{self, Display}; use std::sync::Arc; use crate::arrow::datatypes::SchemaRef; @@ -41,12 +42,29 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use datafusion_common::not_impl_err; +use datafusion_common::file_options::file_type::FileType; +use datafusion_common::{internal_err, not_impl_err, GetExt}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use async_trait::async_trait; +use file_compression_type::FileCompressionType; use object_store::{ObjectMeta, ObjectStore}; +/// Factory for creating [`FileFormat`] instances based on session and command level options +/// +/// Users can provide their own `FileFormatFactory` to support arbitrary file formats +pub trait FileFormatFactory: Sync + Send + GetExt { + /// Initialize a [FileFormat] and configure based on session and command level options + fn create( + &self, + state: &SessionState, + format_options: &HashMap, + ) -> Result>; + + /// Initialize a [FileFormat] with all options set to default values + fn default(&self) -> Arc; +} + /// This trait abstracts all the file format specific implementations /// from the [`TableProvider`]. This helps code re-utilization across /// providers that support the same file formats. @@ -58,6 +76,15 @@ pub trait FileFormat: Send + Sync + fmt::Debug { /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; + /// Returns the extension for this FileFormat, e.g. "file.csv" -> csv + fn get_ext(&self) -> String; + + /// Returns the extension for this FileFormat when compressed, e.g. "file.csv.gz" -> csv + fn get_ext_with_compression( + &self, + _file_compression_type: &FileCompressionType, + ) -> Result; + /// Infer the common schema of the provided objects. The objects will usually /// be analysed up to a given number of records or files (as specified in the /// format config) then give the estimated common schema. This might fail if @@ -106,6 +133,67 @@ pub trait FileFormat: Send + Sync + fmt::Debug { } } +/// A container of [FileFormatFactory] which also implements [FileType]. +/// This enables converting a dyn FileFormat to a dyn FileType. +/// The former trait is a superset of the latter trait, which includes execution time +/// relevant methods. [FileType] is only used in logical planning and only implements +/// the subset of methods required during logical planning. +pub struct DefaultFileType { + file_format_factory: Arc, +} + +impl DefaultFileType { + /// Constructs a [DefaultFileType] wrapper from a [FileFormatFactory] + pub fn new(file_format_factory: Arc) -> Self { + Self { + file_format_factory, + } + } +} + +impl FileType for DefaultFileType { + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Display for DefaultFileType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.file_format_factory.default().fmt(f) + } +} + +impl GetExt for DefaultFileType { + fn get_ext(&self) -> String { + self.file_format_factory.get_ext() + } +} + +/// Converts a [FileFormatFactory] to a [FileType] +pub fn format_as_file_type( + file_format_factory: Arc, +) -> Arc { + Arc::new(DefaultFileType { + file_format_factory, + }) +} + +/// Converts a [FileType] to a [FileFormatFactory]. +/// Returns an error if the [FileType] cannot be +/// downcasted to a [DefaultFileType]. +pub fn file_type_to_format( + file_type: &Arc, +) -> datafusion_common::Result> { + match file_type + .as_ref() + .as_any() + .downcast_ref::() + { + Some(source) => Ok(source.file_format_factory.clone()), + _ => internal_err!("FileType was not DefaultFileType"), + } +} + #[cfg(test)] pub(crate) mod test_util { use std::ops::Range; diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 572904254fd7..44c9cc4ec4a9 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -24,39 +24,39 @@ use std::sync::Arc; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; -use super::{FileFormat, FileScanConfig}; -use crate::arrow::array::{ - BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch, -}; -use crate::arrow::datatypes::{DataType, Fields, Schema, SchemaRef}; +use super::{FileFormat, FileFormatFactory, FileScanConfig}; +use crate::arrow::array::RecordBatch; +use crate::arrow::datatypes::{Fields, Schema, SchemaRef}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::physical_plan::{FileGroupDisplay, FileSinkConfig}; -use crate::datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapterFactory, -}; use crate::datasource::statistics::{create_max_min_accs, get_col_stats}; use crate::error::Result; use crate::execution::context::SessionState; -use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::insert::{DataSink, DataSinkExec}; use crate::physical_plan::{ Accumulator, DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, Statistics, }; -use datafusion_common::config::TableParquetOptions; +use arrow::compute::sum; +use datafusion_common::config::{ConfigField, ConfigFileType, TableParquetOptions}; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_err, DataFusionError, + exec_err, internal_datafusion_err, not_impl_err, DataFusionError, GetExt, + DEFAULT_PARQUET_EXTENSION, }; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; +use hashbrown::HashMap; +use log::debug; use object_store::buffered::BufWriter; use parquet::arrow::arrow_writer::{ compute_leaves, get_column_writers, ArrowColumnChunk, ArrowColumnWriter, @@ -66,18 +66,18 @@ use parquet::arrow::{ arrow_to_parquet_schema, parquet_to_arrow_schema, AsyncArrowWriter, }; use parquet::file::footer::{decode_footer, decode_metadata}; -use parquet::file::metadata::ParquetMetaData; +use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; use parquet::file::properties::WriterProperties; -use parquet::file::statistics::Statistics as ParquetStatistics; use parquet::file::writer::SerializedFileWriter; use parquet::format::FileMetaData; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::task::JoinSet; -use crate::datasource::physical_plan::parquet::ParquetExecBuilder; +use crate::datasource::physical_plan::parquet::{ + ParquetExecBuilder, StatisticsConverter, +}; use futures::{StreamExt, TryStreamExt}; -use hashbrown::HashMap; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; @@ -89,6 +89,65 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; +#[derive(Default)] +/// Factory struct used to create [ParquetFormat] +pub struct ParquetFormatFactory { + options: Option, +} + +impl ParquetFormatFactory { + /// Creates an instance of [ParquetFormatFactory] + pub fn new() -> Self { + Self { options: None } + } + + /// Creates an instance of [ParquetFormatFactory] with customized default options + pub fn new_with_options(options: TableParquetOptions) -> Self { + Self { + options: Some(options), + } + } +} + +impl FileFormatFactory for ParquetFormatFactory { + fn create( + &self, + state: &SessionState, + format_options: &std::collections::HashMap, + ) -> Result> { + let parquet_options = match &self.options { + None => { + let mut table_options = state.default_table_options(); + table_options.set_config_format(ConfigFileType::PARQUET); + table_options.alter_with_string_hash_map(format_options)?; + table_options.parquet + } + Some(parquet_options) => { + let mut parquet_options = parquet_options.clone(); + for (k, v) in format_options { + parquet_options.set(k, v)?; + } + parquet_options + } + }; + + Ok(Arc::new( + ParquetFormat::default().with_options(parquet_options), + )) + } + + fn default(&self) -> Arc { + Arc::new(ParquetFormat::default()) + } +} + +impl GetExt for ParquetFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_PARQUET_EXTENSION[1..].to_string() + } +} + /// The Apache Parquet `FileFormat` implementation #[derive(Debug, Default)] pub struct ParquetFormat { @@ -190,6 +249,23 @@ impl FileFormat for ParquetFormat { self } + fn get_ext(&self) -> String { + ParquetFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "Parquet FileFormat does not support compression.".into(), + )), + } + } + async fn infer_schema( &self, state: &SessionState, @@ -295,86 +371,6 @@ impl FileFormat for ParquetFormat { } } -fn summarize_min_max( - max_values: &mut [Option], - min_values: &mut [Option], - fields: &Fields, - i: usize, - stat: &ParquetStatistics, -) { - if !stat.has_min_max_set() { - max_values[i] = None; - min_values[i] = None; - return; - } - match stat { - ParquetStatistics::Boolean(s) if DataType::Boolean == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(BooleanArray::from(vec![*s.max()]))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(BooleanArray::from(vec![*s.min()]))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - ParquetStatistics::Int32(s) if DataType::Int32 == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(Int32Array::from_value(*s.max(), 1))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(Int32Array::from_value(*s.min(), 1))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - ParquetStatistics::Int64(s) if DataType::Int64 == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(Int64Array::from_value(*s.max(), 1))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(Int64Array::from_value(*s.min(), 1))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - ParquetStatistics::Float(s) if DataType::Float32 == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(Float32Array::from(vec![*s.max()]))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(Float32Array::from(vec![*s.min()]))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - ParquetStatistics::Double(s) if DataType::Float64 == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(Float64Array::from(vec![*s.max()]))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(Float64Array::from(vec![*s.min()]))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - _ => { - max_values[i] = None; - min_values[i] = None; - } - } -} - /// Fetches parquet metadata from ObjectStore for given object /// /// This component is a subject to **change** in near future and is exposed for low level integrations @@ -467,7 +463,7 @@ async fn fetch_statistics( statistics_from_parquet_meta(&metadata, table_schema).await } -/// Convert statistics in [`ParquetMetaData`] into [`Statistics`] +/// Convert statistics in [`ParquetMetaData`] into [`Statistics`] using ['StatisticsConverter`] /// /// The statistics are calculated for each column in the table schema /// using the row group statistics in the parquet metadata. @@ -475,80 +471,107 @@ pub async fn statistics_from_parquet_meta( metadata: &ParquetMetaData, table_schema: SchemaRef, ) -> Result { - let file_metadata = metadata.file_metadata(); + let row_groups_metadata = metadata.row_groups(); + + let mut statistics = Statistics::new_unknown(&table_schema); + let mut has_statistics = false; + let mut num_rows = 0_usize; + let mut total_byte_size = 0_usize; + for row_group_meta in row_groups_metadata { + num_rows += row_group_meta.num_rows() as usize; + total_byte_size += row_group_meta.total_byte_size() as usize; + + if !has_statistics { + row_group_meta.columns().iter().for_each(|column| { + has_statistics = column.statistics().is_some(); + }); + } + } + statistics.num_rows = Precision::Exact(num_rows); + statistics.total_byte_size = Precision::Exact(total_byte_size); + let file_metadata = metadata.file_metadata(); let file_schema = parquet_to_arrow_schema( file_metadata.schema_descr(), file_metadata.key_value_metadata(), )?; - let num_fields = table_schema.fields().len(); - let fields = table_schema.fields(); - - let mut num_rows = 0; - let mut total_byte_size = 0; - let mut null_counts = vec![Precision::Exact(0); num_fields]; - let mut has_statistics = false; - - let schema_adapter = - DefaultSchemaAdapterFactory::default().create(table_schema.clone()); - - let (mut max_values, mut min_values) = create_max_min_accs(&table_schema); - - for row_group_meta in metadata.row_groups() { - num_rows += row_group_meta.num_rows(); - total_byte_size += row_group_meta.total_byte_size(); + statistics.column_statistics = if has_statistics { + let (mut max_accs, mut min_accs) = create_max_min_accs(&table_schema); + let mut null_counts_array = + vec![Precision::Exact(0); table_schema.fields().len()]; - let mut column_stats: HashMap = HashMap::new(); - - for (i, column) in row_group_meta.columns().iter().enumerate() { - if let Some(stat) = column.statistics() { - has_statistics = true; - column_stats.insert(i, (stat.null_count(), stat)); - } - } - - if has_statistics { - for (table_idx, null_cnt) in null_counts.iter_mut().enumerate() { - if let Some(file_idx) = - schema_adapter.map_column_index(table_idx, &file_schema) - { - if let Some((null_count, stats)) = column_stats.get(&file_idx) { - *null_cnt = null_cnt.add(&Precision::Exact(*null_count as usize)); - summarize_min_max( - &mut max_values, - &mut min_values, - fields, - table_idx, - stats, + table_schema + .fields() + .iter() + .enumerate() + .for_each(|(idx, field)| { + match StatisticsConverter::try_new( + field.name(), + &file_schema, + file_metadata.schema_descr(), + ) { + Ok(stats_converter) => { + summarize_min_max_null_counts( + &mut min_accs, + &mut max_accs, + &mut null_counts_array, + idx, + num_rows, + &stats_converter, + row_groups_metadata, ) - } else { - // If none statistics of current column exists, set the Max/Min Accumulator to None. - max_values[table_idx] = None; - min_values[table_idx] = None; + .ok(); + } + Err(e) => { + debug!("Failed to create statistics converter: {}", e); + null_counts_array[idx] = Precision::Exact(num_rows); } - } else { - *null_cnt = null_cnt.add(&Precision::Exact(num_rows as usize)); } - } - } - } + }); - let column_stats = if has_statistics { - get_col_stats(&table_schema, null_counts, &mut max_values, &mut min_values) + get_col_stats( + &table_schema, + null_counts_array, + &mut max_accs, + &mut min_accs, + ) } else { Statistics::unknown_column(&table_schema) }; - let statistics = Statistics { - num_rows: Precision::Exact(num_rows as usize), - total_byte_size: Precision::Exact(total_byte_size as usize), - column_statistics: column_stats, - }; - Ok(statistics) } +fn summarize_min_max_null_counts( + min_accs: &mut [Option], + max_accs: &mut [Option], + null_counts_array: &mut [Precision], + arrow_schema_index: usize, + num_rows: usize, + stats_converter: &StatisticsConverter, + row_groups_metadata: &[RowGroupMetaData], +) -> Result<()> { + let max_values = stats_converter.row_group_maxes(row_groups_metadata)?; + let min_values = stats_converter.row_group_mins(row_groups_metadata)?; + let null_counts = stats_converter.row_group_null_counts(row_groups_metadata)?; + + if let Some(max_acc) = &mut max_accs[arrow_schema_index] { + max_acc.update_batch(&[max_values])?; + } + + if let Some(min_acc) = &mut min_accs[arrow_schema_index] { + min_acc.update_batch(&[min_values])?; + } + + null_counts_array[arrow_schema_index] = Precision::Exact(match sum(&null_counts) { + Some(null_count) => null_count as usize, + None => num_rows, + }); + + Ok(()) +} + /// Implements [`DataSink`] for writing to a parquet file. pub struct ParquetSink { /// Config options for writing data @@ -1126,7 +1149,8 @@ mod tests { use crate::physical_plan::metrics::MetricValue; use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{Array, ArrayRef, StringArray}; - use arrow_schema::Field; + use arrow_array::Int64Array; + use arrow_schema::{DataType, Field}; use async_trait::async_trait; use datafusion_common::cast::{ as_binary_array, as_boolean_array, as_float32_array, as_float64_array, @@ -1449,8 +1473,14 @@ mod tests { // column c1 let c1_stats = &stats.column_statistics[0]; assert_eq!(c1_stats.null_count, Precision::Exact(1)); - assert_eq!(c1_stats.max_value, Precision::Absent); - assert_eq!(c1_stats.min_value, Precision::Absent); + assert_eq!( + c1_stats.max_value, + Precision::Exact(ScalarValue::Utf8(Some("bar".to_string()))) + ); + assert_eq!( + c1_stats.min_value, + Precision::Exact(ScalarValue::Utf8(Some("Foo".to_string()))) + ); // column c2: missing from the file so the table treats all 3 rows as null let c2_stats = &stats.column_statistics[1]; assert_eq!(c2_stats.null_count, Precision::Exact(3)); diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs index 2fd352ee4eb3..14bbc431f973 100644 --- a/datafusion/core/src/datasource/function.rs +++ b/datafusion/core/src/datasource/function.rs @@ -49,6 +49,11 @@ impl TableFunction { &self.name } + /// Get the implementation of the table function + pub fn function(&self) -> &Arc { + &self.fun + } + /// Get the function implementation and generate a table pub fn create_table_provider(&self, args: &[Expr]) -> Result> { self.fun.call(args) diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 7f5e80c4988a..74aca82b3ee6 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -24,20 +24,11 @@ use std::{any::Any, sync::Arc}; use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; use super::PartitionedFile; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{ create_ordering, get_statistics_with_limit, TableProvider, TableType, }; use crate::datasource::{ - file_format::{ - arrow::ArrowFormat, - avro::AvroFormat, - csv::CsvFormat, - file_compression_type::{FileCompressionType, FileTypeExt}, - json::JsonFormat, - FileFormat, - }, + file_format::{file_compression_type::FileCompressionType, FileFormat}, listing::ListingTableUrl, physical_plan::{FileScanConfig, FileSinkConfig}, }; @@ -51,7 +42,8 @@ use crate::{ use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; use arrow_schema::Schema; use datafusion_common::{ - internal_err, plan_err, project_schema, Constraints, FileType, SchemaExt, ToDFSchema, + config_datafusion_err, internal_err, plan_err, project_schema, Constraints, + SchemaExt, ToDFSchema, }; use datafusion_execution::cache::cache_manager::FileStatisticsCache; use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; @@ -119,9 +111,7 @@ impl ListingTableConfig { } } - fn infer_file_type(path: &str) -> Result<(FileType, String)> { - let err_msg = format!("Unable to infer file type from path: {path}"); - + fn infer_file_extension(path: &str) -> Result { let mut exts = path.rsplit('.'); let mut splitted = exts.next().unwrap_or(""); @@ -133,14 +123,7 @@ impl ListingTableConfig { splitted = exts.next().unwrap_or(""); } - let file_type = FileType::from_str(splitted) - .map_err(|_| DataFusionError::Internal(err_msg.to_owned()))?; - - let ext = file_type - .get_ext_with_compression(file_compression_type.to_owned()) - .map_err(|_| DataFusionError::Internal(err_msg))?; - - Ok((file_type, ext)) + Ok(splitted.to_string()) } /// Infer `ListingOptions` based on `table_path` suffix. @@ -161,25 +144,15 @@ impl ListingTableConfig { .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; - let (file_type, file_extension) = - ListingTableConfig::infer_file_type(file.location.as_ref())?; + let file_extension = + ListingTableConfig::infer_file_extension(file.location.as_ref())?; - let mut table_options = state.default_table_options(); - table_options.set_file_format(file_type.clone()); - let file_format: Arc = match file_type { - FileType::CSV => { - Arc::new(CsvFormat::default().with_options(table_options.csv)) - } - #[cfg(feature = "parquet")] - FileType::PARQUET => { - Arc::new(ParquetFormat::default().with_options(table_options.parquet)) - } - FileType::AVRO => Arc::new(AvroFormat), - FileType::JSON => { - Arc::new(JsonFormat::default().with_options(table_options.json)) - } - FileType::ARROW => Arc::new(ArrowFormat), - }; + let file_format = state + .get_file_format_factory(&file_extension) + .ok_or(config_datafusion_err!( + "No file_format found with extension {file_extension}" + ))? + .create(state, &HashMap::new())?; let listing_options = ListingOptions::new(file_format) .with_file_extension(file_extension) @@ -1060,6 +1033,10 @@ impl ListingTable { #[cfg(test)] mod tests { use super::*; + use crate::datasource::file_format::avro::AvroFormat; + use crate::datasource::file_format::csv::CsvFormat; + use crate::datasource::file_format::json::JsonFormat; + use crate::datasource::file_format::parquet::ParquetFormat; #[cfg(feature = "parquet")] use crate::datasource::{provider_as_source, MemTable}; use crate::execution::options::ArrowReadOptions; @@ -1073,7 +1050,7 @@ mod tests { use arrow::record_batch::RecordBatch; use arrow_schema::SortOptions; use datafusion_common::stats::Precision; - use datafusion_common::{assert_contains, GetExt, ScalarValue}; + use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::ExecutionPlanProperties; @@ -1104,6 +1081,8 @@ mod tests { #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_by_default() -> Result<()> { + use crate::datasource::file_format::parquet::ParquetFormat; + let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); @@ -1128,6 +1107,8 @@ mod tests { #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_when_no_stats() -> Result<()> { + use crate::datasource::file_format::parquet::ParquetFormat; + let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); @@ -1162,7 +1143,10 @@ mod tests { let options = ListingOptions::new(Arc::new(ParquetFormat::default())); let schema = options.infer_schema(&state, &table_path).await.unwrap(); - use crate::physical_plan::expressions::col as physical_col; + use crate::{ + datasource::file_format::parquet::ParquetFormat, + physical_plan::expressions::col as physical_col, + }; use std::ops::Add; // (file_sort_order, expected_result) @@ -1253,7 +1237,7 @@ mod tests { register_test_store(&ctx, &[(&path, 100)]); let opt = ListingOptions::new(Arc::new(AvroFormat {})) - .with_file_extension(FileType::AVRO.get_ext()) + .with_file_extension(AvroFormat.get_ext()) .with_table_partition_cols(vec![(String::from("p1"), DataType::Utf8)]) .with_target_partitions(4); @@ -1516,7 +1500,7 @@ mod tests { "10".into(), ); helper_test_append_new_files_to_table( - FileType::JSON, + JsonFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1534,7 +1518,7 @@ mod tests { "10".into(), ); helper_test_append_new_files_to_table( - FileType::CSV, + CsvFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1552,7 +1536,7 @@ mod tests { "10".into(), ); helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1570,7 +1554,7 @@ mod tests { "20".into(), ); helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 1, @@ -1756,7 +1740,7 @@ mod tests { ); config_map.insert("datafusion.execution.batch_size".into(), "1".into()); helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1774,7 +1758,7 @@ mod tests { "zstd".into(), ); let e = helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1787,7 +1771,7 @@ mod tests { } async fn helper_test_append_new_files_to_table( - file_type: FileType, + file_type_ext: String, file_compression_type: FileCompressionType, session_config_map: Option>, expected_n_files_per_insert: usize, @@ -1824,8 +1808,8 @@ mod tests { // Register appropriate table depending on file_type we want to test let tmp_dir = TempDir::new()?; - match file_type { - FileType::CSV => { + match file_type_ext.as_str() { + "csv" => { session_ctx .register_csv( "t", @@ -1836,7 +1820,7 @@ mod tests { ) .await?; } - FileType::JSON => { + "json" => { session_ctx .register_json( "t", @@ -1847,7 +1831,7 @@ mod tests { ) .await?; } - FileType::PARQUET => { + "parquet" => { session_ctx .register_parquet( "t", @@ -1856,7 +1840,7 @@ mod tests { ) .await?; } - FileType::AVRO => { + "avro" => { session_ctx .register_avro( "t", @@ -1865,7 +1849,7 @@ mod tests { ) .await?; } - FileType::ARROW => { + "arrow" => { session_ctx .register_arrow( "t", @@ -1874,6 +1858,7 @@ mod tests { ) .await?; } + _ => panic!("Unrecognized file extension {file_type_ext}"), } // Create and register the source table with the provided schema and inserted data diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 6e4749824395..1d4d08481895 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -18,14 +18,8 @@ //! Factory for creating ListingTables with default options use std::path::Path; -use std::str::FromStr; use std::sync::Arc; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::{ - arrow::ArrowFormat, avro::AvroFormat, csv::CsvFormat, json::JsonFormat, FileFormat, -}; use crate::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; @@ -34,8 +28,8 @@ use crate::datasource::TableProvider; use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::Result; -use datafusion_common::{arrow_datafusion_err, DataFusionError, FileType}; +use datafusion_common::{arrow_datafusion_err, DataFusionError}; +use datafusion_common::{config_datafusion_err, Result}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; @@ -58,28 +52,15 @@ impl TableProviderFactory for ListingTableFactory { state: &SessionState, cmd: &CreateExternalTable, ) -> Result> { - let file_type = FileType::from_str(cmd.file_type.as_str()).map_err(|_| { - DataFusionError::Execution(format!("Unknown FileType {}", cmd.file_type)) - })?; - let mut table_options = state.default_table_options(); - table_options.set_file_format(file_type.clone()); - table_options.alter_with_string_hash_map(&cmd.options)?; + let file_format = state + .get_file_format_factory(cmd.file_type.as_str()) + .ok_or(config_datafusion_err!( + "Unable to create table with format {}! Could not find FileFormat.", + cmd.file_type + ))? + .create(state, &cmd.options)?; let file_extension = get_extension(cmd.location.as_str()); - let file_format: Arc = match file_type { - FileType::CSV => { - Arc::new(CsvFormat::default().with_options(table_options.csv)) - } - #[cfg(feature = "parquet")] - FileType::PARQUET => { - Arc::new(ParquetFormat::default().with_options(table_options.parquet)) - } - FileType::AVRO => Arc::new(AvroFormat), - FileType::JSON => { - Arc::new(JsonFormat::default().with_options(table_options.json)) - } - FileType::ARROW => Arc::new(ArrowFormat), - }; let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() { ( @@ -166,7 +147,9 @@ mod tests { use std::collections::HashMap; use super::*; - use crate::execution::context::SessionContext; + use crate::{ + datasource::file_format::csv::CsvFormat, execution::context::SessionContext, + }; use datafusion_common::{Constraints, DFSchema, TableReference}; diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index c06c630c45d1..327fbd976e87 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -534,13 +534,13 @@ mod tests { use super::*; use crate::dataframe::DataFrameWriteOptions; + use crate::datasource::file_format::csv::CsvFormat; use crate::prelude::*; use crate::test::{partitioned_csv_config, partitioned_file_groups}; use crate::{scalar::ScalarValue, test_util::aggr_test_schema}; use arrow::datatypes::*; use datafusion_common::test_util::arrow_test_data; - use datafusion_common::FileType; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; @@ -561,6 +561,8 @@ mod tests { async fn csv_exec_with_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); @@ -572,7 +574,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -627,6 +629,8 @@ mod tests { async fn csv_exec_with_mixed_order_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); let session_ctx = SessionContext::new_with_config(cfg); let task_ctx = session_ctx.task_ctx(); @@ -639,7 +643,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -694,6 +698,8 @@ mod tests { async fn csv_exec_with_limit( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); let session_ctx = SessionContext::new_with_config(cfg); let task_ctx = session_ctx.task_ctx(); @@ -706,7 +712,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -759,6 +765,8 @@ mod tests { async fn csv_exec_with_missing_column( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema_with_missing_col(); @@ -770,7 +778,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -813,6 +821,8 @@ mod tests { async fn csv_exec_with_partition( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); @@ -824,7 +834,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; @@ -929,7 +939,7 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), ) diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index e97554a791bd..c051b5d9b57d 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -384,7 +384,6 @@ mod tests { use super::*; use crate::dataframe::DataFrameWriteOptions; - use crate::datasource::file_format::file_compression_type::FileTypeExt; use crate::datasource::file_format::{json::JsonFormat, FileFormat}; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; @@ -397,7 +396,6 @@ mod tests { use arrow::array::Array; use arrow::datatypes::{Field, SchemaBuilder}; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; - use datafusion_common::FileType; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use rstest::*; @@ -419,7 +417,7 @@ mod tests { TEST_DATA_BASE, filename, 1, - FileType::JSON, + Arc::new(JsonFormat::default()), file_compression_type.to_owned(), work_dir, ) @@ -453,7 +451,7 @@ mod tests { TEST_DATA_BASE, filename, 1, - FileType::JSON, + Arc::new(JsonFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), ) @@ -472,8 +470,8 @@ mod tests { let path_buf = Path::new(url.path()).join(path); let path = path_buf.to_str().unwrap(); - let ext = FileType::JSON - .get_ext_with_compression(file_compression_type.to_owned()) + let ext = JsonFormat::default() + .get_ext_with_compression(&file_compression_type) .unwrap(); let read_options = NdJsonReadOptions::default() @@ -904,8 +902,8 @@ mod tests { let url: &Url = store_url.as_ref(); let path_buf = Path::new(url.path()).join(path); let path = path_buf.to_str().unwrap(); - let ext = FileType::JSON - .get_ext_with_compression(file_compression_type.to_owned()) + let ext = JsonFormat::default() + .get_ext_with_compression(&file_compression_type) .unwrap(); let read_option = NdJsonReadOptions::default() diff --git a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs index e15e907cd9b8..ea3030664b7b 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs @@ -139,6 +139,11 @@ impl ParquetAccessPlan { self.set(idx, RowGroupAccess::Skip); } + /// scan the i-th row group + pub fn scan(&mut self, idx: usize) { + self.set(idx, RowGroupAccess::Scan); + } + /// Return true if the i-th row group should be scanned pub fn should_scan(&self, idx: usize) -> bool { self.row_groups[idx].should_scan() diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index ec21c5504c69..9d5c64719e75 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -186,9 +186,9 @@ pub use writer::plan_to_parquet; /// let exec = ParquetExec::builder(file_scan_config).build(); /// ``` /// -/// For a complete example, see the [`parquet_index_advanced` example]). +/// For a complete example, see the [`advanced_parquet_index` example]). /// -/// [`parquet_index_advanced` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/parquet_index_advanced.rs +/// [`parquet_index_advanced` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_parquet_index.rs /// /// # Execution Overview /// @@ -796,17 +796,15 @@ mod tests { ArrayRef, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, StructArray, }; - use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; use arrow_schema::Fields; - use datafusion_common::{assert_contains, FileType, GetExt, ScalarValue, ToDFSchema}; - use datafusion_expr::execution_props::ExecutionProps; + use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; - use datafusion_physical_expr::create_physical_expr; + use datafusion_physical_expr::planner::logical2physical; + use datafusion_physical_plan::ExecutionPlanProperties; use chrono::{TimeZone, Utc}; - use datafusion_physical_plan::ExecutionPlanProperties; use futures::StreamExt; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -1996,7 +1994,7 @@ mod tests { // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::PARQUET.get_ext()); + .with_file_extension(ParquetFormat::default().get_ext()); // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; @@ -2061,12 +2059,6 @@ mod tests { Ok(()) } - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() - } - #[tokio::test] async fn test_struct_filter_parquet() -> Result<()> { let tmp_dir = TempDir::new()?; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 18c6c51d2865..f9cce5f783ff 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -410,23 +410,20 @@ pub fn build_row_filter( #[cfg(test)] mod test { - use arrow::datatypes::Field; - use arrow_schema::TimeUnit::Nanosecond; - use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; - use parquet::arrow::parquet_to_arrow_schema; - use parquet::file::reader::{FileReader, SerializedFileReader}; - use rand::prelude::*; - + use super::*; use crate::datasource::schema_adapter::DefaultSchemaAdapterFactory; use crate::datasource::schema_adapter::SchemaAdapterFactory; - use datafusion_common::ToDFSchema; - use datafusion_expr::execution_props::ExecutionProps; + use arrow::datatypes::Field; + use arrow_schema::TimeUnit::Nanosecond; use datafusion_expr::{cast, col, lit, Expr}; - use datafusion_physical_expr::create_physical_expr; + use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::metrics::{Count, Time}; - use super::*; + use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; + use parquet::arrow::parquet_to_arrow_schema; + use parquet::file::reader::{FileReader, SerializedFileReader}; + use rand::prelude::*; // We should ignore predicate that read non-primitive columns #[test] @@ -590,10 +587,4 @@ mod test { assert_eq!(projection, remapped) } } - - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() - } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index e590f372253c..9bc79805746f 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -54,7 +54,7 @@ impl RowGroupAccessPlanFilter { Self { access_plan } } - /// Return true if there are no row groups to scan + /// Return true if there are no row groups pub fn is_empty(&self) -> bool { self.access_plan.is_empty() } @@ -404,15 +404,19 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { #[cfg(test)] mod tests { + use std::ops::Rem; + use std::sync::Arc; + use super::*; use crate::datasource::physical_plan::parquet::reader::ParquetFileReader; use crate::physical_plan::metrics::ExecutionPlanMetricsSet; + use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::{DataType, Field}; - use datafusion_common::{Result, ToDFSchema}; - use datafusion_expr::execution_props::ExecutionProps; + use datafusion_common::Result; use datafusion_expr::{cast, col, lit, Expr}; - use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; + use datafusion_physical_expr::planner::logical2physical; + use parquet::arrow::arrow_to_parquet_schema; use parquet::arrow::async_reader::ParquetObjectReader; use parquet::basic::LogicalType; @@ -422,8 +426,6 @@ mod tests { basic::Type as PhysicalType, file::statistics::Statistics as ParquetStatistics, schema::types::SchemaDescPtr, }; - use std::ops::Rem; - use std::sync::Arc; struct PrimitiveTypeField { name: &'static str, @@ -1111,12 +1113,6 @@ mod tests { ParquetFileMetrics::new(0, "file.parquet", &metrics) } - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() - } - #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { BloomFilterTest::new_data_index_bloom_encoding_stats() diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 3be060ce6180..67c517ddbc4f 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -33,7 +33,7 @@ use arrow_array::{ use arrow_schema::{Field, FieldRef, Schema, TimeUnit}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use half::f16; -use parquet::data_type::FixedLenByteArray; +use parquet::data_type::{ByteArray, FixedLenByteArray}; use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex, RowGroupMetaData}; use parquet::file::page_index::index::{Index, PageIndex}; use parquet::file::statistics::Statistics as ParquetStatistics; @@ -67,7 +67,7 @@ pub(crate) fn from_bytes_to_f16(b: &[u8]) -> Option { // Copy from arrow-rs // https://github.com/apache/arrow-rs/blob/198af7a3f4aa20f9bd003209d9f04b0f37bb120e/parquet/src/arrow/buffer/bit_util.rs#L54 // Convert the byte slice to fixed length byte array with the length of N. -pub fn sign_extend_be(b: &[u8]) -> [u8; N] { +fn sign_extend_be(b: &[u8]) -> [u8; N] { assert!(b.len() <= N, "Array too large, expected less than {N}"); let is_negative = (b[0] & 128u8) == 128u8; let mut result = if is_negative { [255u8; N] } else { [0u8; N] }; @@ -354,32 +354,11 @@ macro_rules! get_statistics { ))), DataType::Timestamp(unit, timezone) =>{ let iter = [<$stat_type_prefix Int64StatsIterator>]::new($iterator).map(|x| x.copied()); - Ok(match unit { - TimeUnit::Second => { - Arc::new(match timezone { - Some(tz) => TimestampSecondArray::from_iter(iter).with_timezone(tz.clone()), - None => TimestampSecondArray::from_iter(iter), - }) - } - TimeUnit::Millisecond => { - Arc::new(match timezone { - Some(tz) => TimestampMillisecondArray::from_iter(iter).with_timezone(tz.clone()), - None => TimestampMillisecondArray::from_iter(iter), - }) - } - TimeUnit::Microsecond => { - Arc::new(match timezone { - Some(tz) => TimestampMicrosecondArray::from_iter(iter).with_timezone(tz.clone()), - None => TimestampMicrosecondArray::from_iter(iter), - }) - } - TimeUnit::Nanosecond => { - Arc::new(match timezone { - Some(tz) => TimestampNanosecondArray::from_iter(iter).with_timezone(tz.clone()), - None => TimestampNanosecondArray::from_iter(iter), - }) - } + TimeUnit::Second => Arc::new(TimestampSecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), + TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), + TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), + TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), }) }, DataType::Time32(unit) => { @@ -549,6 +528,18 @@ macro_rules! make_data_page_stats_iterator { }; } +make_data_page_stats_iterator!( + MinBooleanDataPageStatsIterator, + |x: &PageIndex| { x.min }, + Index::BOOLEAN, + bool +); +make_data_page_stats_iterator!( + MaxBooleanDataPageStatsIterator, + |x: &PageIndex| { x.max }, + Index::BOOLEAN, + bool +); make_data_page_stats_iterator!( MinInt32DataPageStatsIterator, |x: &PageIndex| { x.min }, @@ -609,10 +600,145 @@ make_data_page_stats_iterator!( Index::DOUBLE, f64 ); + +macro_rules! get_decimal_page_stats_iterator { + ($iterator_type: ident, $func: ident, $stat_value_type: ident, $convert_func: ident) => { + struct $iterator_type<'a, I> + where + I: Iterator, + { + iter: I, + } + + impl<'a, I> $iterator_type<'a, I> + where + I: Iterator, + { + fn new(iter: I) -> Self { + Self { iter } + } + } + + impl<'a, I> Iterator for $iterator_type<'a, I> + where + I: Iterator, + { + type Item = Vec>; + + fn next(&mut self) -> Option { + let next = self.iter.next(); + match next { + Some((len, index)) => match index { + Index::INT32(native_index) => Some( + native_index + .indexes + .iter() + .map(|x| { + Some($stat_value_type::from( + x.$func.unwrap_or_default(), + )) + }) + .collect::>(), + ), + Index::INT64(native_index) => Some( + native_index + .indexes + .iter() + .map(|x| { + Some($stat_value_type::from( + x.$func.unwrap_or_default(), + )) + }) + .collect::>(), + ), + Index::BYTE_ARRAY(native_index) => Some( + native_index + .indexes + .iter() + .map(|x| { + Some($convert_func( + x.clone().$func.unwrap_or_default().data(), + )) + }) + .collect::>(), + ), + Index::FIXED_LEN_BYTE_ARRAY(native_index) => Some( + native_index + .indexes + .iter() + .map(|x| { + Some($convert_func( + x.clone().$func.unwrap_or_default().data(), + )) + }) + .collect::>(), + ), + _ => Some(vec![None; len]), + }, + _ => None, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } + } + }; +} + +get_decimal_page_stats_iterator!( + MinDecimal128DataPageStatsIterator, + min, + i128, + from_bytes_to_i128 +); + +get_decimal_page_stats_iterator!( + MaxDecimal128DataPageStatsIterator, + max, + i128, + from_bytes_to_i128 +); + +get_decimal_page_stats_iterator!( + MinDecimal256DataPageStatsIterator, + min, + i256, + from_bytes_to_i256 +); + +get_decimal_page_stats_iterator!( + MaxDecimal256DataPageStatsIterator, + max, + i256, + from_bytes_to_i256 +); +make_data_page_stats_iterator!( + MinByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.min.clone() }, + Index::BYTE_ARRAY, + ByteArray +); +make_data_page_stats_iterator!( + MaxByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.max.clone() }, + Index::BYTE_ARRAY, + ByteArray +); + macro_rules! get_data_page_statistics { ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { paste! { match $data_type { + Some(DataType::Boolean) => Ok(Arc::new( + BooleanArray::from_iter( + [<$stat_type_prefix BooleanDataPageStatsIterator>]::new($iterator) + .flatten() + // BooleanArray::from_iter required a sized iterator, so collect into Vec first + .collect::>() + .into_iter() + ) + )), Some(DataType::UInt8) => Ok(Arc::new( UInt8Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) @@ -692,6 +818,61 @@ macro_rules! get_data_page_statistics { )), Some(DataType::Float32) => Ok(Arc::new(Float32Array::from_iter([<$stat_type_prefix Float32DataPageStatsIterator>]::new($iterator).flatten()))), Some(DataType::Float64) => Ok(Arc::new(Float64Array::from_iter([<$stat_type_prefix Float64DataPageStatsIterator>]::new($iterator).flatten()))), + Some(DataType::Binary) => Ok(Arc::new(BinaryArray::from_iter([<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).flatten()))), + Some(DataType::LargeBinary) => Ok(Arc::new(LargeBinaryArray::from_iter([<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).flatten()))), + Some(DataType::Utf8) => Ok(Arc::new(StringArray::from( + [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).map(|x| { + x.into_iter().filter_map(|x| { + x.and_then(|x| { + let res = std::str::from_utf8(x.data()).map(|s| s.to_string()).ok(); + if res.is_none() { + log::debug!("Utf8 statistics is a non-UTF8 value, ignoring it."); + } + res + }) + }) + }).flatten().collect::>(), + ))), + Some(DataType::LargeUtf8) => Ok(Arc::new(LargeStringArray::from( + [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).map(|x| { + x.into_iter().filter_map(|x| { + x.and_then(|x| { + let res = std::str::from_utf8(x.data()).map(|s| s.to_string()).ok(); + if res.is_none() { + log::debug!("LargeUtf8 statistics is a non-UTF8 value, ignoring it."); + } + res + }) + }) + }).flatten().collect::>(), + ))), + Some(DataType::Timestamp(unit, timezone)) => { + let iter = [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten(); + Ok(match unit { + TimeUnit::Second => Arc::new(TimestampSecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), + TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), + TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), + TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from_iter(iter).with_timezone_opt(timezone.clone())), + }) + }, + Some(DataType::Date32) => Ok(Arc::new(Date32Array::from_iter([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten()))), + Some(DataType::Date64) => Ok( + Arc::new( + Date64Array::from([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) + .map(|x| { + x.into_iter() + .filter_map(|x| { + x.and_then(|x| i64::try_from(x).ok()) + }) + .map(|x| x * 24 * 60 * 60 * 1000) + }).flatten().collect::>() + ) + ) + ), + Some(DataType::Decimal128(precision, scale)) => Ok(Arc::new( + Decimal128Array::from_iter([<$stat_type_prefix Decimal128DataPageStatsIterator>]::new($iterator).flatten()).with_precision_and_scale(*precision, *scale)?)), + Some(DataType::Decimal256(precision, scale)) => Ok(Arc::new( + Decimal256Array::from_iter([<$stat_type_prefix Decimal256DataPageStatsIterator>]::new($iterator).flatten()).with_precision_and_scale(*precision, *scale)?)), _ => unimplemented!() } } @@ -778,6 +959,11 @@ where { let iter = iterator.flat_map(|(len, index)| match index { Index::NONE => vec![None; len], + Index::BOOLEAN(native_index) => native_index + .indexes + .iter() + .map(|x| x.null_count.map(|x| x as u64)) + .collect::>(), Index::INT32(native_index) => native_index .indexes .iter() @@ -803,6 +989,11 @@ where .iter() .map(|x| x.null_count.map(|x| x as u64)) .collect::>(), + Index::BYTE_ARRAY(native_index) => native_index + .indexes + .iter() + .map(|x| x.null_count.map(|x| x as u64)) + .collect::>(), _ => unimplemented!(), }); diff --git a/datafusion/core/src/execution/context/csv.rs b/datafusion/core/src/execution/context/csv.rs index aa1a8b512f7e..08e93cb61305 100644 --- a/datafusion/core/src/execution/context/csv.rs +++ b/datafusion/core/src/execution/context/csv.rs @@ -110,12 +110,12 @@ mod tests { ) .await?; let results = - plan_and_collect(&ctx, "SELECT sum(c1), sum(c2), COUNT(*) FROM test").await?; + plan_and_collect(&ctx, "SELECT sum(c1), sum(c2), count(*) FROM test").await?; assert_eq!(results.len(), 1); let expected = [ "+--------------+--------------+----------+", - "| sum(test.c1) | sum(test.c2) | COUNT(*) |", + "| sum(test.c1) | sum(test.c2) | count(*) |", "+--------------+--------------+----------+", "| 10 | 110 | 20 |", "+--------------+--------------+----------+", diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 6fa83d3d931e..9ec0148d9122 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -28,7 +28,7 @@ use crate::{ catalog::{CatalogProvider, CatalogProviderList, MemoryCatalogProvider}, dataframe::DataFrame, datasource::{ - function::TableFunctionImpl, + function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, provider::TableProviderFactory, }, @@ -52,7 +52,7 @@ use arrow::record_batch::RecordBatch; use arrow_schema::Schema; use datafusion_common::{ config::{ConfigExtension, TableOptions}, - exec_err, not_impl_err, plan_err, + exec_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, DFSchema, SchemaReference, TableReference, }; @@ -75,6 +75,7 @@ use url::Url; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; +use datafusion_optimizer::{AnalyzerRule, OptimizerRule}; mod avro; mod csv; @@ -331,6 +332,23 @@ impl SessionContext { self } + /// Adds an optimizer rule to the end of the existing rules. + /// + /// See [`SessionState`] for more control of when the rule is applied. + pub fn add_optimizer_rule( + &self, + optimizer_rule: Arc, + ) { + self.state.write().append_optimizer_rule(optimizer_rule); + } + + /// Adds an analyzer rule to the end of the existing rules. + /// + /// See [`SessionState`] for more control of when the rule is applied. + pub fn add_analyzer_rule(&self, analyzer_rule: Arc) { + self.state.write().add_analyzer_rule(analyzer_rule); + } + /// Registers an [`ObjectStore`] to be used with a specific URL prefix. /// /// See [`RuntimeEnv::register_object_store`] for more details. @@ -476,6 +494,32 @@ impl SessionContext { self.execute_logical_plan(plan).await } + /// Creates logical expresssions from SQL query text. + /// + /// # Example: Parsing SQL queries + /// + /// ``` + /// # use arrow::datatypes::{DataType, Field, Schema}; + /// # use datafusion::prelude::*; + /// # use datafusion_common::{DFSchema, Result}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// // datafusion will parse number as i64 first. + /// let sql = "a > 10"; + /// let expected = col("a").gt(lit(10 as i64)); + /// // provide type information that `a` is an Int32 + /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + /// let df_schema = DFSchema::try_from(schema).unwrap(); + /// let expr = SessionContext::new() + /// .parse_sql_expr(sql, &df_schema)?; + /// assert_eq!(expected, expr); + /// # Ok(()) + /// # } + /// ``` + pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result { + self.state.read().create_logical_expr(sql, df_schema) + } + /// Execute the [`LogicalPlan`], return a [`DataFrame`]. This API /// is not featured limited (so all SQL such as `CREATE TABLE` and /// `COPY` will be run). @@ -892,6 +936,7 @@ impl SessionContext { dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some(); dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some(); dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some(); + dropped |= self.state.write().deregister_udtf(&stmt.name)?.is_some(); // DROP FUNCTION IF EXISTS drops the specified function only if that // function exists and in this way, it avoids error. While the DROP FUNCTION @@ -972,6 +1017,11 @@ impl SessionContext { self.state.write().deregister_udwf(name).ok(); } + /// Deregisters a UDTF within this context. + pub fn deregister_udtf(&self, name: &str) { + self.state.write().deregister_udtf(name).ok(); + } + /// Creates a [`DataFrame`] for reading a data source. /// /// For more control such as reading multiple files, you can use @@ -1230,6 +1280,20 @@ impl SessionContext { Ok(DataFrame::new(self.state(), plan)) } + /// Retrieves a [`TableFunction`] reference by name. + /// + /// Returns an error if no table function has been registered with the provided name. + /// + /// [`register_udtf`]: SessionContext::register_udtf + pub fn table_function(&self, name: &str) -> Result> { + self.state + .read() + .table_functions() + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Table function '{name}' not found")) + } + /// Return a [`TableProvider`] for the specified table. pub async fn table_provider<'a>( &self, @@ -1267,6 +1331,11 @@ impl SessionContext { state } + /// Get reference to [`SessionState`] + pub fn state_ref(&self) -> Arc> { + self.state.clone() + } + /// Get weak reference to [`SessionState`] pub fn state_weak_ref(&self) -> Weak> { Arc::downgrade(&self.state) @@ -1795,7 +1864,7 @@ mod tests { let catalog_list_weak = { let state = ctx.state.read(); - Arc::downgrade(&state.catalog_list()) + Arc::downgrade(state.catalog_list()) }; drop(ctx); diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 1df77a1f9e0b..0b880ddbf81b 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -25,6 +25,13 @@ use crate::catalog::{ MemoryCatalogProviderList, }; use crate::datasource::cte_worktable::CteWorkTable; +use crate::datasource::file_format::arrow::ArrowFormatFactory; +use crate::datasource::file_format::avro::AvroFormatFactory; +use crate::datasource::file_format::csv::CsvFormatFactory; +use crate::datasource::file_format::json::JsonFormatFactory; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormatFactory; +use crate::datasource::file_format::{format_as_file_type, FileFormatFactory}; use crate::datasource::function::{TableFunction, TableFunctionImpl}; use crate::datasource::provider::{DefaultTableFactory, TableProviderFactory}; use crate::datasource::provider_as_source; @@ -41,10 +48,11 @@ use chrono::{DateTime, Utc}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ - not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, - TableReference, + config_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, + ResolvedTableReference, TableReference, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::object_store::ObjectStoreUrl; @@ -67,7 +75,8 @@ use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use datafusion_sql::parser::{DFParser, Statement}; -use datafusion_sql::planner::{ContextProvider, ParserOptions, SqlToRel}; +use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; +use sqlparser::ast::Expr as SQLExpr; use sqlparser::dialect::dialect_from_str; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; @@ -108,6 +117,8 @@ pub struct SessionState { window_functions: HashMap>, /// Deserializer registry for extensions. serializer_registry: Arc, + /// Holds registered external FileFormat implementations + file_formats: HashMap>, /// Session configuration config: SessionConfig, /// Table options @@ -229,6 +240,7 @@ impl SessionState { aggregate_functions: HashMap::new(), window_functions: HashMap::new(), serializer_registry: Arc::new(EmptySerializerRegistry), + file_formats: HashMap::new(), table_options: TableOptions::default_from_session_config(config.options()), config, execution_props: ExecutionProps::new(), @@ -237,6 +249,37 @@ impl SessionState { function_factory: None, }; + #[cfg(feature = "parquet")] + if let Err(e) = + new_self.register_file_format(Arc::new(ParquetFormatFactory::new()), false) + { + log::info!("Unable to register default ParquetFormat: {e}") + }; + + if let Err(e) = + new_self.register_file_format(Arc::new(JsonFormatFactory::new()), false) + { + log::info!("Unable to register default JsonFormat: {e}") + }; + + if let Err(e) = + new_self.register_file_format(Arc::new(CsvFormatFactory::new()), false) + { + log::info!("Unable to register default CsvFormat: {e}") + }; + + if let Err(e) = + new_self.register_file_format(Arc::new(ArrowFormatFactory::new()), false) + { + log::info!("Unable to register default ArrowFormat: {e}") + }; + + if let Err(e) = + new_self.register_file_format(Arc::new(AvroFormatFactory::new()), false) + { + log::info!("Unable to register default AvroFormat: {e}") + }; + // register built in functions functions::register_all(&mut new_self) .expect("can not register built in functions"); @@ -384,9 +427,9 @@ impl SessionState { /// Add `analyzer_rule` to the end of the list of /// [`AnalyzerRule`]s used to rewrite queries. pub fn add_analyzer_rule( - mut self, + &mut self, analyzer_rule: Arc, - ) -> Self { + ) -> &Self { self.analyzer.rules.push(analyzer_rule); self } @@ -401,6 +444,16 @@ impl SessionState { self } + // the add_optimizer_rule takes an owned reference + // it should probably be renamed to `with_optimizer_rule` to follow builder style + // and `add_optimizer_rule` that takes &mut self added instead of this + pub(crate) fn append_optimizer_rule( + &mut self, + optimizer_rule: Arc, + ) { + self.optimizer.rules.push(optimizer_rule); + } + /// Add `physical_optimizer_rule` to the end of the list of /// [`PhysicalOptimizerRule`]s used to rewrite queries. pub fn add_physical_optimizer_rule( @@ -490,6 +543,27 @@ impl SessionState { Ok(statement) } + /// parse a sql string into a sqlparser-rs AST [`SQLExpr`]. + /// + /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + pub fn sql_to_expr( + &self, + sql: &str, + dialect: &str, + ) -> datafusion_common::Result { + let dialect = dialect_from_str(dialect).ok_or_else(|| { + plan_datafusion_err!( + "Unsupported SQL dialect: {dialect}. Available dialects: \ + Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ + MsSQL, ClickHouse, BigQuery, Ansi." + ) + })?; + + let expr = DFParser::parse_sql_into_expr_with_dialect(sql, dialect.as_ref())?; + + Ok(expr) + } + /// Resolve all table references in the SQL statement. Does not include CTE references. /// /// See [`catalog::resolve_table_references`] for more information. @@ -520,10 +594,6 @@ impl SessionState { tables: HashMap::with_capacity(references.len()), }; - let enable_ident_normalization = - self.config.options().sql_parser.enable_ident_normalization; - let parse_float_as_decimal = - self.config.options().sql_parser.parse_float_as_decimal; for reference in references { let resolved = &self.resolve_table_ref(reference); if let Entry::Vacant(v) = provider.tables.entry(resolved.to_string()) { @@ -535,16 +605,20 @@ impl SessionState { } } - let query = SqlToRel::new_with_options( - &provider, - ParserOptions { - parse_float_as_decimal, - enable_ident_normalization, - }, - ); + let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); query.statement_to_plan(statement) } + fn get_parser_options(&self) -> ParserOptions { + let sql_parser_options = &self.config.options().sql_parser; + + ParserOptions { + parse_float_as_decimal: sql_parser_options.parse_float_as_decimal, + enable_ident_normalization: sql_parser_options.enable_ident_normalization, + support_varchar_with_length: sql_parser_options.support_varchar_with_length, + } + } + /// Creates a [`LogicalPlan`] from the provided SQL string. This /// interface will plan any SQL DataFusion supports, including DML /// like `CREATE TABLE`, and `COPY` (which can write to local @@ -567,6 +641,28 @@ impl SessionState { Ok(plan) } + /// Creates a datafusion style AST [`Expr`] from a SQL string. + /// + /// See example on [SessionContext::parse_sql_expr](crate::execution::context::SessionContext::parse_sql_expr) + pub fn create_logical_expr( + &self, + sql: &str, + df_schema: &DFSchema, + ) -> datafusion_common::Result { + let dialect = self.config.options().sql_parser.dialect.as_str(); + + let sql_expr = self.sql_to_expr(sql, dialect)?; + + let provider = SessionContextProvider { + state: self, + tables: HashMap::new(), + }; + + let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); + + query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) + } + /// Optimizes the logical plan by applying optimizer rules. pub fn optimize(&self, plan: &LogicalPlan) -> datafusion_common::Result { if let LogicalPlan::Explain(e) = plan { @@ -758,14 +854,39 @@ impl SessionState { self.table_options.extensions.insert(extension) } + /// Adds or updates a [FileFormatFactory] which can be used with COPY TO or CREATE EXTERNAL TABLE statements for reading + /// and writing files of custom formats. + pub fn register_file_format( + &mut self, + file_format: Arc, + overwrite: bool, + ) -> Result<(), DataFusionError> { + let ext = file_format.get_ext().to_lowercase(); + match (self.file_formats.entry(ext.clone()), overwrite){ + (Entry::Vacant(e), _) => {e.insert(file_format);}, + (Entry::Occupied(mut e), true) => {e.insert(file_format);}, + (Entry::Occupied(_), false) => return config_err!("File type already registered for extension {ext}. Set overwrite to true to replace this extension."), + }; + Ok(()) + } + + /// Retrieves a [FileFormatFactory] based on file extension which has been registered + /// via SessionContext::register_file_format. Extensions are not case sensitive. + pub fn get_file_format_factory( + &self, + ext: &str, + ) -> Option> { + self.file_formats.get(&ext.to_lowercase()).cloned() + } + /// Get a new TaskContext to run in this session pub fn task_ctx(&self) -> Arc { Arc::new(TaskContext::from(self)) } /// Return catalog list - pub fn catalog_list(&self) -> Arc { - self.catalog_list.clone() + pub fn catalog_list(&self) -> &Arc { + &self.catalog_list } /// set the catalog list @@ -791,9 +912,14 @@ impl SessionState { &self.window_functions } + /// Return reference to table_functions + pub fn table_functions(&self) -> &HashMap> { + &self.table_functions + } + /// Return [SerializerRegistry] for extensions - pub fn serializer_registry(&self) -> Arc { - self.serializer_registry.clone() + pub fn serializer_registry(&self) -> &Arc { + &self.serializer_registry } /// Return version of the cargo package that produced this query @@ -808,6 +934,15 @@ impl SessionState { Arc::new(TableFunction::new(name.to_owned(), fun)), ); } + + /// Deregsiter a user defined table function + pub fn deregister_udtf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udtf = self.table_functions.remove(name); + Ok(udtf.map(|x| x.function().clone())) + } } struct SessionContextProvider<'a> { @@ -900,6 +1035,16 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { fn udwf_names(&self) -> Vec { self.state.window_functions().keys().cloned().collect() } + + fn get_file_type(&self, ext: &str) -> datafusion_common::Result> { + self.state + .file_formats + .get(&ext.to_lowercase()) + .ok_or(plan_datafusion_err!( + "There is no registered file format with ext {ext}" + )) + .map(|file_type| format_as_file_type(file_type.clone())) + } } impl FunctionRegistry for SessionState { diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index a081a822a890..d81efaf68ca3 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -17,16 +17,16 @@ #![warn(missing_docs, clippy::needless_borrow)] //! [DataFusion] is an extensible query engine written in Rust that -//! uses [Apache Arrow] as its in-memory format. DataFusion's many [use -//! cases] help developers build very fast and feature rich database -//! and analytic systems, customized to particular workloads. +//! uses [Apache Arrow] as its in-memory format. DataFusion help developers +//! build fast and feature rich database and analytic systems, customized to +//! particular workloads. See [use cases] for examples //! //! "Out of the box," DataFusion quickly runs complex [SQL] and -//! [`DataFrame`] queries using a sophisticated query planner, a columnar, -//! multi-threaded, vectorized execution engine, and partitioned data +//! [`DataFrame`] queries using a full-featured query planner, a columnar, +//! streaming, multi-threaded, vectorized execution engine, and partitioned data //! sources (Parquet, CSV, JSON, and Avro). //! -//! DataFusion is designed for easy customization such as supporting +//! DataFusion is designed for easy customization such as //! additional data sources, query languages, functions, custom //! operators and more. See the [Architecture] section for more details. //! @@ -130,11 +130,51 @@ //! //! [datafusion-examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples //! +//! # Architecture +//! +//! +//! +//! You can find a formal description of DataFusion's architecture in our +//! [SIGMOD 2024 Paper]. +//! +//! [SIGMOD 2024 Paper]: https://dl.acm.org/doi/10.1145/3626246.3653368 +//! +//! ## Design Goals +//! DataFusion's Architecture Goals are: +//! +//! 1. Work “out of the box”: Provide a very fast, world class query engine with +//! minimal setup or required configuration. +//! +//! 2. Customizable everything: All behavior should be customizable by +//! implementing traits. +//! +//! 3. Architecturally boring 🥱: Follow industrial best practice rather than +//! trying cutting edge, but unproven, techniques. +//! +//! With these principles, users start with a basic, high-performance engine +//! and specialize it over time to suit their needs and available engineering +//! capacity. +//! +//! ## Overview Presentations +//! +//! The following presentations offer high level overviews of the +//! different components and how they interact together. +//! +//! - [Apr 2023]: The Apache DataFusion Architecture talks +//! - _Query Engine_: [recording](https://youtu.be/NVKujPxwSBA) and [slides](https://docs.google.com/presentation/d/1D3GDVas-8y0sA4c8EOgdCvEjVND4s2E7I6zfs67Y4j8/edit#slide=id.p) +//! - _Logical Plan and Expressions_: [recording](https://youtu.be/EzZTLiSJnhY) and [slides](https://docs.google.com/presentation/d/1ypylM3-w60kVDW7Q6S99AHzvlBgciTdjsAfqNP85K30) +//! - _Physical Plan and Execution_: [recording](https://youtu.be/2jkWU3_w6z0) and [slides](https://docs.google.com/presentation/d/1cA2WQJ2qg6tx6y4Wf8FH2WVSm9JQ5UgmBWATHdik0hg) +//! - [July 2022]: DataFusion and Arrow: Supercharge Your Data Analytical Tool with a Rusty Query Engine: [recording](https://www.youtube.com/watch?v=Rii1VTn3seQ) and [slides](https://docs.google.com/presentation/d/1q1bPibvu64k2b7LPi7Yyb0k3gA1BiUYiUbEklqW1Ckc/view#slide=id.g11054eeab4c_0_1165) +//! - [March 2021]: The DataFusion architecture is described in _Query Engine Design and the Rust-Based DataFusion in Apache Arrow_: [recording](https://www.youtube.com/watch?v=K6eCAVEk4kU) (DataFusion content starts [~ 15 minutes in](https://www.youtube.com/watch?v=K6eCAVEk4kU&t=875s)) and [slides](https://www.slideshare.net/influxdata/influxdb-iox-tech-talks-query-engine-design-and-the-rustbased-datafusion-in-apache-arrow-244161934) +//! - [February 2021]: How DataFusion is used within the Ballista Project is described in _Ballista: Distributed Compute with Rust and Apache Arrow_: [recording](https://www.youtube.com/watch?v=ZZHQaOap9pQ) +//! //! ## Customization and Extension //! -//! DataFusion is a "disaggregated" query engine. This -//! means developers can start with a working, full featured engine, and then -//! extend the areas they need to specialize for their usecase. For example, +//! DataFusion is designed to be highly extensible, so you can +//! start with a working, full featured engine, and then +//! specialize any behavior for their usecase. For example, //! some projects may add custom [`ExecutionPlan`] operators, or create their own //! query language that directly creates [`LogicalPlan`] rather than using the //! built in SQL planner, [`SqlToRel`]. @@ -161,30 +201,6 @@ //! [`AnalyzerRule`]: datafusion_optimizer::analyzer::AnalyzerRule //! [`PhysicalOptimizerRule`]: crate::physical_optimizer::optimizer::PhysicalOptimizerRule //! -//! # Architecture -//! -//! -//! -//! You can find a formal description of DataFusion's architecture in our -//! [SIGMOD 2024 Paper]. -//! -//! [SIGMOD 2024 Paper]: https://dl.acm.org/doi/10.1145/3626246.3653368 -//! -//! ## Overview Presentations -//! -//! The following presentations offer high level overviews of the -//! different components and how they interact together. -//! -//! - [Apr 2023]: The Apache DataFusion Architecture talks -//! - _Query Engine_: [recording](https://youtu.be/NVKujPxwSBA) and [slides](https://docs.google.com/presentation/d/1D3GDVas-8y0sA4c8EOgdCvEjVND4s2E7I6zfs67Y4j8/edit#slide=id.p) -//! - _Logical Plan and Expressions_: [recording](https://youtu.be/EzZTLiSJnhY) and [slides](https://docs.google.com/presentation/d/1ypylM3-w60kVDW7Q6S99AHzvlBgciTdjsAfqNP85K30) -//! - _Physical Plan and Execution_: [recording](https://youtu.be/2jkWU3_w6z0) and [slides](https://docs.google.com/presentation/d/1cA2WQJ2qg6tx6y4Wf8FH2WVSm9JQ5UgmBWATHdik0hg) -//! - [July 2022]: DataFusion and Arrow: Supercharge Your Data Analytical Tool with a Rusty Query Engine: [recording](https://www.youtube.com/watch?v=Rii1VTn3seQ) and [slides](https://docs.google.com/presentation/d/1q1bPibvu64k2b7LPi7Yyb0k3gA1BiUYiUbEklqW1Ckc/view#slide=id.g11054eeab4c_0_1165) -//! - [March 2021]: The DataFusion architecture is described in _Query Engine Design and the Rust-Based DataFusion in Apache Arrow_: [recording](https://www.youtube.com/watch?v=K6eCAVEk4kU) (DataFusion content starts [~ 15 minutes in](https://www.youtube.com/watch?v=K6eCAVEk4kU&t=875s)) and [slides](https://www.slideshare.net/influxdata/influxdb-iox-tech-talks-query-engine-design-and-the-rustbased-datafusion-in-apache-arrow-244161934) -//! - [February 2021]: How DataFusion is used within the Ballista Project is described in _Ballista: Distributed Compute with Rust and Apache Arrow_: [recording](https://www.youtube.com/watch?v=ZZHQaOap9pQ) -//! //! ## Query Planning and Execution Overview //! //! ### SQL diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index ca1582bcb34f..7e9aec9e5e4c 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -141,7 +141,7 @@ fn take_optimizable_column_and_table_count( ) -> Option<(ScalarValue, String)> { let col_stats = &stats.column_statistics; if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name() == "COUNT" && !agg_expr.is_distinct() { + if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() { if let Precision::Exact(num_rows) = stats.num_rows { let exprs = agg_expr.expressions(); if exprs.len() == 1 { diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 4dd62a894518..e8f2f34abda0 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -471,8 +471,10 @@ pub struct PruningPredicate { /// Original physical predicate from which this predicate expr is derived /// (required for serialization) orig_expr: Arc, - /// [`LiteralGuarantee`]s that are used to try and prove a predicate can not - /// possibly evaluate to `true`. + /// [`LiteralGuarantee`]s used to try and prove a predicate can not possibly + /// evaluate to `true`. + /// + /// See [`PruningPredicate::literal_guarantees`] for more details. literal_guarantees: Vec, } @@ -595,6 +597,10 @@ impl PruningPredicate { } /// Returns a reference to the literal guarantees + /// + /// Note that **All** `LiteralGuarantee`s must be satisfied for the + /// expression to possibly be `true`. If any is not satisfied, the + /// expression is guaranteed to be `null` or `false`. pub fn literal_guarantees(&self) -> &[LiteralGuarantee] { &self.literal_guarantees } @@ -981,8 +987,8 @@ impl<'a> PruningExpressionBuilder<'a> { }) } - fn op(&self) -> Operator { - self.op + fn op(&self) -> &Operator { + &self.op } fn scalar_expr(&self) -> &Arc { @@ -1058,7 +1064,7 @@ fn rewrite_expr_to_prunable( scalar_expr: &PhysicalExprRef, schema: DFSchema, ) -> Result<(PhysicalExprRef, Operator, PhysicalExprRef)> { - if !is_compare_op(op) { + if !is_compare_op(&op) { return plan_err!("rewrite_expr_to_prunable only support compare expression"); } @@ -1125,7 +1131,7 @@ fn rewrite_expr_to_prunable( } } -fn is_compare_op(op: Operator) -> bool { +fn is_compare_op(op: &Operator) -> bool { matches!( op, Operator::Eq @@ -1352,11 +1358,13 @@ fn build_predicate_expression( .map(|e| { Arc::new(phys_expr::BinaryExpr::new( in_list.expr().clone(), - eq_op, + eq_op.clone(), e.clone(), )) as _ }) - .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _) + .reduce(|a, b| { + Arc::new(phys_expr::BinaryExpr::new(a, re_op.clone(), b)) as _ + }) .unwrap(); return build_predicate_expression(&change_expr, schema, required_columns); } else { @@ -1368,7 +1376,7 @@ fn build_predicate_expression( if let Some(bin_expr) = expr_any.downcast_ref::() { ( bin_expr.left().clone(), - *bin_expr.op(), + bin_expr.op().clone(), bin_expr.right().clone(), ) } else { @@ -1380,7 +1388,7 @@ fn build_predicate_expression( let left_expr = build_predicate_expression(&left, schema, required_columns); let right_expr = build_predicate_expression(&right, schema, required_columns); // simplify boolean expression if applicable - let expr = match (&left_expr, op, &right_expr) { + let expr = match (&left_expr, &op, &right_expr) { (left, Operator::And, _) if is_always_true(left) => right_expr, (_, Operator::And, right) if is_always_true(right) => left_expr, (left, Operator::Or, right) @@ -1388,7 +1396,11 @@ fn build_predicate_expression( { unhandled } - _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), + _ => Arc::new(phys_expr::BinaryExpr::new( + left_expr, + op.clone(), + right_expr, + )), }; return expr; } @@ -1543,22 +1555,22 @@ pub(crate) enum StatisticsType { #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::ops::{Not, Rem}; + use super::*; use crate::assert_batches_eq; use crate::logical_expr::{col, lit}; + use arrow::array::Decimal128Array; use arrow::{ array::{BinaryArray, Int32Array, Int64Array, StringArray}, datatypes::TimeUnit, }; use arrow_array::UInt64Array; - use datafusion_common::ToDFSchema; - use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; - use datafusion_physical_expr::create_physical_expr; - use std::collections::HashMap; - use std::ops::{Not, Rem}; + use datafusion_physical_expr::planner::logical2physical; #[derive(Debug, Default)] /// Mock statistic provider for tests @@ -3864,10 +3876,4 @@ mod tests { let expr = logical2physical(expr, schema); build_predicate_expression(&expr, schema, required_columns) } - - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() - } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 404bcbb2e7d4..5b8501baaad8 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -22,13 +22,7 @@ use std::collections::HashMap; use std::fmt::Write; use std::sync::Arc; -use crate::datasource::file_format::arrow::ArrowFormat; -use crate::datasource::file_format::avro::AvroFormat; -use crate::datasource::file_format::csv::CsvFormat; -use crate::datasource::file_format::json::JsonFormat; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::FileFormat; +use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; use crate::datasource::source_as_provider; @@ -74,11 +68,10 @@ use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use arrow_array::builder::StringBuilder; use arrow_array::RecordBatch; -use datafusion_common::config::FormatOptions; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, - FileType, ScalarValue, + ScalarValue, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -156,13 +149,18 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::Case(case) => { let mut name = "CASE ".to_string(); if let Some(e) = &case.expr { - let _ = write!(name, "{e} "); + let _ = write!(name, "{} ", create_physical_name(e, false)?); } for (w, t) in &case.when_then_expr { - let _ = write!(name, "WHEN {w} THEN {t} "); + let _ = write!( + name, + "WHEN {} THEN {} ", + create_physical_name(w, false)?, + create_physical_name(t, false)? + ); } if let Some(e) = &case.else_expr { - let _ = write!(name, "ELSE {e} "); + let _ = write!(name, "ELSE {} ", create_physical_name(e, false)?); } name += "END"; Ok(name) @@ -759,7 +757,7 @@ impl DefaultPhysicalPlanner { LogicalPlan::Copy(CopyTo { input, output_url, - format_options, + file_type, partition_by, options: source_option_tuples, }) => { @@ -786,32 +784,9 @@ impl DefaultPhysicalPlanner { table_partition_cols, overwrite: false, }; - let mut table_options = session_state.default_table_options(); - let sink_format: Arc = match format_options { - FormatOptions::CSV(options) => { - table_options.csv = options.clone(); - table_options.set_file_format(FileType::CSV); - table_options.alter_with_string_hash_map(source_option_tuples)?; - Arc::new(CsvFormat::default().with_options(table_options.csv)) - } - FormatOptions::JSON(options) => { - table_options.json = options.clone(); - table_options.set_file_format(FileType::JSON); - table_options.alter_with_string_hash_map(source_option_tuples)?; - Arc::new(JsonFormat::default().with_options(table_options.json)) - } - #[cfg(feature = "parquet")] - FormatOptions::PARQUET(options) => { - table_options.parquet = options.clone(); - table_options.set_file_format(FileType::PARQUET); - table_options.alter_with_string_hash_map(source_option_tuples)?; - Arc::new( - ParquetFormat::default().with_options(table_options.parquet), - ) - } - FormatOptions::AVRO => Arc::new(AvroFormat {}), - FormatOptions::ARROW => Arc::new(ArrowFormat {}), - }; + + let sink_format = file_type_to_format(file_type)? + .create(session_state, source_option_tuples)?; sink_format .create_writer_physical_plan(input_exec, session_state, config, None) @@ -1259,7 +1234,7 @@ impl DefaultPhysicalPlanner { let join_filter = match filter { Some(expr) => { // Extract columns from filter expression and saved in a HashSet - let cols = expr.to_columns()?; + let cols = expr.column_refs(); // Collect left & right field indices, the field indices are sorted in ascending order let left_field_indices = cols @@ -1918,6 +1893,7 @@ pub fn create_aggregate_expr_and_maybe_filter( // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" let (name, e) = match e { Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), + Expr::AggregateFunction(_) => (e.display_name().unwrap_or(physical_name(e)?), e), _ => (physical_name(e)?, e), }; diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index e91f83f1199b..e8550a79cb0e 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -24,9 +24,9 @@ use std::io::{BufReader, BufWriter}; use std::path::Path; use std::sync::Arc; -use crate::datasource::file_format::file_compression_type::{ - FileCompressionType, FileTypeExt, -}; +use crate::datasource::file_format::csv::CsvFormat; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::FileFormat; use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; @@ -40,7 +40,7 @@ use crate::test_util::{aggr_test_schema, arrow_test_data}; use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, FileType, Statistics}; +use datafusion_common::{DataFusionError, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalSortExpr}; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; @@ -87,7 +87,7 @@ pub fn scan_partitioned_csv(partitions: usize, work_dir: &Path) -> Result, file_compression_type: FileCompressionType, work_dir: &Path, ) -> Result>> { @@ -120,9 +120,8 @@ pub fn partitioned_file_groups( let filename = format!( "partition-{}{}", i, - file_type - .to_owned() - .get_ext_with_compression(file_compression_type.to_owned()) + file_format + .get_ext_with_compression(&file_compression_type) .unwrap() ); let filename = work_dir.join(filename); @@ -167,7 +166,7 @@ pub fn partitioned_file_groups( for (i, line) in f.lines().enumerate() { let line = line.unwrap(); - if i == 0 && file_type == FileType::CSV { + if i == 0 && file_format.get_ext() == CsvFormat::default().get_ext() { // write header to all partitions for w in writers.iter_mut() { w.write_all(line.as_bytes()).unwrap(); @@ -399,6 +398,10 @@ impl DisplayAs for StatisticsExec { } impl ExecutionPlan for StatisticsExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index e876cfe46547..059fa8fc6da7 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -285,6 +285,10 @@ impl DisplayAs for UnboundedExec { } impl ExecutionPlan for UnboundedExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index e8ead01d2ee4..eebc946ccb68 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -140,6 +140,10 @@ impl DisplayAs for CustomExecutionPlan { } impl ExecutionPlan for CustomExecutionPlan { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -278,7 +282,7 @@ async fn optimizers_catch_all_statistics() { let expected = RecordBatch::try_new( Arc::new(Schema::new(vec![ - Field::new("COUNT(*)", DataType::Int64, false), + Field::new("count(*)", DataType::Int64, false), Field::new("MIN(test.c1)", DataType::Int32, false), Field::new("MAX(test.c1)", DataType::Int32, false), ])), diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 068383b20031..b5506b7c12f6 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -94,6 +94,10 @@ impl DisplayAs for CustomPlan { } impl ExecutionPlan for CustomPlan { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index c7be89533f1d..2d42b03bfed8 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -145,6 +145,10 @@ impl DisplayAs for StatisticsValidation { } impl ExecutionPlan for StatisticsValidation { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/core/tests/data/double_quote.csv b/datafusion/core/tests/data/double_quote.csv new file mode 100644 index 000000000000..95a6f0c4077a --- /dev/null +++ b/datafusion/core/tests/data/double_quote.csv @@ -0,0 +1,5 @@ +c1,c2 +id0,"""value0""" +id1,"""value1""" +id2,"""value2""" +id3,"""value3""" diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index fa364c5f2a65..c3bc2fcca2b5 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{count, sum}; +use datafusion_functions_aggregate::expr_fn::{avg, count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { @@ -170,7 +170,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("select COUNT(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") + .sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") .await? .explain(false, false)? .collect() @@ -211,7 +211,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { let sql_results = ctx .sql("select count(*) from t1") .await? - .select(vec![col("COUNT(*)")])? + .select(vec![col("count(*)")])? .explain(false, false)? .collect() .await?; @@ -604,7 +604,7 @@ async fn test_grouping_sets() -> Result<()> { let expected = vec![ "+-----------+-----+---------------+", - "| a | b | COUNT(test.a) |", + "| a | b | count(test.a) |", "+-----------+-----+---------------+", "| | 100 | 1 |", "| | 10 | 2 |", @@ -645,7 +645,7 @@ async fn test_grouping_sets_count() -> Result<()> { let expected = vec![ "+----+----+-----------------+", - "| c1 | c2 | COUNT(Int32(1)) |", + "| c1 | c2 | count(Int32(1)) |", "+----+----+-----------------+", "| | 5 | 14 |", "| | 4 | 23 |", @@ -1233,7 +1233,7 @@ async fn unnest_aggregate_columns() -> Result<()> { .await?; let expected = [ r#"+-------------+"#, - r#"| COUNT(tags) |"#, + r#"| count(tags) |"#, r#"+-------------+"#, r#"| 9 |"#, r#"+-------------+"#, @@ -1386,7 +1386,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let expected = vec![ "Projection: shapes.shape_id [shape_id:UInt32]", " Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]", - " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })]", " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", ]; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 7085333bee03..f36f2d539845 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -30,6 +30,7 @@ use sqlparser::ast::NullTreatment; /// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan use std::sync::{Arc, OnceLock}; +mod parse_sql_expr; mod simplification; #[test] diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs new file mode 100644 index 000000000000..991579b5a350 --- /dev/null +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{DataType, Field, Schema}; +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use datafusion_common::{DFSchemaRef, Result, ToDFSchema}; +use datafusion_expr::Expr; +use datafusion_sql::unparser::Unparser; + +/// A schema like: +/// +/// a: Int32 (possibly with nulls) +/// b: Int32 +/// s: Float32 +fn schema() -> DFSchemaRef { + Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Float32, false), + ]) + .to_dfschema_ref() + .unwrap() +} + +#[tokio::test] +async fn round_trip_parse_sql_expr() -> Result<()> { + let tests = vec![ + "(a = 10)", + "((a = 10) AND (b <> 20))", + "((a = 10) OR (b <> 20))", + "(((a = 10) AND (b <> 20)) OR (c = a))", + "((a = 10) AND b IN (20, 30))", + "((a = 10) AND b NOT IN (20, 30))", + "sum(a)", + "(sum(a) + 1)", + "(MIN(a) + MAX(b))", + "(MIN(a) + (MAX(b) * sum(c)))", + "(MIN(a) + ((MAX(b) * sum(c)) / 10))", + ]; + + for test in tests { + round_trip_session_context(test)?; + round_trip_dataframe(test).await?; + } + + Ok(()) +} + +fn round_trip_session_context(sql: &str) -> Result<()> { + let ctx = SessionContext::new(); + let df_schema = schema(); + let expr = ctx.parse_sql_expr(sql, &df_schema)?; + let sql2 = unparse_sql_expr(&expr)?; + assert_eq!(sql, sql2); + + Ok(()) +} + +async fn round_trip_dataframe(sql: &str) -> Result<()> { + let ctx = SessionContext::new(); + let df = ctx + .read_csv( + &"tests/data/example.csv".to_string(), + CsvReadOptions::default(), + ) + .await?; + let expr = df.parse_sql_expr(sql)?; + let sql2 = unparse_sql_expr(&expr)?; + assert_eq!(sql, sql2); + + Ok(()) +} + +fn unparse_sql_expr(expr: &Expr) -> Result { + let unparser = Unparser::default(); + + let round_trip_sql = unparser.expr_to_sql(expr)?.to_string(); + Ok(round_trip_sql) +} diff --git a/datafusion/core/tests/fifo/mod.rs b/datafusion/core/tests/fifo/mod.rs index 2e21abffab87..1df97b1636c7 100644 --- a/datafusion/core/tests/fifo/mod.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -217,17 +217,6 @@ mod unix_test { .set_bool("datafusion.execution.coalesce_batches", false) .with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - // Tasks - let mut tasks: Vec> = vec![]; - - // Join filter - let a1_iter = 0..TEST_DATA_SIZE; - // Join key - let a2_iter = (0..TEST_DATA_SIZE).map(|x| x % 10); - let lines = a1_iter - .zip(a2_iter) - .map(|(a1, a2)| format!("{a1},{a2}\n")) - .collect::>(); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; @@ -238,22 +227,6 @@ mod unix_test { // Create a mutex for tracking if the right input source is waiting for data. let waiting = Arc::new(AtomicBool::new(true)); - // Create writing threads for the left and right FIFO files - tasks.push(create_writing_thread( - left_fifo.clone(), - "a1,a2\n".to_owned(), - lines.clone(), - waiting.clone(), - TEST_BATCH_SIZE, - )); - tasks.push(create_writing_thread( - right_fifo.clone(), - "a1,a2\n".to_owned(), - lines.clone(), - waiting.clone(), - TEST_BATCH_SIZE, - )); - // Create schema let schema = Arc::new(Schema::new(vec![ Field::new("a1", DataType::UInt32, false), @@ -264,10 +237,10 @@ mod unix_test { let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; // Set unbounded sorted files read configuration - let provider = fifo_table(schema.clone(), left_fifo, order.clone()); + let provider = fifo_table(schema.clone(), left_fifo.clone(), order.clone()); ctx.register_table("left", provider)?; - let provider = fifo_table(schema.clone(), right_fifo, order); + let provider = fifo_table(schema.clone(), right_fifo.clone(), order); ctx.register_table("right", provider)?; // Execute the query, with no matching rows. (since key is modulus 10) @@ -287,6 +260,34 @@ mod unix_test { .await?; let mut stream = df.execute_stream().await?; let mut operations = vec![]; + + // Tasks + let mut tasks: Vec> = vec![]; + + // Join filter + let a1_iter = 0..TEST_DATA_SIZE; + // Join key + let a2_iter = (0..TEST_DATA_SIZE).map(|x| x % 10); + let lines = a1_iter + .zip(a2_iter) + .map(|(a1, a2)| format!("{a1},{a2}\n")) + .collect::>(); + + // Create writing threads for the left and right FIFO files + tasks.push(create_writing_thread( + left_fifo, + "a1,a2\n".to_owned(), + lines.clone(), + waiting.clone(), + TEST_BATCH_SIZE, + )); + tasks.push(create_writing_thread( + right_fifo, + "a1,a2\n".to_owned(), + lines.clone(), + waiting.clone(), + TEST_BATCH_SIZE, + )); // Partial. while let Some(Ok(batch)) = stream.next().await { waiting.store(false, Ordering::SeqCst); diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 5fdf02079496..17dbf3a0ff28 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -215,10 +215,6 @@ async fn test_semi_join_1k() { .await } -// The test is flaky -// https://github.com/apache/datafusion/issues/10886 -// SMJ produces 1 more row in the output -#[ignore] #[tokio::test] async fn test_semi_join_1k_filtered() { JoinFuzzTestCase::new( @@ -442,18 +438,45 @@ impl JoinFuzzTestCase { if debug { println!("The debug is ON. Input data will be saved"); - let out_dir_name = &format!("fuzz_test_debug_batch_size_{batch_size}"); - Self::save_as_parquet(&self.input1, out_dir_name, "input1"); - Self::save_as_parquet(&self.input2, out_dir_name, "input2"); + let fuzz_debug = "fuzz_test_debug"; + std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); + std::fs::create_dir_all(fuzz_debug).unwrap(); + let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); + Self::save_partitioned_batches_as_parquet( + &self.input1, + out_dir_name, + "input1", + ); + Self::save_partitioned_batches_as_parquet( + &self.input2, + out_dir_name, + "input2", + ); if join_tests.contains(&JoinTestType::NljHj) { - Self::save_as_parquet(&nlj_collected, out_dir_name, "nlj"); - Self::save_as_parquet(&hj_collected, out_dir_name, "hj"); + Self::save_partitioned_batches_as_parquet( + &nlj_collected, + out_dir_name, + "nlj", + ); + Self::save_partitioned_batches_as_parquet( + &hj_collected, + out_dir_name, + "hj", + ); } if join_tests.contains(&JoinTestType::HjSmj) { - Self::save_as_parquet(&hj_collected, out_dir_name, "hj"); - Self::save_as_parquet(&smj_collected, out_dir_name, "smj"); + Self::save_partitioned_batches_as_parquet( + &hj_collected, + out_dir_name, + "hj", + ); + Self::save_partitioned_batches_as_parquet( + &smj_collected, + out_dir_name, + "smj", + ); } } @@ -527,11 +550,26 @@ impl JoinFuzzTestCase { /// as a parquet files preserving partitioning. /// Once the data is saved it is possible to run a custom test on top of the saved data and debug /// + /// #[tokio::test] + /// async fn test1() { + /// let left: Vec = JoinFuzzTestCase::load_partitioned_batches_from_parquet("fuzz_test_debug/batch_size_2/input1").await.unwrap(); + /// let right: Vec = JoinFuzzTestCase::load_partitioned_batches_from_parquet("fuzz_test_debug/batch_size_2/input2").await.unwrap(); + /// + /// JoinFuzzTestCase::new( + /// left, + /// right, + /// JoinType::LeftSemi, + /// Some(Box::new(col_lt_col_filter)), + /// ) + /// .run_test(&[JoinTestType::HjSmj], false) + /// .await + /// } + /// /// let ctx: SessionContext = SessionContext::new(); /// let df = ctx /// .read_parquet( /// "/tmp/input1/*.parquet", - /// ParquetReadOptions::default(), + /// datafusion::prelude::ParquetReadOptions::default(), /// ) /// .await /// .unwrap(); @@ -540,7 +578,7 @@ impl JoinFuzzTestCase { /// let df = ctx /// .read_parquet( /// "/tmp/input2/*.parquet", - /// ParquetReadOptions::default(), + /// datafusion::prelude::ParquetReadOptions::default(), /// ) /// .await /// .unwrap(); @@ -554,8 +592,11 @@ impl JoinFuzzTestCase { /// ) /// .run_test() /// .await - /// } - fn save_as_parquet(input: &[RecordBatch], output_dir: &str, out_name: &str) { + fn save_partitioned_batches_as_parquet( + input: &[RecordBatch], + output_dir: &str, + out_name: &str, + ) { let out_path = &format!("{output_dir}/{out_name}"); std::fs::remove_dir_all(out_path).unwrap_or(()); std::fs::create_dir_all(out_path).unwrap(); @@ -576,6 +617,39 @@ impl JoinFuzzTestCase { println!("The data {out_name} saved as parquet into {out_path}"); } + + /// Read parquet files preserving partitions, i.e. 1 file -> 1 partition + /// Files can be of different sizes + /// The method can be useful to read partitions have been saved by `save_partitioned_batches_as_parquet` + /// for test debugging purposes + #[allow(dead_code)] + async fn load_partitioned_batches_from_parquet( + dir: &str, + ) -> std::io::Result> { + let ctx: SessionContext = SessionContext::new(); + let mut batches: Vec = vec![]; + + for entry in std::fs::read_dir(dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_file() { + let mut batch = ctx + .read_parquet( + path.to_str().unwrap(), + datafusion::prelude::ParquetReadOptions::default(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + batches.append(&mut batch); + } + } + Ok(batches) + } } /// Return randomly sized record batches with: diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index ddb39fce4076..47f079063d3c 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -747,7 +747,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "nanos", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -776,7 +776,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "nanos_timezoned", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -798,7 +798,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "micros", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -827,7 +827,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "micros_timezoned", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -849,7 +849,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "millis", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -878,7 +878,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "millis_timezoned", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -900,7 +900,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "seconds", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -929,7 +929,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "seconds_timezoned", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -975,7 +975,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), column_name: "nanos", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1002,7 +1002,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), column_name: "nanos_timezoned", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1022,7 +1022,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), column_name: "micros", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1049,7 +1049,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), column_name: "micros_timezoned", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1069,7 +1069,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), column_name: "millis", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1096,7 +1096,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), column_name: "millis_timezoned", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1116,7 +1116,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), column_name: "seconds", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1143,7 +1143,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: Some(UInt64Array::from(vec![8, 8, 4])), column_name: "seconds_timezoned", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1181,7 +1181,7 @@ async fn test_dates_32_diff_rg_sizes() { // row counts are [13, 7] expected_row_counts: Some(UInt64Array::from(vec![13, 7])), column_name: "date32", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1324,7 +1324,7 @@ async fn test_dates_64_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![2, 2]), expected_row_counts: Some(UInt64Array::from(vec![13, 7])), column_name: "date64", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1683,7 +1683,7 @@ async fn test_decimal() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), column_name: "decimal_col", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1721,7 +1721,7 @@ async fn test_decimal_256() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), column_name: "decimal256_col", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1801,7 +1801,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), column_name: "name", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1821,7 +1821,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), column_name: "service_string", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1840,7 +1840,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), column_name: "service_binary", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1882,7 +1882,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), column_name: "service_large_binary", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1953,7 +1953,7 @@ async fn test_boolean() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5])), column_name: "bool", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -2003,7 +2003,7 @@ async fn test_utf8() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5])), column_name: "utf8", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -2015,7 +2015,7 @@ async fn test_utf8() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5])), column_name: "large_utf8", - check: Check::RowGroup, + check: Check::Both, } .run(); } diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index bfc5b59f0952..7e7544bdb7c0 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -120,7 +120,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { //3. limit is not contained within a single partition //The id column is included to ensure that the parquet file is actually scanned. let results = ctx - .sql("SELECT COUNT(*) as num_rows_per_month, month, MAX(id) from t group by month order by num_rows_per_month desc") + .sql("SELECT count(*) as num_rows_per_month, month, MAX(id) from t group by month order by num_rows_per_month desc") .await? .collect() .await?; @@ -339,7 +339,7 @@ async fn csv_grouping_by_partition() -> Result<()> { let expected = [ "+------------+----------+----------------------+", - "| date | COUNT(*) | COUNT(DISTINCT t.c1) |", + "| date | count(*) | count(DISTINCT t.c1) |", "+------------+----------+----------------------+", "| 2021-10-26 | 100 | 5 |", "| 2021-10-27 | 100 | 5 |", diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 84b791a3de05..e503b74992c3 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -36,7 +36,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { *actual[0].schema(), Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", - Field::new("item", DataType::UInt32, true), + Field::new("item", DataType::UInt32, false), false ),]) ); @@ -69,12 +69,12 @@ async fn csv_query_array_agg_distinct() -> Result<()> { #[tokio::test] async fn count_partitioned() -> Result<()> { let results = - execute_with_partition("SELECT COUNT(c1), COUNT(c2) FROM test", 4).await?; + execute_with_partition("SELECT count(c1), count(c2) FROM test", 4).await?; assert_eq!(results.len(), 1); let expected = [ "+----------------+----------------+", - "| COUNT(test.c1) | COUNT(test.c2) |", + "| count(test.c1) | count(test.c2) |", "+----------------+----------------+", "| 40 | 40 |", "+----------------+----------------+", @@ -86,11 +86,11 @@ async fn count_partitioned() -> Result<()> { #[tokio::test] async fn count_aggregated() -> Result<()> { let results = - execute_with_partition("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4).await?; + execute_with_partition("SELECT c1, count(c2) FROM test GROUP BY c1", 4).await?; let expected = [ "+----+----------------+", - "| c1 | COUNT(test.c2) |", + "| c1 | count(test.c2) |", "+----+----------------+", "| 0 | 10 |", "| 1 | 10 |", @@ -105,14 +105,14 @@ async fn count_aggregated() -> Result<()> { #[tokio::test] async fn count_aggregated_cube() -> Result<()> { let results = execute_with_partition( - "SELECT c1, c2, COUNT(c3) FROM test GROUP BY CUBE (c1, c2) ORDER BY c1, c2", + "SELECT c1, c2, count(c3) FROM test GROUP BY CUBE (c1, c2) ORDER BY c1, c2", 4, ) .await?; let expected = vec![ "+----+----+----------------+", - "| c1 | c2 | COUNT(test.c3) |", + "| c1 | c2 | count(test.c3) |", "+----+----+----------------+", "| | | 40 |", "| | 1 | 4 |", @@ -222,15 +222,15 @@ async fn run_count_distinct_integers_aggregated_scenario( " SELECT c_group, - COUNT(c_uint64), - COUNT(DISTINCT c_int8), - COUNT(DISTINCT c_int16), - COUNT(DISTINCT c_int32), - COUNT(DISTINCT c_int64), - COUNT(DISTINCT c_uint8), - COUNT(DISTINCT c_uint16), - COUNT(DISTINCT c_uint32), - COUNT(DISTINCT c_uint64) + count(c_uint64), + count(DISTINCT c_int8), + count(DISTINCT c_int16), + count(DISTINCT c_int32), + count(DISTINCT c_int64), + count(DISTINCT c_uint8), + count(DISTINCT c_uint16), + count(DISTINCT c_uint32), + count(DISTINCT c_uint64) FROM test GROUP BY c_group ", @@ -260,7 +260,7 @@ async fn count_distinct_integers_aggregated_single_partition() -> Result<()> { let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; let expected = ["+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", + "| c_group | count(test.c_uint64) | count(DISTINCT test.c_int8) | count(DISTINCT test.c_int16) | count(DISTINCT test.c_int32) | count(DISTINCT test.c_int64) | count(DISTINCT test.c_uint8) | count(DISTINCT test.c_uint16) | count(DISTINCT test.c_uint32) | count(DISTINCT test.c_uint64) |", "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", "| a | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", "| b | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", @@ -284,7 +284,7 @@ async fn count_distinct_integers_aggregated_multiple_partitions() -> Result<()> let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; let expected = ["+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", + "| c_group | count(test.c_uint64) | count(DISTINCT test.c_int8) | count(DISTINCT test.c_int16) | count(DISTINCT test.c_int32) | count(DISTINCT test.c_int64) | count(DISTINCT test.c_uint8) | count(DISTINCT test.c_uint16) | count(DISTINCT test.c_uint32) | count(DISTINCT test.c_uint64) |", "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", "| a | 5 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 |", "| b | 5 | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 4 |", @@ -301,7 +301,7 @@ async fn test_accumulator_row_accumulator() -> Result<()> { let ctx = SessionContext::new_with_config(config); register_aggregate_csv(&ctx).await?; - let sql = "SELECT c1, c2, MIN(c13) as min1, MIN(c9) as min2, MAX(c13) as max1, MAX(c9) as max2, AVG(c9) as avg1, MIN(c13) as min3, COUNT(C9) as cnt1, 0.5*SUM(c9-c8) as sum1 + let sql = "SELECT c1, c2, MIN(c13) as min1, MIN(c9) as min2, MAX(c13) as max1, MAX(c9) as max2, AVG(c9) as avg1, MIN(c13) as min3, count(C9) as cnt1, 0.5*SUM(c9-c8) as sum1 FROM aggregate_test_100 GROUP BY c1, c2 ORDER BY c1, c2 diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 3e5a0681589c..502590f9e2e2 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -81,7 +81,7 @@ async fn explain_analyze_baseline_metrics() { ); assert_metrics!( &formatted, - "ProjectionExec: expr=[COUNT(*)", + "ProjectionExec: expr=[count(*)", "metrics=[output_rows=1, elapsed_compute=" ); assert_metrics!( @@ -700,7 +700,7 @@ async fn csv_explain_analyze() { // Only test basic plumbing and try to avoid having to change too // many things. explain_analyze_baseline_metrics covers the values // in greater depth - let needle = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(*)], metrics=[output_rows=5"; + let needle = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[count(*)], metrics=[output_rows=5"; assert_contains!(&formatted, needle); let verbose_needle = "Output Rows"; @@ -793,7 +793,7 @@ async fn explain_logical_plan_only() { let expected = vec![ vec![ "logical_plan", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\ + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\ \n SubqueryAlias: t\ \n Projection: \ \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))" @@ -812,7 +812,7 @@ async fn explain_physical_plan_only() { let expected = vec![vec![ "physical_plan", - "ProjectionExec: expr=[2 as COUNT(*)]\ + "ProjectionExec: expr=[2 as count(*)]\ \n PlaceholderRowExec\ \n", ]]; diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index f2710e659240..d9ef462df26c 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -246,7 +246,7 @@ async fn test_parameter_invalid_types() -> Result<()> { .await; assert_eq!( results.unwrap_err().strip_backtrace(), - "Arrow error: Invalid argument error: Invalid comparison operation: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })" + "type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32" ); Ok(()) } diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 44fb0afff319..b99bc2680044 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1044,7 +1044,10 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { for table in &tables { ctx.register_table( table.name.as_str(), - Arc::new(MemTable::try_new(Arc::new(table.schema.clone()), vec![])?), + Arc::new(MemTable::try_new( + Arc::new(table.schema.clone()), + vec![vec![]], + )?), )?; } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 66cdeb575a15..d591c662d877 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -48,7 +48,8 @@ use datafusion_expr::{ create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, }; -use datafusion_physical_expr::expressions::AvgAccumulator; +use datafusion_functions_aggregate::average::AvgAccumulator; + /// Test to show the contents of the setup #[tokio::test] async fn test_setup() { diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index ebf907c5e2c0..38ed142cf922 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -92,8 +92,12 @@ use datafusion::{ }; use async_trait::async_trait; -use datafusion_common::tree_node::Transformed; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::ScalarValue; +use datafusion_expr::Projection; use datafusion_optimizer::optimizer::ApplyOrder; +use datafusion_optimizer::AnalyzerRule; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -132,11 +136,13 @@ async fn setup_table_without_schemas(mut ctx: SessionContext) -> Result Result<()> { @@ -164,6 +170,34 @@ async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Re Ok(()) } +// Run the query using the specified execution context and compare it +// to the known result +async fn run_and_compare_query_with_analyzer_rule( + mut ctx: SessionContext, + description: &str, +) -> Result<()> { + let expected = vec![ + "+------------+--------------------------+", + "| UInt64(42) | arrow_typeof(UInt64(42)) |", + "+------------+--------------------------+", + "| 42 | UInt64 |", + "+------------+--------------------------+", + ]; + + let s = exec_sql(&mut ctx, QUERY2).await?; + let actual = s.lines().collect::>(); + + assert_eq!( + expected, + actual, + "output mismatch for {}. Expectedn\n{}Actual:\n{}", + description, + expected.join("\n"), + s + ); + Ok(()) +} + // Run the query using the specified execution context and compare it // to the known result async fn run_and_compare_query_with_auto_schemas( @@ -208,6 +242,14 @@ async fn normal_query() -> Result<()> { run_and_compare_query(ctx, "Default context").await } +#[tokio::test] +// Run the query using default planners, optimizer and custom analyzer rule +async fn normal_query_with_analyzer() -> Result<()> { + let ctx = SessionContext::new(); + ctx.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); + run_and_compare_query_with_analyzer_rule(ctx, "MyAnalyzerRule").await +} + #[tokio::test] // Run the query using topk optimization async fn topk_query() -> Result<()> { @@ -248,9 +290,10 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime) + let mut state = SessionState::new_with_config_rt(config, runtime) .with_query_planner(Arc::new(TopKQueryPlanner {})) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); + state.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); SessionContext::new_with_state(state) } @@ -281,15 +324,6 @@ impl QueryPlanner for TopKQueryPlanner { struct TopKOptimizerRule {} impl OptimizerRule for TopKOptimizerRule { - // Example rewrite pass to insert a user defined LogicalPlanNode - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - unreachable!() - } - fn name(&self) -> &str { "topk" } @@ -302,6 +336,7 @@ impl OptimizerRule for TopKOptimizerRule { true } + // Example rewrite pass to insert a user defined LogicalPlanNode fn rewrite( &self, plan: LogicalPlan, @@ -473,6 +508,10 @@ impl DisplayAs for TopKExec { #[async_trait] impl ExecutionPlan for TopKExec { + fn name(&self) -> &'static str { + Self::static_name() + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -633,3 +672,52 @@ impl RecordBatchStream for TopKReader { self.input.schema() } } + +struct MyAnalyzerRule {} + +impl AnalyzerRule for MyAnalyzerRule { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + Self::analyze_plan(plan) + } + + fn name(&self) -> &str { + "my_analyzer_rule" + } +} + +impl MyAnalyzerRule { + fn analyze_plan(plan: LogicalPlan) -> Result { + plan.transform(|plan| { + Ok(match plan { + LogicalPlan::Projection(projection) => { + let expr = Self::analyze_expr(projection.expr.clone())?; + Transformed::yes(LogicalPlan::Projection(Projection::try_new( + expr, + projection.input, + )?)) + } + _ => Transformed::no(plan), + }) + }) + .data() + } + + fn analyze_expr(expr: Vec) -> Result> { + expr.into_iter() + .map(|e| { + e.transform(|e| { + Ok(match e { + Expr::Literal(ScalarValue::Int64(i)) => { + // transform to UInt64 + Transformed::yes(Expr::Literal(ScalarValue::UInt64( + i.map(|i| i as u64), + ))) + } + _ => Transformed::no(e), + }) + }) + .data() + }) + .collect() + } +} diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index a81fc9159e52..5e3c44c039ab 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -51,7 +51,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let actual = plan_and_collect(&ctx, sql).await.unwrap(); let expected = [ "+------------------------------------------+", - "| AVG(custom_sqrt(aggregate_test_100.c11)) |", + "| avg(custom_sqrt(aggregate_test_100.c11)) |", "+------------------------------------------+", "| 0.6584408483418835 |", "+------------------------------------------+", @@ -69,7 +69,7 @@ async fn csv_query_avg_sqrt() -> Result<()> { let actual = plan_and_collect(&ctx, sql).await.unwrap(); let expected = [ "+------------------------------------------+", - "| AVG(custom_sqrt(aggregate_test_100.c12)) |", + "| avg(custom_sqrt(aggregate_test_100.c12)) |", "+------------------------------------------+", "| 0.6706002946036459 |", "+------------------------------------------+", diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index d3ddbed20d59..1e8d30cab638 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -90,6 +90,21 @@ async fn test_simple_read_csv_udtf() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_deregister_udtf() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_udtf("read_csv", Arc::new(SimpleCsvTableFunc {})); + + assert!(ctx.state().table_functions().contains_key("read_csv")); + + ctx.deregister_udtf("read_csv"); + + assert!(!ctx.state().table_functions().contains_key("read_csv")); + + Ok(()) +} + struct SimpleCsvTable { schema: SchemaRef, exprs: Vec, diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 967ccc0b0866..b17e4294a1ef 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -37,14 +37,10 @@ pub enum AggregateFunction { Min, /// Maximum Max, - /// Average - Avg, /// Aggregation into an array ArrayAgg, /// N'th value in a group according to some ordering NthValue, - /// Correlation - Correlation, /// Grouping Grouping, } @@ -55,10 +51,8 @@ impl AggregateFunction { match self { Min => "MIN", Max => "MAX", - Avg => "AVG", ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", - Correlation => "CORR", Grouping => "GROUPING", } } @@ -75,14 +69,10 @@ impl FromStr for AggregateFunction { fn from_str(name: &str) -> Result { Ok(match name { // general - "avg" => AggregateFunction::Avg, "max" => AggregateFunction::Max, - "mean" => AggregateFunction::Avg, "min" => AggregateFunction::Min, "array_agg" => AggregateFunction::ArrayAgg, "nth_value" => AggregateFunction::NthValue, - // statistical - "corr" => AggregateFunction::Correlation, // other "grouping" => AggregateFunction::Grouping, _ => { @@ -96,7 +86,11 @@ impl AggregateFunction { /// Returns the datatype of the aggregate function given its argument types /// /// This is used to get the returned data type for aggregate expr. - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + pub fn return_type( + &self, + input_expr_types: &[DataType], + input_expr_nullable: &[bool], + ) -> Result { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. @@ -120,32 +114,26 @@ impl AggregateFunction { // The coerced_data_types is same with input_types. Ok(coerced_data_types[0].clone()) } - AggregateFunction::Correlation => { - correlation_return_type(&coerced_data_types[0]) - } - AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( "item", coerced_data_types[0].clone(), - true, + input_expr_nullable[0], )))), AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), } } -} -/// Returns the internal sum datatype of the avg aggregate function. -pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - let fun = AggregateFunction::Avg; - let coerced_data_types = crate::type_coercion::aggregates::coerce_types( - &fun, - input_expr_types, - &fun.signature(), - )?; - avg_sum_type(&coerced_data_types[0]) + /// Returns if the return type of the aggregate function is nullable given its argument + /// nullability + pub fn nullable(&self) -> Result { + match self { + AggregateFunction::Max | AggregateFunction::Min => Ok(true), + AggregateFunction::ArrayAgg => Ok(false), + AggregateFunction::Grouping => Ok(true), + AggregateFunction::NthValue => Ok(true), + } + } } impl AggregateFunction { @@ -168,13 +156,7 @@ impl AggregateFunction { .collect::>(); Signature::uniform(1, valid, Volatility::Immutable) } - AggregateFunction::Avg => { - Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) - } AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), - AggregateFunction::Correlation => { - Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) - } } } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9ba866a4c919..846b627b2242 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -19,7 +19,8 @@ use std::collections::HashSet; use std::fmt::{self, Display, Formatter, Write}; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; +use std::mem; use std::str::FromStr; use std::sync::Arc; @@ -33,7 +34,9 @@ use crate::{ use crate::{window_frame, Volatility}; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{ internal_err, plan_err, Column, DFSchema, Result, ScalarValue, TableReference, }; @@ -305,7 +308,7 @@ pub enum Expr { /// /// This expr has to be resolved to a list of columns before translating logical /// plan into physical plan. - Wildcard { qualifier: Option }, + Wildcard { qualifier: Option }, /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list GroupingSet(GroupingSet), @@ -706,10 +709,14 @@ pub enum WindowFunctionDefinition { impl WindowFunctionDefinition { /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + pub fn return_type( + &self, + input_expr_types: &[DataType], + input_expr_nullable: &[bool], + ) -> Result { match self { WindowFunctionDefinition::AggregateFunction(fun) => { - fun.return_type(input_expr_types) + fun.return_type(input_expr_types, input_expr_nullable) } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { fun.return_type(input_expr_types) @@ -1326,6 +1333,7 @@ impl Expr { } /// Return all referenced columns of this expression. + #[deprecated(since = "40.0.0", note = "use Expr::column_refs instead")] pub fn to_columns(&self) -> Result> { let mut using_columns = HashSet::new(); expr_to_columns(self, &mut using_columns)?; @@ -1333,6 +1341,46 @@ impl Expr { Ok(using_columns) } + /// Return all references to columns in this expression. + /// + /// # Example + /// ``` + /// # use std::collections::HashSet; + /// # use datafusion_common::Column; + /// # use datafusion_expr::col; + /// // For an expression `a + (b * a)` + /// let expr = col("a") + (col("b") * col("a")); + /// let refs = expr.column_refs(); + /// // refs contains "a" and "b" + /// assert_eq!(refs.len(), 2); + /// assert!(refs.contains(&Column::new_unqualified("a"))); + /// assert!(refs.contains(&Column::new_unqualified("b"))); + /// ``` + pub fn column_refs(&self) -> HashSet<&Column> { + let mut using_columns = HashSet::new(); + self.add_column_refs(&mut using_columns); + using_columns + } + + /// Adds references to all columns in this expression to the set + /// + /// See [`Self::column_refs`] for details + pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { + self.apply(|expr| { + if let Expr::Column(col) = expr { + set.insert(col); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallable"); + } + + /// Returns true if there are any column references in this Expr + pub fn any_column_refs(&self) -> bool { + self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) + .unwrap() + } + /// Return true when the expression contains out reference(correlated) expressions. pub fn contains_outer(&self) -> bool { self.exists(|expr| Ok(matches!(expr, Expr::OuterReferenceColumn { .. }))) @@ -1419,6 +1467,176 @@ impl Expr { | Expr::Placeholder(..) => false, } } + + /// Hashes the direct content of an `Expr` without recursing into its children. + /// + /// This method is useful to incrementally compute hashes, such as in + /// `CommonSubexprEliminate` which builds a deep hash of a node and its descendants + /// during the bottom-up phase of the first traversal and so avoid computing the hash + /// of the node and then the hash of its descendants separately. + /// + /// If a node doesn't have any children then this method is similar to `.hash()`, but + /// not necessarily returns the same value. + /// + /// As it is pretty easy to forget changing this method when `Expr` changes the + /// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes + /// compile time. + pub fn hash_node(&self, hasher: &mut H) { + mem::discriminant(self).hash(hasher); + match self { + Expr::Alias(Alias { + expr: _expr, + relation, + name, + }) => { + relation.hash(hasher); + name.hash(hasher); + } + Expr::Column(column) => { + column.hash(hasher); + } + Expr::ScalarVariable(data_type, name) => { + data_type.hash(hasher); + name.hash(hasher); + } + Expr::Literal(scalar_value) => { + scalar_value.hash(hasher); + } + Expr::BinaryExpr(BinaryExpr { + left: _left, + op, + right: _right, + }) => { + op.hash(hasher); + } + Expr::Like(Like { + negated, + expr: _expr, + pattern: _pattern, + escape_char, + case_insensitive, + }) + | Expr::SimilarTo(Like { + negated, + expr: _expr, + pattern: _pattern, + escape_char, + case_insensitive, + }) => { + negated.hash(hasher); + escape_char.hash(hasher); + case_insensitive.hash(hasher); + } + Expr::Not(_expr) + | Expr::IsNotNull(_expr) + | Expr::IsNull(_expr) + | Expr::IsTrue(_expr) + | Expr::IsFalse(_expr) + | Expr::IsUnknown(_expr) + | Expr::IsNotTrue(_expr) + | Expr::IsNotFalse(_expr) + | Expr::IsNotUnknown(_expr) + | Expr::Negative(_expr) => {} + Expr::Between(Between { + expr: _expr, + negated, + low: _low, + high: _high, + }) => { + negated.hash(hasher); + } + Expr::Case(Case { + expr: _expr, + when_then_expr: _when_then_expr, + else_expr: _else_expr, + }) => {} + Expr::Cast(Cast { + expr: _expr, + data_type, + }) + | Expr::TryCast(TryCast { + expr: _expr, + data_type, + }) => { + data_type.hash(hasher); + } + Expr::Sort(Sort { + expr: _expr, + asc, + nulls_first, + }) => { + asc.hash(hasher); + nulls_first.hash(hasher); + } + Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { + func.hash(hasher); + } + Expr::AggregateFunction(AggregateFunction { + func_def, + args: _args, + distinct, + filter: _filter, + order_by: _order_by, + null_treatment, + }) => { + func_def.hash(hasher); + distinct.hash(hasher); + null_treatment.hash(hasher); + } + Expr::WindowFunction(WindowFunction { + fun, + args: _args, + partition_by: _partition_by, + order_by: _order_by, + window_frame, + null_treatment, + }) => { + fun.hash(hasher); + window_frame.hash(hasher); + null_treatment.hash(hasher); + } + Expr::InList(InList { + expr: _expr, + list: _list, + negated, + }) => { + negated.hash(hasher); + } + Expr::Exists(Exists { subquery, negated }) => { + subquery.hash(hasher); + negated.hash(hasher); + } + Expr::InSubquery(InSubquery { + expr: _expr, + subquery, + negated, + }) => { + subquery.hash(hasher); + negated.hash(hasher); + } + Expr::ScalarSubquery(subquery) => { + subquery.hash(hasher); + } + Expr::Wildcard { qualifier } => { + qualifier.hash(hasher); + } + Expr::GroupingSet(grouping_set) => { + mem::discriminant(grouping_set).hash(hasher); + match grouping_set { + GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) => {} + GroupingSet::GroupingSets(_exprs) => {} + } + } + Expr::Placeholder(place_holder) => { + place_holder.hash(hasher); + } + Expr::OuterReferenceColumn(data_type, column) => { + data_type.hash(hasher); + column.hash(hasher); + } + Expr::Unnest(Unnest { expr: _expr }) => {} + }; + } } // modifies expr if it is a placeholder with datatype of right @@ -2038,7 +2256,7 @@ mod test { // single column { let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); - let columns = expr.to_columns()?; + let columns = expr.column_refs(); assert_eq!(1, columns.len()); assert!(columns.contains(&Column::from_name("a"))); } @@ -2046,7 +2264,7 @@ mod test { // multiple columns { let expr = col("a") + col("b") + lit(1); - let columns = expr.to_columns()?; + let columns = expr.column_refs(); assert_eq!(2, columns.len()); assert!(columns.contains(&Column::from_name("a"))); assert!(columns.contains(&Column::from_name("b"))); @@ -2138,10 +2356,10 @@ mod test { #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::UInt64])?; + let observed = fun.return_type(&[DataType::UInt64], &[true])?; assert_eq!(DataType::UInt64, observed); Ok(()) @@ -2150,10 +2368,10 @@ mod test { #[test] fn test_last_value_return_type() -> Result<()> { let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &[true])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2162,10 +2380,10 @@ mod test { #[test] fn test_lead_return_type() -> Result<()> { let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &[true])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2174,10 +2392,10 @@ mod test { #[test] fn test_lag_return_type() -> Result<()> { let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &[true])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2186,10 +2404,12 @@ mod test { #[test] fn test_nth_value_return_type() -> Result<()> { let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + let observed = + fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; + let observed = + fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2198,7 +2418,7 @@ mod test { #[test] fn test_percent_rank_return_type() -> Result<()> { let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; + let observed = fun.return_type(&[], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2207,7 +2427,7 @@ mod test { #[test] fn test_cume_dist_return_type() -> Result<()> { let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; + let observed = fun.return_type(&[], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2216,7 +2436,7 @@ mod test { #[test] fn test_ntile_return_type() -> Result<()> { let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16])?; + let observed = fun.return_type(&[DataType::Int16], &[true])?; assert_eq!(DataType::UInt64, observed); Ok(()) @@ -2238,7 +2458,6 @@ mod test { "nth_value", "min", "max", - "avg", ]; for name in names { let fun = find_df_window_func(name).unwrap(); @@ -2267,12 +2486,6 @@ mod test { aggregate_function::AggregateFunction::Min )) ); - assert_eq!( - find_df_window_func("avg"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Avg - )) - ); assert_eq!( find_df_window_func("cume_dist"), Some(WindowFunctionDefinition::BuiltInWindowFunction( diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a87412ee6356..8b0213fd52fd 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -31,7 +31,9 @@ use crate::{ Signature, Volatility, }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; -use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; +use arrow::compute::kernels::cast_utils::{ + parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, +}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{Column, Result, ScalarValue}; use std::any::Any; @@ -181,18 +183,6 @@ pub fn array_agg(expr: Expr) -> Expr { )) } -/// Create an expression to represent the avg() aggregate function -pub fn avg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Avg, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( @@ -671,6 +661,16 @@ impl WindowUDFImpl for SimpleWindowUDF { } } +pub fn interval_year_month_lit(value: &str) -> Expr { + let interval = parse_interval_year_month(value).ok(); + Expr::Literal(ScalarValue::IntervalYearMonth(interval)) +} + +pub fn interval_datetime_lit(value: &str) -> Expr { + let interval = parse_interval_day_time(value).ok(); + Expr::Literal(ScalarValue::IntervalDayTime(interval)) +} + pub fn interval_month_day_nano_lit(value: &str) -> Expr { let interval = parse_interval_month_day_nano(value).ok(); Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index eb38fee7cad0..4b56ca3d1c2e 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -156,8 +156,8 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use crate::{ - avg, cast, col, lit, logical_plan::builder::LogicalTableSource, min, try_cast, - LogicalPlanBuilder, + cast, col, lit, logical_plan::builder::LogicalTableSource, min, + test::function_stub::avg, try_cast, LogicalPlanBuilder, }; use super::*; @@ -246,9 +246,9 @@ mod test { expected: sort(col("c1") + col("MIN(t.c2)")), }, TestCase { - desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#, + desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#, input: sort(avg(col("c3"))), - expected: sort(col("AVG(t.c3)").alias("average")), + expected: sort(col("avg(t.c3)").alias("average")), }, ]; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 986f85adebaa..d5a04ad4ae1f 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -160,6 +160,10 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; + let nullability = args + .iter() + .map(|e| e.nullable(schema)) + .collect::>>()?; match fun { WindowFunctionDefinition::AggregateUDF(udf) => { let new_types = data_types_with_aggregate_udf(&data_types, udf).map_err(|err| { @@ -173,10 +177,10 @@ impl ExprSchemable for Expr { ) ) })?; - Ok(fun.return_type(&new_types)?) + Ok(fun.return_type(&new_types, &nullability)?) } _ => { - fun.return_type(&data_types) + fun.return_type(&data_types, &nullability) } } } @@ -185,9 +189,13 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; + let nullability = args + .iter() + .map(|e| e.nullable(schema)) + .collect::>>()?; match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - fun.return_type(&data_types) + fun.return_type(&data_types, &nullability) } AggregateFunctionDefinition::UDF(fun) => { let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| { @@ -314,11 +322,17 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), + Expr::AggregateFunction(AggregateFunction { func_def, .. }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => fun.nullable(), + // TODO: UDF should be able to customize nullability + AggregateFunctionDefinition::UDF(_) => Ok(true), + } + } Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } - | Expr::AggregateFunction { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) @@ -343,9 +357,12 @@ impl ExprSchemable for Expr { | Expr::SimilarTo(Like { expr, pattern, .. }) => { Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) } - Expr::Wildcard { .. } => internal_err!( - "Wildcard expressions are not valid in a logical query plan" - ), + Expr::Wildcard { qualifier } => match qualifier { + Some(_) => internal_err!( + "QualifiedWildcard expressions are not valid in a logical query plan" + ), + None => Ok(false), + }, Expr::GroupingSet(_) => { // grouping sets do not really have the concept of nullable and do not appear // in projections diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 169436145aae..73ab51494de6 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -23,6 +23,14 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; use std::sync::Arc; +#[derive(Debug, Clone, Copy)] +pub enum Hint { + /// Indicates the argument needs to be padded if it is scalar + Pad, + /// Indicates the argument can be converted to an array of length 1 + AcceptsSingular, +} + /// Scalar function /// /// The Fn param is the wrapped function but be aware that the function will diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2f1ece32ab15..f87151efd88b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -39,7 +39,8 @@ use crate::logical_plan::{ use crate::type_coercion::binary::{comparison_coercion, values_coercion}; use crate::utils::{ can_hash, columnize_expr, compare_sort_expr, expand_qualified_wildcard, - expand_wildcard, find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, + expand_wildcard, expr_to_columns, find_valid_equijoin_key_pair, + group_window_expr_by_sort_keys, }; use crate::{ and, binary_expr, logical_plan::tree_node::unwrap_arc, DmlStatement, Expr, @@ -48,8 +49,8 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; -use datafusion_common::config::FormatOptions; use datafusion_common::display::ToStringifiedPlan; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -271,14 +272,14 @@ impl LogicalPlanBuilder { pub fn copy_to( input: LogicalPlan, output_url: String, - format_options: FormatOptions, + file_type: Arc, options: HashMap, partition_by: Vec, ) -> Result { Ok(Self::from(LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url, - format_options, + file_type, options, partition_by, }))) @@ -534,11 +535,11 @@ impl LogicalPlanBuilder { .clone() .into_iter() .try_for_each::<_, Result<()>>(|expr| { - let columns = expr.to_columns()?; + let columns = expr.column_refs(); columns.into_iter().for_each(|c| { - if schema.field_from_column(&c).is_err() { - missing_cols.push(c); + if !schema.has_column(c) { + missing_cols.push(c.clone()); } }); @@ -1070,14 +1071,16 @@ impl LogicalPlanBuilder { let left_key = l.into(); let right_key = r.into(); - let left_using_columns = left_key.to_columns()?; + let mut left_using_columns = HashSet::new(); + expr_to_columns(&left_key, &mut left_using_columns)?; let normalized_left_key = normalize_col_with_schemas_and_ambiguity_check( left_key, &[&[self.plan.schema(), right.schema()]], &[left_using_columns], )?; - let right_using_columns = right_key.to_columns()?; + let mut right_using_columns = HashSet::new(); + expr_to_columns(&right_key, &mut right_using_columns)?; let normalized_right_key = normalize_col_with_schemas_and_ambiguity_check( right_key, &[&[self.plan.schema(), right.schema()]], diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 707cff8ab5f1..81fd03555abb 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -425,7 +425,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { LogicalPlan::Copy(CopyTo { input: _, output_url, - format_options, + file_type, partition_by: _, options, }) => { @@ -437,7 +437,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { json!({ "Node Type": "CopyTo", "Output URL": output_url, - "Format Options": format!("{}", format_options), + "File Type": format!("{}", file_type.get_ext()), "Options": op_str }) } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 13f3759ab8c0..c9eef9bd34cc 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -21,7 +21,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::config::FormatOptions; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchemaRef, TableReference}; use crate::LogicalPlan; @@ -35,8 +35,8 @@ pub struct CopyTo { pub output_url: String, /// Determines which, if any, columns should be used for hive-style partitioned writes pub partition_by: Vec, - /// File format options. - pub format_options: FormatOptions, + /// File type trait + pub file_type: Arc, /// SQL Options that can affect the formats pub options: HashMap, } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6e7efaf39e3e..31f830a6a13d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -857,13 +857,13 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input: _, output_url, - format_options, + file_type, options, partition_by, }) => Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(inputs.swap_remove(0)), output_url: output_url.clone(), - format_options: format_options.clone(), + file_type: file_type.clone(), options: options.clone(), partition_by: partition_by.clone(), })), @@ -1729,7 +1729,7 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input: _, output_url, - format_options, + file_type, options, .. }) => { @@ -1739,7 +1739,7 @@ impl LogicalPlan { .collect::>() .join(", "); - write!(f, "CopyTo: format={format_options} output_url={output_url} options: ({op_str})") + write!(f, "CopyTo: format={} output_url={output_url} options: ({op_str})", file_type.get_ext()) } LogicalPlan::Ddl(ddl) => { write!(f, "{}", ddl.display()) diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 86c0cffd80a1..a47906f20322 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -256,14 +256,14 @@ impl TreeNode for LogicalPlan { input, output_url, partition_by, - format_options, + file_type, options, }) => rewrite_arc(input, f)?.update_data(|input| { LogicalPlan::Copy(CopyTo { input, output_url, partition_by, - format_options, + file_type, options, }) }), diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index a10312e23446..742511822a0f 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -25,7 +25,7 @@ use std::ops; use std::ops::Not; /// Operators applied to expressions -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Operator { /// Expressions are equal Eq, diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index ac98ee9747cc..14a6522ebe91 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -21,6 +21,14 @@ use std::any::Any; +use arrow::datatypes::{ + DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; + +use datafusion_common::{exec_err, not_impl_err, Result}; + +use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS}; +use crate::Volatility::Immutable; use crate::{ expr::AggregateFunction, function::{AccumulatorArgs, StateFieldsArgs}, @@ -28,10 +36,6 @@ use crate::{ Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature, Volatility, }; -use arrow::datatypes::{ - DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, -}; -use datafusion_common::{exec_err, not_impl_err, Result}; macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { @@ -82,6 +86,19 @@ pub fn count(expr: Expr) -> Expr { )) } +create_func!(Avg, avg_udaf); + +pub fn avg(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + avg_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + /// Stub `sum` used for optimizer testing #[derive(Debug)] pub struct Sum { @@ -273,3 +290,58 @@ impl AggregateUDFImpl for Count { ReversedUDAF::Identical } } + +/// Testing stub implementation of avg aggregate +#[derive(Debug)] +pub struct Avg { + signature: Signature, + aliases: Vec, +} + +impl Avg { + pub fn new() -> Self { + Self { + aliases: vec![String::from("mean")], + signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable), + } + } +} + +impl Default for Avg { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Avg { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "avg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + coerce_avg_type(self.name(), arg_types) + } +} diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 428fc99070d2..36a789d5b0ee 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -17,14 +17,15 @@ use std::ops::Deref; -use crate::{AggregateFunction, Signature, TypeSignature}; - use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; + use datafusion_common::{internal_err, plan_err, Result}; +use crate::{AggregateFunction, Signature, TypeSignature}; + pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; pub static SIGNED_INTEGERS: &[DataType] = &[ @@ -90,7 +91,6 @@ pub fn coerce_types( input_types: &[DataType], signature: &Signature, ) -> Result> { - use DataType::*; // Validate input_types matches (at least one of) the func signature. check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; @@ -101,36 +101,6 @@ pub fn coerce_types( // unpack the dictionary to get the value get_min_max_result_type(input_types) } - AggregateFunction::Avg => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval - let v = match &input_types[0] { - Decimal128(p, s) => Decimal128(*p, *s), - Decimal256(p, s) => Decimal256(*p, *s), - d if d.is_numeric() => Float64, - Dictionary(_, v) => { - return coerce_types(agg_fun, &[v.as_ref().clone()], signature) - } - _ => { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ) - } - }; - Ok(vec![v]) - } - AggregateFunction::Correlation => { - if !is_correlation_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), } @@ -262,7 +232,7 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result { } /// function return type of an average -pub fn avg_return_type(arg_type: &DataType) -> Result { +pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result { match arg_type { DataType::Decimal128(precision, scale) => { // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). @@ -280,9 +250,9 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { } arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), DataType::Dictionary(_, dict_value_type) => { - avg_return_type(dict_value_type.as_ref()) + avg_return_type(func_name, dict_value_type.as_ref()) } - other => plan_err!("AVG does not support {other:?}"), + other => plan_err!("{func_name} does not support {other:?}"), } } @@ -358,10 +328,29 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool { arg_type.is_integer() } +pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result> { + // Supported types smallint, int, bigint, real, double precision, decimal, or interval + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + fn coerced_type(func_name: &str, data_type: &DataType) -> Result { + return match &data_type { + DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), + DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), + d if d.is_numeric() => Ok(DataType::Float64), + DataType::Dictionary(_, v) => return coerced_type(func_name, v.as_ref()), + _ => { + return plan_err!( + "The function {:?} does not support inputs of type {:?}.", + func_name, + data_type + ) + } + }; + } + Ok(vec![coerced_type(func_name, &arg_types[0])?]) +} #[cfg(test)] mod tests { use super::*; - #[test] fn test_aggregate_coerce_types() { // test input args with error number input types @@ -371,16 +360,6 @@ mod tests { let result = coerce_types(&fun, &input_types, &signature); assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace()); - let fun = AggregateFunction::Avg; - // test input args is invalid data type for avg - let input_types = vec![DataType::Utf8]; - let signature = fun.signature(); - let result = coerce_types(&fun, &input_types, &signature); - assert_eq!( - "Error during planning: The function Avg does not support inputs of type Utf8.", - result.unwrap_err().strip_backtrace() - ); - // test count, array_agg, approx_distinct, min, max. // the coerced types is same with input types let funs = vec![ @@ -401,30 +380,6 @@ mod tests { assert_eq!(*input_type, result.unwrap()); } } - - // test avg - let fun = AggregateFunction::Avg; - let signature = fun.signature(); - let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap(); - assert_eq!(r[0], DataType::Float64); - let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap(); - assert_eq!(r[0], DataType::Float64); - let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap(); - assert_eq!(r[0], DataType::Decimal128(20, 3)); - let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap(); - assert_eq!(r[0], DataType::Decimal256(20, 3)); - } - - #[test] - fn test_avg_return_data_type() -> Result<()> { - let data_type = DataType::Decimal128(10, 5); - let result_type = avg_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(14, 9), result_type); - - let data_type = DataType::Decimal128(36, 10); - let result_type = avg_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(38, 14), result_type); - Ok(()) } #[test] diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index ea9d0c2fe72e..5645a2a4dede 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -889,21 +889,18 @@ fn dictionary_coercion( /// 2. Data type of the other side should be able to cast to string type fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; - string_coercion(lhs_type, rhs_type) - .or_else(|| list_coercion(lhs_type, rhs_type)) - .or(match (lhs_type, rhs_type) { - (Utf8, from_type) | (from_type, Utf8) => { - string_concat_internal_coercion(from_type, &Utf8) - } - (LargeUtf8, from_type) | (from_type, LargeUtf8) => { - string_concat_internal_coercion(from_type, &LargeUtf8) - } - _ => None, - }) + string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { + (Utf8, from_type) | (from_type, Utf8) => { + string_concat_internal_coercion(from_type, &Utf8) + } + (LargeUtf8, from_type) | (from_type, LargeUtf8) => { + string_concat_internal_coercion(from_type, &LargeUtf8) + } + _ => None, + }) } fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - // TODO: cast between array elements (#6558) if lhs_type.equals_datatype(rhs_type) { Some(lhs_type.to_owned()) } else { @@ -952,10 +949,7 @@ fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - // TODO: cast between array elements (#6558) (List(_), List(_)) => Some(lhs_type.clone()), - (List(_), _) => Some(lhs_type.clone()), - (_, List(_)) => Some(rhs_type.clone()), _ => None, } } @@ -1158,8 +1152,8 @@ mod tests { ]; for (i, input_type) in input_types.iter().enumerate() { let expect_type = &result_types[i]; - for op in comparison_op_types { - let (lhs, rhs) = get_input_types(&input_decimal, &op, input_type)?; + for op in &comparison_op_types { + let (lhs, rhs) = get_input_types(&input_decimal, op, input_type)?; assert_eq!(expect_type, &lhs); assert_eq!(expect_type, &rhs); } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index a248518c2d94..c8362691452b 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -124,8 +124,8 @@ impl AggregateUDF { } /// Return the underlying [`AggregateUDFImpl`] trait object for this function - pub fn inner(&self) -> Arc { - self.inner.clone() + pub fn inner(&self) -> &Arc { + &self.inner } /// Adds additional names that can be used to invoke this function, in diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index e14b62f1c841..03650b1d4748 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -105,8 +105,8 @@ impl ScalarUDF { } /// Return the underlying [`ScalarUDFImpl`] trait object for this function - pub fn inner(&self) -> Arc { - self.inner.clone() + pub fn inner(&self) -> &Arc { + &self.inner } /// Adds additional names that can be used to invoke this function, in diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index ce28b444adbc..a17bb0ade8e3 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -108,8 +108,8 @@ impl WindowUDF { } /// Return the underlying [`WindowUDFImpl`] trait object for this function - pub fn inner(&self) -> Arc { - self.inner.clone() + pub fn inner(&self) -> &Arc { + &self.inner } /// Adds additional names that can be used to invoke this function, in diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 3ab0c180dcba..286f05309ea7 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -46,6 +46,7 @@ pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::Int64(Some(1)); /// Recursively walk a list of expression trees, collecting the unique set of columns /// referenced in the expression +#[deprecated(since = "40.0.0", note = "Expr::add_column_refs instead")] pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result<()> { for e in expr { expr_to_columns(e, accum)?; @@ -317,7 +318,7 @@ fn get_excluded_columns( opt_exclude: Option<&ExcludeSelectItem>, opt_except: Option<&ExceptSelectItem>, schema: &DFSchema, - qualifier: &Option, + qualifier: Option<&TableReference>, ) -> Result> { let mut idents = vec![]; if let Some(excepts) = opt_except { @@ -342,8 +343,7 @@ fn get_excluded_columns( let mut result = vec![]; for ident in unique_idents.into_iter() { let col_name = ident.value.as_str(); - let (qualifier, field) = - schema.qualified_field_with_name(qualifier.as_ref(), col_name)?; + let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?; result.push(Column::from((qualifier, field))); } Ok(result) @@ -405,7 +405,7 @@ pub fn expand_wildcard( .. }) = wildcard_options { - get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, &None)? + get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)? } else { vec![] }; @@ -416,12 +416,11 @@ pub fn expand_wildcard( /// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s. pub fn expand_qualified_wildcard( - qualifier: &str, + qualifier: &TableReference, schema: &DFSchema, wildcard_options: Option<&WildcardAdditionalOptions>, ) -> Result> { - let qualifier = TableReference::from(qualifier); - let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + let qualified_indices = schema.fields_indices_with_qualified(qualifier); let projected_func_dependencies = schema .functional_dependencies() .project_functional_dependencies(&qualified_indices, qualified_indices.len()); @@ -444,7 +443,7 @@ pub fn expand_qualified_wildcard( opt_exclude.as_ref(), opt_except.as_ref(), schema, - &Some(qualifier), + Some(qualifier), )? } else { vec![] @@ -871,7 +870,7 @@ pub fn can_hash(data_type: &DataType) -> bool { /// Check whether all columns are from the schema. pub fn check_all_columns_from_schema( - columns: &HashSet, + columns: &HashSet<&Column>, schema: &DFSchema, ) -> Result { for col in columns.iter() { @@ -899,8 +898,8 @@ pub fn find_valid_equijoin_key_pair( left_schema: &DFSchema, right_schema: &DFSchema, ) -> Result> { - let left_using_columns = left_key.to_columns()?; - let right_using_columns = right_key.to_columns()?; + let left_using_columns = left_key.column_refs(); + let right_using_columns = right_key.column_refs(); // Conditions like a = 10, will be added to non-equijoin. if left_using_columns.is_empty() || right_using_columns.is_empty() { @@ -997,7 +996,7 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& /// assert_eq!(split_conjunction_owned(expr), split); /// ``` pub fn split_conjunction_owned(expr: Expr) -> Vec { - split_binary_owned(expr, Operator::And) + split_binary_owned(expr, &Operator::And) } /// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` @@ -1020,19 +1019,19 @@ pub fn split_conjunction_owned(expr: Expr) -> Vec { /// ]; /// /// // use split_binary_owned to split them -/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); +/// assert_eq!(split_binary_owned(expr, &Operator::Plus), split); /// ``` -pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { +pub fn split_binary_owned(expr: Expr, op: &Operator) -> Vec { split_binary_owned_impl(expr, op, vec![]) } fn split_binary_owned_impl( expr: Expr, - operator: Operator, + operator: &Operator, mut exprs: Vec, ) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if &op == operator => { let exprs = split_binary_owned_impl(*left, operator, exprs); split_binary_owned_impl(*right, operator, exprs) } @@ -1049,17 +1048,17 @@ fn split_binary_owned_impl( /// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` /// /// See [`split_binary_owned`] for more details and an example. -pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { +pub fn split_binary<'a>(expr: &'a Expr, op: &Operator) -> Vec<&'a Expr> { split_binary_impl(expr, op, vec![]) } fn split_binary_impl<'a>( expr: &'a Expr, - operator: Operator, + operator: &Operator, mut exprs: Vec<&'a Expr>, ) -> Vec<&'a Expr> { match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { let exprs = split_binary_impl(left, operator, exprs); split_binary_impl(right, operator, exprs) } @@ -1613,13 +1612,13 @@ mod tests { #[test] fn test_split_binary_owned() { let expr = col("a"); - assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]); + assert_eq!(split_binary_owned(expr.clone(), &Operator::And), vec![expr]); } #[test] fn test_split_binary_owned_two() { assert_eq!( - split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And), + split_binary_owned(col("a").eq(lit(5)).and(col("b")), &Operator::And), vec![col("a").eq(lit(5)), col("b")] ); } @@ -1629,7 +1628,7 @@ mod tests { let expr = col("a").eq(lit(5)).or(col("b")); assert_eq!( // expr is connected by OR, but pass in AND - split_binary_owned(expr.clone(), Operator::And), + split_binary_owned(expr.clone(), &Operator::And), vec![expr] ); } diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/functions-aggregate/src/average.rs similarity index 76% rename from datafusion/physical-expr/src/aggregate/average.rs rename to datafusion/functions-aggregate/src/average.rs index 80fcc9b70c5f..1dc1f10afce6 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -15,77 +15,85 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! Defines `Avg` & `Mean` aggregate & accumulators -use arrow::array::{AsArray, PrimitiveBuilder}; +use arrow::array::{ + self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, + AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, +}; +use arrow::compute::sum; +use arrow::datatypes::{ + i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, + Float64Type, UInt64Type, +}; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, +}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; +use datafusion_physical_expr_common::aggregate::utils::DecimalAverager; use log::debug; - use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -use crate::aggregate::groups_accumulator::accumulate::NullState; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::compute::sum; -use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type}; -use arrow::{ - array::{ArrayRef, UInt64Array}, - datatypes::Field, -}; -use arrow_array::types::{Decimal256Type, DecimalType}; -use arrow_array::{ - Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray, -}; -use arrow_buffer::{i256, ArrowNativeType}; -use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_expr::type_coercion::aggregates::avg_return_type; -use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; +make_udaf_expr_and_func!( + Avg, + avg, + expression, + "Returns the avg of a group of values.", + avg_udaf +); -use super::utils::DecimalAverager; - -/// AVG aggregate expression -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Avg { - name: String, - expr: Arc, - input_data_type: DataType, - result_data_type: DataType, + signature: Signature, + aliases: Vec, } impl Avg { - /// Create a new AVG aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - let result_data_type = avg_return_type(&data_type).unwrap(); - + pub fn new() -> Self { Self { - name: name.into(), - expr, - input_data_type: data_type, - result_data_type, + signature: Signature::user_defined(Immutable), + aliases: vec![String::from("mean")], } } } -impl AggregateExpr for Avg { - /// Return a reference to Any that can be used for downcasting +impl Default for Avg { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Avg { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.result_data_type.clone(), true)) + fn name(&self) -> &str { + "avg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return exec_err!("avg(DISTINCT) aggregations are not available"); + } use DataType::*; // instantiate specialized accumulator based for the type - match (&self.input_data_type, &self.result_data_type) { + match (acc_args.input_type, acc_args.data_type) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -110,59 +118,47 @@ impl AggregateExpr for Avg { target_precision: *target_precision, target_scale: *target_scale, })), - _ => not_impl_err!( + _ => exec_err!( "AvgAccumulator for ({} --> {})", - self.input_data_type, - self.result_data_type + acc_args.input_type, + acc_args.data_type ), } } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "count"), + format_state_name(args.name, "count"), DataType::UInt64, true, ), Field::new( - format_state_name(&self.name, "sum"), - self.input_data_type.clone(), + format_state_name(args.name, "sum"), + args.input_type.clone(), true, ), ]) } - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - self.create_accumulator() - } - - fn groups_accumulator_supported(&self) -> bool { - use DataType::*; - - matches!(&self.result_data_type, Float64 | Decimal128(_, _)) + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + matches!( + args.data_type, + DataType::Float64 | DataType::Decimal128(_, _) + ) } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { use DataType::*; // instantiate specialized accumulator based for the type - match (&self.input_data_type, &self.result_data_type) { + match (args.input_type, args.data_type) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - &self.input_data_type, - &self.result_data_type, + args.input_type, + args.data_type, |sum: f64, count: u64| Ok(sum / count as f64), ))) } @@ -180,8 +176,8 @@ impl AggregateExpr for Avg { move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); Ok(Box::new(AvgGroupsAccumulator::::new( - &self.input_data_type, - &self.result_data_type, + args.input_type, + args.data_type, avg_fn, ))) } @@ -201,32 +197,40 @@ impl AggregateExpr for Avg { }; Ok(Box::new(AvgGroupsAccumulator::::new( - &self.input_data_type, - &self.result_data_type, + args.input_type, + args.data_type, avg_fn, ))) } _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", - self.input_data_type, - self.result_data_type + args.input_type, + args.data_type ), } } -} -impl PartialEq for Avg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.result_data_type == x.result_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.accumulator(args) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!("{} expects exactly one argument.", self.name()); + } + coerce_avg_type(self.name(), arg_types) } } @@ -238,13 +242,6 @@ pub struct AvgAccumulator { } impl Accumulator for AvgAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::Float64(self.sum), - ]) - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = values[0].as_primitive::(); self.count += (values.len() - values.null_count()) as u64; @@ -255,13 +252,21 @@ impl Accumulator for AvgAccumulator { Ok(()) } - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - self.count -= (values.len() - values.null_count()) as u64; - if let Some(x) = sum(values) { - self.sum = Some(self.sum.unwrap() - x); - } - Ok(()) + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Float64( + self.sum.map(|f| f / self.count as f64), + )) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::Float64(self.sum), + ]) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { @@ -275,23 +280,23 @@ impl Accumulator for AvgAccumulator { } Ok(()) } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Float64( - self.sum.map(|f| f / self.count as f64), - )) + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count -= (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap() - x); + } + Ok(()) } + fn supports_retract_batch(&self) -> bool { true } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } } /// An accumulator to compute the average for decimals -struct DecimalAvgAccumulator { +#[derive(Debug)] +struct DecimalAvgAccumulator { sum: Option, count: u64, sum_scale: i8, @@ -300,56 +305,12 @@ struct DecimalAvgAccumulator { target_scale: i8, } -impl Debug for DecimalAvgAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("DecimalAvgAccumulator") - .field("sum", &self.sum) - .field("count", &self.count) - .field("sum_scale", &self.sum_scale) - .field("sum_precision", &self.sum_precision) - .field("target_precision", &self.target_precision) - .field("target_scale", &self.target_scale) - .finish() - } -} - -impl Accumulator for DecimalAvgAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::new_primitive::( - self.sum, - &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale), - )?, - ]) - } - +impl Accumulator for DecimalAvgAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = values[0].as_primitive::(); - self.count += (values.len() - values.null_count()) as u64; - if let Some(x) = sum(values) { - let v = self.sum.get_or_insert(T::Native::default()); - self.sum = Some(v.add_wrapping(x)); - } - Ok(()) - } - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - self.count -= (values.len() - values.null_count()) as u64; if let Some(x) = sum(values) { - self.sum = Some(self.sum.unwrap().sub_wrapping(x)); - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // counts are summed - self.count += sum(states[0].as_primitive::()).unwrap_or_default(); - - // sums are summed - if let Some(x) = sum(states[1].as_primitive::()) { let v = self.sum.get_or_insert(T::Native::default()); self.sum = Some(v.add_wrapping(x)); } @@ -374,13 +335,44 @@ impl Accumulator for DecimalAvgAccumulator &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), ) } - fn supports_retract_batch(&self) -> bool { - true - } fn size(&self) -> usize { std::mem::size_of_val(self) } + + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::new_primitive::( + self.sum, + &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale), + )?, + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // counts are summed + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); + + // sums are summed + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert(T::Native::default()); + self.sum = Some(v.add_wrapping(x)); + } + Ok(()) + } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count -= (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap().sub_wrapping(x)); + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// An accumulator to compute the average of `[PrimitiveArray]`. @@ -444,7 +436,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, + opt_filter: Option<&array::BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); @@ -469,45 +461,6 @@ where Ok(()) } - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is counts, second is partial sums - let partial_counts = values[0].as_primitive::(); - let partial_sums = values[1].as_primitive::(); - // update counts with partial counts - self.counts.resize(total_num_groups, 0); - self.null_state.accumulate( - group_indices, - partial_counts, - opt_filter, - total_num_groups, - |group_index, partial_count| { - self.counts[group_index] += partial_count; - }, - ); - - // update sums - self.sums.resize(total_num_groups, T::default_value()); - self.null_state.accumulate( - group_indices, - partial_sums, - opt_filter, - total_num_groups, - |group_index, new_value: ::Native| { - let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(new_value); - }, - ); - - Ok(()) - } - fn evaluate(&mut self, emit_to: EmitTo) -> Result { let counts = emit_to.take_needed(&mut self.counts); let sums = emit_to.take_needed(&mut self.sums); @@ -536,7 +489,7 @@ where .into_iter() .zip(counts.into_iter()) .map(|(sum, count)| (self.avg_fn)(sum, count)) - .collect::>>()?; + .collect::>>()?; PrimitiveArray::new(averages.into(), Some(nulls)) // no copy .with_data_type(self.return_data_type.clone()) }; @@ -562,6 +515,45 @@ where ]) } + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 2, "two arguments to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values[0].as_primitive::(); + let partial_sums = values[1].as_primitive::(); + // update counts with partial counts + self.counts.resize(total_num_groups, 0); + self.null_state.accumulate( + group_indices, + partial_counts, + opt_filter, + total_num_groups, + |group_index, partial_count| { + self.counts[group_index] += partial_count; + }, + ); + + // update sums + self.sums.resize(total_num_groups, T::default_value()); + self.null_state.accumulate( + group_indices, + partial_sums, + opt_filter, + total_num_groups, + |group_index, new_value: ::Native| { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + }, + ); + + Ok(()) + } + fn size(&self) -> usize { self.counts.capacity() * std::mem::size_of::() + self.sums.capacity() * std::mem::size_of::() diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 19e24f547d8a..ba9964270443 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -440,7 +440,7 @@ where .map(|x| ScalarValue::new_primitive::(Some(*x), &T::DATA_TYPE)) .collect::>>()?; - let arr = ScalarValue::new_list(&values, &T::DATA_TYPE); + let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE); vec![ScalarValue::List(arr)] }; Ok(state_out) diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs new file mode 100644 index 000000000000..10d556308615 --- /dev/null +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`Correlation`]: correlation sample aggregations. + +use std::any::Any; +use std::fmt::Debug; + +use arrow::compute::{and, filter, is_not_null}; +use arrow::{ + array::ArrayRef, + datatypes::{DataType, Field}, +}; + +use crate::covariance::CovarianceAccumulator; +use crate::stddev::StddevAccumulator; +use datafusion_common::{plan_err, Result, ScalarValue}; +use datafusion_expr::{ + function::{AccumulatorArgs, StateFieldsArgs}, + type_coercion::aggregates::NUMERICS, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Signature, Volatility, +}; +use datafusion_physical_expr_common::aggregate::stats::StatsType; + +make_udaf_expr_and_func!( + Correlation, + corr, + y x, + "Correlation between two numeric values.", + corr_udaf +); + +#[derive(Debug)] +pub struct Correlation { + signature: Signature, +} + +impl Default for Correlation { + fn default() -> Self { + Self::new() + } +} + +impl Correlation { + /// Create a new COVAR_POP aggregate function + pub fn new() -> Self { + Self { + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Correlation { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "corr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Correlation requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(CorrelationAccumulator::try_new()?)) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean1"), DataType::Float64, true), + Field::new(format_state_name(name, "m2_1"), DataType::Float64, true), + Field::new(format_state_name(name, "mean2"), DataType::Float64, true), + Field::new(format_state_name(name, "m2_2"), DataType::Float64, true), + Field::new( + format_state_name(name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } +} + +/// An accumulator to compute correlation +#[derive(Debug)] +pub struct CorrelationAccumulator { + covar: CovarianceAccumulator, + stddev1: StddevAccumulator, + stddev2: StddevAccumulator, +} + +impl CorrelationAccumulator { + /// Creates a new `CorrelationAccumulator` + pub fn try_new() -> Result { + Ok(Self { + covar: CovarianceAccumulator::try_new(StatsType::Population)?, + stddev1: StddevAccumulator::try_new(StatsType::Population)?, + stddev2: StddevAccumulator::try_new(StatsType::Population)?, + }) + } +} + +impl Accumulator for CorrelationAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // TODO: null input skipping logic duplicated across Correlation + // and its children accumulators. + // This could be simplified by splitting up input filtering and + // calculation logic in children accumulators, and calling only + // calculation part from Correlation + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + self.covar.update_batch(&values)?; + self.stddev1.update_batch(&values[0..1])?; + self.stddev2.update_batch(&values[1..2])?; + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let covar = self.covar.evaluate()?; + let stddev1 = self.stddev1.evaluate()?; + let stddev2 = self.stddev2.evaluate()?; + + if let ScalarValue::Float64(Some(c)) = covar { + if let ScalarValue::Float64(Some(s1)) = stddev1 { + if let ScalarValue::Float64(Some(s2)) = stddev2 { + if s1 == 0_f64 || s2 == 0_f64 { + return Ok(ScalarValue::Float64(Some(0_f64))); + } else { + return Ok(ScalarValue::Float64(Some(c / s1 / s2))); + } + } + } + } + + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) + + self.covar.size() + - std::mem::size_of_val(&self.stddev1) + + self.stddev1.size() + - std::mem::size_of_val(&self.stddev2) + + self.stddev2.size() + } + + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.covar.get_count()), + ScalarValue::from(self.covar.get_mean1()), + ScalarValue::from(self.stddev1.get_m2()), + ScalarValue::from(self.covar.get_mean2()), + ScalarValue::from(self.stddev2.get_m2()), + ScalarValue::from(self.covar.get_algo_const()), + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let states_c = [ + states[0].clone(), + states[1].clone(), + states[3].clone(), + states[5].clone(), + ]; + let states_s1 = [states[0].clone(), states[1].clone(), states[2].clone()]; + let states_s2 = [states[0].clone(), states[3].clone(), states[4].clone()]; + + self.covar.merge_batch(&states_c)?; + self.stddev1.merge_batch(&states_s1)?; + self.stddev2.merge_batch(&states_s2)?; + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + self.covar.retract_batch(&values)?; + self.stddev1.retract_batch(&values[0..1])?; + self.stddev2.retract_batch(&values[1..2])?; + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 062e148975bf..0fc8e32d7240 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -77,7 +77,6 @@ pub fn count_distinct(expr: Expr) -> datafusion_expr::Expr { pub struct Count { signature: Signature, - aliases: Vec, } impl Debug for Count { @@ -98,7 +97,6 @@ impl Default for Count { impl Count { pub fn new() -> Self { Self { - aliases: vec!["count".to_string()], signature: Signature::variadic_any(Volatility::Immutable), } } @@ -110,7 +108,7 @@ impl AggregateUDFImpl for Count { } fn name(&self) -> &str { - "COUNT" + "count" } fn signature(&self) -> &Signature { @@ -249,7 +247,7 @@ impl AggregateUDFImpl for Count { } fn aliases(&self) -> &[String] { - &self.aliases + &[] } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { @@ -502,7 +500,8 @@ impl Accumulator for DistinctCountAccumulator { /// Returns the distinct values seen so far as (one element) ListArray. fn state(&mut self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); - let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + let arr = + ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type); Ok(vec![ScalarValue::List(arr)]) } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 260d6dab31b9..063e6000b4c9 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -56,6 +56,7 @@ pub mod macros; pub mod approx_distinct; +pub mod correlation; pub mod count; pub mod covariance; pub mod first_last; @@ -69,9 +70,11 @@ pub mod variance; pub mod approx_median; pub mod approx_percentile_cont; pub mod approx_percentile_cont_with_weight; +pub mod average; pub mod bit_and_or_xor; pub mod bool_and_or; pub mod string_agg; + use crate::approx_percentile_cont::approx_percentile_cont_udaf; use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; use datafusion_common::Result; @@ -86,11 +89,13 @@ pub mod expr_fn { pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; + pub use super::average::avg; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; pub use super::bit_and_or_xor::bit_xor; pub use super::bool_and_or::bool_and; pub use super::bool_and_or::bool_or; + pub use super::correlation::corr; pub use super::count::count; pub use super::count::count_distinct; pub use super::covariance::covar_pop; @@ -120,8 +125,9 @@ pub fn all_default_aggregate_functions() -> Vec> { first_last::first_value_udaf(), first_last::last_value_udaf(), covariance::covar_samp_udaf(), - sum::sum_udaf(), covariance::covar_pop_udaf(), + correlation::corr_udaf(), + sum::sum_udaf(), median::median_udaf(), count::count_udaf(), regr::regr_slope_udaf(), @@ -147,6 +153,7 @@ pub fn all_default_aggregate_functions() -> Vec> { bit_and_or_xor::bit_xor_udaf(), bool_and_or::bool_and_udaf(), bool_and_or::bool_or_udaf(), + average::avg_udaf(), ] } @@ -176,7 +183,7 @@ mod tests { let mut names = HashSet::new(); for func in all_default_aggregate_functions() { // TODO: remove this - // These functions are in intermidiate migration state, skip them + // These functions are in intermediate migration state, skip them if func.name().to_lowercase() == "count" { continue; } diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index c8bc78ac2dcd..bb926b8da271 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -180,7 +180,7 @@ impl Accumulator for MedianAccumulator { .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) .collect::>>()?; - let arr = ScalarValue::new_list(&all_values, &self.data_type); + let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type); Ok(vec![ScalarValue::List(arr)]) } @@ -237,7 +237,7 @@ impl Accumulator for DistinctMedianAccumulator { .map(|x| ScalarValue::new_primitive::(Some(x.0), &self.data_type)) .collect::>>()?; - let arr = ScalarValue::new_list(&all_values, &self.data_type); + let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type); Ok(vec![ScalarValue::List(arr)]) } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index b9293bc2ca28..a9f31dc05be9 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -384,7 +384,7 @@ impl Accumulator for DistinctSumAccumulator { }) .collect::>>()?; - vec![ScalarValue::List(ScalarValue::new_list( + vec![ScalarValue::List(ScalarValue::new_list_nullable( &distinct_values, &self.data_type, ))] diff --git a/datafusion/functions-array/src/make_array.rs b/datafusion/functions-array/src/make_array.rs index 0159d4ac0829..79858041d3ca 100644 --- a/datafusion/functions-array/src/make_array.rs +++ b/datafusion/functions-array/src/make_array.rs @@ -27,7 +27,7 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; use datafusion_common::internal_err; -use datafusion_common::{plan_err, utils::array_into_list_array, Result}; +use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::TypeSignature; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -155,7 +155,7 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { let length = arrays.iter().map(|a| a.len()).sum(); // By default Int64 let array = new_null_array(&DataType::Int64, length); - Ok(Arc::new(array_into_list_array(array))) + Ok(Arc::new(array_into_list_array_nullable(array))) } LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index d18f1f8a3cbb..28bc2d5e4373 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -18,12 +18,10 @@ //! Rewrites for using Array Functions use crate::array_has::array_has_all; -use crate::concat::{array_append, array_concat, array_prepend}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; -use datafusion_common::utils::list_ndims; +use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_common::{Column, DFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::{BinaryExpr, Expr, Operator}; @@ -39,7 +37,7 @@ impl FunctionRewrite for ArrayFunctionRewriter { fn rewrite( &self, expr: Expr, - schema: &DFSchema, + _schema: &DFSchema, _config: &ConfigOptions, ) -> Result> { let transformed = match expr { @@ -61,91 +59,6 @@ impl FunctionRewrite for ArrayFunctionRewriter { Transformed::yes(array_has_all(*right, *left)) } - // Column cases: - // 1) array_prepend/append/concat || column - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_one_of_func( - &left, - &["array_append", "array_prepend", "array_concat"], - ) - && as_col(&right).is_some() => - { - let c = as_col(&right).unwrap(); - let d = schema.field_from_column(c)?.data_type(); - let ndim = list_ndims(d); - match ndim { - 0 => Transformed::yes(array_append(*left, *right)), - _ => Transformed::yes(array_concat(vec![*left, *right])), - } - } - // 2) select column1 || column2 - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && as_col(&left).is_some() - && as_col(&right).is_some() => - { - let c1 = as_col(&left).unwrap(); - let c2 = as_col(&right).unwrap(); - let d1 = schema.field_from_column(c1)?.data_type(); - let d2 = schema.field_from_column(c2)?.data_type(); - let ndim1 = list_ndims(d1); - let ndim2 = list_ndims(d2); - match (ndim1, ndim2) { - (0, _) => Transformed::yes(array_prepend(*left, *right)), - (_, 0) => Transformed::yes(array_append(*left, *right)), - _ => Transformed::yes(array_concat(vec![*left, *right])), - } - } - - // Chain concat operator (a || b) || array, - // (array_concat, array_append, array_prepend) || array -> array concat - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_one_of_func( - &left, - &["array_append", "array_prepend", "array_concat"], - ) - && is_func(&right, "make_array") => - { - Transformed::yes(array_concat(vec![*left, *right])) - } - - // Chain concat operator (a || b) || scalar, - // (array_concat, array_append, array_prepend) || scalar -> array append - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_one_of_func( - &left, - &["array_append", "array_prepend", "array_concat"], - ) => - { - Transformed::yes(array_append(*left, *right)) - } - - // array || array -> array concat - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_func(&left, "make_array") - && is_func(&right, "make_array") => - { - Transformed::yes(array_concat(vec![*left, *right])) - } - - // array || scalar -> array append - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat && is_func(&left, "make_array") => - { - Transformed::yes(array_append(*left, *right)) - } - - // scalar || array -> array prepend - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat && is_func(&right, "make_array") => - { - Transformed::yes(array_prepend(*left, *right)) - } - _ => Transformed::no(expr), }; Ok(transformed) @@ -161,21 +74,3 @@ fn is_func(expr: &Expr, func_name: &str) -> bool { func.name() == func_name } - -/// Returns true if expr is a function call with one of the specified names -fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool { - let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else { - return false; - }; - - func_names.contains(&func.name()) -} - -/// returns Some(col) if this is Expr::Column -fn as_col(expr: &Expr) -> Option<&Column> { - if let Expr::Column(c) = expr { - Some(c) - } else { - None - } -} diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs index 04832b4b1259..d02c863db8b7 100644 --- a/datafusion/functions-array/src/string.rs +++ b/datafusion/functions-array/src/string.rs @@ -26,12 +26,15 @@ use arrow::array::{ use arrow::datatypes::{DataType, Field}; use datafusion_expr::TypeSignature; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use std::any::{type_name, Any}; use crate::utils::{downcast_arg, make_scalar_function}; -use arrow_schema::DataType::{FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8}; +use arrow::compute::cast; +use arrow_schema::DataType::{ + Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, +}; use datafusion_common::cast::{ as_generic_string_array, as_large_list_array, as_list_array, as_string_array, }; @@ -76,7 +79,7 @@ macro_rules! call_array_function { DataType::UInt16 => array_function!(UInt16Array), DataType::UInt32 => array_function!(UInt32Array), DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), + dt => not_impl_err!("Unsupported data type in array_to_string: {dt}"), } }; ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ @@ -95,7 +98,7 @@ macro_rules! call_array_function { DataType::UInt16 => array_function!(UInt16Array), DataType::UInt32 => array_function!(UInt32Array), DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), + dt => not_impl_err!("Unsupported data type in array_to_string: {dt}"), } }}; } @@ -245,6 +248,8 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { with_null_string = true; } + /// Creates a single string from single element of a ListArray (which is + /// itself another Array) fn compute_array_to_string( arg: &mut String, arr: ArrayRef, @@ -281,6 +286,22 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { Ok(arg) } + Dictionary(_key_type, value_type) => { + // Call cast to unwrap the dictionary. This could be optimized if we wanted + // to accept the overhead of extra code + let values = cast(&arr, value_type.as_ref()).map_err(|e| { + DataFusionError::from(e).context( + "Casting dictionary to values in compute_array_to_string", + ) + })?; + compute_array_to_string( + arg, + values, + delimiter, + null_string, + with_null_string, + ) + } Null => Ok(arg), data_type => { macro_rules! array_function { diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-array/src/utils.rs index 00a6a68f7aac..3ecccf3c8713 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-array/src/utils.rs @@ -262,7 +262,7 @@ pub(super) fn get_arg_name(args: &[Expr], i: usize) -> String { mod tests { use super::*; use arrow::datatypes::Int64Type; - use datafusion_common::utils::array_into_list_array; + use datafusion_common::utils::array_into_list_array_nullable; /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` #[test] @@ -277,8 +277,10 @@ mod tests { Some(vec![Some(6), Some(7), Some(8)]), ])); - let array2d_1 = Arc::new(array_into_list_array(array1d_1.clone())) as ArrayRef; - let array2d_2 = Arc::new(array_into_list_array(array1d_2.clone())) as ArrayRef; + let array2d_1 = + Arc::new(array_into_list_array_nullable(array1d_1.clone())) as ArrayRef; + let array2d_2 = + Arc::new(array_into_list_array_nullable(array1d_2.clone())) as ArrayRef; let res = align_array_dimensions::(vec![ array1d_1.to_owned(), @@ -294,8 +296,8 @@ mod tests { expected_dim ); - let array3d_1 = Arc::new(array_into_list_array(array2d_1)) as ArrayRef; - let array3d_2 = array_into_list_array(array2d_2.to_owned()); + let array3d_1 = Arc::new(array_into_list_array_nullable(array2d_1)) as ArrayRef; + let array3d_2 = array_into_list_array_nullable(array2d_2.to_owned()); let res = align_array_dimensions::(vec![array1d_1, Arc::new(array3d_2.clone())]) .unwrap(); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 20d6cbc37459..884a66724c91 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -73,7 +73,6 @@ chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-physical-expr = { workspace = true, default-features = true } hashbrown = { workspace = true, optional = true } hex = { version = "0.4", optional = true } itertools = { workspace = true } diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 7c6f2e42605a..4cb91447f386 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -20,8 +20,8 @@ use std::any::Any; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{ - ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, + ArrowTimestampType, DataType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; use datafusion_common::{exec_err, Result, ScalarType}; @@ -143,8 +143,8 @@ impl ScalarUDFImpl for ToTimestampFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Nanosecond, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Nanosecond)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -167,6 +167,9 @@ impl ScalarUDFImpl for ToTimestampFunc { DataType::Null | DataType::Float64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } + DataType::Timestamp(_, Some(tz)) => { + args[0].cast_to(&Timestamp(Nanosecond, Some(tz)), None) + } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp") } @@ -193,8 +196,8 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Second, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Second)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -214,6 +217,9 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Second, None), None) } + DataType::Timestamp(_, Some(tz)) => { + args[0].cast_to(&Timestamp(Second, Some(tz)), None) + } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp_seconds") } @@ -240,8 +246,8 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Millisecond, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Millisecond)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -261,6 +267,9 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Millisecond, None), None) } + DataType::Timestamp(_, Some(tz)) => { + args[0].cast_to(&Timestamp(Millisecond, Some(tz)), None) + } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp_millis") } @@ -287,8 +296,8 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Microsecond, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Microsecond)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -308,6 +317,9 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Microsecond, None), None) } + DataType::Timestamp(_, Some(tz)) => { + args[0].cast_to(&Timestamp(Microsecond, Some(tz)), None) + } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp_micros") } @@ -334,8 +346,8 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Nanosecond, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Nanosecond)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -355,6 +367,9 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } + DataType::Timestamp(_, Some(tz)) => { + args[0].cast_to(&Timestamp(Nanosecond, Some(tz)), None) + } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp_nanos") } @@ -368,6 +383,15 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } } +/// Returns the return type for the to_timestamp_* function, preserving +/// the timezone if it exists. +fn return_type_for(arg: &DataType, unit: TimeUnit) -> DataType { + match arg { + Timestamp(_, Some(tz)) => Timestamp(unit, Some(tz.clone())), + _ => Timestamp(unit, None), + } +} + fn to_timestamp_impl>( args: &[ColumnarValue], name: &str, @@ -740,6 +764,103 @@ mod tests { } } + #[test] + fn test_tz() { + let udfs: Vec> = vec![ + Box::new(ToTimestampFunc::new()), + Box::new(ToTimestampSecondsFunc::new()), + Box::new(ToTimestampMillisFunc::new()), + Box::new(ToTimestampNanosFunc::new()), + Box::new(ToTimestampSecondsFunc::new()), + ]; + + let mut nanos_builder = TimestampNanosecondArray::builder(2); + let mut millis_builder = TimestampMillisecondArray::builder(2); + let mut micros_builder = TimestampMicrosecondArray::builder(2); + let mut sec_builder = TimestampSecondArray::builder(2); + + nanos_builder.append_value(1599572549190850000); + millis_builder.append_value(1599572549190); + micros_builder.append_value(1599572549190850); + sec_builder.append_value(1599572549); + + let nanos_timestamps = + Arc::new(nanos_builder.finish().with_timezone("UTC")) as ArrayRef; + let millis_timestamps = + Arc::new(millis_builder.finish().with_timezone("UTC")) as ArrayRef; + let micros_timestamps = + Arc::new(micros_builder.finish().with_timezone("UTC")) as ArrayRef; + let sec_timestamps = + Arc::new(sec_builder.finish().with_timezone("UTC")) as ArrayRef; + + let arrays = &[ + ColumnarValue::Array(nanos_timestamps.clone()), + ColumnarValue::Array(millis_timestamps.clone()), + ColumnarValue::Array(micros_timestamps.clone()), + ColumnarValue::Array(sec_timestamps.clone()), + ]; + + for udf in &udfs { + for array in arrays { + let rt = udf.return_type(&[array.data_type()]).unwrap(); + assert!(matches!(rt, DataType::Timestamp(_, Some(_)))); + + let res = udf + .invoke(&[array.clone()]) + .expect("that to_timestamp parsed values without error"); + let array = match res { + ColumnarValue::Array(res) => res, + _ => panic!("Expected a columnar array"), + }; + let ty = array.data_type(); + assert!(matches!(ty, DataType::Timestamp(_, Some(_)))); + } + } + + let mut nanos_builder = TimestampNanosecondArray::builder(2); + let mut millis_builder = TimestampMillisecondArray::builder(2); + let mut micros_builder = TimestampMicrosecondArray::builder(2); + let mut sec_builder = TimestampSecondArray::builder(2); + let mut i64_builder = Int64Array::builder(2); + + nanos_builder.append_value(1599572549190850000); + millis_builder.append_value(1599572549190); + micros_builder.append_value(1599572549190850); + sec_builder.append_value(1599572549); + i64_builder.append_value(1599572549); + + let nanos_timestamps = Arc::new(nanos_builder.finish()) as ArrayRef; + let millis_timestamps = Arc::new(millis_builder.finish()) as ArrayRef; + let micros_timestamps = Arc::new(micros_builder.finish()) as ArrayRef; + let sec_timestamps = Arc::new(sec_builder.finish()) as ArrayRef; + let i64_timestamps = Arc::new(i64_builder.finish()) as ArrayRef; + + let arrays = &[ + ColumnarValue::Array(nanos_timestamps.clone()), + ColumnarValue::Array(millis_timestamps.clone()), + ColumnarValue::Array(micros_timestamps.clone()), + ColumnarValue::Array(sec_timestamps.clone()), + ColumnarValue::Array(i64_timestamps.clone()), + ]; + + for udf in &udfs { + for array in arrays { + let rt = udf.return_type(&[array.data_type()]).unwrap(); + assert!(matches!(rt, DataType::Timestamp(_, None))); + + let res = udf + .invoke(&[array.clone()]) + .expect("that to_timestamp parsed values without error"); + let array = match res { + ColumnarValue::Array(res) => res, + _ => panic!("Expected a columnar array"), + }; + let ty = array.data_type(); + assert!(matches!(ty, DataType::Timestamp(_, None))); + } + } + } + #[test] fn test_to_timestamp_arg_validation() { let mut date_string_builder = StringBuilder::with_capacity(2, 1024); @@ -811,6 +932,11 @@ mod tests { .expect("that to_timestamp with format args parsed values without error"); if let ColumnarValue::Array(parsed_array) = parsed_timestamps { assert_eq!(parsed_array.len(), 1); + assert!(matches!( + parsed_array.data_type(), + DataType::Timestamp(_, None) + )); + match time_unit { Nanosecond => { assert_eq!(nanos_expected_timestamps, parsed_array.as_ref()) diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index dc481da79069..74ad2c738a93 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, Int64Array}; +use arrow::{ + array::{ArrayRef, Int64Array}, + error::ArrowError, +}; use std::any::Any; use std::sync::Arc; @@ -23,7 +26,7 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] @@ -67,28 +70,27 @@ impl ScalarUDFImpl for FactorialFunc { } } -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; -} - /// Factorial SQL function fn factorial(args: &[ArrayRef]) -> Result { match args[0].data_type() { - DataType::Int64 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Int64Array, - { |value: i64| { (1..=value).product() } } - )) as ArrayRef), + DataType::Int64 => { + let arg = downcast_arg!((&args[0]), "value", Int64Array); + Ok(arg + .iter() + .map(|a| match a { + Some(a) => (2..=a) + .try_fold(1i64, i64::checked_mul) + .ok_or_else(|| { + arrow_datafusion_err!(ArrowError::ComputeError(format!( + "Overflow happened on FACTORIAL({a})" + ))) + }) + .map(Some), + _ => Ok(None), + }) + .collect::>() + .map(Arc::new)? as ArrayRef) + } other => exec_err!("Unsupported data type {other:?} for function factorial."), } } diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index d0199f7a22c4..95a559c5d103 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -16,6 +16,7 @@ // under the License. use arrow::array::{ArrayRef, Int64Array}; +use arrow::error::ArrowError; use std::any::Any; use std::mem::swap; use std::sync::Arc; @@ -24,7 +25,7 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -73,59 +74,73 @@ impl ScalarUDFImpl for GcdFunc { /// Gcd SQL function fn gcd(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Int64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Int64Array, - Int64Array, - { compute_gcd } - )) as ArrayRef), + Int64 => { + let arg1 = downcast_arg!(&args[0], "x", Int64Array); + let arg2 = downcast_arg!(&args[1], "y", Int64Array); + + Ok(arg1 + .iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Ok(Some(compute_gcd(a1, a2)?)), + _ => Ok(None), + }) + .collect::>() + .map(Arc::new)? as ArrayRef) + } other => exec_err!("Unsupported data type {other:?} for function gcd"), } } -/// Computes greatest common divisor using Binary GCD algorithm. -pub fn compute_gcd(x: i64, y: i64) -> i64 { - if x == 0 { - return y; +/// Computes gcd of two unsigned integers using Binary GCD algorithm. +pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 { + if a == 0 { + return b; } - if y == 0 { - return x; + if b == 0 { + return a; } - let mut a = x.unsigned_abs(); - let mut b = y.unsigned_abs(); - let shift = (a | b).trailing_zeros(); a >>= shift; b >>= shift; a >>= a.trailing_zeros(); - loop { b >>= b.trailing_zeros(); if a > b { swap(&mut a, &mut b); } - b -= a; - if b == 0 { - // because the input values are i64, casting this back to i64 is safe - return (a << shift) as i64; + return a << shift; } } } +/// Computes greatest common divisor using Binary GCD algorithm. +pub fn compute_gcd(x: i64, y: i64) -> Result { + let a = x.unsigned_abs(); + let b = y.unsigned_abs(); + let r = unsigned_gcd(a, b); + // gcd(i64::MIN, i64::MIN) = i64::MIN.unsigned_abs() cannot fit into i64 + r.try_into().map_err(|_| { + arrow_datafusion_err!(ArrowError::ComputeError(format!( + "Signed integer overflow in GCD({x}, {y})" + ))) + }) +} + #[cfg(test)] mod test { use std::sync::Arc; - use arrow::array::{ArrayRef, Int64Array}; + use arrow::{ + array::{ArrayRef, Int64Array}, + error::ArrowError, + }; use crate::math::gcd::gcd; - use datafusion_common::cast::as_int64_array; + use datafusion_common::{cast::as_int64_array, DataFusionError}; #[test] fn test_gcd_i64() { @@ -143,4 +158,21 @@ mod test { assert_eq!(ints.value(2), 5); assert_eq!(ints.value(3), 8); } + + #[test] + fn overflow_on_both_param_i64_min() { + let args: Vec = vec![ + Arc::new(Int64Array::from(vec![i64::MIN])), // x + Arc::new(Int64Array::from(vec![i64::MIN])), // y + ]; + + match gcd(&args) { + // we expect a overflow + Err(DataFusionError::ArrowError(ArrowError::ComputeError(_), _)) => {} + Err(_) => { + panic!("failed to initialize function gcd") + } + Ok(_) => panic!("GCD({0}, {0}) should have overflown", i64::MIN), + }; + } } diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 3674f7371de2..21c201657e90 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -22,11 +22,12 @@ use arrow::array::{ArrayRef, Int64Array}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; -use datafusion_common::{exec_err, DataFusionError, Result}; +use arrow::error::ArrowError; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::math::gcd::compute_gcd; +use super::gcd::unsigned_gcd; use crate::utils::make_scalar_function; #[derive(Debug)] @@ -74,25 +75,40 @@ impl ScalarUDFImpl for LcmFunc { /// Lcm SQL function fn lcm(args: &[ArrayRef]) -> Result { let compute_lcm = |x: i64, y: i64| { - let a = x.wrapping_abs(); - let b = y.wrapping_abs(); - - if a == 0 || b == 0 { - return 0; + if x == 0 || y == 0 { + return Ok(0); } - a / compute_gcd(a, b) * b + + // lcm(x, y) = |x| * |y| / gcd(|x|, |y|) + let a = x.unsigned_abs(); + let b = y.unsigned_abs(); + let gcd = unsigned_gcd(a, b); + // gcd is not zero since both a and b are not zero, so the division is safe. + (a / gcd) + .checked_mul(b) + .and_then(|v| i64::try_from(v).ok()) + .ok_or_else(|| { + arrow_datafusion_err!(ArrowError::ComputeError(format!( + "Signed integer overflow in LCM({x}, {y})" + ))) + }) }; match args[0].data_type() { - Int64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Int64Array, - Int64Array, - { compute_lcm } - )) as ArrayRef), + Int64 => { + let arg1 = downcast_arg!(&args[0], "x", Int64Array); + let arg2 = downcast_arg!(&args[1], "y", Int64Array); + + Ok(arg1 + .iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Ok(Some(compute_lcm(a1, a2)?)), + _ => Ok(None), + }) + .collect::>() + .map(Arc::new)? as ArrayRef) + } other => exec_err!("Unsupported data type {other:?} for function lcm"), } } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 7677e8b2af95..5b790fb56ddf 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -17,9 +17,11 @@ //! Math function: `power()`. -use arrow::datatypes::DataType; +use arrow::datatypes::{ArrowNativeTypeOp, DataType}; + use datafusion_common::{ - exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, + arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, + DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; @@ -94,14 +96,25 @@ impl ScalarUDFImpl for PowerFunc { { f64::powf } )), - DataType::Int64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Int64Array, - { i64::pow } - )), + DataType::Int64 => { + let bases = downcast_arg!(&args[0], "base", Int64Array); + let exponents = downcast_arg!(&args[1], "exponent", Int64Array); + bases + .iter() + .zip(exponents.iter()) + .map(|(base, exp)| match (base, exp) { + (Some(base), Some(exp)) => Ok(Some(base.pow_checked( + exp.try_into().map_err(|_| { + exec_datafusion_err!( + "Can't use negative exponents: {exp} in integer computation, please use Float." + ) + })?, + ).map_err(|e| arrow_datafusion_err!(e))?)), + _ => Ok(None), + }) + .collect::>() + .map(Arc::new)? as ArrayRef + } other => { return exec_err!( diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 1bab2953e4f6..71ab7c1b4350 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -20,10 +20,13 @@ use std::sync::Arc; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; +use arrow::array::{ArrayRef, Float32Array, Float64Array, Int32Array}; +use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Float32, Float64}; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use arrow::datatypes::DataType::{Float32, Float64, Int32}; +use datafusion_common::{ + exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -114,7 +117,11 @@ pub fn round(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Float64 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; Ok(Arc::new(make_function_scalar_inputs!( &args[0], @@ -128,21 +135,30 @@ pub fn round(args: &[ArrayRef]) -> Result { } )) as ArrayRef) } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float64Array, - Int64Array, - { - |value: f64, decimal_places: i64| { - (value * 10.0_f64.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f64.powi(decimal_places.try_into().unwrap()) + ColumnarValue::Array(decimal_places) => { + let options = CastOptions { + safe: false, // raise error if the cast is not possible + ..Default::default() + }; + let decimal_places = cast_with_options(&decimal_places, &Int32, &options) + .map_err(|e| { + exec_datafusion_err!("Invalid values for decimal places: {e}") + })?; + Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float64Array, + Int32Array, + { + |value: f64, decimal_places: i32| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + } } - } - )) as ArrayRef), + )) as ArrayRef) + } _ => { exec_err!("round function requires a scalar or array for decimal_places") } @@ -150,7 +166,11 @@ pub fn round(args: &[ArrayRef]) -> Result { DataType::Float32 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; Ok(Arc::new(make_function_scalar_inputs!( &args[0], @@ -164,21 +184,30 @@ pub fn round(args: &[ArrayRef]) -> Result { } )) as ArrayRef) } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float32Array, - Int64Array, - { - |value: f32, decimal_places: i64| { - (value * 10.0_f32.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f32.powi(decimal_places.try_into().unwrap()) + ColumnarValue::Array(_) => { + let ColumnarValue::Array(decimal_places) = + decimal_places.cast_to(&Int32, None).map_err(|e| { + exec_datafusion_err!("Invalid values for decimal places: {e}") + })? + else { + panic!("Unexpected result of ColumnarValue::Array.cast") + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float32Array, + Int32Array, + { + |value: f32, decimal_places: i32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + } } - } - )) as ArrayRef), + )) as ArrayRef) + } _ => { exec_err!("round function requires a scalar or array for decimal_places") } @@ -196,6 +225,7 @@ mod test { use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_common::DataFusionError; #[test] fn test_round_f32() { @@ -262,4 +292,17 @@ mod test { assert_eq!(floats, &expected); } + + #[test] + fn test_round_f32_cast_fail() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345])), // input + Arc::new(Int64Array::from(vec![2147483648])), // decimal_places + ]; + + let result = round(&args); + + assert!(result.is_err()); + assert!(matches!(result, Err(DataFusionError::Execution { .. }))); + } } diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 4e21883c9752..201eebde22bb 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -28,10 +28,10 @@ use datafusion_common::ScalarValue; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; +use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use datafusion_physical_expr::functions::Hint; use regex::Regex; use std::any::Any; use std::collections::HashMap; diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 97b54a194a27..349928d09664 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -21,10 +21,10 @@ use std::any::Any; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; +use datafusion_expr::function::Hint; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use datafusion_physical_expr::functions::Hint; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index ef05a2cb2a13..de14bbaa2bcf 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -21,10 +21,10 @@ use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; +use datafusion_expr::function::Hint; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use datafusion_physical_expr::functions::Hint; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 2e39080e226b..2d29b50cb173 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -21,10 +21,10 @@ use std::any::Any; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; +use datafusion_expr::function::Hint; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use datafusion_physical_expr::functions::Hint; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index d14844c4a445..393dcc456a88 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -18,8 +18,8 @@ use arrow::array::ArrayRef; use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::function::Hint; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; -use datafusion_physical_expr::functions::Hint; use std::sync::Arc; /// Creates a function to identify the optimal return type of a string function given diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index de2af520053a..34f9802b1fd9 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -59,14 +59,14 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { func_def: AggregateFunctionDefinition::UDF(udf), args, .. - } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) + } if udf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { let args = &window_function.args; matches!(window_function.fun, WindowFunctionDefinition::AggregateUDF(ref udaf) - if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) + if udaf.name() == "count" && args.len() == 1 && is_wildcard(&args[0])) } fn analyze_internal(plan: LogicalPlan) -> Result> { @@ -127,9 +127,9 @@ mod tests { .project(vec![count(wildcard())])? .sort(vec![count(wildcard()).sort(true, false)])? .build()?; - let expected = "Sort: COUNT(*) ASC NULLS LAST [COUNT(*):Int64;N]\ - \n Projection: COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[test.b]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] [b:UInt32, COUNT(*):Int64;N]\ + let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64;N]\ + \n Projection: count(*) [count(*):Int64;N]\ + \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) } @@ -152,9 +152,9 @@ mod tests { .build()?; let expected = "Filter: t1.a IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [COUNT(*):Int64;N]\ - \n Projection: COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] [COUNT(*):Int64;N]\ + \n Subquery: [count(*):Int64;N]\ + \n Projection: count(*) [count(*):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) @@ -175,9 +175,9 @@ mod tests { .build()?; let expected = "Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [COUNT(*):Int64;N]\ - \n Projection: COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] [COUNT(*):Int64;N]\ + \n Subquery: [count(*):Int64;N]\ + \n Projection: count(*) [count(*):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) @@ -207,9 +207,9 @@ mod tests { let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\ \n Filter: () > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [COUNT(Int64(1)):Int64;N]\ - \n Projection: COUNT(Int64(1)) [COUNT(Int64(1)):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] [COUNT(Int64(1)):Int64;N]\ + \n Subquery: [count(Int64(1)):Int64;N]\ + \n Projection: count(Int64(1)) [count(Int64(1)):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64;N]\ \n Filter: outer_ref(t1.a) = t2.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; @@ -235,8 +235,8 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: COUNT(Int64(1)) AS COUNT(*) [COUNT(*):Int64;N]\ - \n WindowAggr: windowExpr=[[COUNT(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS COUNT(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, COUNT(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64;N]\ + \n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) } @@ -249,8 +249,8 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] [COUNT(*):Int64;N]\ + let expected = "Projection: count(*) [count(*):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) } @@ -272,8 +272,8 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: COUNT(Int64(1)) AS COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(COUNT(Int64(1))) AS MAX(COUNT(*))]] [MAX(COUNT(*)):Int64;N]\ + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(count(Int64(1))) AS MAX(count(*))]] [MAX(count(*)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 444ee94c4292..5725a725e64a 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -249,7 +249,7 @@ fn check_aggregation_in_scalar_subquery( let mut group_columns = agg .group_expr .iter() - .map(|group| Ok(group.to_columns()?.into_iter().collect::>())) + .map(|group| Ok(group.column_refs().into_iter().cloned().collect::>())) .collect::>>()? .into_iter() .flatten(); @@ -300,19 +300,17 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool { }) = expr { match (left.deref(), right.deref()) { - (Expr::Column(_), right) if right.to_columns().unwrap().is_empty() => true, - (left, Expr::Column(_)) if left.to_columns().unwrap().is_empty() => true, + (Expr::Column(_), right) => !right.any_column_refs(), + (left, Expr::Column(_)) => !left.any_column_refs(), (Expr::Cast(Cast { expr, .. }), right) - if matches!(expr.deref(), Expr::Column(_)) - && right.to_columns().unwrap().is_empty() => + if matches!(expr.deref(), Expr::Column(_)) => { - true + !right.any_column_refs() } (left, Expr::Cast(Cast { expr, .. })) - if matches!(expr.deref(), Expr::Column(_)) - && left.to_columns().unwrap().is_empty() => + if matches!(expr.deref(), Expr::Column(_)) => { - true + !left.any_column_refs() } (_, _) => false, } @@ -323,9 +321,10 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool { /// Check whether the window expressions contain a mixture of out reference columns and inner columns fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> { - let mixed = window.window_expr.iter().any(|win_expr| { - win_expr.contains_outer() && !win_expr.to_columns().unwrap().is_empty() - }); + let mixed = window + .window_expr + .iter() + .any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs()); if mixed { plan_err!( "Window expressions should not contain a mixed of outer references and inner columns" diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index acc21f14f44d..51ec8d8af1d3 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -146,7 +146,7 @@ impl<'a> TypeCoercionRewriter<'a> { .map(|(lhs, rhs)| { // coerce the arguments as though they were a single binary equality // expression - let (lhs, rhs) = self.coerce_binary_op(lhs, Operator::Eq, rhs)?; + let (lhs, rhs) = self.coerce_binary_op(lhs, &Operator::Eq, rhs)?; Ok((lhs, rhs)) }) .collect::>>()?; @@ -157,12 +157,12 @@ impl<'a> TypeCoercionRewriter<'a> { fn coerce_binary_op( &self, left: Expr, - op: Operator, + op: &Operator, right: Expr, ) -> Result<(Expr, Expr)> { let (left_type, right_type) = get_input_types( &left.get_type(self.schema)?, - &op, + op, &right.get_type(self.schema)?, )?; Ok(( @@ -265,7 +265,10 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) })?; - let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?); + let expr = match left_type { + DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr, + _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), + }; let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, @@ -276,7 +279,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (left, right) = self.coerce_binary_op(*left, op, *right)?; + let (left, right) = self.coerce_binary_op(*left, &op, *right)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, @@ -815,13 +818,14 @@ mod test { use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection}; + use datafusion_expr::test::function_stub::avg_udaf; use datafusion_expr::{ - cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, - AggregateFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr, - ExprSchemable, Filter, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, - Signature, SimpleAggregateUDF, Subquery, Volatility, + cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF, + BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan, + Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, + Volatility, }; - use datafusion_physical_expr::expressions::AvgAccumulator; + use datafusion_functions_aggregate::average::AvgAccumulator; use crate::analyzer::type_coercion::{ coerce_case_expression, TypeCoercion, TypeCoercionRewriter, @@ -1003,31 +1007,29 @@ mod test { #[test] fn agg_function_case() -> Result<()> { let empty = empty(); - let fun: AggregateFunction = AggregateFunction::Avg; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - vec![lit(12i64)], + let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + avg_udaf(), + vec![lit(12f64)], false, None, None, None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: AVG(CAST(Int64(12) AS Float64))\n EmptyRelation"; + let expected = "Projection: avg(Float64(12))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let empty = empty_with_type(DataType::Int32); - let fun: AggregateFunction = AggregateFunction::Avg; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - vec![col("a")], + let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + avg_udaf(), + vec![cast(col("a"), DataType::Float64)], false, None, None, None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: AVG(CAST(a AS Float64))\n EmptyRelation"; + let expected = "Projection: avg(CAST(a AS Float64))\n EmptyRelation"; assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } @@ -1035,9 +1037,8 @@ mod test { #[test] fn agg_function_invalid_input_avg() -> Result<()> { let empty = empty(); - let fun: AggregateFunction = AggregateFunction::Avg; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, + let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + avg_udaf(), vec![lit("1")], false, None, @@ -1048,10 +1049,7 @@ mod test { .err() .unwrap() .strip_backtrace(); - assert_eq!( - "Error during planning: No function matches the given name and argument types 'AVG(Utf8)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tAVG(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64)", - err - ); + assert!(err.starts_with("Error during planning: Error during planning: Coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed.")); Ok(()) } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 7f4093ba110e..e760845e043a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -18,6 +18,7 @@ //! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions use std::collections::{BTreeSet, HashMap}; +use std::hash::{BuildHasher, Hash, Hasher, RandomState}; use std::sync::Arc; use crate::{OptimizerConfig, OptimizerRule}; @@ -25,11 +26,12 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; use datafusion_common::alias::AliasGenerator; +use datafusion_common::hash_utils::combine_hashes; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ - internal_datafusion_err, internal_err, qualified_name, Column, DFSchema, Result, + internal_datafusion_err, qualified_name, Column, DFSchema, DFSchemaRef, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::tree_node::unwrap_arc; @@ -43,18 +45,37 @@ const CSE_PREFIX: &str = "__common_expr"; /// Identifier that represents a subexpression tree. /// -/// Note that the current implementation contains: -/// - the `Display` of an expression (a `String`) and -/// - the identifiers of the childrens of the expression -/// concatenated. -/// -/// An identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no -/// collision (as low as possible)" -/// -/// Since an identifier is likely to be copied many times, it is better that an identifier -/// is small or "copy". otherwise some kinds of reference count is needed. String -/// description here is not such a good choose. -type Identifier = String; +/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and +/// "have no collision (as low as possible)" +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +struct Identifier<'n> { + // Hash of `expr` built up incrementally during the first, visiting traversal, but its + // value is not necessarily equal to `expr.hash()`. + hash: u64, + expr: &'n Expr, +} + +impl<'n> Identifier<'n> { + fn new(expr: &'n Expr, random_state: &RandomState) -> Self { + let mut hasher = random_state.build_hasher(); + expr.hash_node(&mut hasher); + let hash = hasher.finish(); + Self { hash, expr } + } + + fn combine(mut self, other: Option) -> Self { + other.map_or(self, |other_id| { + self.hash = combine_hashes(self.hash, other_id.hash); + self + }) + } +} + +impl Hash for Identifier<'_> { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); + } +} /// A cache that contains the postorder index and the identifier of expression tree nodes /// by the preorder index of the nodes. @@ -83,14 +104,14 @@ type Identifier = String; /// (0, "b") /// ] /// ``` -type IdArray = Vec<(usize, Identifier)>; +type IdArray<'n> = Vec<(usize, Option>)>; /// A map that contains the number of occurrences of expressions by their identifiers. -type ExprStats = HashMap; +type ExprStats<'n> = HashMap, usize>; /// A map that contains the common expressions and their alias extracted during the /// second, rewriting traversal. -type CommonExprs = IndexMap; +type CommonExprs<'n> = IndexMap, (Expr, String)>; /// Performs Common Sub-expression Elimination optimization. /// @@ -118,21 +139,86 @@ type CommonExprs = IndexMap; /// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here /// ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once /// ``` -pub struct CommonSubexprEliminate {} +pub struct CommonSubexprEliminate { + random_state: RandomState, +} impl CommonSubexprEliminate { + pub fn new() -> Self { + Self { + random_state: RandomState::new(), + } + } + + /// Returns the identifier list for each element in `exprs` and a flag to indicate if + /// rewrite phase of CSE make sense. + /// + /// Returns and array with 1 element for each input expr in `exprs` + /// + /// Each element is itself the result of [`CommonSubexprEliminate::expr_to_identifier`] for that expr + /// (e.g. the identifiers for each node in the tree) + fn to_arrays<'n>( + &self, + exprs: &'n [Expr], + expr_stats: &mut ExprStats<'n>, + expr_mask: ExprMask, + ) -> Result<(bool, Vec>)> { + let mut found_common = false; + exprs + .iter() + .map(|e| { + let mut id_array = vec![]; + self.expr_to_identifier(e, expr_stats, &mut id_array, expr_mask) + .map(|fc| { + found_common |= fc; + + id_array + }) + }) + .collect::>>() + .map(|id_arrays| (found_common, id_arrays)) + } + + /// Add an identifier to `id_array` for every subexpression in this tree. + fn expr_to_identifier<'n>( + &self, + expr: &'n Expr, + expr_stats: &mut ExprStats<'n>, + id_array: &mut IdArray<'n>, + expr_mask: ExprMask, + ) -> Result { + // Don't consider volatile expressions for CSE. + Ok(if expr.is_volatile()? { + false + } else { + let mut visitor = ExprIdentifierVisitor { + expr_stats, + id_array, + visit_stack: vec![], + down_index: 0, + up_index: 0, + expr_mask, + random_state: &self.random_state, + found_common: false, + }; + expr.visit(&mut visitor)?; + + visitor.found_common + }) + } + /// Rewrites `exprs_list` with common sub-expressions replaced with a new /// column. /// /// `common_exprs` is updated with any sub expressions that were replaced. /// /// Returns the rewritten expressions - fn rewrite_exprs_list( + fn rewrite_exprs_list<'n>( &self, exprs_list: Vec>, - arrays_list: &[&[IdArray]], - expr_stats: &ExprStats, - common_exprs: &mut CommonExprs, + arrays_list: Vec>>, + expr_stats: &ExprStats<'n>, + common_exprs: &mut CommonExprs<'n>, alias_generator: &AliasGenerator, ) -> Result>>> { let mut transformed = false; @@ -175,7 +261,7 @@ impl CommonSubexprEliminate { fn rewrite_expr( &self, exprs_list: Vec>, - arrays_list: &[&[IdArray]], + arrays_list: Vec>, input: LogicalPlan, expr_stats: &ExprStats, config: &dyn OptimizerConfig, @@ -275,68 +361,95 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result> { // collect all window expressions from any number of LogicalPlanWindow - let ConsecutiveWindowExprs { - window_exprs, - arrays_per_window, - expr_stats, - plan, - } = ConsecutiveWindowExprs::try_new(window)?; + let (mut window_exprs, mut window_schemas, mut plan) = + get_consecutive_window_exprs(window); - let arrays_per_window = arrays_per_window + let mut found_common = false; + let mut expr_stats = ExprStats::new(); + let arrays_per_window = window_exprs .iter() - .map(|arrays| arrays.as_slice()) - .collect::>(); + .map(|window_expr| { + self.to_arrays(window_expr, &mut expr_stats, ExprMask::Normal) + .map(|(fc, id_arrays)| { + found_common |= fc; - // save the original names - let name_preserver = NamePreserver::new(&plan); - let mut saved_names = window_exprs - .iter() - .map(|exprs| { - exprs - .iter() - .map(|expr| name_preserver.save(expr)) - .collect::>>() + id_arrays + }) }) .collect::>>()?; - assert_eq!(window_exprs.len(), arrays_per_window.len()); - let num_window_exprs = window_exprs.len(); - let rewritten_window_exprs = self.rewrite_expr( - window_exprs, - &arrays_per_window, - plan, - &expr_stats, - config, - )?; - let transformed = rewritten_window_exprs.transformed; + if found_common { + // save the original names + let name_preserver = NamePreserver::new(&plan); + let mut saved_names = window_exprs + .iter() + .map(|exprs| { + exprs + .iter() + .map(|expr| name_preserver.save(expr)) + .collect::>>() + }) + .collect::>>()?; - let (mut new_expr, new_input) = rewritten_window_exprs.data; + assert_eq!(window_exprs.len(), arrays_per_window.len()); + let num_window_exprs = window_exprs.len(); + let rewritten_window_exprs = self.rewrite_expr( + // Must clone as Identifiers use references to original expressions so we + // have to keep the original expressions intact. + window_exprs.clone(), + arrays_per_window, + plan, + &expr_stats, + config, + )?; + let transformed = rewritten_window_exprs.transformed; + assert!(transformed); - let mut plan = new_input; + let (mut new_expr, new_input) = rewritten_window_exprs.data; - // Construct consecutive window operator, with their corresponding new - // window expressions. - // - // Note this iterates over, `new_expr` and `saved_names` which are the - // same length, in reverse order - assert_eq!(num_window_exprs, new_expr.len()); - assert_eq!(num_window_exprs, saved_names.len()); - while let (Some(new_window_expr), Some(saved_names)) = - (new_expr.pop(), saved_names.pop()) - { - assert_eq!(new_window_expr.len(), saved_names.len()); + let mut plan = new_input; - // Rename re-written window expressions with original name, to - // preserve the output schema - let new_window_expr = new_window_expr - .into_iter() - .zip(saved_names.into_iter()) - .map(|(new_window_expr, saved_name)| saved_name.restore(new_window_expr)) - .collect::>>()?; - plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?); - } + // Construct consecutive window operator, with their corresponding new + // window expressions. + // + // Note this iterates over, `new_expr` and `saved_names` which are the + // same length, in reverse order + assert_eq!(num_window_exprs, new_expr.len()); + assert_eq!(num_window_exprs, saved_names.len()); + while let (Some(new_window_expr), Some(saved_names)) = + (new_expr.pop(), saved_names.pop()) + { + assert_eq!(new_window_expr.len(), saved_names.len()); + + // Rename re-written window expressions with original name, to + // preserve the output schema + let new_window_expr = new_window_expr + .into_iter() + .zip(saved_names.into_iter()) + .map(|(new_window_expr, saved_name)| { + saved_name.restore(new_window_expr) + }) + .collect::>>()?; + plan = LogicalPlan::Window(Window::try_new( + new_window_expr, + Arc::new(plan), + )?); + } + + Ok(Transformed::new_transformed(plan, transformed)) + } else { + while let (Some(window_expr), Some(schema)) = + (window_exprs.pop(), window_schemas.pop()) + { + plan = LogicalPlan::Window(Window { + input: Arc::new(plan), + window_expr, + schema, + }); + } - Ok(Transformed::new_transformed(plan, transformed)) + Ok(Transformed::no(plan)) + } } fn try_optimize_aggregate( @@ -351,56 +464,112 @@ impl CommonSubexprEliminate { schema: orig_schema, .. } = aggregate; - let mut expr_stats = ExprStats::new(); - // track transformed information let mut transformed = false; - // rewrite inputs - let group_arrays = to_arrays(&group_expr, &mut expr_stats, ExprMask::Normal)?; - let aggr_arrays = to_arrays(&aggr_expr, &mut expr_stats, ExprMask::Normal)?; - let name_perserver = NamePreserver::new_for_projection(); let saved_names = aggr_expr .iter() .map(|expr| name_perserver.save(expr)) .collect::>>()?; - // rewrite both group exprs and aggr_expr - let rewritten = self.rewrite_expr( - vec![group_expr, aggr_expr], - &[&group_arrays, &aggr_arrays], - unwrap_arc(input), - &expr_stats, - config, - )?; - transformed |= rewritten.transformed; - let (mut new_expr, new_input) = rewritten.data; - - // note the reversed pop order. - let new_aggr_expr = pop_expr(&mut new_expr)?; - let new_group_expr = pop_expr(&mut new_expr)?; + let mut expr_stats = ExprStats::new(); + // rewrite inputs + let (group_found_common, group_arrays) = + self.to_arrays(&group_expr, &mut expr_stats, ExprMask::Normal)?; + let (aggr_found_common, aggr_arrays) = + self.to_arrays(&aggr_expr, &mut expr_stats, ExprMask::Normal)?; + let (new_aggr_expr, new_group_expr, new_input) = + if group_found_common || aggr_found_common { + // rewrite both group exprs and aggr_expr + let rewritten = self.rewrite_expr( + // Must clone as Identifiers use references to original expressions so + // we have to keep the original expressions intact. + vec![group_expr.clone(), aggr_expr.clone()], + vec![group_arrays, aggr_arrays], + unwrap_arc(input), + &expr_stats, + config, + )?; + assert!(rewritten.transformed); + transformed |= rewritten.transformed; + let (mut new_expr, new_input) = rewritten.data; + + // note the reversed pop order. + let new_aggr_expr = pop_expr(&mut new_expr)?; + let new_group_expr = pop_expr(&mut new_expr)?; + + (new_aggr_expr, new_group_expr, Arc::new(new_input)) + } else { + (aggr_expr, group_expr, input) + }; // create potential projection on top let mut expr_stats = ExprStats::new(); - let new_input_schema = Arc::clone(new_input.schema()); - let aggr_arrays = to_arrays( + let (aggr_found_common, aggr_arrays) = self.to_arrays( &new_aggr_expr, &mut expr_stats, ExprMask::NormalAndAggregates, )?; - let mut common_exprs = IndexMap::new(); - let mut rewritten_exprs = self.rewrite_exprs_list( - vec![new_aggr_expr.clone()], - &[&aggr_arrays], - &expr_stats, - &mut common_exprs, - &config.alias_generator(), - )?; - transformed |= rewritten_exprs.transformed; - let rewritten = pop_expr(&mut rewritten_exprs.data)?; + if aggr_found_common { + let mut common_exprs = CommonExprs::new(); + let mut rewritten_exprs = self.rewrite_exprs_list( + // Must clone as Identifiers use references to original expressions so we + // have to keep the original expressions intact. + vec![new_aggr_expr.clone()], + vec![aggr_arrays], + &expr_stats, + &mut common_exprs, + &config.alias_generator(), + )?; + assert!(rewritten_exprs.transformed); + let rewritten = pop_expr(&mut rewritten_exprs.data)?; + + assert!(!common_exprs.is_empty()); + let mut agg_exprs = common_exprs + .into_values() + .map(|(expr, expr_alias)| expr.alias(expr_alias)) + .collect::>(); + + let new_input_schema = Arc::clone(new_input.schema()); + let mut proj_exprs = vec![]; + for expr in &new_group_expr { + extract_expressions(expr, &new_input_schema, &mut proj_exprs)? + } + for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { + if expr_rewritten == expr_orig { + if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { + agg_exprs.push(expr.alias(&name)); + proj_exprs.push(Expr::Column(Column::from_name(name))); + } else { + let expr_alias = config.alias_generator().next(CSE_PREFIX); + let (qualifier, field) = + expr_rewritten.to_field(&new_input_schema)?; + let out_name = qualified_name(qualifier.as_ref(), field.name()); + + agg_exprs.push(expr_rewritten.alias(&expr_alias)); + proj_exprs.push( + Expr::Column(Column::from_name(expr_alias)).alias(out_name), + ); + } + } else { + proj_exprs.push(expr_rewritten); + } + } + + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + new_input, + new_group_expr, + agg_exprs, + )?); - if common_exprs.is_empty() { + Projection::try_new(proj_exprs, Arc::new(agg)) + .map(LogicalPlan::Projection) + .map(Transformed::yes) + } else { + // TODO: How exactly can the name or the schema change in this case? + // In theory `new_aggr_expr` and `new_group_expr` are either the original expressions or they were crafted via `rewrite_expr()`, that keeps the original expression names. + // If this is really needed can we have UT for it? // Alias aggregation expressions if they have changed let new_aggr_expr = new_aggr_expr .into_iter() @@ -409,57 +578,19 @@ impl CommonSubexprEliminate { .collect::>>()?; // Since group_expr may have changed, schema may also. Use try_new method. let new_agg = if transformed { - Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr)? + Aggregate::try_new(new_input, new_group_expr, new_aggr_expr)? } else { Aggregate::try_new_with_schema( - Arc::new(new_input), + new_input, new_group_expr, new_aggr_expr, orig_schema, )? }; let new_agg = LogicalPlan::Aggregate(new_agg); - return Ok(Transformed::new_transformed(new_agg, transformed)); - } - let mut agg_exprs = common_exprs - .into_values() - .map(|(expr, expr_alias)| expr.alias(expr_alias)) - .collect::>(); - - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &new_input_schema, &mut proj_exprs)? - } - for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { - agg_exprs.push(expr.alias(&name)); - proj_exprs.push(Expr::Column(Column::from_name(name))); - } else { - let expr_alias = config.alias_generator().next(CSE_PREFIX); - let (qualifier, field) = - expr_rewritten.to_field(&new_input_schema)?; - let out_name = qualified_name(qualifier.as_ref(), field.name()); - - agg_exprs.push(expr_rewritten.alias(&expr_alias)); - proj_exprs.push( - Expr::Column(Column::from_name(expr_alias)).alias(out_name), - ); - } - } else { - proj_exprs.push(expr_rewritten); - } - } - - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(new_input), - new_group_expr, - agg_exprs, - )?); - Projection::try_new(proj_exprs, Arc::new(agg)) - .map(LogicalPlan::Projection) - .map(Transformed::yes) + Ok(Transformed::new_transformed(new_agg, transformed)) + } } /// Rewrites the expr list and input to remove common subexpressions @@ -483,13 +614,27 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result, LogicalPlan)>> { let mut expr_stats = ExprStats::new(); - let arrays = to_arrays(&expr, &mut expr_stats, ExprMask::Normal)?; - - self.rewrite_expr(vec![expr], &[&arrays], input, &expr_stats, config)? - .map_data(|(mut new_expr, new_input)| { + let (found_common, id_arrays) = + self.to_arrays(&expr, &mut expr_stats, ExprMask::Normal)?; + + if found_common { + let rewritten = self.rewrite_expr( + // Must clone as Identifiers use references to original expressions so we + // have to keep the original expressions intact. + vec![expr.clone()], + vec![id_arrays], + input, + &expr_stats, + config, + )?; + assert!(rewritten.transformed); + rewritten.map_data(|(mut new_expr, new_input)| { assert_eq!(new_expr.len(), 1); Ok((new_expr.pop().unwrap(), new_input)) }) + } else { + Ok(Transformed::no((expr, input))) + } } } @@ -507,7 +652,7 @@ impl CommonSubexprEliminate { /// ``` /// /// Returns: -/// * `window_exprs`: `[a, b, c, d]` +/// * `window_exprs`: `[[a, b, c], [d]]` /// * InputPlan /// /// Consecutive window expressions may refer to same complex expression. @@ -524,52 +669,27 @@ impl CommonSubexprEliminate { /// ``` /// /// where, it is referred once by each `WindowAggr` (total of 2) in the plan. -struct ConsecutiveWindowExprs { - window_exprs: Vec>, - /// result of calling `to_arrays` on each set of window exprs - arrays_per_window: Vec>>, - expr_stats: ExprStats, - /// input plan to the window - plan: LogicalPlan, -} - -impl ConsecutiveWindowExprs { - fn try_new(window: Window) -> Result { - let mut window_exprs = vec![]; - let mut arrays_per_window = vec![]; - let mut expr_stats = ExprStats::new(); - - let mut plan = LogicalPlan::Window(window); - while let LogicalPlan::Window(Window { - input, window_expr, .. - }) = plan - { - plan = unwrap_arc(input); - - let arrays = to_arrays(&window_expr, &mut expr_stats, ExprMask::Normal)?; - - window_exprs.push(window_expr); - arrays_per_window.push(arrays); - } - - Ok(Self { - window_exprs, - arrays_per_window, - expr_stats, - plan, - }) +fn get_consecutive_window_exprs( + window: Window, +) -> (Vec>, Vec, LogicalPlan) { + let mut window_exprs = vec![]; + let mut window_schemas = vec![]; + let mut plan = LogicalPlan::Window(window); + while let LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) = plan + { + window_exprs.push(window_expr); + window_schemas.push(schema); + + plan = unwrap_arc(input); } + (window_exprs, window_schemas, plan) } impl OptimizerRule for CommonSubexprEliminate { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called CommonSubexprEliminate::rewrite") - } - fn supports_rewrite(&self) -> bool { true } @@ -640,41 +760,12 @@ impl Default for CommonSubexprEliminate { } } -impl CommonSubexprEliminate { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - fn pop_expr(new_expr: &mut Vec>) -> Result> { new_expr .pop() .ok_or_else(|| internal_datafusion_err!("Failed to pop expression")) } -/// Returns the identifier list for each element in `exprs` -/// -/// Returns and array with 1 element for each input expr in `exprs` -/// -/// Each element is itself the result of [`expr_to_identifier`] for that expr -/// (e.g. the identifiers for each node in the tree) -fn to_arrays( - exprs: &[Expr], - expr_stats: &mut ExprStats, - expr_mask: ExprMask, -) -> Result> { - exprs - .iter() - .map(|e| { - let mut id_array = vec![]; - expr_to_identifier(e, expr_stats, &mut id_array, expr_mask)?; - - Ok(id_array) - }) - .collect() -} - /// Build the "intermediate" projection plan that evaluates the extracted common /// expressions. /// @@ -798,45 +889,48 @@ impl ExprMask { /// /// `Expr` without sub-expr (column, literal etc.) will not have identifier /// because they should not be recognized as common sub-expr. -struct ExprIdentifierVisitor<'a> { +struct ExprIdentifierVisitor<'a, 'n> { // statistics of expressions - expr_stats: &'a mut ExprStats, + expr_stats: &'a mut ExprStats<'n>, // cache to speed up second traversal - id_array: &'a mut IdArray, + id_array: &'a mut IdArray<'n>, // inner states - visit_stack: Vec, + visit_stack: Vec>, // preorder index, start from 0. down_index: usize, // postorder index, start from 0. up_index: usize, // which expression should be skipped? expr_mask: ExprMask, + // a `RandomState` to generate hashes during the first traversal + random_state: &'a RandomState, + // a flag to indicate that common expression found + found_common: bool, } /// Record item that used when traversing a expression tree. -enum VisitRecord { +enum VisitRecord<'n> { /// `usize` postorder index assigned in `f-down`(). Starts from 0. EnterMark(usize), /// the node's children were skipped => jump to f_up on same node JumpMark, /// Accumulated identifier of sub expression. - ExprItem(Identifier), + ExprItem(Identifier<'n>), } -impl ExprIdentifierVisitor<'_> { +impl<'n> ExprIdentifierVisitor<'_, 'n> { /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { - let mut desc = String::new(); + fn pop_enter_mark(&mut self) -> Option<(usize, Option>)> { + let mut expr_id = None; while let Some(item) = self.visit_stack.pop() { match item { VisitRecord::EnterMark(idx) => { - return Some((idx, desc)); + return Some((idx, expr_id)); } VisitRecord::ExprItem(id) => { - desc.push('|'); - desc.push_str(&id); + expr_id = Some(id.combine(expr_id)); } VisitRecord::JumpMark => return None, } @@ -845,21 +939,22 @@ impl ExprIdentifierVisitor<'_> { } } -impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_> { +impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { type Node = Expr; fn f_down(&mut self, expr: &'n Expr) -> Result { - // related to https://github.com/apache/arrow-datafusion/issues/8814 - // If the expr contain volatile expression or is a short-circuit expression, skip it. - // TODO: propagate is_volatile state bottom-up + consider non-volatile sub-expressions for CSE + // TODO: consider non-volatile sub-expressions for CSE // TODO: consider surely executed children of "short circuited"s for CSE - if expr.short_circuits() || expr.is_volatile()? { + + // If an expression can short circuit its children then don't consider it for CSE + // (https://github.com/apache/arrow-datafusion/issues/8814). + if expr.short_circuits() { self.visit_stack.push(VisitRecord::JumpMark); return Ok(TreeNodeRecursion::Jump); } - self.id_array.push((0, "".to_string())); + self.id_array.push((0, None)); self.visit_stack .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; @@ -872,13 +967,16 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_> { return Ok(TreeNodeRecursion::Continue); }; - let expr_id = expr_identifier(expr, sub_expr_id); + let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id); self.id_array[down_index].0 = self.up_index; if !self.expr_mask.ignores(expr) { - self.id_array[down_index].1.clone_from(&expr_id); - let count = self.expr_stats.entry(expr_id.clone()).or_insert(0); + self.id_array[down_index].1 = Some(expr_id); + let count = self.expr_stats.entry(expr_id).or_insert(0); *count += 1; + if *count > 1 { + self.found_common = true; + } } self.visit_stack.push(VisitRecord::ExprItem(expr_id)); self.up_index += 1; @@ -887,40 +985,17 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_> { } } -fn expr_identifier(expr: &Expr, sub_expr_identifier: Identifier) -> Identifier { - format!("{{{expr}{sub_expr_identifier}}}") -} - -/// Go through an expression tree and generate identifier for every node in this tree. -fn expr_to_identifier( - expr: &Expr, - expr_stats: &mut ExprStats, - id_array: &mut IdArray, - expr_mask: ExprMask, -) -> Result<()> { - expr.visit(&mut ExprIdentifierVisitor { - expr_stats, - id_array, - visit_stack: vec![], - down_index: 0, - up_index: 0, - expr_mask, - })?; - - Ok(()) -} - /// Rewrite expression by replacing detected common sub-expression with /// the corresponding temporary column name. That column contains the /// evaluate result of replaced expression. -struct CommonSubexprRewriter<'a> { +struct CommonSubexprRewriter<'a, 'n> { // statistics of expressions - expr_stats: &'a ExprStats, + expr_stats: &'a ExprStats<'n>, // cache to speed up second traversal - id_array: &'a IdArray, + id_array: &'a IdArray<'n>, // common expression, that are replaced during the second traversal, are collected to // this map - common_exprs: &'a mut CommonExprs, + common_exprs: &'a mut CommonExprs<'n>, // preorder index, starts from 0. down_index: usize, // how many aliases have we seen so far @@ -929,17 +1004,9 @@ struct CommonSubexprRewriter<'a> { alias_generator: &'a AliasGenerator, } -impl TreeNodeRewriter for CommonSubexprRewriter<'_> { +impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result> { - if matches!(expr, Expr::Alias(_)) { - self.alias_counter -= 1 - } - - Ok(Transformed::no(expr)) - } - fn f_down(&mut self, expr: Expr) -> Result> { if matches!(expr, Expr::Alias(_)) { self.alias_counter += 1; @@ -948,33 +1015,32 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - if expr.short_circuits() || expr.is_volatile()? { + if expr.short_circuits() { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } - let (up_index, expr_id) = &self.id_array[self.down_index]; + let (up_index, expr_id) = self.id_array[self.down_index]; self.down_index += 1; // skip `Expr`s without identifier (empty identifier). - if expr_id.is_empty() { + let Some(expr_id) = expr_id else { return Ok(Transformed::no(expr)); - } + }; - let count = self.expr_stats.get(expr_id).unwrap(); + let count = self.expr_stats.get(&expr_id).unwrap(); if *count > 1 { // step index to skip all sub-node (which has smaller series number). while self.down_index < self.id_array.len() - && self.id_array[self.down_index].0 < *up_index + && self.id_array[self.down_index].0 < up_index { self.down_index += 1; } let expr_name = expr.display_name()?; - let (_, expr_alias) = - self.common_exprs.entry(expr_id.clone()).or_insert_with(|| { - let expr_alias = self.alias_generator.next(CSE_PREFIX); - (expr, expr_alias) - }); + let (_, expr_alias) = self.common_exprs.entry(expr_id).or_insert_with(|| { + let expr_alias = self.alias_generator.next(CSE_PREFIX); + (expr, expr_alias) + }); // alias the expressions without an `Alias` ancestor node let rewritten = if self.alias_counter > 0 { @@ -989,44 +1055,56 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { Ok(Transformed::no(expr)) } } + + fn f_up(&mut self, expr: Expr) -> Result> { + if matches!(expr, Expr::Alias(_)) { + self.alias_counter -= 1 + } + + Ok(Transformed::no(expr)) + } } /// Replace common sub-expression in `expr` with the corresponding temporary /// column name, updating `common_exprs` with any replaced expressions -fn replace_common_expr( +fn replace_common_expr<'n>( expr: Expr, - id_array: &IdArray, - expr_stats: &ExprStats, - common_exprs: &mut CommonExprs, + id_array: &IdArray<'n>, + expr_stats: &ExprStats<'n>, + common_exprs: &mut CommonExprs<'n>, alias_generator: &AliasGenerator, ) -> Result> { - expr.rewrite(&mut CommonSubexprRewriter { - expr_stats, - id_array, - common_exprs, - down_index: 0, - alias_counter: 0, - alias_generator, - }) + if id_array.is_empty() { + Ok(Transformed::no(expr)) + } else { + expr.rewrite(&mut CommonSubexprRewriter { + expr_stats, + id_array, + common_exprs, + down_index: 0, + alias_counter: 0, + alias_generator, + }) + } } #[cfg(test)] mod test { + use std::collections::HashSet; use std::iter; use arrow::datatypes::{DataType, Field, Schema}; - + use datafusion_expr::expr::AggregateFunction; use datafusion_expr::logical_plan::{table_scan, JoinType}; - - use datafusion_expr::{avg, lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, Signature, + grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Signature, SimpleAggregateUDF, Volatility, }; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use crate::optimizer::OptimizerContext; use crate::test::*; - use datafusion_expr::test::function_stub::sum; + use datafusion_expr::test::function_stub::{avg, sum}; use super::*; @@ -1036,7 +1114,7 @@ mod test { config: Option<&dyn OptimizerConfig>, ) { assert_eq!(expected, format!("{plan:?}"), "Unexpected starting plan"); - let optimizer = CommonSubexprEliminate {}; + let optimizer = CommonSubexprEliminate::new(); let default_config = OptimizerContext::new(); let config = config.unwrap_or(&default_config); let optimized_plan = optimizer.rewrite(plan, config).unwrap(); @@ -1054,7 +1132,7 @@ mod test { plan: LogicalPlan, config: Option<&dyn OptimizerConfig>, ) { - let optimizer = CommonSubexprEliminate {}; + let optimizer = CommonSubexprEliminate::new(); let default_config = OptimizerContext::new(); let config = config.unwrap_or(&default_config); let optimized_plan = optimizer.rewrite(plan, config).unwrap(); @@ -1066,51 +1144,147 @@ mod test { #[test] fn id_array_visitor() -> Result<()> { - let expr = ((sum(col("a") + lit(1))) - avg(col("c"))) * lit(2); + let optimizer = CommonSubexprEliminate::new(); + + let a_plus_1 = col("a") + lit(1); + let avg_c = avg(col("c")); + let sum_a_plus_1 = sum(a_plus_1); + let sum_a_plus_1_minus_avg_c = sum_a_plus_1 - avg_c; + let expr = sum_a_plus_1_minus_avg_c * lit(2); + + let Expr::BinaryExpr(BinaryExpr { + left: sum_a_plus_1_minus_avg_c, + .. + }) = &expr + else { + panic!("Cannot extract subexpression reference") + }; + let Expr::BinaryExpr(BinaryExpr { + left: sum_a_plus_1, + right: avg_c, + .. + }) = sum_a_plus_1_minus_avg_c.as_ref() + else { + panic!("Cannot extract subexpression reference") + }; + let Expr::AggregateFunction(AggregateFunction { + args: a_plus_1_vec, .. + }) = sum_a_plus_1.as_ref() + else { + panic!("Cannot extract subexpression reference") + }; + let a_plus_1 = &a_plus_1_vec.as_slice()[0]; // skip aggregates let mut id_array = vec![]; - expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, ExprMask::Normal)?; + optimizer.expr_to_identifier( + &expr, + &mut ExprStats::new(), + &mut id_array, + ExprMask::Normal, + )?; + + // Collect distinct hashes and set them to 0 in `id_array` + fn collect_hashes(id_array: &mut IdArray) -> HashSet { + id_array + .iter_mut() + .flat_map(|(_, expr_id_option)| { + expr_id_option.as_mut().map(|expr_id| { + let hash = expr_id.hash; + expr_id.hash = 0; + hash + }) + }) + .collect::>() + } + + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 3); let expected = vec![ - (8, "{(sum(a + Int32(1)) - AVG(c)) * Int32(2)|{Int32(2)}|{sum(a + Int32(1)) - AVG(c)|{AVG(c)|{c}}|{sum(a + Int32(1))|{a + Int32(1)|{Int32(1)}|{a}}}}}"), - (6, "{sum(a + Int32(1)) - AVG(c)|{AVG(c)|{c}}|{sum(a + Int32(1))|{a + Int32(1)|{Int32(1)}|{a}}}}"), - (3, ""), - (2, "{a + Int32(1)|{Int32(1)}|{a}}"), - (0, ""), - (1, ""), - (5, ""), - (4, ""), - (7, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); + ( + 8, + Some(Identifier { + hash: 0, + expr: &expr, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + expr: sum_a_plus_1_minus_avg_c, + }), + ), + (3, None), + ( + 2, + Some(Identifier { + hash: 0, + expr: a_plus_1, + }), + ), + (0, None), + (1, None), + (5, None), + (4, None), + (7, None), + ]; assert_eq!(expected, id_array); // include aggregates let mut id_array = vec![]; - expr_to_identifier( + optimizer.expr_to_identifier( &expr, - &mut HashMap::new(), + &mut ExprStats::new(), &mut id_array, ExprMask::NormalAndAggregates, )?; + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 5); + let expected = vec![ - (8, "{(sum(a + Int32(1)) - AVG(c)) * Int32(2)|{Int32(2)}|{sum(a + Int32(1)) - AVG(c)|{AVG(c)|{c}}|{sum(a + Int32(1))|{a + Int32(1)|{Int32(1)}|{a}}}}}"), - (6, "{sum(a + Int32(1)) - AVG(c)|{AVG(c)|{c}}|{sum(a + Int32(1))|{a + Int32(1)|{Int32(1)}|{a}}}}"), - (3, "{sum(a + Int32(1))|{a + Int32(1)|{Int32(1)}|{a}}}"), - (2, "{a + Int32(1)|{Int32(1)}|{a}}"), - (0, ""), - (1, ""), - (5, "{AVG(c)|{c}}"), - (4, ""), - (7, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); + ( + 8, + Some(Identifier { + hash: 0, + expr: &expr, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + expr: sum_a_plus_1_minus_avg_c, + }), + ), + ( + 3, + Some(Identifier { + hash: 0, + expr: sum_a_plus_1, + }), + ), + ( + 2, + Some(Identifier { + hash: 0, + expr: a_plus_1, + }), + ), + (0, None), + (1, None), + ( + 5, + Some(Identifier { + hash: 0, + expr: avg_c, + }), + ), + (4, None), + (7, None), + ]; assert_eq!(expected, id_array); Ok(()) @@ -1211,8 +1385,8 @@ mod test { )? .build()?; - let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS AVG(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, AVG(test.b) AS col3, AVG(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\ + let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\ + \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\ \n TableScan: test"; assert_optimized_plan_eq(expected, plan, None); @@ -1230,8 +1404,8 @@ mod test { )? .build()?; - let expected = "Projection: Int32(1) + __common_expr_1 AS AVG(test.a), Int32(1) - __common_expr_1 AS AVG(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\ - \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\ + let expected = "Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\ + \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\ \n TableScan: test"; assert_optimized_plan_eq(expected, plan, None); @@ -1247,7 +1421,7 @@ mod test { )? .build()?; - let expected ="Aggregate: groupBy=[[]], aggr=[[AVG(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ + let expected ="Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; @@ -1264,7 +1438,7 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[AVG(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ + let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; @@ -1285,8 +1459,8 @@ mod test { )? .build()?; - let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS AVG(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\ - \n Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[AVG(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, AVG(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\ + let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\ + \n Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\ \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; @@ -1312,8 +1486,8 @@ mod test { )? .build()?; - let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS AVG(UInt32(1) + table.test.col.a), __common_expr_2 AS AVG(UInt32(1) + table.test.col.a)\ - \n Aggregate: groupBy=[[table.test.col.a]], aggr=[[AVG(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\ + let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)\ + \n Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\ \n Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\ \n TableScan: table.test"; @@ -1375,27 +1549,35 @@ mod test { Ok(()) } + fn test_identifier(hash: u64, expr: &Expr) -> Identifier { + Identifier { hash, expr } + } + #[test] fn redundant_project_fields() { let table_scan = test_table_scan().unwrap(); + let c_plus_a = col("c") + col("a"); + let b_plus_a = col("b") + col("a"); let common_exprs_1 = CommonExprs::from([ ( - "c+a".to_string(), - (col("c") + col("a"), format!("{CSE_PREFIX}_1")), + test_identifier(0, &c_plus_a), + (c_plus_a.clone(), format!("{CSE_PREFIX}_1")), ), ( - "b+a".to_string(), - (col("b") + col("a"), format!("{CSE_PREFIX}_2")), + test_identifier(1, &b_plus_a), + (b_plus_a.clone(), format!("{CSE_PREFIX}_2")), ), ]); + let c_plus_a_2 = col(format!("{CSE_PREFIX}_1")); + let b_plus_a_2 = col(format!("{CSE_PREFIX}_2")); let common_exprs_2 = CommonExprs::from([ ( - "c+a".to_string(), - (col(format!("{CSE_PREFIX}_1")), format!("{CSE_PREFIX}_3")), + test_identifier(3, &c_plus_a_2), + (c_plus_a_2.clone(), format!("{CSE_PREFIX}_3")), ), ( - "b+a".to_string(), - (col(format!("{CSE_PREFIX}_2")), format!("{CSE_PREFIX}_4")), + test_identifier(4, &b_plus_a_2), + (b_plus_a_2.clone(), format!("{CSE_PREFIX}_4")), ), ]); let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap(); @@ -1416,24 +1598,28 @@ mod test { .unwrap() .build() .unwrap(); + let c_plus_a = col("test1.c") + col("test1.a"); + let b_plus_a = col("test1.b") + col("test1.a"); let common_exprs_1 = CommonExprs::from([ ( - "test1.c+test1.a".to_string(), - (col("test1.c") + col("test1.a"), format!("{CSE_PREFIX}_1")), + test_identifier(0, &c_plus_a), + (c_plus_a.clone(), format!("{CSE_PREFIX}_1")), ), ( - "test1.b+test1.a".to_string(), - (col("test1.b") + col("test1.a"), format!("{CSE_PREFIX}_2")), + test_identifier(1, &b_plus_a), + (b_plus_a.clone(), format!("{CSE_PREFIX}_2")), ), ]); + let c_plus_a_2 = col(format!("{CSE_PREFIX}_1")); + let b_plus_a_2 = col(format!("{CSE_PREFIX}_2")); let common_exprs_2 = CommonExprs::from([ ( - "test1.c+test1.a".to_string(), - (col(format!("{CSE_PREFIX}_1")), format!("{CSE_PREFIX}_3")), + test_identifier(3, &c_plus_a_2), + (c_plus_a_2.clone(), format!("{CSE_PREFIX}_3")), ), ( - "test1.b+test1.a".to_string(), - (col(format!("{CSE_PREFIX}_2")), format!("{CSE_PREFIX}_4")), + test_identifier(4, &b_plus_a_2), + (b_plus_a_2.clone(), format!("{CSE_PREFIX}_4")), ), ]); let project = build_common_expr_project_plan(join, common_exprs_1).unwrap(); @@ -1465,7 +1651,7 @@ mod test { .unwrap() .build() .unwrap(); - let rule = CommonSubexprEliminate {}; + let rule = CommonSubexprEliminate::new(); let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); assert!(!optimized_plan.transformed); let optimized_plan = optimized_plan.data; diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index e949e1921b97..5f8e0a85215a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -370,11 +370,14 @@ impl PullUpCorrelatedExpr { } } if let Some(pull_up_having) = &self.pull_up_having_expr { - let filter_apply_columns = pull_up_having.to_columns()?; + let filter_apply_columns = pull_up_having.column_refs(); for col in filter_apply_columns { - let col_expr = Expr::Column(col); - if !missing_exprs.contains(&col_expr) { - missing_exprs.push(col_expr) + // add to missing_exprs if not already there + let contains = missing_exprs + .iter() + .any(|expr| matches!(expr, Expr::Column(c) if c == col)); + if !contains { + missing_exprs.push(Expr::Column(col.clone())) } } } @@ -436,7 +439,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( Transformed::yes(Expr::Literal(ScalarValue::Null)) } AggregateFunctionDefinition::UDF(fun) => { - if fun.name() == "COUNT" { + if fun.name() == "count" { Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( 0, )))) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index e5e97b693c6a..81d6dc863af6 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -123,14 +123,6 @@ impl DecorrelatePredicateSubquery { } impl OptimizerRule for DecorrelatePredicateSubquery { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called DecorrelatePredicateSubquery::rewrite") - } - fn supports_rewrite(&self) -> bool { true } diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 9d871c50ad99..6d6f84373a36 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -77,14 +77,6 @@ impl EliminateCrossJoin { /// /// This fix helps to improve the performance of TPCH Q19. issue#78 impl OptimizerRule for EliminateCrossJoin { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateCrossJoin::rewrite") - } - fn supports_rewrite(&self) -> bool { true } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index 3dbfc750e899..e9d091d52b00 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -20,7 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{Aggregate, Expr, Sort}; use indexmap::IndexSet; @@ -63,14 +63,6 @@ impl Hash for SortExprWrapper { } } impl OptimizerRule for EliminateDuplicatedExpr { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateDuplicatedExpr::rewrite") - } - fn apply_order(&self) -> Option { Some(ApplyOrder::TopDown) } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 97f234942182..7c873b411d59 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -18,7 +18,7 @@ //! [`EliminateFilter`] replaces `where false` or `where null` with an empty relation. use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan}; @@ -41,14 +41,6 @@ impl EliminateFilter { } impl OptimizerRule for EliminateFilter { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateFilter::rewrite") - } - fn name(&self) -> &str { "eliminate_filter" } diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 7a8dd7aac249..c7869d9e4dd7 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -20,7 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility}; /// Optimizer rule that removes constant expressions from `GROUP BY` clause @@ -82,14 +82,6 @@ impl OptimizerRule for EliminateGroupByConstant { } } - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateGroupByConstant::rewrite") - } - fn name(&self) -> &str { "eliminate_group_by_constant" } @@ -176,8 +168,8 @@ mod tests { .build()?; let expected = "\ - Projection: test.a, UInt32(1), COUNT(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.c)]]\ + Projection: test.a, UInt32(1), count(test.c)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ \n TableScan: test\ "; @@ -196,8 +188,8 @@ mod tests { .build()?; let expected = "\ - Projection: Utf8(\"test\"), UInt32(123), COUNT(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(test.c)]]\ + Projection: Utf8(\"test\"), UInt32(123), count(test.c)\ + \n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\ \n TableScan: test\ "; @@ -216,7 +208,7 @@ mod tests { .build()?; let expected = "\ - Aggregate: groupBy=[[test.a, test.b]], aggr=[[COUNT(test.c)]]\ + Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\ \n TableScan: test\ "; @@ -257,8 +249,8 @@ mod tests { .build()?; let expected = "\ - Projection: UInt32(123) AS const, test.a, COUNT(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.c)]]\ + Projection: UInt32(123) AS const, test.a, count(test.c)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ \n TableScan: test\ "; @@ -282,8 +274,8 @@ mod tests { .build()?; let expected = "\ - Projection: scalar_fn_mock(UInt32(123)), test.a, COUNT(test.c)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.c)]]\ + Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ \n TableScan: test\ "; @@ -307,7 +299,7 @@ mod tests { .build()?; let expected = "\ - Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[COUNT(test.c)]]\ + Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\ \n TableScan: test\ "; diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index fea87e758790..c5115c87a0ed 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -19,7 +19,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{ logical_plan::{EmptyRelation, LogicalPlan}, @@ -38,14 +38,6 @@ impl EliminateJoin { } impl OptimizerRule for EliminateJoin { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateJoin::rewrite") - } - fn name(&self) -> &str { "eliminate_join" } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index bdb821bf3a1f..b0a75fa47c27 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -19,7 +19,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::logical_plan::{tree_node::unwrap_arc, EmptyRelation, LogicalPlan}; /// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is @@ -40,14 +40,6 @@ impl EliminateLimit { } impl OptimizerRule for EliminateLimit { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateLimit::rewrite") - } - fn name(&self) -> &str { "eliminate_limit" } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index aa6f2b497531..09407aed53cd 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -19,7 +19,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::{Distinct, LogicalPlan, Union}; use std::sync::Arc; @@ -36,14 +36,6 @@ impl EliminateNestedUnion { } impl OptimizerRule for EliminateNestedUnion { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateNestedUnion::rewrite") - } - fn name(&self) -> &str { "eliminate_nested_union" } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 7763e7d3b796..edf6b72d7e17 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -17,7 +17,7 @@ //! [`EliminateOneUnion`] eliminates single element `Union` use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{internal_err, tree_node::Transformed, Result}; +use datafusion_common::{tree_node::Transformed, Result}; use datafusion_expr::logical_plan::{tree_node::unwrap_arc, LogicalPlan, Union}; use crate::optimizer::ApplyOrder; @@ -34,14 +34,6 @@ impl EliminateOneUnion { } impl OptimizerRule for EliminateOneUnion { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateOneUnion::rewrite") - } - fn name(&self) -> &str { "eliminate_one_union" } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index c3c5d80922f9..ccc637a0eb01 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -17,7 +17,7 @@ //! [`EliminateOuterJoin`] converts `LEFT/RIGHT/FULL` joins to `INNER` joins use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{internal_err, Column, DFSchema, Result}; +use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; use datafusion_expr::{Expr, Filter, Operator}; @@ -60,14 +60,6 @@ impl EliminateOuterJoin { /// Attempt to eliminate outer joins. impl OptimizerRule for EliminateOuterJoin { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateOuterJoin::rewrite") - } - fn name(&self) -> &str { "eliminate_outer_join" } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 237c00352419..87d205139e8e 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -19,8 +19,8 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; +use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_common::{internal_err, DFSchema}; use datafusion_expr::utils::split_conjunction_owned; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; @@ -49,13 +49,6 @@ impl ExtractEquijoinPredicate { } impl OptimizerRule for ExtractEquijoinPredicate { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called ExtractEquijoinPredicate::rewrite") - } fn supports_rewrite(&self) -> bool { true } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index ecd1901abe58..381713662f10 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -20,7 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::utils::conjunction; use datafusion_expr::{ logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, LogicalPlan, @@ -35,14 +35,6 @@ use std::sync::Arc; pub struct FilterNullJoinKeys {} impl OptimizerRule for FilterNullJoinKeys { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called FilterNullJoinKeys::rewrite") - } - fn supports_rewrite(&self) -> bool { true } diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 11540d3e162e..2fbf77523bd1 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -70,14 +70,6 @@ impl OptimizeProjections { } impl OptimizerRule for OptimizeProjections { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called OptimizeProjections::rewrite") - } - fn name(&self) -> &str { "optimize_projections" } @@ -171,7 +163,7 @@ fn optimize_projections( // still need to create a correct aggregate, which may be optimized // out later. As an example, consider the following query: // - // SELECT COUNT(*) FROM (SELECT COUNT(*) FROM [...]) + // SELECT count(*) FROM (SELECT count(*) FROM [...]) // // which always returns 1. if new_aggr_expr.is_empty() @@ -479,10 +471,10 @@ fn merge_consecutive_projections(proj: Projection) -> Result::new(); - for columns in expr.iter().flat_map(|expr| expr.to_columns()) { + let mut column_referral_map = HashMap::<&Column, usize>::new(); + for columns in expr.iter().map(|expr| expr.column_refs()) { for col in columns.into_iter() { - *column_referral_map.entry(col.clone()).or_default() += 1; + *column_referral_map.entry(col).or_default() += 1; } } @@ -493,7 +485,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result 1 && !is_expr_trivial( &prev_projection.expr - [prev_projection.schema.index_of_column(&col).unwrap()], + [prev_projection.schema.index_of_column(col).unwrap()], ) }) { // no change @@ -625,12 +617,12 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { /// * `expr` - The expression to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -fn outer_columns(expr: &Expr, columns: &mut HashSet) { +fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) { // inspect_expr_pre doesn't handle subquery references, so find them explicitly expr.apply(|expr| { match expr { Expr::OuterReferenceColumn(_, col) => { - columns.insert(col.clone()); + columns.insert(col); } Expr::ScalarSubquery(subquery) => { outer_columns_helper_multi(&subquery.outer_ref_columns, columns); @@ -660,9 +652,9 @@ fn outer_columns(expr: &Expr, columns: &mut HashSet) { /// * `exprs` - The expressions to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -fn outer_columns_helper_multi<'a>( +fn outer_columns_helper_multi<'a, 'b>( exprs: impl IntoIterator, - columns: &mut HashSet, + columns: &'b mut HashSet<&'a Column>, ) { exprs.into_iter().for_each(|e| outer_columns(e, columns)); } @@ -1049,9 +1041,9 @@ mod tests { .build() .unwrap(); - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ \n Projection: \ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ \n TableScan: ?table? projection=[]"; assert_optimized_plan_equal(plan, expected) } @@ -1901,7 +1893,7 @@ mod tests { )? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ \n TableScan: test projection=[a, b, c]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/src/optimize_projections/required_indices.rs b/datafusion/optimizer/src/optimize_projections/required_indices.rs index 113c100bbd9b..3f32a0c36a9a 100644 --- a/datafusion/optimizer/src/optimize_projections/required_indices.rs +++ b/datafusion/optimizer/src/optimize_projections/required_indices.rs @@ -113,12 +113,12 @@ impl RequiredIndicies { /// * `expr`: An expression for which we want to find necessary field indices. fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) -> Result<()> { // TODO could remove these clones (and visit the expression directly) - let mut cols = expr.to_columns()?; + let mut cols = expr.column_refs(); // Get outer-referenced (subquery) columns: outer_columns(expr, &mut cols); self.indices.reserve(cols.len()); for col in cols { - if let Some(idx) = input_schema.maybe_index_of_column(&col) { + if let Some(idx) = input_schema.maybe_index_of_column(col) { self.indices.push(idx); } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 7a3ea6ed4cc1..14e5ac141eeb 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -83,9 +83,11 @@ pub trait OptimizerRule { )] fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result>; + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("Should have called rewrite") + } /// A human readable name for this optimizer rule fn name(&self) -> &str; @@ -100,7 +102,7 @@ pub trait OptimizerRule { /// Does this rule support rewriting owned plans (rather than by reference)? fn supports_rewrite(&self) -> bool { - false + true } /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes` @@ -667,14 +669,6 @@ mod tests { struct BadRule {} impl OptimizerRule for BadRule { - fn try_optimize( - &self, - _: &LogicalPlan, - _: &dyn OptimizerConfig, - ) -> Result> { - unreachable!() - } - fn name(&self) -> &str { "bad rule" } @@ -696,14 +690,6 @@ mod tests { struct GetTableScanRule {} impl OptimizerRule for GetTableScanRule { - fn try_optimize( - &self, - _: &LogicalPlan, - _: &dyn OptimizerConfig, - ) -> Result> { - unreachable!() - } - fn name(&self) -> &str { "get table_scan rule" } @@ -741,14 +727,6 @@ mod tests { } impl OptimizerRule for RotateProjectionRule { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _: &dyn OptimizerConfig, - ) -> Result> { - unreachable!() - } - fn name(&self) -> &str { "rotate_projection" } diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index dfcfc313efcc..63b357510f2f 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use datafusion_common::tree_node::Transformed; use datafusion_common::JoinType; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{plan_err, Result}; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{EmptyRelation, Projection, Union}; @@ -41,14 +41,6 @@ impl PropagateEmptyRelation { } impl OptimizerRule for PropagateEmptyRelation { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called PropagateEmptyRelation::rewrite") - } - fn name(&self) -> &str { "propagate_empty_relation" } @@ -96,9 +88,6 @@ impl OptimizerRule for PropagateEmptyRelation { LogicalPlan::Join(ref join) => { // TODO: For Join, more join type need to be careful: - // For LeftAnti Join, if the right side is empty, the Join result is left side(should exclude null ??). - // For RightAnti Join, if the left side is empty, the Join result is right side(should exclude null ??). - // For Full Join, only both sides are empty, the Join result is empty. // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side // columns + right side columns replaced with null values. // For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side @@ -106,6 +95,13 @@ impl OptimizerRule for PropagateEmptyRelation { let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; match join.join_type { + // For Full Join, only both sides are empty, the Join result is empty. + JoinType::Full if left_empty && right_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: join.schema.clone(), + }), + )), JoinType::Inner if left_empty || right_empty => Ok(Transformed::yes( LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -142,13 +138,19 @@ impl OptimizerRule for PropagateEmptyRelation { schema: join.schema.clone(), }), )), + JoinType::LeftAnti if right_empty => { + Ok(Transformed::yes((*join.left).clone())) + } + JoinType::RightAnti if left_empty => { + Ok(Transformed::yes((*join.right).clone())) + } JoinType::RightAnti if right_empty => Ok(Transformed::yes( LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: join.schema.clone(), }), )), - _ => Ok(Transformed::no(LogicalPlan::Join(join.clone()))), + _ => Ok(Transformed::no(plan)), } } LogicalPlan::Aggregate(ref agg) => { @@ -475,8 +477,39 @@ mod tests { assert_together_optimized_plan(plan, expected, eq) } + // TODO: fix this long name + fn assert_anti_join_empty_join_table_is_base_table( + anti_left_join: bool, + ) -> Result<()> { + // if we have an anti join with an empty join table, then the result is the base_table + let (left, right, join_type, expected) = if anti_left_join { + let left = test_table_scan()?; + let right = LogicalPlanBuilder::from(test_table_scan()?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + let expected = left.display_indent().to_string(); + (left, right, JoinType::LeftAnti, expected) + } else { + let right = test_table_scan()?; + let left = LogicalPlanBuilder::from(test_table_scan()?) + .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .build()?; + let expected = right.display_indent().to_string(); + (left, right, JoinType::RightAnti, expected) + }; + + let plan = LogicalPlanBuilder::from(left) + .join_using(right, join_type, vec![Column::from_name("a".to_string())])? + .build()?; + + assert_together_optimized_plan(plan, &expected, true) + } + #[test] fn test_join_empty_propagation_rules() -> Result<()> { + // test full join with empty left and empty right + assert_empty_left_empty_right_lp(true, true, JoinType::Full, true)?; + // test left join with empty left assert_empty_left_empty_right_lp(true, false, JoinType::Left, true)?; @@ -499,7 +532,13 @@ mod tests { assert_empty_left_empty_right_lp(true, false, JoinType::LeftAnti, true)?; // test right anti join empty right - assert_empty_left_empty_right_lp(false, true, JoinType::RightAnti, true) + assert_empty_left_empty_right_lp(false, true, JoinType::RightAnti, true)?; + + // test left anti join empty right + assert_anti_join_empty_join_table_is_base_table(true)?; + + // test right anti join empty left + assert_anti_join_empty_join_table_is_base_table(false) } #[test] diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 68339a84649d..fa432ad76de5 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -41,6 +41,7 @@ use datafusion_expr::{ }; use crate::optimizer::ApplyOrder; +use crate::utils::has_all_column_refs; use crate::{OptimizerConfig, OptimizerRule}; /// Optimizer rule for pushing (moving) filter expressions down in a plan so @@ -199,13 +200,7 @@ fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result>(); - let columns = predicate.to_columns()?; - - Ok(schema_columns - .intersection(&columns) - .collect::>() - .len() - == columns.len()) + Ok(has_all_column_refs(predicate, &schema_columns)) } /// Determine whether the predicate can evaluate as the join conditions @@ -372,14 +367,7 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { - let columns = expr.to_columns().ok().unwrap(); - - if schema_columns - .intersection(&columns) - .collect::>() - .len() - == columns.len() - { + if has_all_column_refs(expr, schema_columns) { predicate = Some(expr.clone()); } } @@ -561,12 +549,9 @@ fn infer_join_predicates( .filter_map(|predicate| { let mut join_cols_to_replace = HashMap::new(); - let columns = match predicate.to_columns() { - Ok(columns) => columns, - Err(e) => return Some(Err(e)), - }; + let columns = predicate.column_refs(); - for col in columns.iter() { + for &col in columns.iter() { for (l, r) in join_col_keys.iter() { if col == *l { join_cols_to_replace.insert(col, *r); @@ -596,14 +581,6 @@ fn infer_join_predicates( } impl OptimizerRule for PushDownFilter { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called PushDownFilter::rewrite") - } - fn name(&self) -> &str { "push_down_filter" } @@ -798,7 +775,7 @@ impl OptimizerRule for PushDownFilter { let mut keep_predicates = vec![]; let mut push_predicates = vec![]; for expr in predicates { - let cols = expr.to_columns()?; + let cols = expr.column_refs(); if cols.iter().all(|c| group_expr_columns.contains(c)) { push_predicates.push(expr); } else { @@ -899,7 +876,7 @@ impl OptimizerRule for PushDownFilter { let predicate_push_or_keep = split_conjunction(&filter.predicate) .iter() .map(|expr| { - let cols = expr.to_columns()?; + let cols = expr.column_refs(); if cols.iter().any(|c| prevent_cols.contains(&c.name)) { Ok(false) // No push (keep) } else { diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 6723672ff498..cd2e0b6f5ba2 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -24,7 +24,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; @@ -43,14 +43,6 @@ impl PushDownLimit { /// Push down Limit. impl OptimizerRule for PushDownLimit { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called PushDownLimit::rewrite") - } - fn supports_rewrite(&self) -> bool { true } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index b32a88635395..fcd33be618f7 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -20,7 +20,7 @@ use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Column, Result}; +use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder}; @@ -157,14 +157,6 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { } } - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called ReplaceDistinctWithAggregate::rewrite") - } - fn name(&self) -> &str { "replace_distinct_aggregate" } diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index ba865fa1e944..897afda267dc 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -19,7 +19,6 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::internal_err; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::expr::BinaryExpr; @@ -133,14 +132,6 @@ impl RewriteDisjunctivePredicate { } impl OptimizerRule for RewriteDisjunctivePredicate { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called RewriteDisjunctivePredicate::rewrite") - } - fn name(&self) -> &str { "rewrite_disjunctive_predicate" } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 279eca9c912b..0333cc8dde36 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -69,14 +69,6 @@ impl ScalarSubqueryToJoin { } impl OptimizerRule for ScalarSubqueryToJoin { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called ScalarSubqueryToJoin::rewrite") - } - fn supports_rewrite(&self) -> bool { true } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index d15d12b690da..e650d4c09c23 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, DFSchema, DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::simplify::SimplifyContext; @@ -48,14 +48,6 @@ use super::ExprSimplifier; pub struct SimplifyExpressions {} impl OptimizerRule for SimplifyExpressions { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called SimplifyExpressions::rewrite") - } - fn name(&self) -> &str { "simplify_expressions" } diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 5da727cb5990..ed3fd75f3efd 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -69,8 +69,8 @@ pub static POWS_OF_TEN: [i128; 38] = [ /// expressions. Such as: (A AND B) AND C pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { - expr_contains(left, needle, search_op) + Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &search_op => { + expr_contains(left, needle, search_op.clone()) || expr_contains(right, needle, search_op) } _ => expr == needle, @@ -88,7 +88,7 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> ) -> Expr { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) - if *op == Operator::BitwiseXor => + if op == &Operator::BitwiseXor => { let left_expr = recursive_delete_xor_in_expr(left, needle, xor_counter); let right_expr = recursive_delete_xor_in_expr(right, needle, xor_counter); @@ -102,7 +102,7 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr::BinaryExpr(BinaryExpr::new( Box::new(left_expr), - *op, + op.clone(), Box::new(right_expr), )) } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index d3d22eb53f39..b3562b7065e1 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -40,12 +40,12 @@ use hashbrown::HashSet; /// single distinct to group by optimizer rule /// ```text /// Before: -/// SELECT a, COUNT(DINSTINCT b), sum(c) +/// SELECT a, count(DINSTINCT b), sum(c) /// FROM t /// GROUP BY a /// /// After: -/// SELECT a, COUNT(alias1), sum(alias2) +/// SELECT a, count(alias1), sum(alias2) /// FROM ( /// SELECT a, b as alias1, sum(c) as alias2 /// FROM t @@ -123,14 +123,6 @@ fn contains_grouping_set(expr: &[Expr]) -> bool { } impl OptimizerRule for SingleDistinctToGroupBy { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called SingleDistinctToGroupBy::rewrite") - } - fn name(&self) -> &str { "single_distinct_aggregation_to_group_by" } @@ -175,7 +167,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // if from parent operators successfully. // Consider plan below. // - // Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] // @@ -183,7 +175,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. // If we were to write plan above as below without alias // - // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] // @@ -404,8 +396,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\ + let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\ \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -427,7 +419,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -445,7 +437,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -464,7 +456,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -478,8 +470,8 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\ + let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\ \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -495,8 +487,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\ + let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64;N]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -515,7 +507,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(DISTINCT test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -542,8 +534,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1), MAX(alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), MAX(alias1)]] [a:UInt32, count(alias1):Int64;N, MAX(alias1):UInt32;N]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -562,7 +554,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -577,8 +569,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int32, COUNT(DISTINCT test.c):Int64;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64;N]\ + \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -607,8 +599,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64;N, MAX(alias1):UInt32;N]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -626,8 +618,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\ + let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, count(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, count(alias1):Int64;N]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -645,8 +637,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\ + let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, count(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), count(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, count(alias1):Int64;N]\ \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -670,7 +662,7 @@ mod tests { .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -680,7 +672,7 @@ mod tests { fn distinct_with_filter() -> Result<()> { let table_scan = test_table_scan()?; - // COUNT(DISTINCT a) FILTER (WHERE a > 5) + // count(DISTINCT a) FILTER (WHERE a > 5) let expr = count_udaf() .call(vec![col("a")]) .distinct() @@ -690,7 +682,7 @@ mod tests { .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -713,7 +705,7 @@ mod tests { .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -723,7 +715,7 @@ mod tests { fn distinct_with_order_by() -> Result<()> { let table_scan = test_table_scan()?; - // COUNT(DISTINCT a ORDER BY a) + // count(DISTINCT a ORDER BY a) let expr = count_udaf() .call(vec![col("a")]) .distinct() @@ -733,7 +725,7 @@ mod tests { .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -743,7 +735,7 @@ mod tests { fn aggregate_with_filter_and_order_by() -> Result<()> { let table_scan = test_table_scan()?; - // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) + // count(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) let expr = count_udaf() .call(vec![col("a")]) .distinct() @@ -754,7 +746,7 @@ mod tests { .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 07a946c1add9..fb18518fd226 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -82,14 +82,6 @@ impl UnwrapCastInComparison { } impl OptimizerRule for UnwrapCastInComparison { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called UnwrapCastInComparison::rewrite") - } - fn name(&self) -> &str { "unwrap_cast_in_comparison" } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 6218140409b5..0549845430a6 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,7 +17,7 @@ //! Utility functions leveraged by the query optimizer rules -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use crate::{OptimizerConfig, OptimizerRule}; @@ -66,15 +66,26 @@ pub fn optimize_children( } } +/// Returns true if `expr` contains all columns in `schema_cols` +pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { + let column_refs = expr.column_refs(); + // note can't use HashSet::intersect because of different types (owned vs References) + schema_cols + .iter() + .filter(|c| column_refs.contains(c)) + .count() + == column_refs.len() +} + pub(crate) fn collect_subquery_cols( exprs: &[Expr], subquery_schema: DFSchemaRef, ) -> Result> { exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { let mut using_cols: Vec = vec![]; - for col in expr.to_columns()?.into_iter() { - if subquery_schema.has_column(&col) { - using_cols.push(col); + for col in expr.column_refs().into_iter() { + if subquery_schema.has_column(col) { + using_cols.push(col.clone()); } } @@ -166,13 +177,13 @@ pub fn split_conjunction_owned(expr: Expr) -> Vec { /// ]; /// /// // use split_binary_owned to split them -/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); +/// assert_eq!(split_binary_owned(expr, &Operator::Plus), split); /// ``` #[deprecated( since = "34.0.0", note = "use `datafusion_expr::utils::split_binary_owned` instead" )] -pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { +pub fn split_binary_owned(expr: Expr, op: &Operator) -> Vec { expr_utils::split_binary_owned(expr, op) } @@ -183,7 +194,7 @@ pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { since = "34.0.0", note = "use `datafusion_expr::utils::split_binary` instead" )] -pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { +pub fn split_binary<'a>(expr: &'a Expr, op: &Operator) -> Vec<&'a Expr> { expr_utils::split_binary(expr, op) } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index f60bf6609005..c501d5aaa4bf 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; @@ -64,16 +65,16 @@ fn subquery_filter_with_cast() -> Result<()> { // regression test for https://github.com/apache/datafusion/issues/3760 let sql = "SELECT col_int32 FROM test \ WHERE col_int32 > (\ - SELECT AVG(col_int32) FROM test \ + SELECT avg(col_int32) FROM test \ WHERE col_utf8 BETWEEN '2002-05-08' \ AND (cast('2002-05-08' as date) + interval '5 days')\ )"; let plan = test_sql(sql)?; let expected = "Projection: test.col_int32\ - \n Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.AVG(test.col_int32)\ + \n Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.avg(test.col_int32)\ \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: __scalar_sq_1\ - \n Aggregate: groupBy=[[]], aggr=[[AVG(CAST(test.col_int32 AS Float64))]]\ + \n Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]]\ \n Projection: test.col_int32\ \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ \n TableScan: test projection=[col_int32, col_utf8]"; @@ -187,7 +188,7 @@ fn between_date32_plus_interval() -> Result<()> { WHERE col_date32 between '1998-03-18' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; let plan = test_sql(sql)?; let expected = - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ \n Projection: \ \n Filter: test.col_date32 >= Date32(\"1998-03-18\") AND test.col_date32 <= Date32(\"1998-06-16\")\ \n TableScan: test projection=[col_date32]"; @@ -201,7 +202,7 @@ fn between_date64_plus_interval() -> Result<()> { WHERE col_date64 between '1998-03-18T00:00:00' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; let plan = test_sql(sql)?; let expected = - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ \n Projection: \ \n Filter: test.col_date64 >= Date64(\"1998-03-18\") AND test.col_date64 <= Date64(\"1998-06-16\")\ \n TableScan: test projection=[col_date64]"; @@ -257,8 +258,8 @@ fn join_keys_in_subquery_alias_1() { fn push_down_filter_groupby_expr_contains_alias() { let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3"; let plan = test_sql(sql).unwrap(); - let expected = "Projection: test.col_int32 + test.col_uint32 AS c, COUNT(*)\ - \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\ + let expected = "Projection: test.col_int32 + test.col_uint32 AS c, count(*)\ + \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[count(Int64(1)) AS count(*)]]\ \n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\ \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); @@ -326,7 +327,8 @@ fn test_sql(sql: &str) -> Result { let statement = &ast[0]; let context_provider = MyContextProvider::default() .with_udaf(sum_udaf()) - .with_udaf(count_udaf()); + .with_udaf(count_udaf()) + .with_udaf(avg_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs index 5c888ca66caa..27094b0c819a 100644 --- a/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs @@ -20,7 +20,7 @@ use crate::binary_map::{ArrowBytesSet, OutputType}; use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; use std::fmt::Debug; @@ -47,7 +47,7 @@ impl Accumulator for BytesDistinctCountAccumulator { fn state(&mut self) -> datafusion_common::Result> { let set = self.0.take(); let arr = set.into_state(); - let list = Arc::new(array_into_list_array(arr)); + let list = Arc::new(array_into_list_array_nullable(arr)); Ok(vec![ScalarValue::List(list)]) } diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs index 72b83676e81d..e525118b9a17 100644 --- a/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs @@ -32,7 +32,7 @@ use arrow::array::PrimitiveArray; use arrow::datatypes::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; @@ -72,7 +72,7 @@ where PrimitiveArray::::from_iter_values(self.values.iter().cloned()) .with_data_type(self.data_type.clone()), ); - let list = Arc::new(array_into_list_array(arr)); + let list = Arc::new(array_into_list_array_nullable(arr)); Ok(vec![ScalarValue::List(list)]) } @@ -160,7 +160,7 @@ where let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().map(|v| v.0), )) as ArrayRef; - let list = Arc::new(array_into_list_array(arr)); + let list = Arc::new(array_into_list_array_nullable(arr)); Ok(vec![ScalarValue::List(list)]) } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 432267e045b2..336e28b4d28e 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -211,6 +211,9 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { /// Rewrites [`AggregateExpr`], with new expressions given. The argument should be consistent /// with the return value of the [`AggregateExpr::all_expressions`] method. /// Returns `Some(Arc)` if re-write is supported, otherwise returns `None`. + /// TODO: This method only rewrites the [`PhysicalExpr`]s and does not handle [`Expr`]s. + /// This can cause silent bugs and should be fixed in the future (possibly with physical-to-logical + /// conversions). fn with_new_expressions( &self, _args: Vec>, diff --git a/datafusion/physical-expr-common/src/aggregate/tdigest.rs b/datafusion/physical-expr-common/src/aggregate/tdigest.rs index 5107d0ab8e52..1da3d7180d84 100644 --- a/datafusion/physical-expr-common/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr-common/src/aggregate/tdigest.rs @@ -576,7 +576,7 @@ impl TDigest { .map(|v| ScalarValue::Float64(Some(v))) .collect(); - let arr = ScalarValue::new_list(¢roids, &DataType::Float64); + let arr = ScalarValue::new_list_nullable(¢roids, &DataType::Float64); vec![ ScalarValue::UInt64(Some(self.max_size as u64)), diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs new file mode 100644 index 000000000000..96c903180ed9 --- /dev/null +++ b/datafusion/physical-expr-common/src/datum.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// UnLt required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::BooleanArray; +use arrow::array::{make_comparator, ArrayRef, Datum}; +use arrow::buffer::NullBuffer; +use arrow::compute::SortOptions; +use arrow::error::ArrowError; +use datafusion_common::internal_err; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Operator}; +use std::sync::Arc; + +/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` +/// +/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction +pub fn apply( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + match (&lhs, &rhs) { + (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { + Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) + } + (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( + ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ), + (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( + ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), + ), + (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { + let array = f(&left.to_scalar()?, &right.to_scalar()?)?; + let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } + } +} + +/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` +pub fn apply_cmp( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) +} + +/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like +/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type +pub fn apply_cmp_for_nested( + op: &Operator, + lhs: &ColumnarValue, + rhs: &ColumnarValue, +) -> Result { + if matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::Gt + | Operator::LtEq + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + ) { + apply(lhs, rhs, |l, r| { + Ok(Arc::new(compare_op_for_nested(op, l, r)?)) + }) + } else { + internal_err!("invalid operator for nested") + } +} + +/// Compare on nested type List, Struct, and so on +pub fn compare_op_for_nested( + op: &Operator, + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + let (l, is_l_scalar) = lhs.get(); + let (r, is_r_scalar) = rhs.get(); + let l_len = l.len(); + let r_len = r.len(); + + if l_len != r_len && !is_l_scalar && !is_r_scalar { + return internal_err!("len mismatch"); + } + + let len = match is_l_scalar { + true => r_len, + false => l_len, + }; + + // fast path, if compare with one null and operator is not 'distinct', then we can return null array directly + if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) + && (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1) + { + return Ok(BooleanArray::new_null(len)); + } + + // TODO: make SortOptions configurable + // we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour + let cmp = make_comparator(l, r, SortOptions::default())?; + + let cmp_with_op = |i, j| match op { + Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(), + Operator::Lt => cmp(i, j).is_lt(), + Operator::Gt => cmp(i, j).is_gt(), + Operator::LtEq => !cmp(i, j).is_gt(), + Operator::GtEq => !cmp(i, j).is_lt(), + Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(), + _ => unreachable!("unexpected operator found"), + }; + + let values = match (is_l_scalar, is_r_scalar) { + (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(), + (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(), + (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(), + (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(), + }; + + // Distinct understand how to compare with NULL + // i.e NULL is distinct from NULL -> false + if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) { + Ok(BooleanArray::new(values, None)) + } else { + // If one of the side is NULL, we returns NULL + // i.e. NULL eq NULL -> NULL + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + Ok(BooleanArray::new(values, nulls)) + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{make_comparator, Array, BooleanArray, ListArray}, + buffer::NullBuffer, + compute::SortOptions, + datatypes::Int32Type, + }; + + #[test] + fn test123() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let a = ListArray::from_iter_primitive::(data); + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let b = ListArray::from_iter_primitive::(data); + let cmp = make_comparator(&a, &b, SortOptions::default()).unwrap(); + let len = a.len().min(b.len()); + let values = (0..len).map(|i| cmp(i, i).is_eq()).collect(); + let nulls = NullBuffer::union(a.nulls(), b.nulls()); + println!("res: {:?}", BooleanArray::new(values, nulls)); + } +} diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 0ddb84141a07..8d50e0b964e5 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -17,6 +17,7 @@ pub mod aggregate; pub mod binary_map; +pub mod datum; pub mod expressions; pub mod physical_expr; pub mod sort_expr; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 364940bdee8b..d8dbe636d90c 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -56,7 +56,6 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index a23ba07de44a..c5a0662a2283 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -70,22 +70,23 @@ impl AggregateExpr for ArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )) } fn create_accumulator(&self) -> Result> { Ok(Box::new(ArrayAggAccumulator::try_new( &self.input_data_type, + self.nullable, )?)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )]) } @@ -115,14 +116,16 @@ impl PartialEq for ArrayAgg { pub(crate) struct ArrayAggAccumulator { values: Vec, datatype: DataType, + nullable: bool, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new(datatype: &DataType, nullable: bool) -> Result { Ok(Self { values: vec![], datatype: datatype.clone(), + nullable, }) } } @@ -164,12 +167,12 @@ impl Accumulator for ArrayAggAccumulator { self.values.iter().map(|a| a.as_ref()).collect(); if element_arrays.is_empty() { - let arr = ScalarValue::new_list(&[], &self.datatype); + let arr = ScalarValue::new_list(&[], &self.datatype, self.nullable); return Ok(ScalarValue::List(arr)); } let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array(concated_array); + let list_array = array_into_list_array(concated_array, self.nullable); Ok(ScalarValue::List(Arc::new(list_array))) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 244a44acdcb5..fc838196de20 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -74,22 +74,23 @@ impl AggregateExpr for DistinctArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )) } fn create_accumulator(&self) -> Result> { Ok(Box::new(DistinctArrayAggAccumulator::try_new( &self.input_data_type, + self.nullable, )?)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )]) } @@ -119,13 +120,15 @@ impl PartialEq for DistinctArrayAgg { struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, + nullable: bool, } impl DistinctArrayAggAccumulator { - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new(datatype: &DataType, nullable: bool) -> Result { Ok(Self { values: HashSet::new(), datatype: datatype.clone(), + nullable, }) } } @@ -162,7 +165,7 @@ impl Accumulator for DistinctArrayAggAccumulator { fn evaluate(&mut self) -> Result { let values: Vec = self.values.iter().cloned().collect(); - let arr = ScalarValue::new_list(&values, &self.datatype); + let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable); Ok(ScalarValue::List(arr)) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 837a9d551153..1234ab40c188 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -91,8 +91,8 @@ impl AggregateExpr for OrderSensitiveArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )) } @@ -102,6 +102,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { &self.order_by_data_types, self.ordering_req.clone(), self.reverse, + self.nullable, ) .map(|acc| Box::new(acc) as _) } @@ -109,14 +110,18 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn state_fields(&self) -> Result> { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, // This should be the same as field() + Field::new("item", self.input_data_type.clone(), self.nullable), + false, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), - Field::new("item", DataType::Struct(Fields::from(orderings)), true), - self.nullable, + Field::new( + "item", + DataType::Struct(Fields::from(orderings)), + self.nullable, + ), + false, )); Ok(fields) } @@ -181,6 +186,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator { ordering_req: LexOrdering, /// Whether the aggregation is running in reverse. reverse: bool, + /// Whether the input expr is nullable + nullable: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -191,6 +198,7 @@ impl OrderSensitiveArrayAggAccumulator { ordering_dtypes: &[DataType], ordering_req: LexOrdering, reverse: bool, + nullable: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -200,6 +208,7 @@ impl OrderSensitiveArrayAggAccumulator { datatypes, ordering_req, reverse, + nullable, }) } } @@ -302,9 +311,17 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { fn evaluate(&mut self) -> Result { let values = self.values.clone(); let array = if self.reverse { - ScalarValue::new_list_from_iter(values.into_iter().rev(), &self.datatypes[0]) + ScalarValue::new_list_from_iter( + values.into_iter().rev(), + &self.datatypes[0], + self.nullable, + ) } else { - ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0]) + ScalarValue::new_list_from_iter( + values.into_iter(), + &self.datatypes[0], + self.nullable, + ) }; Ok(ScalarValue::List(array)) } @@ -362,6 +379,7 @@ impl OrderSensitiveArrayAggAccumulator { )?; Ok(ScalarValue::List(Arc::new(array_into_list_array( Arc::new(ordering_array), + self.nullable, )))) } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 53cfcfb033a1..169418d2daa0 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -33,7 +33,6 @@ use arrow::datatypes::Schema; use datafusion_common::{exec_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; -use crate::aggregate::average::Avg; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; @@ -108,23 +107,6 @@ pub fn create_aggregate_expr( name, data_type, )), - (AggregateFunction::Avg, false) => { - Arc::new(Avg::new(input_phy_exprs[0].clone(), name, data_type)) - } - (AggregateFunction::Avg, true) => { - return not_impl_err!("AVG(DISTINCT) aggregations are not available"); - } - (AggregateFunction::Correlation, false) => { - Arc::new(expressions::Correlation::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )) - } - (AggregateFunction::Correlation, true) => { - return not_impl_err!("CORR(DISTINCT) aggregations are not available"); - } (AggregateFunction::NthValue, _) => { let expr = &input_phy_exprs[0]; let Some(n) = input_phy_exprs[1] @@ -155,7 +137,7 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{try_cast, ArrayAgg, Avg, DistinctArrayAgg, Max, Min}; + use crate::expressions::{try_cast, ArrayAgg, DistinctArrayAgg, Max, Min}; use super::*; #[test] @@ -190,7 +172,7 @@ mod tests { Field::new_list( "c1", Field::new("item", data_type.clone(), true), - true, + false, ), result_agg_phy_exprs.field().unwrap() ); @@ -210,7 +192,7 @@ mod tests { Field::new_list( "c1", Field::new("item", data_type.clone(), true), - true, + false, ), result_agg_phy_exprs.field().unwrap() ); @@ -269,92 +251,27 @@ mod tests { Ok(()) } - #[test] - fn test_sum_avg_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Avg]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Avg { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ); - }; - } - } - Ok(()) - } - #[test] fn test_min_max() -> Result<()> { - let observed = AggregateFunction::Min.return_type(&[DataType::Utf8])?; + let observed = AggregateFunction::Min.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = AggregateFunction::Max.return_type(&[DataType::Int32])?; + let observed = AggregateFunction::Max.return_type(&[DataType::Int32], &[true])?; assert_eq!(DataType::Int32, observed); // test decimal for min - let observed = - AggregateFunction::Min.return_type(&[DataType::Decimal128(10, 6)])?; + let observed = AggregateFunction::Min + .return_type(&[DataType::Decimal128(10, 6)], &[true])?; assert_eq!(DataType::Decimal128(10, 6), observed); // test decimal for max - let observed = - AggregateFunction::Max.return_type(&[DataType::Decimal128(28, 13)])?; + let observed = AggregateFunction::Max + .return_type(&[DataType::Decimal128(28, 13)], &[true])?; assert_eq!(DataType::Decimal128(28, 13), observed); Ok(()) } - #[test] - fn test_avg_return_type() -> Result<()> { - let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Avg.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Avg.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Float64, observed); - - let observed = - AggregateFunction::Avg.return_type(&[DataType::Decimal128(10, 6)])?; - assert_eq!(DataType::Decimal128(14, 10), observed); - - let observed = - AggregateFunction::Avg.return_type(&[DataType::Decimal128(36, 6)])?; - assert_eq!(DataType::Decimal128(38, 10), observed); - Ok(()) - } - - #[test] - fn test_avg_no_utf8() { - let observed = AggregateFunction::Avg.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } - // Helper function // Create aggregate expr with type coercion fn create_physical_agg_expr_for_test( diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs deleted file mode 100644 index a47d35053208..000000000000 --- a/datafusion/physical-expr/src/aggregate/correlation.rs +++ /dev/null @@ -1,524 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::aggregate::covariance::CovarianceAccumulator; -use crate::aggregate::stats::StatsType; -use crate::aggregate::stddev::StddevAccumulator; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::{ - array::ArrayRef, - compute::{and, filter, is_not_null}, - datatypes::{DataType, Field}, -}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// CORR aggregate expression -#[derive(Debug)] -pub struct Correlation { - name: String, - expr1: Arc, - expr2: Arc, -} - -impl Correlation { - /// Create a new COVAR_POP aggregate function - pub fn new( - expr1: Arc, - expr2: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of correlation just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr1, - expr2, - } - } -} - -impl AggregateExpr for Correlation { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CorrelationAccumulator::try_new()?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean1"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "m2_1"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "mean2"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "m2_2"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "algo_const"), - DataType::Float64, - true, - ), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr1.clone(), self.expr2.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Correlation { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2) - }) - .unwrap_or(false) - } -} - -/// An accumulator to compute correlation -#[derive(Debug)] -pub struct CorrelationAccumulator { - covar: CovarianceAccumulator, - stddev1: StddevAccumulator, - stddev2: StddevAccumulator, -} - -impl CorrelationAccumulator { - /// Creates a new `CorrelationAccumulator` - pub fn try_new() -> Result { - Ok(Self { - covar: CovarianceAccumulator::try_new(StatsType::Population)?, - stddev1: StddevAccumulator::try_new(StatsType::Population)?, - stddev2: StddevAccumulator::try_new(StatsType::Population)?, - }) - } -} - -impl Accumulator for CorrelationAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.covar.get_count()), - ScalarValue::from(self.covar.get_mean1()), - ScalarValue::from(self.stddev1.get_m2()), - ScalarValue::from(self.covar.get_mean2()), - ScalarValue::from(self.stddev2.get_m2()), - ScalarValue::from(self.covar.get_algo_const()), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // TODO: null input skipping logic duplicated across Correlation - // and its children accumulators. - // This could be simplified by splitting up input filtering and - // calculation logic in children accumulators, and calling only - // calculation part from Correlation - let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { - let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; - let values1 = filter(&values[0], &mask)?; - let values2 = filter(&values[1], &mask)?; - - vec![values1, values2] - } else { - values.to_vec() - }; - - self.covar.update_batch(&values)?; - self.stddev1.update_batch(&values[0..1])?; - self.stddev2.update_batch(&values[1..2])?; - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { - let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; - let values1 = filter(&values[0], &mask)?; - let values2 = filter(&values[1], &mask)?; - - vec![values1, values2] - } else { - values.to_vec() - }; - - self.covar.retract_batch(&values)?; - self.stddev1.retract_batch(&values[0..1])?; - self.stddev2.retract_batch(&values[1..2])?; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let states_c = [ - states[0].clone(), - states[1].clone(), - states[3].clone(), - states[5].clone(), - ]; - let states_s1 = [states[0].clone(), states[1].clone(), states[2].clone()]; - let states_s2 = [states[0].clone(), states[3].clone(), states[4].clone()]; - - self.covar.merge_batch(&states_c)?; - self.stddev1.merge_batch(&states_s1)?; - self.stddev2.merge_batch(&states_s2)?; - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let covar = self.covar.evaluate()?; - let stddev1 = self.stddev1.evaluate()?; - let stddev2 = self.stddev2.evaluate()?; - - if let ScalarValue::Float64(Some(c)) = covar { - if let ScalarValue::Float64(Some(s1)) = stddev1 { - if let ScalarValue::Float64(Some(s2)) = stddev2 { - if s1 == 0_f64 || s2 == 0_f64 { - return Ok(ScalarValue::Float64(Some(0_f64))); - } else { - return Ok(ScalarValue::Float64(Some(c / s1 / s2))); - } - } - } - } - - Ok(ScalarValue::Float64(None)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) - + self.covar.size() - - std::mem::size_of_val(&self.stddev1) - + self.stddev1.size() - - std::mem::size_of_val(&self.stddev2) - + self.stddev2.size() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::aggregate::utils::get_accum_scalar_values_as_arrays; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op2; - use arrow::{array::*, datatypes::*}; - - #[test] - fn correlation_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 7_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Correlation, - ScalarValue::from(0.9819805060619659_f64) - ) - } - - #[test] - fn correlation_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, -5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Correlation, - ScalarValue::from(0.17066403719657236_f64) - ) - } - - #[test] - fn correlation_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_f64_6() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![ - 1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64, - ])); - let b = Arc::new(Float64Array::from(vec![ - 4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64, - ])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Correlation, - ScalarValue::from(0.9860135594710389_f64) - ) - } - - #[test] - fn correlation_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32])); - let b: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 5_u32, 6_u32])); - generic_test_op2!( - a, - b, - DataType::UInt32, - DataType::UInt32, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_f32() -> Result<()> { - let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32])); - let b: ArrayRef = Arc::new(Float32Array::from(vec![4_f32, 5_f32, 6_f32])); - generic_test_op2!( - a, - b, - DataType::Float32, - DataType::Float32, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_i32_with_nulls_1() -> Result<()> { - let a: ArrayRef = - Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(3)])); - let b: ArrayRef = - Arc::new(Int32Array::from(vec![Some(4), None, Some(6), Some(3)])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Correlation, - ScalarValue::from(0.1889822365046137_f64) - ) - } - - #[test] - fn correlation_i32_with_nulls_2() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(2), - Some(9), - Some(3), - ])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(4), - Some(5), - Some(5), - None, - Some(6), - ])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Correlation, - ScalarValue::Float64(None) - ) - } - - #[test] - fn correlation_f64_merge_1() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - let c = Arc::new(Float64Array::from(vec![1.1_f64, 2.2_f64, 3.3_f64])); - let d = Arc::new(Float64Array::from(vec![4.4_f64, 5.5_f64, 9.9_f64])); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - ]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; - - let agg1 = Arc::new(Correlation::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(Correlation::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(0.8443707186481967)); - - Ok(()) - } - - #[test] - fn correlation_f64_merge_2() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - let c = Arc::new(Float64Array::from(vec![None])); - let d = Arc::new(Float64Array::from(vec![None])); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - ]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; - - let agg1 = Arc::new(Correlation::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(Correlation::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(1_f64)); - - Ok(()) - } - - fn merge( - batch1: &RecordBatch, - batch2: &RecordBatch, - agg1: Arc, - agg2: Arc, - ) -> Result { - let mut accum1 = agg1.create_accumulator()?; - let mut accum2 = agg2.create_accumulator()?; - let expr1 = agg1.expressions(); - let expr2 = agg2.expressions(); - - let values1 = expr1 - .iter() - .map(|e| { - e.evaluate(batch1) - .and_then(|v| v.into_array(batch1.num_rows())) - }) - .collect::>>()?; - let values2 = expr2 - .iter() - .map(|e| { - e.evaluate(batch2) - .and_then(|v| v.into_array(batch2.num_rows())) - }) - .collect::>>()?; - accum1.update_batch(&values1)?; - accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; - accum1.merge_batch(&state2)?; - accum1.evaluate() - } -} diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs deleted file mode 100644 index 639d8a098c01..000000000000 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ /dev/null @@ -1,227 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use arrow::array::Float64Array; -use arrow::{ - array::{ArrayRef, UInt64Array}, - compute::cast, - datatypes::DataType, -}; -use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; - -use crate::aggregate::stats::StatsType; - -/// An accumulator to compute covariance -/// The algrithm used is an online implementation and numerically stable. It is derived from the following paper -/// for calculating variance: -/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". -/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. -/// -/// The algorithm has been analyzed here: -/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". -/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. -/// -/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online, -/// parallelizable and numerically stable. - -#[derive(Debug)] -pub struct CovarianceAccumulator { - algo_const: f64, - mean1: f64, - mean2: f64, - count: u64, - stats_type: StatsType, -} - -impl CovarianceAccumulator { - /// Creates a new `CovarianceAccumulator` - pub fn try_new(s_type: StatsType) -> Result { - Ok(Self { - algo_const: 0_f64, - mean1: 0_f64, - mean2: 0_f64, - count: 0_u64, - stats_type: s_type, - }) - } - - pub fn get_count(&self) -> u64 { - self.count - } - - pub fn get_mean1(&self) -> f64 { - self.mean1 - } - - pub fn get_mean2(&self) -> f64 { - self.mean2 - } - - pub fn get_algo_const(&self) -> f64 { - self.algo_const - } -} - -impl Accumulator for CovarianceAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::from(self.mean1), - ScalarValue::from(self.mean2), - ScalarValue::from(self.algo_const), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; - - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); - let new_count = self.count + 1; - let delta1 = value1 - self.mean1; - let new_mean1 = delta1 / new_count as f64 + self.mean1; - let delta2 = value2 - self.mean2; - let new_mean2 = delta2 / new_count as f64 + self.mean2; - let new_c = delta1 * (value2 - new_mean2) + self.algo_const; - - self.count += 1; - self.mean1 = new_mean1; - self.mean2 = new_mean2; - self.algo_const = new_c; - } - - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; - - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); - - let new_count = self.count - 1; - let delta1 = self.mean1 - value1; - let new_mean1 = delta1 / new_count as f64 + self.mean1; - let delta2 = self.mean2 - value2; - let new_mean2 = delta2 / new_count as f64 + self.mean2; - let new_c = self.algo_const - delta1 * (new_mean2 - value2); - - self.count -= 1; - self.mean1 = new_mean1; - self.mean2 = new_mean2; - self.algo_const = new_c; - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means1 = downcast_value!(states[1], Float64Array); - let means2 = downcast_value!(states[2], Float64Array); - let cs = downcast_value!(states[3], Float64Array); - - for i in 0..counts.len() { - let c = counts.value(i); - if c == 0_u64 { - continue; - } - let new_count = self.count + c; - let new_mean1 = self.mean1 * self.count as f64 / new_count as f64 - + means1.value(i) * c as f64 / new_count as f64; - let new_mean2 = self.mean2 * self.count as f64 / new_count as f64 - + means2.value(i) * c as f64 / new_count as f64; - let delta1 = self.mean1 - means1.value(i); - let delta2 = self.mean2 - means2.value(i); - let new_c = self.algo_const - + cs.value(i) - + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64; - - self.count = new_count; - self.mean1 = new_mean1; - self.mean2 = new_mean2; - self.algo_const = new_c; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let count = match self.stats_type { - StatsType::Population => self.count, - StatsType::Sample => { - if self.count > 0 { - self.count - 1 - } else { - self.count - } - } - }; - - if count == 0 { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(Some(self.algo_const / count as f64))) - } - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index 73d810ec056d..1944e2b2d415 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -19,6 +19,7 @@ mod adapter; pub use adapter::GroupsAccumulatorAdapter; // Backward compatibility +#[allow(unused_imports)] pub(crate) mod accumulate { pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; } diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index a6d5054ec170..8d07f0df0742 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -24,18 +24,20 @@ use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use crate::{AggregateExpr, PhysicalExpr}; use arrow::compute; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + DataType, Date32Type, Date64Type, IntervalUnit, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, }; use arrow::{ array::{ ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, - LargeStringArray, StringArray, Time32MillisecondArray, Time32SecondArray, - Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + LargeBinaryArray, LargeStringArray, StringArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, datatypes::Field, }; @@ -408,6 +410,25 @@ macro_rules! min_max_batch { $OP ) } + DataType::Interval(IntervalUnit::YearMonth) => { + typed_min_max_batch!( + $VALUES, + IntervalYearMonthArray, + IntervalYearMonth, + $OP + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + typed_min_max_batch!( + $VALUES, + IntervalMonthDayNanoArray, + IntervalMonthDayNano, + $OP + ) + } other => { // This should have been handled before return internal_err!( @@ -1121,6 +1142,108 @@ impl Accumulator for SlidingMinAccumulator { #[cfg(test)] mod tests { use super::*; + use arrow::datatypes::{ + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + }; + + #[test] + fn interval_min_max() { + // IntervalYearMonth + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(0, 1), + IntervalYearMonthType::make_value(5, 34), + IntervalYearMonthType::make_value(-2, 4), + IntervalYearMonthType::make_value(7, -4), + IntervalYearMonthType::make_value(0, 1), + ]); + let b: ArrayRef = Arc::new(b); + + let mut min = + MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) + .unwrap(); + min.update_batch(&[b.clone()]).unwrap(); + let min_res = min.evaluate().unwrap(); + assert_eq!( + min_res, + ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value( + -2, 4 + ))) + ); + + let mut max = + MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) + .unwrap(); + max.update_batch(&[b.clone()]).unwrap(); + let max_res = max.evaluate().unwrap(); + assert_eq!( + max_res, + ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value( + 5, 34 + ))) + ); + + // IntervalDayTime + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(0, 0), + IntervalDayTimeType::make_value(5, 454000), + IntervalDayTimeType::make_value(-34, 0), + IntervalDayTimeType::make_value(7, -4000), + IntervalDayTimeType::make_value(1, 0), + ]); + let b: ArrayRef = Arc::new(b); + + let mut min = + MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); + min.update_batch(&[b.clone()]).unwrap(); + let min_res = min.evaluate().unwrap(); + assert_eq!( + min_res, + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(-34, 0))) + ); + + let mut max = + MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); + max.update_batch(&[b.clone()]).unwrap(); + let max_res = max.evaluate().unwrap(); + assert_eq!( + max_res, + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(7, -4000))) + ); + + // IntervalMonthDayNano + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(1, 0, 0), + IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000), + IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000), + IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000), + IntervalMonthDayNanoType::make_value(1, 0, 0), + ]); + let b: ArrayRef = Arc::new(b); + + let mut min = + MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) + .unwrap(); + min.update_batch(&[b.clone()]).unwrap(); + let min_res = min.evaluate().unwrap(); + assert_eq!( + min_res, + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000) + )) + ); + + let mut max = + MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) + .unwrap(); + max.update_batch(&[b.clone()]).unwrap(); + let max_res = max.evaluate().unwrap(); + assert_eq!( + max_res, + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000) + )) + ); + } #[test] fn float_min_max_with_nans() { diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index f64c5b1fb260..ca5bf3293442 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -20,17 +20,12 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; -pub(crate) mod average; -pub(crate) mod correlation; -pub(crate) mod covariance; pub(crate) mod grouping; pub(crate) mod nth_value; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; pub(crate) mod stats; -pub(crate) mod stddev; -pub(crate) mod variance; pub mod build_in; pub mod moving_min_max; diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/physical-expr/src/aggregate/nth_value.rs index ee7426a897b3..f6d25348f222 100644 --- a/datafusion/physical-expr/src/aggregate/nth_value.rs +++ b/datafusion/physical-expr/src/aggregate/nth_value.rs @@ -32,7 +32,7 @@ use crate::{ use arrow_array::cast::AsArray; use arrow_array::{new_empty_array, ArrayRef, StructArray}; use arrow_schema::{DataType, Field, Fields}; -use datafusion_common::utils::{array_into_list_array, get_row_at_idx}; +use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::Accumulator; @@ -393,7 +393,7 @@ impl NthValueAccumulator { None, )?; - Ok(ScalarValue::List(Arc::new(array_into_list_array( + Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( Arc::new(ordering_array), )))) } @@ -401,7 +401,10 @@ impl NthValueAccumulator { fn evaluate_values(&self) -> ScalarValue { let mut values_cloned = self.values.clone(); let values_slice = values_cloned.make_contiguous(); - ScalarValue::List(ScalarValue::new_list(values_slice, &self.datatypes[0])) + ScalarValue::List(ScalarValue::new_list_nullable( + values_slice, + &self.datatypes[0], + )) } /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs deleted file mode 100644 index 3ade67b51905..000000000000 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ /dev/null @@ -1,87 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use arrow::array::ArrayRef; - -use datafusion_common::ScalarValue; -use datafusion_common::{internal_err, Result}; -use datafusion_expr::Accumulator; - -use crate::aggregate::stats::StatsType; -use crate::aggregate::variance::VarianceAccumulator; - -/// An accumulator to compute the average -#[derive(Debug)] -pub struct StddevAccumulator { - variance: VarianceAccumulator, -} - -impl StddevAccumulator { - /// Creates a new `StddevAccumulator` - pub fn try_new(s_type: StatsType) -> Result { - Ok(Self { - variance: VarianceAccumulator::try_new(s_type)?, - }) - } - - pub fn get_m2(&self) -> f64 { - self.variance.get_m2() - } -} - -impl Accumulator for StddevAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.variance.get_count()), - ScalarValue::from(self.variance.get_mean()), - ScalarValue::from(self.variance.get_m2()), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.variance.update_batch(values) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.variance.retract_batch(values) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.variance.merge_batch(states) - } - - fn evaluate(&mut self) -> Result { - let variance = self.variance.evaluate()?; - match variance { - ScalarValue::Float64(e) => { - if e.is_none() { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) - } - } - _ => internal_err!("Variance should be f64"), - } - } - - fn size(&self) -> usize { - std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) - + self.variance.size() - } -} diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs deleted file mode 100644 index 27c67a2f9c7c..000000000000 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ /dev/null @@ -1,176 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::aggregate::stats::StatsType; -use arrow::array::Float64Array; -use arrow::{ - array::{ArrayRef, UInt64Array}, - compute::cast, - datatypes::DataType, -}; -use datafusion_common::downcast_value; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::Accumulator; - -// TODO only holds the definition of `VarianceAccumulator` for use by `StddevAccumulator` in `physical-expr`, -// which in turn only has it there for legacy `CorrelationAccumulator`, but this whole file should go away -// once the latter is moved to `functions-aggregate`. - -/// An accumulator to compute variance -/// The algrithm used is an online implementation and numerically stable. It is based on this paper: -/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". -/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. -/// -/// The algorithm has been analyzed here: -/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". -/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. - -#[derive(Debug)] -pub struct VarianceAccumulator { - m2: f64, - mean: f64, - count: u64, - stats_type: StatsType, -} - -impl VarianceAccumulator { - /// Creates a new `VarianceAccumulator` - pub fn try_new(s_type: StatsType) -> Result { - Ok(Self { - m2: 0_f64, - mean: 0_f64, - count: 0_u64, - stats_type: s_type, - }) - } - - pub fn get_count(&self) -> u64 { - self.count - } - - pub fn get_mean(&self) -> f64 { - self.mean - } - - pub fn get_m2(&self) -> f64 { - self.m2 - } -} - -impl Accumulator for VarianceAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::from(self.mean), - ScalarValue::from(self.m2), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { - let new_count = self.count + 1; - let delta1 = value - self.mean; - let new_mean = delta1 / new_count as f64 + self.mean; - let delta2 = value - new_mean; - let new_m2 = self.m2 + delta1 * delta2; - - self.count += 1; - self.mean = new_mean; - self.m2 = new_m2; - } - - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { - let new_count = self.count - 1; - let delta1 = self.mean - value; - let new_mean = delta1 / new_count as f64 + self.mean; - let delta2 = new_mean - value; - let new_m2 = self.m2 - delta1 * delta2; - - self.count -= 1; - self.mean = new_mean; - self.m2 = new_m2; - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means = downcast_value!(states[1], Float64Array); - let m2s = downcast_value!(states[2], Float64Array); - - for i in 0..counts.len() { - let c = counts.value(i); - if c == 0_u64 { - continue; - } - let new_count = self.count + c; - let new_mean = self.mean * self.count as f64 / new_count as f64 - + means.value(i) * c as f64 / new_count as f64; - let delta = self.mean - means.value(i); - let new_m2 = self.m2 - + m2s.value(i) - + delta * delta * self.count as f64 * c as f64 / new_count as f64; - - self.count = new_count; - self.mean = new_mean; - self.m2 = new_m2; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let count = match self.stats_type { - StatsType::Population => self.count, - StatsType::Sample => { - if self.count > 0 { - self.count - 1 - } else { - self.count - } - } - }; - - Ok(ScalarValue::Float64(match self.count { - 0 => None, - 1 => { - if let StatsType::Population = self.stats_type { - Some(0.0) - } else { - None - } - } - _ => Some(self.m2 / count as f64), - })) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 98df0cba9f3e..d19279c20d10 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,7 +20,6 @@ mod kernels; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; @@ -40,6 +39,7 @@ use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; +use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested}; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, @@ -265,6 +265,13 @@ impl PhysicalExpr for BinaryExpr { let schema = batch.schema(); let input_schema = schema.as_ref(); + if left_data_type.is_nested() { + if right_data_type != left_data_type { + return internal_err!("type mismatch"); + } + return apply_cmp_for_nested(&self.op, &lhs, &rhs); + } + match self.op { Operator::Plus => return apply(&lhs, &rhs, add_wrapping), Operator::Minus => return apply(&lhs, &rhs, sub_wrapping), @@ -322,7 +329,7 @@ impl PhysicalExpr for BinaryExpr { ) -> Result> { Ok(Arc::new(BinaryExpr::new( children[0].clone(), - self.op, + self.op.clone(), children[1].clone(), ))) } diff --git a/datafusion/physical-expr/src/expressions/datum.rs b/datafusion/physical-expr/src/expressions/datum.rs deleted file mode 100644 index 2bb79922cfec..000000000000 --- a/datafusion/physical-expr/src/expressions/datum.rs +++ /dev/null @@ -1,58 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::{ArrayRef, Datum}; -use arrow::error::ArrowError; -use arrow_array::BooleanArray; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::ColumnarValue; -use std::sync::Arc; - -/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` -/// -/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction -pub(crate) fn apply( - lhs: &ColumnarValue, - rhs: &ColumnarValue, - f: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - match (&lhs, &rhs) { - (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { - Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) - } - (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( - ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), - ), - (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( - ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), - ), - (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { - let array = f(&left.to_scalar()?, &right.to_scalar()?)?; - let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; - Ok(ColumnarValue::Scalar(scalar)) - } - } -} - -/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` -pub(crate) fn apply_cmp( - lhs: &ColumnarValue, - rhs: &ColumnarValue, - f: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) -} diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index eec347db8ed8..e0c02b0a90e9 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -20,11 +20,11 @@ use std::{any::Any, sync::Arc}; use crate::{physical_expr::down_cast_any_ref, PhysicalExpr}; -use crate::expressions::datum::apply_cmp; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Schema}; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::datum::apply_cmp; // Like expression #[derive(Debug, Hash)] @@ -148,6 +148,14 @@ impl PartialEq for LikeExpr { } } +/// used for optimize Dictionary like +fn can_like_type(from_type: &DataType) -> bool { + match from_type { + DataType::Dictionary(_, inner_type_from) => **inner_type_from == DataType::Utf8, + _ => false, + } +} + /// Create a like expression, erroring if the argument types are not compatible. pub fn like( negated: bool, @@ -158,7 +166,7 @@ pub fn like( ) -> Result> { let expr_type = &expr.data_type(input_schema)?; let pattern_type = &pattern.data_type(input_schema)?; - if !expr_type.eq(pattern_type) { + if !expr_type.eq(pattern_type) && !can_like_type(expr_type) { return internal_err!( "The type of {expr_type} AND {pattern_type} of like physical should be same" ); diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 0020aa5f55b2..b87b6daa64c7 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -21,7 +21,6 @@ mod binary; mod case; mod column; -mod datum; mod in_list; mod is_not_null; mod is_null; @@ -38,10 +37,7 @@ pub mod helpers { pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; -pub use crate::aggregate::average::Avg; -pub use crate::aggregate::average::AvgAccumulator; pub use crate::aggregate::build_in::create_aggregate_expr; -pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; @@ -58,7 +54,6 @@ pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; pub use column::UnKnownColumn; pub use datafusion_expr::utils::format_state_name; -pub use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal}; pub use datafusion_physical_expr_common::expressions::{cast, CastExpr}; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 9c7d6d09349d..e33c28df1988 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -38,14 +38,8 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; pub use crate::scalar_function::create_physical_expr; - -#[derive(Debug, Clone, Copy)] -pub enum Hint { - /// Indicates the argument needs to be padded if it is scalar - Pad, - /// Indicates the argument can be converted to an array of length 1 - AcceptsSingular, -} +// For backward compatibility +pub use datafusion_expr::function::Hint; #[deprecated(since = "36.0.0", note = "Use ColumarValue::values_to_arrays instead")] pub fn columnar_values_to_array(args: &[ColumnarValue]) -> Result> { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 5ba628e7ce40..6fbcd461af66 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -222,7 +222,7 @@ pub fn propagate_arithmetic( left_child: &Interval, right_child: &Interval, ) -> Result> { - let inverse_op = get_inverse_op(*op)?; + let inverse_op = get_inverse_op(op)?; match (left_child.data_type(), right_child.data_type()) { // If we have a child whose type is a time interval (i.e. DataType::Interval), // we need special handling since timestamp differencing results in a diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index b426a656fba9..37527802f84d 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -63,7 +63,7 @@ pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { } // This function returns the inverse operator of the given operator. -pub fn get_inverse_op(op: Operator) -> Result { +pub fn get_inverse_op(op: &Operator) -> Result { match op { Operator::Plus => Ok(Operator::Minus), Operator::Minus => Ok(Operator::Plus), diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index fcb3278b6022..273c77fb1d5e 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -152,6 +152,8 @@ impl Partitioning { match required { Distribution::UnspecifiedDistribution => true, Distribution::SinglePartition if self.partition_count() == 1 => true, + // When partition count is 1, hash requirement is satisfied. + Distribution::HashPartitioned(_) if self.partition_count() == 1 => true, Distribution::HashPartitioned(required_exprs) => { match self { // Here we do not check the partition count for hash partitioning and assumes the partition count @@ -290,7 +292,7 @@ mod tests { assert_eq!(result, (true, false, false, false, false)) } Distribution::HashPartitioned(_) => { - assert_eq!(result, (false, false, false, true, false)) + assert_eq!(result, (true, false, false, true, false)) } } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 9e8561eb68c5..8fe99cdca591 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,10 +17,15 @@ use std::sync::Arc; -use arrow::datatypes::Schema; +use crate::scalar_function; +use crate::{ + expressions::{self, binary, like, Column, Literal}, + PhysicalExpr, +}; +use arrow::datatypes::Schema; use datafusion_common::{ - exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, + exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; @@ -28,12 +33,6 @@ use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{binary_expr, Between, BinaryExpr, Expr, Like, Operator, TryCast}; -use crate::scalar_function; -use crate::{ - expressions::{self, binary, like, Column, Literal}, - PhysicalExpr, -}; - /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 /// AS int)`. /// @@ -196,7 +195,7 @@ pub fn create_physical_expr( // // There should be no coercion during physical // planning. - binary(lhs, *op, rhs, input_schema) + binary(lhs, op.clone(), rhs, input_schema) } Expr::Like(Like { negated, @@ -358,6 +357,13 @@ where .collect::>>() } +/// Convert a logical expression to a physical expression (without any simplification, etc) +pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, &execution_props).unwrap() +} + #[cfg(test)] mod tests { use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index deaff5453848..070034116fb4 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -419,15 +419,16 @@ impl<'a> ColOpLit<'a> { #[cfg(test)] mod test { + use std::sync::OnceLock; + use super::*; - use crate::create_physical_expr; + use crate::planner::logical2physical; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::ToDFSchema; - use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_fn::*; use datafusion_expr::{lit, Expr}; + use itertools::Itertools; - use std::sync::OnceLock; #[test] fn test_literal() { @@ -867,13 +868,6 @@ mod test { LiteralGuarantee::try_new(column, Guarantee::NotIn, literals.iter()).unwrap() } - /// Convert a logical expression to a physical expression (without any simplification, etc) - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() - } - // Schema for testing fn schema() -> SchemaRef { SCHEMA diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 005d834552f9..492cb02941df 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -44,7 +44,7 @@ use petgraph::stable_graph::StableGraph; pub fn split_conjunction( predicate: &Arc, ) -> Vec<&Arc> { - split_impl(Operator::And, predicate, vec![]) + split_impl(&Operator::And, predicate, vec![]) } /// Assume the predicate is in the form of DNF, split the predicate to a Vec of PhysicalExprs. @@ -53,16 +53,16 @@ pub fn split_conjunction( pub fn split_disjunction( predicate: &Arc, ) -> Vec<&Arc> { - split_impl(Operator::Or, predicate, vec![]) + split_impl(&Operator::Or, predicate, vec![]) } fn split_impl<'a>( - operator: Operator, + operator: &Operator, predicate: &'a Arc, mut exprs: Vec<&'a Arc>, ) -> Vec<&'a Arc> { match predicate.as_any().downcast_ref::() { - Some(binary) if binary.op() == &operator => { + Some(binary) if binary.op() == operator => { let exprs = split_impl(operator, binary.left(), exprs); split_impl(operator, binary.right(), exprs) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b7d8d60f4f35..2bf32e8d7084 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -675,7 +675,7 @@ impl ExecutionPlan for AggregateExec { vec![Distribution::UnspecifiedDistribution] } AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => { - vec![Distribution::HashPartitioned(self.output_group_expr())] + vec![Distribution::HashPartitioned(self.group_by.input_exprs())] } AggregateMode::Final | AggregateMode::Single => { vec![Distribution::SinglePartition] @@ -1177,7 +1177,7 @@ mod tests { use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; - use crate::expressions::{col, Avg}; + use crate::expressions::col; use crate::memory::MemoryExec; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; @@ -1194,11 +1194,11 @@ mod tests { use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::expr::Sort; + use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; use datafusion_functions_aggregate::median::median_udaf; - use datafusion_physical_expr::expressions::{ - lit, FirstValue, LastValue, OrderSensitiveArrayAgg, - }; + use datafusion_physical_expr::expressions::{lit, OrderSensitiveArrayAgg}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::aggregate::create_aggregate_expr; @@ -1485,11 +1485,17 @@ mod tests { groups: vec![vec![false]], }; - let aggregates: Vec> = vec![Arc::new(Avg::new( - col("b", &input_schema)?, - "AVG(b)".to_string(), - DataType::Float64, - ))]; + let aggregates: Vec> = vec![create_aggregate_expr( + &avg_udaf(), + &[col("b", &input_schema)?], + &[datafusion_expr::col("b")], + &[], + &[], + &input_schema, + "AVG(b)", + false, + false, + )?]; let task_ctx = if spill { // set to an appropriate value to trigger spill @@ -1819,11 +1825,17 @@ mod tests { vec![test_median_agg_expr(&input_schema)?]; // use fast-path in `row_hash.rs`. - let aggregates_v2: Vec> = vec![Arc::new(Avg::new( - col("b", &input_schema)?, - "AVG(b)".to_string(), - DataType::Float64, - ))]; + let aggregates_v2: Vec> = vec![create_aggregate_expr( + &avg_udaf(), + &[col("b", &input_schema)?], + &[datafusion_expr::col("b")], + &[], + &[], + &input_schema, + "AVG(b)", + false, + false, + )?]; for (version, groups, aggregates) in [ (0, groups_none, aggregates_v0), @@ -1873,15 +1885,21 @@ mod tests { async fn test_drop_cancel_without_groups() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)])); let groups = PhysicalGroupBy::default(); - let aggregates: Vec> = vec![Arc::new(Avg::new( - col("a", &schema)?, - "AVG(a)".to_string(), - DataType::Float64, - ))]; + let aggregates: Vec> = vec![create_aggregate_expr( + &avg_udaf(), + &[col("a", &schema)?], + &[datafusion_expr::col("a")], + &[], + &[], + &schema, + "AVG(a)", + false, + false, + )?]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -1908,18 +1926,24 @@ mod tests { async fn test_drop_cancel_with_groups() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, true), - Field::new("b", DataType::Float32, true), + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Float64, true), ])); let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let aggregates: Vec> = vec![Arc::new(Avg::new( - col("b", &schema)?, - "AVG(b)".to_string(), - DataType::Float64, - ))]; + let aggregates: Vec> = vec![create_aggregate_expr( + &avg_udaf(), + &[col("b", &schema)?], + &[datafusion_expr::col("b")], + &[], + &[], + &schema, + "AVG(b)", + false, + false, + )?]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 5353092d5c45..b2f9ef560745 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -73,6 +73,8 @@ use datafusion_physical_expr::expressions::UnKnownColumn; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use ahash::RandomState; +use datafusion_expr::Operator; +use datafusion_physical_expr_common::datum::compare_op_for_nested; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; @@ -1210,6 +1212,12 @@ fn eq_dyn_null( right: &dyn Array, null_equals_null: bool, ) -> Result { + // Nested datatypes cannot use the underlying not_distinct function and must use a special + // implementation + // + if left.data_type().is_nested() && null_equals_null { + return Ok(compare_op_for_nested(&Operator::Eq, &left, &right)?); + } match (left.data_type(), right.data_type()) { _ if null_equals_null => not_distinct(&left, &right), _ => eq(&left, &right), diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 420fab51da39..91b2151d32e7 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1532,17 +1532,21 @@ fn get_filtered_join_mask( for i in 0..streamed_indices_length { // LeftSemi respects only first true values for specific streaming index, // others true values for the same index must be false - if mask.value(i) && !seen_as_true { + let streamed_idx = streamed_indices.value(i); + if mask.value(i) + && !seen_as_true + && !matched_indices.contains(&streamed_idx) + { seen_as_true = true; corrected_mask.append_value(true); - filter_matched_indices.push(streamed_indices.value(i)); + filter_matched_indices.push(streamed_idx); } else { corrected_mask.append_value(false); } // if switched to next streaming index(e.g. from 0 to 1, or from 1 to 2), we reset seen_as_true flag if i < streamed_indices_length - 1 - && streamed_indices.value(i) != streamed_indices.value(i + 1) + && streamed_idx != streamed_indices.value(i + 1) { seen_as_true = false; } @@ -2940,6 +2944,20 @@ mod tests { )) ); + assert_eq!( + get_filtered_join_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &BooleanArray::from(vec![true, false, false, false, false, true]), + &HashSet::from_iter(vec![1]), + &0, + ), + Some(( + BooleanArray::from(vec![true, false, false, false, false, false]), + vec![0] + )) + ); + Ok(()) } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index bd77814bbbe4..c648547c98b1 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -116,12 +116,13 @@ pub mod udaf { /// [`required_input_ordering`]: ExecutionPlan::required_input_ordering pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// Short name for the ExecutionPlan, such as 'ParquetExec'. - fn name(&self) -> &'static str - where - Self: Sized, - { - Self::static_name() - } + /// + /// Implementation note: this method can just proxy to + /// [`static_name`](ExecutionPlan::static_name) if no special action is + /// needed. It doesn't provide a default implementation like that because + /// this method doesn't require the `Sized` constrain to allow a wilder + /// range of use cases. + fn name(&self) -> &str; /// Short name for the ExecutionPlan, such as 'ParquetExec'. /// Like [`name`](ExecutionPlan::name) but can be called without an instance. @@ -829,6 +830,10 @@ mod tests { } impl ExecutionPlan for EmptyExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -881,6 +886,10 @@ mod tests { } impl ExecutionPlan for RenamedEmptyExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn static_name() -> &'static str where Self: Sized, @@ -931,6 +940,14 @@ mod tests { assert_eq!(renamed_exec.name(), "MyRenamedEmptyExec"); assert_eq!(RenamedEmptyExec::static_name(), "MyRenamedEmptyExec"); } + + /// A compilation test to ensure that the `ExecutionPlan::name()` method can + /// be called from a trait object. + /// Related ticket: https://github.com/apache/datafusion/pull/11047 + #[allow(dead_code)] + fn use_execution_plan_as_trait_object(plan: &dyn ExecutionPlan) { + let _ = plan.name(); + } } pub mod test; diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index d5ad9292b49d..ad47a484c9f7 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -177,6 +177,10 @@ impl DisplayAs for MockExec { } impl ExecutionPlan for MockExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -335,6 +339,10 @@ impl DisplayAs for BarrierExec { } impl ExecutionPlan for BarrierExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -444,6 +452,10 @@ impl DisplayAs for ErrorExec { } impl ExecutionPlan for ErrorExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -527,6 +539,10 @@ impl DisplayAs for StatisticsExec { } impl ExecutionPlan for StatisticsExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -619,6 +635,10 @@ impl DisplayAs for BlockingExec { } impl ExecutionPlan for BlockingExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -760,6 +780,10 @@ impl DisplayAs for PanicExec { } impl ExecutionPlan for PanicExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index fc60ab997375..9eb29891703e 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1653,7 +1653,7 @@ mod tests { // // Effectively following query is run on this data // - // SELECT *, COUNT(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) + // SELECT *, count(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) // FROM test; // // partition `duplicated_hash=2` receives following data from the input @@ -1727,8 +1727,8 @@ mod tests { let plan = projection_exec(window)?; let expected_plan = vec![ - "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, COUNT([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]", - " BoundedWindowAggExec: wdw=[COUNT([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Ok(Field { name: \"COUNT([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", + "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]", + " BoundedWindowAggExec: wdw=[count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Ok(Field { name: \"count([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", " StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]", ]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index ecfe123a43af..181c30800434 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -65,7 +65,11 @@ pub fn schema_add_window_field( .iter() .map(|e| e.clone().as_ref().data_type(schema)) .collect::>>()?; - let window_expr_return_type = window_fn.return_type(&data_types)?; + let nullability = args + .iter() + .map(|e| e.clone().as_ref().nullable(schema)) + .collect::>>()?; + let window_expr_return_type = window_fn.return_type(&data_types, &nullability)?; let mut window_fields = schema .fields() .iter() diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 66ce7cbd838f..e5d65827cdec 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.75" +rust-version = "1.76" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 9f8f03de6dc9..54ec0e44694b 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen-common" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.75" +rust-version = "1.76" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index e523ef1a5e93..225bb9ddf661 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -385,6 +385,12 @@ message CsvWriterOptions { string time_format = 7; // Optional value to represent null string null_value = 8; + // Optional quote. Defaults to `b'"'` + string quote = 9; + // Optional escape. Defaults to `'\\'` + string escape = 10; + // Optional flag whether to double quotes, instead of escaping. Defaults to `true` + bool double_quote = 11; } // Options controlling CSV format @@ -402,6 +408,7 @@ message CsvOptions { string time_format = 11; // Optional time format string null_value = 12; // Optional representation of null value bytes comment = 13; // Optional comment character as a byte + bytes double_quote = 14; // Indicates if quotes are doubled } // Options controlling CSV format diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index be87123fb13f..de9fede9ee86 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -857,6 +857,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions { delimiter: proto_opts.delimiter[0], quote: proto_opts.quote[0], escape: proto_opts.escape.first().copied(), + double_quote: proto_opts.has_header.first().map(|h| *h != 0), compression: proto_opts.compression().into(), schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize, date_format: (!proto_opts.date_format.is_empty()) @@ -1091,11 +1092,34 @@ pub(crate) fn csv_writer_options_from_proto( return Err(proto_error("Error parsing CSV Delimiter")); } } + if !writer_options.quote.is_empty() { + if let Some(quote) = writer_options.quote.chars().next() { + if quote.is_ascii() { + builder = builder.with_quote(quote as u8); + } else { + return Err(proto_error("CSV Quote is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Quote")); + } + } + if !writer_options.escape.is_empty() { + if let Some(escape) = writer_options.escape.chars().next() { + if escape.is_ascii() { + builder = builder.with_escape(escape as u8); + } else { + return Err(proto_error("CSV Escape is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Escape")); + } + } Ok(builder .with_header(writer_options.has_header) .with_date_format(writer_options.date_format.clone()) .with_datetime_format(writer_options.datetime_format.clone()) .with_timestamp_format(writer_options.timestamp_format.clone()) .with_time_format(writer_options.time_format.clone()) - .with_null(writer_options.null_value.clone())) + .with_null(writer_options.null_value.clone()) + .with_double_quote(writer_options.double_quote)) } diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index ead29d9b92e0..3cf34aeb6d01 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -1881,6 +1881,9 @@ impl serde::Serialize for CsvOptions { if !self.comment.is_empty() { len += 1; } + if !self.double_quote.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?; if !self.has_header.is_empty() { #[allow(clippy::needless_borrow)] @@ -1929,6 +1932,10 @@ impl serde::Serialize for CsvOptions { #[allow(clippy::needless_borrow)] struct_ser.serialize_field("comment", pbjson::private::base64::encode(&self.comment).as_str())?; } + if !self.double_quote.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?; + } struct_ser.end() } } @@ -1960,6 +1967,8 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "null_value", "nullValue", "comment", + "double_quote", + "doubleQuote", ]; #[allow(clippy::enum_variant_names)] @@ -1977,6 +1986,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { TimeFormat, NullValue, Comment, + DoubleQuote, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2011,6 +2021,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), "nullValue" | "null_value" => Ok(GeneratedField::NullValue), "comment" => Ok(GeneratedField::Comment), + "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2043,6 +2054,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { let mut time_format__ = None; let mut null_value__ = None; let mut comment__ = None; + let mut double_quote__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::HasHeader => { @@ -2135,6 +2147,14 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } + GeneratedField::DoubleQuote => { + if double_quote__.is_some() { + return Err(serde::de::Error::duplicate_field("doubleQuote")); + } + double_quote__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } } } Ok(CsvOptions { @@ -2151,6 +2171,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { time_format: time_format__.unwrap_or_default(), null_value: null_value__.unwrap_or_default(), comment: comment__.unwrap_or_default(), + double_quote: double_quote__.unwrap_or_default(), }) } } @@ -2189,6 +2210,15 @@ impl serde::Serialize for CsvWriterOptions { if !self.null_value.is_empty() { len += 1; } + if !self.quote.is_empty() { + len += 1; + } + if !self.escape.is_empty() { + len += 1; + } + if self.double_quote { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvWriterOptions", len)?; if self.compression != 0 { let v = CompressionTypeVariant::try_from(self.compression) @@ -2216,6 +2246,15 @@ impl serde::Serialize for CsvWriterOptions { if !self.null_value.is_empty() { struct_ser.serialize_field("nullValue", &self.null_value)?; } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if !self.escape.is_empty() { + struct_ser.serialize_field("escape", &self.escape)?; + } + if self.double_quote { + struct_ser.serialize_field("doubleQuote", &self.double_quote)?; + } struct_ser.end() } } @@ -2240,6 +2279,10 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { "timeFormat", "null_value", "nullValue", + "quote", + "escape", + "double_quote", + "doubleQuote", ]; #[allow(clippy::enum_variant_names)] @@ -2252,6 +2295,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { TimestampFormat, TimeFormat, NullValue, + Quote, + Escape, + DoubleQuote, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2281,6 +2327,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), + "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2308,6 +2357,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { let mut timestamp_format__ = None; let mut time_format__ = None; let mut null_value__ = None; + let mut quote__ = None; + let mut escape__ = None; + let mut double_quote__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Compression => { @@ -2358,6 +2410,24 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { } null_value__ = Some(map_.next_value()?); } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map_.next_value()?); + } + GeneratedField::Escape => { + if escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + escape__ = Some(map_.next_value()?); + } + GeneratedField::DoubleQuote => { + if double_quote__.is_some() { + return Err(serde::de::Error::duplicate_field("doubleQuote")); + } + double_quote__ = Some(map_.next_value()?); + } } } Ok(CsvWriterOptions { @@ -2369,6 +2439,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { timestamp_format: timestamp_format__.unwrap_or_default(), time_format: time_format__.unwrap_or_default(), null_value: null_value__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + escape: escape__.unwrap_or_default(), + double_quote: double_quote__.unwrap_or_default(), }) } } diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index b306f3212a2f..57893321e665 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -575,6 +575,15 @@ pub struct CsvWriterOptions { /// Optional value to represent null #[prost(string, tag = "8")] pub null_value: ::prost::alloc::string::String, + /// Optional quote. Defaults to `b'"'` + #[prost(string, tag = "9")] + pub quote: ::prost::alloc::string::String, + /// Optional escape. Defaults to `'\\'` + #[prost(string, tag = "10")] + pub escape: ::prost::alloc::string::String, + /// Optional flag whether to double quote instead of escaping. Defaults to `true` + #[prost(bool, tag = "11")] + pub double_quote: bool, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] @@ -619,6 +628,9 @@ pub struct CsvOptions { /// Optional comment character as a byte #[prost(bytes = "vec", tag = "13")] pub comment: ::prost::alloc::vec::Vec, + /// Indicates if quotes are doubled + #[prost(bytes = "vec", tag = "14")] + pub double_quote: ::prost::alloc::vec::Vec, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index a3dc826a79ca..877043f66809 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -896,6 +896,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions { delimiter: vec![opts.delimiter], quote: vec![opts.quote], escape: opts.escape.map_or_else(Vec::new, |e| vec![e]), + double_quote: opts.double_quote.map_or_else(Vec::new, |h| vec![h as u8]), compression: compression.into(), schema_infer_max_rec: opts.schema_infer_max_rec as u64, date_format: opts.date_format.clone().unwrap_or_default(), @@ -1022,5 +1023,8 @@ pub(crate) fn csv_writer_options_to_proto( timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), time_format: csv_options.time_format().unwrap_or("").to_owned(), null_value: csv_options.null().to_owned(), + quote: (csv_options.quote() as char).to_string(), + escape: (csv_options.escape() as char).to_string(), + double_quote: csv_options.double_quote(), } } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index aa8d0e55b68f..95d9e6700a50 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,7 +27,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.75" +rust-version = "1.76" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index eabaf7ba8e14..401c51c94563 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.75" +rust-version = "1.76" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/gen/src/main.rs b/datafusion/proto/gen/src/main.rs index 22c16eacb093..d38a41a01ac2 100644 --- a/datafusion/proto/gen/src/main.rs +++ b/datafusion/proto/gen/src/main.rs @@ -29,6 +29,7 @@ fn main() -> Result<(), String> { let descriptor_path = proto_dir.join("proto/proto_descriptor.bin"); prost_build::Config::new() + .protoc_arg("--experimental_allow_proto3_optional") .file_descriptor_set_path(&descriptor_path) .out_dir(out_dir) .compile_well_known_types() diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 50356d5b6052..f2594ba10340 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -251,13 +251,7 @@ message DistinctOnNode { message CopyToNode { LogicalPlanNode input = 1; string output_url = 2; - oneof format_options { - datafusion_common.CsvOptions csv = 8; - datafusion_common.JsonOptions json = 9; - datafusion_common.TableParquetOptions parquet = 10; - datafusion_common.AvroOptions avro = 11; - datafusion_common.ArrowOptions arrow = 12; - } + bytes file_type = 3; repeated string partition_by = 7; } @@ -369,7 +363,7 @@ message LogicalExprNode { } message Wildcard { - string qualifier = 1; + TableReference qualifier = 1; } message PlaceholderNode { @@ -475,7 +469,7 @@ enum AggregateFunction { MIN = 0; MAX = 1; // SUM = 2; - AVG = 3; + // AVG = 3; // COUNT = 4; // APPROX_DISTINCT = 5; ARRAY_AGG = 6; @@ -485,7 +479,7 @@ enum AggregateFunction { // COVARIANCE_POP = 10; // STDDEV = 11; // STDDEV_POP = 12; - CORRELATION = 13; + // CORRELATION = 13; // APPROX_PERCENTILE_CONT = 14; // APPROX_MEDIAN = 15; // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index b306f3212a2f..875fe8992e90 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -575,6 +575,15 @@ pub struct CsvWriterOptions { /// Optional value to represent null #[prost(string, tag = "8")] pub null_value: ::prost::alloc::string::String, + /// Optional quote. Defaults to `b'"'` + #[prost(string, tag = "9")] + pub quote: ::prost::alloc::string::String, + /// Optional escape. Defaults to `'\\'` + #[prost(string, tag = "10")] + pub escape: ::prost::alloc::string::String, + /// Optional flag whether to double quotes, instead of escaping. Defaults to `true` + #[prost(bool, tag = "11")] + pub double_quote: bool, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] @@ -619,6 +628,9 @@ pub struct CsvOptions { /// Optional comment character as a byte #[prost(bytes = "vec", tag = "13")] pub comment: ::prost::alloc::vec::Vec, + /// Indicates if quotes are doubled + #[prost(bytes = "vec", tag = "14")] + pub double_quote: ::prost::alloc::vec::Vec, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8cca0fe4a876..e8fbe954428a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -534,9 +534,7 @@ impl serde::Serialize for AggregateFunction { let variant = match self { Self::Min => "MIN", Self::Max => "MAX", - Self::Avg => "AVG", Self::ArrayAgg => "ARRAY_AGG", - Self::Correlation => "CORRELATION", Self::Grouping => "GROUPING", Self::NthValueAgg => "NTH_VALUE_AGG", }; @@ -552,9 +550,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { const FIELDS: &[&str] = &[ "MIN", "MAX", - "AVG", "ARRAY_AGG", - "CORRELATION", "GROUPING", "NTH_VALUE_AGG", ]; @@ -599,9 +595,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { match value { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), - "AVG" => Ok(AggregateFunction::Avg), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), - "CORRELATION" => Ok(AggregateFunction::Correlation), "GROUPING" => Ok(AggregateFunction::Grouping), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), @@ -2542,10 +2536,10 @@ impl serde::Serialize for CopyToNode { if !self.output_url.is_empty() { len += 1; } - if !self.partition_by.is_empty() { + if !self.file_type.is_empty() { len += 1; } - if self.format_options.is_some() { + if !self.partition_by.is_empty() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.CopyToNode", len)?; @@ -2555,28 +2549,13 @@ impl serde::Serialize for CopyToNode { if !self.output_url.is_empty() { struct_ser.serialize_field("outputUrl", &self.output_url)?; } + if !self.file_type.is_empty() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("fileType", pbjson::private::base64::encode(&self.file_type).as_str())?; + } if !self.partition_by.is_empty() { struct_ser.serialize_field("partitionBy", &self.partition_by)?; } - if let Some(v) = self.format_options.as_ref() { - match v { - copy_to_node::FormatOptions::Csv(v) => { - struct_ser.serialize_field("csv", v)?; - } - copy_to_node::FormatOptions::Json(v) => { - struct_ser.serialize_field("json", v)?; - } - copy_to_node::FormatOptions::Parquet(v) => { - struct_ser.serialize_field("parquet", v)?; - } - copy_to_node::FormatOptions::Avro(v) => { - struct_ser.serialize_field("avro", v)?; - } - copy_to_node::FormatOptions::Arrow(v) => { - struct_ser.serialize_field("arrow", v)?; - } - } - } struct_ser.end() } } @@ -2590,25 +2569,18 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { "input", "output_url", "outputUrl", + "file_type", + "fileType", "partition_by", "partitionBy", - "csv", - "json", - "parquet", - "avro", - "arrow", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, OutputUrl, + FileType, PartitionBy, - Csv, - Json, - Parquet, - Avro, - Arrow, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2632,12 +2604,8 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { match value { "input" => Ok(GeneratedField::Input), "outputUrl" | "output_url" => Ok(GeneratedField::OutputUrl), + "fileType" | "file_type" => Ok(GeneratedField::FileType), "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), - "csv" => Ok(GeneratedField::Csv), - "json" => Ok(GeneratedField::Json), - "parquet" => Ok(GeneratedField::Parquet), - "avro" => Ok(GeneratedField::Avro), - "arrow" => Ok(GeneratedField::Arrow), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2659,8 +2627,8 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { { let mut input__ = None; let mut output_url__ = None; + let mut file_type__ = None; let mut partition_by__ = None; - let mut format_options__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -2675,54 +2643,27 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { } output_url__ = Some(map_.next_value()?); } + GeneratedField::FileType => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fileType")); + } + file_type__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } GeneratedField::PartitionBy => { if partition_by__.is_some() { return Err(serde::de::Error::duplicate_field("partitionBy")); } partition_by__ = Some(map_.next_value()?); } - GeneratedField::Csv => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("csv")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Csv) -; - } - GeneratedField::Json => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("json")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Json) -; - } - GeneratedField::Parquet => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("parquet")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Parquet) -; - } - GeneratedField::Avro => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("avro")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Avro) -; - } - GeneratedField::Arrow => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("arrow")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Arrow) -; - } } } Ok(CopyToNode { input: input__, output_url: output_url__.unwrap_or_default(), + file_type: file_type__.unwrap_or_default(), partition_by: partition_by__.unwrap_or_default(), - format_options: format_options__, }) } } @@ -19967,12 +19908,12 @@ impl serde::Serialize for Wildcard { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.qualifier.is_empty() { + if self.qualifier.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.Wildcard", len)?; - if !self.qualifier.is_empty() { - struct_ser.serialize_field("qualifier", &self.qualifier)?; + if let Some(v) = self.qualifier.as_ref() { + struct_ser.serialize_field("qualifier", v)?; } struct_ser.end() } @@ -20038,12 +19979,12 @@ impl<'de> serde::Deserialize<'de> for Wildcard { if qualifier__.is_some() { return Err(serde::de::Error::duplicate_field("qualifier")); } - qualifier__ = Some(map_.next_value()?); + qualifier__ = map_.next_value()?; } } } Ok(Wildcard { - qualifier: qualifier__.unwrap_or_default(), + qualifier: qualifier__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 56f14982923d..93bf6c060227 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -411,27 +411,10 @@ pub struct CopyToNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(string, tag = "2")] pub output_url: ::prost::alloc::string::String, + #[prost(bytes = "vec", tag = "3")] + pub file_type: ::prost::alloc::vec::Vec, #[prost(string, repeated, tag = "7")] pub partition_by: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - #[prost(oneof = "copy_to_node::FormatOptions", tags = "8, 9, 10, 11, 12")] - pub format_options: ::core::option::Option, -} -/// Nested message and enum types in `CopyToNode`. -pub mod copy_to_node { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum FormatOptions { - #[prost(message, tag = "8")] - Csv(super::super::datafusion_common::CsvOptions), - #[prost(message, tag = "9")] - Json(super::super::datafusion_common::JsonOptions), - #[prost(message, tag = "10")] - Parquet(super::super::datafusion_common::TableParquetOptions), - #[prost(message, tag = "11")] - Avro(super::super::datafusion_common::AvroOptions), - #[prost(message, tag = "12")] - Arrow(super::super::datafusion_common::ArrowOptions), - } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -592,8 +575,8 @@ pub mod logical_expr_node { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Wildcard { - #[prost(string, tag = "1")] - pub qualifier: ::prost::alloc::string::String, + #[prost(message, optional, tag = "1")] + pub qualifier: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -1929,7 +1912,7 @@ pub enum AggregateFunction { Min = 0, Max = 1, /// SUM = 2; - Avg = 3, + /// AVG = 3; /// COUNT = 4; /// APPROX_DISTINCT = 5; ArrayAgg = 6, @@ -1939,7 +1922,7 @@ pub enum AggregateFunction { /// COVARIANCE_POP = 10; /// STDDEV = 11; /// STDDEV_POP = 12; - Correlation = 13, + /// CORRELATION = 13; /// APPROX_PERCENTILE_CONT = 14; /// APPROX_MEDIAN = 15; /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; @@ -1971,9 +1954,7 @@ impl AggregateFunction { match self { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", - AggregateFunction::Avg => "AVG", AggregateFunction::ArrayAgg => "ARRAY_AGG", - AggregateFunction::Correlation => "CORRELATION", AggregateFunction::Grouping => "GROUPING", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } @@ -1983,9 +1964,7 @@ impl AggregateFunction { match value { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), - "AVG" => Some(Self::Avg), "ARRAY_AGG" => Some(Self::ArrayAgg), - "CORRELATION" => Some(Self::Correlation), "GROUPING" => Some(Self::Grouping), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs new file mode 100644 index 000000000000..106d5639489e --- /dev/null +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -0,0 +1,409 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::{ + datasource::file_format::{ + arrow::ArrowFormatFactory, csv::CsvFormatFactory, json::JsonFormatFactory, + parquet::ParquetFormatFactory, FileFormatFactory, + }, + prelude::SessionContext, +}; +use datafusion_common::{not_impl_err, TableReference}; + +use super::LogicalExtensionCodec; + +#[derive(Debug)] +pub struct CsvLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for CsvLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(CsvFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +pub struct JsonLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for JsonLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(JsonFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +pub struct ParquetLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(ParquetFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +pub struct ArrowLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(ArrowFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +pub struct AvroLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for AvroLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _cts: &datafusion::prelude::SessionContext, + ) -> datafusion_common::Result< + std::sync::Arc, + > { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: std::sync::Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(ArrowFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &[u8], + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } + + fn try_decode_udf( + &self, + name: &str, + __buf: &[u8], + ) -> datafusion_common::Result> { + not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") + } + + fn try_encode_udf( + &self, + __node: &datafusion_expr::ScalarUDF, + __buf: &mut Vec, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ba0e708218cf..21331a94c18c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -139,9 +139,7 @@ impl From for AggregateFunction { match agg_fun { protobuf::AggregateFunction::Min => Self::Min, protobuf::AggregateFunction::Max => Self::Max, - protobuf::AggregateFunction::Avg => Self::Avg, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, - protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, } @@ -268,7 +266,11 @@ pub fn parse_expr( Ok(operands .into_iter() .reduce(|left, right| { - Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) + Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + op.clone(), + Box::new(right), + )) }) .expect("Binary expression could not be reduced to a single expression.")) } @@ -593,13 +595,10 @@ pub fn parse_expr( parse_exprs(&in_list.list, registry, codec)?, in_list.negated, ))), - ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { - qualifier: if qualifier.is_empty() { - None - } else { - Some(qualifier.clone()) - }, - }), + ExprType::Wildcard(protobuf::Wildcard { qualifier }) => { + let qualifier = qualifier.to_owned().map(|x| x.try_into()).transpose()?; + Ok(Expr::Wildcard { qualifier }) + } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index ef37150a35db..664cd7e11555 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -33,6 +33,9 @@ use crate::protobuf::{proto_error, FromProtoError, ToProtoError}; use arrow::datatypes::{DataType, Schema, SchemaRef}; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::{ + file_type_to_format, format_as_file_type, FileFormatFactory, +}; use datafusion::{ datasource::{ file_format::{avro::AvroFormat, csv::CsvFormat, FileFormat}, @@ -43,6 +46,7 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ context, internal_datafusion_err, internal_err, not_impl_err, DataFusionError, Result, TableReference, @@ -64,6 +68,7 @@ use prost::Message; use self::to_proto::serialize_expr; +pub mod file_formats; pub mod from_proto; pub mod to_proto; @@ -104,16 +109,34 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_decode_table_provider( &self, buf: &[u8], + table_ref: &TableReference, schema: SchemaRef, ctx: &SessionContext, ) -> Result>; fn try_encode_table_provider( &self, + table_ref: &TableReference, node: Arc, buf: &mut Vec, ) -> Result<()>; + fn try_decode_file_format( + &self, + _buf: &[u8], + _ctx: &SessionContext, + ) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for file format") + } + + fn try_encode_file_format( + &self, + _buf: &[u8], + _node: Arc, + ) -> Result<()> { + Ok(()) + } + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") } @@ -143,6 +166,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { fn try_decode_table_provider( &self, _buf: &[u8], + _table_ref: &TableReference, _schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { @@ -151,6 +175,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { fn try_encode_table_provider( &self, + _table_ref: &TableReference, _node: Arc, _buf: &mut Vec, ) -> Result<()> { @@ -424,15 +449,17 @@ impl AsLogicalPlan for LogicalPlanNode { .iter() .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .collect::, _>>()?; + + let table_name = + from_table_reference(scan.table_name.as_ref(), "CustomScan")?; + let provider = extension_codec.try_decode_table_provider( &scan.custom_table_data, + &table_name, schema, ctx, )?; - let table_name = - from_table_reference(scan.table_name.as_ref(), "CustomScan")?; - LogicalPlanBuilder::scan_with_filters( table_name, provider_as_source(provider), @@ -829,12 +856,16 @@ impl AsLogicalPlan for LogicalPlanNode { let input: LogicalPlan = into_logical_plan!(copy.input, ctx, extension_codec)?; + let file_type: Arc = format_as_file_type( + extension_codec.try_decode_file_format(©.file_type, ctx)?, + ); + Ok(datafusion_expr::LogicalPlan::Copy( datafusion_expr::dml::CopyTo { input: Arc::new(input), output_url: copy.output_url.clone(), partition_by: copy.partition_by.clone(), - format_options: convert_required!(copy.format_options)?, + file_type, options: Default::default(), }, )) @@ -1023,7 +1054,7 @@ impl AsLogicalPlan for LogicalPlanNode { } else { let mut bytes = vec![]; extension_codec - .try_encode_table_provider(provider, &mut bytes) + .try_encode_table_provider(table_name, provider, &mut bytes) .map_err(|e| context!("Error serializing custom table", e))?; let scan = CustomScan(CustomTableScanNode { table_name: Some(table_name.clone().into()), @@ -1609,7 +1640,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Copy(dml::CopyTo { input, output_url, - format_options, + file_type, partition_by, .. }) => { @@ -1618,12 +1649,16 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec, )?; + let buf = Vec::new(); + extension_codec + .try_encode_file_format(&buf, file_type_to_format(file_type)?)?; + Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( protobuf::CopyToNode { input: Some(Box::new(input)), output_url: output_url.to_string(), - format_options: Some(format_options.try_into()?), + file_type: buf, partition_by: partition_by.clone(), }, ))), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 08999effa4b1..3a1db1defdd9 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -110,9 +110,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { match value { AggregateFunction::Min => Self::Min, AggregateFunction::Max => Self::Max, - AggregateFunction::Avg => Self::Avg, AggregateFunction::ArrayAgg => Self::ArrayAgg, - AggregateFunction::Correlation => Self::Correlation, AggregateFunction::Grouping => Self::Grouping, AggregateFunction::NthValue => Self::NthValueAgg, } @@ -374,10 +372,6 @@ pub fn serialize_expr( AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg @@ -618,7 +612,7 @@ pub fn serialize_expr( } Expr::Wildcard { qualifier } => protobuf::LogicalExprNode { expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { - qualifier: qualifier.clone().unwrap_or("".to_string()), + qualifier: qualifier.to_owned().map(|x| x.into()), })), }, Expr::ScalarSubquery(_) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index b636c77641c7..7783c1561185 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -41,14 +41,13 @@ use datafusion::physical_plan::expressions::{ }; use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; -use datafusion_common::config::FormatOptions; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_proto_common::common::proto_error; use crate::convert_required; use crate::logical_plan::{self}; +use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::{self, copy_to_node}; use super::PhysicalExtensionCodec; @@ -653,22 +652,3 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { }) } } - -impl TryFrom<©_to_node::FormatOptions> for FormatOptions { - type Error = DataFusionError; - fn try_from(value: ©_to_node::FormatOptions) -> Result { - Ok(match value { - copy_to_node::FormatOptions::Csv(options) => { - FormatOptions::CSV(options.try_into()?) - } - copy_to_node::FormatOptions::Json(options) => { - FormatOptions::JSON(options.try_into()?) - } - copy_to_node::FormatOptions::Parquet(options) => { - FormatOptions::PARQUET(options.try_into()?) - } - copy_to_node::FormatOptions::Avro(_) => FormatOptions::AVRO, - copy_to_node::FormatOptions::Arrow(_) => FormatOptions::ARROW, - }) - } -} diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 8a488d30cf24..56e702704798 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1010,7 +1010,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1027,7 +1027,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } @@ -1040,7 +1040,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1057,7 +1057,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } @@ -1070,7 +1070,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1087,7 +1087,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index a9d3736dee08..8583900e9fa7 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,10 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, - DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, - NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, - RankType, RowNumber, TryCastExpr, WindowShift, + ArrayAgg, BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, + Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, + NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, + RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -39,12 +39,11 @@ use datafusion::{ }, physical_plan::expressions::LikeExpr, }; -use datafusion_common::config::FormatOptions; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use crate::protobuf::{ - self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, - PhysicalSortExprNode, PhysicalSortExprNodeCollection, + self, physical_aggregate_expr_node, physical_window_expr_node, PhysicalSortExprNode, + PhysicalSortExprNodeCollection, }; use super::PhysicalExtensionCodec; @@ -165,21 +164,28 @@ pub fn serialize_physical_window_expr( } else if let Some(plain_aggr_window_expr) = expr.downcast_ref::() { - let AggrFn { inner, distinct } = - aggr_expr_to_aggr_fn(plain_aggr_window_expr.get_aggregate_expr().as_ref())?; + let aggr_expr = plain_aggr_window_expr.get_aggregate_expr(); + if let Some(a) = aggr_expr.as_any().downcast_ref::() { + physical_window_expr_node::WindowFunction::UserDefinedAggrFunction( + a.fun().name().to_string(), + ) + } else { + let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( + plain_aggr_window_expr.get_aggregate_expr().as_ref(), + )?; - if distinct { - // TODO - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } + if distinct { + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } - if !window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); - } + if !window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + } - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + } } else if let Some(sliding_aggr_window_expr) = expr.downcast_ref::() { @@ -251,10 +257,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Min } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Max - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Avg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Correlation } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::NthValueAgg } else { @@ -725,26 +727,3 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { }) } } - -impl TryFrom<&FormatOptions> for copy_to_node::FormatOptions { - type Error = DataFusionError; - fn try_from(value: &FormatOptions) -> std::result::Result { - Ok(match value { - FormatOptions::CSV(options) => { - copy_to_node::FormatOptions::Csv(options.try_into()?) - } - FormatOptions::JSON(options) => { - copy_to_node::FormatOptions::Json(options.try_into()?) - } - FormatOptions::PARQUET(options) => { - copy_to_node::FormatOptions::Parquet(options.try_into()?) - } - FormatOptions::AVRO => { - copy_to_node::FormatOptions::Avro(protobuf::AvroOptions {}) - } - FormatOptions::ARROW => { - copy_to_node::FormatOptions::Arrow(protobuf::ArrowOptions {}) - } - }) - } -} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b3966c3f0204..fe3da3d05854 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,6 +26,13 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use datafusion::datasource::file_format::arrow::ArrowFormatFactory; +use datafusion::datasource::file_format::csv::CsvFormatFactory; +use datafusion::datasource::file_format::format_as_file_type; +use datafusion::datasource::file_format::parquet::ParquetFormatFactory; +use datafusion_proto::logical_plan::file_formats::{ + ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, ParquetLogicalExtensionCodec, +}; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -41,11 +48,11 @@ use datafusion::functions_aggregate::expr_fn::{ }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; -use datafusion_common::config::{FormatOptions, TableOptions}; +use datafusion_common::config::TableOptions; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef, - DataFusionError, FileType, Result, ScalarValue, + DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -59,8 +66,9 @@ use datafusion_expr::{ TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; +use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::expr_fn::{ - bit_and, bit_or, bit_xor, bool_and, bool_or, + avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -126,6 +134,9 @@ pub struct TestTableProto { /// URL of the table root #[prost(string, tag = "1")] pub url: String, + /// Qualified table name + #[prost(string, tag = "2")] + pub table_name: String, } #[derive(Debug)] @@ -148,12 +159,14 @@ impl LogicalExtensionCodec for TestTableProviderCodec { fn try_decode_table_provider( &self, buf: &[u8], + table_ref: &TableReference, schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { let msg = TestTableProto::decode(buf).map_err(|_| { DataFusionError::Internal("Error decoding test table".to_string()) })?; + assert_eq!(msg.table_name, table_ref.to_string()); let provider = TestTableProvider { url: msg.url, schema, @@ -163,6 +176,7 @@ impl LogicalExtensionCodec for TestTableProviderCodec { fn try_encode_table_provider( &self, + table_ref: &TableReference, node: Arc, buf: &mut Vec, ) -> Result<()> { @@ -173,6 +187,7 @@ impl LogicalExtensionCodec for TestTableProviderCodec { .expect("Can't encode non-test tables"); let msg = TestTableProto { url: table.url.clone(), + table_name: table_ref.to_string(), }; msg.encode(buf).map_err(|_| { DataFusionError::Internal("Error encoding test table".to_string()) @@ -325,20 +340,20 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let ctx = SessionContext::new(); let input = create_csv_scan(&ctx).await?; - let mut table_options = ctx.copied_table_options(); - table_options.set_file_format(FileType::CSV); - table_options.set("format.delimiter", ";")?; + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - format_options: FormatOptions::CSV(table_options.csv.clone()), + file_type, options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = CsvLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); Ok(()) @@ -363,26 +378,27 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { parquet_format.global.dictionary_page_size_limit = 444; parquet_format.global.max_row_group_size = 555; + let file_type = format_as_file_type(Arc::new(ParquetFormatFactory::new())); + let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.parquet".to_string(), - format_options: FormatOptions::PARQUET(parquet_format.clone()), + file_type, partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = ParquetLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.parquet", copy_to.output_url); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); - assert_eq!( - copy_to.format_options, - FormatOptions::PARQUET(parquet_format) - ); + assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); } _ => panic!(), } @@ -395,22 +411,26 @@ async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { let input = create_csv_scan(&ctx).await?; + let file_type = format_as_file_type(Arc::new(ArrowFormatFactory::new())); + let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.arrow".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - format_options: FormatOptions::ARROW, + file_type, options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = ArrowLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.arrow", copy_to.output_url); - assert_eq!(FormatOptions::ARROW, copy_to.format_options); + assert_eq!("arrow".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); } _ => panic!(), @@ -436,22 +456,26 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.time_format = Some("HH:mm:ss".to_string()); csv_format.null_value = Some("NIL".to_string()); + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); + let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - format_options: FormatOptions::CSV(csv_format.clone()), + file_type, options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = CsvLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.csv", copy_to.output_url); - assert_eq!(FormatOptions::CSV(csv_format), copy_to.format_options); + assert_eq!("csv".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); } _ => panic!(), @@ -658,8 +682,10 @@ async fn roundtrip_expr_api() -> Result<()> { count_distinct(lit(1)), first_value(lit(1), None), first_value(lit(1), Some(vec![lit(2).sort(true, true)])), + avg(lit(1.5)), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), + corr(lit(1.5), lit(2.2)), sum(lit(1)), median(lit(2)), var_sample(lit(2.2)), @@ -847,6 +873,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { fn try_decode_table_provider( &self, _buf: &[u8], + _table_ref: &TableReference, _schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { @@ -855,6 +882,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { fn try_encode_table_provider( &self, + _table_ref: &TableReference, _node: Arc, _buf: &mut Vec, ) -> Result<()> { @@ -924,6 +952,7 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_decode_table_provider( &self, _buf: &[u8], + _table_ref: &TableReference, _schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { @@ -932,6 +961,7 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_encode_table_provider( &self, + _table_ref: &TableReference, _node: Arc, _buf: &mut Vec, ) -> Result<()> { @@ -981,7 +1011,7 @@ fn round_trip_scalar_values() { ScalarValue::UInt64(None), ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), - ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)), + ScalarValue::List(ScalarValue::new_list_nullable(&[], &DataType::Boolean)), ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), @@ -1073,7 +1103,7 @@ fn round_trip_scalar_values() { i64::MAX, ))), ScalarValue::IntervalMonthDayNano(None), - ScalarValue::List(ScalarValue::new_list( + ScalarValue::List(ScalarValue::new_list_nullable( &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), @@ -1093,10 +1123,13 @@ fn round_trip_scalar_values() { ], &DataType::Float32, )), - ScalarValue::List(ScalarValue::new_list( + ScalarValue::List(ScalarValue::new_list_nullable( &[ - ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)), - ScalarValue::List(ScalarValue::new_list( + ScalarValue::List(ScalarValue::new_list_nullable( + &[], + &DataType::Float32, + )), + ScalarValue::List(ScalarValue::new_list_nullable( &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), @@ -2163,7 +2196,16 @@ fn roundtrip_window() { vec![col("col1")], vec![col("col1")], vec![col("col2")], - row_number_frame, + row_number_frame.clone(), + None, + )); + + let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(avg_udaf()), + vec![col("col1")], + vec![], + vec![], + row_number_frame.clone(), None, )); @@ -2174,5 +2216,6 @@ fn roundtrip_window() { roundtrip_expr_test(test_expr3, ctx.clone()); roundtrip_expr_test(test_expr4, ctx.clone()); roundtrip_expr_test(test_expr5, ctx.clone()); - roundtrip_expr_test(test_expr6, ctx); + roundtrip_expr_test(test_expr6, ctx.clone()); + roundtrip_expr_test(text_expr7, ctx); } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index eb3313239544..03c72cfc32b1 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -47,7 +47,7 @@ use datafusion::physical_plan::aggregates::{ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ - binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr, NthValue, + binary, cast, col, in_list, like, lit, BinaryExpr, Column, NotExpr, NthValue, PhysicalSortExpr, }; use datafusion::physical_plan::filter::FilterExec; @@ -60,6 +60,7 @@ use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::udaf::create_aggregate_expr; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, @@ -79,6 +80,7 @@ use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; +use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::string_agg::StringAgg; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, @@ -282,11 +284,17 @@ fn roundtrip_window() -> Result<()> { )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( - Arc::new(Avg::new( - cast(col("b", &schema)?, &schema, DataType::Float64)?, - "AVG(b)".to_string(), - DataType::Float64, - )), + create_aggregate_expr( + &avg_udaf(), + &[cast(col("b", &schema)?, &schema, DataType::Float64)?], + &[], + &[], + &[], + &schema, + "avg(b)", + false, + false, + )?, &[], &[], Arc::new(WindowFrame::new(None)), @@ -342,11 +350,17 @@ fn rountrip_aggregate() -> Result<()> { let test_cases: Vec>> = vec![ // AVG - vec![Arc::new(Avg::new( - cast(col("b", &schema)?, &schema, DataType::Float64)?, - "AVG(b)".to_string(), - DataType::Float64, - ))], + vec![create_aggregate_expr( + &avg_udaf(), + &[col("b", &schema)?], + &[], + &[], + &[], + &schema, + "AVG(b)", + false, + false, + )?], // NTH_VALUE vec![Arc::new(NthValueAgg::new( col("b", &schema)?, @@ -398,11 +412,17 @@ fn rountrip_aggregate_with_limit() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![Arc::new(Avg::new( - cast(col("b", &schema)?, &schema, DataType::Float64)?, - "AVG(b)".to_string(), - DataType::Float64, - ))]; + let aggregates: Vec> = vec![create_aggregate_expr( + &avg_udaf(), + &[col("b", &schema)?], + &[], + &[], + &[], + &schema, + "AVG(b)", + false, + false, + )?]; let agg = AggregateExec::try_new( AggregateMode::Final, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 8b64ccfb52cb..a8af37ee6a37 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,6 +17,7 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; +use datafusion_common::utils::list_ndims; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; use datafusion_common::{ @@ -86,13 +87,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { StackEntry::Operator(op) => { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); - - let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - op, - Box::new(right), - )); - + let expr = self.build_logical_expr(op, left, right, schema)?; eval_stack.push(expr); } } @@ -103,6 +98,69 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(expr) } + fn build_logical_expr( + &self, + op: Operator, + left: Expr, + right: Expr, + schema: &DFSchema, + ) -> Result { + // Rewrite string concat operator to function based on types + // if we get list || list then we rewrite it to array_concat() + // if we get list || non-list then we rewrite it to array_append() + // if we get non-list || list then we rewrite it to array_prepend() + // if we get string || string then we rewrite it to concat() + if op == Operator::StringConcat { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + + // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. + // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. + if left_list_ndims + right_list_ndims == 0 { + // TODO: concat function ignore null, but string concat takes null into consideration + // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` + } else if left_list_ndims == right_list_ndims { + if let Some(udf) = self.context_provider.get_function_meta("array_concat") + { + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![left, right], + ))); + } else { + return internal_err!("array_concat not found"); + } + } else if left_list_ndims > right_list_ndims { + if let Some(udf) = self.context_provider.get_function_meta("array_append") + { + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![left, right], + ))); + } else { + return internal_err!("array_append not found"); + } + } else if left_list_ndims < right_list_ndims { + if let Some(udf) = + self.context_provider.get_function_meta("array_prepend") + { + return Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![left, right], + ))); + } else { + return internal_err!("array_append not found"); + } + } + } + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + op, + Box::new(right), + ))) + } + /// Generate a relational expression from a SQL expression pub fn sql_to_expr( &self, diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index bbc3a52f07ea..d2f3f508a316 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -22,7 +22,7 @@ use std::fmt; use sqlparser::{ ast::{ - ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query, + ColumnDef, ColumnOptionDef, Expr, ObjectName, OrderByExpr, Query, Statement as SQLStatement, TableConstraint, Value, }, dialect::{keywords::Keyword, Dialect, GenericDialect}, @@ -323,6 +323,14 @@ impl<'a> DFParser<'a> { Ok(stmts) } + pub fn parse_sql_into_expr_with_dialect( + sql: &str, + dialect: &dyn Dialect, + ) -> Result { + let mut parser = DFParser::new_with_dialect(sql, dialect)?; + parser.parse_expr() + } + /// Report an unexpected token fn expected( &self, @@ -367,6 +375,19 @@ impl<'a> DFParser<'a> { } } + pub fn parse_expr(&mut self) -> Result { + if let Token::Word(w) = self.parser.peek_token().token { + match w.keyword { + Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => { + return parser_err!("Unsupported command in expression"); + } + _ => {} + } + } + + self.parser.parse_expr() + } + /// Parse a SQL `COPY TO` statement pub fn parse_copy(&mut self) -> Result { // parse as a query diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 30f95170a34f..00f221200624 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; @@ -48,6 +49,11 @@ use crate::utils::make_decimal_type; pub trait ContextProvider { /// Getter for a datasource fn get_table_source(&self, name: TableReference) -> Result>; + + fn get_file_type(&self, _ext: &str) -> Result> { + not_impl_err!("Registered file types are not supported") + } + /// Getter for a table function fn get_table_function_source( &self, @@ -97,6 +103,7 @@ pub trait ContextProvider { pub struct ParserOptions { pub parse_float_as_decimal: bool, pub enable_ident_normalization: bool, + pub support_varchar_with_length: bool, } impl Default for ParserOptions { @@ -104,6 +111,7 @@ impl Default for ParserOptions { Self { parse_float_as_decimal: false, enable_ident_normalization: true, + support_varchar_with_length: true, } } } @@ -398,12 +406,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) | SQLDataType::UnsignedInt4(_) => { Ok(DataType::UInt32) } + SQLDataType::Varchar(length) => { + match (length, self.options.support_varchar_with_length) { + (Some(_), false) => plan_err!("does not support Varchar with length, please set `support_varchar_with_length` to be true"), + _ => Ok(DataType::Utf8), + } + } SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(DataType::UInt64), SQLDataType::Float(_) => Ok(DataType::Float32), SQLDataType::Real | SQLDataType::Float4 => Ok(DataType::Float32), SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64), SQLDataType::Char(_) - | SQLDataType::Varchar(_) | SQLDataType::Text | SQLDataType::String(_) => Ok(DataType::Utf8), SQLDataType::Timestamp(None, tz_info) => { diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 0fa266e4e01d..102b47216e7e 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -18,7 +18,9 @@ use std::collections::HashSet; use std::sync::Arc; -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use crate::planner::{ + idents_to_table_reference, ContextProvider, PlannerContext, SqlToRel, +}; use crate::utils::{ check_columns_satisfy_exprs, extract_aliases, rebase_expr, recursive_transform_unnest, resolve_aliases_to_exprs, resolve_columns, @@ -475,9 +477,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(expanded_exprs) } } - SelectItem::QualifiedWildcard(ref object_name, options) => { + SelectItem::QualifiedWildcard(object_name, options) => { Self::check_wildcard_options(&options)?; - let qualifier = format!("{object_name}"); + let qualifier = idents_to_table_reference(object_name.0, false)?; // do not expand from outer schema let expanded_exprs = expand_qualified_wildcard( &qualifier, diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index d10956efb66c..518972545a48 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -34,8 +34,8 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ exec_err, not_impl_err, plan_datafusion_err, plan_err, schema_err, unqualified_field_not_found, Column, Constraints, DFSchema, DFSchemaRef, - DataFusionError, FileType, Result, ScalarValue, SchemaError, SchemaReference, - TableReference, ToDFSchema, + DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, + ToDFSchema, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; @@ -899,31 +899,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - let file_type = if let Some(file_type) = statement.stored_as { - FileType::from_str(&file_type).map_err(|_| { - DataFusionError::Configuration(format!("Unknown FileType {}", file_type)) - })? + let maybe_file_type = if let Some(stored_as) = &statement.stored_as { + if let Ok(ext_file_type) = self.context_provider.get_file_type(stored_as) { + Some(ext_file_type) + } else { + None + } } else { - let e = || { - DataFusionError::Configuration( - "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." - .to_string(), - ) - }; - // try to infer file format from file extension - let extension: &str = &Path::new(&statement.target) - .extension() - .ok_or_else(e)? - .to_str() - .ok_or_else(e)? - .to_lowercase(); - - FileType::from_str(extension).map_err(|e| { - DataFusionError::Configuration(format!( - "{}. Use STORED AS to define file format.", - e - )) - })? + None + }; + + let file_type = match maybe_file_type { + Some(ft) => ft, + None => { + let e = || { + DataFusionError::Configuration( + "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." + .to_string(), + ) + }; + // try to infer file format from file extension + let extension: &str = &Path::new(&statement.target) + .extension() + .ok_or_else(e)? + .to_str() + .ok_or_else(e)? + .to_lowercase(); + + self.context_provider.get_file_type(extension)? + } }; let partition_by = statement @@ -938,7 +942,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: statement.target, - format_options: file_type.into(), + file_type, partition_by, options, })) @@ -964,7 +968,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.order_by_to_sort_expr(&expr, schema, planner_context, true, None)?; // Verify that columns of all SortExprs exist in the schema: for expr in expr_vec.iter() { - for column in expr.to_columns()?.iter() { + for column in expr.column_refs().iter() { if !schema.has_column(column) { // Return an error if any column is not in the schema: return plan_err!("Column {column} is not in schema"); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 65481aed64f9..ad898de5987a 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,11 +15,18 @@ // specific language governing permissions and limitations // under the License. -use arrow::util::display::array_value_to_string; use core::fmt; +use std::sync::Arc; use std::{fmt::Display, vec}; -use arrow_array::{Date32Array, Date64Array, TimestampNanosecondArray}; +use arrow::datatypes::{Decimal128Type, Decimal256Type, DecimalType}; +use arrow::util::display::array_value_to_string; +use arrow_array::types::{ + ArrowTemporalType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; +use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ @@ -647,6 +654,69 @@ impl Unparser<'_> { } } + fn handle_timestamp( + &self, + v: &ScalarValue, + tz: &Option>, + ) -> Result + where + i64: From, + { + let ts = if let Some(tz) = tz { + v.to_array()? + .as_any() + .downcast_ref::>() + .ok_or(internal_datafusion_err!( + "Failed to downcast type {v:?} to arrow array" + ))? + .value_as_datetime_with_tz(0, tz.parse()?) + .ok_or(internal_datafusion_err!( + "Unable to convert {v:?} to DateTime" + ))? + .to_string() + } else { + v.to_array()? + .as_any() + .downcast_ref::>() + .ok_or(internal_datafusion_err!( + "Failed to downcast type {v:?} to arrow array" + ))? + .value_as_datetime(0) + .ok_or(internal_datafusion_err!( + "Unable to convert {v:?} to DateTime" + ))? + .to_string() + }; + Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(ast::Expr::Value(SingleQuotedString(ts))), + data_type: ast::DataType::Timestamp(None, TimezoneInfo::None), + format: None, + }) + } + + fn handle_time(&self, v: &ScalarValue) -> Result + where + i64: From, + { + let time = v + .to_array()? + .as_any() + .downcast_ref::>() + .ok_or(internal_datafusion_err!( + "Failed to downcast type {v:?} to arrow array" + ))? + .value_as_time(0) + .ok_or(internal_datafusion_err!("Unable to convert {v:?} to Time"))? + .to_string(); + Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(ast::Expr::Value(SingleQuotedString(time))), + data_type: ast::DataType::Time(None, TimezoneInfo::None), + format: None, + }) + } + /// DataFusion ScalarValues sometimes require a ast::Expr to construct. /// For example ScalarValue::Date32(d) corresponds to the ast::Expr CAST('datestr' as DATE) fn scalar_to_sql(&self, v: &ScalarValue) -> Result { @@ -668,12 +738,18 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false))) } ScalarValue::Float64(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Decimal128(Some(_), ..) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::Decimal128(Some(value), precision, scale) => { + Ok(ast::Expr::Value(ast::Value::Number( + Decimal128Type::format_decimal(*value, *precision, *scale), + false, + ))) } ScalarValue::Decimal128(None, ..) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Decimal256(Some(_), ..) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::Decimal256(Some(value), precision, scale) => { + Ok(ast::Expr::Value(ast::Value::Number( + Decimal256Type::format_decimal(*value, *precision, *scale), + false, + ))) } ScalarValue::Decimal256(None, ..) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Int8(Some(i)) => { @@ -783,92 +859,56 @@ impl Unparser<'_> { } ScalarValue::Date64(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Time32Second(Some(_t)) => { - not_impl_err!("Unsupported scalar: {v:?}") + self.handle_time::(v) } ScalarValue::Time32Second(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Time32Millisecond(Some(_t)) => { - not_impl_err!("Unsupported scalar: {v:?}") + self.handle_time::(v) } ScalarValue::Time32Millisecond(None) => { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Time64Microsecond(Some(_t)) => { - not_impl_err!("Unsupported scalar: {v:?}") + self.handle_time::(v) } ScalarValue::Time64Microsecond(None) => { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Time64Nanosecond(Some(_t)) => { - not_impl_err!("Unsupported scalar: {v:?}") + self.handle_time::(v) } ScalarValue::Time64Nanosecond(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::TimestampSecond(Some(_ts), _) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::TimestampSecond(Some(_ts), tz) => { + self.handle_timestamp::(v, tz) } ScalarValue::TimestampSecond(None, _) => { Ok(ast::Expr::Value(ast::Value::Null)) } - ScalarValue::TimestampMillisecond(Some(_ts), _) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::TimestampMillisecond(Some(_ts), tz) => { + self.handle_timestamp::(v, tz) } ScalarValue::TimestampMillisecond(None, _) => { Ok(ast::Expr::Value(ast::Value::Null)) } - ScalarValue::TimestampMicrosecond(Some(_ts), _) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::TimestampMicrosecond(Some(_ts), tz) => { + self.handle_timestamp::(v, tz) } ScalarValue::TimestampMicrosecond(None, _) => { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::TimestampNanosecond(Some(_ts), tz) => { - let result = if let Some(tz) = tz { - v.to_array()? - .as_any() - .downcast_ref::() - .ok_or(internal_datafusion_err!( - "Unable to downcast to TimestampNanosecond from TimestampNanosecond scalar" - ))? - .value_as_datetime_with_tz(0, tz.parse()?) - .ok_or(internal_datafusion_err!( - "Unable to convert TimestampNanosecond to DateTime" - ))?.to_string() - } else { - v.to_array()? - .as_any() - .downcast_ref::() - .ok_or(internal_datafusion_err!( - "Unable to downcast to TimestampNanosecond from TimestampNanosecond scalar" - ))? - .value_as_datetime(0) - .ok_or(internal_datafusion_err!( - "Unable to convert TimestampNanosecond to NaiveDateTime" - ))?.to_string() - }; - Ok(ast::Expr::Cast { - kind: ast::CastKind::Cast, - expr: Box::new(ast::Expr::Value(SingleQuotedString(result))), - data_type: ast::DataType::Timestamp(None, TimezoneInfo::None), - format: None, - }) + self.handle_timestamp::(v, tz) } ScalarValue::TimestampNanosecond(None, _) => { Ok(ast::Expr::Value(ast::Value::Null)) } - ScalarValue::IntervalYearMonth(Some(_i)) => { - not_impl_err!("Unsupported scalar: {v:?}") - } - ScalarValue::IntervalYearMonth(None) => { - Ok(ast::Expr::Value(ast::Value::Null)) - } - ScalarValue::IntervalDayTime(Some(_i)) => { - not_impl_err!("Unsupported scalar: {v:?}") - } - ScalarValue::IntervalDayTime(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::IntervalMonthDayNano(Some(_i)) => { + ScalarValue::IntervalYearMonth(Some(_)) + | ScalarValue::IntervalDayTime(Some(_)) + | ScalarValue::IntervalMonthDayNano(Some(_)) => { let wrap_array = v.to_array()?; let Some(result) = array_value_to_string(&wrap_array, 0).ok() else { return internal_err!( - "Unable to convert IntervalMonthDayNano to string" + "Unable to convert interval scalar value to string" ); }; let interval = Interval { @@ -882,6 +922,10 @@ impl Unparser<'_> { }; Ok(ast::Expr::Interval(interval)) } + ScalarValue::IntervalYearMonth(None) => { + Ok(ast::Expr::Value(ast::Value::Null)) + } + ScalarValue::IntervalDayTime(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::IntervalMonthDayNano(None) => { Ok(ast::Expr::Value(ast::Value::Null)) } @@ -990,11 +1034,18 @@ impl Unparser<'_> { DataType::Dictionary(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Decimal128(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::Decimal256(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + DataType::Decimal128(precision, scale) + | DataType::Decimal256(precision, scale) => { + let mut new_precision = *precision as u64; + let mut new_scale = *scale as u64; + if *scale < 0 { + new_precision = (*precision as i16 - *scale as i16) as u64; + new_scale = 0 + } + + Ok(ast::DataType::Decimal( + ast::ExactNumberInfo::PrecisionAndScale(new_precision, new_scale), + )) } DataType::Map(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") @@ -1013,12 +1064,13 @@ mod tests { use arrow::datatypes::{Field, Schema}; use arrow_schema::DataType::Int8; + use datafusion_common::TableReference; use datafusion_expr::{ - case, col, cube, exists, grouping_set, lit, not, not_exists, out_ref_col, - placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, - ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, - WindowFunctionDefinition, + case, col, cube, exists, grouping_set, interval_datetime_lit, + interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup, + table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{interval_month_day_nano_lit, AggregateExt}; use datafusion_functions_aggregate::count::count_udaf; @@ -1180,6 +1232,39 @@ mod tests { Expr::Literal(ScalarValue::Date32(Some(-1))), r#"CAST('1969-12-31' AS DATE)"#, ), + ( + Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None)), + r#"CAST('1970-01-01 02:46:41' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampSecond( + Some(10001), + Some("+08:00".into()), + )), + r#"CAST('1970-01-01 10:46:41 +08:00' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None)), + r#"CAST('1970-01-01 00:00:10.001' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampMillisecond( + Some(10001), + Some("+08:00".into()), + )), + r#"CAST('1970-01-01 08:00:10.001 +08:00' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None)), + r#"CAST('1970-01-01 00:00:00.010001' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampMicrosecond( + Some(10001), + Some("+08:00".into()), + )), + r#"CAST('1970-01-01 08:00:00.010001 +08:00' AS TIMESTAMP)"#, + ), ( Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None)), r#"CAST('1970-01-01 00:00:00.000010001' AS TIMESTAMP)"#, @@ -1191,6 +1276,22 @@ mod tests { )), r#"CAST('1970-01-01 08:00:00.000010001 +08:00' AS TIMESTAMP)"#, ), + ( + Expr::Literal(ScalarValue::Time32Second(Some(10001))), + r#"CAST('02:46:41' AS TIME)"#, + ), + ( + Expr::Literal(ScalarValue::Time32Millisecond(Some(10001))), + r#"CAST('00:00:10.001' AS TIME)"#, + ), + ( + Expr::Literal(ScalarValue::Time64Microsecond(Some(10001))), + r#"CAST('00:00:00.010001' AS TIME)"#, + ), + ( + Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001))), + r#"CAST('00:00:00.000010001' AS TIME)"#, + ), (sum(col("a")), r#"sum(a)"#), ( count_udaf() @@ -1198,7 +1299,7 @@ mod tests { .distinct() .build() .unwrap(), - "COUNT(DISTINCT *)", + "count(DISTINCT *)", ), ( count_udaf() @@ -1206,7 +1307,7 @@ mod tests { .filter(lit(true)) .build() .unwrap(), - "COUNT(*) FILTER (WHERE true)", + "count(*) FILTER (WHERE true)", ), ( Expr::WindowFunction(WindowFunction { @@ -1242,7 +1343,7 @@ mod tests { ), null_treatment: None, }), - r#"COUNT(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, + r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, ), (col("a").is_not_null(), r#"a IS NOT NULL"#), (col("a").is_null(), r#"a IS NULL"#), @@ -1345,6 +1446,45 @@ mod tests { .sub(interval_month_day_nano_lit("1 DAY")), r#"(INTERVAL '0 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.000000000 SECS' - INTERVAL '0 YEARS 0 MONS 1 DAYS 0 HOURS 0 MINS 0.000000000 SECS')"#, ), + ( + interval_datetime_lit("10 DAY 1 HOUR 10 MINUTE 20 SECOND"), + r#"INTERVAL '0 YEARS 0 MONS 10 DAYS 1 HOURS 10 MINS 20.000 SECS'"#, + ), + ( + interval_datetime_lit("10 DAY 1.5 HOUR 10 MINUTE 20 SECOND"), + r#"INTERVAL '0 YEARS 0 MONS 10 DAYS 1 HOURS 40 MINS 20.000 SECS'"#, + ), + ( + interval_year_month_lit("1 YEAR 1 MONTH"), + r#"INTERVAL '1 YEARS 1 MONS 0 DAYS 0 HOURS 0 MINS 0.00 SECS'"#, + ), + ( + interval_year_month_lit("1.5 YEAR 1 MONTH"), + r#"INTERVAL '1 YEARS 7 MONS 0 DAYS 0 HOURS 0 MINS 0.00 SECS'"#, + ), + ( + (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128( + Some(100123), + 28, + 3, + ))), + r#"((a + b) > 100.123)"#, + ), + ( + (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( + Some(100123.into()), + 28, + 3, + ))), + r#"((a + b) > 100.123)"#, + ), + ( + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Decimal128(10, -2), + }), + r#"CAST(a AS DECIMAL(12,0))"#, + ), ]; for (expr, expected) in tests { diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index fb0285901c3f..fbbed4972b17 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -18,6 +18,7 @@ mod ast; mod expr; mod plan; +mod rewrite; mod utils; pub use expr::expr_to_sql; diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index a4a457f51dc9..15137403c582 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -28,6 +28,7 @@ use super::{ BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }, + rewrite::normalize_union_schema, utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, Unparser, }; @@ -63,6 +64,8 @@ pub fn plan_to_sql(plan: &LogicalPlan) -> Result { impl Unparser<'_> { pub fn plan_to_sql(&self, plan: &LogicalPlan) -> Result { + let plan = normalize_union_schema(plan)?; + match plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) @@ -80,8 +83,8 @@ impl Unparser<'_> { | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) - | LogicalPlan::Distinct(_) => self.select_to_sql_statement(plan), - LogicalPlan::Dml(_) => self.dml_to_sql(plan), + | LogicalPlan::Distinct(_) => self.select_to_sql_statement(&plan), + LogicalPlan::Dml(_) => self.dml_to_sql(&plan), LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) | LogicalPlan::Extension(_) diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs new file mode 100644 index 000000000000..a73fce30ced3 --- /dev/null +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::{ + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator}, + Result, +}; +use datafusion_expr::{Expr, LogicalPlan, Sort}; + +/// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. +/// +/// DataFusion will return an error if two columns in the schema have the same name with no table qualifiers. +/// There are certain types of UNION queries that can result in having two columns with the same name, and the +/// solution was to add table qualifiers to the schema fields. +/// See for more context on this decision. +/// +/// However, this causes a problem when unparsing these queries back to SQL - as the table qualifier has +/// logically been erased and is no longer a valid reference. +/// +/// The following input SQL: +/// ```sql +/// SELECT table1.foo FROM table1 +/// UNION ALL +/// SELECT table2.foo FROM table2 +/// ORDER BY foo +/// ``` +/// +/// Would be unparsed into the following invalid SQL without this transformation: +/// ```sql +/// SELECT table1.foo FROM table1 +/// UNION ALL +/// SELECT table2.foo FROM table2 +/// ORDER BY table1.foo +/// ``` +/// +/// Which would result in a SQL error, as `table1.foo` is not a valid reference in the context of the UNION. +pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result { + let plan = plan.clone(); + + let transformed_plan = plan.transform_up(|plan| match plan { + LogicalPlan::Union(mut union) => { + let schema = match Arc::try_unwrap(union.schema) { + Ok(inner) => inner, + Err(schema) => (*schema).clone(), + }; + let schema = schema.strip_qualifiers(); + + union.schema = Arc::new(schema); + Ok(Transformed::yes(LogicalPlan::Union(union))) + } + LogicalPlan::Sort(sort) => { + // Only rewrite Sort expressions that have a UNION as their input + if !matches!(&*sort.input, LogicalPlan::Union(_)) { + return Ok(Transformed::no(LogicalPlan::Sort(sort))); + } + + Ok(Transformed::yes(LogicalPlan::Sort(Sort { + expr: rewrite_sort_expr_for_union(sort.expr)?, + input: sort.input, + fetch: sort.fetch, + }))) + } + _ => Ok(Transformed::no(plan)), + }); + transformed_plan.data() +} + +/// Rewrite sort expressions that have a UNION plan as their input to remove the table reference. +fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { + let sort_exprs: Vec = exprs + .into_iter() + .map_until_stop_and_collect(|expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } + }) + }) + .data()?; + + Ok(sort_exprs) +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 33e28e7056b9..374403d853f9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -230,6 +230,16 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, + TestStatementWithDialect { + sql: "SELECT j1_id FROM j1 + UNION ALL + SELECT tb.j2_id as j1_id FROM j2 tb + ORDER BY j1_id + LIMIT 10;", + expected: r#"SELECT j1.j1_id FROM j1 UNION ALL SELECT tb.j2_id AS j1_id FROM j2 AS tb ORDER BY j1_id ASC NULLS LAST LIMIT 10"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, ]; for query in tests { @@ -239,7 +249,9 @@ fn roundtrip_statement_with_dialect() -> Result<()> { let context = MockContextProvider::default(); let sql_to_rel = SqlToRel::new(&context); - let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + let plan = sql_to_rel + .sql_statement_to_plan(statement) + .unwrap_or_else(|e| panic!("Failed to parse sql: {}\n{e}", query.sql)); let unparser = Unparser::new(&*query.unparser_dialect); let roundtrip_statement = unparser.plan_to_sql(&plan)?; diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index 893678d6b374..f5caaefb3ea0 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -15,16 +15,39 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; #[cfg(test)] use std::collections::HashMap; +use std::fmt::Display; use std::{sync::Arc, vec}; use arrow_schema::*; use datafusion_common::config::ConfigOptions; -use datafusion_common::{plan_err, Result, TableReference}; +use datafusion_common::file_options::file_type::FileType; +use datafusion_common::{plan_err, GetExt, Result, TableReference}; use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; use datafusion_sql::planner::ContextProvider; +struct MockCsvType {} + +impl GetExt for MockCsvType { + fn get_ext(&self) -> String { + "csv".to_string() + } +} + +impl FileType for MockCsvType { + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Display for MockCsvType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.get_ext()) + } +} + #[derive(Default)] pub(crate) struct MockContextProvider { options: ConfigOptions, @@ -191,6 +214,13 @@ impl ContextProvider for MockContextProvider { &self.options } + fn get_file_type( + &self, + _ext: &str, + ) -> Result> { + Ok(Arc::new(MockCsvType {})) + } + fn create_cte_work_table( &self, _name: &str, diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 8eb2a2b609e7..e72a439b323b 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -37,6 +37,7 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; +use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::{ approx_median::approx_median_udaf, count::count_udaf, }; @@ -83,6 +84,7 @@ fn parse_decimals() { ParserOptions { parse_float_as_decimal: true, enable_ident_normalization: false, + support_varchar_with_length: false, }, ); } @@ -136,6 +138,7 @@ fn parse_ident_normalization() { ParserOptions { parse_float_as_decimal: false, enable_ident_normalization, + support_varchar_with_length: false, }, ); if plan.is_ok() { @@ -996,12 +999,12 @@ fn select_aggregate_with_having_with_aggregate_not_in_select() { #[test] fn select_aggregate_with_having_referencing_column_not_in_select() { - let sql = "SELECT COUNT(*) + let sql = "SELECT count(*) FROM person HAVING first_name = 'M'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: HAVING clause references non-aggregate values: Expression person.first_name could not be resolved from available columns: COUNT(*)", + "Error during planning: HAVING clause references non-aggregate values: Expression person.first_name could not be resolved from available columns: count(*)", err.strip_backtrace() ); } @@ -1200,10 +1203,10 @@ fn select_aggregate_with_group_by_with_having_using_count_star_not_in_select() { let sql = "SELECT first_name, MAX(age) FROM person GROUP BY first_name - HAVING MAX(age) > 100 AND COUNT(*) < 50"; + HAVING MAX(age) > 100 AND count(*) < 50"; let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) > Int64(100) AND COUNT(*) < Int64(50)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), COUNT(*)]]\ + \n Filter: MAX(person.age) > Int64(100) AND count(*) < Int64(50)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1464,15 +1467,15 @@ fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() #[test] fn select_simple_aggregate_with_groupby_can_use_positions() { quick_test( - "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 1, 2", - "Projection: person.state, person.age AS b, COUNT(Int64(1))\ - \n Aggregate: groupBy=[[person.state, person.age]], aggr=[[COUNT(Int64(1))]]\ + "SELECT state, age AS b, count(1) FROM person GROUP BY 1, 2", + "Projection: person.state, person.age AS b, count(Int64(1))\ + \n Aggregate: groupBy=[[person.state, person.age]], aggr=[[count(Int64(1))]]\ \n TableScan: person", ); quick_test( - "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 2, 1", - "Projection: person.state, person.age AS b, COUNT(Int64(1))\ - \n Aggregate: groupBy=[[person.age, person.state]], aggr=[[COUNT(Int64(1))]]\ + "SELECT state, age AS b, count(1) FROM person GROUP BY 2, 1", + "Projection: person.state, person.age AS b, count(Int64(1))\ + \n Aggregate: groupBy=[[person.age, person.state]], aggr=[[count(Int64(1))]]\ \n TableScan: person", ); } @@ -1633,18 +1636,18 @@ fn test_wildcard() { #[test] fn select_count_one() { - let sql = "SELECT COUNT(1) FROM person"; - let expected = "Projection: COUNT(Int64(1))\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + let sql = "SELECT count(1) FROM person"; + let expected = "Projection: count(Int64(1))\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] fn select_count_column() { - let sql = "SELECT COUNT(id) FROM person"; - let expected = "Projection: COUNT(person.id)\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(person.id)]]\ + let sql = "SELECT count(id) FROM person"; + let expected = "Projection: count(person.id)\ + \n Aggregate: groupBy=[[]], aggr=[[count(person.id)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1814,9 +1817,9 @@ fn select_group_by_columns_not_in_select() { #[test] fn select_group_by_count_star() { - let sql = "SELECT state, COUNT(*) FROM person GROUP BY state"; - let expected = "Projection: person.state, COUNT(*)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(*)]]\ + let sql = "SELECT state, count(*) FROM person GROUP BY state"; + let expected = "Projection: person.state, count(*)\ + \n Aggregate: groupBy=[[person.state]], aggr=[[count(*)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -1824,10 +1827,10 @@ fn select_group_by_count_star() { #[test] fn select_group_by_needs_projection() { - let sql = "SELECT COUNT(state), state FROM person GROUP BY state"; + let sql = "SELECT count(state), state FROM person GROUP BY state"; let expected = "\ - Projection: COUNT(person.state), person.state\ - \n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(person.state)]]\ + Projection: count(person.state), person.state\ + \n Aggregate: groupBy=[[person.state]], aggr=[[count(person.state)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -2309,10 +2312,10 @@ fn empty_over_plus() { #[test] fn empty_over_multiple() { - let sql = "SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders"; + let sql = "SELECT order_id, MAX(qty) OVER (), min(qty) over (), avg(qty) OVER () from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, AVG(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, AVG(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id, MAX(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, avg(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[MAX(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, avg(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2627,8 +2630,8 @@ fn select_groupby_orderby() { // expect that this is not an ambiguous reference let expected = "Sort: birth_date ASC NULLS LAST\ - \n Projection: AVG(person.age) AS value, date_trunc(Utf8(\"month\"), person.birth_date) AS birth_date\ - \n Aggregate: groupBy=[[person.birth_date]], aggr=[[AVG(person.age)]]\ + \n Projection: avg(person.age) AS value, date_trunc(Utf8(\"month\"), person.birth_date) AS birth_date\ + \n Aggregate: groupBy=[[person.birth_date]], aggr=[[avg(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -2705,7 +2708,8 @@ fn logical_plan_with_dialect_and_options( .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64)) .with_udaf(sum_udaf()) .with_udaf(approx_median_udaf()) - .with_udaf(count_udaf()); + .with_udaf(count_udaf()) + .with_udaf(avg_udaf()); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); @@ -3000,8 +3004,8 @@ fn scalar_subquery_reference_outer_field() { let expected = "Projection: j1.j1_string, j2.j2_string\ \n Filter: j1.j1_id = j2.j2_id - Int64(1) AND j2.j2_id < ()\ \n Subquery:\ - \n Projection: COUNT(*)\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(*)]]\ + \n Projection: count(*)\ + \n Aggregate: groupBy=[[]], aggr=[[count(*)]]\ \n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\ \n CrossJoin:\ \n TableScan: j1\ @@ -3098,19 +3102,19 @@ fn cte_unbalanced_number_of_columns() { #[test] fn aggregate_with_rollup() { let sql = - "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\ + "SELECT id, state, age, count(*) FROM person GROUP BY id, ROLLUP (state, age)"; + let expected = "Projection: person.id, person.state, person.age, count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] fn aggregate_with_rollup_with_grouping() { - let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), COUNT(*) \ + let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), count(*) \ FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), COUNT(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(*)]]\ + let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3140,9 +3144,9 @@ fn rank_partition_grouping() { #[test] fn aggregate_with_cube() { let sql = - "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)"; - let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\ + "SELECT id, state, age, count(*) FROM person GROUP BY id, CUBE (state, age)"; + let expected = "Projection: person.id, person.state, person.age, count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3157,9 +3161,9 @@ fn round_decimal() { #[test] fn aggregate_with_grouping_sets() { - let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; - let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[COUNT(*)]]\ + let sql = "SELECT id, state, age, count(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; + let expected = "Projection: person.id, person.state, person.age, count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 378cab206240..552ad6a2a756 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1528,154 +1528,154 @@ e e 1323 query TTI SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY CUBE (a.c1, b.c1) ORDER BY a.c1, b.c1 ---- -a a 1260 -a b 1140 -a c 1260 -a d 1080 -a e 1260 -a NULL 6000 -b a 1302 -b b 1178 -b c 1302 -b d 1116 -b e 1302 -b NULL 6200 -c a 1176 -c b 1064 -c c 1176 -c d 1008 -c e 1176 -c NULL 5600 -d a 924 -d b 836 -d c 924 -d d 792 -d e 924 -d NULL 4400 -e a 1323 -e b 1197 -e c 1323 -e d 1134 -e e 1323 -e NULL 6300 -NULL a 5985 -NULL b 5415 -NULL c 5985 -NULL d 5130 -NULL e 5985 +a a 1260 +a b 1140 +a c 1260 +a d 1080 +a e 1260 +a NULL 6000 +b a 1302 +b b 1178 +b c 1302 +b d 1116 +b e 1302 +b NULL 6200 +c a 1176 +c b 1064 +c c 1176 +c d 1008 +c e 1176 +c NULL 5600 +d a 924 +d b 836 +d c 924 +d d 792 +d e 924 +d NULL 4400 +e a 1323 +e b 1197 +e c 1323 +e d 1134 +e e 1323 +e NULL 6300 +NULL a 5985 +NULL b 5415 +NULL c 5985 +NULL d 5130 +NULL e 5985 NULL NULL 28500 # csv_query_cube_distinct_count query TII SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY CUBE (c1,c2) ORDER BY c1,c2 ---- -a 1 5 -a 2 3 -a 3 5 -a 4 4 -a 5 3 -a NULL 19 -b 1 3 -b 2 4 -b 3 2 -b 4 5 -b 5 5 -b NULL 17 -c 1 4 -c 2 7 -c 3 4 -c 4 4 -c 5 2 -c NULL 21 -d 1 7 -d 2 3 -d 3 3 -d 4 3 -d 5 2 -d NULL 18 -e 1 3 -e 2 4 -e 3 4 -e 4 7 -e 5 2 -e NULL 18 -NULL 1 22 -NULL 2 20 -NULL 3 17 -NULL 4 23 -NULL 5 14 +a 1 5 +a 2 3 +a 3 5 +a 4 4 +a 5 3 +a NULL 19 +b 1 3 +b 2 4 +b 3 2 +b 4 5 +b 5 5 +b NULL 17 +c 1 4 +c 2 7 +c 3 4 +c 4 4 +c 5 2 +c NULL 21 +d 1 7 +d 2 3 +d 3 3 +d 4 3 +d 5 2 +d NULL 18 +e 1 3 +e 2 4 +e 3 4 +e 4 7 +e 5 2 +e NULL 18 +NULL 1 22 +NULL 2 20 +NULL 3 17 +NULL 4 23 +NULL 5 14 NULL NULL 80 # csv_query_rollup_distinct_count query TII SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY ROLLUP (c1,c2) ORDER BY c1,c2 ---- -a 1 5 -a 2 3 -a 3 5 -a 4 4 -a 5 3 -a NULL 19 -b 1 3 -b 2 4 -b 3 2 -b 4 5 -b 5 5 -b NULL 17 -c 1 4 -c 2 7 -c 3 4 -c 4 4 -c 5 2 -c NULL 21 -d 1 7 -d 2 3 -d 3 3 -d 4 3 -d 5 2 -d NULL 18 -e 1 3 -e 2 4 -e 3 4 -e 4 7 -e 5 2 -e NULL 18 +a 1 5 +a 2 3 +a 3 5 +a 4 4 +a 5 3 +a NULL 19 +b 1 3 +b 2 4 +b 3 2 +b 4 5 +b 5 5 +b NULL 17 +c 1 4 +c 2 7 +c 3 4 +c 4 4 +c 5 2 +c NULL 21 +d 1 7 +d 2 3 +d 3 3 +d 4 3 +d 5 2 +d NULL 18 +e 1 3 +e 2 4 +e 3 4 +e 4 7 +e 5 2 +e NULL 18 NULL NULL 80 # csv_query_rollup_sum_crossjoin query TTI SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY ROLLUP (a.c1, b.c1) ORDER BY a.c1, b.c1 ---- -a a 1260 -a b 1140 -a c 1260 -a d 1080 -a e 1260 -a NULL 6000 -b a 1302 -b b 1178 -b c 1302 -b d 1116 -b e 1302 -b NULL 6200 -c a 1176 -c b 1064 -c c 1176 -c d 1008 -c e 1176 -c NULL 5600 -d a 924 -d b 836 -d c 924 -d d 792 -d e 924 -d NULL 4400 -e a 1323 -e b 1197 -e c 1323 -e d 1134 -e e 1323 -e NULL 6300 +a a 1260 +a b 1140 +a c 1260 +a d 1080 +a e 1260 +a NULL 6000 +b a 1302 +b b 1178 +b c 1302 +b d 1116 +b e 1302 +b NULL 6200 +c a 1176 +c b 1064 +c c 1176 +c d 1008 +c e 1176 +c NULL 5600 +d a 924 +d b 836 +d c 924 +d d 792 +d e 924 +d NULL 4400 +e a 1323 +e b 1197 +e c 1323 +e d 1134 +e e 1323 +e NULL 6300 NULL NULL 28500 # query_count_without_from @@ -1785,30 +1785,39 @@ select min(t), max(t) from (select '00:00:00' as t union select '00:00:01' unio ---- 00:00:00 00:00:02 -# aggregate_decimal_min -query RT -select min(c1), arrow_typeof(min(c1)) from d_table ----- --100.009 Decimal128(10, 3) - -# aggregate_decimal_max -query RT -select max(c1), arrow_typeof(max(c1)) from d_table +# aggregate Interval(MonthDayNano) min/max +query T?? +select + arrow_typeof(min(column1)), min(column1), max(column1) +from values + (interval '1 month'), + (interval '2 months'), + (interval '2 month 15 days'), + (interval '-2 month') ---- -110.009 Decimal128(10, 3) +Interval(MonthDayNano) 0 years -2 mons 0 days 0 hours 0 mins 0.000000000 secs 0 years 2 mons 15 days 0 hours 0 mins 0.000000000 secs -# aggregate_decimal_sum -query RT -select sum(c1), arrow_typeof(sum(c1)) from d_table +# aggregate Interval(DayTime) min/max +query T?? +select + arrow_typeof(min(column1)), min(column1), max(column1) +from values + (arrow_cast('60 minutes', 'Interval(DayTime)')), + (arrow_cast('-3 minutes', 'Interval(DayTime)')), + (arrow_cast('30 minutes', 'Interval(DayTime)')); ---- -100 Decimal128(20, 3) +Interval(DayTime) 0 years 0 mons 0 days 0 hours -3 mins 0.000 secs 0 years 0 mons 0 days 1 hours 0 mins 0.000 secs -# aggregate_decimal_avg -query RT -select avg(c1), arrow_typeof(avg(c1)) from d_table +# aggregate Interval(YearMonth) min/max +query T?? +select + arrow_typeof(min(column1)), min(column1), max(column1) +from values + (arrow_cast('-1 year', 'Interval(YearMonth)')), + (arrow_cast('13 months', 'Interval(YearMonth)')), + (arrow_cast('1 year', 'Interval(YearMonth)')); ---- -5 Decimal128(14, 7) - +Interval(YearMonth) -1 years 0 mons 0 days 0 hours 0 mins 0.00 secs 1 years 1 mons 0 days 0 hours 0 mins 0.00 secs # aggregate query II @@ -3681,10 +3690,10 @@ X 2 2 2 2 Y 1 1 1 1 # aggregate_timestamps_avg -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +query error SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +query error SELECT tag, avg(nanos), avg(micros), avg(millis), avg(secs) FROM t GROUP BY tag ORDER BY tag; # aggregate_duration_array_agg @@ -3781,10 +3790,10 @@ Y 2021-01-01 2021-01-01T00:00:00 # aggregate_timestamps_avg -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Date32\)'\. You might need to add explicit type casts\. +query error SELECT avg(date32), avg(date64) FROM t -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Date32\)'\. You might need to add explicit type casts\. +query error SELECT tag, avg(date32), avg(date64) FROM t GROUP BY tag ORDER BY tag; @@ -3879,10 +3888,10 @@ B 21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 # aggregate_times_avg -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +query error SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +query error SELECT tag, avg(nanos), avg(micros), avg(millis), avg(secs) FROM t GROUP BY tag ORDER BY tag; statement ok @@ -4316,7 +4325,7 @@ select avg(distinct x_dict) from value_dict; ---- 3 -statement error DataFusion error: This feature is not implemented: AVG\(DISTINCT\) aggregations are not available +query error select avg(x_dict), avg(distinct x_dict) from value_dict; query I diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 55a430767c76..77d1a9da1f55 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -3769,6 +3769,54 @@ select array_to_string(make_array(), ',') ---- (empty) +# array to string dictionary +statement ok +CREATE TABLE table1 AS VALUES + (1, 'foo'), + (3, 'bar'), + (1, 'foo'), + (2, NULL), + (NULL, 'baz') + ; + +# expect 1-3-1-2 (dictionary values should be repeated) +query T +SELECT array_to_string(array_agg(column1),'-') +FROM ( + SELECT arrow_cast(column1, 'Dictionary(Int32, Int32)') as column1 + FROM table1 +); +---- +1-3-1-2 + +# expect foo,bar,foo,baz (dictionary values should be repeated) +query T +SELECT array_to_string(array_agg(column2),',') +FROM ( + SELECT arrow_cast(column2, 'Dictionary(Int64, Utf8)') as column2 + FROM table1 +); +---- +foo,bar,foo,baz + +# Expect only values that are in the group +query I?T +SELECT column1, array_agg(column2), array_to_string(array_agg(column2),',') +FROM ( + SELECT column1, arrow_cast(column2, 'Dictionary(Int32, Utf8)') as column2 + FROM table1 +) +GROUP BY column1 +ORDER BY column1; +---- +1 [foo, foo] foo,foo +2 [] (empty) +3 [bar] bar +NULL [baz] baz + +statement ok +drop table table1; + ## array_union (aliases: `list_union`) diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index 24c99fc849b6..8fde295e6051 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -41,17 +41,68 @@ SELECT * FROM data; # Filtering ########### -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I rowsort SELECT * FROM data WHERE column1 = [1,2,3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) -SELECT * FROM data WHERE column1 = column2 - -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I SELECT * FROM data WHERE column1 != [1,2,3]; +---- +[2, 3] [2, 3] 1 -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I SELECT * FROM data WHERE column1 != column2 +---- +[1, 2, 3] [4, 5] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 < [1,2,3,4]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 <= [2, 3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 > [1,2]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 >= [1, 2, 3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +# test with scalar null +query ??I +SELECT * FROM data WHERE column2 = null; +---- + +query ??I +SELECT * FROM data WHERE null = column2; +---- + +query ??I rowsort +SELECT * FROM data WHERE column2 is distinct from null; +---- +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I +SELECT * FROM data WHERE column2 is not distinct from null; +---- +[1, 2, 3] NULL 1 ########### # Aggregates @@ -158,3 +209,68 @@ SELECT * FROM data ORDER BY column1, column3, column2; statement ok drop table data + + +# test filter column with all nulls +statement ok +create table data (a int) as values (null), (null), (null); + +query I +select * from data where a = null; +---- + +query I +select * from data where a is not distinct from null; +---- +NULL +NULL +NULL + +statement ok +drop table data; + +statement ok +create table data (a int[][], b int) as values ([[1,2,3]], 1), ([[2,3], [4,5]], 2), (null, 3); + +query ?I +select * from data; +---- +[[1, 2, 3]] 1 +[[2, 3], [4, 5]] 2 +NULL 3 + +query ?I +select * from data where a = [[1,2,3]]; +---- +[[1, 2, 3]] 1 + +query ?I +select * from data where a > [[1,2,3]]; +---- +[[2, 3], [4, 5]] 2 + +query ?I +select * from data where a > [[1,2]]; +---- +[[1, 2, 3]] 1 +[[2, 3], [4, 5]] 2 + +query ?I +select * from data where a < [[2, 3]]; +---- +[[1, 2, 3]] 1 + +# compare with null with eq results in null +query ?I +select * from data where a = null; +---- + +query ?I +select * from data where a != null; +---- + +# compare with null with distinct results in true/false +query ?I +select * from data where a is not distinct from null; +---- +NULL 3 diff --git a/datafusion/sqllogictest/test_files/avro.slt b/datafusion/sqllogictest/test_files/avro.slt index fced1924ced9..f8ef81a8ba2b 100644 --- a/datafusion/sqllogictest/test_files/avro.slt +++ b/datafusion/sqllogictest/test_files/avro.slt @@ -243,11 +243,11 @@ query TT EXPLAIN SELECT count(*) from alltypes_plain ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--TableScan: alltypes_plain projection=[] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +01)AggregateExec: mode=Final, gby=[], aggr=[count(*)] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +03)----AggregateExec: mode=Partial, gby=[], aggr=[count(*)] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------AvroExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/alltypes_plain.avro]]} diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index 8902b3eebf24..a8a689cbb8b5 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -226,3 +226,70 @@ SELECT * from stored_table_with_comments; ---- column1 column2 2 3 + +# read csv with double quote +statement ok +CREATE EXTERNAL TABLE csv_with_double_quote ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +OPTIONS ('format.delimiter' ',', + 'format.has_header' 'true', + 'format.double_quote' 'true') +LOCATION '../core/tests/data/double_quote.csv'; + +query TT +select * from csv_with_double_quote +---- +id0 "value0" +id1 "value1" +id2 "value2" +id3 "value3" + +# ensure that double quote option is used when writing to csv +query TT +COPY csv_with_double_quote TO 'test_files/scratch/csv_files/table_with_double_quotes.csv' +STORED AS csv +OPTIONS ('format.double_quote' 'true'); +---- +4 + +statement ok +CREATE EXTERNAL TABLE stored_table_with_double_quotes ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION 'test_files/scratch/csv_files/table_with_double_quotes.csv' +OPTIONS ('format.double_quote' 'true'); + +query TT +select * from stored_table_with_double_quotes; +---- +id0 "value0" +id1 "value1" +id2 "value2" +id3 "value3" + +# ensure when double quote option is disabled that quotes are escaped instead +query TT +COPY csv_with_double_quote TO 'test_files/scratch/csv_files/table_with_escaped_quotes.csv' +STORED AS csv +OPTIONS ('format.double_quote' 'false', 'format.escape' '#'); +---- +4 + +statement ok +CREATE EXTERNAL TABLE stored_table_with_escaped_quotes ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION 'test_files/scratch/csv_files/table_with_escaped_quotes.csv' +OPTIONS ('format.double_quote' 'false', 'format.escape' '#'); + +query TT +select * from stored_table_with_escaped_quotes; +---- +id0 "value0" +id1 "value1" +id2 "value2" +id3 "value3" diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index d51c69496d46..fa25f00974a9 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -108,7 +108,7 @@ query error select count(); # AggregateFunction with wrong number of arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Utf8, Float64\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tAVG\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +query error select avg(c1, c12) from aggregate_test_100; # AggregateFunction with wrong argument type diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 3c5f8c7f7ad6..96e73a591678 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -93,7 +93,7 @@ query TT EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3) ---- physical_plan -01)ProjectionExec: expr=[2 as COUNT(*)] +01)ProjectionExec: expr=[2 as count(*)] 02)--PlaceholderRowExec statement ok diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index d274d7d4390c..4e8f3b59a650 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2473,7 +2473,7 @@ host2 202 host3 303 # TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 -query error +query TR select t2.server['c3'] as host, sum(( @@ -2488,6 +2488,10 @@ select ) t2 where t2.server['c3'] IS NOT NULL group by t2.server['c3'] order by host; +---- +host1 101 +host2 202 +host3 303 # can have 2 projections with aggr(short_circuited), with different short-circuited expr query TRR @@ -2559,7 +2563,7 @@ host2 2.2 202 host3 3.3 303 # TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 -query error +query TRR select t2.server['c3'] as host, sum(( @@ -2579,6 +2583,10 @@ select ) t2 where t2.server['c3'] IS NOT NULL group by t2.server['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 # can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce) query TRR @@ -2587,3 +2595,28 @@ select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesc host1 1.1 101 host2 2.2 202 host3 3.3 303 + +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +statement ok +create table t (a float) as values (1), (2), (3); + +query TT +explain select min(a) filter (where a > 1) as x from t; +---- +logical_plan +01)Projection: MIN(t.a) FILTER (WHERE t.a > Int64(1)) AS x +02)--Aggregate: groupBy=[[]], aggr=[[MIN(t.a) FILTER (WHERE t.a > Float32(1)) AS MIN(t.a) FILTER (WHERE t.a > Int64(1))]] +03)----TableScan: t projection=[a] +physical_plan +01)ProjectionExec: expr=[MIN(t.a) FILTER (WHERE t.a > Int64(1))@0 as x] +02)--AggregateExec: mode=Single, gby=[], aggr=[MIN(t.a) FILTER (WHERE t.a > Int64(1))] +03)----MemoryExec: partitions=1, partition_sizes=[1] + + +statement ok +drop table t; + +statement ok +set datafusion.sql_parser.dialect = 'Generic'; diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index fff3977fe1e6..04a1fcc78fe7 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -1962,9 +1962,9 @@ GROUP BY ALL; 2 0 13 query IIR rowsort -SELECT sub.col1, sub.col0, sub."AVG(tab3.col2)" AS avg_col2 +SELECT sub.col1, sub.col0, sub."avg(tab3.col2)" AS avg_col2 FROM ( - SELECT col1, AVG(col2), col0 FROM tab3 GROUP BY ALL + SELECT col1, avg(col2), col0 FROM tab3 GROUP BY ALL ) AS sub GROUP BY ALL; ---- @@ -4394,18 +4394,18 @@ EXPLAIN SELECT c1, count(distinct c2), min(distinct c2), sum(c3), max(c4) FROM a ---- logical_plan 01)Sort: aggregate_test_100.c1 ASC NULLS LAST -02)--Projection: aggregate_test_100.c1, COUNT(alias1) AS COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2), sum(alias2) AS sum(aggregate_test_100.c3), MAX(alias3) AS MAX(aggregate_test_100.c4) -03)----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(alias1), MIN(alias1), sum(alias2), MAX(alias3)]] +02)--Projection: aggregate_test_100.c1, count(alias1) AS count(DISTINCT aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2), sum(alias2) AS sum(aggregate_test_100.c3), MAX(alias3) AS MAX(aggregate_test_100.c4) +03)----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[count(alias1), MIN(alias1), sum(alias2), MAX(alias3)]] 04)------Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS alias1]], aggr=[[sum(CAST(aggregate_test_100.c3 AS Int64)) AS alias2, MAX(aggregate_test_100.c4) AS alias3]] 05)--------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4] physical_plan 01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST] 02)--SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[c1@0 as c1, COUNT(alias1)@1 as COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1)@2 as MIN(DISTINCT aggregate_test_100.c2), sum(alias2)@3 as sum(aggregate_test_100.c3), MAX(alias3)@4 as MAX(aggregate_test_100.c4)] -04)------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), sum(alias2), MAX(alias3)] +03)----ProjectionExec: expr=[c1@0 as c1, count(alias1)@1 as count(DISTINCT aggregate_test_100.c2), MIN(alias1)@2 as MIN(DISTINCT aggregate_test_100.c2), sum(alias2)@3 as sum(aggregate_test_100.c3), MAX(alias3)@4 as MAX(aggregate_test_100.c4)] +04)------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[count(alias1), MIN(alias1), sum(alias2), MAX(alias3)] 05)--------CoalesceBatchesExec: target_batch_size=2 06)----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 -07)------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), sum(alias2), MAX(alias3)] +07)------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[count(alias1), MIN(alias1), sum(alias2), MAX(alias3)] 08)--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1, alias1@1 as alias1], aggr=[alias2, alias3] 09)----------------CoalesceBatchesExec: target_batch_size=2 10)------------------RepartitionExec: partitioning=Hash([c1@0, alias1@1], 8), input_partitions=8 @@ -5109,15 +5109,15 @@ GROUP BY ts_chunk; ---- logical_plan -01)Projection: date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01")) AS ts_chunk, COUNT(keywords_stream.keyword) AS alert_keyword_count -02)--Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"), keywords_stream.ts, TimestampNanosecond(946684800000000000, None)) AS date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))]], aggr=[[COUNT(keywords_stream.keyword)]] +01)Projection: date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01")) AS ts_chunk, count(keywords_stream.keyword) AS alert_keyword_count +02)--Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"), keywords_stream.ts, TimestampNanosecond(946684800000000000, None)) AS date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))]], aggr=[[count(keywords_stream.keyword)]] 03)----LeftSemi Join: keywords_stream.keyword = __correlated_sq_1.keyword 04)------TableScan: keywords_stream projection=[ts, keyword] 05)------SubqueryAlias: __correlated_sq_1 06)--------TableScan: alert_keywords projection=[keyword] physical_plan -01)ProjectionExec: expr=[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))@0 as ts_chunk, COUNT(keywords_stream.keyword)@1 as alert_keyword_count] -02)--AggregateExec: mode=Single, gby=[date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }, ts@0, 946684800000000000) as date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))], aggr=[COUNT(keywords_stream.keyword)] +01)ProjectionExec: expr=[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))@0 as ts_chunk, count(keywords_stream.keyword)@1 as alert_keyword_count] +02)--AggregateExec: mode=Single, gby=[date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }, ts@0, 946684800000000000) as date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))], aggr=[count(keywords_stream.keyword)] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(keyword@0, keyword@1)] 05)--------MemoryExec: partitions=1, partition_sizes=[1] @@ -5135,3 +5135,22 @@ GROUP BY ts_chunk; ---- 2024-01-01T00:00:00 4 + +# Issue: https://github.com/apache/datafusion/issues/11118 +statement ok +CREATE TABLE test_case_expr(a INT, b TEXT) AS VALUES (1,'hello'), (2,'world') + +query T +SELECT (CASE WHEN CONCAT(b, 'hello') = 'test' THEN 'good' ELSE 'bad' END) AS c + FROM test_case_expr GROUP BY c; +---- +bad + +query I rowsort +SELECT (CASE a::BIGINT WHEN 1 THEN 1 END) AS c FROM test_case_expr GROUP BY c; +---- +1 +NULL + +statement ok +drop table test_case_expr diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 6f31973fdb6f..3cc837aa8ee9 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -237,6 +237,7 @@ datafusion.optimizer.top_down_join_key_reordering true datafusion.sql_parser.dialect generic datafusion.sql_parser.enable_ident_normalization true datafusion.sql_parser.parse_float_as_decimal false +datafusion.sql_parser.support_varchar_with_length true # show all variables with verbose query TTT rowsort @@ -318,6 +319,7 @@ datafusion.optimizer.top_down_join_key_reordering true When set to true, the phy datafusion.sql_parser.dialect generic Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. datafusion.sql_parser.enable_ident_normalization true When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) datafusion.sql_parser.parse_float_as_decimal false When set to true, SQL parser will parse float as decimal type +datafusion.sql_parser.support_varchar_with_length true If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. # show_variable_in_config_options query TT diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index 1fa319111c45..9115cb532540 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -58,17 +58,17 @@ ORDER by c1 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 03)----Sort: aggregate_test_100.c1 ASC NULLS LAST -04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 -05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 +05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 06)----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=MemoryTable (partitions=1) -02)--ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] +02)--ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] 03)----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] -05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 06)----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 @@ -121,14 +121,14 @@ FROM aggregate_test_100 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 -03)----WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +03)----WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 04)------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=MemoryTable (partitions=1) 02)--CoalescePartitionsExec -03)----ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +03)----ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST], preserve_partitioning=[true] 06)----------CoalesceBatchesExec: target_batch_size=8192 07)------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 @@ -171,15 +171,15 @@ logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] 02)--Projection: a1 AS a1, a2 AS a2 03)----Sort: aggregate_test_100.c1 ASC NULLS LAST -04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a2, aggregate_test_100.c1 -05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a2, aggregate_test_100.c1 +05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 06)----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=MemoryTable (partitions=8) 02)--ProjectionExec: expr=[a1@0 as a1, a2@1 as a2] 03)----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] -05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 06)----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 9f930defbbf9..8f6bafd92e41 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -347,17 +347,17 @@ ORDER by c1 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 03)----Sort: aggregate_test_100.c1 ASC NULLS LAST -04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 -05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 +05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 06)----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=ParquetSink(file_groups=[]) -02)--ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] +02)--ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] 03)----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] -05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 06)----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 @@ -411,14 +411,14 @@ FROM aggregate_test_100 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 -03)----WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +03)----WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 04)------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=ParquetSink(file_groups=[]) 02)--CoalescePartitionsExec -03)----ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +03)----ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 05)--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST], preserve_partitioning=[true] 06)----------CoalesceBatchesExec: target_batch_size=8192 07)------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 4b62f2561260..501ae497745b 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1343,15 +1343,15 @@ from (select * from join_t1 inner join join_t2 on join_t1.t1_id = join_t2.t2_id) group by t1_id ---- logical_plan -01)Projection: COUNT(*) -02)--Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Projection: count(*) +02)--Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[count(Int64(1)) AS count(*)]] 03)----Projection: join_t1.t1_id 04)------Inner Join: join_t1.t1_id = join_t2.t2_id 05)--------TableScan: join_t1 projection=[t1_id] 06)--------TableScan: join_t2 projection=[t2_id] physical_plan -01)ProjectionExec: expr=[COUNT(*)@1 as COUNT(*)] -02)--AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as t1_id], aggr=[COUNT(*)] +01)ProjectionExec: expr=[count(*)@1 as count(*)] +02)--AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as t1_id], aggr=[count(*)] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] 05)--------CoalesceBatchesExec: target_batch_size=2 @@ -1370,18 +1370,18 @@ from join_t1 inner join join_t2 on join_t1.t1_id = join_t2.t2_id ---- logical_plan -01)Projection: COUNT(alias1) AS COUNT(DISTINCT join_t1.t1_id) -02)--Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] +01)Projection: count(alias1) AS count(DISTINCT join_t1.t1_id) +02)--Aggregate: groupBy=[[]], aggr=[[count(alias1)]] 03)----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]] 04)------Projection: join_t1.t1_id 05)--------Inner Join: join_t1.t1_id = join_t2.t2_id 06)----------TableScan: join_t1 projection=[t1_id] 07)----------TableScan: join_t2 projection=[t2_id] physical_plan -01)ProjectionExec: expr=[COUNT(alias1)@0 as COUNT(DISTINCT join_t1.t1_id)] -02)--AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)] +01)ProjectionExec: expr=[count(alias1)@0 as count(DISTINCT join_t1.t1_id)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(alias1)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)] +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(alias1)] 05)--------AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as alias1], aggr=[] 06)----------CoalesceBatchesExec: target_batch_size=2 07)------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] @@ -3781,3 +3781,33 @@ EXPLAIN SELECT * FROM ( ) AS a RIGHT ANTI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; ---- logical_plan EmptyRelation + +# FULL OUTER join with empty left and empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a FULL JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left ANTI join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a +) AS a LEFT ANTI JOIN (SELECT 1 AS a WHERE 1=0) as b ON a.a=b.a; +---- +logical_plan +01)SubqueryAlias: a +02)--Projection: Int64(1) AS a +03)----EmptyRelation + +# Right ANTI join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a RIGHT ANTI JOIN (SELECT 1 AS a) as b ON a.a=b.a; +---- +logical_plan +01)SubqueryAlias: b +02)--Projection: Int64(1) AS a +03)----EmptyRelation diff --git a/datafusion/sqllogictest/test_files/json.slt b/datafusion/sqllogictest/test_files/json.slt index 5d3c23d5130b..0b9508310b00 100644 --- a/datafusion/sqllogictest/test_files/json.slt +++ b/datafusion/sqllogictest/test_files/json.slt @@ -49,12 +49,12 @@ query TT EXPLAIN SELECT count(*) from json_test ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--TableScan: json_test projection=[] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +01)AggregateExec: mode=Final, gby=[], aggr=[count(*)] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +03)----AggregateExec: mode=Partial, gby=[], aggr=[count(*)] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/2.json]]} diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 2c65b1da4474..094017c383a6 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -307,11 +307,11 @@ query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Limit: skip=11, fetch=3 03)----TableScan: t1 projection=[], fetch=14 physical_plan -01)ProjectionExec: expr=[0 as COUNT(*)] +01)ProjectionExec: expr=[0 as count(*)] 02)--PlaceholderRowExec query I @@ -325,11 +325,11 @@ query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Limit: skip=8, fetch=3 03)----TableScan: t1 projection=[], fetch=11 physical_plan -01)ProjectionExec: expr=[2 as COUNT(*)] +01)ProjectionExec: expr=[2 as count(*)] 02)--PlaceholderRowExec query I @@ -343,11 +343,11 @@ query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8); ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Limit: skip=8, fetch=None 03)----TableScan: t1 projection=[] physical_plan -01)ProjectionExec: expr=[2 as COUNT(*)] +01)ProjectionExec: expr=[2 as count(*)] 02)--PlaceholderRowExec query I @@ -360,15 +360,15 @@ query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Projection: 03)----Limit: skip=6, fetch=3 04)------Filter: t1.a > Int32(3) 05)--------TableScan: t1 projection=[a] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +01)AggregateExec: mode=Final, gby=[], aggr=[count(*)] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +03)----AggregateExec: mode=Partial, gby=[], aggr=[count(*)] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------ProjectionExec: expr=[] 06)----------GlobalLimitExec: skip=6, fetch=3 diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 78efbb3f564b..573441ab4401 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -564,3 +564,121 @@ SELECT c1%0 FROM test_non_nullable_decimal statement ok drop table test_non_nullable_decimal + +statement ok +CREATE TABLE signed_integers( + a INT, + b INT, + c INT, + d INT, + e INT, + f INT +) as VALUES + (-1, 100, -567, 1024, -4, 10), + (2, -1000, 123, -256, 5, -11), + (-3, 10000, -978, 2048, -6, 12), + (4, NULL, NULL, -512, NULL, NULL) +; + + +## gcd + +# gcd scalar function +query IIIIII rowsort +select gcd(0, 0), gcd(2, 0), gcd(0, 2), gcd(2, 3), gcd(15, 10), gcd(20, 1000) +---- +0 2 2 1 5 20 + +# gcd with negative values +query IIIII +select gcd(-100, 0), gcd(0, -100), gcd(-2, 3), gcd(15, -10), gcd(-20, -1000) +---- +100 100 1 5 20 + +# gcd scalar nulls +query III +select gcd(null, 64), gcd(2, null), gcd(null, null); +---- +NULL NULL NULL + +# scalar maxes and/or negative 1 +query III +select + gcd(9223372036854775807, -9223372036854775808), -- i64::MAX, i64::MIN + gcd(9223372036854775807, -1), -- i64::MAX, -1 + gcd(-9223372036854775808, -1); -- i64::MIN, -1 +---- +1 1 1 + +# gcd with columns and expresions +query II rowsort +select gcd(a, b), gcd(c*d + 1, abs(e)) + f from signed_integers; +---- +1 11 +1 13 +2 -10 +NULL NULL + +# gcd(i64::MIN, i64::MIN) +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in GCD\(\-9223372036854775808, \-9223372036854775808\) +select gcd(-9223372036854775808, -9223372036854775808); + +# gcd(i64::MIN, 0) +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in GCD\(\-9223372036854775808, 0\) +select gcd(-9223372036854775808, 0); + +# gcd(0, i64::MIN) +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in GCD\(0, \-9223372036854775808\) +select gcd(0, -9223372036854775808); + + +## lcm + +# Basic cases +query IIIIII +select lcm(0, 0), lcm(0, 2), lcm(3, 0), lcm(2, 3), lcm(15, 10), lcm(20, 1000) +---- +0 0 0 6 30 1000 + +# Test lcm with negative numbers +query IIIII +select lcm(0, -2), lcm(-3, 0), lcm(-2, 3), lcm(15, -10), lcm(-15, -10) +---- +0 0 6 30 30 + +# Test lcm with Nulls +query III +select lcm(null, 64), lcm(16, null), lcm(null, null) +---- +NULL NULL NULL + +# Test lcm with columns +query III rowsort +select lcm(a, b), lcm(c, d), lcm(e, f) from signed_integers; +---- +100 580608 20 +1000 31488 55 +30000 1001472 12 +NULL NULL NULL + +# Result cannot fit in i64 +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in LCM\(\-9223372036854775808, \-9223372036854775808\) +select lcm(-9223372036854775808, -9223372036854775808); + +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in LCM\(1, \-9223372036854775808\) +select lcm(1, -9223372036854775808); + +# Overflow on multiplication +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in LCM\(2, 9223372036854775803\) +select lcm(2, 9223372036854775803); + + +query error DataFusion error: Arrow error: Compute error: Overflow happened on: 2107754225 \^ 1221660777 +select power(2107754225, 1221660777); + +# factorial overflow +query error DataFusion error: Arrow error: Compute error: Overflow happened on FACTORIAL\(350943270\) +select FACTORIAL(350943270); + +statement ok +drop table signed_integers diff --git a/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt b/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt index f578b08482ac..de6a153f58d9 100644 --- a/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt +++ b/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt @@ -48,8 +48,8 @@ FROM test_table t GROUP BY 1, 2, 3, 4 ---- logical_plan -01)Projection: t.c1, Int64(99999), t.c5 + t.c8, Utf8("test"), COUNT(Int64(1)) -02)--Aggregate: groupBy=[[t.c1, t.c5 + t.c8]], aggr=[[COUNT(Int64(1))]] +01)Projection: t.c1, Int64(99999), t.c5 + t.c8, Utf8("test"), count(Int64(1)) +02)--Aggregate: groupBy=[[t.c1, t.c5 + t.c8]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: t 04)------TableScan: test_table projection=[c1, c5, c8] @@ -60,8 +60,8 @@ FROM test_table t group by 1, 2, 3 ---- logical_plan -01)Projection: Int64(123), Int64(456), Int64(789), COUNT(Int64(1)), AVG(t.c12) -02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)), AVG(t.c12)]] +01)Projection: Int64(123), Int64(456), Int64(789), count(Int64(1)), avg(t.c12) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1)), avg(t.c12)]] 03)----SubqueryAlias: t 04)------TableScan: test_table projection=[c12] @@ -72,8 +72,8 @@ FROM test_table t GROUP BY 1, 2 ---- logical_plan -01)Projection: Date32("2023-05-04") AS dt, Boolean(true) AS today_filter, COUNT(Int64(1)) -02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +01)Projection: Date32("2023-05-04") AS dt, Boolean(true) AS today_filter, count(Int64(1)) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: t 04)------TableScan: test_table projection=[] @@ -90,8 +90,8 @@ FROM test_table t GROUP BY 1 ---- logical_plan -01)Projection: Boolean(true) AS NOT date_part(Utf8("MONTH"),now()) BETWEEN Int64(50) AND Int64(60), COUNT(Int64(1)) -02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +01)Projection: Boolean(true) AS NOT date_part(Utf8("MONTH"),now()) BETWEEN Int64(50) AND Int64(60), count(Int64(1)) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 03)----SubqueryAlias: t 04)------TableScan: test_table projection=[] diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index ac0dc3018879..ffaae7204eca 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -748,7 +748,7 @@ OR GROUP BY p_partkey; ---- logical_plan -01)Aggregate: groupBy=[[part.p_partkey]], aggr=[[sum(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(DISTINCT partsupp.ps_suppkey)]] +01)Aggregate: groupBy=[[part.p_partkey]], aggr=[[sum(lineitem.l_extendedprice), avg(lineitem.l_discount), count(DISTINCT partsupp.ps_suppkey)]] 02)--Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_partkey, partsupp.ps_suppkey 03)----Inner Join: part.p_partkey = partsupp.ps_partkey 04)------Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_partkey @@ -759,7 +759,7 @@ logical_plan 09)--------------TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23")] 10)------TableScan: partsupp projection=[ps_partkey, ps_suppkey] physical_plan -01)AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as p_partkey], aggr=[sum(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(DISTINCT partsupp.ps_suppkey)] +01)AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as p_partkey], aggr=[sum(lineitem.l_extendedprice), avg(lineitem.l_discount), count(DISTINCT partsupp.ps_suppkey)] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, ps_partkey@0)], projection=[l_extendedprice@0, l_discount@1, p_partkey@2, ps_suppkey@4] 04)------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index a45ce3718bc4..fed7ac31712c 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -322,3 +322,84 @@ true statement ok drop table t; + +statement ok +create or replace table strings as values + ('FooBar'), + ('Foo'), + ('Foo'), + ('Bar'), + ('FooBar'), + ('Bar'), + ('Baz'); + +statement ok +create or replace table dict_table as +select arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1 +from strings; + +query ? +select column1 from dict_table where column1 LIKE '%oo%'; +---- +FooBar +Foo +Foo +FooBar + +query ? +select column1 from dict_table where column1 NOT LIKE '%oo%'; +---- +Bar +Bar +Baz + +query ? +select column1 from dict_table where column1 ILIKE '%oO%'; +---- +FooBar +Foo +Foo +FooBar + +query ? +select column1 from dict_table where column1 NOT ILIKE '%oO%'; +---- +Bar +Bar +Baz + + +# plan should not cast the column, instead it should use the dictionary directly +query TT +explain select column1 from dict_table where column1 LIKE '%oo%'; +---- +logical_plan +01)Filter: dict_table.column1 LIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column1@0 LIKE %oo% +03)----MemoryExec: partitions=1, partition_sizes=[1] + +# Ensure casting / coercion works for all operators +# (there should be no casts to Utf8) +query TT +explain select + column1 LIKE '%oo%', + column1 NOT LIKE '%oo%', + column1 ILIKE '%oo%', + column1 NOT ILIKE '%oo%' +from dict_table; +---- +logical_plan +01)Projection: dict_table.column1 LIKE Utf8("%oo%"), dict_table.column1 NOT LIKE Utf8("%oo%"), dict_table.column1 ILIKE Utf8("%oo%"), dict_table.column1 NOT ILIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)ProjectionExec: expr=[column1@0 LIKE %oo% as dict_table.column1 LIKE Utf8("%oo%"), column1@0 NOT LIKE %oo% as dict_table.column1 NOT LIKE Utf8("%oo%"), column1@0 ILIKE %oo% as dict_table.column1 ILIKE Utf8("%oo%"), column1@0 NOT ILIKE %oo% as dict_table.column1 NOT ILIKE Utf8("%oo%")] +02)--MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +drop table strings + +statement ok +drop table dict_table diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 551c50e0a17b..85ac5b0c242d 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -448,86 +448,6 @@ select floor(a), floor(b), floor(c) from signed_integers; 2 -1000 123 4 NULL NULL -## gcd - -# gcd scalar function -query III rowsort -select gcd(0, 0), gcd(2, 3), gcd(15, 10); ----- -0 1 5 - -# gcd scalar nulls -query I rowsort -select gcd(null, 64); ----- -NULL - -# gcd scalar nulls #1 -query I rowsort -select gcd(2, null); ----- -NULL - -# gcd scalar nulls #2 -query I rowsort -select gcd(null, null); ----- -NULL - -# scalar maxes and/or negative 1 -query III rowsort -select - gcd(9223372036854775807, -9223372036854775808), -- i64::MIN, i64::MAX - -- wait till fix, cause it fails gcd(-9223372036854775808, -9223372036854775808), -- -i64::MIN, i64::MIN - gcd(9223372036854775807, -1), -- i64::MAX, -1 - gcd(-9223372036854775808, -1); -- i64::MIN, -1 ----- -1 1 1 - -# gcd with columns -query III rowsort -select gcd(a, b), gcd(c, d), gcd(e, f) from signed_integers; ----- -1 1 2 -1 2 6 -2 1 1 -NULL NULL NULL - -## lcm - -# lcm scalar function -query III rowsort -select lcm(0, 0), lcm(2, 3), lcm(15, 10); ----- -0 6 30 - -# lcm scalar nulls -query I rowsort -select lcm(null, 64); ----- -NULL - -# lcm scalar nulls #1 -query I rowsort -select lcm(2, null); ----- -NULL - -# lcm scalar nulls #2 -query I rowsort -select lcm(null, null); ----- -NULL - -# lcm with columns -query III rowsort -select lcm(a, b), lcm(c, d), lcm(e, f) from signed_integers; ----- -100 580608 20 -1000 31488 55 -30000 1001472 12 -NULL NULL NULL - ## ln # ln scalar function @@ -858,6 +778,16 @@ select round(a), round(b), round(c) from small_floats; 0 0 1 1 0 0 +# round with too large +# max Int32 is 2147483647 +query error DataFusion error: Execution error: Invalid values for decimal places: Cast error: Can't cast value 2147483648 to type Int32 +select round(3.14, 2147483648); + +# with array +query error DataFusion error: Execution error: Invalid values for decimal places: Cast error: Can't cast value 2147483649 to type Int32 +select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 2147483649); + + ## signum # signum scalar function diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index c8ef2b7f5e0b..f9baf8db69d5 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1383,16 +1383,16 @@ ORDER BY c1, c2) GROUP BY c2; ---- logical_plan -01)Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Projection: aggregate_test_100.c2 03)----Sort: aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST 04)------Projection: aggregate_test_100.c2, aggregate_test_100.c1 05)--------TableScan: aggregate_test_100 projection=[c1, c2] physical_plan -01)AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[COUNT(*)] +01)AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[count(*)] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----RepartitionExec: partitioning=Hash([c2@0], 2), input_partitions=2 -04)------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[COUNT(*)] +04)------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[count(*)] 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], has_header=true diff --git a/datafusion/sqllogictest/test_files/strings.slt b/datafusion/sqllogictest/test_files/strings.slt index 27ed0e2d0983..3cd6c339b44f 100644 --- a/datafusion/sqllogictest/test_files/strings.slt +++ b/datafusion/sqllogictest/test_files/strings.slt @@ -78,3 +78,52 @@ e1 p2 p2e1 p2m1e1 + +## VARCHAR with length support + +# Lengths can be used by default +query T +SELECT '12345'::VARCHAR(2); +---- +12345 + +# Lengths can not be used when the config setting is disabled + +statement ok +set datafusion.sql_parser.support_varchar_with_length = false; + +query error +SELECT '12345'::VARCHAR(2); + +query error +SELECT s::VARCHAR(2) FROM (VALUES ('12345')) t(s); + +statement ok +create table vals(s char) as values('abc'), ('def'); + +query error +SELECT s::VARCHAR(2) FROM vals + +# Lengths can be used when the config setting is enabled + +statement ok +set datafusion.sql_parser.support_varchar_with_length = true; + +query T +SELECT '12345'::VARCHAR(2) +---- +12345 + +query T +SELECT s::VARCHAR(2) FROM (VALUES ('12345')) t(s) +---- +12345 + +query T +SELECT s::VARCHAR(2) FROM vals +---- +abc +def + +statement ok +drop table vals; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index dbdb7fc76b8b..30b3631681e7 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -530,9 +530,9 @@ logical_plan 01)Projection: t1.t1_id, t1.t1_name 02)--Filter: EXISTS () 03)----Subquery: -04)------Projection: COUNT(*) +04)------Projection: count(*) 05)--------Filter: sum(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) -06)----------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*), sum(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +06)----------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*), sum(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] 07)------------Filter: outer_ref(t1.t1_name) = t2.t2_name 08)--------------TableScan: t2 09)----TableScan: t1 projection=[t1_id, t1_name, t1_int] @@ -638,10 +638,7 @@ SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id query TT explain SELECT t1_id, t1_name FROM t1 WHERE NOT EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) ---- -logical_plan -01)LeftAnti Join: t1.t1_id = __correlated_sq_1.t2_id -02)--TableScan: t1 projection=[t1_id, t1_name] -03)--EmptyRelation +logical_plan TableScan: t1 projection=[t1_id, t1_name] query IT rowsort SELECT t1_id, t1_name FROM t1 WHERE NOT EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) @@ -717,9 +714,9 @@ query TT explain select (select count(*) from t1) as b ---- logical_plan -01)Projection: __scalar_sq_1.COUNT(*) AS b +01)Projection: __scalar_sq_1.count(*) AS b 02)--SubqueryAlias: __scalar_sq_1 -03)----Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +03)----Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 04)------TableScan: t1 projection=[] #simple_uncorrelated_scalar_subquery2 @@ -727,13 +724,13 @@ query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.COUNT(*) AS b, __scalar_sq_2.COUNT(Int64(1)) AS COUNT(Int64(1)) +01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) 02)--Left Join: 03)----SubqueryAlias: __scalar_sq_1 -04)------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 05)--------TableScan: t1 projection=[] 06)----SubqueryAlias: __scalar_sq_2 -07)------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +07)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 08)--------TableScan: t2 projection=[] statement ok @@ -743,20 +740,20 @@ query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.COUNT(*) AS b, __scalar_sq_2.COUNT(Int64(1)) AS COUNT(Int64(1)) +01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) 02)--Left Join: 03)----SubqueryAlias: __scalar_sq_1 -04)------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 05)--------TableScan: t1 projection=[] 06)----SubqueryAlias: __scalar_sq_2 -07)------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +07)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 08)--------TableScan: t2 projection=[] physical_plan -01)ProjectionExec: expr=[COUNT(*)@0 as b, COUNT(Int64(1))@1 as COUNT(Int64(1))] +01)ProjectionExec: expr=[count(*)@0 as b, count(Int64(1))@1 as count(Int64(1))] 02)--NestedLoopJoinExec: join_type=Left -03)----ProjectionExec: expr=[4 as COUNT(*)] +03)----ProjectionExec: expr=[4 as count(*)] 04)------PlaceholderRowExec -05)----ProjectionExec: expr=[4 as COUNT(Int64(1))] +05)----ProjectionExec: expr=[4 as count(Int64(1))] 06)------PlaceholderRowExec statement ok @@ -772,12 +769,12 @@ query TT explain SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END AS COUNT(*) +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*), t2.t2_int, Boolean(true) AS __always_true -06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*), t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -794,12 +791,12 @@ query TT explain SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END AS cnt +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS cnt 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*), t2.t2_int, Boolean(true) AS __always_true -06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*), t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -819,8 +816,8 @@ logical_plan 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*) AS _cnt, t2.t2_int, Boolean(true) AS __always_true -06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*) AS _cnt, t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -840,8 +837,8 @@ logical_plan 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*) + Int64(2) AS _cnt, t2.t2_int, Boolean(true) AS __always_true -06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*) + Int64(2) AS _cnt, t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -858,13 +855,13 @@ explain select t1.t1_int from t1 where (select count(*) from t2 where t1.t1_id = ---- logical_plan 01)Projection: t1.t1_int -02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END < CAST(t1.t1_int AS Int64) -03)----Projection: t1.t1_int, __scalar_sq_1.COUNT(*), __scalar_sq_1.__always_true +02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END < CAST(t1.t1_int AS Int64) +03)----Projection: t1.t1_int, __scalar_sq_1.count(*), __scalar_sq_1.__always_true 04)------Left Join: t1.t1_id = __scalar_sq_1.t2_id 05)--------TableScan: t1 projection=[t1_id, t1_int] 06)--------SubqueryAlias: __scalar_sq_1 -07)----------Projection: COUNT(*), t2.t2_id, Boolean(true) AS __always_true -08)------------Aggregate: groupBy=[[t2.t2_id]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +07)----------Projection: count(*), t2.t2_id, Boolean(true) AS __always_true +08)------------Aggregate: groupBy=[[t2.t2_id]], aggr=[[count(Int64(1)) AS count(*)]] 09)--------------TableScan: t2 projection=[t2_id] query I rowsort @@ -884,9 +881,9 @@ logical_plan 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*) + Int64(2) AS cnt_plus_2, t2.t2_int -06)--------Filter: COUNT(*) > Int64(1) -07)----------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*) + Int64(2) AS cnt_plus_2, t2.t2_int +06)--------Filter: count(*) > Int64(1) +07)----------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 08)------------TableScan: t2 projection=[t2_int] query II rowsort @@ -903,12 +900,12 @@ query TT explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS cnt_plus_2 WHEN __scalar_sq_1.COUNT(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS cnt_plus_2 WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*) + Int64(2) AS cnt_plus_2, t2.t2_int, COUNT(*), Boolean(true) AS __always_true -06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*) + Int64(2) AS cnt_plus_2, t2.t2_int, count(*), Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -925,14 +922,14 @@ explain select t1.t1_int from t1 group by t1.t1_int having (select count(*) from ---- logical_plan 01)Projection: t1.t1_int -02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END = Int64(0) -03)----Projection: t1.t1_int, __scalar_sq_1.COUNT(*), __scalar_sq_1.__always_true +02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END = Int64(0) +03)----Projection: t1.t1_int, __scalar_sq_1.count(*), __scalar_sq_1.__always_true 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int 05)--------Aggregate: groupBy=[[t1.t1_int]], aggr=[[]] 06)----------TableScan: t1 projection=[t1_int] 07)--------SubqueryAlias: __scalar_sq_1 -08)----------Projection: COUNT(*), t2.t2_int, Boolean(true) AS __always_true -09)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +08)----------Projection: count(*), t2.t2_int, Boolean(true) AS __always_true +09)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 10)--------------TableScan: t2 projection=[t2_int] query I rowsort @@ -952,8 +949,8 @@ logical_plan 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int 05)--------TableScan: t1 projection=[t1_int] 06)--------SubqueryAlias: __scalar_sq_1 -07)----------Projection: COUNT(*) AS cnt, t2.t2_int, Boolean(true) AS __always_true -08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +07)----------Projection: count(*) AS cnt, t2.t2_int, Boolean(true) AS __always_true +08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 09)--------------TableScan: t2 projection=[t2_int] @@ -977,13 +974,13 @@ select t1.t1_int from t1 where ( ---- logical_plan 01)Projection: t1.t1_int -02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.COUNT(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_two END = Int64(2) -03)----Projection: t1.t1_int, __scalar_sq_1.cnt_plus_two, __scalar_sq_1.COUNT(*), __scalar_sq_1.__always_true +02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_two END = Int64(2) +03)----Projection: t1.t1_int, __scalar_sq_1.cnt_plus_two, __scalar_sq_1.count(*), __scalar_sq_1.__always_true 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int 05)--------TableScan: t1 projection=[t1_int] 06)--------SubqueryAlias: __scalar_sq_1 -07)----------Projection: COUNT(*) + Int64(1) + Int64(1) AS cnt_plus_two, t2.t2_int, COUNT(*), Boolean(true) AS __always_true -08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +07)----------Projection: count(*) + Int64(1) + Int64(1) AS cnt_plus_two, t2.t2_int, count(*), Boolean(true) AS __always_true +08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 09)--------------TableScan: t2 projection=[t2_int] query I rowsort @@ -1011,8 +1008,8 @@ logical_plan 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int 05)--------TableScan: t1 projection=[t1_int] 06)--------SubqueryAlias: __scalar_sq_1 -07)----------Projection: CASE WHEN COUNT(*) = Int64(1) THEN Int64(NULL) ELSE COUNT(*) END AS cnt, t2.t2_int, Boolean(true) AS __always_true -08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +07)----------Projection: CASE WHEN count(*) = Int64(1) THEN Int64(NULL) ELSE count(*) END AS cnt, t2.t2_int, Boolean(true) AS __always_true +08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 09)--------------TableScan: t2 projection=[t2_int] diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 96d846d449e1..2216dbfa5fd5 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2774,6 +2774,26 @@ SELECT '2000-12-01 04:04:12' AT TIME ZONE 'America/New_York'; ---- 2000-12-01T04:04:12-05:00 +query P +SELECT '2024-03-30 00:00:20' AT TIME ZONE 'Europe/Brussels'; +---- +2024-03-30T00:00:20+01:00 + +query P +SELECT '2024-03-30 00:00:20'::timestamp AT TIME ZONE 'Europe/Brussels'; +---- +2024-03-30T00:00:20+01:00 + +query P +SELECT '2024-03-30 00:00:20Z' AT TIME ZONE 'Europe/Brussels'; +---- +2024-03-30T01:00:20+01:00 + +query P +SELECT '2024-03-30 00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'; +---- +2024-03-30T00:00:20+01:00 + ## date-time strings that already have a explicit timezone can be used with AT TIME ZONE # same time zone as provided date-time diff --git a/datafusion/sqllogictest/test_files/tpch/q1.slt.part b/datafusion/sqllogictest/test_files/tpch/q1.slt.part index 5a21bdf276e3..6c3e7dd3618a 100644 --- a/datafusion/sqllogictest/test_files/tpch/q1.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q1.slt.part @@ -41,19 +41,19 @@ explain select ---- logical_plan 01)Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS LAST -02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus, sum(lineitem.l_quantity) AS sum_qty, sum(lineitem.l_extendedprice) AS sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order -03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum(__common_expr_1 * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]] +02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus, sum(lineitem.l_quantity) AS sum_qty, sum(lineitem.l_extendedprice) AS sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, avg(lineitem.l_quantity) AS avg_qty, avg(lineitem.l_extendedprice) AS avg_price, avg(lineitem.l_discount) AS avg_disc, count(*) AS count_order +03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum(__common_expr_1 * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), avg(lineitem.l_quantity), avg(lineitem.l_extendedprice), avg(lineitem.l_discount), count(Int64(1)) AS count(*)]] 04)------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS __common_expr_1, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus 05)--------Filter: lineitem.l_shipdate <= Date32("1998-09-02") 06)----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("1998-09-02")] physical_plan 01)SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST,l_linestatus@1 ASC NULLS LAST] 02)--SortExec: expr=[l_returnflag@0 ASC NULLS LAST,l_linestatus@1 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus, sum(lineitem.l_quantity)@2 as sum_qty, sum(lineitem.l_extendedprice)@3 as sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@4 as sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax)@5 as sum_charge, AVG(lineitem.l_quantity)@6 as avg_qty, AVG(lineitem.l_extendedprice)@7 as avg_price, AVG(lineitem.l_discount)@8 as avg_disc, COUNT(*)@9 as count_order] -04)------AggregateExec: mode=FinalPartitioned, gby=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus], aggr=[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(*)] +03)----ProjectionExec: expr=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus, sum(lineitem.l_quantity)@2 as sum_qty, sum(lineitem.l_extendedprice)@3 as sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@4 as sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax)@5 as sum_charge, avg(lineitem.l_quantity)@6 as avg_qty, avg(lineitem.l_extendedprice)@7 as avg_price, avg(lineitem.l_discount)@8 as avg_disc, count(*)@9 as count_order] +04)------AggregateExec: mode=FinalPartitioned, gby=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus], aggr=[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), avg(lineitem.l_quantity), avg(lineitem.l_extendedprice), avg(lineitem.l_discount), count(*)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([l_returnflag@0, l_linestatus@1], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[l_returnflag@5 as l_returnflag, l_linestatus@6 as l_linestatus], aggr=[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(*)] +07)------------AggregateExec: mode=Partial, gby=[l_returnflag@5 as l_returnflag, l_linestatus@6 as l_linestatus], aggr=[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), avg(lineitem.l_quantity), avg(lineitem.l_extendedprice), avg(lineitem.l_discount), count(*)] 08)--------------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as __common_expr_1, l_quantity@0 as l_quantity, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, l_tax@3 as l_tax, l_returnflag@4 as l_returnflag, l_linestatus@5 as l_linestatus] 09)----------------CoalesceBatchesExec: target_batch_size=8192 10)------------------FilterExec: l_shipdate@6 <= 1998-09-02 diff --git a/datafusion/sqllogictest/test_files/tpch/q13.slt.part b/datafusion/sqllogictest/test_files/tpch/q13.slt.part index f19db720fb2c..f25f23de8817 100644 --- a/datafusion/sqllogictest/test_files/tpch/q13.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q13.slt.part @@ -42,11 +42,11 @@ limit 10; logical_plan 01)Limit: skip=0, fetch=10 02)--Sort: custdist DESC NULLS FIRST, c_orders.c_count DESC NULLS FIRST, fetch=10 -03)----Projection: c_orders.c_count, COUNT(*) AS custdist -04)------Aggregate: groupBy=[[c_orders.c_count]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +03)----Projection: c_orders.c_count, count(*) AS custdist +04)------Aggregate: groupBy=[[c_orders.c_count]], aggr=[[count(Int64(1)) AS count(*)]] 05)--------SubqueryAlias: c_orders -06)----------Projection: COUNT(orders.o_orderkey) AS c_count -07)------------Aggregate: groupBy=[[customer.c_custkey]], aggr=[[COUNT(orders.o_orderkey)]] +06)----------Projection: count(orders.o_orderkey) AS c_count +07)------------Aggregate: groupBy=[[customer.c_custkey]], aggr=[[count(orders.o_orderkey)]] 08)--------------Projection: customer.c_custkey, orders.o_orderkey 09)----------------Left Join: customer.c_custkey = orders.o_custkey 10)------------------TableScan: customer projection=[c_custkey] @@ -57,13 +57,13 @@ physical_plan 01)GlobalLimitExec: skip=0, fetch=10 02)--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC], fetch=10 03)----SortExec: TopK(fetch=10), expr=[custdist@1 DESC,c_count@0 DESC], preserve_partitioning=[true] -04)------ProjectionExec: expr=[c_count@0 as c_count, COUNT(*)@1 as custdist] -05)--------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[COUNT(*)] +04)------ProjectionExec: expr=[c_count@0 as c_count, count(*)@1 as custdist] +05)--------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[count(*)] 06)----------CoalesceBatchesExec: target_batch_size=8192 07)------------RepartitionExec: partitioning=Hash([c_count@0], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[c_count@0 as c_count], aggr=[COUNT(*)] -09)----------------ProjectionExec: expr=[COUNT(orders.o_orderkey)@1 as c_count] -10)------------------AggregateExec: mode=SinglePartitioned, gby=[c_custkey@0 as c_custkey], aggr=[COUNT(orders.o_orderkey)] +08)--------------AggregateExec: mode=Partial, gby=[c_count@0 as c_count], aggr=[count(*)] +09)----------------ProjectionExec: expr=[count(orders.o_orderkey)@1 as c_count] +10)------------------AggregateExec: mode=SinglePartitioned, gby=[c_custkey@0 as c_custkey], aggr=[count(orders.o_orderkey)] 11)--------------------CoalesceBatchesExec: target_batch_size=8192 12)----------------------HashJoinExec: mode=Partitioned, join_type=Left, on=[(c_custkey@0, o_custkey@1)], projection=[c_custkey@0, o_orderkey@1] 13)------------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/q16.slt.part index 2b01980f0e6f..d568b2ca69e6 100644 --- a/datafusion/sqllogictest/test_files/tpch/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q16.slt.part @@ -52,8 +52,8 @@ limit 10; logical_plan 01)Limit: skip=0, fetch=10 02)--Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 -03)----Projection: part.p_brand, part.p_type, part.p_size, COUNT(alias1) AS supplier_cnt -04)------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1)]] +03)----Projection: part.p_brand, part.p_type, part.p_size, count(alias1) AS supplier_cnt +04)------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[count(alias1)]] 05)--------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] 06)----------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey 07)------------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size @@ -69,11 +69,11 @@ physical_plan 01)GlobalLimitExec: skip=0, fetch=10 02)--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10 03)----SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], preserve_partitioning=[true] -04)------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(alias1)@3 as supplier_cnt] -05)--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] +04)------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, count(alias1)@3 as supplier_cnt] +05)--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] 06)----------CoalesceBatchesExec: target_batch_size=8192 07)------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] +08)--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] 09)----------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] 10)------------------CoalesceBatchesExec: target_batch_size=8192 11)--------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 diff --git a/datafusion/sqllogictest/test_files/tpch/q17.slt.part b/datafusion/sqllogictest/test_files/tpch/q17.slt.part index b1562301d9d9..ecb54e97b910 100644 --- a/datafusion/sqllogictest/test_files/tpch/q17.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q17.slt.part @@ -39,7 +39,7 @@ logical_plan 01)Projection: CAST(sum(lineitem.l_extendedprice) AS Float64) / Float64(7) AS avg_yearly 02)--Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice)]] 03)----Projection: lineitem.l_extendedprice -04)------Inner Join: part.p_partkey = __scalar_sq_1.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < __scalar_sq_1.Float64(0.2) * AVG(lineitem.l_quantity) +04)------Inner Join: part.p_partkey = __scalar_sq_1.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < __scalar_sq_1.Float64(0.2) * avg(lineitem.l_quantity) 05)--------Projection: lineitem.l_quantity, lineitem.l_extendedprice, part.p_partkey 06)----------Inner Join: lineitem.l_partkey = part.p_partkey 07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] @@ -47,8 +47,8 @@ logical_plan 09)--------------Filter: part.p_brand = Utf8("Brand#23") AND part.p_container = Utf8("MED BOX") 10)----------------TableScan: part projection=[p_partkey, p_brand, p_container], partial_filters=[part.p_brand = Utf8("Brand#23"), part.p_container = Utf8("MED BOX")] 11)--------SubqueryAlias: __scalar_sq_1 -12)----------Projection: CAST(Float64(0.2) * CAST(AVG(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)), lineitem.l_partkey -13)------------Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[AVG(lineitem.l_quantity)]] +12)----------Projection: CAST(Float64(0.2) * CAST(avg(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)), lineitem.l_partkey +13)------------Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[avg(lineitem.l_quantity)]] 14)--------------TableScan: lineitem projection=[l_partkey, l_quantity] physical_plan 01)ProjectionExec: expr=[CAST(sum(lineitem.l_extendedprice)@0 AS Float64) / 7 as avg_yearly] @@ -56,7 +56,7 @@ physical_plan 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice)] 05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, l_partkey@1)], filter=CAST(l_quantity@0 AS Decimal128(30, 15)) < Float64(0.2) * AVG(lineitem.l_quantity)@1, projection=[l_extendedprice@1] +06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, l_partkey@1)], filter=CAST(l_quantity@0 AS Decimal128(30, 15)) < Float64(0.2) * avg(lineitem.l_quantity)@1, projection=[l_extendedprice@1] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], projection=[l_quantity@1, l_extendedprice@2, p_partkey@3] 09)----------------CoalesceBatchesExec: target_batch_size=8192 @@ -69,11 +69,11 @@ physical_plan 16)------------------------FilterExec: p_brand@1 = Brand#23 AND p_container@2 = MED BOX 17)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 18)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_container], has_header=false -19)------------ProjectionExec: expr=[CAST(0.2 * CAST(AVG(lineitem.l_quantity)@1 AS Float64) AS Decimal128(30, 15)) as Float64(0.2) * AVG(lineitem.l_quantity), l_partkey@0 as l_partkey] -20)--------------AggregateExec: mode=FinalPartitioned, gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)] +19)------------ProjectionExec: expr=[CAST(0.2 * CAST(avg(lineitem.l_quantity)@1 AS Float64) AS Decimal128(30, 15)) as Float64(0.2) * avg(lineitem.l_quantity), l_partkey@0 as l_partkey] +20)--------------AggregateExec: mode=FinalPartitioned, gby=[l_partkey@0 as l_partkey], aggr=[avg(lineitem.l_quantity)] 21)----------------CoalesceBatchesExec: target_batch_size=8192 22)------------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 -23)--------------------AggregateExec: mode=Partial, gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)] +23)--------------------AggregateExec: mode=Partial, gby=[l_partkey@0 as l_partkey], aggr=[avg(lineitem.l_quantity)] 24)----------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q21.slt.part b/datafusion/sqllogictest/test_files/tpch/q21.slt.part index b536dd281eca..74c1c2fa77d7 100644 --- a/datafusion/sqllogictest/test_files/tpch/q21.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q21.slt.part @@ -59,8 +59,8 @@ order by ---- logical_plan 01)Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST -02)--Projection: supplier.s_name, COUNT(*) AS numwait -03)----Aggregate: groupBy=[[supplier.s_name]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +02)--Projection: supplier.s_name, count(*) AS numwait +03)----Aggregate: groupBy=[[supplier.s_name]], aggr=[[count(Int64(1)) AS count(*)]] 04)------Projection: supplier.s_name 05)--------LeftAnti Join: l1.l_orderkey = __correlated_sq_2.l_orderkey Filter: __correlated_sq_2.l_suppkey != l1.l_suppkey 06)----------LeftSemi Join: l1.l_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_suppkey != l1.l_suppkey @@ -92,11 +92,11 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [numwait@1 DESC,s_name@0 ASC NULLS LAST] 02)--SortExec: expr=[numwait@1 DESC,s_name@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[s_name@0 as s_name, COUNT(*)@1 as numwait] -04)------AggregateExec: mode=FinalPartitioned, gby=[s_name@0 as s_name], aggr=[COUNT(*)] +03)----ProjectionExec: expr=[s_name@0 as s_name, count(*)@1 as numwait] +04)------AggregateExec: mode=FinalPartitioned, gby=[s_name@0 as s_name], aggr=[count(*)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([s_name@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[s_name@0 as s_name], aggr=[COUNT(*)] +07)------------AggregateExec: mode=Partial, gby=[s_name@0 as s_name], aggr=[count(*)] 08)--------------CoalesceBatchesExec: target_batch_size=8192 09)----------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(l_orderkey@1, l_orderkey@0)], filter=l_suppkey@1 != l_suppkey@0, projection=[s_name@0] 10)------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/tpch/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/q22.slt.part index d05666b2513c..b3bfc329642f 100644 --- a/datafusion/sqllogictest/test_files/tpch/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q22.slt.part @@ -57,11 +57,11 @@ order by ---- logical_plan 01)Sort: custsale.cntrycode ASC NULLS LAST -02)--Projection: custsale.cntrycode, COUNT(*) AS numcust, sum(custsale.c_acctbal) AS totacctbal -03)----Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[COUNT(Int64(1)) AS COUNT(*), sum(custsale.c_acctbal)]] +02)--Projection: custsale.cntrycode, count(*) AS numcust, sum(custsale.c_acctbal) AS totacctbal +03)----Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[count(Int64(1)) AS count(*), sum(custsale.c_acctbal)]] 04)------SubqueryAlias: custsale 05)--------Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal -06)----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.AVG(customer.c_acctbal) +06)----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.avg(customer.c_acctbal) 07)------------Projection: customer.c_phone, customer.c_acctbal 08)--------------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey 09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) @@ -69,21 +69,21 @@ logical_plan 11)----------------SubqueryAlias: __correlated_sq_1 12)------------------TableScan: orders projection=[o_custkey] 13)------------SubqueryAlias: __scalar_sq_2 -14)--------------Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] +14)--------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] 15)----------------Projection: customer.c_acctbal 16)------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) 17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2) AS customer.c_acctbal > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]), customer.c_acctbal > Decimal128(Some(0),15,2)] physical_plan 01)SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] 02)--SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[cntrycode@0 as cntrycode, COUNT(*)@1 as numcust, sum(custsale.c_acctbal)@2 as totacctbal] -04)------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[COUNT(*), sum(custsale.c_acctbal)] +03)----ProjectionExec: expr=[cntrycode@0 as cntrycode, count(*)@1 as numcust, sum(custsale.c_acctbal)@2 as totacctbal] +04)------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[count(*), sum(custsale.c_acctbal)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([cntrycode@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[COUNT(*), sum(custsale.c_acctbal)] +07)------------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[count(*), sum(custsale.c_acctbal)] 08)--------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] 09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -10)------------------NestedLoopJoinExec: join_type=Inner, filter=CAST(c_acctbal@0 AS Decimal128(19, 6)) > AVG(customer.c_acctbal)@1 +10)------------------NestedLoopJoinExec: join_type=Inner, filter=CAST(c_acctbal@0 AS Decimal128(19, 6)) > avg(customer.c_acctbal)@1 11)--------------------CoalescePartitionsExec 12)----------------------CoalesceBatchesExec: target_batch_size=8192 13)------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2] @@ -96,9 +96,9 @@ physical_plan 20)--------------------------CoalesceBatchesExec: target_batch_size=8192 21)----------------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4 22)------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_custkey], has_header=false -23)--------------------AggregateExec: mode=Final, gby=[], aggr=[AVG(customer.c_acctbal)] +23)--------------------AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] 24)----------------------CoalescePartitionsExec -25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[AVG(customer.c_acctbal)] +25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] 26)--------------------------ProjectionExec: expr=[c_acctbal@1 as c_acctbal] 27)----------------------------CoalesceBatchesExec: target_batch_size=8192 28)------------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND Use substr(c_phone@0, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]) diff --git a/datafusion/sqllogictest/test_files/tpch/q4.slt.part b/datafusion/sqllogictest/test_files/tpch/q4.slt.part index e2a5b9c5f009..b5a40e5b62d1 100644 --- a/datafusion/sqllogictest/test_files/tpch/q4.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q4.slt.part @@ -41,8 +41,8 @@ order by ---- logical_plan 01)Sort: orders.o_orderpriority ASC NULLS LAST -02)--Projection: orders.o_orderpriority, COUNT(*) AS order_count -03)----Aggregate: groupBy=[[orders.o_orderpriority]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +02)--Projection: orders.o_orderpriority, count(*) AS order_count +03)----Aggregate: groupBy=[[orders.o_orderpriority]], aggr=[[count(Int64(1)) AS count(*)]] 04)------Projection: orders.o_orderpriority 05)--------LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey 06)----------Projection: orders.o_orderkey, orders.o_orderpriority @@ -55,11 +55,11 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST] 02)--SortExec: expr=[o_orderpriority@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[o_orderpriority@0 as o_orderpriority, COUNT(*)@1 as order_count] -04)------AggregateExec: mode=FinalPartitioned, gby=[o_orderpriority@0 as o_orderpriority], aggr=[COUNT(*)] +03)----ProjectionExec: expr=[o_orderpriority@0 as o_orderpriority, count(*)@1 as order_count] +04)------AggregateExec: mode=FinalPartitioned, gby=[o_orderpriority@0 as o_orderpriority], aggr=[count(*)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([o_orderpriority@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[o_orderpriority@0 as o_orderpriority], aggr=[COUNT(*)] +07)------------AggregateExec: mode=Partial, gby=[o_orderpriority@0 as o_orderpriority], aggr=[count(*)] 08)--------------CoalesceBatchesExec: target_batch_size=8192 09)----------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(o_orderkey@0, l_orderkey@0)], projection=[o_orderpriority@1] 10)------------------CoalesceBatchesExec: target_batch_size=8192 diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 36f024961875..7b91e97e4a3e 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -420,16 +420,16 @@ SELECT count(*) FROM ( ) GROUP BY name ---- logical_plan -01)Projection: COUNT(*) -02)--Aggregate: groupBy=[[t1.name]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Projection: count(*) +02)--Aggregate: groupBy=[[t1.name]], aggr=[[count(Int64(1)) AS count(*)]] 03)----Union 04)------Aggregate: groupBy=[[t1.name]], aggr=[[]] 05)--------TableScan: t1 projection=[name] 06)------Aggregate: groupBy=[[t2.name]], aggr=[[]] 07)--------TableScan: t2 projection=[name] physical_plan -01)ProjectionExec: expr=[COUNT(*)@1 as COUNT(*)] -02)--AggregateExec: mode=SinglePartitioned, gby=[name@0 as name], aggr=[COUNT(*)] +01)ProjectionExec: expr=[count(*)@1 as count(*)] +02)--AggregateExec: mode=SinglePartitioned, gby=[name@0 as name], aggr=[count(*)] 03)----InterleaveExec 04)------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] 05)--------CoalesceBatchesExec: target_batch_size=2 @@ -565,8 +565,8 @@ select x, y from (select 1 as x , max(10) as y) b ---- logical_plan 01)Union -02)--Projection: COUNT(*) AS count, a.n -03)----Aggregate: groupBy=[[a.n]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +02)--Projection: count(*) AS count, a.n +03)----Aggregate: groupBy=[[a.n]], aggr=[[count(Int64(1)) AS count(*)]] 04)------SubqueryAlias: a 05)--------Projection: Int64(5) AS n 06)----------EmptyRelation @@ -577,11 +577,11 @@ logical_plan 11)----------EmptyRelation physical_plan 01)UnionExec -02)--ProjectionExec: expr=[COUNT(*)@1 as count, n@0 as n] -03)----AggregateExec: mode=FinalPartitioned, gby=[n@0 as n], aggr=[COUNT(*)], ordering_mode=Sorted +02)--ProjectionExec: expr=[count(*)@1 as count, n@0 as n] +03)----AggregateExec: mode=FinalPartitioned, gby=[n@0 as n], aggr=[count(*)], ordering_mode=Sorted 04)------CoalesceBatchesExec: target_batch_size=2 05)--------RepartitionExec: partitioning=Hash([n@0], 4), input_partitions=1 -06)----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[COUNT(*)], ordering_mode=Sorted +06)----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[count(*)], ordering_mode=Sorted 07)------------ProjectionExec: expr=[5 as n] 08)--------------PlaceholderRowExec 09)--ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 77b839f3f77a..e6f3e70c1ebd 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -1303,14 +1303,14 @@ EXPLAIN SELECT FROM aggregate_test_100 ---- logical_plan -01)Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING -02)--WindowAggr: windowExpr=[[COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +01)Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING +02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 03)----Projection: aggregate_test_100.c1, aggregate_test_100.c2, sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING 04)------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c1, c2, c4] physical_plan -01)ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@2 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] -02)--BoundedWindowAggExec: wdw=[COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +01)ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@2 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] +02)--BoundedWindowAggExec: wdw=[count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] 03)----SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=4096 05)--------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 @@ -1763,8 +1763,8 @@ EXPLAIN SELECT count(*) as global_count FROM ORDER BY c1 ) AS a ---- logical_plan -01)Projection: COUNT(*) AS global_count -02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Projection: count(*) AS global_count +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 03)----SubqueryAlias: a 04)------Projection: 05)--------Sort: aggregate_test_100.c1 ASC NULLS LAST @@ -1773,10 +1773,10 @@ logical_plan 08)--------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") 09)----------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] physical_plan -01)ProjectionExec: expr=[COUNT(*)@0 as global_count] -02)--AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +01)ProjectionExec: expr=[count(*)@0 as global_count] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(*)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(*)] 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=2 06)----------ProjectionExec: expr=[] 07)------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] @@ -2573,22 +2573,22 @@ logical_plan 01)Projection: sum1, sum2, sum3, min1, min2, min3, max1, max2, max3, cnt1, cnt2, sumr1, sumr2, sumr3, minr1, minr2, minr3, maxr1, maxr2, maxr3, cntr1, cntr2, sum4, cnt3 02)--Limit: skip=0, fetch=5 03)----Sort: annotated_data_finite.inc_col DESC NULLS FIRST, fetch=5 -04)------Projection: sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS sum2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS max3, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS cnt1, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING AS sumr1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING AS sumr2, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS maxr3, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS cntr1, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cntr2, sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS sum4, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt3, annotated_data_finite.inc_col -05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] -06)----------Projection: __common_expr_1, annotated_data_finite.inc_col, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING -07)------------WindowAggr: windowExpr=[[sum(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] -08)--------------WindowAggr: windowExpr=[[sum(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] +04)------Projection: sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS sum2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS max3, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS cnt1, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING AS sumr1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING AS sumr2, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS maxr3, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS cntr1, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cntr2, sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS sum4, count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt3, annotated_data_finite.inc_col +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, count(Int64(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] +06)----------Projection: __common_expr_1, annotated_data_finite.inc_col, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING +07)------------WindowAggr: windowExpr=[[sum(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] +08)--------------WindowAggr: windowExpr=[[sum(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] 09)----------------Projection: CAST(annotated_data_finite.desc_col AS Int64) AS __common_expr_1, CAST(annotated_data_finite.inc_col AS Int64) AS __common_expr_2, annotated_data_finite.ts, annotated_data_finite.inc_col, annotated_data_finite.desc_col 10)------------------TableScan: annotated_data_finite projection=[ts, inc_col, desc_col] physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, min1@3 as min1, min2@4 as min2, min3@5 as min3, max1@6 as max1, max2@7 as max2, max3@8 as max3, cnt1@9 as cnt1, cnt2@10 as cnt2, sumr1@11 as sumr1, sumr2@12 as sumr2, sumr3@13 as sumr3, minr1@14 as minr1, minr2@15 as minr2, minr3@16 as minr3, maxr1@17 as maxr1, maxr2@18 as maxr2, maxr3@19 as maxr3, cntr1@20 as cntr1, cntr2@21 as cntr2, sum4@22 as sum4, cnt3@23 as cnt3] 02)--GlobalLimitExec: skip=0, fetch=5 03)----SortExec: TopK(fetch=5), expr=[inc_col@24 DESC], preserve_partitioning=[false] -04)------ProjectionExec: expr=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as sum1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@14 as sum2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@15 as sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as max3, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@22 as cnt1, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cnt2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@2 as sumr1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@3 as sumr2, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as maxr3, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@11 as cntr1, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cntr2, sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as sum4, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as cnt3, inc_col@1 as inc_col] -05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -06)----------ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, inc_col@3 as inc_col, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@5 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@6 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@7 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@12 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@13 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@14 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@15 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@22 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@23 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@25 as COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@26 as COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING] -07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)), is_causal: false }, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -08)--------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)), is_causal: false }, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)), is_causal: false }], mode=[Sorted] +04)------ProjectionExec: expr=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as sum1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@14 as sum2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@15 as sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as max3, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@22 as cnt1, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cnt2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@2 as sumr1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@3 as sumr2, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as maxr3, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@11 as cntr1, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cntr2, sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as sum4, count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as cnt3, inc_col@1 as inc_col] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +06)----------ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, inc_col@3 as inc_col, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@5 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@6 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@7 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@12 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@13 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@14 as count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@15 as count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@22 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@23 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@25 as count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@26 as count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING] +07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)), is_causal: false }, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +08)--------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)), is_causal: false }, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)), is_causal: false }], mode=[Sorted] 09)----------------ProjectionExec: expr=[CAST(desc_col@2 AS Int64) as __common_expr_1, CAST(inc_col@1 AS Int64) as __common_expr_2, ts@0 as ts, inc_col@1 as inc_col, desc_col@2 as desc_col] 10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col, desc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true @@ -2727,8 +2727,8 @@ EXPLAIN SELECT MAX(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as max2, COUNT(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as count1, COUNT(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as count2, - AVG(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as avg1, - AVG(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as avg2 + avg(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as avg1, + avg(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as avg2 FROM annotated_data_finite ORDER BY inc_col ASC LIMIT 5 @@ -2737,18 +2737,18 @@ logical_plan 01)Projection: sum1, sum2, min1, min2, max1, max2, count1, count2, avg1, avg2 02)--Limit: skip=0, fetch=5 03)----Sort: annotated_data_finite.inc_col ASC NULLS LAST, fetch=5 -04)------Projection: sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS sum1, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS min1, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS min2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS max1, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS max2, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS count1, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS avg1, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS avg2, annotated_data_finite.inc_col -05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, AVG(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, AVG(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +04)------Projection: sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS sum1, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS min1, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS min2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS max1, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS max2, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS count1, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS avg1, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS avg2, annotated_data_finite.inc_col +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, avg(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING]] +06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, avg(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] 07)------------Projection: CAST(annotated_data_finite.inc_col AS Int64) AS __common_expr_1, CAST(annotated_data_finite.inc_col AS Float64) AS __common_expr_2, annotated_data_finite.ts, annotated_data_finite.inc_col 08)--------------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, min1@2 as min1, min2@3 as min2, max1@4 as max1, max2@5 as max2, count1@6 as count1, count2@7 as count2, avg1@8 as avg1, avg2@9 as avg2] 02)--GlobalLimitExec: skip=0, fetch=5 03)----SortExec: TopK(fetch=5), expr=[inc_col@10 ASC NULLS LAST], preserve_partitioning=[false] -04)------ProjectionExec: expr=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@9 as sum1, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as sum2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@10 as min1, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@5 as min2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@11 as max1, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@6 as max2, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@12 as count1, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@7 as count2, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@13 as avg1, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@8 as avg2, inc_col@3 as inc_col] -05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }], mode=[Sorted] -06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }], mode=[Sorted] +04)------ProjectionExec: expr=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@9 as sum1, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as sum2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@10 as min1, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@5 as min2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@11 as max1, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@6 as max2, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@12 as count1, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@7 as count2, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@13 as avg1, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@8 as avg2, inc_col@3 as inc_col] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }], mode=[Sorted] 07)------------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as __common_expr_1, CAST(inc_col@1 AS Float64) as __common_expr_2, ts@0 as ts, inc_col@1 as inc_col] 08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true @@ -2838,17 +2838,17 @@ logical_plan 01)Projection: sum1, sum2, count1, count2 02)--Limit: skip=0, fetch=5 03)----Sort: annotated_data_infinite.ts ASC NULLS LAST, fetch=5 -04)------Projection: sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts -05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +04)------Projection: sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] +06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] 07)------------Projection: CAST(annotated_data_infinite.inc_col AS Int64) AS __common_expr_1, annotated_data_infinite.ts, annotated_data_infinite.inc_col 08)--------------TableScan: annotated_data_infinite projection=[ts, inc_col] physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2@3 as count2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----ProjectionExec: expr=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2, ts@1 as ts] -04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +03)----ProjectionExec: expr=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2, ts@1 as ts] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] 06)----------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as __common_expr_1, ts@0 as ts, inc_col@1 as inc_col] 07)------------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] @@ -2885,17 +2885,17 @@ logical_plan 01)Projection: sum1, sum2, count1, count2 02)--Limit: skip=0, fetch=5 03)----Sort: annotated_data_infinite.ts ASC NULLS LAST, fetch=5 -04)------Projection: sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts -05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +04)------Projection: sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] +06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] 07)------------Projection: CAST(annotated_data_infinite.inc_col AS Int64) AS __common_expr_1, annotated_data_infinite.ts, annotated_data_infinite.inc_col 08)--------------TableScan: annotated_data_infinite projection=[ts, inc_col] physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2@3 as count2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----ProjectionExec: expr=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2, ts@1 as ts] -04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +03)----ProjectionExec: expr=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2, ts@1 as ts] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] 06)----------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as __common_expr_1, ts@0 as ts, inc_col@1 as inc_col] 07)------------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] @@ -3630,7 +3630,7 @@ set datafusion.execution.target_partitions = 2; # we should still have the orderings [a ASC, b ASC], [c ASC]. query TT EXPLAIN SELECT *, - AVG(d) OVER sliding_window AS avg_d + avg(d) OVER sliding_window AS avg_d FROM multiple_ordered_table_inf WINDOW sliding_window AS ( PARTITION BY d @@ -3640,13 +3640,13 @@ ORDER BY c ---- logical_plan 01)Sort: multiple_ordered_table_inf.c ASC NULLS LAST -02)--Projection: multiple_ordered_table_inf.a0, multiple_ordered_table_inf.a, multiple_ordered_table_inf.b, multiple_ordered_table_inf.c, multiple_ordered_table_inf.d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW AS avg_d -03)----WindowAggr: windowExpr=[[AVG(CAST(multiple_ordered_table_inf.d AS Float64)) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW]] +02)--Projection: multiple_ordered_table_inf.a0, multiple_ordered_table_inf.a, multiple_ordered_table_inf.b, multiple_ordered_table_inf.c, multiple_ordered_table_inf.d, avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW AS avg_d +03)----WindowAggr: windowExpr=[[avg(CAST(multiple_ordered_table_inf.d AS Float64)) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW]] 04)------TableScan: multiple_ordered_table_inf projection=[a0, a, b, c, d] physical_plan 01)SortPreservingMergeExec: [c@3 ASC NULLS LAST] -02)--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] -03)----BoundedWindowAggExec: wdw=[AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] +02)--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] +03)----BoundedWindowAggExec: wdw=[avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] 04)------CoalesceBatchesExec: target_batch_size=4096 05)--------RepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST 06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -4134,13 +4134,13 @@ query TT EXPLAIN select count(*) over (partition by a order by a) from (select * from a where a = 1); ---- logical_plan -01)Projection: COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -02)--WindowAggr: windowExpr=[[COUNT(Int64(1)) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)Projection: count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 03)----Filter: a.a = Int64(1) 04)------TableScan: a projection=[a] physical_plan -01)ProjectionExec: expr=[COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -02)--BoundedWindowAggExec: wdw=[COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +01)ProjectionExec: expr=[count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +02)--BoundedWindowAggExec: wdw=[count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index d934dba4cfea..f3f8f6e3abca 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.75" +rust-version = "1.76" [lints] workspace = true diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 93f197885c0a..9bc842a12af4 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -250,14 +250,14 @@ pub async fn from_substrait_plan( match plan { // If the last node of the plan produces expressions, bake the renames into those expressions. // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)), + LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)), LogicalPlan::Aggregate(a) => { - let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?; + let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?; Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?)) }, // There are probably more plans where we could bake things in, can add them later as needed. // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?)) + _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?)) } } }, @@ -308,34 +308,46 @@ pub fn extract_projection( } } +/// Ensure the expressions have the right name(s) according to the new schema. +/// This includes the top-level (column) name, which will be renamed through aliasing if needed, +/// as well as nested names (if the expression produces any struct types), which will be renamed +/// through casting if needed. fn rename_expressions( exprs: impl IntoIterator, input_schema: &DFSchema, - new_schema: DFSchemaRef, + new_schema: &DFSchema, ) -> Result> { exprs .into_iter() .zip(new_schema.fields()) .map(|(old_expr, new_field)| { - if &old_expr.get_type(input_schema)? == new_field.data_type() { - // Alias column if needed - old_expr.alias_if_changed(new_field.name().into()) - } else { - // Use Cast to rename inner struct fields + alias column if needed + // Check if type (i.e. nested struct field names) match, use Cast to rename if needed + let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { Expr::Cast(Cast::new( Box::new(old_expr), new_field.data_type().to_owned(), )) - .alias_if_changed(new_field.name().into()) + } else { + old_expr + }; + // Alias column if needed to fix the top-level name + match &new_expr { + // If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier + Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr), + _ => new_expr.alias_if_changed(new_field.name().to_owned()), } }) .collect() } +/// Produce a version of the given schema with names matching the given list of names. +/// Substrait doesn't deal with column (incl. nested struct field) names within the schema, +/// but it does give us the list of expected names at the end of the plan, so we use this +/// to rename the schema to match the expected names. fn make_renamed_schema( schema: &DFSchemaRef, dfs_names: &Vec, -) -> Result { +) -> Result { fn rename_inner_fields( dtype: &DataType, dfs_names: &Vec, @@ -401,10 +413,10 @@ fn make_renamed_schema( dfs_names.len()); } - Ok(Arc::new(DFSchema::from_field_specific_qualified_schema( + DFSchema::from_field_specific_qualified_schema( qualifiers, &Arc::new(Schema::new(fields)), - )?)) + ) } /// Convert Substrait Rel to DataFusion DataFrame @@ -594,6 +606,8 @@ pub async fn from_substrait_rel( let right = LogicalPlanBuilder::from( from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?, ); + let (left, right) = requalify_sides_if_needed(left, right)?; + let join_type = from_substrait_jointype(join.r#type)?; // The join condition expression needs full input schema and not the output schema from join since we lose columns from // certain join types such as semi and anti joins @@ -627,13 +641,15 @@ pub async fn from_substrait_rel( } } Some(RelType::Cross(cross)) => { - let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + let left = LogicalPlanBuilder::from( from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, ); - let right = + let right = LogicalPlanBuilder::from( from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) - .await?; - left.cross_join(right)?.build() + .await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + left.cross_join(right.build()?)?.build() } Some(RelType::Read(read)) => match &read.as_ref().read_type { Some(ReadType::NamedTable(nt)) => { @@ -846,6 +862,34 @@ pub async fn from_substrait_rel( } } +/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise +/// conflict with the columns from the other. +/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For +/// Substrait the names don't matter since it only refers to columns by indices, however DataFusion +/// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). +fn requalify_sides_if_needed( + left: LogicalPlanBuilder, + right: LogicalPlanBuilder, +) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder)> { + let left_cols = left.schema().columns(); + let right_cols = right.schema().columns(); + if left_cols.iter().any(|l| { + right_cols.iter().any(|r| { + l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none())) + }) + }) { + // These names have no connection to the original plan, but they'll make the columns + // (mostly) unique. There may be cases where this still causes duplicates, if either left + // or right side itself contains duplicate names with different qualifiers. + Ok(( + left.alias(TableReference::bare("left"))?, + right.alias(TableReference::bare("right"))?, + )) + } else { + Ok((left, right)) + } +} + fn from_substrait_jointype(join_type: i32) -> Result { if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) { match substrait_join_type { @@ -983,7 +1027,7 @@ pub async fn from_substrait_agg_func( // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { // deal with situation that count(*) got no arguments - if fun.name() == "COUNT" && args.is_empty() { + if fun.name() == "count" && args.is_empty() { args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); } @@ -1405,6 +1449,36 @@ fn from_substrait_type( )?, } } + r#type::Kind::Map(map) => { + let key_type = map.key.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have key type") + })?; + let value_type = map.value.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have value type") + })?; + let key_field = Arc::new(Field::new( + "key", + from_substrait_type(key_type, dfs_names, name_idx)?, + false, + )); + let value_field = Arc::new(Field::new( + "value", + from_substrait_type(value_type, dfs_names, name_idx)?, + true, + )); + match map.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => { + Ok(DataType::Map(Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), false)) + }, + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + )?, + } + } r#type::Kind::Decimal(d) => match d.type_variation_reference { DECIMAL_128_TYPE_VARIATION_REF => { Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) @@ -1686,7 +1760,7 @@ fn from_substrait_literal( let element_type = elements[0].data_type(); match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( - ScalarValue::new_list(elements.as_slice(), &element_type), + ScalarValue::new_list_nullable(elements.as_slice(), &element_type), ), LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( ScalarValue::new_large_list(elements.as_slice(), &element_type), @@ -1704,7 +1778,7 @@ fn from_substrait_literal( )?; match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::List(ScalarValue::new_list(&[], &element_type)) + ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) } LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( ScalarValue::new_large_list(&[], &element_type), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c0469d333164..302f38606bfb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -664,7 +664,12 @@ fn to_substrait_join_expr( extension_info, )?; // AND with existing expression - exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info)); + exprs.push(make_binary_op_scalar_func( + &l, + &r, + eq_op.clone(), + extension_info, + )); } let join_expr: Option = exprs.into_iter().reduce(|acc: Expression, e: Expression| { @@ -1154,7 +1159,12 @@ pub fn to_substrait_rex( let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; - Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) + Ok(make_binary_op_scalar_func( + &l, + &r, + op.clone(), + extension_info, + )) } Expr::Case(Case { expr, @@ -1619,6 +1629,27 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> Result match inner.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + let key_type = to_substrait_type( + key_and_value[0].data_type(), + key_and_value[0].is_nullable(), + )?; + let value_type = to_substrait_type( + key_and_value[1].data_type(), + key_and_value[1].is_nullable(), + )?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Map(Box::new(r#type::Map { + key: Some(Box::new(key_type)), + value: Some(Box::new(value_type)), + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, DataType::Struct(fields) => { let field_types = fields .iter() @@ -2231,11 +2262,11 @@ mod test { round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; - round_trip_literal(ScalarValue::List(ScalarValue::new_list( + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( &[ScalarValue::Float32(Some(1.0))], &DataType::Float32, )))?; - round_trip_literal(ScalarValue::List(ScalarValue::new_list( + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( &[], &DataType::Float32, )))?; @@ -2316,6 +2347,19 @@ mod test { Field::new_list_field(DataType::Int32, true).into(), ))?; + round_trip_type(DataType::Map( + Field::new_struct( + "entries", + [ + Field::new("key", DataType::Utf8, false).into(), + Field::new("value", DataType::Int32, true).into(), + ], + false, + ) + .into(), + false, + ))?; + round_trip_type(DataType::Struct( vec![ Field::new("c0", DataType::Int32, true), diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index e0151ecc3a4f..8ea3a69cab61 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -46,9 +46,9 @@ mod tests { let plan_str = format!("{:?}", plan); assert_eq!( plan_str, - "Projection: FILENAME_PLACEHOLDER_0.l_returnflag AS L_RETURNFLAG, FILENAME_PLACEHOLDER_0.l_linestatus AS L_LINESTATUS, sum(FILENAME_PLACEHOLDER_0.l_quantity) AS SUM_QTY, sum(FILENAME_PLACEHOLDER_0.l_extendedprice) AS SUM_BASE_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount) AS SUM_DISC_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax) AS SUM_CHARGE, AVG(FILENAME_PLACEHOLDER_0.l_quantity) AS AVG_QTY, AVG(FILENAME_PLACEHOLDER_0.l_extendedprice) AS AVG_PRICE, AVG(FILENAME_PLACEHOLDER_0.l_discount) AS AVG_DISC, COUNT(Int64(1)) AS COUNT_ORDER\ + "Projection: FILENAME_PLACEHOLDER_0.l_returnflag AS L_RETURNFLAG, FILENAME_PLACEHOLDER_0.l_linestatus AS L_LINESTATUS, sum(FILENAME_PLACEHOLDER_0.l_quantity) AS SUM_QTY, sum(FILENAME_PLACEHOLDER_0.l_extendedprice) AS SUM_BASE_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount) AS SUM_DISC_PRICE, sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax) AS SUM_CHARGE, avg(FILENAME_PLACEHOLDER_0.l_quantity) AS AVG_QTY, avg(FILENAME_PLACEHOLDER_0.l_extendedprice) AS AVG_PRICE, avg(FILENAME_PLACEHOLDER_0.l_discount) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER\ \n Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\ - \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[sum(FILENAME_PLACEHOLDER_0.l_quantity), sum(FILENAME_PLACEHOLDER_0.l_extendedprice), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[sum(FILENAME_PLACEHOLDER_0.l_quantity), sum(FILENAME_PLACEHOLDER_0.l_extendedprice), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), sum(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), avg(FILENAME_PLACEHOLDER_0.l_quantity), avg(FILENAME_PLACEHOLDER_0.l_extendedprice), avg(FILENAME_PLACEHOLDER_0.l_discount), count(Int64(1))]]\ \n Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\ \n Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\ \n TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]" diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 4e4fa45a15a6..7ed376f62ba0 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -239,7 +239,7 @@ async fn aggregate_grouping_sets() -> Result<()> { async fn aggregate_grouping_rollup() -> Result<()> { assert_expected_plan( "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)", - "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[AVG(data.b)]]\ + "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\ \n TableScan: data projection=[a, b, c, e]", true ).await @@ -491,6 +491,23 @@ async fn roundtrip_outer_join() -> Result<()> { roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a = data2.a").await } +#[tokio::test] +async fn roundtrip_self_join() -> Result<()> { + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. + // This roundtrip works because we set aliases to what the Substrait consumer will generate. + roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.a = right.a").await?; + roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.b = right.b").await +} + +#[tokio::test] +async fn roundtrip_self_implicit_cross_join() -> Result<()> { + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. + // This roundtrip works because we set aliases to what the Substrait consumer will generate. + roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right").await +} + #[tokio::test] async fn roundtrip_arithmetic_ops() -> Result<()> { roundtrip("SELECT a - a FROM data").await?; @@ -594,10 +611,10 @@ async fn roundtrip_union_all() -> Result<()> { #[tokio::test] async fn simple_intersect() -> Result<()> { - // Substrait treats both COUNT(*) and COUNT(1) the same + // Substrait treats both count(*) and count(1) the same assert_expected_plan( - "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\ + "SELECT count(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\ \n Projection: \ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ @@ -610,7 +627,22 @@ async fn simple_intersect() -> Result<()> { #[tokio::test] async fn simple_intersect_table_reuse() -> Result<()> { - roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. + // In this case the aliasing happens at a different point in the plan, so we cannot use roundtrip. + // Schema check works because we set aliases to what the Substrait consumer will generate. + assert_expected_plan( + "SELECT count(1) FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);", + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ + \n Projection: \ + \n LeftSemi Join: left.a = right.a\ + \n SubqueryAlias: left\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n SubqueryAlias: right\ + \n TableScan: data projection=[a]", + true + ).await } #[tokio::test] @@ -628,32 +660,6 @@ async fn qualified_catalog_schema_table_reference() -> Result<()> { roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await } -#[tokio::test] -async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> { - assert_expected_plan( - "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a", - "Projection: data.b, data.c\ - \n Inner Join: data.a = data.a\ - \n TableScan: data projection=[a, b]\ - \n TableScan: data projection=[a, c]", - false, // "d1" vs "data" field qualifier - ) - .await -} - -#[tokio::test] -async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> { - assert_expected_plan( - "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", - "Projection: data.b, data.c\ - \n Inner Join: data.b = data.b\ - \n TableScan: data projection=[b]\ - \n TableScan: data projection=[b, c]", - false, // "d1" vs "data" field qualifier - ) - .await -} - /// Construct a plan that contains several literals of types that are currently supported. /// This case ignores: /// - Date64, for this literal is not supported @@ -707,20 +713,17 @@ async fn roundtrip_literal_struct() -> Result<()> { #[tokio::test] async fn roundtrip_values() -> Result<()> { // TODO: would be nice to have a struct inside the LargeList, but arrow_cast doesn't support that currently - let values = "(\ + assert_expected_plan( + "VALUES \ + (\ 1, \ 'a', \ [[-213.1, NULL, 5.5, 2.0, 1.0], []], \ arrow_cast([1,2,3], 'LargeList(Int64)'), \ STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \ [STRUCT(STRUCT('a' AS string_field) AS struct_field)]\ - )"; - - // Test LogicalPlan::Values - assert_expected_plan( - format!("VALUES \ - {values}, \ - (NULL, NULL, NULL, NULL, NULL, NULL)").as_str(), + ), \ + (NULL, NULL, NULL, NULL, NULL, NULL)", "Values: \ (\ Int64(1), \ @@ -731,11 +734,28 @@ async fn roundtrip_values() -> Result<()> { List([{struct_field: {string_field: a}}])\ ), \ (Int64(NULL), Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())", - true) - .await?; + true).await +} + +#[tokio::test] +async fn roundtrip_values_empty_relation() -> Result<()> { + roundtrip("SELECT * FROM (VALUES ('a')) LIMIT 0").await +} - // Test LogicalPlan::EmptyRelation - roundtrip(format!("SELECT * FROM (VALUES {values}) LIMIT 0").as_str()).await +#[tokio::test] +async fn roundtrip_values_duplicate_column_join() -> Result<()> { + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. + // This roundtrip works because we set aliases to what the Substrait consumer will generate. + roundtrip( + "SELECT left.column1 as c1, right.column1 as c2 \ + FROM \ + (VALUES (1)) AS left \ + JOIN \ + (VALUES (2)) AS right \ + ON left.column1 == right.column1", + ) + .await } /// Construct a plan that cast columns. Only those SQL types are supported for now. diff --git a/docs/source/index.rst b/docs/source/index.rst index 8500c70623d8..8677b63c916a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -41,9 +41,9 @@ DataFusion offers SQL and Dataframe APIs, excellent CSV, Parquet, JSON, and Avro, extensive customization, and a great community. -The `example usage`_ section in the user guide and the `datafusion-examples`_ code in the crate contain information on using DataFusion. +To get started with examples, see the `example usage`_ section of the user guide and the `datafusion-examples`_ directory. -Please see the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. +See the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. .. _example usage: user-guide/example-usage.html .. _datafusion-examples: https://github.com/apache/datafusion/tree/main/datafusion-examples diff --git a/docs/source/user-guide/cli/installation.md b/docs/source/user-guide/cli/installation.md index 3a71240783e5..f5114cafe54a 100644 --- a/docs/source/user-guide/cli/installation.md +++ b/docs/source/user-guide/cli/installation.md @@ -56,8 +56,7 @@ this to work. ```bash git clone https://github.com/apache/datafusion -cd arrow-datafusion -git checkout 12.0.0 +cd datafusion docker build -f datafusion-cli/Dockerfile . --tag datafusion-cli docker run -it -v $(your_data_location):/data datafusion-cli ``` diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 80d88632ffdb..c5f22725e0a3 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -113,3 +113,4 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | | datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | | datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. | +| datafusion.sql_parser.support_varchar_with_length | true | If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. |