diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index ce4b4b06cf44..2ddeebbc558e 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -581,4 +581,4 @@ jobs: run: cargo msrv verify - name: Check datafusion-cli working-directory: datafusion-cli - run: cargo msrv verify + run: cargo msrv verify \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 9d2d2d81d680..000000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,69 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# To use this, install the python package `pre-commit` and -# run once `pre-commit install`. This will setup a git pre-commit-hook -# that is executed on each commit and will report the linting problems. -# To run all hooks on all files use `pre-commit run -a` - -repos: - - repo: local - hooks: - - id: rat - name: Release Audit Tool - language: system - entry: bash -c "git archive HEAD --prefix=apache-arrow/ --output=arrow-src.tar && ./dev/release/run-rat.sh arrow-src.tar" - always_run: true - pass_filenames: false - - id: rustfmt - name: Rust Format - language: system - entry: bash -c "cd rust && cargo +stable fmt --all -- --check" - files: ^rust/.*\.rs$ - types: - - file - - rust - - id: cmake-format - name: CMake Format - language: python - entry: python run-cmake-format.py - types: [cmake] - additional_dependencies: - - cmake_format==0.5.2 - - id: hadolint - name: Docker Format - language: docker_image - types: - - dockerfile - entry: --entrypoint /bin/hadolint hadolint/hadolint:latest - - exclude: ^dev/.*$ - - repo: git://github.com/pre-commit/pre-commit-hooks - sha: v1.2.3 - hooks: - - id: flake8 - name: Python Format - files: ^(python|dev|integration)/ - types: - - file - - python - - id: flake8 - name: Cython Format - files: ^python/ - types: - - file - - cython - args: [--config=python/.flake8.cython] diff --git a/Cargo.toml b/Cargo.toml index 6a6928e25bdd..968a74e37f10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,22 +64,22 @@ version = "39.0.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -arrow = { version = "52.0.0", features = [ +arrow = { version = "52.1.0", features = [ "prettyprint", ] } -arrow-array = { version = "52.0.0", default-features = false, features = [ +arrow-array = { version = "52.1.0", default-features = false, features = [ "chrono-tz", ] } -arrow-buffer = { version = "52.0.0", default-features = false } -arrow-flight = { version = "52.0.0", features = [ +arrow-buffer = { version = "52.1.0", default-features = false } +arrow-flight = { version = "52.1.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "52.0.0", default-features = false, features = [ +arrow-ipc = { version = "52.1.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "52.0.0", default-features = false } -arrow-schema = { version = "52.0.0", default-features = false } -arrow-string = { version = "52.0.0", default-features = false } +arrow-ord = { version = "52.1.0", default-features = false } +arrow-schema = { version = "52.1.0", default-features = false } +arrow-string = { version = "52.1.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" @@ -114,7 +114,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.10.1", default-features = false } parking_lot = "0.12" -parquet = { version = "52.0.0", default-features = false, features = [ +parquet = { version = "52.1.0", default-features = false, features = [ "arrow", "async", "object_store", diff --git a/ci/scripts/rust_example.sh b/ci/scripts/rust_example.sh index 0415090665d2..1bb97c88106f 100755 --- a/ci/scripts/rust_example.sh +++ b/ci/scripts/rust_example.sh @@ -19,7 +19,6 @@ set -ex cd datafusion-examples/examples/ -cargo fmt --all -- --check cargo check --examples files=$(ls .) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5fc8dbcfdfb3..500e731a5b4f 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ae9728f104939be6d8d9b368a354b4929b0569160ea1641f0721b55a861ce38" +checksum = "6127ea5e585a12ec9f742232442828ebaf264dfa5eefdd71282376c599562b77" dependencies = [ "arrow-arith", "arrow-array", @@ -151,9 +151,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7029a5b3efbeafbf4a12d12dc16b8f9e9bff20a410b8c25c5d28acc089e1043" +checksum = "7add7f39210b7d726e2a8efc0083e7bf06e8f2d15bdb4896b564dce4410fbf5d" dependencies = [ "arrow-array", "arrow-buffer", @@ -166,9 +166,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d33238427c60271710695f17742f45b1a5dc5bcfc5c15331c25ddfe7abf70d97" +checksum = "81c16ec702d3898c2f5cfdc148443c6cd7dbe5bac28399859eb0a3d38f072827" dependencies = [ "ahash", "arrow-buffer", @@ -183,9 +183,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe9b95e825ae838efaf77e366c00d3fc8cca78134c9db497d6bda425f2e7b7c1" +checksum = "cae6970bab043c4fbc10aee1660ceb5b306d0c42c8cc5f6ae564efcd9759b663" dependencies = [ "bytes", "half", @@ -194,9 +194,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cf8385a9d5b5fcde771661dd07652b79b9139fea66193eda6a88664400ccab" +checksum = "1c7ef44f26ef4f8edc392a048324ed5d757ad09135eff6d5509e6450d39e0398" dependencies = [ "arrow-array", "arrow-buffer", @@ -215,9 +215,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea5068bef430a86690059665e40034625ec323ffa4dd21972048eebb0127adc" +checksum = "5f843490bd258c5182b66e888161bb6f198f49f3792f7c7f98198b924ae0f564" dependencies = [ "arrow-array", "arrow-buffer", @@ -234,9 +234,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb29be98f987bcf217b070512bb7afba2f65180858bca462edf4a39d84a23e10" +checksum = "a769666ffac256dd301006faca1ca553d0ae7cffcf4cd07095f73f95eb226514" dependencies = [ "arrow-buffer", "arrow-schema", @@ -246,9 +246,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffc68f6523970aa6f7ce1dc9a33a7d9284cfb9af77d4ad3e617dbe5d79cc6ec8" +checksum = "dbf9c3fb57390a1af0b7bb3b5558c1ee1f63905f3eccf49ae7676a8d1e6e5a72" dependencies = [ "arrow-array", "arrow-buffer", @@ -261,9 +261,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2041380f94bd6437ab648e6c2085a045e45a0c44f91a1b9a4fe3fed3d379bfb1" +checksum = "654e7f3724176b66ddfacba31af397c48e106fbe4d281c8144e7d237df5acfd7" dependencies = [ "arrow-array", "arrow-buffer", @@ -281,9 +281,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb56ed1547004e12203652f12fe12e824161ff9d1e5cf2a7dc4ff02ba94f413" +checksum = "e8008370e624e8e3c68174faaf793540287106cfda8ad1da862fdc53d8e096b4" dependencies = [ "arrow-array", "arrow-buffer", @@ -296,9 +296,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "575b42f1fc588f2da6977b94a5ca565459f5ab07b60545e17243fb9a7ed6d43e" +checksum = "ca5e3a6b7fda8d9fe03f3b18a2d946354ea7f3c8e4076dbdb502ad50d9d44824" dependencies = [ "ahash", "arrow-array", @@ -311,15 +311,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32aae6a60458a2389c0da89c9de0b7932427776127da1a738e2efc21d32f3393" +checksum = "dab1c12b40e29d9f3b699e0203c2a73ba558444c05e388a4377208f8f9c97eee" [[package]] name = "arrow-select" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de36abaef8767b4220d7b4a8c2fe5ffc78b47db81b03d77e2136091c3ba39102" +checksum = "e80159088ffe8c48965cb9b1a7c968b2729f29f37363df7eca177fc3281fe7c3" dependencies = [ "ahash", "arrow-array", @@ -331,9 +331,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e435ada8409bcafc910bc3e0077f532a4daa20e99060a496685c0e3e53cc2597" +checksum = "0fd04a6ea7de183648edbcb7a6dd925bbd04c210895f6384c780e27a9b54afcd" dependencies = [ "arrow-array", "arrow-buffer", @@ -375,8 +375,8 @@ dependencies = [ "pin-project-lite", "tokio", "xz2", - "zstd 0.13.0", - "zstd-safe 7.0.0", + "zstd 0.13.2", + "zstd-safe 7.2.0", ] [[package]] @@ -875,9 +875,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.103" +version = "1.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2755ff20a1d93490d26ba33a6f092a38a508398a5320df5d4b3014fcccce9410" +checksum = "74b6a57f98764a267ff415d50a25e6e166f3831a5071af4995296ea97d210490" dependencies = [ "jobserver", "libc", @@ -900,7 +900,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1172,7 +1172,7 @@ dependencies = [ "url", "uuid", "xz2", - "zstd 0.13.0", + "zstd 0.13.2", ] [[package]] @@ -1319,6 +1319,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-aggregate", "itertools", "log", "paste", @@ -1964,9 +1965,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +checksum = "c4fe55fb7a772d59a5ff1dfbff4fe0258d19b89fec4b233e75d35d5d2316badc" dependencies = [ "bytes", "futures-channel", @@ -2005,10 +2006,10 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.3.1", + "hyper 1.4.0", "hyper-util", "rustls 0.23.10", - "rustls-native-certs 0.7.0", + "rustls-native-certs 0.7.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -2017,16 +2018,16 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b875924a60b96e5d7b9ae7b066540b1dd1cbd90d1828f54c92e02a283351c56" +checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.3.1", + "hyper 1.4.0", "pin-project-lite", "socket2", "tokio", @@ -2501,7 +2502,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.3.1", + "hyper 1.4.0", "itertools", "md-5", "parking_lot", @@ -2573,14 +2574,14 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] name = "parquet" -version = "52.0.0" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c3b5322cc1bbf67f11c079c42be41a55949099b78732f7dba9e15edde40eab" +checksum = "0f22ba0d95db56dde8685e3fadcb915cdaadda31ab8abbe3ff7f0ad1ef333267" dependencies = [ "ahash", "arrow-array", @@ -2608,7 +2609,7 @@ dependencies = [ "thrift", "tokio", "twox-hash", - "zstd 0.13.0", + "zstd 0.13.2", "zstd-sys", ] @@ -2975,7 +2976,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.0", "hyper-rustls 0.27.2", "hyper-util", "ipnet", @@ -2987,7 +2988,7 @@ dependencies = [ "pin-project-lite", "quinn", "rustls 0.23.10", - "rustls-native-certs 0.7.0", + "rustls-native-certs 0.7.1", "rustls-pemfile 2.1.2", "rustls-pki-types", "serde", @@ -3142,9 +3143,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" dependencies = [ "openssl-probe", "rustls-pemfile 2.1.2", @@ -3180,9 +3181,9 @@ checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78" dependencies = [ "ring 0.17.8", "rustls-pki-types", @@ -3315,9 +3316,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.119" +version = "1.0.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8eddb61f0697cc3989c5d64b452f5488e2b8a60fd7d5076a3045076ffef8cb0" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" dependencies = [ "itoa", "ryu", @@ -3646,9 +3647,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" +checksum = "ce6b6a2fb3a985e99cebfaefa9faa3024743da73304ca1c683a36429613d3d22" dependencies = [ "tinyvec_macros", ] @@ -4097,7 +4098,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4115,7 +4116,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4135,18 +4136,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4157,9 +4158,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4169,9 +4170,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -4181,15 +4182,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -4199,9 +4200,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -4211,9 +4212,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -4223,9 +4224,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -4235,9 +4236,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winreg" @@ -4266,18 +4267,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", @@ -4301,11 +4302,11 @@ dependencies = [ [[package]] name = "zstd" -version = "0.13.0" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" dependencies = [ - "zstd-safe 7.0.0", + "zstd-safe 7.2.0", ] [[package]] @@ -4320,18 +4321,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.0.0" +version = "7.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +checksum = "fa556e971e7b568dc775c136fc9de8c779b1c2fc3a63defaafadffdbd3181afa" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.9+zstd.1.5.5" +version = "2.0.11+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +checksum = "75652c55c0b6f3e6f12eb786fe1bc960396bf05a1eb3bf1f3691c3610ac2e6d4" dependencies = [ "cc", "pkg-config", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 8578476ed43d..bcacf1d52a9b 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -30,7 +30,7 @@ rust-version = "1.76" readme = "README.md" [dependencies] -arrow = { version = "52.0.0" } +arrow = { version = "52.1.0" } async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" @@ -51,7 +51,7 @@ futures = "0.3" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.10.1", features = ["aws", "gcp", "http"] } parking_lot = { version = "0.12" } -parquet = { version = "52.0.0", default-features = false } +parquet = { version = "52.1.0", default-features = false } regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index f93aaec4218d..9b1f2aa125c2 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -252,7 +252,7 @@ mod tests { fn unescape_readline_input() -> Result<()> { let validator = CliHelper::default(); - // shoule be valid + // should be valid let result = readline_direct( Cursor::new( r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');" @@ -326,7 +326,7 @@ mod tests { fn sql_dialect() -> Result<()> { let mut validator = CliHelper::default(); - // shoule be invalid in generic dialect + // should be invalid in generic dialect let result = readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?; assert!( diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 1c5651ad8ac3..90469e6715a6 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -36,9 +36,12 @@ cd datafusion # Download test data git submodule update --init -# Run the `csv_sql` example: +# Change to the examples directory +cd datafusion-examples/examples + +# Run the `dataframe` example: # ... use the equivalent for other examples -cargo run --example csv_sql +cargo run --example dataframe ``` ## Single Process @@ -47,10 +50,9 @@ cargo run --example csv_sql - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) - [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a detailed secondary index that covers the contents of several parquet files -- [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file +- [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule to change a query's semantics (row level access control) - [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog - [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization -- [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file - [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) - [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 @@ -66,16 +68,14 @@ cargo run --example csv_sql - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es - [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to replace certain predicates - [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries -- [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files - [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution - [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into Datafusion `Expr`. - [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from Datafusion `Expr` and `LogicalPlan` -- [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics +- [`pruning.rs`](examples/pruning.rs): Use pruning to rule out files based on statistics - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`regexp.rs`](examples/regexp.rs): Examples of using regular expression functions -- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 2c672a18a738..1259f90d6449 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -92,7 +92,7 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new("prod", args.return_type.clone(), true), Field::new("n", DataType::UInt32, true), @@ -339,7 +339,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { Ok(()) } - /// Generate output, as specififed by `emit_to` and update the intermediate state + /// Generate output, as specified by `emit_to` and update the intermediate state fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { let counts = emit_to.take_needed(&mut self.counts); let prods = emit_to.take_needed(&mut self.prods); diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 41c6381df5d4..11fb6f6ccc48 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -75,7 +75,7 @@ impl WindowUDFImpl for SmoothItUdf { Ok(DataType::Float64) } - /// Create a `PartitionEvalutor` to evaluate this function on a new + /// Create a `PartitionEvaluator` to evaluate this function on a new /// partition. fn partition_evaluator(&self) -> Result> { Ok(Box::new(MyPartitionEvaluator::new())) diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs new file mode 100644 index 000000000000..bd067be97b8b --- /dev/null +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use datafusion::prelude::SessionContext; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_optimizer::analyzer::AnalyzerRule; +use std::sync::{Arc, Mutex}; + +/// This example demonstrates how to add your own [`AnalyzerRule`] to +/// DataFusion. +/// +/// [`AnalyzerRule`]s transform [`LogicalPlan`]s prior to the DataFusion +/// optimization process, and can be used to change the plan's semantics (e.g. +/// output types). +/// +/// This example shows an `AnalyzerRule` which implements a simplistic of row +/// level access control scheme by introducing a filter to the query. +/// +/// See [optimizer_rule.rs] for an example of a optimizer rule +#[tokio::main] +pub async fn main() -> Result<()> { + // AnalyzerRules run before OptimizerRules. + // + // DataFusion includes several built in AnalyzerRules for tasks such as type + // coercion which change the types of expressions in the plan. Add our new + // rule to the context to run it during the analysis phase. + let rule = Arc::new(RowLevelAccessControl::new()); + let ctx = SessionContext::new(); + ctx.add_analyzer_rule(Arc::clone(&rule) as _); + + ctx.register_batch("employee", employee_batch())?; + + // Now, planning any SQL statement also invokes the AnalyzerRule + let plan = ctx + .sql("SELECT * FROM employee") + .await? + .into_optimized_plan()?; + + // Printing the query plan shows a filter has been added + // + // Filter: employee.position = Utf8("Engineer") + // TableScan: employee projection=[name, age, position] + println!("Logical Plan:\n\n{}\n", plan.display_indent()); + + // Execute the query, and indeed no Manager's are returned + // + // +-----------+-----+----------+ + // | name | age | position | + // +-----------+-----+----------+ + // | Andy | 11 | Engineer | + // | Oleks | 33 | Engineer | + // | Xiangpeng | 55 | Engineer | + // +-----------+-----+----------+ + ctx.sql("SELECT * FROM employee").await?.show().await?; + + // We can now change the access level to "Manager" and see the results + // + // +----------+-----+----------+ + // | name | age | position | + // +----------+-----+----------+ + // | Andrew | 22 | Manager | + // | Chunchun | 44 | Manager | + // +----------+-----+----------+ + rule.set_show_position("Manager"); + ctx.sql("SELECT * FROM employee").await?.show().await?; + + // The filters introduced by our AnalyzerRule are treated the same as any + // other filter by the DataFusion optimizer, including predicate push down + // (including into scans), simplifications, and similar optimizations. + // + // For example adding another predicate to the query + let plan = ctx + .sql("SELECT * FROM employee WHERE age > 30") + .await? + .into_optimized_plan()?; + + // We can see the DataFusion Optimizer has combined the filters together + // when we print out the plan + // + // Filter: employee.age > Int32(30) AND employee.position = Utf8("Manager") + // TableScan: employee projection=[name, age, position] + println!("Logical Plan:\n\n{}\n", plan.display_indent()); + + Ok(()) +} + +/// Example AnalyzerRule that implements a very basic "row level access +/// control" +/// +/// In this case, it adds a filter to the plan that removes all managers from +/// the result set. +#[derive(Debug)] +struct RowLevelAccessControl { + /// Models the current access level of the session + /// + /// This is value of the position column which should be included in the + /// result set. It is wrapped in a `Mutex` so we can change it during query + show_position: Mutex, +} + +impl RowLevelAccessControl { + fn new() -> Self { + Self { + show_position: Mutex::new("Engineer".to_string()), + } + } + + /// return the current position to show, as an expression + fn show_position(&self) -> Expr { + lit(self.show_position.lock().unwrap().clone()) + } + + /// specifies a different position to show in the result set + fn set_show_position(&self, access_level: impl Into) { + *self.show_position.lock().unwrap() = access_level.into(); + } +} + +impl AnalyzerRule for RowLevelAccessControl { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + // use the TreeNode API to recursively walk the LogicalPlan tree + // and all of its children (inputs) + let transfomed_plan = plan.transform(|plan| { + // This closure is called for each LogicalPlan node + // if it is a Scan node, add a filter to remove all managers + if is_employee_table_scan(&plan) { + // Use the LogicalPlanBuilder to add a filter to the plan + let filter = LogicalPlanBuilder::from(plan) + // Filter Expression: position = + .filter(col("position").eq(self.show_position()))? + .build()?; + + // `Transformed::yes` signals the plan was changed + Ok(Transformed::yes(filter)) + } else { + // `Transformed::no` + // signals the plan was not changed + Ok(Transformed::no(plan)) + } + })?; + + // the result of calling transform is a `Transformed` structure which + // contains + // + // 1. a flag signaling if any rewrite took place + // 2. a flag if the recursion stopped early + // 3. The actual transformed data (a LogicalPlan in this case) + // + // This example does not need the value of either flag, so simply + // extract the LogicalPlan "data" + Ok(transfomed_plan.data) + } + + fn name(&self) -> &str { + "table_access" + } +} + +fn is_employee_table_scan(plan: &LogicalPlan) -> bool { + if let LogicalPlan::TableScan(scan) = plan { + scan.table_name.table() == "employee" + } else { + false + } +} + +/// Return a RecordBatch with made up data about fictional employees +fn employee_batch() -> RecordBatch { + let name: ArrayRef = Arc::new(StringArray::from_iter_values([ + "Andy", + "Andrew", + "Oleks", + "Chunchun", + "Xiangpeng", + ])); + let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22, 33, 44, 55])); + let position = Arc::new(StringArray::from_iter_values([ + "Engineer", "Manager", "Engineer", "Manager", "Engineer", + ])); + RecordBatch::try_from_iter(vec![("name", name), ("age", age), ("position", position)]) + .unwrap() +} diff --git a/datafusion-examples/examples/avro_sql.rs b/datafusion-examples/examples/avro_sql.rs deleted file mode 100644 index ac1053aa1881..000000000000 --- a/datafusion-examples/examples/avro_sql.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::util::pretty; - -use datafusion::error::Result; -use datafusion::prelude::*; - -/// This example demonstrates executing a simple query against an Arrow data source (Avro) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { - // create local execution context - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::arrow_test_data(); - - // register avro file with the execution context - let avro_file = &format!("{testdata}/avro/alltypes_plain.avro"); - ctx.register_avro("alltypes_plain", avro_file, AvroReadOptions::default()) - .await?; - - // execute the query - let df = ctx - .sql( - "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ - FROM alltypes_plain \ - WHERE id > 1 AND tinyint_col < double_col", - ) - .await?; - let results = df.collect().await?; - - // print the results - pretty::print_batches(&results)?; - - Ok(()) -} diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/catalog.rs index 5bc2cadac128..b9188e1cd5e0 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/catalog.rs @@ -83,7 +83,7 @@ async fn main() -> Result<()> { // register our catalog in the context ctx.register_catalog("dircat", Arc::new(catalog)); { - // catalog was passed down into our custom catalog list since we overide the ctx's default + // catalog was passed down into our custom catalog list since we override the ctx's default let catalogs = catlist.catalogs.read().unwrap(); assert!(catalogs.contains_key("dircat")); }; diff --git a/datafusion-examples/examples/csv_sql.rs b/datafusion-examples/examples/csv_sql.rs deleted file mode 100644 index 851fdcb626d2..000000000000 --- a/datafusion-examples/examples/csv_sql.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::datasource::file_format::file_compression_type::FileCompressionType; -use datafusion::error::Result; -use datafusion::prelude::*; - -/// This example demonstrates executing a simple query against an Arrow data source (CSV) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { - // create local execution context - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::arrow_test_data(); - - // register csv file with the execution context - ctx.register_csv( - "aggregate_test_100", - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::new(), - ) - .await?; - - // execute the query - let df = ctx - .sql( - "SELECT c1, MIN(c12), MAX(c12) \ - FROM aggregate_test_100 \ - WHERE c11 > 0.1 AND c11 < 0.9 \ - GROUP BY c1", - ) - .await?; - - // print the results - df.show().await?; - - // query compressed CSV with specific options - let csv_options = CsvReadOptions::default() - .has_header(true) - .file_compression_type(FileCompressionType::GZIP) - .file_extension("csv.gz"); - let df = ctx - .read_csv( - &format!("{testdata}/csv/aggregate_test_100.csv.gz"), - csv_options, - ) - .await?; - let df = df - .filter(col("c1").eq(lit("a")))? - .select_columns(&["c2", "c3"])?; - - df.show().await?; - - Ok(()) -} diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 36ce3badcb5e..43729a913e5d 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -46,7 +46,7 @@ use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator}; /// /// The code in this example shows how to: /// 1. Create [`Expr`]s using different APIs: [`main`]` -/// 2. Use the fluent API to easly create complex [`Expr`]s: [`expr_fn_demo`] +/// 2. Use the fluent API to easily create complex [`Expr`]s: [`expr_fn_demo`] /// 3. Evaluate [`Expr`]s against data: [`evaluate_demo`] /// 4. Simplify expressions: [`simplify_demo`] /// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`] diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index 057852946341..b4663b345f64 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -19,7 +19,7 @@ use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow_schema::DataType; use datafusion::prelude::SessionContext; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{assert_batches_eq, Result, ScalarValue}; use datafusion_expr::{ BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility, @@ -54,39 +54,46 @@ pub async fn main() -> Result<()> { // We can see the effect of our rewrite on the output plan that the filter // has been rewritten to `my_eq` - // - // Filter: my_eq(person.age, Int32(22)) - // TableScan: person projection=[name, age] - println!("Logical Plan:\n\n{}\n", plan.display_indent()); + assert_eq!( + plan.display_indent().to_string(), + "Filter: my_eq(person.age, Int32(22))\ + \n TableScan: person projection=[name, age]" + ); // The query below doesn't respect a filter `where age = 22` because // the plan has been rewritten using UDF which returns always true // // And the output verifies the predicates have been changed (as the my_eq // function always returns true) - // - // +--------+-----+ - // | name | age | - // +--------+-----+ - // | Andy | 11 | - // | Andrew | 22 | - // | Oleks | 33 | - // +--------+-----+ - ctx.sql(sql).await?.show().await?; + assert_batches_eq!( + [ + "+--------+-----+", + "| name | age |", + "+--------+-----+", + "| Andy | 11 |", + "| Andrew | 22 |", + "| Oleks | 33 |", + "+--------+-----+", + ], + &ctx.sql(sql).await?.collect().await? + ); // however we can see the rule doesn't trigger for queries with predicates // other than `=` - // - // +-------+-----+ - // | name | age | - // +-------+-----+ - // | Andy | 11 | - // | Oleks | 33 | - // +-------+-----+ - ctx.sql("SELECT * FROM person WHERE age <> 22") - .await? - .show() - .await?; + assert_batches_eq!( + [ + "+-------+-----+", + "| name | age |", + "+-------+-----+", + "| Andy | 11 |", + "| Oleks | 33 |", + "+-------+-----+", + ], + &ctx.sql("SELECT * FROM person WHERE age <> 22") + .await? + .collect() + .await? + ); Ok(()) } diff --git a/datafusion-examples/examples/parquet_sql.rs b/datafusion-examples/examples/parquet_sql.rs deleted file mode 100644 index fb438a7832cb..000000000000 --- a/datafusion-examples/examples/parquet_sql.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::error::Result; -use datafusion::prelude::*; - -/// This example demonstrates executing a simple query against an Arrow data source (Parquet) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { - // create local session context - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::parquet_test_data(); - - // register parquet file with the execution context - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; - - // execute the query - let df = ctx - .sql( - "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ - FROM alltypes_plain \ - WHERE id > 1 AND tinyint_col < double_col", - ) - .await?; - - // print the results - df.show().await?; - - Ok(()) -} diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index 6444eb68b6b2..a1fc5d269a04 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -113,7 +113,7 @@ async fn query_parquet_demo() -> Result<()> { vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], )? // Directly parsing the SQL text into a sort expression is not supported yet, so - // construct it programatically + // construct it programmatically .sort(vec![col("double_col").sort(false, false)])? .limit(0, Some(1))?; diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs deleted file mode 100644 index 06286d5d66ed..000000000000 --- a/datafusion-examples/examples/rewrite_expr.rs +++ /dev/null @@ -1,251 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{ - AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, -}; -use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; -use datafusion_optimizer::optimizer::{ApplyOrder, Optimizer}; -use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; -use datafusion_sql::planner::{ContextProvider, SqlToRel}; -use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; -use datafusion_sql::sqlparser::parser::Parser; -use datafusion_sql::TableReference; -use std::any::Any; -use std::sync::Arc; - -pub fn main() -> Result<()> { - // produce a logical plan using the datafusion-sql crate - let dialect = PostgreSqlDialect {}; - let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; - let statements = Parser::parse_sql(&dialect, sql)?; - - // produce a logical plan using the datafusion-sql crate - let context_provider = MyContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context_provider); - let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; - println!( - "Unoptimized Logical Plan:\n\n{}\n", - logical_plan.display_indent() - ); - - // run the analyzer with our custom rule - let config = OptimizerContext::default().with_skip_failing_rules(false); - let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); - let analyzed_plan = - analyzer.execute_and_check(logical_plan, config.options(), |_, _| {})?; - println!( - "Analyzed Logical Plan:\n\n{}\n", - analyzed_plan.display_indent() - ); - - // then run the optimizer with our custom rule - let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; - println!( - "Optimized Logical Plan:\n\n{}\n", - optimized_plan.display_indent() - ); - - Ok(()) -} - -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { - println!( - "After applying rule '{}':\n{}\n", - rule.name(), - plan.display_indent() - ) -} - -/// An example analyzer rule that changes Int64 literals to UInt64 -struct MyAnalyzerRule {} - -impl AnalyzerRule for MyAnalyzerRule { - fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - Self::analyze_plan(plan) - } - - fn name(&self) -> &str { - "my_analyzer_rule" - } -} - -impl MyAnalyzerRule { - fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(|plan| { - Ok(match plan { - LogicalPlan::Filter(filter) => { - let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?)) - } - _ => Transformed::no(plan), - }) - }) - .data() - } - - fn analyze_expr(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Literal(ScalarValue::Int64(i)) => { - // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) - } - _ => Transformed::no(expr), - }) - }) - .data() - } -} - -/// An example optimizer rule that rewrite BETWEEN expression to binary compare expressions -struct MyOptimizerRule {} - -impl OptimizerRule for MyOptimizerRule { - fn name(&self) -> &str { - "my_optimizer_rule" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } - - fn supports_rewrite(&self) -> bool { - true - } - - fn rewrite( - &self, - plan: LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result, DataFusionError> { - match plan { - LogicalPlan::Filter(filter) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input.clone(), - )?))) - } - _ => Ok(Transformed::no(plan)), - } - } -} - -/// use rewrite_expr to modify the expression tree. -fn my_rewrite(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - // unbox - let expr: Expr = *expr; - let low: Expr = *low; - let high: Expr = *high; - if negated { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } else { - Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) - } - } - _ => Transformed::no(expr), - }) - }) - .data() -} - -#[derive(Default)] -struct MyContextProvider { - options: ConfigOptions, -} - -impl ContextProvider for MyContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - if name.table() == "person" { - Ok(Arc::new(MyTableSource { - schema: Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("age", DataType::UInt8, false), - ])), - })) - } else { - plan_err!("table not found") - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn udf_names(&self) -> Vec { - Vec::new() - } - - fn udaf_names(&self) -> Vec { - Vec::new() - } - - fn udwf_names(&self) -> Vec { - Vec::new() - } -} - -struct MyTableSource { - schema: SchemaRef, -} - -impl TableSource for MyTableSource { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index 95339eff1cae..563f02cee6a6 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -132,7 +132,7 @@ async fn main() -> Result<()> { Ok(()) } -/// Create a `PartitionEvalutor` to evaluate this function on a new +/// Create a `PartitionEvaluator` to evaluate this function on a new /// partition. fn make_partition_evaluator() -> Result> { Ok(Box::new(MyPartitionEvaluator::new())) diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_analysis.rs index 3995988751c7..9a2aabaa79c2 100644 --- a/datafusion-examples/examples/sql_analysis.rs +++ b/datafusion-examples/examples/sql_analysis.rs @@ -87,7 +87,7 @@ fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { let mut groups = vec![]; while let Some(node) = to_visit.pop() { - // if we encouter a join, we know were at the root of the tree + // if we encounter a join, we know were at the root of the tree // count this tree and recurse on it's inputs if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) { let (group_count, inputs) = count_tree(node); diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index bd2265c85003..26e03a3b9893 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -126,7 +126,7 @@ pub use struct_builder::ScalarStructBuilder; /// /// # Nested Types /// -/// `List` / `LargeList` / `FixedSizeList` / `Struct` are represented as a +/// `List` / `LargeList` / `FixedSizeList` / `Struct` / `Map` are represented as a /// single element array of the corresponding type. /// /// ## Example: Creating [`ScalarValue::Struct`] using [`ScalarStructBuilder`] @@ -247,6 +247,8 @@ pub enum ScalarValue { /// Represents a single element [`StructArray`] as an [`ArrayRef`]. See /// [`ScalarValue`] for examples of how to create instances of this type. Struct(Arc), + /// Represents a single element [`MapArray`] as an [`ArrayRef`]. + Map(Arc), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -370,6 +372,8 @@ impl PartialEq for ScalarValue { (LargeList(_), _) => false, (Struct(v1), Struct(v2)) => v1.eq(v2), (Struct(_), _) => false, + (Map(v1), Map(v2)) => v1.eq(v2), + (Map(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -502,6 +506,8 @@ impl PartialOrd for ScalarValue { partial_cmp_struct(struct_arr1, struct_arr2) } (Struct(_), _) => None, + (Map(map_arr1), Map(map_arr2)) => partial_cmp_map(map_arr1, map_arr2), + (Map(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -631,6 +637,34 @@ fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option, m2: &Arc) -> Option { + if m1.len() != m2.len() { + return None; + } + + if m1.data_type() != m2.data_type() { + return None; + } + + for col_index in 0..m1.len() { + let arr1 = m1.entries().column(col_index); + let arr2 = m2.entries().column(col_index); + + let lt_res = arrow::compute::kernels::cmp::lt(arr1, arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(arr1, arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + } + Some(Ordering::Equal) +} + impl Eq for ScalarValue {} //Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper @@ -696,6 +730,9 @@ impl std::hash::Hash for ScalarValue { Struct(arr) => { hash_nested_array(arr.to_owned() as ArrayRef, state); } + Map(arr) => { + hash_nested_array(arr.to_owned() as ArrayRef, state); + } Date32(v) => v.hash(state), Date64(v) => v.hash(state), Time32Second(v) => v.hash(state), @@ -1132,6 +1169,7 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.data_type().to_owned(), ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Struct(arr) => arr.data_type().to_owned(), + ScalarValue::Map(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1403,6 +1441,7 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.len() == arr.null_count(), ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Struct(arr) => arr.len() == arr.null_count(), + ScalarValue::Map(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1420,7 +1459,10 @@ impl ScalarValue { ScalarValue::DurationMillisecond(v) => v.is_none(), ScalarValue::DurationMicrosecond(v) => v.is_none(), ScalarValue::DurationNanosecond(v) => v.is_none(), - ScalarValue::Union(v, _, _) => v.is_none(), + ScalarValue::Union(v, _, _) => match v { + Some((_, s)) => s.is_null(), + None => true, + }, ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -2172,6 +2214,9 @@ impl ScalarValue { ScalarValue::Struct(arr) => { Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } + ScalarValue::Map(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) } @@ -2802,6 +2847,9 @@ impl ScalarValue { ScalarValue::Struct(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } + ScalarValue::Map(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val)? } @@ -2937,6 +2985,7 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.get_array_memory_size(), ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(arr) => arr.get_array_memory_size(), + ScalarValue::Map(arr) => arr.get_array_memory_size(), ScalarValue::Union(vals, fields, _mode) => { vals.as_ref() .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) @@ -3269,6 +3318,12 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), + DataType::Map(fields, sorted) => ScalarValue::Map( + new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1) + .as_map() + .to_owned() + .into(), + ), DataType::Union(fields, mode) => { ScalarValue::Union(None, fields.clone(), *mode) } @@ -3399,6 +3454,43 @@ impl fmt::Display for ScalarValue { .join(",") )? } + ScalarValue::Map(map_arr) => { + if map_arr.null_count() == map_arr.len() { + write!(f, "NULL")?; + return Ok(()); + } + + write!( + f, + "[{}]", + map_arr + .iter() + .map(|struct_array| { + if let Some(arr) = struct_array { + let mut buffer = VecDeque::new(); + for i in 0..arr.len() { + let key = + array_value_to_string(arr.column(0), i).unwrap(); + let value = + array_value_to_string(arr.column(1), i).unwrap(); + buffer.push_back(format!("{}:{}", key, value)); + } + format!( + "{{{}}}", + buffer + .into_iter() + .collect::>() + .join(",") + .as_str() + ) + } else { + "NULL".to_string() + } + }) + .collect::>() + .join(",") + )? + } ScalarValue::Union(val, _fields, _mode) => match val { Some((id, val)) => write!(f, "{}:{}", id, val)?, None => write!(f, "NULL")?, @@ -3492,6 +3584,33 @@ impl fmt::Debug for ScalarValue { .join(",") ) } + ScalarValue::Map(map_arr) => { + write!( + f, + "Map([{}])", + map_arr + .iter() + .map(|struct_array| { + if let Some(arr) = struct_array { + let buffer: Vec = (0..arr.len()) + .map(|i| { + let key = array_value_to_string(arr.column(0), i) + .unwrap(); + let value = + array_value_to_string(arr.column(1), i) + .unwrap(); + format!("{key:?}:{value:?}") + }) + .collect(); + format!("{{{}}}", buffer.join(",")) + } else { + "NULL".to_string() + } + }) + .collect::>() + .join(",") + ) + } ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), @@ -3580,7 +3699,7 @@ mod tests { use super::*; use crate::cast::{ - as_string_array, as_struct_array, as_uint32_array, as_uint64_array, + as_map_array, as_string_array, as_struct_array, as_uint32_array, as_uint64_array, }; use crate::assert_batches_eq; @@ -3594,6 +3713,31 @@ mod tests { use chrono::NaiveDate; use rand::Rng; + #[test] + fn test_scalar_value_from_for_map() { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + + let expected = builder.finish(); + + let sv = ScalarValue::Map(Arc::new(expected.clone())); + let map_arr = sv.to_array().unwrap(); + let actual = as_map_array(&map_arr).unwrap(); + assert_eq!(actual, &expected); + } + #[test] fn test_scalar_value_from_for_struct() { let boolean = Arc::new(BooleanArray::from(vec![false])); @@ -6158,6 +6302,7 @@ mod tests { .unwrap(); assert_eq!(s.to_string(), "{a:1,b:}"); + assert_eq!(format!("{s:?}"), r#"Struct({a:1,b:})"#); let ScalarValue::Struct(arr) = s else { panic!("Expected struct"); @@ -6199,6 +6344,50 @@ mod tests { assert_batches_eq!(&expected, &[batch]); } + #[test] + fn test_map_display_and_debug() { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + + let map_value = ScalarValue::Map(Arc::new(builder.finish())); + + assert_eq!(map_value.to_string(), "[{joe:1},{blogs:2,foo:4},{},NULL]"); + assert_eq!( + format!("{map_value:?}"), + r#"Map([{"joe":"1"},{"blogs":"2","foo":"4"},{},NULL])"# + ); + + let ScalarValue::Map(arr) = map_value else { + panic!("Expected map"); + }; + + //verify compared to arrow display + let batch = RecordBatch::try_from_iter(vec![("m", arr as _)]).unwrap(); + let expected = [ + "+--------------------+", + "| m |", + "+--------------------+", + "| {joe: 1} |", + "| {blogs: 2, foo: 4} |", + "| {} |", + "| |", + "+--------------------+", + ]; + assert_batches_eq!(&expected, &[batch]); + } + #[test] fn test_build_timestamp_millisecond_list() { let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; @@ -6328,4 +6517,33 @@ mod tests { } intervals } + + fn union_fields() -> UnionFields { + [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + ] + .into_iter() + .collect() + } + + #[test] + fn sparse_scalar_union_is_null() { + let sparse_scalar = ScalarValue::Union( + Some((0_i8, Box::new(ScalarValue::Int32(None)))), + union_fields(), + UnionMode::Sparse, + ); + assert!(sparse_scalar.is_null()); + } + + #[test] + fn dense_scalar_union_is_null() { + let dense_scalar = ScalarValue::Union( + Some((0_i8, Box::new(ScalarValue::Int32(None)))), + union_fields(), + UnionMode::Dense, + ); + assert!(dense_scalar.is_null()); + } } diff --git a/datafusion/core/benches/parquet_statistic.rs b/datafusion/core/benches/parquet_statistic.rs index b58ecc13aee0..3595e8773b07 100644 --- a/datafusion/core/benches/parquet_statistic.rs +++ b/datafusion/core/benches/parquet_statistic.rs @@ -18,20 +18,26 @@ //! Benchmarks of benchmark for extracting arrow statistics from parquet use arrow::array::{ArrayRef, DictionaryArray, Float64Array, StringArray, UInt64Array}; -use arrow_array::{Int32Array, RecordBatch}; +use arrow_array::{Int32Array, Int64Array, RecordBatch}; use arrow_schema::{ DataType::{self, *}, Field, Schema, }; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion::datasource::physical_plan::parquet::StatisticsConverter; -use parquet::arrow::{arrow_reader::ArrowReaderBuilder, ArrowWriter}; -use parquet::file::properties::WriterProperties; +use parquet::{ + arrow::arrow_reader::ArrowReaderOptions, file::properties::WriterProperties, +}; +use parquet::{ + arrow::{arrow_reader::ArrowReaderBuilder, ArrowWriter}, + file::properties::EnabledStatistics, +}; use std::sync::Arc; use tempfile::NamedTempFile; #[derive(Debug, Clone)] enum TestTypes { UInt64, + Int64, F64, String, Dictionary, @@ -43,6 +49,7 @@ impl fmt::Display for TestTypes { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { TestTypes::UInt64 => write!(f, "UInt64"), + TestTypes::Int64 => write!(f, "Int64"), TestTypes::F64 => write!(f, "F64"), TestTypes::String => write!(f, "String"), TestTypes::Dictionary => write!(f, "Dictionary(Int32, String)"), @@ -50,11 +57,18 @@ impl fmt::Display for TestTypes { } } -fn create_parquet_file(dtype: TestTypes, row_groups: usize) -> NamedTempFile { +fn create_parquet_file( + dtype: TestTypes, + row_groups: usize, + data_page_row_count_limit: &Option, +) -> NamedTempFile { let schema = match dtype { TestTypes::UInt64 => { Arc::new(Schema::new(vec![Field::new("col", DataType::UInt64, true)])) } + TestTypes::Int64 => { + Arc::new(Schema::new(vec![Field::new("col", DataType::Int64, true)])) + } TestTypes::F64 => Arc::new(Schema::new(vec![Field::new( "col", DataType::Float64, @@ -70,7 +84,14 @@ fn create_parquet_file(dtype: TestTypes, row_groups: usize) -> NamedTempFile { )])), }; - let props = WriterProperties::builder().build(); + let mut props = WriterProperties::builder().set_max_row_group_size(row_groups); + if let Some(limit) = data_page_row_count_limit { + props = props + .set_data_page_row_count_limit(*limit) + .set_statistics_enabled(EnabledStatistics::Page); + }; + let props = props.build(); + let file = tempfile::Builder::new() .suffix(".parquet") .tempfile() @@ -82,11 +103,21 @@ fn create_parquet_file(dtype: TestTypes, row_groups: usize) -> NamedTempFile { for _ in 0..row_groups { let batch = match dtype { TestTypes::UInt64 => make_uint64_batch(), + TestTypes::Int64 => make_int64_batch(), TestTypes::F64 => make_f64_batch(), TestTypes::String => make_string_batch(), TestTypes::Dictionary => make_dict_batch(), }; - writer.write(&batch).unwrap(); + if data_page_row_count_limit.is_some() { + // Send batches one at a time. This allows the + // writer to apply the page limit, that is only + // checked on RecordBatch boundaries. + for i in 0..batch.num_rows() { + writer.write(&batch.slice(i, 1)).unwrap(); + } + } else { + writer.write(&batch).unwrap(); + } } writer.close().unwrap(); file @@ -109,6 +140,23 @@ fn make_uint64_batch() -> RecordBatch { .unwrap() } +fn make_int64_batch() -> RecordBatch { + let array: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])); + RecordBatch::try_new( + Arc::new(arrow::datatypes::Schema::new(vec![ + arrow::datatypes::Field::new("col", Int64, false), + ])), + vec![array], + ) + .unwrap() +} + fn make_f64_batch() -> RecordBatch { let array: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0])); RecordBatch::try_new( @@ -150,36 +198,88 @@ fn make_dict_batch() -> RecordBatch { fn criterion_benchmark(c: &mut Criterion) { let row_groups = 100; use TestTypes::*; - let types = vec![UInt64, F64, String, Dictionary]; + let types = vec![Int64, UInt64, F64, String, Dictionary]; + let data_page_row_count_limits = vec![None, Some(1)]; for dtype in types { - let file = create_parquet_file(dtype.clone(), row_groups); - let file = file.reopen().unwrap(); - let reader = ArrowReaderBuilder::try_new(file).unwrap(); - let metadata = reader.metadata(); - let row_groups = metadata.row_groups(); - - let mut group = - c.benchmark_group(format!("Extract statistics for {}", dtype.clone())); - group.bench_function( - BenchmarkId::new("extract_statistics", dtype.clone()), - |b| { - b.iter(|| { - let converter = StatisticsConverter::try_new( - "col", - reader.schema(), - reader.parquet_schema(), - ) - .unwrap(); - - let _ = converter.row_group_mins(row_groups.iter()).unwrap(); - let _ = converter.row_group_maxes(row_groups.iter()).unwrap(); - let _ = converter.row_group_null_counts(row_groups.iter()).unwrap(); - let _ = converter.row_group_row_counts(row_groups.iter()).unwrap(); - }) - }, - ); - group.finish(); + for data_page_row_count_limit in &data_page_row_count_limits { + let file = + create_parquet_file(dtype.clone(), row_groups, data_page_row_count_limit); + let file = file.reopen().unwrap(); + let options = ArrowReaderOptions::new().with_page_index(true); + let reader = ArrowReaderBuilder::try_new_with_options(file, options).unwrap(); + let metadata = reader.metadata(); + let row_groups = metadata.row_groups(); + let row_group_indices: Vec<_> = (0..row_groups.len()).collect(); + + let statistic_type = if data_page_row_count_limit.is_some() { + "data page" + } else { + "row group" + }; + + let mut group = c.benchmark_group(format!( + "Extract {} statistics for {}", + statistic_type, + dtype.clone() + )); + group.bench_function( + BenchmarkId::new("extract_statistics", dtype.clone()), + |b| { + b.iter(|| { + let converter = StatisticsConverter::try_new( + "col", + reader.schema(), + reader.parquet_schema(), + ) + .unwrap(); + + if data_page_row_count_limit.is_some() { + let column_page_index = reader + .metadata() + .column_index() + .expect("File should have column page indices"); + + let column_offset_index = reader + .metadata() + .offset_index() + .expect("File should have column offset indices"); + + let _ = converter.data_page_mins( + column_page_index, + column_offset_index, + &row_group_indices, + ); + let _ = converter.data_page_maxes( + column_page_index, + column_offset_index, + &row_group_indices, + ); + let _ = converter.data_page_null_counts( + column_page_index, + column_offset_index, + &row_group_indices, + ); + let _ = converter.data_page_row_counts( + column_offset_index, + row_groups, + &row_group_indices, + ); + } else { + let _ = converter.row_group_mins(row_groups.iter()).unwrap(); + let _ = converter.row_group_maxes(row_groups.iter()).unwrap(); + let _ = converter + .row_group_null_counts(row_groups.iter()) + .unwrap(); + let _ = converter + .row_group_row_counts(row_groups.iter()) + .unwrap(); + } + }) + }, + ); + group.finish(); + } } } diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index c0e02d388af4..00f6d5916751 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -82,7 +82,7 @@ fn register_defs(ctx: SessionContext, defs: Vec) -> SessionContext { defs.iter().for_each(|TableDef { name, schema }| { ctx.register_table( name, - Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![]).unwrap()), + Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![vec![]]).unwrap()), ) .unwrap(); }); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index fcecae27a52f..59369aba57a9 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -19,16 +19,19 @@ // TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 +use arrow::array::{ + BooleanBuilder, FixedSizeBinaryBuilder, LargeStringBuilder, StringBuilder, +}; use arrow::datatypes::i256; use arrow::{array::ArrayRef, datatypes::DataType}; use arrow_array::{ new_empty_array, new_null_array, BinaryArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, - LargeStringArray, StringArray, Time32MillisecondArray, Time32SecondArray, - Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }; use arrow_schema::{Field, FieldRef, Schema, TimeUnit}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; @@ -392,51 +395,73 @@ macro_rules! get_statistics { }) }, DataType::Binary => Ok(Arc::new(BinaryArray::from_iter( - [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator).map(|x| x.map(|x| x.to_vec())), + [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator) ))), DataType::LargeBinary => Ok(Arc::new(LargeBinaryArray::from_iter( - [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator).map(|x| x.map(|x|x.to_vec())), - ))), - DataType::Utf8 => Ok(Arc::new(StringArray::from_iter( - [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| { - let res = std::str::from_utf8(x).map(|s| s.to_string()).ok(); - if res.is_none() { - log::debug!("Utf8 statistics is a non-UTF8 value, ignoring it."); - } - res - }) - }), + [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator) ))), + DataType::Utf8 => { + let iterator = [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator); + let mut builder = StringBuilder::new(); + for x in iterator { + let Some(x) = x else { + builder.append_null(); // no statistics value + continue; + }; + + let Ok(x) = std::str::from_utf8(x) else { + log::debug!("Utf8 statistics is a non-UTF8 value, ignoring it."); + builder.append_null(); + continue; + }; + + builder.append_value(x); + } + Ok(Arc::new(builder.finish())) + }, DataType::LargeUtf8 => { - Ok(Arc::new(LargeStringArray::from_iter( - [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| { - let res = std::str::from_utf8(x).map(|s| s.to_string()).ok(); - if res.is_none() { - log::debug!("LargeUtf8 statistics is a non-UTF8 value, ignoring it."); - } - res - }) - }), - ))) - } - DataType::FixedSizeBinary(size) => Ok(Arc::new(FixedSizeBinaryArray::from( - [<$stat_type_prefix FixedLenByteArrayStatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| { - if x.len().try_into() == Ok(*size) { - Some(x) - } else { - log::debug!( - "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", - size, - x.len(), - ); - None - } - }) - }).collect::>(), - ))), + let iterator = [<$stat_type_prefix ByteArrayStatsIterator>]::new($iterator); + let mut builder = LargeStringBuilder::new(); + for x in iterator { + let Some(x) = x else { + builder.append_null(); // no statistics value + continue; + }; + + let Ok(x) = std::str::from_utf8(x) else { + log::debug!("Utf8 statistics is a non-UTF8 value, ignoring it."); + builder.append_null(); + continue; + }; + + builder.append_value(x); + } + Ok(Arc::new(builder.finish())) + }, + DataType::FixedSizeBinary(size) => { + let iterator = [<$stat_type_prefix FixedLenByteArrayStatsIterator>]::new($iterator); + let mut builder = FixedSizeBinaryBuilder::new(*size); + for x in iterator { + let Some(x) = x else { + builder.append_null(); // no statistics value + continue; + }; + + // ignore invalid values + if x.len().try_into() != Ok(*size){ + log::debug!( + "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", + size, + x.len(), + ); + builder.append_null(); + continue; + } + + builder.append_value(x).expect("ensure to append successfully here, because size have been checked before"); + } + Ok(Arc::new(builder.finish())) + }, DataType::Decimal128(precision, scale) => { let arr = Decimal128Array::from_iter( [<$stat_type_prefix Decimal128StatsIterator>]::new($iterator) @@ -600,6 +625,31 @@ make_data_page_stats_iterator!( Index::DOUBLE, f64 ); +make_data_page_stats_iterator!( + MinByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.min.clone() }, + Index::BYTE_ARRAY, + ByteArray +); +make_data_page_stats_iterator!( + MaxByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.max.clone() }, + Index::BYTE_ARRAY, + ByteArray +); +make_data_page_stats_iterator!( + MaxFixedLenByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.max.clone() }, + Index::FIXED_LEN_BYTE_ARRAY, + FixedLenByteArray +); + +make_data_page_stats_iterator!( + MinFixedLenByteArrayDataPageStatsIterator, + |x: &PageIndex| { x.min.clone() }, + Index::FIXED_LEN_BYTE_ARRAY, + FixedLenByteArray +); macro_rules! get_decimal_page_stats_iterator { ($iterator_type: ident, $func: ident, $stat_value_type: ident, $convert_func: ident) => { @@ -634,9 +684,7 @@ macro_rules! get_decimal_page_stats_iterator { .indexes .iter() .map(|x| { - Some($stat_value_type::from( - x.$func.unwrap_or_default(), - )) + x.$func.and_then(|x| Some($stat_value_type::from(x))) }) .collect::>(), ), @@ -645,9 +693,7 @@ macro_rules! get_decimal_page_stats_iterator { .indexes .iter() .map(|x| { - Some($stat_value_type::from( - x.$func.unwrap_or_default(), - )) + x.$func.and_then(|x| Some($stat_value_type::from(x))) }) .collect::>(), ), @@ -656,9 +702,9 @@ macro_rules! get_decimal_page_stats_iterator { .indexes .iter() .map(|x| { - Some($convert_func( - x.clone().$func.unwrap_or_default().data(), - )) + x.clone() + .$func + .and_then(|x| Some($convert_func(x.data()))) }) .collect::>(), ), @@ -667,9 +713,9 @@ macro_rules! get_decimal_page_stats_iterator { .indexes .iter() .map(|x| { - Some($convert_func( - x.clone().$func.unwrap_or_default().data(), - )) + x.clone() + .$func + .and_then(|x| Some($convert_func(x.data()))) }) .collect::>(), ), @@ -713,37 +759,30 @@ get_decimal_page_stats_iterator!( i256, from_bytes_to_i256 ); -make_data_page_stats_iterator!( - MinByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.min.clone() }, - Index::BYTE_ARRAY, - ByteArray -); -make_data_page_stats_iterator!( - MaxByteArrayDataPageStatsIterator, - |x: &PageIndex| { x.max.clone() }, - Index::BYTE_ARRAY, - ByteArray -); macro_rules! get_data_page_statistics { ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { paste! { match $data_type { - Some(DataType::Boolean) => Ok(Arc::new( - BooleanArray::from_iter( - [<$stat_type_prefix BooleanDataPageStatsIterator>]::new($iterator) - .flatten() - // BooleanArray::from_iter required a sized iterator, so collect into Vec first - .collect::>() - .into_iter() - ) - )), + Some(DataType::Boolean) => { + let iterator = [<$stat_type_prefix BooleanDataPageStatsIterator>]::new($iterator); + let mut builder = BooleanBuilder::new(); + for x in iterator { + for x in x.into_iter() { + let Some(x) = x else { + builder.append_null(); // no statistics value + continue; + }; + builder.append_value(x); + } + } + Ok(Arc::new(builder.finish())) + }, Some(DataType::UInt8) => Ok(Arc::new( UInt8Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| u8::try_from(x).ok()) }) }) @@ -754,7 +793,7 @@ macro_rules! get_data_page_statistics { UInt16Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| u16::try_from(x).ok()) }) }) @@ -765,8 +804,8 @@ macro_rules! get_data_page_statistics { UInt32Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { - x.and_then(|x| u32::try_from(x).ok()) + x.into_iter().map(|x| { + x.and_then(|x| Some(x as u32)) }) }) .flatten() @@ -775,8 +814,8 @@ macro_rules! get_data_page_statistics { UInt64Array::from_iter( [<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { - x.and_then(|x| u64::try_from(x).ok()) + x.into_iter().map(|x| { + x.and_then(|x| Some(x as u64)) }) }) .flatten() @@ -785,7 +824,7 @@ macro_rules! get_data_page_statistics { Int8Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| i8::try_from(x).ok()) }) }) @@ -796,7 +835,7 @@ macro_rules! get_data_page_statistics { Int16Array::from_iter( [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { + x.into_iter().map(|x| { x.and_then(|x| i16::try_from(x).ok()) }) }) @@ -809,8 +848,8 @@ macro_rules! get_data_page_statistics { Float16Array::from_iter( [<$stat_type_prefix Float16DataPageStatsIterator>]::new($iterator) .map(|x| { - x.into_iter().filter_map(|x| { - x.and_then(|x| Some(from_bytes_to_f16(x.data()))) + x.into_iter().map(|x| { + x.and_then(|x| from_bytes_to_f16(x.data())) }) }) .flatten() @@ -820,32 +859,48 @@ macro_rules! get_data_page_statistics { Some(DataType::Float64) => Ok(Arc::new(Float64Array::from_iter([<$stat_type_prefix Float64DataPageStatsIterator>]::new($iterator).flatten()))), Some(DataType::Binary) => Ok(Arc::new(BinaryArray::from_iter([<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).flatten()))), Some(DataType::LargeBinary) => Ok(Arc::new(LargeBinaryArray::from_iter([<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).flatten()))), - Some(DataType::Utf8) => Ok(Arc::new(StringArray::from( - [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).map(|x| { - x.into_iter().filter_map(|x| { - x.and_then(|x| { - let res = std::str::from_utf8(x.data()).map(|s| s.to_string()).ok(); - if res.is_none() { + Some(DataType::Utf8) => { + let mut builder = StringBuilder::new(); + let iterator = [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator); + for x in iterator { + for x in x.into_iter() { + let Some(x) = x else { + builder.append_null(); // no statistics value + continue; + }; + + let Ok(x) = std::str::from_utf8(x.data()) else { log::debug!("Utf8 statistics is a non-UTF8 value, ignoring it."); - } - res - }) - }) - }).flatten().collect::>(), - ))), - Some(DataType::LargeUtf8) => Ok(Arc::new(LargeStringArray::from( - [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator).map(|x| { - x.into_iter().filter_map(|x| { - x.and_then(|x| { - let res = std::str::from_utf8(x.data()).map(|s| s.to_string()).ok(); - if res.is_none() { - log::debug!("LargeUtf8 statistics is a non-UTF8 value, ignoring it."); - } - res - }) - }) - }).flatten().collect::>(), - ))), + builder.append_null(); + continue; + }; + + builder.append_value(x); + } + } + Ok(Arc::new(builder.finish())) + }, + Some(DataType::LargeUtf8) => { + let mut builder = LargeStringBuilder::new(); + let iterator = [<$stat_type_prefix ByteArrayDataPageStatsIterator>]::new($iterator); + for x in iterator { + for x in x.into_iter() { + let Some(x) = x else { + builder.append_null(); // no statistics value + continue; + }; + + let Ok(x) = std::str::from_utf8(x.data()) else { + log::debug!("LargeUtf8 statistics is a non-UTF8 value, ignoring it."); + builder.append_null(); + continue; + }; + + builder.append_value(x); + } + } + Ok(Arc::new(builder.finish())) + }, Some(DataType::Dictionary(_, value_type)) => { [<$stat_type_prefix:lower _ page_statistics>](Some(value_type), $iterator) }, @@ -861,14 +916,14 @@ macro_rules! get_data_page_statistics { Some(DataType::Date32) => Ok(Arc::new(Date32Array::from_iter([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten()))), Some(DataType::Date64) => Ok( Arc::new( - Date64Array::from([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) + Date64Array::from_iter([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) .map(|x| { x.into_iter() - .filter_map(|x| { + .map(|x| { x.and_then(|x| i64::try_from(x).ok()) }) - .map(|x| x * 24 * 60 * 60 * 1000) - }).flatten().collect::>() + .map(|x| x.map(|x| x * 24 * 60 * 60 * 1000)) + }).flatten() ) ) ), @@ -903,7 +958,31 @@ macro_rules! get_data_page_statistics { new_empty_array(&DataType::Time64(unit.clone())) } }) - } + }, + Some(DataType::FixedSizeBinary(size)) => { + let mut builder = FixedSizeBinaryBuilder::new(*size); + let iterator = [<$stat_type_prefix FixedLenByteArrayDataPageStatsIterator>]::new($iterator); + for x in iterator { + for x in x.into_iter() { + let Some(x) = x else { + builder.append_null(); // no statistics value + continue; + }; + + if x.len() == *size as usize { + let _ = builder.append_value(x.data()); + } else { + log::debug!( + "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", + size, + x.len(), + ); + builder.append_null(); + } + } + } + Ok(Arc::new(builder.finish())) + }, _ => unimplemented!() } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 4685f194fe29..04debf498aa9 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1392,6 +1392,10 @@ impl FunctionRegistry for SessionContext { self.state.write().register_function_rewrite(rewrite) } + fn expr_planners(&self) -> Vec> { + self.state.read().expr_planners() + } + fn register_user_defined_sql_planner( &mut self, user_defined_sql_planner: Arc, diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index aa81d77cf682..d056b91c2747 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -231,10 +231,24 @@ impl SessionState { ); } + let user_defined_sql_planners: Vec> = vec![ + Arc::new(functions::core::planner::CoreFunctionPlanner::default()), + // register crate of array expressions (if enabled) + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::ArrayFunctionPlanner), + #[cfg(feature = "array_expressions")] + Arc::new(functions_array::planner::FieldAccessPlanner), + #[cfg(any( + feature = "datetime_expressions", + feature = "unicode_expressions" + ))] + Arc::new(functions::planner::UserDefinedFunctionPlanner), + ]; + let mut new_self = SessionState { session_id, analyzer: Analyzer::new(), - user_defined_sql_planners: vec![], + user_defined_sql_planners, optimizer: Optimizer::new(), physical_optimizers: PhysicalOptimizer::new(), query_planner: Arc::new(DefaultQueryPlanner {}), @@ -958,23 +972,7 @@ impl SessionState { query = query.with_user_defined_planner(planner.clone()); } - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - { - let array_planner = - Arc::new(functions_array::planner::ArrayFunctionPlanner) as _; - - let field_access_planner = - Arc::new(functions_array::planner::FieldAccessPlanner) as _; - - query - .with_user_defined_planner(array_planner) - .with_user_defined_planner(field_access_planner) - } - #[cfg(not(feature = "array_expressions"))] - { - query - } + query } } @@ -1186,6 +1184,10 @@ impl FunctionRegistry for SessionState { Ok(()) } + fn expr_planners(&self) -> Vec> { + self.user_defined_sql_planners.clone() + } + fn register_user_defined_sql_planner( &mut self, user_defined_sql_planner: Arc, diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 2810dca46365..fb7abcd795e8 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -631,3 +631,9 @@ doc_comment::doctest!( "../../../docs/source/user-guide/expressions.md", user_guide_expressions ); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/using-the-sql-api.md", + library_user_guide_example_usage +); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index c3bc2fcca2b5..2d1904d9e166 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -29,8 +29,9 @@ use arrow::{ }, record_batch::RecordBatch, }; -use arrow_array::Float32Array; -use arrow_schema::ArrowError; +use arrow_array::{Array, Float32Array, Float64Array, UnionArray}; +use arrow_buffer::ScalarBuffer; +use arrow_schema::{ArrowError, UnionFields, UnionMode}; use datafusion_functions_aggregate::count::count_udaf; use object_store::local::LocalFileSystem; use std::fs; @@ -234,6 +235,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { Ok(()) } + #[tokio::test] async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { let ctx = create_join_context()?; @@ -2194,3 +2196,163 @@ async fn write_parquet_results() -> Result<()> { Ok(()) } + +fn union_fields() -> UnionFields { + [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + (2, Arc::new(Field::new("C", DataType::Utf8, true))), + ] + .into_iter() + .collect() +} + +#[tokio::test] +async fn sparse_union_is_null() { + // union of [{A=1}, {A=}, {B=3.2}, {B=}, {C="a"}, {C=}] + let int_array = Int32Array::from(vec![Some(1), None, None, None, None, None]); + let float_array = Float64Array::from(vec![None, None, Some(3.2), None, None, None]); + let str_array = StringArray::from(vec![None, None, None, None, Some("a"), None]); + let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::>(); + + let children = vec![ + Arc::new(int_array) as Arc, + Arc::new(float_array), + Arc::new(str_array), + ]; + + let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); + + let field = Field::new( + "my_union", + DataType::Union(union_fields(), UnionMode::Sparse), + true, + ); + let schema = Arc::new(Schema::new(vec![field])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap(); + + let ctx = SessionContext::new(); + + ctx.register_batch("union_batch", batch).unwrap(); + + let df = ctx.table("union_batch").await.unwrap(); + + // view_all + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {A=} |", + "| {B=3.2} |", + "| {B=} |", + "| {C=a} |", + "| {C=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap()); + + // filter where is null + let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=} |", + "| {B=} |", + "| {C=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); + + // filter where is not null + let result_df = df.filter(col("my_union").is_not_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {B=3.2} |", + "| {C=a} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); +} + +#[tokio::test] +async fn dense_union_is_null() { + // union of [{A=1}, null, {B=3.2}, {A=34}] + let int_array = Int32Array::from(vec![Some(1), None]); + let float_array = Float64Array::from(vec![Some(3.2), None]); + let str_array = StringArray::from(vec![Some("a"), None]); + let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::>(); + let offsets = [0, 1, 0, 1, 0, 1] + .into_iter() + .collect::>(); + + let children = vec![ + Arc::new(int_array) as Arc, + Arc::new(float_array), + Arc::new(str_array), + ]; + + let array = + UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap(); + + let field = Field::new( + "my_union", + DataType::Union(union_fields(), UnionMode::Dense), + true, + ); + let schema = Arc::new(Schema::new(vec![field])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap(); + + let ctx = SessionContext::new(); + + ctx.register_batch("union_batch", batch).unwrap(); + + let df = ctx.table("union_batch").await.unwrap(); + + // view_all + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {A=} |", + "| {B=3.2} |", + "| {B=} |", + "| {C=a} |", + "| {C=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap()); + + // filter where is null + let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=} |", + "| {B=} |", + "| {C=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); + + // filter where is not null + let result_df = df.filter(col("my_union").is_not_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {B=3.2} |", + "| {C=a} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); +} diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index f00d17a06ffc..ceae13a469f0 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -80,7 +80,7 @@ mod sp_repartition_fuzz_tests { // Define a and f are aliases eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. - eq_properties = eq_properties.add_constants([ConstExpr::new(col_e.clone())]); + eq_properties = eq_properties.add_constants([ConstExpr::from(col_e)]); // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index ea83c1fa788d..2b4ba0b17133 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -29,15 +29,15 @@ use arrow::datatypes::{ TimestampNanosecondType, TimestampSecondType, }; use arrow_array::{ - make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, - LargeStringArray, RecordBatch, StringArray, Time32MillisecondArray, + make_array, new_null_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, + Date64Array, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, + LargeBinaryArray, LargeStringArray, RecordBatch, StringArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; -use arrow_schema::{DataType, Field, Schema}; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; use datafusion::datasource::physical_plan::parquet::StatisticsConverter; use half::f16; use parquet::arrow::arrow_reader::{ @@ -91,51 +91,60 @@ impl Int64Case { // Create a parquet file with the specified settings pub fn build(&self) -> ParquetRecordBatchReaderBuilder { - let mut output_file = tempfile::Builder::new() - .prefix("parquert_statistics_test") - .suffix(".parquet") - .tempfile() - .expect("tempfile creation"); - - let mut builder = - WriterProperties::builder().set_max_row_group_size(self.row_per_group); - if let Some(enable_stats) = self.enable_stats { - builder = builder.set_statistics_enabled(enable_stats); - } - if let Some(data_page_row_count_limit) = self.data_page_row_count_limit { - builder = builder.set_data_page_row_count_limit(data_page_row_count_limit); - } - let props = builder.build(); - let batches = vec![self.make_int64_batches_with_null()]; + build_parquet_file( + self.row_per_group, + self.enable_stats, + self.data_page_row_count_limit, + batches, + ) + } +} + +fn build_parquet_file( + row_per_group: usize, + enable_stats: Option, + data_page_row_count_limit: Option, + batches: Vec, +) -> ParquetRecordBatchReaderBuilder { + let mut output_file = tempfile::Builder::new() + .prefix("parquert_statistics_test") + .suffix(".parquet") + .tempfile() + .expect("tempfile creation"); - let schema = batches[0].schema(); + let mut builder = WriterProperties::builder().set_max_row_group_size(row_per_group); + if let Some(enable_stats) = enable_stats { + builder = builder.set_statistics_enabled(enable_stats); + } + if let Some(data_page_row_count_limit) = data_page_row_count_limit { + builder = builder.set_data_page_row_count_limit(data_page_row_count_limit); + } + let props = builder.build(); - let mut writer = - ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); + let schema = batches[0].schema(); - // if we have a datapage limit send the batches in one at a time to give - // the writer a chance to be split into multiple pages - if self.data_page_row_count_limit.is_some() { - for batch in batches { - for i in 0..batch.num_rows() { - writer.write(&batch.slice(i, 1)).expect("writing batch"); - } - } - } else { - for batch in batches { - writer.write(&batch).expect("writing batch"); + let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); + + // if we have a datapage limit send the batches in one at a time to give + // the writer a chance to be split into multiple pages + if data_page_row_count_limit.is_some() { + for batch in &batches { + for i in 0..batch.num_rows() { + writer.write(&batch.slice(i, 1)).expect("writing batch"); } } + } else { + for batch in &batches { + writer.write(batch).expect("writing batch"); + } + } - // close file - let _file_meta = writer.close().unwrap(); + let _file_meta = writer.close().unwrap(); - // open the file & get the reader - let file = output_file.reopen().unwrap(); - let options = ArrowReaderOptions::new().with_page_index(true); - ArrowReaderBuilder::try_new_with_options(file, options).unwrap() - } + let file = output_file.reopen().unwrap(); + let options = ArrowReaderOptions::new().with_page_index(true); + ArrowReaderBuilder::try_new_with_options(file, options).unwrap() } /// Defines what data to create in a parquet file @@ -386,7 +395,7 @@ async fn test_one_row_group_without_null() { // 3 rows expected_row_counts: Some(UInt64Array::from(vec![3])), column_name: "i64", - check: Check::RowGroup, + check: Check::Both, } .run() } @@ -413,7 +422,7 @@ async fn test_one_row_group_with_null_and_negative() { // 8 rows expected_row_counts: Some(UInt64Array::from(vec![8])), column_name: "i64", - check: Check::RowGroup, + check: Check::Both, } .run() } @@ -440,7 +449,7 @@ async fn test_two_row_group_with_null() { // row counts are [10, 5] expected_row_counts: Some(UInt64Array::from(vec![10, 5])), column_name: "i64", - check: Check::RowGroup, + check: Check::Both, } .run() } @@ -467,7 +476,7 @@ async fn test_two_row_groups_with_all_nulls_in_one() { // row counts are [5, 3] expected_row_counts: Some(UInt64Array::from(vec![5, 3])), column_name: "i64", - check: Check::RowGroup, + check: Check::Both, } .run() } @@ -503,6 +512,71 @@ async fn test_multiple_data_pages_nulls_and_negatives() { .run() } +#[tokio::test] +async fn test_data_page_stats_with_all_null_page() { + for data_type in &[ + DataType::Boolean, + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + DataType::Float16, + DataType::Float32, + DataType::Float64, + DataType::Date32, + DataType::Date64, + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Second), + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Binary, + DataType::LargeBinary, + DataType::FixedSizeBinary(3), + DataType::Utf8, + DataType::LargeUtf8, + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + DataType::Decimal128(8, 2), // as INT32 + DataType::Decimal128(10, 2), // as INT64 + DataType::Decimal128(20, 2), // as FIXED_LEN_BYTE_ARRAY + DataType::Decimal256(8, 2), // as INT32 + DataType::Decimal256(10, 2), // as INT64 + DataType::Decimal256(20, 2), // as FIXED_LEN_BYTE_ARRAY + ] { + let batch = + RecordBatch::try_from_iter(vec![("col", new_null_array(data_type, 4))]) + .expect("record batch creation"); + + let reader = + build_parquet_file(4, Some(EnabledStatistics::Page), Some(4), vec![batch]); + + let expected_data_type = match data_type { + DataType::Dictionary(_, value_type) => value_type.as_ref(), + _ => data_type, + }; + + // There is one data page with 4 nulls + // The statistics should be present but null + Test { + reader: &reader, + expected_min: new_null_array(expected_data_type, 1), + expected_max: new_null_array(expected_data_type, 1), + expected_null_counts: UInt64Array::from(vec![4]), + expected_row_counts: Some(UInt64Array::from(vec![4])), + column_name: "col", + check: Check::DataPage, + } + .run() + } +} + /////////////// MORE GENERAL TESTS ////////////////////// // . Many columns in a file // . Differnet data types @@ -1408,7 +1482,7 @@ async fn test_int32_range() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: Some(UInt64Array::from(vec![4])), column_name: "i", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1431,7 +1505,7 @@ async fn test_uint32_range() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: Some(UInt64Array::from(vec![4])), column_name: "u", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1453,7 +1527,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "u8", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1464,7 +1538,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "u16", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1475,7 +1549,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "u32", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1486,7 +1560,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "u64", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1508,7 +1582,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "i8", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1519,7 +1593,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "i16", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1530,7 +1604,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "i32", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1541,7 +1615,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "i64", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1563,7 +1637,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "f32", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1574,7 +1648,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "f64", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1585,7 +1659,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "f32_nan", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1596,7 +1670,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "f64_nan", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1619,7 +1693,7 @@ async fn test_float64() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "f", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1652,7 +1726,7 @@ async fn test_float16() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5, 5])), column_name: "f", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1741,7 +1815,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "string_dict_i8", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1763,7 +1837,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 2])), column_name: "int_dict_i8", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -1861,7 +1935,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), column_name: "service_fixedsize", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1915,7 +1989,7 @@ async fn test_period_in_column_names() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), column_name: "name", - check: Check::RowGroup, + check: Check::Both, } .run(); @@ -1929,7 +2003,7 @@ async fn test_period_in_column_names() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: Some(UInt64Array::from(vec![5, 5, 5])), column_name: "service.name", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -2041,7 +2115,7 @@ async fn test_missing_statistics() { expected_null_counts: UInt64Array::from(vec![None]), expected_row_counts: Some(UInt64Array::from(vec![3])), // still has row count statistics column_name: "i64", - check: Check::RowGroup, + check: Check::Both, } .run(); } @@ -2063,7 +2137,7 @@ async fn test_column_not_found() { expected_null_counts: UInt64Array::from(vec![2, 2]), expected_row_counts: Some(UInt64Array::from(vec![13, 7])), column_name: "not_a_column", - check: Check::RowGroup, + check: Check::Both, } .run_col_not_found(); } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 5e3c44c039ab..ae8a009c6292 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -571,6 +571,17 @@ async fn test_user_defined_functions_cast_to_i64() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_user_defined_sql_functions() -> Result<()> { + let ctx = SessionContext::new(); + + let sql_planners = ctx.expr_planners(); + + assert!(!sql_planners.is_empty()); + + Ok(()) +} + #[tokio::test] async fn deregister_udf() -> Result<()> { let cast2i64 = ScalarUDF::from(CastToI64UDF::new()); diff --git a/datafusion/execution/src/cache/cache_manager.rs b/datafusion/execution/src/cache/cache_manager.rs index 97529263688b..c2403e34c665 100644 --- a/datafusion/execution/src/cache/cache_manager.rs +++ b/datafusion/execution/src/cache/cache_manager.rs @@ -54,10 +54,10 @@ impl CacheManager { pub fn try_new(config: &CacheManagerConfig) -> Result> { let mut manager = CacheManager::default(); if let Some(cc) = &config.table_files_statistics_cache { - manager.file_statistic_cache = Some(cc.clone()) + manager.file_statistic_cache = Some(Arc::clone(cc)) } if let Some(lc) = &config.list_files_cache { - manager.list_files_cache = Some(lc.clone()) + manager.list_files_cache = Some(Arc::clone(lc)) } Ok(Arc::new(manager)) } diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index 25f9b9fa4d68..a9291659a3ef 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -39,7 +39,7 @@ impl CacheAccessor> for DefaultFileStatisticsCache { fn get(&self, k: &Path) -> Option> { self.statistics .get(k) - .map(|s| Some(s.value().1.clone())) + .map(|s| Some(Arc::clone(&s.value().1))) .unwrap_or(None) } @@ -55,7 +55,7 @@ impl CacheAccessor> for DefaultFileStatisticsCache { // file has changed None } else { - Some(statistics.clone()) + Some(Arc::clone(statistics)) } }) .unwrap_or(None) @@ -108,7 +108,7 @@ impl CacheAccessor>> for DefaultListFilesCache { type Extra = ObjectMeta; fn get(&self, k: &Path) -> Option>> { - self.statistics.get(k).map(|x| x.value().clone()) + self.statistics.get(k).map(|x| Arc::clone(x.value())) } fn get_with_extra( diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 85cc6f8499f0..cca25c7c3e88 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -139,7 +139,7 @@ impl DiskManager { let dir_index = thread_rng().gen_range(0..local_dirs.len()); Ok(RefCountedTempFile { - parent_temp_dir: local_dirs[dir_index].clone(), + parent_temp_dir: Arc::clone(&local_dirs[dir_index]), tempfile: Builder::new() .tempfile_in(local_dirs[dir_index].as_ref()) .map_err(DataFusionError::IoError)?, diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index 2fe0d83b1d1c..909364fa805d 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! DataFusion execution configuration and runtime structures diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 321a89127be7..3f66a304dc18 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -268,7 +268,7 @@ impl MemoryReservation { self.size = self.size.checked_sub(capacity).unwrap(); Self { size: capacity, - registration: self.registration.clone(), + registration: Arc::clone(&self.registration), } } @@ -276,7 +276,7 @@ impl MemoryReservation { pub fn new_empty(&self) -> Self { Self { size: 0, - registration: self.registration.clone(), + registration: Arc::clone(&self.registration), } } diff --git a/datafusion/execution/src/object_store.rs b/datafusion/execution/src/object_store.rs index 3ba21e399f93..9e1d94b346eb 100644 --- a/datafusion/execution/src/object_store.rs +++ b/datafusion/execution/src/object_store.rs @@ -234,7 +234,7 @@ impl ObjectStoreRegistry for DefaultObjectStoreRegistry { let s = get_url_key(url); self.object_stores .get(&s) - .map(|o| o.value().clone()) + .map(|o| Arc::clone(o.value())) .ok_or_else(|| { DataFusionError::Internal(format!( "No suitable object store found for {url}. See `RuntimeEnv::register_object_store`" diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index b3a510ef2a3f..24d61e6a8b72 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -20,15 +20,15 @@ use std::{ sync::Arc, }; -use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; - use crate::{ config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry, runtime_env::{RuntimeConfig, RuntimeEnv}, }; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::planner::UserDefinedSQLPlanner; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; /// Task Execution Context /// @@ -121,7 +121,7 @@ impl TaskContext { /// Return the [RuntimeEnv] associated with this [TaskContext] pub fn runtime_env(&self) -> Arc { - self.runtime.clone() + Arc::clone(&self.runtime) } /// Update the [`SessionConfig`] @@ -172,22 +172,29 @@ impl FunctionRegistry for TaskContext { udaf: Arc, ) -> Result>> { udaf.aliases().iter().for_each(|alias| { - self.aggregate_functions.insert(alias.clone(), udaf.clone()); + self.aggregate_functions + .insert(alias.clone(), Arc::clone(&udaf)); }); Ok(self.aggregate_functions.insert(udaf.name().into(), udaf)) } fn register_udwf(&mut self, udwf: Arc) -> Result>> { udwf.aliases().iter().for_each(|alias| { - self.window_functions.insert(alias.clone(), udwf.clone()); + self.window_functions + .insert(alias.clone(), Arc::clone(&udwf)); }); Ok(self.window_functions.insert(udwf.name().into(), udwf)) } fn register_udf(&mut self, udf: Arc) -> Result>> { udf.aliases().iter().for_each(|alias| { - self.scalar_functions.insert(alias.clone(), udf.clone()); + self.scalar_functions + .insert(alias.clone(), Arc::clone(&udf)); }); Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } #[cfg(test)] diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 760952d94815..23e98714dfa4 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -39,8 +39,6 @@ pub enum AggregateFunction { Max, /// Aggregation into an array ArrayAgg, - /// N'th value in a group according to some ordering - NthValue, } impl AggregateFunction { @@ -50,7 +48,6 @@ impl AggregateFunction { Min => "MIN", Max => "MAX", ArrayAgg => "ARRAY_AGG", - NthValue => "NTH_VALUE", } } } @@ -69,7 +66,6 @@ impl FromStr for AggregateFunction { "max" => AggregateFunction::Max, "min" => AggregateFunction::Min, "array_agg" => AggregateFunction::ArrayAgg, - "nth_value" => AggregateFunction::NthValue, _ => { return plan_err!("There is no built-in function named {name}"); } @@ -114,7 +110,6 @@ impl AggregateFunction { coerced_data_types[0].clone(), input_expr_nullable[0], )))), - AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), } } @@ -124,7 +119,6 @@ impl AggregateFunction { match self { AggregateFunction::Max | AggregateFunction::Min => Ok(true), AggregateFunction::ArrayAgg => Ok(false), - AggregateFunction::NthValue => Ok(true), } } } @@ -147,7 +141,6 @@ impl AggregateFunction { .collect::>(); Signature::uniform(1, valid, Volatility::Immutable) } - AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), } } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 579f5fed578f..ecece6dbfce7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1413,12 +1413,19 @@ impl Expr { .unwrap() } + /// Returns true if the expression node is volatile, i.e. whether it can return + /// different results when evaluated multiple times with the same input. + /// Note: unlike [`Self::is_volatile`], this function does not consider inputs: + /// - `rand()` returns `true`, + /// - `a + rand()` returns `false` + pub fn is_volatile_node(&self) -> bool { + matches!(self, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile) + } + /// Returns true if the expression is volatile, i.e. whether it can return different /// results when evaluated multiple times with the same input. pub fn is_volatile(&self) -> Result { - self.exists(|expr| { - Ok(matches!(expr, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile )) - }) + self.exists(|expr| Ok(expr.is_volatile_node())) } /// Recursively find all [`Expr::Placeholder`] expressions, and diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 024e4a0ceae5..91bec501f4a0 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -215,7 +215,7 @@ pub fn coerce_plan_expr_for_schema( LogicalPlan::Projection(Projection { expr, input, .. }) => { let new_exprs = coerce_exprs_for_schema(expr.clone(), input.schema(), schema)?; - let projection = Projection::try_new(new_exprs, input.clone())?; + let projection = Projection::try_new(new_exprs, Arc::clone(input))?; Ok(LogicalPlan::Projection(projection)) } _ => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 8bb655eda575..1df5d6c4d736 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -118,18 +118,18 @@ impl ExprSchemable for Expr { Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list - match arg_data_type{ - DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) =>{ - Ok(field.data_type().clone()) - } - DataType::Struct(_) => { - Ok(arg_data_type) - } + match arg_data_type { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()), + DataType::Struct(_) => Ok(arg_data_type), DataType::Null => { not_impl_err!("unnest() does not support null yet") } _ => { - plan_err!("unnest() can only be applied to array, struct and null") + plan_err!( + "unnest() can only be applied to array, struct and null" + ) } } } @@ -138,22 +138,22 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { - plan_datafusion_err!( - "{} {}", - err, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_data_types, - ) - ) - })?; - - // perform additional function arguments validation (due to limited - // expressiveness of `TypeSignature`), then infer return type - Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?) + // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { + plan_datafusion_err!( + "{} {}", + err, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_data_types, + ) + ) + })?; + + // perform additional function arguments validation (due to limited + // expressiveness of `TypeSignature`), then infer return type + Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args @@ -166,7 +166,8 @@ impl ExprSchemable for Expr { .collect::>>()?; match fun { WindowFunctionDefinition::AggregateUDF(udf) => { - let new_types = data_types_with_aggregate_udf(&data_types, udf).map_err(|err| { + let new_types = data_types_with_aggregate_udf(&data_types, udf) + .map_err(|err| { plan_datafusion_err!( "{} {}", err, @@ -179,9 +180,7 @@ impl ExprSchemable for Expr { })?; Ok(fun.return_type(&new_types, &nullability)?) } - _ => { - fun.return_type(&data_types, &nullability) - } + _ => fun.return_type(&data_types, &nullability), } } Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { @@ -198,7 +197,8 @@ impl ExprSchemable for Expr { fun.return_type(&data_types, &nullability) } AggregateFunctionDefinition::UDF(fun) => { - let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| { + let new_types = data_types_with_aggregate_udf(&data_types, fun) + .map_err(|err| { plan_datafusion_err!( "{} {}", err, @@ -237,7 +237,11 @@ impl ExprSchemable for Expr { Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), Expr::Placeholder(Placeholder { data_type, .. }) => { data_type.clone().ok_or_else(|| { - plan_datafusion_err!("Placeholder type could not be resolved. Make sure that the placeholder is bound to a concrete type, e.g. by providing parameter values.") + plan_datafusion_err!( + "Placeholder type could not be resolved. Make sure that the \ + placeholder is bound to a concrete type, e.g. by providing \ + parameter values." + ) }) } Expr::Wildcard { qualifier } => { @@ -326,6 +330,9 @@ impl ExprSchemable for Expr { match func_def { AggregateFunctionDefinition::BuiltIn(fun) => fun.nullable(), // TODO: UDF should be able to customize nullability + AggregateFunctionDefinition::UDF(udf) if udf.name() == "count" => { + Ok(false) + } AggregateFunctionDefinition::UDF(_) => Ok(true), } } @@ -520,7 +527,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 5f1d3c9d5c6b..e1943c890e7c 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! [DataFusion](https://github.com/apache/datafusion) //! is an extensible query execution framework that uses diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index cc4348d58c33..4ad3bd5018a4 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1223,17 +1223,17 @@ pub fn build_join_schema( JoinType::Inner => { // left then right let left_fields = left_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); let right_fields = right_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); left_fields.into_iter().chain(right_fields).collect() } JoinType::Left => { // left then right, right set to nullable in case of not matched scenario let left_fields = left_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); left_fields .into_iter() @@ -1243,7 +1243,7 @@ pub fn build_join_schema( JoinType::Right => { // left then right, left set to nullable in case of not matched scenario let right_fields = right_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); nullify_fields(left_fields) .into_iter() @@ -1259,11 +1259,15 @@ pub fn build_join_schema( } JoinType::LeftSemi | JoinType::LeftAnti => { // Only use the left side for the schema - left_fields.map(|(q, f)| (q.cloned(), f.clone())).collect() + left_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .collect() } JoinType::RightSemi | JoinType::RightAnti => { // Only use the right side for the schema - right_fields.map(|(q, f)| (q.cloned(), f.clone())).collect() + right_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .collect() } }; let func_dependencies = left.functional_dependencies().join( @@ -1577,7 +1581,7 @@ impl TableSource for LogicalTableSource { } fn schema(&self) -> SchemaRef { - self.table_schema.clone() + Arc::clone(&self.table_schema) } fn supports_filters_pushdown( @@ -1691,7 +1695,10 @@ pub fn unnest_with_options( } None => { dependency_indices.push(index); - Ok(vec![(original_qualifier.cloned(), original_field.clone())]) + Ok(vec![( + original_qualifier.cloned(), + Arc::clone(original_field), + )]) } } }) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 2921541934f8..8fd5982a0f2e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -762,9 +762,9 @@ impl LogicalPlan { // If inputs are not pruned do not change schema // TODO this seems wrong (shouldn't we always use the schema of the input?) let schema = if schema.fields().len() == input_schema.fields().len() { - schema.clone() + Arc::clone(&schema) } else { - input_schema.clone() + Arc::clone(input_schema) }; Ok(LogicalPlan::Union(Union { inputs, schema })) } @@ -850,7 +850,7 @@ impl LogicalPlan { .. }) => Ok(LogicalPlan::Dml(DmlStatement::new( table_name.clone(), - table_schema.clone(), + Arc::clone(table_schema), op.clone(), Arc::new(inputs.swap_remove(0)), ))), @@ -863,13 +863,13 @@ impl LogicalPlan { }) => Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(inputs.swap_remove(0)), output_url: output_url.clone(), - file_type: file_type.clone(), + file_type: Arc::clone(file_type), options: options.clone(), partition_by: partition_by.clone(), })), LogicalPlan::Values(Values { schema, .. }) => { Ok(LogicalPlan::Values(Values { - schema: schema.clone(), + schema: Arc::clone(schema), values: expr .chunks_exact(schema.fields().len()) .map(|s| s.to_vec()) @@ -1027,9 +1027,9 @@ impl LogicalPlan { let input_schema = inputs[0].schema(); // If inputs are not pruned do not change schema. let schema = if schema.fields().len() == input_schema.fields().len() { - schema.clone() + Arc::clone(schema) } else { - input_schema.clone() + Arc::clone(input_schema) }; Ok(LogicalPlan::Union(Union { inputs: inputs.into_iter().map(Arc::new).collect(), @@ -1073,7 +1073,7 @@ impl LogicalPlan { assert_eq!(inputs.len(), 1); Ok(LogicalPlan::Analyze(Analyze { verbose: a.verbose, - schema: a.schema.clone(), + schema: Arc::clone(&a.schema), input: Arc::new(inputs.swap_remove(0)), })) } @@ -1087,7 +1087,7 @@ impl LogicalPlan { verbose: e.verbose, plan: Arc::new(inputs.swap_remove(0)), stringified_plans: e.stringified_plans.clone(), - schema: e.schema.clone(), + schema: Arc::clone(&e.schema), logical_optimization_succeeded: e.logical_optimization_succeeded, })) } @@ -1369,7 +1369,7 @@ impl LogicalPlan { param_values: &ParamValues, ) -> Result { self.transform_up_with_subqueries(|plan| { - let schema = plan.schema().clone(); + let schema = Arc::clone(plan.schema()); plan.map_expressions(|e| { e.infer_placeholder_types(&schema)?.transform_up(|e| { if let Expr::Placeholder(Placeholder { id, .. }) = e { @@ -2227,7 +2227,7 @@ impl Window { let fields: Vec<(Option, Arc)> = input .schema() .iter() - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect(); let input_len = fields.len(); let mut window_fields = fields; @@ -3352,7 +3352,7 @@ digraph { vec![col("a")], Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: empty_schema.clone(), + schema: Arc::clone(&empty_schema), })), empty_schema, ); @@ -3467,9 +3467,9 @@ digraph { ); let scan = Arc::new(LogicalPlan::TableScan(TableScan { table_name: TableReference::bare("tab"), - source: source.clone(), + source: Arc::clone(&source) as Arc, projection: None, - projected_schema: schema.clone(), + projected_schema: Arc::clone(&schema), filters: vec![], fetch: None, })); @@ -3499,7 +3499,7 @@ digraph { table_name: TableReference::bare("tab"), source, projection: None, - projected_schema: unique_schema.clone(), + projected_schema: Arc::clone(&unique_schema), filters: vec![], fetch: None, })); diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index c928ab39194d..c255edbea5ae 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -84,7 +84,8 @@ pub trait ContextProvider { /// This trait allows users to customize the behavior of the SQL planner pub trait UserDefinedSQLPlanner: Send + Sync { - /// Plan the binary operation between two expressions, returns OriginalBinaryExpr if not possible + /// Plan the binary operation between two expressions, returns original + /// BinaryExpr if not possible fn plan_binary_op( &self, expr: RawBinaryExpr, @@ -93,7 +94,9 @@ pub trait UserDefinedSQLPlanner: Send + Sync { Ok(PlannerResult::Original(expr)) } - /// Plan the field access expression, returns OriginalFieldAccessExpr if not possible + /// Plan the field access expression + /// + /// returns original FieldAccessExpr if not possible fn plan_field_access( &self, expr: RawFieldAccessExpr, @@ -102,7 +105,9 @@ pub trait UserDefinedSQLPlanner: Send + Sync { Ok(PlannerResult::Original(expr)) } - // Plan the array literal, returns OriginalArray if not possible + /// Plan the array literal, returns OriginalArray if not possible + /// + /// Returns origin expression arguments if not possible fn plan_array_literal( &self, exprs: Vec, @@ -110,6 +115,52 @@ pub trait UserDefinedSQLPlanner: Send + Sync { ) -> Result>> { Ok(PlannerResult::Original(exprs)) } + + // Plan the POSITION expression, e.g., POSITION( in ) + // returns origin expression arguments if not possible + fn plan_position(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plan the dictionary literal `{ key: value, ...}` + /// + /// Returns origin expression arguments if not possible + fn plan_dictionary_literal( + &self, + expr: RawDictionaryExpr, + _schema: &DFSchema, + ) -> Result> { + Ok(PlannerResult::Original(expr)) + } + + /// Plan an extract expression, e.g., `EXTRACT(month FROM foo)` + /// + /// Returns origin expression arguments if not possible + fn plan_extract(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plan an substring expression, e.g., `SUBSTRING( [FROM ] [FOR ])` + /// + /// Returns origin expression arguments if not possible + fn plan_substring(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plans a struct `struct(expression1[, ..., expression_n])` + /// literal based on the given input expressions. + /// This function takes a vector of expressions and a boolean flag indicating whether + /// the struct uses the optional name + /// + /// Returns a `PlannerResult` containing either the planned struct expressions or the original + /// input expressions if planning is not possible. + fn plan_struct_literal( + &self, + args: Vec, + _is_named_struct: bool, + ) -> Result>> { + Ok(PlannerResult::Original(args)) + } } /// An operator with two arguments to plan @@ -136,6 +187,16 @@ pub struct RawFieldAccessExpr { pub expr: Expr, } +/// A Dictionary literal expression `{ key: value, ...}` +/// +/// This structure is used by [`UserDefinedSQLPlanner`] to plan operators with +/// custom expressions. +#[derive(Debug, Clone)] +pub struct RawDictionaryExpr { + pub keys: Vec, + pub values: Vec, +} + /// Result of planning a raw expr with [`UserDefinedSQLPlanner`] #[derive(Debug, Clone)] pub enum PlannerResult { diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index c276fe30f897..6a27c05bb451 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -110,6 +110,9 @@ pub trait FunctionRegistry { not_impl_err!("Registering FunctionRewrite") } + /// Set of all registered [`UserDefinedSQLPlanner`]s + fn expr_planners(&self) -> Vec>; + /// Registers a new [`UserDefinedSQLPlanner`] with the registry. fn register_user_defined_sql_planner( &mut self, @@ -192,4 +195,8 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn register_udwf(&mut self, udaf: Arc) -> Result>> { Ok(self.udwfs.insert(udaf.name().into(), udaf)) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 0f7464b96b3e..fbec6e2f8024 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -101,7 +101,6 @@ pub fn coerce_types( // unpack the dictionary to get the value get_min_max_result_type(input_types) } - AggregateFunction::NthValue => Ok(input_types.to_vec()), } } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 5645a2a4dede..83a7da046844 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -1048,16 +1048,16 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(lhs_tz.clone()), - (lhs, rhs) if lhs == rhs => Some(lhs_tz.clone()), + ("UTC", "+00:00") | ("+00:00", "UTC") => Some(Arc::clone(lhs_tz)), + (lhs, rhs) if lhs == rhs => Some(Arc::clone(lhs_tz)), // can't cast across timezones _ => { return None; } } } - (Some(lhs_tz), None) => Some(lhs_tz.clone()), - (None, Some(rhs_tz)) => Some(rhs_tz.clone()), + (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)), + (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)), (None, None) => None, }; @@ -1076,7 +1076,7 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Microsecond, (l, r) => { assert_eq!(l, r); - l.clone() + *l } }; diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 5f060a4a4f16..b430b343e484 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -598,7 +598,7 @@ fn coerced_from<'a>( Arc::new(f_into.as_ref().clone().with_data_type(data_type)); Some(FixedSizeList(new_field, *size_from)) } - Some(_) => Some(FixedSizeList(f_into.clone(), *size_from)), + Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)), _ => None, } } @@ -607,11 +607,11 @@ fn coerced_from<'a>( (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { - Some(Timestamp(unit.clone(), Some(from_tz.clone()))) + Some(Timestamp(*unit, Some(Arc::clone(from_tz)))) } Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => { // In the absence of any other information assume the time zone is "+00" (UTC). - Some(Timestamp(unit.clone(), Some("+00".into()))) + Some(Timestamp(*unit, Some("+00".into()))) } _ => None, } @@ -715,12 +715,12 @@ mod tests { fn test_fixed_list_wildcard_coerce() -> Result<()> { let inner = Arc::new(Field::new("item", DataType::Int32, false)); let current_types = vec![ - DataType::FixedSizeList(inner.clone(), 2), // able to coerce for any size + DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size ]; let signature = Signature::exact( vec![DataType::FixedSizeList( - inner.clone(), + Arc::clone(&inner), FIXED_SIZE_LIST_WILDCARD, )], Volatility::Stable, @@ -731,7 +731,7 @@ mod tests { // make sure it can't coerce to a different size let signature = Signature::exact( - vec![DataType::FixedSizeList(inner.clone(), 3)], + vec![DataType::FixedSizeList(Arc::clone(&inner), 3)], Volatility::Stable, ); let coerced_data_types = data_types(¤t_types, &signature); @@ -739,7 +739,7 @@ mod tests { // make sure it works with the same type. let signature = Signature::exact( - vec![DataType::FixedSizeList(inner.clone(), 2)], + vec![DataType::FixedSizeList(Arc::clone(&inner), 2)], Volatility::Stable, ); let coerced_data_types = data_types(¤t_types, &signature).unwrap(); diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index c8362691452b..7a054abea75b 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -106,8 +106,8 @@ impl AggregateUDF { Self::new_from_impl(AggregateUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), - return_type: return_type.clone(), - accumulator: accumulator.clone(), + return_type: Arc::clone(return_type), + accumulator: Arc::clone(accumulator), }) } @@ -133,7 +133,10 @@ impl AggregateUDF { /// /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedAggregateUDFImpl::new(self.inner.clone(), aliases)) + Self::new_from_impl(AliasedAggregateUDFImpl::new( + Arc::clone(&self.inner), + aliases, + )) } /// creates an [`Expr`] that calls the aggregate function. diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 03650b1d4748..68d3af6ace3c 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -87,8 +87,8 @@ impl ScalarUDF { Self::new_from_impl(ScalarUdfLegacyWrapper { name: name.to_owned(), signature: signature.clone(), - return_type: return_type.clone(), - fun: fun.clone(), + return_type: Arc::clone(return_type), + fun: Arc::clone(fun), }) } @@ -114,7 +114,7 @@ impl ScalarUDF { /// /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedScalarUDFImpl::new(self.inner.clone(), aliases)) + Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases)) } /// Returns a [`Expr`] logical expression to call this UDF with specified @@ -199,7 +199,7 @@ impl ScalarUDF { /// Returns a `ScalarFunctionImplementation` that can invoke the function /// during execution pub fn fun(&self) -> ScalarFunctionImplementation { - let captured = self.inner.clone(); + let captured = Arc::clone(&self.inner); Arc::new(move |args| captured.invoke(args)) } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index a17bb0ade8e3..70b44e5e307a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -90,8 +90,8 @@ impl WindowUDF { Self::new_from_impl(WindowUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), - return_type: return_type.clone(), - partition_evaluator_factory: partition_evaluator_factory.clone(), + return_type: Arc::clone(return_type), + partition_evaluator_factory: Arc::clone(partition_evaluator_factory), }) } @@ -117,7 +117,7 @@ impl WindowUDF { /// /// If you implement [`WindowUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedWindowUDFImpl::new(self.inner.clone(), aliases)) + Self::new_from_impl(AliasedWindowUDFImpl::new(Arc::clone(&self.inner), aliases)) } /// creates a [`Expr`] that calls the window function given diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 286f05309ea7..34e007207427 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -864,6 +864,7 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::List(_) => true, DataType::LargeList(_) => true, DataType::FixedSizeList(_, _) => true, + DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())), _ => false, } } @@ -1199,7 +1200,7 @@ pub fn only_or_err(slice: &[T]) -> Result<&T> { /// merge inputs schema into a single schema. pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { if inputs.len() == 1 { - inputs[0].schema().clone().as_ref().clone() + inputs[0].schema().as_ref().clone() } else { inputs.iter().map(|input| input.schema()).fold( DFSchema::empty(), diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 6f1e97a16380..7c6aef9944f6 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -211,7 +211,7 @@ where impl Accumulator for NumericHLLAccumulator where - T: ArrowPrimitiveType + std::fmt::Debug, + T: ArrowPrimitiveType + Debug, T::Native: Hash, { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 5ae5684d9cab..bbe7d21e2486 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -191,7 +191,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { } #[allow(rustdoc::private_intra_doc_links)] - /// See [`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`] for a description of the serialised + /// See [`TDigest::to_scalar_state()`] for a description of the serialised /// state. fn state_fields( &self, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 1dc1f10afce6..18642fb84329 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -489,7 +489,7 @@ where .into_iter() .zip(counts.into_iter()) .map(|(sum, count)| (self.avg_fn)(sum, count)) - .collect::>>()?; + .collect::>>()?; PrimitiveArray::new(averages.into(), Some(nulls)) // no copy .with_data_type(self.return_data_type.clone()) }; diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index ba9964270443..9224b06e407a 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -245,7 +245,7 @@ struct BitAndAccumulator { } impl std::fmt::Debug for BitAndAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "BitAndAccumulator({})", T::DATA_TYPE) } } @@ -290,7 +290,7 @@ struct BitOrAccumulator { } impl std::fmt::Debug for BitOrAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "BitOrAccumulator({})", T::DATA_TYPE) } } @@ -335,7 +335,7 @@ struct BitXorAccumulator { } impl std::fmt::Debug for BitXorAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "BitXorAccumulator({})", T::DATA_TYPE) } } @@ -380,7 +380,7 @@ struct DistinctBitXorAccumulator { } impl std::fmt::Debug for DistinctBitXorAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE) } } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 0fc8e32d7240..bd0155df0271 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -62,17 +62,15 @@ make_udaf_expr_and_func!( count_udaf ); -pub fn count_distinct(expr: Expr) -> datafusion_expr::Expr { - datafusion_expr::Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new_udf( - count_udaf(), - vec![expr], - true, - None, - None, - None, - ), - ) +pub fn count_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + count_udaf(), + vec![expr], + true, + None, + None, + None, + )) } pub struct Count { diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index fc485a284ab4..6ae2dfb3697c 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -74,6 +74,7 @@ pub mod average; pub mod bit_and_or_xor; pub mod bool_and_or; pub mod grouping; +pub mod nth_value; pub mod string_agg; use crate::approx_percentile_cont::approx_percentile_cont_udaf; @@ -105,6 +106,7 @@ pub mod expr_fn { pub use super::first_last::last_value; pub use super::grouping::grouping; pub use super::median::median; + pub use super::nth_value::nth_value; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; @@ -157,6 +159,7 @@ pub fn all_default_aggregate_functions() -> Vec> { bool_and_or::bool_or_udaf(), average::avg_udaf(), grouping::grouping_udaf(), + nth_value::nth_value_udaf(), ] } diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs similarity index 77% rename from datafusion/physical-expr/src/aggregate/nth_value.rs rename to datafusion/functions-aggregate/src/nth_value.rs index f6d25348f222..6719c673c55b 100644 --- a/datafusion/physical-expr/src/aggregate/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -22,149 +22,149 @@ use std::any::Any; use std::collections::VecDeque; use std::sync::Arc; -use crate::aggregate::array_agg_ordered::merge_ordered_arrays; -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; -use crate::expressions::{format_state_name, Literal}; -use crate::{ - reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, -}; - -use arrow_array::cast::AsArray; -use arrow_array::{new_empty_array, ArrayRef, StructArray}; +use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; use arrow_schema::{DataType, Field, Fields}; + use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; -use datafusion_expr::utils::AggregateOrderSensitivity; -use datafusion_expr::Accumulator; +use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDF, AggregateUDFImpl, Expr, ReversedUDAF, Signature, + Volatility, +}; +use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; +use datafusion_physical_expr_common::aggregate::utils::ordering_fields; +use datafusion_physical_expr_common::sort_expr::{ + limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, +}; + +make_udaf_expr_and_func!( + NthValueAgg, + nth_value, + "Returns the nth value in a group of values.", + nth_value_udaf +); /// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi /// partition setting, partial aggregations are computed for every partition, /// and then their results are merged. #[derive(Debug)] pub struct NthValueAgg { - /// Column name - name: String, - /// The `DataType` for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// The `N` value. - n: i64, - /// If the input expression can have `NULL`s - nullable: bool, - /// Ordering data types - order_by_data_types: Vec, - /// Ordering requirement - ordering_req: LexOrdering, + signature: Signature, + /// Determines whether `N` is relative to the beginning or the end + /// of the aggregation. When set to `true`, then `N` is from the end. + reversed: bool, } impl NthValueAgg { /// Create a new `NthValueAgg` aggregate function - pub fn new( - expr: Arc, - n: i64, - name: impl Into, - input_data_type: DataType, - nullable: bool, - order_by_data_types: Vec, - ordering_req: LexOrdering, - ) -> Self { + pub fn new() -> Self { Self { - name: name.into(), - input_data_type, - expr, - n, - nullable, - order_by_data_types, - ordering_req, + signature: Signature::any(2, Volatility::Immutable), + reversed: false, } } + + pub fn with_reversed(mut self, reversed: bool) -> Self { + self.reversed = reversed; + self + } } -impl AggregateExpr for NthValueAgg { +impl Default for NthValueAgg { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for NthValueAgg { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) + fn name(&self) -> &str { + "nth_value" + } + + fn signature(&self) -> &Signature { + &self.signature } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(NthValueAccumulator::try_new( - self.n, - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - )?)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) } - fn state_fields(&self) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let n = match acc_args.input_exprs[1] { + Expr::Literal(ScalarValue::Int64(Some(value))) => { + if self.reversed { + Ok(-value) + } else { + Ok(value) + } + } + _ => not_impl_err!( + "{} not supported for n: {}", + self.name(), + &acc_args.input_exprs[1] + ), + }?; + + let ordering_req = limited_convert_logical_sort_exprs_to_physical( + acc_args.sort_exprs, + acc_args.schema, + )?; + + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; + + NthValueAccumulator::try_new( + n, + acc_args.input_type, + &ordering_dtypes, + ordering_req, + ) + .map(|acc| Box::new(acc) as _) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new_list( - format_state_name(&self.name, "nth_value"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, // This should be the same as field() + format_state_name(self.name(), "nth_value"), + // TODO: The nullability of the list element should be configurable. + // The hard-coded `true` should be changed once the field for + // nullability is added to `StateFieldArgs` struct. + // See: https://github.com/apache/datafusion/pull/11063 + Field::new("item", args.input_type.clone(), true), + false, )]; - if !self.ordering_req.is_empty() { - let orderings = - ordering_fields(&self.ordering_req, &self.order_by_data_types); + let orderings = args.ordering_fields.to_vec(); + if !orderings.is_empty() { fields.push(Field::new_list( - format_state_name(&self.name, "nth_value_orderings"), + format_state_name(self.name(), "nth_value_orderings"), Field::new("item", DataType::Struct(Fields::from(orderings)), true), - self.nullable, + false, )); } Ok(fields) } - fn expressions(&self) -> Vec> { - let n = Arc::new(Literal::new(ScalarValue::Int64(Some(self.n)))) as _; - vec![self.expr.clone(), n] + fn aliases(&self) -> &[String] { + &[] } - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } - - fn order_sensitivity(&self) -> AggregateOrderSensitivity { - AggregateOrderSensitivity::HardRequirement - } - - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.to_string(), - input_data_type: self.input_data_type.clone(), - expr: self.expr.clone(), - // index should be from the opposite side - n: -self.n, - nullable: self.nullable, - order_by_data_types: self.order_by_data_types.clone(), - // reverse requirement - ordering_req: reverse_order_bys(&self.ordering_req), - }) as _) - } -} - -impl PartialEq for NthValueAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Reversed(Arc::from(AggregateUDF::from( + Self::new().with_reversed(!self.reversed), + ))) } } #[derive(Debug)] -pub(crate) struct NthValueAccumulator { +pub struct NthValueAccumulator { + /// The `N` value. n: i64, /// Stores entries in the `NTH_VALUE` result. values: VecDeque, diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index eb1ef9e03f31..73c5b9114a2c 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -49,6 +49,7 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } +datafusion-functions-aggregate = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } paste = "1.0.14" diff --git a/datafusion/functions-array/src/planner.rs b/datafusion/functions-array/src/planner.rs index cfb3e5ed0729..01853fb56908 100644 --- a/datafusion/functions-array/src/planner.rs +++ b/datafusion/functions-array/src/planner.rs @@ -23,6 +23,7 @@ use datafusion_expr::{ sqlparser, AggregateFunction, Expr, ExprSchemable, GetFieldAccess, }; use datafusion_functions::expr_fn::get_field; +use datafusion_functions_aggregate::nth_value::nth_value_udaf; use crate::{ array_has::array_has_all, @@ -119,8 +120,8 @@ impl UserDefinedSQLPlanner for FieldAccessPlanner { // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { Ok(PlannerResult::Planned(Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new( - AggregateFunction::NthValue, + datafusion_expr::expr::AggregateFunction::new_udf( + nth_value_udaf(), agg_func .args .into_iter() diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index a2742220f3e9..062a4a104d54 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -29,6 +29,7 @@ pub mod named_struct; pub mod nullif; pub mod nvl; pub mod nvl2; +pub mod planner; pub mod r#struct; // create UDFs @@ -92,7 +93,6 @@ pub fn functions() -> Vec> { nvl(), nvl2(), arrow_typeof(), - r#struct(), named_struct(), get_field(), coalesce(), diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs new file mode 100644 index 000000000000..e803c92dd0b3 --- /dev/null +++ b/datafusion/functions/src/core/planner.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::DFSchema; +use datafusion_common::Result; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::planner::{PlannerResult, RawDictionaryExpr, UserDefinedSQLPlanner}; +use datafusion_expr::Expr; + +use super::named_struct; + +#[derive(Default)] +pub struct CoreFunctionPlanner {} + +impl UserDefinedSQLPlanner for CoreFunctionPlanner { + fn plan_dictionary_literal( + &self, + expr: RawDictionaryExpr, + _schema: &DFSchema, + ) -> Result> { + let mut args = vec![]; + for (k, v) in expr.keys.into_iter().zip(expr.values.into_iter()) { + args.push(k); + args.push(v); + } + Ok(PlannerResult::Planned(named_struct().call(args))) + } + + fn plan_struct_literal( + &self, + args: Vec, + is_named_struct: bool, + ) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + if is_named_struct { + crate::core::named_struct() + } else { + crate::core::r#struct() + }, + args, + ), + ))) + } +} diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index e777e5ea95d0..997f1a36ad04 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -57,35 +57,35 @@ impl DateBinFunc { vec![ Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), Timestamp(Nanosecond, None), ]), Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ DataType::Interval(DayTime), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), Timestamp(Nanosecond, None), ]), Exact(vec![ DataType::Interval(DayTime), - Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), ]), Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ DataType::Interval(DayTime), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), ]), Exact(vec![ DataType::Interval(DayTime), diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 4bc24931d06b..433a4f90d95b 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -130,6 +130,9 @@ make_stub_package!(crypto, "crypto_expressions"); pub mod unicode; make_stub_package!(unicode, "unicode_expressions"); +#[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))] +pub mod planner; + mod utils; /// Fluent-style API for creating `Expr`s diff --git a/datafusion/functions/src/planner.rs b/datafusion/functions/src/planner.rs new file mode 100644 index 000000000000..41ff92f26111 --- /dev/null +++ b/datafusion/functions/src/planner.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL planning extensions like [`UserDefinedFunctionPlanner`] + +use datafusion_common::Result; +use datafusion_expr::{ + expr::ScalarFunction, + planner::{PlannerResult, UserDefinedSQLPlanner}, + Expr, +}; + +#[derive(Default)] +pub struct UserDefinedFunctionPlanner; + +impl UserDefinedSQLPlanner for UserDefinedFunctionPlanner { + #[cfg(feature = "datetime_expressions")] + fn plan_extract(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::datetime::date_part(), args), + ))) + } + + #[cfg(feature = "unicode_expressions")] + fn plan_position(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::unicode::strpos(), args), + ))) + } + + #[cfg(feature = "unicode_expressions")] + fn plan_substring(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::unicode::substr(), args), + ))) + } +} diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index c297182057fe..9d15920bb655 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -32,6 +32,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub struct SubstrFunc { signature: Signature, + aliases: Vec, } impl Default for SubstrFunc { @@ -53,6 +54,7 @@ impl SubstrFunc { ], Volatility::Immutable, ), + aliases: vec![String::from("substring")], } } } @@ -81,6 +83,10 @@ impl ScalarUDFImpl for SubstrFunc { other => exec_err!("Unsupported data type {other:?} for function substr"), } } + + fn aliases(&self) -> &[String] { + &self.aliases + } } /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index 2f1f85e3a57a..5aacfaf59cb1 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -67,8 +67,10 @@ let optimizer = Optimizer::with_rules(vec![ ## Writing Optimization Rules -Please refer to the [rewrite_expr example](../../datafusion-examples/examples/rewrite_expr.rs) to learn more about -the general approach to writing optimizer rules and then move onto studying the existing rules. +Please refer to the +[optimizer_rule.rs](../../datafusion-examples/examples/optimizer_rule.rs) +example to learn more about the general approach to writing optimizer rules and +then move onto studying the existing rules. All rules must implement the `OptimizerRule` trait. diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 34f9802b1fd9..959ffdaaa212 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -127,9 +127,9 @@ mod tests { .project(vec![count(wildcard())])? .sort(vec![count(wildcard()).sort(true, false)])? .build()?; - let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64;N]\ - \n Projection: count(*) [count(*):Int64;N]\ - \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64;N]\ + let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) } @@ -152,9 +152,9 @@ mod tests { .build()?; let expected = "Filter: t1.a IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [count(*):Int64;N]\ - \n Projection: count(*) [count(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\ + \n Subquery: [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) @@ -175,9 +175,9 @@ mod tests { .build()?; let expected = "Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [count(*):Int64;N]\ - \n Projection: count(*) [count(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\ + \n Subquery: [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) @@ -207,9 +207,9 @@ mod tests { let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\ \n Filter: () > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [count(Int64(1)):Int64;N]\ - \n Projection: count(Int64(1)) [count(Int64(1)):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64;N]\ + \n Subquery: [count(Int64(1)):Int64]\ + \n Projection: count(Int64(1)) [count(Int64(1)):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64]\ \n Filter: outer_ref(t1.a) = t2.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; @@ -235,7 +235,7 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64;N]\ + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ \n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) @@ -249,8 +249,8 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: count(*) [count(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64;N]\ + let expected = "Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) } @@ -272,7 +272,7 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64;N]\ + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ \n Aggregate: groupBy=[[]], aggr=[[MAX(count(Int64(1))) AS MAX(count(*))]] [MAX(count(*)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(plan, expected) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 5eec5ba09d7e..4a4933fe9cfd 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -56,9 +56,13 @@ struct Identifier<'n> { } impl<'n> Identifier<'n> { - fn new(expr: &'n Expr, random_state: &RandomState) -> Self { + fn new(expr: &'n Expr, is_tree: bool, random_state: &RandomState) -> Self { let mut hasher = random_state.build_hasher(); - expr.hash_node(&mut hasher); + if is_tree { + expr.hash(&mut hasher); + } else { + expr.hash_node(&mut hasher); + } let hash = hasher.finish(); Self { hash, expr } } @@ -187,24 +191,19 @@ impl CommonSubexprEliminate { id_array: &mut IdArray<'n>, expr_mask: ExprMask, ) -> Result { - // Don't consider volatile expressions for CSE. - Ok(if expr.is_volatile()? { - false - } else { - let mut visitor = ExprIdentifierVisitor { - expr_stats, - id_array, - visit_stack: vec![], - down_index: 0, - up_index: 0, - expr_mask, - random_state: &self.random_state, - found_common: false, - }; - expr.visit(&mut visitor)?; + let mut visitor = ExprIdentifierVisitor { + expr_stats, + id_array, + visit_stack: vec![], + down_index: 0, + up_index: 0, + expr_mask, + random_state: &self.random_state, + found_common: false, + }; + expr.visit(&mut visitor)?; - visitor.found_common - }) + Ok(visitor.found_common) } /// Rewrites `exprs_list` with common sub-expressions replaced with a new @@ -911,31 +910,53 @@ struct ExprIdentifierVisitor<'a, 'n> { found_common: bool, } -/// Record item that used when traversing a expression tree. +/// Record item that used when traversing an expression tree. enum VisitRecord<'n> { - /// `usize` postorder index assigned in `f-down`(). Starts from 0. - EnterMark(usize), - /// the node's children were skipped => jump to f_up on same node - JumpMark, - /// Accumulated identifier of sub expression. - ExprItem(Identifier<'n>), + /// Marks the beginning of expression. It contains: + /// - The post-order index assigned during the first, visiting traversal. + /// - A boolean flag if the record marks an expression subtree (not just a single + /// node). + EnterMark(usize, bool), + + /// Marks an accumulated subexpression tree. It contains: + /// - The accumulated identifier of a subexpression. + /// - A boolean flag if the expression is valid for subexpression elimination. + /// The flag is propagated up from children to parent. (E.g. volatile expressions + /// are not valid and can't be extracted, but non-volatile children of volatile + /// expressions can be extracted.) + ExprItem(Identifier<'n>, bool), } impl<'n> ExprIdentifierVisitor<'_, 'n> { - /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` - /// before it. - fn pop_enter_mark(&mut self) -> Option<(usize, Option>)> { + /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` before + /// it. Returns a tuple that contains: + /// - The pre-order index of the expression we marked. + /// - A boolean flag if we marked an expression subtree (not just a single node). + /// If true we didn't recurse into the node's children, so we need to calculate the + /// hash of the marked expression tree (not just the node) and we need to validate + /// the expression tree (not just the node). + /// - The accumulated identifier of the children of the marked expression. + /// - An accumulated boolean flag from the children of the marked expression if all + /// children are valid for subexpression elimination (i.e. it is safe to extract the + /// expression as a common expression from its children POV). + /// (E.g. if any of the children of the marked expression is not valid (e.g. is + /// volatile) then the expression is also not valid, so we can propagate this + /// information up from children to parents via `visit_stack` during the first, + /// visiting traversal and no need to test the expression's validity beforehand with + /// an extra traversal). + fn pop_enter_mark(&mut self) -> (usize, bool, Option>, bool) { let mut expr_id = None; + let mut is_valid = true; while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(idx) => { - return Some((idx, expr_id)); + VisitRecord::EnterMark(down_index, is_tree) => { + return (down_index, is_tree, expr_id, is_valid); } - VisitRecord::ExprItem(id) => { - expr_id = Some(id.combine(expr_id)); + VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => { + expr_id = Some(sub_expr_id.combine(expr_id)); + is_valid &= sub_expr_is_valid; } - VisitRecord::JumpMark => return None, } } unreachable!("Enter mark should paired with node number"); @@ -946,34 +967,43 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { type Node = Expr; fn f_down(&mut self, expr: &'n Expr) -> Result { - // TODO: consider non-volatile sub-expressions for CSE + // If an expression can short circuit its children then don't consider its + // children for CSE (https://github.com/apache/arrow-datafusion/issues/8814). + // This means that we don't recurse into its children, but handle the expression + // as a subtree when we calculate its identifier. // TODO: consider surely executed children of "short circuited"s for CSE - - // If an expression can short circuit its children then don't consider it for CSE - // (https://github.com/apache/arrow-datafusion/issues/8814). - if expr.short_circuits() { - self.visit_stack.push(VisitRecord::JumpMark); - - return Ok(TreeNodeRecursion::Jump); - } + let is_tree = expr.short_circuits(); + let tnr = if is_tree { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }; self.id_array.push((0, None)); self.visit_stack - .push(VisitRecord::EnterMark(self.down_index)); + .push(VisitRecord::EnterMark(self.down_index, is_tree)); self.down_index += 1; - Ok(TreeNodeRecursion::Continue) + Ok(tnr) } fn f_up(&mut self, expr: &'n Expr) -> Result { - let Some((down_index, sub_expr_id)) = self.pop_enter_mark() else { - return Ok(TreeNodeRecursion::Continue); - }; + let (down_index, is_tree, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); - let expr_id = Identifier::new(expr, self.random_state).combine(sub_expr_id); + let (expr_id, is_valid) = if is_tree { + ( + Identifier::new(expr, true, self.random_state), + !expr.is_volatile()?, + ) + } else { + ( + Identifier::new(expr, false, self.random_state).combine(sub_expr_id), + !expr.is_volatile_node() && sub_expr_is_valid, + ) + }; self.id_array[down_index].0 = self.up_index; - if !self.expr_mask.ignores(expr) { + if is_valid && !self.expr_mask.ignores(expr) { self.id_array[down_index].1 = Some(expr_id); let count = self.expr_stats.entry(expr_id).or_insert(0); *count += 1; @@ -981,7 +1011,8 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { self.found_common = true; } } - self.visit_stack.push(VisitRecord::ExprItem(expr_id)); + self.visit_stack + .push(VisitRecord::ExprItem(expr_id, is_valid)); self.up_index += 1; Ok(TreeNodeRecursion::Continue) @@ -1015,19 +1046,22 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { self.alias_counter += 1; } - // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate - // the `id_array`, which records the expr's identifier used to rewrite expr. So if we + // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate the + // `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - if expr.short_circuits() { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); - } + let is_tree = expr.short_circuits(); + let tnr = if is_tree { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }; let (up_index, expr_id) = self.id_array[self.down_index]; self.down_index += 1; // skip `Expr`s without identifier (empty identifier). let Some(expr_id) = expr_id else { - return Ok(Transformed::no(expr)); + return Ok(Transformed::new(expr, false, tnr)); }; let count = self.expr_stats.get(&expr_id).unwrap(); @@ -1055,7 +1089,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_, '_> { Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)) } else { - Ok(Transformed::no(expr)) + Ok(Transformed::new(expr, false, tnr)) } } @@ -1093,6 +1127,7 @@ fn replace_common_expr<'n>( #[cfg(test)] mod test { + use std::any::Any; use std::collections::HashSet; use std::iter; @@ -1100,8 +1135,9 @@ mod test { use datafusion_expr::expr::AggregateFunction; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Signature, - SimpleAggregateUDF, Volatility, + grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, + ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; @@ -1802,4 +1838,124 @@ mod test { assert!(result.len() == 1); Ok(()) } + + #[test] + fn test_short_circuits() -> Result<()> { + let table_scan = test_table_scan()?; + + let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0))); + let not_extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0)); + let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0)); + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![ + extracted_short_circuit.clone().alias("c1"), + extracted_short_circuit.alias("c2"), + not_extracted_short_circuit_leg_1.clone().alias("c3"), + not_extracted_short_circuit_leg_2.clone().alias("c4"), + not_extracted_short_circuit_leg_1 + .or(not_extracted_short_circuit_leg_2) + .alias("c5"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, test.a + test.b = Int32(0) AS c3, test.a - test.b = Int32(0) AS c4, test.a + test.b = Int32(0) OR test.a - test.b = Int32(0) AS c5\ + \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_volatile() -> Result<()> { + let table_scan = test_table_scan()?; + + let extracted_child = col("a") + col("b"); + let rand = rand_func().call(vec![]); + let not_extracted_volatile = extracted_child + rand; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![ + not_extracted_volatile.clone().alias("c1"), + not_extracted_volatile.alias("c2"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_volatile_short_circuits() -> Result<()> { + let table_scan = test_table_scan()?; + + let rand = rand_func().call(vec![]); + let not_extracted_volatile_short_circuit_2 = + rand.clone().eq(lit(0)).or(col("b").eq(lit(0))); + let not_extracted_volatile_short_circuit_1 = + col("a").eq(lit(0)).or(rand.eq(lit(0))); + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![ + not_extracted_volatile_short_circuit_1.clone().alias("c1"), + not_extracted_volatile_short_circuit_1.alias("c2"), + not_extracted_volatile_short_circuit_2.clone().alias("c3"), + not_extracted_volatile_short_circuit_2.alias("c4"), + ])? + .build()?; + + let expected = "Projection: test.a = Int32(0) OR random() = Int32(0) AS c1, test.a = Int32(0) OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\ + \n TableScan: test"; + + assert_non_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + /// returns a "random" function that is marked volatile (aka each invocation + /// returns a different value) + /// + /// Does not use datafusion_functions::rand to avoid introducing a + /// dependency on that crate. + fn rand_func() -> ScalarUDF { + ScalarUDF::new_from_impl(RandomStub::new()) + } + + #[derive(Debug)] + struct RandomStub { + signature: Signature, + } + + impl RandomStub { + fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } + } + impl ScalarUDFImpl for RandomStub { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "random" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 09407aed53cd..3732f7ed90c8 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -21,7 +21,9 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::{Distinct, LogicalPlan, Union}; +use itertools::Itertools; use std::sync::Arc; #[derive(Default)] @@ -56,32 +58,34 @@ impl OptimizerRule for EliminateNestedUnion { match plan { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs - .iter() + .into_iter() .flat_map(extract_plans_from_union) .collect::>(); Ok(Transformed::yes(LogicalPlan::Union(Union { - inputs, + inputs: inputs.into_iter().map(Arc::new).collect_vec(), schema, }))) } - LogicalPlan::Distinct(Distinct::All(ref nested_plan)) => { - match nested_plan.as_ref() { + LogicalPlan::Distinct(Distinct::All(nested_plan)) => { + match unwrap_arc(nested_plan) { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs - .iter() + .into_iter() .map(extract_plan_from_distinct) .flat_map(extract_plans_from_union) .collect::>(); Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All( Arc::new(LogicalPlan::Union(Union { - inputs, + inputs: inputs.into_iter().map(Arc::new).collect_vec(), schema: schema.clone(), })), )))) } - _ => Ok(Transformed::no(plan)), + nested_plan => Ok(Transformed::no(LogicalPlan::Distinct( + Distinct::All(Arc::new(nested_plan)), + ))), } } _ => Ok(Transformed::no(plan)), @@ -89,20 +93,20 @@ impl OptimizerRule for EliminateNestedUnion { } } -fn extract_plans_from_union(plan: &Arc) -> Vec> { - match plan.as_ref() { +fn extract_plans_from_union(plan: Arc) -> Vec { + match unwrap_arc(plan) { LogicalPlan::Union(Union { inputs, schema }) => inputs - .iter() - .map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap())) + .into_iter() + .map(|plan| coerce_plan_expr_for_schema(&plan, &schema).unwrap()) .collect::>(), - _ => vec![plan.clone()], + plan => vec![plan], } } -fn extract_plan_from_distinct(plan: &Arc) -> &Arc { - match plan.as_ref() { +fn extract_plan_from_distinct(plan: Arc) -> Arc { + match unwrap_arc(plan) { LogicalPlan::Distinct(Distinct::All(plan)) => plan, - _ => plan, + plan => Arc::new(plan), } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index ccc637a0eb01..13c483c6dfcc 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -18,6 +18,7 @@ //! [`EliminateOuterJoin`] converts `LEFT/RIGHT/FULL` joins to `INNER` joins use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, Result}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; use datafusion_expr::{Expr, Filter, Operator}; @@ -78,7 +79,7 @@ impl OptimizerRule for EliminateOuterJoin { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Filter(filter) => match filter.input.as_ref() { + LogicalPlan::Filter(mut filter) => match unwrap_arc(filter.input) { LogicalPlan::Join(join) => { let mut non_nullable_cols: Vec = vec![]; @@ -109,9 +110,10 @@ impl OptimizerRule for EliminateOuterJoin { } else { join.join_type }; + let new_join = Arc::new(LogicalPlan::Join(Join { - left: Arc::new((*join.left).clone()), - right: Arc::new((*join.right).clone()), + left: join.left, + right: join.right, join_type: new_join_type, join_constraint: join.join_constraint, on: join.on.clone(), @@ -122,7 +124,10 @@ impl OptimizerRule for EliminateOuterJoin { Filter::try_new(filter.predicate, new_join) .map(|f| Transformed::yes(LogicalPlan::Filter(f))) } - _ => Ok(Transformed::no(LogicalPlan::Filter(filter))), + filter_input => { + filter.input = Arc::new(filter_input); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } }, _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 576dabe305e6..88bd1b17883b 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -182,7 +182,9 @@ impl OptimizerRule for PropagateEmptyRelation { }, ))) } else if new_inputs.len() == 1 { - let child = unwrap_arc(new_inputs[0].clone()); + let mut new_inputs = new_inputs; + let input_plan = new_inputs.pop().unwrap(); // length checked + let child = unwrap_arc(input_plan); if child.schema().eq(plan.schema()) { Ok(Transformed::yes(child)) } else { diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index fa432ad76de5..1c3186b762b7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -134,7 +134,15 @@ use crate::{OptimizerConfig, OptimizerRule}; #[derive(Default)] pub struct PushDownFilter {} -/// For a given JOIN type, determine whether each side of the join is preserved. +/// For a given JOIN type, determine whether each input of the join is preserved +/// for post-join (`WHERE` clause) filters. +/// +/// It is only correct to push filters below a join for preserved inputs. +/// +/// # Return Value +/// A tuple of booleans - (left_preserved, right_preserved). +/// +/// # "Preserved" input definition /// /// We say a join side is preserved if the join returns all or a subset of the rows from /// the relevant side, such that each row of the output table directly maps to a row of @@ -145,15 +153,11 @@ pub struct PushDownFilter {} /// For example: /// - In an inner join, both sides are preserved, because each row of the output /// maps directly to a row from each side. -/// - In a left join, the left side is preserved and the right is not, because -/// there may be rows in the output that don't directly map to a row in the -/// right input (due to nulls filling where there is no match on the right). /// -/// This is important because we can always push down post-join filters to a preserved -/// side of the join, assuming the filter only references columns from that side. For the -/// non-preserved side it can be more tricky. -/// -/// Returns a tuple of booleans - (left_preserved, right_preserved). +/// - In a left join, the left side is preserved (we can push predicates) but +/// the right is not, because there may be rows in the output that don't +/// directly map to a row in the right input (due to nulls filling where there +/// is no match on the right). fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { match join_type { JoinType::Inner => Ok((true, true)), @@ -169,9 +173,15 @@ fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { } } -/// For a given JOIN logical plan, determine whether each side of the join is preserved -/// in terms on join filtering. -/// Predicates from join filter can only be pushed to preserved join side. +/// For a given JOIN type, determine whether each input of the join is preserved +/// for the join condition (`ON` clause filters). +/// +/// It is only correct to push filters below a join for preserved inputs. +/// +/// # Return Value +/// A tuple of booleans - (left_preserved, right_preserved). +/// +/// See [`lr_is_preserved`] for a definition of "preserved". fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { match join_type { JoinType::Inner => Ok((true, true)), @@ -184,13 +194,50 @@ fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { } } -/// Determine which predicates in state can be pushed down to a given side of a join. -/// To determine this, we need to know the schema of the relevant join side and whether -/// or not the side's rows are preserved when joining. If the side is not preserved, we -/// do not push down anything. Otherwise we can push down predicates where all of the -/// relevant columns are contained on the relevant join side's schema. -fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result { - let schema_columns = schema +/// Evaluates the columns referenced in the given expression to see if they refer +/// only to the left or right columns +#[derive(Debug)] +struct ColumnChecker<'a> { + /// schema of left join input + left_schema: &'a DFSchema, + /// columns in left_schema, computed on demand + left_columns: Option>, + /// schema of right join input + right_schema: &'a DFSchema, + /// columns in left_schema, computed on demand + right_columns: Option>, +} + +impl<'a> ColumnChecker<'a> { + fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self { + Self { + left_schema, + left_columns: None, + right_schema, + right_columns: None, + } + } + + /// Return true if the expression references only columns from the left side of the join + fn is_left_only(&mut self, predicate: &Expr) -> bool { + if self.left_columns.is_none() { + self.left_columns = Some(schema_columns(self.left_schema)); + } + has_all_column_refs(predicate, self.left_columns.as_ref().unwrap()) + } + + /// Return true if the expression references only columns from the right side of the join + fn is_right_only(&mut self, predicate: &Expr) -> bool { + if self.right_columns.is_none() { + self.right_columns = Some(schema_columns(self.right_schema)); + } + has_all_column_refs(predicate, self.right_columns.as_ref().unwrap()) + } +} + +/// Returns all columns in the schema +fn schema_columns(schema: &DFSchema) -> HashSet { + schema .iter() .flat_map(|(qualifier, field)| { [ @@ -199,8 +246,7 @@ fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result>(); - Ok(has_all_column_refs(predicate, &schema_columns)) + .collect::>() } /// Determine whether the predicate can evaluate as the join conditions @@ -285,16 +331,7 @@ fn extract_or_clauses_for_join<'a>( filters: &'a [Expr], schema: &'a DFSchema, ) -> impl Iterator + 'a { - let schema_columns = schema - .iter() - .flat_map(|(qualifier, field)| { - [ - Column::new(qualifier.cloned(), field.name()), - // we need to push down filter using unqualified column as well - Column::new_unqualified(field.name()), - ] - }) - .collect::>(); + let schema_columns = schema_columns(schema); // new formed OR clauses and their column references filters.iter().filter_map(move |expr| { @@ -397,12 +434,11 @@ fn push_down_all_join( let mut right_push = vec![]; let mut keep_predicates = vec![]; let mut join_conditions = vec![]; + let mut checker = ColumnChecker::new(left_schema, right_schema); for predicate in predicates { - if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { + if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved - && can_pushdown_join_predicate(&predicate, right_schema)? - { + } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate @@ -415,26 +451,24 @@ fn push_down_all_join( // For infer predicates, if they can not push through join, just drop them for predicate in inferred_join_predicates { - if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { + if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved - && can_pushdown_join_predicate(&predicate, right_schema)? - { + } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } } + let mut on_filter_join_conditions = vec![]; + let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?; + if !on_filter.is_empty() { - let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?; for on in on_filter { - if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? { + if on_left_preserved && checker.is_left_only(&on) { left_push.push(on) - } else if on_right_preserved - && can_pushdown_join_predicate(&on, right_schema)? - { + } else if on_right_preserved && checker.is_right_only(&on) { right_push.push(on) } else { - join_conditions.push(on) + on_filter_join_conditions.push(on) } } } @@ -450,6 +484,21 @@ fn push_down_all_join( right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema)); } + // For predicates from join filter, we should check with if a join side is preserved + // in term of join filtering. + if on_left_preserved { + left_push.extend(extract_or_clauses_for_join( + &on_filter_join_conditions, + left_schema, + )); + } + if on_right_preserved { + right_push.extend(extract_or_clauses_for_join( + &on_filter_join_conditions, + right_schema, + )); + } + if let Some(predicate) = conjunction(left_push) { join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); } @@ -459,6 +508,7 @@ fn push_down_all_join( } // Add any new join conditions as the non join predicates + join_conditions.extend(on_filter_join_conditions); join.filter = conjunction(join_conditions); // wrap the join on the filter whose predicates must be kept, if any diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index b3562b7065e1..7c66d659cbaf 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -396,8 +396,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\ + let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -419,7 +419,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -437,7 +437,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -456,7 +456,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -470,8 +470,8 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64;N]\ + let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -487,8 +487,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64;N]\ + let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -507,7 +507,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -534,8 +534,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), MAX(alias1)]] [a:UInt32, count(alias1):Int64;N, MAX(alias1):UInt32;N]\ + let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), MAX(alias1)]] [a:UInt32, count(alias1):Int64, MAX(alias1):UInt32;N]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -554,7 +554,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64;N, count(test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -569,8 +569,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ + let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64]\ + \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -599,8 +599,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64;N, MAX(alias1):UInt32;N]\ + let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), MAX(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, MAX(alias1):UInt32;N]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -608,7 +608,7 @@ mod tests { } #[test] - fn one_distinctand_and_two_common() -> Result<()> { + fn one_distinct_and_two_common() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -618,8 +618,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, count(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, count(alias1):Int64;N]\ + let expected = "Projection: test.a, sum(alias2) AS sum(test.c), MAX(alias3) AS MAX(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, MAX(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), MAX(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, MAX(alias3):UInt32;N, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -637,8 +637,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, count(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), count(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, count(alias1):Int64;N]\ + let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), count(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -662,7 +662,7 @@ mod tests { .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -682,7 +682,7 @@ mod tests { .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -705,7 +705,7 @@ mod tests { .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a]:UInt64;N, count(DISTINCT test.b):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -725,7 +725,7 @@ mod tests { .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -746,7 +746,7 @@ mod tests { .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index c501d5aaa4bf..c0863839dba1 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -294,6 +294,21 @@ fn eliminate_nested_filters() { assert_eq!(expected, format!("{plan:?}")); } +#[test] +fn eliminate_redundant_null_check_on_count() { + let sql = "\ + SELECT col_int32, count(*) c + FROM test + GROUP BY col_int32 + HAVING c IS NOT NULL"; + let plan = test_sql(sql).unwrap(); + let expected = "\ + Projection: test.col_int32, count(*) AS c\ + \n Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1)) AS count(*)]]\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{plan:?}")); +} + #[test] fn test_propagate_empty_relation_inner_join_and_unions() { let sql = "\ diff --git a/datafusion/physical-expr-common/src/aggregate/merge_arrays.rs b/datafusion/physical-expr-common/src/aggregate/merge_arrays.rs new file mode 100644 index 000000000000..544bdc182829 --- /dev/null +++ b/datafusion/physical-expr-common/src/aggregate/merge_arrays.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::SortOptions; +use datafusion_common::utils::compare_rows; +use datafusion_common::{exec_err, ScalarValue}; +use std::cmp::Ordering; +use std::collections::{BinaryHeap, VecDeque}; + +/// This is a wrapper struct to be able to correctly merge `ARRAY_AGG` data from +/// multiple partitions using `BinaryHeap`. When used inside `BinaryHeap`, this +/// struct returns smallest `CustomElement`, where smallest is determined by +/// `ordering` values (`Vec`) according to `sort_options`. +#[derive(Debug, PartialEq, Eq)] +struct CustomElement<'a> { + /// Stores the partition this entry came from + branch_idx: usize, + /// Values to merge + value: ScalarValue, + // Comparison "key" + ordering: Vec, + /// Options defining the ordering semantics + sort_options: &'a [SortOptions], +} + +impl<'a> CustomElement<'a> { + fn new( + branch_idx: usize, + value: ScalarValue, + ordering: Vec, + sort_options: &'a [SortOptions], + ) -> Self { + Self { + branch_idx, + value, + ordering, + sort_options, + } + } + + fn ordering( + &self, + current: &[ScalarValue], + target: &[ScalarValue], + ) -> datafusion_common::Result { + // Calculate ordering according to `sort_options` + compare_rows(current, target, self.sort_options) + } +} + +// Overwrite ordering implementation such that +// - `self.ordering` values are used for comparison, +// - When used inside `BinaryHeap` it is a min-heap. +impl<'a> Ord for CustomElement<'a> { + fn cmp(&self, other: &Self) -> Ordering { + // Compares according to custom ordering + self.ordering(&self.ordering, &other.ordering) + // Convert max heap to min heap + .map(|ordering| ordering.reverse()) + // This function return error, when `self.ordering` and `other.ordering` + // have different types (such as one is `ScalarValue::Int64`, other is `ScalarValue::Float32`) + // Here this case won't happen, because data from each partition will have same type + .unwrap() + } +} + +impl<'a> PartialOrd for CustomElement<'a> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// This functions merges `values` array (`&[Vec]`) into single array `Vec` +/// Merging done according to ordering values stored inside `ordering_values` (`&[Vec>]`) +/// Inner `Vec` in the `ordering_values` can be thought as ordering information for the +/// each `ScalarValue` in the `values` array. +/// Desired ordering specified by `sort_options` argument (Should have same size with inner `Vec` +/// of the `ordering_values` array). +/// +/// As an example +/// values can be \[ +/// \[1, 2, 3, 4, 5\], +/// \[1, 2, 3, 4\], +/// \[1, 2, 3, 4, 5, 6\], +/// \] +/// In this case we will be merging three arrays (doesn't have to be same size) +/// and produce a merged array with size 15 (sum of 5+4+6) +/// Merging will be done according to ordering at `ordering_values` vector. +/// As an example `ordering_values` can be [ +/// \[(1, a), (2, b), (3, b), (4, a), (5, b) \], +/// \[(1, a), (2, b), (3, b), (4, a) \], +/// \[(1, b), (2, c), (3, d), (4, e), (5, a), (6, b) \], +/// ] +/// For each ScalarValue in the `values` we have a corresponding `Vec` (like timestamp of it) +/// for the example above `sort_options` will have size two, that defines ordering requirement of the merge. +/// Inner `Vec`s of the `ordering_values` will be compared according `sort_options` (Their sizes should match) +pub fn merge_ordered_arrays( + // We will merge values into single `Vec`. + values: &mut [VecDeque], + // `values` will be merged according to `ordering_values`. + // Inner `Vec` can be thought as ordering information for the + // each `ScalarValue` in the values`. + ordering_values: &mut [VecDeque>], + // Defines according to which ordering comparisons should be done. + sort_options: &[SortOptions], +) -> datafusion_common::Result<(Vec, Vec>)> { + // Keep track the most recent data of each branch, in binary heap data structure. + let mut heap = BinaryHeap::::new(); + + if values.len() != ordering_values.len() + || values + .iter() + .zip(ordering_values.iter()) + .any(|(vals, ordering_vals)| vals.len() != ordering_vals.len()) + { + return exec_err!( + "Expects values arguments and/or ordering_values arguments to have same size" + ); + } + let n_branch = values.len(); + let mut merged_values = vec![]; + let mut merged_orderings = vec![]; + // Continue iterating the loop until consuming data of all branches. + loop { + let minimum = if let Some(minimum) = heap.pop() { + minimum + } else { + // Heap is empty, fill it with the next entries from each branch. + for branch_idx in 0..n_branch { + if let Some(orderings) = ordering_values[branch_idx].pop_front() { + // Their size should be same, we can safely .unwrap here. + let value = values[branch_idx].pop_front().unwrap(); + // Push the next element to the heap: + heap.push(CustomElement::new( + branch_idx, + value, + orderings, + sort_options, + )); + } + // If None, we consumed this branch, skip it. + } + + // Now we have filled the heap, get the largest entry (this will be + // the next element in merge). + if let Some(minimum) = heap.pop() { + minimum + } else { + // Heap is empty, this means that all indices are same with + // `end_indices`. We have consumed all of the branches, merge + // is completed, exit from the loop: + break; + } + }; + let CustomElement { + branch_idx, + value, + ordering, + .. + } = minimum; + // Add minimum value in the heap to the result + merged_values.push(value); + merged_orderings.push(ordering); + + // If there is an available entry, push next entry in the most + // recently consumed branch to the heap. + if let Some(orderings) = ordering_values[branch_idx].pop_front() { + // Their size should be same, we can safely .unwrap here. + let value = values[branch_idx].pop_front().unwrap(); + // Push the next element to the heap: + heap.push(CustomElement::new( + branch_idx, + value, + orderings, + sort_options, + )); + } + } + + Ok((merged_values, merged_orderings)) +} diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 336e28b4d28e..35666f199ace 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -17,6 +17,7 @@ pub mod count_distinct; pub mod groups_accumulator; +pub mod merge_arrays; pub mod stats; pub mod tdigest; pub mod utils; @@ -221,6 +222,17 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { ) -> Option> { None } + + /// If this function is max, return (output_field, true) + /// if the function is min, return (output_field, false) + /// otherwise return None (the default) + /// + /// output_field is the name of the column produced by this aggregate + /// + /// Note: this is used to use special aggregate implementations in certain conditions + fn get_minmax_desc(&self) -> Option<(Field, bool)> { + None + } } /// Stores the physical expressions used inside the `AggregateExpr`. diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index 6d5ba737a1df..bff571f5b5be 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -255,7 +255,7 @@ where /// the same output type pub fn take(&mut self) -> Self { let mut new_self = Self::new(self.output_type); - std::mem::swap(self, &mut new_self); + mem::swap(self, &mut new_self); new_self } @@ -538,7 +538,7 @@ where /// this set, not including `self` pub fn size(&self) -> usize { self.map_size - + self.buffer.capacity() * std::mem::size_of::() + + self.buffer.capacity() * mem::size_of::() + self.offsets.allocated_size() + self.hashes_buffer.allocated_size() } diff --git a/datafusion/physical-expr-common/src/expressions/cast.rs b/datafusion/physical-expr-common/src/expressions/cast.rs index 31b96889fd62..dd6131ad65c3 100644 --- a/datafusion/physical-expr-common/src/expressions/cast.rs +++ b/datafusion/physical-expr-common/src/expressions/cast.rs @@ -36,6 +36,11 @@ const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { format_options: DEFAULT_FORMAT_OPTIONS, }; +const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { + safe: true, + format_options: DEFAULT_FORMAT_OPTIONS, +}; + /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast #[derive(Debug, Clone)] pub struct CastExpr { @@ -150,9 +155,9 @@ impl PhysicalExpr for CastExpr { let child_interval = children[0]; // Get child's datatype: let cast_type = child_interval.data_type(); - Ok(Some( - vec![interval.cast_to(&cast_type, &self.cast_options)?], - )) + Ok(Some(vec![ + interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)? + ])) } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -366,9 +371,9 @@ mod tests { generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 3), + Decimal128(10, 3), Decimal128Array, - DataType::Decimal128(20, 6), + Decimal128(20, 6), [ Some(1_234_000), Some(2_222_000), @@ -387,9 +392,9 @@ mod tests { generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 3), + Decimal128(10, 3), Decimal128Array, - DataType::Decimal128(10, 2), + Decimal128(10, 2), [Some(123), Some(222), Some(0), Some(400), Some(500), None], None ); @@ -408,9 +413,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int8Array, - DataType::Int8, + Int8, [ Some(1_i8), Some(2_i8), @@ -430,9 +435,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int16Array, - DataType::Int16, + Int16, [ Some(1_i16), Some(2_i16), @@ -452,9 +457,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int32Array, - DataType::Int32, + Int32, [ Some(1_i32), Some(2_i32), @@ -473,9 +478,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int64Array, - DataType::Int64, + Int64, [ Some(1_i64), Some(2_i64), @@ -503,9 +508,9 @@ mod tests { .with_precision_and_scale(10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 3), + Decimal128(10, 3), Float32Array, - DataType::Float32, + Float32, [ Some(1.234_f32), Some(2.222_f32), @@ -524,9 +529,9 @@ mod tests { .with_precision_and_scale(20, 6)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(20, 6), + Decimal128(20, 6), Float64Array, - DataType::Float64, + Float64, [ Some(0.001234_f64), Some(0.002222_f64), @@ -545,10 +550,10 @@ mod tests { // int8 generic_test_cast!( Int8Array, - DataType::Int8, + Int8, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(3, 0), + Decimal128(3, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -556,10 +561,10 @@ mod tests { // int16 generic_test_cast!( Int16Array, - DataType::Int16, + Int16, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(5, 0), + Decimal128(5, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -567,10 +572,10 @@ mod tests { // int32 generic_test_cast!( Int32Array, - DataType::Int32, + Int32, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -578,10 +583,10 @@ mod tests { // int64 generic_test_cast!( Int64Array, - DataType::Int64, + Int64, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(20, 0), + Decimal128(20, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -589,10 +594,10 @@ mod tests { // int64 to different scale generic_test_cast!( Int64Array, - DataType::Int64, + Int64, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(20, 2), + Decimal128(20, 2), [Some(100), Some(200), Some(300), Some(400), Some(500)], None ); @@ -600,10 +605,10 @@ mod tests { // float32 generic_test_cast!( Float32Array, - DataType::Float32, + Float32, vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, - DataType::Decimal128(10, 2), + Decimal128(10, 2), [Some(150), Some(250), Some(300), Some(112), Some(550)], None ); @@ -611,10 +616,10 @@ mod tests { // float64 generic_test_cast!( Float64Array, - DataType::Float64, + Float64, vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, - DataType::Decimal128(20, 4), + Decimal128(20, 4), [ Some(15000), Some(25000), @@ -631,10 +636,10 @@ mod tests { fn test_cast_i32_u32() -> Result<()> { generic_test_cast!( Int32Array, - DataType::Int32, + Int32, vec![1, 2, 3, 4, 5], UInt32Array, - DataType::UInt32, + UInt32, [ Some(1_u32), Some(2_u32), @@ -651,10 +656,10 @@ mod tests { fn test_cast_i32_utf8() -> Result<()> { generic_test_cast!( Int32Array, - DataType::Int32, + Int32, vec![1, 2, 3, 4, 5], StringArray, - DataType::Utf8, + Utf8, [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")], None ); @@ -670,10 +675,10 @@ mod tests { .collect(); generic_test_cast!( Int64Array, - DataType::Int64, + Int64, original, TimestampNanosecondArray, - DataType::Timestamp(TimeUnit::Nanosecond, None), + Timestamp(TimeUnit::Nanosecond, None), expected, None ); @@ -683,7 +688,7 @@ mod tests { #[test] fn invalid_cast() { // Ensure a useful error happens at plan time if invalid casts are used - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema = Schema::new(vec![Field::new("a", Int32, false)]); let result = cast( col("a", &schema).unwrap(), @@ -696,11 +701,10 @@ mod tests { #[test] fn invalid_cast_with_options_error() -> Result<()> { // Ensure a useful error happens at plan time if invalid casts are used - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let schema = Schema::new(vec![Field::new("a", Utf8, false)]); let a = StringArray::from(vec!["9.1"]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let expression = - cast_with_options(col("a", &schema)?, &schema, DataType::Int32, None)?; + let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?; let result = expression.evaluate(&batch); match result { @@ -717,15 +721,11 @@ mod tests { #[test] #[ignore] // TODO: https://github.com/apache/datafusion/issues/5396 fn test_cast_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + let schema = Schema::new(vec![Field::new("a", Int64, false)]); let a = Int64Array::from(vec![100]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let expression = cast_with_options( - col("a", &schema)?, - &schema, - DataType::Decimal128(38, 38), - None, - )?; + let expression = + cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?; expression.evaluate(&batch)?; Ok(()) } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index f637355519af..8fb1356a8092 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -39,6 +39,13 @@ pub struct PhysicalSortExpr { pub options: SortOptions, } +impl PhysicalSortExpr { + /// Create a new PhysicalSortExpr + pub fn new(expr: Arc, options: SortOptions) -> Self { + Self { expr, options } + } +} + impl PartialEq for PhysicalSortExpr { fn eq(&self, other: &PhysicalSortExpr) -> bool { self.options == other.options && self.expr.eq(&other.expr) @@ -155,10 +162,7 @@ impl From for PhysicalSortExpr { descending: false, nulls_first: false, }); - PhysicalSortExpr { - expr: value.expr, - options, - } + PhysicalSortExpr::new(value.expr, options) } } @@ -281,16 +285,13 @@ pub fn limited_convert_logical_sort_exprs_to_physical( let Expr::Sort(sort) = expr else { return exec_err!("Expects to receive sort expression"); }; - sort_exprs.push(PhysicalSortExpr { - expr: limited_convert_logical_expr_to_physical_expr( - sort.expr.as_ref(), - schema, - )?, - options: SortOptions { + sort_exprs.push(PhysicalSortExpr::new( + limited_convert_logical_expr_to_physical_expr(sort.expr.as_ref(), schema)?, + SortOptions { descending: !sort.asc, nulls_first: sort.nulls_first, }, - }); + )) } Ok(sort_exprs) } diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index d5cd3c6f4af0..44622bd309df 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -104,10 +104,7 @@ pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec { order_bys .iter() - .map(|e| PhysicalSortExpr { - expr: e.expr.clone(), - options: !e.options, - }) + .map(|e| PhysicalSortExpr::new(e.expr.clone(), !e.options)) .collect() } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index c5a0662a2283..634a0a017903 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -91,7 +91,7 @@ impl AggregateExpr for ArrayAgg { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { @@ -137,7 +137,7 @@ impl Accumulator for ArrayAggAccumulator { return Ok(()); } assert!(values.len() == 1, "array_agg can only take 1 param!"); - let val = values[0].clone(); + let val = Arc::clone(&values[0]); self.values.push(val); Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index fc838196de20..a59d85e84a20 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -95,7 +95,7 @@ impl AggregateExpr for DistinctArrayAgg { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 1234ab40c188..a64d97637c3b 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -19,8 +19,7 @@ //! that can evaluated at runtime during query execution use std::any::Any; -use std::cmp::Ordering; -use std::collections::{BinaryHeap, VecDeque}; +use std::collections::VecDeque; use std::fmt::Debug; use std::sync::Arc; @@ -33,11 +32,12 @@ use crate::{ use arrow::datatypes::{DataType, Field}; use arrow_array::cast::AsArray; use arrow_array::{new_empty_array, Array, ArrayRef, StructArray}; -use arrow_schema::{Fields, SortOptions}; -use datafusion_common::utils::{array_into_list_array, compare_rows, get_row_at_idx}; +use arrow_schema::Fields; +use datafusion_common::utils::{array_into_list_array, get_row_at_idx}; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::Accumulator; +use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; /// Expression for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi /// partition setting, partial aggregations are computed for every partition, @@ -127,7 +127,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { @@ -146,7 +146,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { Some(Arc::new(Self { name: self.name.to_string(), input_data_type: self.input_data_type.clone(), - expr: self.expr.clone(), + expr: Arc::clone(&self.expr), nullable: self.nullable, order_by_data_types: self.order_by_data_types.clone(), // Reverse requirement: @@ -384,179 +384,6 @@ impl OrderSensitiveArrayAggAccumulator { } } -/// This is a wrapper struct to be able to correctly merge `ARRAY_AGG` data from -/// multiple partitions using `BinaryHeap`. When used inside `BinaryHeap`, this -/// struct returns smallest `CustomElement`, where smallest is determined by -/// `ordering` values (`Vec`) according to `sort_options`. -#[derive(Debug, PartialEq, Eq)] -struct CustomElement<'a> { - /// Stores the partition this entry came from - branch_idx: usize, - /// Values to merge - value: ScalarValue, - // Comparison "key" - ordering: Vec, - /// Options defining the ordering semantics - sort_options: &'a [SortOptions], -} - -impl<'a> CustomElement<'a> { - fn new( - branch_idx: usize, - value: ScalarValue, - ordering: Vec, - sort_options: &'a [SortOptions], - ) -> Self { - Self { - branch_idx, - value, - ordering, - sort_options, - } - } - - fn ordering( - &self, - current: &[ScalarValue], - target: &[ScalarValue], - ) -> Result { - // Calculate ordering according to `sort_options` - compare_rows(current, target, self.sort_options) - } -} - -// Overwrite ordering implementation such that -// - `self.ordering` values are used for comparison, -// - When used inside `BinaryHeap` it is a min-heap. -impl<'a> Ord for CustomElement<'a> { - fn cmp(&self, other: &Self) -> Ordering { - // Compares according to custom ordering - self.ordering(&self.ordering, &other.ordering) - // Convert max heap to min heap - .map(|ordering| ordering.reverse()) - // This function return error, when `self.ordering` and `other.ordering` - // have different types (such as one is `ScalarValue::Int64`, other is `ScalarValue::Float32`) - // Here this case won't happen, because data from each partition will have same type - .unwrap() - } -} - -impl<'a> PartialOrd for CustomElement<'a> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -/// This functions merges `values` array (`&[Vec]`) into single array `Vec` -/// Merging done according to ordering values stored inside `ordering_values` (`&[Vec>]`) -/// Inner `Vec` in the `ordering_values` can be thought as ordering information for the -/// each `ScalarValue` in the `values` array. -/// Desired ordering specified by `sort_options` argument (Should have same size with inner `Vec` -/// of the `ordering_values` array). -/// -/// As an example -/// values can be \[ -/// \[1, 2, 3, 4, 5\], -/// \[1, 2, 3, 4\], -/// \[1, 2, 3, 4, 5, 6\], -/// \] -/// In this case we will be merging three arrays (doesn't have to be same size) -/// and produce a merged array with size 15 (sum of 5+4+6) -/// Merging will be done according to ordering at `ordering_values` vector. -/// As an example `ordering_values` can be [ -/// \[(1, a), (2, b), (3, b), (4, a), (5, b) \], -/// \[(1, a), (2, b), (3, b), (4, a) \], -/// \[(1, b), (2, c), (3, d), (4, e), (5, a), (6, b) \], -/// ] -/// For each ScalarValue in the `values` we have a corresponding `Vec` (like timestamp of it) -/// for the example above `sort_options` will have size two, that defines ordering requirement of the merge. -/// Inner `Vec`s of the `ordering_values` will be compared according `sort_options` (Their sizes should match) -pub(crate) fn merge_ordered_arrays( - // We will merge values into single `Vec`. - values: &mut [VecDeque], - // `values` will be merged according to `ordering_values`. - // Inner `Vec` can be thought as ordering information for the - // each `ScalarValue` in the values`. - ordering_values: &mut [VecDeque>], - // Defines according to which ordering comparisons should be done. - sort_options: &[SortOptions], -) -> Result<(Vec, Vec>)> { - // Keep track the most recent data of each branch, in binary heap data structure. - let mut heap = BinaryHeap::::new(); - - if values.len() != ordering_values.len() - || values - .iter() - .zip(ordering_values.iter()) - .any(|(vals, ordering_vals)| vals.len() != ordering_vals.len()) - { - return exec_err!( - "Expects values arguments and/or ordering_values arguments to have same size" - ); - } - let n_branch = values.len(); - let mut merged_values = vec![]; - let mut merged_orderings = vec![]; - // Continue iterating the loop until consuming data of all branches. - loop { - let minimum = if let Some(minimum) = heap.pop() { - minimum - } else { - // Heap is empty, fill it with the next entries from each branch. - for branch_idx in 0..n_branch { - if let Some(orderings) = ordering_values[branch_idx].pop_front() { - // Their size should be same, we can safely .unwrap here. - let value = values[branch_idx].pop_front().unwrap(); - // Push the next element to the heap: - heap.push(CustomElement::new( - branch_idx, - value, - orderings, - sort_options, - )); - } - // If None, we consumed this branch, skip it. - } - - // Now we have filled the heap, get the largest entry (this will be - // the next element in merge). - if let Some(minimum) = heap.pop() { - minimum - } else { - // Heap is empty, this means that all indices are same with - // `end_indices`. We have consumed all of the branches, merge - // is completed, exit from the loop: - break; - } - }; - let CustomElement { - branch_idx, - value, - ordering, - .. - } = minimum; - // Add minimum value in the heap to the result - merged_values.push(value); - merged_orderings.push(ordering); - - // If there is an available entry, push next entry in the most - // recently consumed branch to the heap. - if let Some(orderings) = ordering_values[branch_idx].pop_front() { - // Their size should be same, we can safely .unwrap here. - let value = values[branch_idx].pop_front().unwrap(); - // Push the next element to the heap: - heap.push(CustomElement::new( - branch_idx, - value, - orderings, - sort_options, - )); - } - } - - Ok((merged_values, merged_orderings)) -} - #[cfg(test)] mod tests { use std::collections::VecDeque; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index adbbbd3e631e..d4cd3d51d174 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,10 +30,10 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_common::{not_impl_err, Result}; use datafusion_expr::AggregateFunction; -use crate::expressions::{self, Literal}; +use crate::expressions::{self}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; /// Create a physical aggregation expression. @@ -61,7 +61,7 @@ pub fn create_aggregate_expr( let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { (AggregateFunction::ArrayAgg, false) => { - let expr = input_phy_exprs[0].clone(); + let expr = Arc::clone(&input_phy_exprs[0]); let nullable = expr.nullable(input_schema)?; if ordering_req.is_empty() { @@ -83,7 +83,7 @@ pub fn create_aggregate_expr( "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" ); } - let expr = input_phy_exprs[0].clone(); + let expr = Arc::clone(&input_phy_exprs[0]); let is_expr_nullable = expr.nullable(input_schema)?; Arc::new(expressions::DistinctArrayAgg::new( expr, @@ -93,35 +93,15 @@ pub fn create_aggregate_expr( )) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( - input_phy_exprs[0].clone(), + Arc::clone(&input_phy_exprs[0]), name, data_type, )), (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( - input_phy_exprs[0].clone(), + Arc::clone(&input_phy_exprs[0]), name, data_type, )), - (AggregateFunction::NthValue, _) => { - let expr = &input_phy_exprs[0]; - let Some(n) = input_phy_exprs[1] - .as_any() - .downcast_ref::() - .map(|literal| literal.value()) - else { - return exec_err!("Second argument of NTH_VALUE needs to be a literal"); - }; - let nullable = expr.nullable(input_schema)?; - Arc::new(expressions::NthValueAgg::new( - expr.clone(), - n.clone().try_into()?, - name, - input_phy_types[0].clone(), - nullable, - ordering_types, - ordering_req.to_vec(), - )) - } }) } @@ -320,7 +300,7 @@ mod tests { input_exprs .iter() .zip(coerced_types) - .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) + .map(|(expr, coerced_type)| try_cast(Arc::clone(expr), schema, coerced_type)) .collect::>>() } } diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 8d07f0df0742..65bb9e478c3d 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -162,7 +162,7 @@ impl AggregateExpr for Max { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn create_accumulator(&self) -> Result> { @@ -266,6 +266,10 @@ impl AggregateExpr for Max { fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(SlidingMaxAccumulator::try_new(&self.data_type)?)) } + + fn get_minmax_desc(&self) -> Option<(Field, bool)> { + Some((self.field().ok()?, true)) + } } impl PartialEq for Max { @@ -923,7 +927,7 @@ impl AggregateExpr for Min { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { @@ -1018,6 +1022,10 @@ impl AggregateExpr for Min { fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(SlidingMinAccumulator::try_new(&self.data_type)?)) } + + fn get_minmax_desc(&self) -> Option<(Field, bool)> { + Some((self.field().ok()?, false)) + } } impl PartialEq for Min { @@ -1161,7 +1169,7 @@ mod tests { let mut min = MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) .unwrap(); - min.update_batch(&[b.clone()]).unwrap(); + min.update_batch(&[Arc::clone(&b)]).unwrap(); let min_res = min.evaluate().unwrap(); assert_eq!( min_res, @@ -1173,7 +1181,7 @@ mod tests { let mut max = MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) .unwrap(); - max.update_batch(&[b.clone()]).unwrap(); + max.update_batch(&[Arc::clone(&b)]).unwrap(); let max_res = max.evaluate().unwrap(); assert_eq!( max_res, @@ -1194,7 +1202,7 @@ mod tests { let mut min = MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); - min.update_batch(&[b.clone()]).unwrap(); + min.update_batch(&[Arc::clone(&b)]).unwrap(); let min_res = min.evaluate().unwrap(); assert_eq!( min_res, @@ -1203,7 +1211,7 @@ mod tests { let mut max = MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); - max.update_batch(&[b.clone()]).unwrap(); + max.update_batch(&[Arc::clone(&b)]).unwrap(); let max_res = max.evaluate().unwrap(); assert_eq!( max_res, @@ -1223,7 +1231,7 @@ mod tests { let mut min = MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) .unwrap(); - min.update_batch(&[b.clone()]).unwrap(); + min.update_batch(&[Arc::clone(&b)]).unwrap(); let min_res = min.evaluate().unwrap(); assert_eq!( min_res, @@ -1235,7 +1243,7 @@ mod tests { let mut max = MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) .unwrap(); - max.update_batch(&[b.clone()]).unwrap(); + max.update_batch(&[Arc::clone(&b)]).unwrap(); let max_res = max.evaluate().unwrap(); assert_eq!( max_res, diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index f0de7446f6f1..b9d803900f53 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -20,7 +20,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; -pub(crate) mod nth_value; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index e7b199af3743..bcf1c8e510b1 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -163,7 +163,7 @@ pub fn analyze( ) -> Result { let target_boundaries = context.boundaries; - let mut graph = ExprIntervalGraph::try_new(expr.clone(), schema)?; + let mut graph = ExprIntervalGraph::try_new(Arc::clone(expr), schema)?; let columns = collect_columns(expr) .into_iter() diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 6c12acb934be..e483f935b75c 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -42,12 +42,28 @@ use datafusion_common::JoinType; /// - `across_partitions`: A boolean flag indicating whether the constant expression is /// valid across partitions. If set to `true`, the constant expression has same value for all partitions. /// If set to `false`, the constant expression may have different values for different partitions. +/// +/// # Example +/// +/// ```rust +/// # use datafusion_physical_expr::ConstExpr; +/// # use datafusion_physical_expr_common::expressions::lit; +/// let col = lit(5); +/// // Create a constant expression from a physical expression ref +/// let const_expr = ConstExpr::from(&col); +/// // create a constant expression from a physical expression +/// let const_expr = ConstExpr::from(col); +/// ``` pub struct ConstExpr { expr: Arc, across_partitions: bool, } impl ConstExpr { + /// Create a new constant expression from a physical expression. + /// + /// Note you can also use `ConstExpr::from` to create a constant expression + /// from a reference as well pub fn new(expr: Arc) -> Self { Self { expr, @@ -85,6 +101,18 @@ impl ConstExpr { } } +impl From> for ConstExpr { + fn from(expr: Arc) -> Self { + Self::new(expr) + } +} + +impl From<&Arc> for ConstExpr { + fn from(expr: &Arc) -> Self { + Self::new(Arc::clone(expr)) + } +} + /// Checks whether `expr` is among in the `const_exprs`. pub fn const_exprs_contains( const_exprs: &[ConstExpr], @@ -267,17 +295,19 @@ impl EquivalenceGroup { } (Some(group_idx), None) => { // Right side is new, extend left side's class: - self.classes[group_idx].push(right.clone()); + self.classes[group_idx].push(Arc::clone(right)); } (None, Some(group_idx)) => { // Left side is new, extend right side's class: - self.classes[group_idx].push(left.clone()); + self.classes[group_idx].push(Arc::clone(left)); } (None, None) => { // None of the expressions is among existing classes. // Create a new equivalence class and extend the group. - self.classes - .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); + self.classes.push(EquivalenceClass::new(vec![ + Arc::clone(left), + Arc::clone(right), + ])); } } } @@ -328,7 +358,7 @@ impl EquivalenceGroup { /// The expression is replaced with the first expression in the equivalence /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { - expr.clone() + Arc::clone(&expr) .transform(|expr| { for cls in self.iter() { if cls.contains(&expr) { @@ -429,7 +459,7 @@ impl EquivalenceGroup { .get_equivalence_class(source) .map_or(false, |group| group.contains(expr)) { - return Some(target.clone()); + return Some(Arc::clone(target)); } } } @@ -443,7 +473,7 @@ impl EquivalenceGroup { .into_iter() .map(|child| self.project_expr(mapping, child)) .collect::>>() - .map(|children| expr.clone().with_new_children(children).unwrap()) + .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) } /// Projects this equivalence group according to the given projection mapping. @@ -461,13 +491,13 @@ impl EquivalenceGroup { let mut new_classes = vec![]; for (source, target) in mapping.iter() { if new_classes.is_empty() { - new_classes.push((source, vec![target.clone()])); + new_classes.push((source, vec![Arc::clone(target)])); } if let Some((_, values)) = new_classes.iter_mut().find(|(key, _)| key.eq(source)) { if !physical_exprs_contains(values, target) { - values.push(target.clone()); + values.push(Arc::clone(target)); } } } @@ -515,10 +545,9 @@ impl EquivalenceGroup { // are equal in the resulting table. if join_type == &JoinType::Inner { for (lhs, rhs) in on.iter() { - let new_lhs = lhs.clone() as _; + let new_lhs = Arc::clone(lhs) as _; // Rewrite rhs to point to the right side of the join: - let new_rhs = rhs - .clone() + let new_rhs = Arc::clone(rhs) .transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() @@ -649,7 +678,7 @@ mod tests { let eq_group = eq_properties.eq_group(); for (expr, expected_eq) in expressions { assert!( - expected_eq.eq(&eq_group.normalize_expr(expr.clone())), + expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))), "error in test: expr: {expr:?}" ); } @@ -669,9 +698,11 @@ mod tests { Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); - let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); - let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + let cls1 = + EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]); + let cls2 = + EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]); + let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]); // lit_true is common assert!(cls1.contains_any(&cls2)); diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 5eb8a19e3d67..83f94057f740 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -145,7 +145,7 @@ mod tests { let col_e = &col("e", &test_schema)?; let col_f = &col("f", &test_schema)?; let col_g = &col("g", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); eq_properties.add_equal_conditions(col_a, col_c)?; let option_asc = SortOptions { @@ -201,11 +201,11 @@ mod tests { let col_f = &col("f", &test_schema)?; let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Define a and f are aliases eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. - eq_properties = eq_properties.add_constants([ConstExpr::new(col_e.clone())]); + eq_properties = eq_properties.add_constants([ConstExpr::from(col_e)]); // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -223,7 +223,7 @@ mod tests { let ordering = remaining_exprs .drain(0..n_sort_expr) .map(|expr| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options: options_asc, }) .collect(); @@ -241,7 +241,7 @@ mod tests { in_data .iter() .map(|(expr, options)| { - PhysicalSortRequirement::new((*expr).clone(), *options) + PhysicalSortRequirement::new(Arc::clone(*expr), *options) }) .collect() } @@ -253,7 +253,7 @@ mod tests { in_data .iter() .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), + expr: Arc::clone(*expr), options: *options, }) .collect() @@ -276,7 +276,7 @@ mod tests { in_data .iter() .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), + expr: Arc::clone(expr), options: *options, }) .collect() @@ -309,9 +309,9 @@ mod tests { .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) .collect::>>()?; let projected_batch = if projected_values.is_empty() { - RecordBatch::new_empty(output_schema.clone()) + RecordBatch::new_empty(Arc::clone(&output_schema)) } else { - RecordBatch::try_new(output_schema.clone(), projected_values)? + RecordBatch::try_new(Arc::clone(&output_schema), projected_values)? }; let projected_eq = @@ -399,7 +399,7 @@ mod tests { let vals: Vec = (0..n_row).collect::>(); let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; - columns.push(unique_col.clone()); + columns.push(Arc::clone(&unique_col)); // Create a new schema with the added unique column let unique_col_name = "unique"; @@ -414,7 +414,7 @@ mod tests { let schema = Arc::new(Schema::new(fields)); // Create a new batch with the added column - let new_batch = RecordBatch::try_new(schema.clone(), columns)?; + let new_batch = RecordBatch::try_new(Arc::clone(&schema), columns)?; // Add the unique column to the required ordering to ensure deterministic results required_ordering.push(PhysicalSortExpr { @@ -454,7 +454,7 @@ mod tests { let col = expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); if let Some(res) = &existing_vec[idx] { - return Some(res.clone()); + return Some(Arc::clone(res)); } } None @@ -516,13 +516,13 @@ mod tests { // Fill columns based on equivalence groups for eq_group in eq_properties.eq_group.iter() { let representative_array = - get_representative_arr(eq_group, &schema_vec, schema.clone()) + get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); for expr in eq_group.iter() { let col = expr.as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - schema_vec[idx] = Some(representative_array.clone()); + schema_vec[idx] = Some(Arc::clone(&representative_array)); } } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index ac9d64e486ac..c4b8a5c46563 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -174,7 +174,7 @@ impl OrderingEquivalenceClass { pub fn add_offset(&mut self, offset: usize) { for ordering in self.orderings.iter_mut() { for sort_expr in ordering { - sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); + sort_expr.expr = add_offset_to_expr(Arc::clone(&sort_expr.expr), offset); } } } @@ -264,12 +264,14 @@ mod tests { }, ]; // finer ordering satisfies, crude ordering should return true - let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); + let mut eq_properties_finer = + EquivalenceProperties::new(Arc::clone(&input_schema)); eq_properties_finer.oeq_class.push(finer.clone()); assert!(eq_properties_finer.ordering_satisfy(&crude)); // Crude ordering doesn't satisfy finer ordering. should return false - let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); + let mut eq_properties_crude = + EquivalenceProperties::new(Arc::clone(&input_schema)); eq_properties_crude.oeq_class.push(crude.clone()); assert!(!eq_properties_crude.ordering_satisfy(&finer)); Ok(()) @@ -307,9 +309,9 @@ mod tests { &DFSchema::empty(), )?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let options = SortOptions { descending: false, @@ -541,7 +543,7 @@ mod tests { for (orderings, eq_group, constants, reqs, expected) in test_cases { let err_msg = format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); let orderings = convert_to_orderings(&orderings); eq_properties.add_new_orderings(orderings); let eq_group = eq_group @@ -556,7 +558,7 @@ mod tests { let constants = constants .into_iter() - .map(|expr| ConstExpr::new(expr.clone()).with_across_partitions(true)); + .map(|expr| ConstExpr::from(expr).with_across_partitions(true)); eq_properties = eq_properties.add_constants(constants); let reqs = convert_to_sort_exprs(&reqs); @@ -717,7 +719,7 @@ mod tests { let required = cols .into_iter() .map(|(expr, options)| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options, }) .collect::>(); @@ -769,7 +771,7 @@ mod tests { let requirement = exprs .into_iter() .map(|expr| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options: SORT_OPTIONS, }) .collect::>(); @@ -842,7 +844,7 @@ mod tests { let requirement = exprs .into_iter() .map(|expr| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options: SORT_OPTIONS, }) .collect::>(); diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index b5ac149d8b71..f1ce3f04489e 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -56,8 +56,7 @@ impl ProjectionMapping { .enumerate() .map(|(expr_idx, (expression, name))| { let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - expression - .clone() + Arc::clone(expression) .transform_down(|e| match e.as_any().downcast_ref::() { Some(col) => { // Sometimes, an expression and its name in the input_schema @@ -107,7 +106,7 @@ impl ProjectionMapping { self.map .iter() .find(|(source, _)| source.eq(expr)) - .map(|(_, target)| target.clone()) + .map(|(_, target)| Arc::clone(target)) } } @@ -149,24 +148,24 @@ mod tests { let col_e = &col("e", &schema)?; let col_ts = &col("ts", &schema)?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), + Arc::clone(col_b), Operator::Plus, - col_d.clone(), + Arc::clone(col_d), )) as Arc; let b_plus_e = Arc::new(BinaryExpr::new( - col_b.clone(), + Arc::clone(col_b), Operator::Plus, - col_e.clone(), + Arc::clone(col_e), )) as Arc; let c_plus_d = Arc::new(BinaryExpr::new( - col_c.clone(), + Arc::clone(col_c), Operator::Plus, - col_d.clone(), + Arc::clone(col_d), )) as Arc; let option_asc = SortOptions { @@ -587,14 +586,14 @@ mod tests { for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(&orderings); eq_properties.add_new_orderings(orderings); let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name)) + .map(|(expr, name)| (Arc::clone(expr), name)) .collect::>(); let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; @@ -643,15 +642,15 @@ mod tests { let col_c = &col("c", &schema)?; let col_ts = &col("ts", &schema)?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); let round_c = &create_physical_expr( &test_fun, - &[col_c.clone()], + &[Arc::clone(col_c)], &schema, &[], &DFSchema::empty(), @@ -670,7 +669,7 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name)) + .map(|(expr, name)| (Arc::clone(expr), name)) .collect::>(); let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; @@ -680,9 +679,9 @@ mod tests { let col_c_new = &col("c_new", &output_schema)?; let col_round_c_res = &col("round_c_res", &output_schema)?; let a_new_plus_b_new = Arc::new(BinaryExpr::new( - col_a_new.clone(), + Arc::clone(col_a_new), Operator::Plus, - col_b_new.clone(), + Arc::clone(col_b_new), )) as Arc; let test_cases = vec![ @@ -793,7 +792,7 @@ mod tests { ]; for (idx, (orderings, expected)) in test_cases.iter().enumerate() { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(orderings); eq_properties.add_new_orderings(orderings); @@ -801,7 +800,7 @@ mod tests { let expected = convert_to_orderings(expected); let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); + eq_properties.project(&projection_mapping, Arc::clone(&output_schema)); let orderings = projected_eq.oeq_class(); let err_msg = format!( @@ -834,9 +833,9 @@ mod tests { let col_e = &col("e", &schema)?; let col_f = &col("f", &schema)?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let option_asc = SortOptions { @@ -851,7 +850,7 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name)) + .map(|(expr, name)| (Arc::clone(expr), name)) .collect::>(); let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; @@ -936,7 +935,7 @@ mod tests { ), ]; for (orderings, equal_columns, expected) in test_cases { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); for (lhs, rhs) in equal_columns { eq_properties.add_equal_conditions(lhs, rhs)?; } @@ -947,7 +946,7 @@ mod tests { let expected = convert_to_orderings(&expected); let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); + eq_properties.project(&projection_mapping, Arc::clone(&output_schema)); let orderings = projected_eq.oeq_class(); let err_msg = format!( @@ -1006,7 +1005,7 @@ mod tests { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) .collect::>(); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), @@ -1084,7 +1083,7 @@ mod tests { for proj_exprs in proj_exprs.iter().combinations(n_req) { let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) .collect::>(); let (projected_batch, projected_eq) = apply_projection( proj_exprs.clone(), @@ -1097,7 +1096,7 @@ mod tests { let projected_exprs = projection_mapping .iter() - .map(|(_source, target)| target.clone()) + .map(|(_source, target)| Arc::clone(target)) .collect::>(); for n_req in 0..=projected_exprs.len() { @@ -1105,7 +1104,7 @@ mod tests { let requirement = exprs .into_iter() .map(|expr| PhysicalSortExpr { - expr: expr.clone(), + expr: Arc::clone(expr), options: SORT_OPTIONS, }) .collect::>(); diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index e3a2d1c753ca..d9d19c0bcf47 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -213,13 +213,13 @@ impl EquivalenceProperties { // Left expression is constant, add right as constant if !const_exprs_contains(&self.constants, right) { self.constants - .push(ConstExpr::new(right.clone()).with_across_partitions(true)); + .push(ConstExpr::from(right).with_across_partitions(true)); } } else if self.is_expr_constant(right) { // Right expression is constant, add left as constant if !const_exprs_contains(&self.constants, left) { self.constants - .push(ConstExpr::new(left.clone()).with_across_partitions(true)); + .push(ConstExpr::from(left).with_across_partitions(true)); } } @@ -300,7 +300,7 @@ impl EquivalenceProperties { { if !const_exprs_contains(&self.constants, &expr) { let const_expr = - ConstExpr::new(expr).with_across_partitions(across_partitions); + ConstExpr::from(expr).with_across_partitions(across_partitions); self.constants.push(const_expr); } } @@ -357,7 +357,7 @@ impl EquivalenceProperties { constant_exprs.extend( self.constants .iter() - .map(|const_expr| const_expr.expr().clone()), + .map(|const_expr| Arc::clone(const_expr.expr())), ); let constants_normalized = self.eq_group.normalize_exprs(constant_exprs); // Prune redundant sections in the requirement: @@ -404,7 +404,7 @@ impl EquivalenceProperties { // we add column `a` as constant to the algorithm state. This enables us // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. eq_properties = eq_properties - .add_constants(std::iter::once(ConstExpr::new(normalized_req.expr))); + .add_constants(std::iter::once(ConstExpr::from(normalized_req.expr))); } true } @@ -424,11 +424,11 @@ impl EquivalenceProperties { fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { let ExprProperties { sort_properties, .. - } = self.get_expr_properties(req.expr.clone()); + } = self.get_expr_properties(Arc::clone(&req.expr)); match sort_properties { SortProperties::Ordered(options) => { let sort_expr = PhysicalSortExpr { - expr: req.expr.clone(), + expr: Arc::clone(&req.expr), options, }; sort_expr.satisfy(req, self.schema()) @@ -572,7 +572,7 @@ impl EquivalenceProperties { && cast_expr.is_bigger_cast(expr_type) { res.push(PhysicalSortExpr { - expr: r_expr.clone(), + expr: Arc::clone(&r_expr), options: sort_expr.options, }); } @@ -715,8 +715,9 @@ impl EquivalenceProperties { map: mapping .iter() .map(|(source, target)| { - let normalized_source = self.eq_group.normalize_expr(source.clone()); - (normalized_source, target.clone()) + let normalized_source = + self.eq_group.normalize_expr(Arc::clone(source)); + (normalized_source, Arc::clone(target)) }) .collect(), } @@ -758,7 +759,7 @@ impl EquivalenceProperties { }) .flat_map(|(options, relevant_deps)| { let sort_expr = PhysicalSortExpr { - expr: target.clone(), + expr: Arc::clone(target), options, }; // Generate dependent orderings (i.e. prefixes for `sort_expr`): @@ -832,7 +833,7 @@ impl EquivalenceProperties { { // Expression evaluates to single value projected_constants - .push(ConstExpr::new(target.clone()).with_across_partitions(true)); + .push(ConstExpr::from(target).with_across_partitions(true)); } } projected_constants @@ -889,11 +890,11 @@ impl EquivalenceProperties { .flat_map(|&idx| { let ExprProperties { sort_properties, .. - } = eq_properties.get_expr_properties(exprs[idx].clone()); + } = eq_properties.get_expr_properties(Arc::clone(&exprs[idx])); match sort_properties { SortProperties::Ordered(options) => Some(( PhysicalSortExpr { - expr: exprs[idx].clone(), + expr: Arc::clone(&exprs[idx]), options, }, idx, @@ -903,7 +904,7 @@ impl EquivalenceProperties { let options = SortOptions::default(); Some(( PhysicalSortExpr { - expr: exprs[idx].clone(), + expr: Arc::clone(&exprs[idx]), options, }, idx, @@ -925,8 +926,8 @@ impl EquivalenceProperties { // Note that these expressions are not properly "constants". This is just // an implementation strategy confined to this function. for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { - eq_properties = eq_properties - .add_constants(std::iter::once(ConstExpr::new(expr.clone()))); + eq_properties = + eq_properties.add_constants(std::iter::once(ConstExpr::from(expr))); search_indices.shift_remove(idx); } // Add new ordered section to the state. @@ -954,9 +955,9 @@ impl EquivalenceProperties { let const_exprs = self .constants .iter() - .map(|const_expr| const_expr.expr().clone()); + .map(|const_expr| Arc::clone(const_expr.expr())); let normalized_constants = self.eq_group.normalize_exprs(const_exprs); - let normalized_expr = self.eq_group.normalize_expr(expr.clone()); + let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); is_constant_recurse(&normalized_constants, &normalized_expr) } @@ -1022,7 +1023,9 @@ fn update_properties( Interval::make_unbounded(&node.expr.data_type(eq_properties.schema())?)? } // Now, check what we know about orderings: - let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); + let normalized_expr = eq_properties + .eq_group + .normalize_expr(Arc::clone(&node.expr)); if eq_properties.is_expr_constant(&normalized_expr) { node.data.sort_properties = SortProperties::Singleton; } else if let Some(options) = eq_properties @@ -1108,7 +1111,7 @@ fn referred_dependencies( .keys() .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) { - let key = ExprWrapper(sort_expr.expr.clone()); + let key = ExprWrapper(Arc::clone(&sort_expr.expr)); expr_to_sort_exprs .entry(key) .or_default() @@ -1484,25 +1487,25 @@ mod tests { Field::new("c", DataType::Int64, true), ])); - let input_properties = EquivalenceProperties::new(input_schema.clone()); + let input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); let col_a = col("a", &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), + (Arc::clone(&col_a), "a1".to_string()), + (Arc::clone(&col_a), "a2".to_string()), + (Arc::clone(&col_a), "a3".to_string()), + (Arc::clone(&col_a), "a4".to_string()), ]; let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; let out_schema = output_schema(&projection_mapping, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), + (Arc::clone(&col_a), "a1".to_string()), + (Arc::clone(&col_a), "a2".to_string()), + (Arc::clone(&col_a), "a3".to_string()), + (Arc::clone(&col_a), "a4".to_string()), ]; let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; @@ -1532,8 +1535,8 @@ mod tests { let col_b = &col("b", &schema)?; let col_c = &col("c", &schema)?; let offset = schema.fields.len(); - let col_a2 = &add_offset_to_expr(col_a.clone(), offset); - let col_b2 = &add_offset_to_expr(col_b.clone(), offset); + let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset); + let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset); let option_asc = SortOptions { descending: false, nulls_first: false, @@ -1577,8 +1580,8 @@ mod tests { ), ]; for (left_orderings, right_orderings, expected) in test_cases { - let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); - let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); + let mut left_eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + let mut right_eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let left_orderings = convert_to_orderings(&left_orderings); let right_orderings = convert_to_orderings(&right_orderings); let expected = convert_to_orderings(&expected); @@ -1626,17 +1629,17 @@ mod tests { let col_b = col("b", &schema)?; let col_d = col("d", &schema)?; let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), + Arc::clone(&col_b), Operator::Plus, - col_d.clone(), + Arc::clone(&col_d), )) as Arc; - let constants = vec![col_a.clone(), col_b.clone()]; - let expr = b_plus_d.clone(); + let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b)]; + let expr = Arc::clone(&b_plus_d); assert!(!is_constant_recurse(&constants, &expr)); - let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; - let expr = b_plus_d.clone(); + let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b), Arc::clone(&col_d)]; + let expr = Arc::clone(&b_plus_d); assert!(is_constant_recurse(&constants, &expr)); Ok(()) } @@ -1726,11 +1729,11 @@ mod tests { eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr)?; let others = vec![ vec![PhysicalSortExpr { - expr: col_b_expr.clone(), + expr: Arc::clone(&col_b_expr), options: sort_options, }], vec![PhysicalSortExpr { - expr: col_c_expr.clone(), + expr: Arc::clone(&col_c_expr), options: sort_options, }], ]; @@ -1739,11 +1742,11 @@ mod tests { let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); expected_eqs.add_new_orderings([ vec![PhysicalSortExpr { - expr: col_b_expr.clone(), + expr: Arc::clone(&col_b_expr), options: sort_options, }], vec![PhysicalSortExpr { - expr: col_c_expr.clone(), + expr: Arc::clone(&col_c_expr), options: sort_options, }], ]); @@ -1766,7 +1769,7 @@ mod tests { ]); let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; + let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); eq_properties.add_new_orderings([vec![ PhysicalSortExpr { @@ -1784,11 +1787,11 @@ mod tests { result, vec![ PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: sort_options_not }, PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: sort_options } ] @@ -1801,7 +1804,7 @@ mod tests { ]); let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; + let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); eq_properties.add_new_orderings([ vec![PhysicalSortExpr { @@ -1825,11 +1828,11 @@ mod tests { result, vec![ PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: sort_options_not }, PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: sort_options } ] @@ -1890,11 +1893,11 @@ mod tests { // [b ASC], [d ASC] eq_properties.add_new_orderings(vec![ vec![PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: option_asc, }], vec![PhysicalSortExpr { - expr: col_d.clone(), + expr: Arc::clone(col_d), options: option_asc, }], ]); @@ -1903,22 +1906,22 @@ mod tests { // d + b ( Arc::new(BinaryExpr::new( - col_d.clone(), + Arc::clone(col_d), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc, SortProperties::Ordered(option_asc), ), // b - (col_b.clone(), SortProperties::Ordered(option_asc)), + (Arc::clone(col_b), SortProperties::Ordered(option_asc)), // a - (col_a.clone(), SortProperties::Ordered(option_asc)), + (Arc::clone(col_a), SortProperties::Ordered(option_asc)), // a + c ( Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_c.clone(), + Arc::clone(col_c), )), SortProperties::Unordered, ), @@ -1929,7 +1932,7 @@ mod tests { .iter() .flat_map(|ordering| ordering.first().cloned()) .collect::>(); - let expr_props = eq_properties.get_expr_properties(expr.clone()); + let expr_props = eq_properties.get_expr_properties(Arc::clone(&expr)); let err_msg = format!( "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", expr, expected, expr_props.sort_properties @@ -1987,7 +1990,7 @@ mod tests { .iter() .zip(ordering.iter()) .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: exprs[idx].clone(), + expr: Arc::clone(&exprs[idx]), options: sort_expr.options, }) .collect::>(); @@ -2034,9 +2037,9 @@ mod tests { let col_h = &col("h", &test_schema)?; // a + d let a_plus_d = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_d.clone(), + Arc::clone(col_d), )) as Arc; let option_asc = SortOptions { @@ -2050,11 +2053,11 @@ mod tests { // [d ASC, h DESC] also satisfies schema. eq_properties.add_new_orderings([vec![ PhysicalSortExpr { - expr: col_d.clone(), + expr: Arc::clone(col_d), options: option_asc, }, PhysicalSortExpr { - expr: col_h.clone(), + expr: Arc::clone(col_h), options: option_desc, }, ]]); @@ -2143,7 +2146,7 @@ mod tests { let col_h = &col("h", &test_schema)?; // Add column h as constant - eq_properties = eq_properties.add_constants(vec![ConstExpr::new(col_h.clone())]); + eq_properties = eq_properties.add_constants(vec![ConstExpr::from(col_h)]); let test_cases = vec![ // TEST CASE 1 @@ -2382,20 +2385,21 @@ mod tests { Field::new("b", DataType::Utf8, true), Field::new("c", DataType::Timestamp(TimeUnit::Nanosecond, None), true), ])); - let base_properties = EquivalenceProperties::new(schema.clone()).with_reorder( - ["a", "b", "c"] - .into_iter() - .map(|c| { - col(c, schema.as_ref()).map(|expr| PhysicalSortExpr { - expr, - options: SortOptions { - descending: false, - nulls_first: true, - }, + let base_properties = EquivalenceProperties::new(Arc::clone(&schema)) + .with_reorder( + ["a", "b", "c"] + .into_iter() + .map(|c| { + col(c, schema.as_ref()).map(|expr| PhysicalSortExpr { + expr, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }) }) - }) - .collect::>>()?, - ); + .collect::>>()?, + ); struct TestCase { name: &'static str, @@ -2414,8 +2418,11 @@ mod tests { TestCase { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order - constants: vec![col_b.clone()], - equal_conditions: vec![[cast_c.clone(), col_a.clone()]], + constants: vec![Arc::clone(&col_b)], + equal_conditions: vec![[ + Arc::clone(&cast_c) as Arc, + Arc::clone(&col_a), + ]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -2425,7 +2432,10 @@ mod tests { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order constants: vec![col_b], - equal_conditions: vec![[col_a.clone(), cast_c.clone()]], + equal_conditions: vec![[ + Arc::clone(&col_a), + Arc::clone(&cast_c) as Arc, + ]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -2434,7 +2444,10 @@ mod tests { // b is not constant anymore constants: vec![], // a and c are still compatible, but this is irrelevant since the original ordering is (a, b, c) - equal_conditions: vec![[cast_c.clone(), col_a.clone()]], + equal_conditions: vec![[ + Arc::clone(&cast_c) as Arc, + Arc::clone(&col_a), + ]], sort_columns: &["c"], should_satisfy_ordering: false, }, @@ -2443,7 +2456,7 @@ mod tests { for case in cases { let mut properties = base_properties .clone() - .add_constants(case.constants.into_iter().map(ConstExpr::new)); + .add_constants(case.constants.into_iter().map(ConstExpr::from)); for [left, right] in &case.equal_conditions { properties.add_equal_conditions(left, right)? } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index d19279c20d10..f1e40575bc64 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -328,9 +328,9 @@ impl PhysicalExpr for BinaryExpr { children: Vec>, ) -> Result> { Ok(Arc::new(BinaryExpr::new( - children[0].clone(), + Arc::clone(&children[0]), self.op.clone(), - children[1].clone(), + Arc::clone(&children[1]), ))) } @@ -1493,8 +1493,11 @@ mod tests { let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); apply_arithmetic::( - schema.clone(), - vec![a.clone(), b.clone()], + Arc::clone(&schema), + vec![ + Arc::clone(&a) as Arc, + Arc::clone(&b) as Arc, + ], Operator::Minus, Int32Array::from(vec![0, 0, 1, 4, 11]), )?; @@ -2376,8 +2379,8 @@ mod tests { expected: BooleanArray, ) -> Result<()> { let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; - let data: Vec = vec![left.clone(), right.clone()]; - let batch = RecordBatch::try_new(schema.clone(), data)?; + let data: Vec = vec![Arc::clone(left), Arc::clone(right)]; + let batch = RecordBatch::try_new(Arc::clone(schema), data)?; let result = op .evaluate(&batch)? .into_array(batch.num_rows()) @@ -3471,8 +3474,8 @@ mod tests { expected: ArrayRef, ) -> Result<()> { let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; - let data: Vec = vec![left.clone(), right.clone()]; - let batch = RecordBatch::try_new(schema.clone(), data)?; + let data: Vec = vec![Arc::clone(left), Arc::clone(right)]; + let batch = RecordBatch::try_new(Arc::clone(schema), data)?; let result = arithmetic_op .evaluate(&batch)? .into_array(batch.num_rows()) @@ -3767,15 +3770,15 @@ mod tests { let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; let right = Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef; - let mut result = bitwise_and_dyn(left.clone(), right.clone())?; + let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = Int32Array::from(vec![Some(0), None, Some(3)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_or_dyn(left.clone(), right.clone())?; + result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = Int32Array::from(vec![Some(13), None, Some(15)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_xor_dyn(left.clone(), right.clone())?; + result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = Int32Array::from(vec![Some(13), None, Some(12)]); assert_eq!(result.as_ref(), &expected); @@ -3783,15 +3786,15 @@ mod tests { Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; let right = Arc::new(UInt32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef; - let mut result = bitwise_and_dyn(left.clone(), right.clone())?; + let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = UInt32Array::from(vec![Some(0), None, Some(3)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_or_dyn(left.clone(), right.clone())?; + result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = UInt32Array::from(vec![Some(13), None, Some(15)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_xor_dyn(left.clone(), right.clone())?; + result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = UInt32Array::from(vec![Some(13), None, Some(12)]); assert_eq!(result.as_ref(), &expected); @@ -3803,24 +3806,26 @@ mod tests { let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef; let modules = Arc::new(Int32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef; - let mut result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let mut result = + bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = Int32Array::from(vec![Some(8), None, Some(2560)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_shift_right_dyn(result.clone(), modules.clone())?; + result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?; assert_eq!(result.as_ref(), &input); let input = Arc::new(UInt32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef; let modules = Arc::new(UInt32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef; - let mut result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let mut result = + bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = UInt32Array::from(vec![Some(8), None, Some(2560)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_shift_right_dyn(result.clone(), modules.clone())?; + result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?; assert_eq!(result.as_ref(), &input); Ok(()) } @@ -3829,14 +3834,14 @@ mod tests { fn bitwise_shift_array_overflow_test() -> Result<()> { let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef; let modules = Arc::new(Int32Array::from(vec![Some(100)])) as ArrayRef; - let result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = Int32Array::from(vec![Some(32)]); assert_eq!(result.as_ref(), &expected); let input = Arc::new(UInt32Array::from(vec![Some(2)])) as ArrayRef; let modules = Arc::new(UInt32Array::from(vec![Some(100)])) as ArrayRef; - let result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = UInt32Array::from(vec![Some(32)]); assert_eq!(result.as_ref(), &expected); @@ -3973,9 +3978,12 @@ mod tests { Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef; // Casting Dictionary to Int32 - let casted = - to_result_type_array(&Operator::Plus, dictionary.clone(), &DataType::Int32) - .unwrap(); + let casted = to_result_type_array( + &Operator::Plus, + Arc::clone(&dictionary), + &DataType::Int32, + ) + .unwrap(); assert_eq!( &casted, &(Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)])) @@ -3985,16 +3993,19 @@ mod tests { // Array has same datatype as result type, no casting let casted = to_result_type_array( &Operator::Plus, - dictionary.clone(), + Arc::clone(&dictionary), dictionary.data_type(), ) .unwrap(); assert_eq!(&casted, &dictionary); // Not numerical operator, no casting - let casted = - to_result_type_array(&Operator::Eq, dictionary.clone(), &DataType::Int32) - .unwrap(); + let casted = to_result_type_array( + &Operator::Eq, + Arc::clone(&dictionary), + &DataType::Int32, + ) + .unwrap(); assert_eq!(&casted, &dictionary); } } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 08d8cd441334..cd73c5cb579c 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -173,8 +173,8 @@ impl CaseExpr { if let Some(e) = &self.else_expr { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(e.clone(), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| e.clone()); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); // null and unmatched tuples should be assigned else value remainder = or(&base_nulls, &remainder)?; let else_ = expr @@ -246,8 +246,8 @@ impl CaseExpr { if let Some(e) = &self.else_expr { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(e.clone(), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| e.clone()); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); let else_ = expr .evaluate_selection(batch, &remainder)? .into_array(batch.num_rows())?; @@ -870,7 +870,7 @@ mod tests { ); assert!(expr.is_ok()); let result_type = expr.unwrap().data_type(schema.as_ref())?; - assert_eq!(DataType::Float64, result_type); + assert_eq!(Float64, result_type); Ok(()) } @@ -887,26 +887,26 @@ mod tests { let expr1 = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![ - (when1.clone(), then1.clone()), - (when2.clone(), then2.clone()), + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), ], - Some(else_value.clone()), + Some(Arc::clone(&else_value)), &schema, )?; let expr2 = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![ - (when1.clone(), then1.clone()), - (when2.clone(), then2.clone()), + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), ], - Some(else_value.clone()), + Some(Arc::clone(&else_value)), &schema, )?; let expr3 = generate_case_when_with_type_coercion( Some(col("a", &schema)?), - vec![(when1.clone(), then1.clone()), (when2, then2)], + vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)], None, &schema, )?; @@ -943,15 +943,14 @@ mod tests { let expr = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![ - (when1.clone(), then1.clone()), - (when2.clone(), then2.clone()), + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), ], - Some(else_value.clone()), + Some(Arc::clone(&else_value)), &schema, )?; - let expr2 = expr - .clone() + let expr2 = Arc::clone(&expr) .transform(|e| { let transformed = match e.as_any().downcast_ref::() { @@ -972,8 +971,7 @@ mod tests { .data() .unwrap(); - let expr3 = expr - .clone() + let expr3 = Arc::clone(&expr) .transform_down(|e| { let transformed = match e.as_any().downcast_ref::() { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 53c790ff6b54..8a3885030b9d 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -389,7 +389,7 @@ impl PhysicalExpr for InListExpr { ) -> Result> { // assume the static_filter will not change during the rewrite process Ok(Arc::new(InListExpr::new( - children[0].clone(), + Arc::clone(&children[0]), children[1..].to_vec(), self.negated, self.static_filter.clone(), @@ -540,7 +540,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -551,7 +551,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -562,7 +562,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -573,7 +573,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -598,7 +598,7 @@ mod tests { list.clone(), &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -608,7 +608,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -623,7 +623,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -633,7 +633,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -654,7 +654,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -665,7 +665,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -676,7 +676,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -687,7 +687,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -714,7 +714,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None, Some(false), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -725,7 +725,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None, Some(true), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -736,7 +736,7 @@ mod tests { list, &false, vec![Some(true), None, None, None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -747,7 +747,7 @@ mod tests { list, &true, vec![Some(false), None, None, None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -758,7 +758,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None, Some(true), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -769,7 +769,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None, Some(false), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -780,7 +780,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None, Some(false), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -791,7 +791,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None, Some(true), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -812,7 +812,7 @@ mod tests { list, &false, vec![Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -823,7 +823,7 @@ mod tests { list, &true, vec![Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -834,7 +834,7 @@ mod tests { list, &false, vec![Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -845,7 +845,7 @@ mod tests { list, &true, vec![Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -869,7 +869,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -883,7 +883,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -898,7 +898,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -913,7 +913,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -937,7 +937,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -951,7 +951,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -966,7 +966,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -981,7 +981,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1008,7 +1008,7 @@ mod tests { list, &false, vec![Some(true), None, Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); // expression: "a not in (100,200) @@ -1018,7 +1018,7 @@ mod tests { list, &true, vec![Some(false), None, Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1029,7 +1029,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); // expression: "a not in (200,NULL), the data type of list is INT32 AND NULL @@ -1038,7 +1038,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1049,7 +1049,7 @@ mod tests { list, &false, vec![Some(true), None, Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1060,7 +1060,7 @@ mod tests { list, &true, vec![Some(true), None, Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1073,7 +1073,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1082,7 +1082,7 @@ mod tests { list, &true, vec![Some(false), None, Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1168,7 +1168,7 @@ mod tests { list.clone(), &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1177,7 +1177,7 @@ mod tests { list.clone(), &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); Ok(()) @@ -1219,13 +1219,13 @@ mod tests { vec![Arc::new(a), Arc::new(b), Arc::new(c)], )?; - let list = vec![col_b.clone(), col_c.clone()]; + let list = vec![Arc::clone(&col_b), Arc::clone(&col_c)]; in_list!( batch, list.clone(), &false, vec![Some(false), Some(true), None, Some(true), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1234,7 +1234,7 @@ mod tests { list, &true, vec![Some(true), Some(false), None, Some(false), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1262,22 +1262,22 @@ mod tests { // static_filter has no nulls let list = vec![lit(1_i64), lit(2_i64)]; - test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false); + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); // static_filter has nulls let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)]; - test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); - let list = vec![c1_nullable.clone()]; - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true); + let list = vec![Arc::clone(&c1_nullable)]; + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); - let list = vec![c2_non_nullable.clone()]; - test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); + let list = vec![Arc::clone(&c2_non_nullable)]; + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); - let list = vec![c2_non_nullable.clone(), c2_non_nullable.clone()]; - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false); + let list = vec![Arc::clone(&c2_non_nullable), Arc::clone(&c2_non_nullable)]; + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); Ok(()) } @@ -1370,7 +1370,7 @@ mod tests { list.clone(), &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } @@ -1382,7 +1382,7 @@ mod tests { list.clone(), &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } @@ -1402,7 +1402,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } @@ -1414,7 +1414,7 @@ mod tests { list.clone(), &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 1918f0891fff..9f7438d13e05 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -73,9 +73,11 @@ impl PhysicalExpr for IsNotNullExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; match arg { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( - compute::is_not_null(array.as_ref())?, - ))), + ColumnarValue::Array(array) => { + let is_null = super::is_null::compute_is_null(array)?; + let is_not_null = compute::not(&is_null)?; + Ok(ColumnarValue::Array(Arc::new(is_not_null))) + } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( ScalarValue::Boolean(Some(!scalar.is_null())), )), @@ -90,7 +92,7 @@ impl PhysicalExpr for IsNotNullExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(IsNotNullExpr::new(children[0].clone()))) + Ok(Arc::new(IsNotNullExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -120,6 +122,8 @@ mod tests { array::{BooleanArray, StringArray}, datatypes::*, }; + use arrow_array::{Array, Float64Array, Int32Array, UnionArray}; + use arrow_buffer::ScalarBuffer; use datafusion_common::cast::as_boolean_array; #[test] @@ -143,4 +147,48 @@ mod tests { Ok(()) } + + #[test] + fn union_is_not_null_op() { + // union of [{A=1}, {A=}, {B=1.1}, {B=1.2}, {B=}] + let int_array = Int32Array::from(vec![Some(1), None, None, None, None]); + let float_array = + Float64Array::from(vec![None, None, Some(1.1), Some(1.2), None]); + let type_ids = [0, 0, 1, 1, 1].into_iter().collect::>(); + + let children = vec![Arc::new(int_array) as Arc, Arc::new(float_array)]; + + let union_fields: UnionFields = [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + ] + .into_iter() + .collect(); + + let array = + UnionArray::try_new(union_fields.clone(), type_ids, None, children).unwrap(); + + let field = Field::new( + "my_union", + DataType::Union(union_fields, UnionMode::Sparse), + true, + ); + + let schema = Schema::new(vec![field]); + let expr = is_not_null(col("my_union", &schema).unwrap()).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap(); + + // expression: "a is not null" + let actual = expr + .evaluate(&batch) + .unwrap() + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let actual = as_boolean_array(&actual).unwrap(); + + let expected = &BooleanArray::from(vec![true, false, true, true, false]); + + assert_eq!(expected, actual); + } } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 3430efcd7635..e2dc941e26bc 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -25,6 +25,9 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use arrow_array::{Array, ArrayRef, BooleanArray, Int8Array, UnionArray}; +use arrow_buffer::{BooleanBuffer, ScalarBuffer}; +use arrow_ord::cmp; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; @@ -74,9 +77,9 @@ impl PhysicalExpr for IsNullExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; match arg { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( - compute::is_null(array.as_ref())?, - ))), + ColumnarValue::Array(array) => { + Ok(ColumnarValue::Array(Arc::new(compute_is_null(array)?))) + } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( ScalarValue::Boolean(Some(scalar.is_null())), )), @@ -91,7 +94,7 @@ impl PhysicalExpr for IsNullExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(IsNullExpr::new(children[0].clone()))) + Ok(Arc::new(IsNullExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -100,6 +103,55 @@ impl PhysicalExpr for IsNullExpr { } } +/// workaround , +/// this can be replaced with a direct call to `arrow::compute::is_null` once it's fixed. +pub(crate) fn compute_is_null(array: ArrayRef) -> Result { + if let Some(union_array) = array.as_any().downcast_ref::() { + if let Some(offsets) = union_array.offsets() { + dense_union_is_null(union_array, offsets) + } else { + sparse_union_is_null(union_array) + } + } else { + compute::is_null(array.as_ref()).map_err(Into::into) + } +} + +fn dense_union_is_null( + union_array: &UnionArray, + offsets: &ScalarBuffer, +) -> Result { + let child_arrays = (0..union_array.type_names().len()) + .map(|type_id| { + compute::is_null(&union_array.child(type_id as i8)).map_err(Into::into) + }) + .collect::>>()?; + + let buffer: BooleanBuffer = offsets + .iter() + .zip(union_array.type_ids()) + .map(|(offset, type_id)| child_arrays[*type_id as usize].value(*offset as usize)) + .collect(); + + Ok(BooleanArray::new(buffer, None)) +} + +fn sparse_union_is_null(union_array: &UnionArray) -> Result { + let type_ids = Int8Array::new(union_array.type_ids().clone(), None); + + let mut union_is_null = + BooleanArray::new(BooleanBuffer::new_unset(union_array.len()), None); + for type_id in 0..union_array.type_names().len() { + let type_id = type_id as i8; + let union_is_child = cmp::eq(&type_ids, &Int8Array::new_scalar(type_id))?; + let child = union_array.child(type_id); + let child_array_is_null = compute::is_null(&child)?; + let child_is_null = compute::and(&union_is_child, &child_array_is_null)?; + union_is_null = compute::or(&union_is_null, &child_is_null)?; + } + Ok(union_is_null) +} + impl PartialEq for IsNullExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) @@ -108,6 +160,7 @@ impl PartialEq for IsNullExpr { .unwrap_or(false) } } + /// Create an IS NULL expression pub fn is_null(arg: Arc) -> Result> { Ok(Arc::new(IsNullExpr::new(arg))) @@ -121,6 +174,8 @@ mod tests { array::{BooleanArray, StringArray}, datatypes::*, }; + use arrow_array::{Float64Array, Int32Array}; + use arrow_buffer::ScalarBuffer; use datafusion_common::cast::as_boolean_array; #[test] @@ -145,4 +200,72 @@ mod tests { Ok(()) } + + fn union_fields() -> UnionFields { + [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + (2, Arc::new(Field::new("C", DataType::Utf8, true))), + ] + .into_iter() + .collect() + } + + #[test] + fn sparse_union_is_null() { + // union of [{A=1}, {A=}, {B=1.1}, {B=1.2}, {B=}, {C=}, {C="a"}] + let int_array = + Int32Array::from(vec![Some(1), None, None, None, None, None, None]); + let float_array = + Float64Array::from(vec![None, None, Some(1.1), Some(1.2), None, None, None]); + let str_array = + StringArray::from(vec![None, None, None, None, None, None, Some("a")]); + let type_ids = [0, 0, 1, 1, 1, 2, 2] + .into_iter() + .collect::>(); + + let children = vec![ + Arc::new(int_array) as Arc, + Arc::new(float_array), + Arc::new(str_array), + ]; + + let array = + UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); + + let array_ref = Arc::new(array) as ArrayRef; + let result = compute_is_null(array_ref).unwrap(); + + let expected = + &BooleanArray::from(vec![false, true, false, false, true, true, false]); + assert_eq!(expected, &result); + } + + #[test] + fn dense_union_is_null() { + // union of [{A=1}, {A=}, {B=3.2}, {B=}, {C="a"}, {C=}] + let int_array = Int32Array::from(vec![Some(1), None]); + let float_array = Float64Array::from(vec![Some(3.2), None]); + let str_array = StringArray::from(vec![Some("a"), None]); + let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::>(); + let offsets = [0, 1, 0, 1, 0, 1] + .into_iter() + .collect::>(); + + let children = vec![ + Arc::new(int_array) as Arc, + Arc::new(float_array), + Arc::new(str_array), + ]; + + let array = + UnionArray::try_new(union_fields(), type_ids, Some(offsets), children) + .unwrap(); + + let array_ref = Arc::new(array) as ArrayRef; + let result = compute_is_null(array_ref).unwrap(); + + let expected = &BooleanArray::from(vec![false, true, false, true, false, true]); + assert_eq!(expected, &result); + } } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index e0c02b0a90e9..b84ba82b642d 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -123,8 +123,8 @@ impl PhysicalExpr for LikeExpr { Ok(Arc::new(LikeExpr::new( self.negated, self.case_insensitive, - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), ))) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 1f2c955ad07e..7d8f12091f46 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -39,7 +39,6 @@ pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; -pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::stats::StatsType; pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::lead_lag::{lag, lead, WindowShift}; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index aed2675e0447..b5ebc250cb89 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -97,7 +97,7 @@ impl PhysicalExpr for NegativeExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(NegativeExpr::new(children[0].clone()))) + Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 9aaab0658d39..b69954e00bba 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -97,7 +97,7 @@ impl PhysicalExpr for NotExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(NotExpr::new(children[0].clone()))) + Ok(Arc::new(NotExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index d31306e239bd..3549a3df83bb 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -106,7 +106,7 @@ impl PhysicalExpr for TryCastExpr { children: Vec>, ) -> Result> { Ok(Arc::new(TryCastExpr::new( - children[0].clone(), + Arc::clone(&children[0]), self.cast_type.clone(), ))) } @@ -137,7 +137,7 @@ pub fn try_cast( ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { - Ok(expr.clone()) + Ok(Arc::clone(&expr)) } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(TryCastExpr::new(expr, cast_type))) } else { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 6fbcd461af66..ef9dd36cfb50 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -180,7 +180,7 @@ impl ExprIntervalGraphNode { /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { - let expr = node.expr.clone(); + let expr = Arc::clone(&node.expr); if let Some(literal) = expr.as_any().downcast_ref::() { let value = literal.value(); Interval::try_new(value.clone(), value.clone()) @@ -422,7 +422,7 @@ impl ExprIntervalGraph { let mut removals = vec![]; let mut expr_node_indices = exprs .iter() - .map(|e| (e.clone(), usize::MAX)) + .map(|e| (Arc::clone(e), usize::MAX)) .collect::>(); while let Some(node) = bfs.next(graph) { // Get the plan corresponding to this node: @@ -744,16 +744,17 @@ mod tests { schema: &Schema, ) -> Result<()> { let col_stats = vec![ - (exprs_with_interval.0.clone(), left_interval), - (exprs_with_interval.1.clone(), right_interval), + (Arc::clone(&exprs_with_interval.0), left_interval), + (Arc::clone(&exprs_with_interval.1), right_interval), ]; let expected = vec![ - (exprs_with_interval.0.clone(), left_expected), - (exprs_with_interval.1.clone(), right_expected), + (Arc::clone(&exprs_with_interval.0), left_expected), + (Arc::clone(&exprs_with_interval.1), right_expected), ]; let mut graph = ExprIntervalGraph::try_new(expr, schema)?; - let expr_indexes = graph - .gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec()); + let expr_indexes = graph.gather_node_indices( + &col_stats.iter().map(|(e, _)| Arc::clone(e)).collect_vec(), + ); let mut col_stat_nodes = col_stats .iter() @@ -870,14 +871,21 @@ mod tests { // left_watermark > right_watermark + 5 let left_and_1 = Arc::new(BinaryExpr::new( - left_col.clone(), + Arc::clone(&left_col) as Arc, Operator::Plus, Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), )); - let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone())); + let expr = Arc::new(BinaryExpr::new( + left_and_1, + Operator::Gt, + Arc::clone(&right_col) as Arc, + )); experiment( expr, - (left_col.clone(), right_col.clone()), + ( + Arc::clone(&left_col) as Arc, + Arc::clone(&right_col) as Arc, + ), Interval::make(Some(10_i32), Some(20_i32))?, Interval::make(Some(100), None)?, Interval::make(Some(10), Some(20))?, diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index 075b8240353d..cedf55bccbf2 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -41,12 +41,12 @@ pub fn gen_conjunctive_numerical_expr( ) -> Arc { let (op_1, op_2, op_3, op_4) = op; let left_and_1 = Arc::new(BinaryExpr::new( - left_col.clone(), + Arc::clone(&left_col), op_1, Arc::new(Literal::new(a)), )); let left_and_2 = Arc::new(BinaryExpr::new( - right_col.clone(), + Arc::clone(&right_col), op_2, Arc::new(Literal::new(b)), )); @@ -78,8 +78,18 @@ pub fn gen_conjunctive_temporal_expr( d: ScalarValue, schema: &Schema, ) -> Result, DataFusionError> { - let left_and_1 = binary(left_col.clone(), op_1, Arc::new(Literal::new(a)), schema)?; - let left_and_2 = binary(right_col.clone(), op_2, Arc::new(Literal::new(b)), schema)?; + let left_and_1 = binary( + Arc::clone(&left_col), + op_1, + Arc::new(Literal::new(a)), + schema, + )?; + let left_and_2 = binary( + Arc::clone(&right_col), + op_2, + Arc::new(Literal::new(b)), + schema, + )?; let right_and_1 = binary(left_col, op_3, Arc::new(Literal::new(c)), schema)?; let right_and_2 = binary(right_col, op_4, Arc::new(Literal::new(d)), schema)?; let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 06c73636773e..4f83ae01959b 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] pub mod aggregate; pub mod analysis; diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index 273c77fb1d5e..821b2c9fe17a 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -169,11 +169,11 @@ impl Partitioning { if !eq_groups.is_empty() { let normalized_required_exprs = required_exprs .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect::>(); let normalized_partition_exprs = partition_exprs .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect::>(); return physical_exprs_equal( &normalized_required_exprs, diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 127194f681a5..c60a772b9ce2 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -117,12 +117,12 @@ mod tests { // lit(true), lit(false), lit(4), lit(2), Col(a), Col(b) let physical_exprs: Vec> = vec![ - lit_true.clone(), - lit_false.clone(), - lit4.clone(), - lit2.clone(), - col_a_expr.clone(), - col_b_expr.clone(), + Arc::clone(&lit_true), + Arc::clone(&lit_false), + Arc::clone(&lit4), + Arc::clone(&lit2), + Arc::clone(&col_a_expr), + Arc::clone(&col_b_expr), ]; // below expressions are inside physical_exprs assert!(physical_exprs_contains(&physical_exprs, &lit_true)); @@ -146,10 +146,10 @@ mod tests { Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let vec1 = vec![lit_true.clone(), lit_false.clone()]; - let vec2 = vec![lit_true.clone(), col_b_expr.clone()]; - let vec3 = vec![lit2.clone(), lit1.clone()]; - let vec4 = vec![lit_true.clone(), lit_false.clone()]; + let vec1 = vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]; + let vec2 = vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]; + let vec3 = vec![Arc::clone(&lit2), Arc::clone(&lit1)]; + let vec4 = vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]; // these vectors are same assert!(physical_exprs_equal(&vec1, &vec1)); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 8fe99cdca591..dbebf4c18b79 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -242,7 +242,7 @@ pub fn create_physical_expr( when_expr .iter() .zip(then_expr.iter()) - .map(|(w, t)| (w.clone(), t.clone())) + .map(|(w, t)| (Arc::clone(w), Arc::clone(t))) .collect(); let else_expr: Option> = if let Some(e) = &case.else_expr { @@ -288,7 +288,7 @@ pub fn create_physical_expr( create_physical_exprs(args, input_dfschema, execution_props)?; scalar_function::create_physical_expr( - func.clone().as_ref(), + Arc::clone(func).as_ref(), &physical_args, input_schema, args, @@ -307,9 +307,19 @@ pub fn create_physical_expr( // rewrite the between into the two binary operators let binary_expr = binary( - binary(value_expr.clone(), Operator::GtEq, low_expr, input_schema)?, + binary( + Arc::clone(&value_expr), + Operator::GtEq, + low_expr, + input_schema, + )?, Operator::And, - binary(value_expr.clone(), Operator::LtEq, high_expr, input_schema)?, + binary( + Arc::clone(&value_expr), + Operator::LtEq, + high_expr, + input_schema, + )?, input_schema, ); diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 10e29b41031d..83272fc9b269 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -153,7 +153,7 @@ impl PhysicalExpr for ScalarFunctionExpr { ) -> Result> { Ok(Arc::new(ScalarFunctionExpr::new( &self.name, - self.fun.clone(), + Arc::clone(&self.fun), children, self.return_type().clone(), ))) diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 070034116fb4..42e5e6fcf3ac 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -870,14 +870,12 @@ mod test { // Schema for testing fn schema() -> SchemaRef { - SCHEMA - .get_or_init(|| { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Int32, false), - ])) - }) - .clone() + Arc::clone(SCHEMA.get_or_init(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])) + })) } static SCHEMA: OnceLock = OnceLock::new(); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 492cb02941df..a33f65f92a61 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -111,7 +111,7 @@ pub fn convert_to_expr>( ) -> Vec> { sequence .into_iter() - .map(|elem| elem.borrow().expr.clone()) + .map(|elem| Arc::clone(&elem.borrow().expr)) .collect() } @@ -166,7 +166,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> for expr_node in node.children.iter() { self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } - self.visited_plans.push((expr.clone(), node_idx)); + self.visited_plans.push((Arc::clone(expr), node_idx)); node_idx } }; @@ -379,7 +379,7 @@ pub(crate) mod tests { } fn make_dummy_node(node: &ExprTreeNode) -> Result { - let expr = node.expr.clone(); + let expr = Arc::clone(&node.expr); let dummy_property = if expr.as_any().is::() { "Binary" } else if expr.as_any().is::() { diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 065260a73e0b..04d359903eae 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -137,7 +137,7 @@ impl WindowExpr for BuiltInWindowExpr { let order_bys_ref = &values[n_args..]; let mut window_frame_ctx = - WindowFrameContext::new(self.window_frame.clone(), sort_options); + WindowFrameContext::new(Arc::clone(&self.window_frame), sort_options); let mut last_range = Range { start: 0, end: 0 }; // We iterate on each row to calculate window frame range and and window function result for idx in 0..num_rows { @@ -217,7 +217,7 @@ impl WindowExpr for BuiltInWindowExpr { .window_frame_ctx .get_or_insert_with(|| { WindowFrameContext::new( - self.window_frame.clone(), + Arc::clone(&self.window_frame), sort_options.clone(), ) }) diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index 9a7c89dca56c..1656b7c3033a 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -104,7 +104,7 @@ impl BuiltInWindowFunctionExpr for WindowShift { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { @@ -125,7 +125,7 @@ impl BuiltInWindowFunctionExpr for WindowShift { name: self.name.clone(), data_type: self.data_type.clone(), shift_offset: -self.shift_offset, - expr: self.expr.clone(), + expr: Arc::clone(&self.expr), default_value: self.default_value.clone(), ignore_nulls: self.ignore_nulls, })) @@ -209,7 +209,7 @@ fn shift_with_default_value( let value_len = array.len() as i64; if offset == 0 { - Ok(array.clone()) + Ok(Arc::clone(array)) } else if offset == i64::MIN || offset.abs() >= value_len { default_value.to_array_of_size(value_len as usize) } else { diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 4bd40066ff34..87c74579c639 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -116,7 +116,7 @@ impl BuiltInWindowFunctionExpr for NthValue { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { @@ -142,7 +142,7 @@ impl BuiltInWindowFunctionExpr for NthValue { }; Some(Arc::new(Self { name: self.name.clone(), - expr: self.expr.clone(), + expr: Arc::clone(&self.expr), data_type: self.data_type.clone(), kind: reversed_kind, ignore_nulls: self.ignore_nulls, diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 961f0884dd87..50e9632b2196 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -163,7 +163,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { aggregate: self.aggregate.with_new_expressions(args, vec![])?, partition_by: partition_bys, order_by: new_order_by, - window_frame: self.window_frame.clone(), + window_frame: Arc::clone(&self.window_frame), })) } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 3cf68379d72b..7020f7f5cf83 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -137,7 +137,7 @@ pub trait WindowExpr: Send + Sync + Debug { let order_by_exprs = self .order_by() .iter() - .map(|sort_expr| sort_expr.expr.clone()) + .map(|sort_expr| Arc::clone(&sort_expr.expr)) .collect::>(); WindowPhysicalExpressions { args, @@ -193,7 +193,7 @@ pub trait AggregateWindowExpr: WindowExpr { let sort_options: Vec = self.order_by().iter().map(|o| o.options).collect(); let mut window_frame_ctx = - WindowFrameContext::new(self.get_window_frame().clone(), sort_options); + WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options); self.get_result_column( &mut accumulator, batch, @@ -241,7 +241,7 @@ pub trait AggregateWindowExpr: WindowExpr { let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { let sort_options: Vec = self.order_by().iter().map(|o| o.options).collect(); - WindowFrameContext::new(self.get_window_frame().clone(), sort_options) + WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options) }); let out_col = self.get_result_column( accumulator, diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 2bf32e8d7084..8caf10acf09b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -41,7 +41,7 @@ use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ equivalence::{collapse_lex_req, ProjectionMapping}, - expressions::{Column, Max, Min, UnKnownColumn}, + expressions::{Column, UnKnownColumn}, physical_exprs_contains, AggregateExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortRequirement, }; @@ -188,7 +188,7 @@ impl PhysicalGroupBy { pub fn input_exprs(&self) -> Vec> { self.expr .iter() - .map(|(expr, _alias)| expr.clone()) + .map(|(expr, _alias)| Arc::clone(expr)) .collect() } @@ -283,9 +283,9 @@ impl AggregateExec { group_by: self.group_by.clone(), filter_expr: self.filter_expr.clone(), limit: self.limit, - input: self.input.clone(), - schema: self.schema.clone(), - input_schema: self.input_schema.clone(), + input: Arc::clone(&self.input), + schema: Arc::clone(&self.schema), + input_schema: Arc::clone(&self.input_schema), } } @@ -355,7 +355,7 @@ impl AggregateExec { let mut new_requirement = indices .iter() .map(|&idx| PhysicalSortRequirement { - expr: groupby_exprs[idx].clone(), + expr: Arc::clone(&groupby_exprs[idx]), options: None, }) .collect::>(); @@ -369,14 +369,26 @@ impl AggregateExec { new_requirement.extend(req); new_requirement = collapse_lex_req(new_requirement); - let input_order_mode = - if indices.len() == groupby_exprs.len() && !indices.is_empty() { - InputOrderMode::Sorted - } else if !indices.is_empty() { - InputOrderMode::PartiallySorted(indices) - } else { - InputOrderMode::Linear - }; + // If our aggregation has grouping sets then our base grouping exprs will + // be expanded based on the flags in `group_by.groups` where for each + // group we swap the grouping expr for `null` if the flag is `true` + // That means that each index in `indices` is valid if and only if + // it is not null in every group + let indices: Vec = indices + .into_iter() + .filter(|idx| group_by.groups.iter().all(|group| !group[*idx])) + .collect(); + + let input_order_mode = if indices.len() == groupby_exprs.len() + && !indices.is_empty() + && group_by.groups.len() == 1 + { + InputOrderMode::Sorted + } else if !indices.is_empty() { + InputOrderMode::PartiallySorted(indices) + } else { + InputOrderMode::Linear + }; // construct a map from the input expression to the output expression of the Aggregation group by let projection_mapping = @@ -387,7 +399,7 @@ impl AggregateExec { let cache = Self::compute_properties( &input, - schema.clone(), + Arc::clone(&schema), &projection_mapping, &mode, &input_order_mode, @@ -446,7 +458,7 @@ impl AggregateExec { /// Get the input schema before any aggregates are applied pub fn input_schema(&self) -> SchemaRef { - self.input_schema.clone() + Arc::clone(&self.input_schema) } /// number of rows soft limit of the AggregateExec @@ -484,13 +496,7 @@ impl AggregateExec { /// Finds the DataType and SortDirection for this Aggregate, if there is one pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; - if let Some(max) = agg_expr.as_any().downcast_ref::() { - Some((max.field().ok()?, true)) - } else if let Some(min) = agg_expr.as_any().downcast_ref::() { - Some((min.field().ok()?, false)) - } else { - None - } + agg_expr.get_minmax_desc() } /// true, if this Aggregate has a group-by with no required or explicit ordering, @@ -700,9 +706,9 @@ impl ExecutionPlan for AggregateExec { self.group_by.clone(), self.aggr_expr.clone(), self.filter_expr.clone(), - children[0].clone(), - self.input_schema.clone(), - self.schema.clone(), + Arc::clone(&children[0]), + Arc::clone(&self.input_schema), + Arc::clone(&self.schema), )?; me.limit = self.limit; @@ -999,7 +1005,7 @@ fn aggregate_expressions( // way order sensitive aggregators can satisfy requirement // themselves. if let Some(ordering_req) = agg.order_bys() { - result.extend(ordering_req.iter().map(|item| item.expr.clone())); + result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr))); } result }) @@ -1159,9 +1165,9 @@ pub(crate) fn evaluate_group_by( .enumerate() .map(|(idx, is_null)| { if *is_null { - null_exprs[idx].clone() + Arc::clone(&null_exprs[idx]) } else { - exprs[idx].clone() + Arc::clone(&exprs[idx]) } }) .collect() @@ -1186,6 +1192,7 @@ mod tests { use arrow::array::{Float64Array, UInt32Array}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; + use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, ScalarValue, @@ -1201,7 +1208,9 @@ mod tests { use datafusion_physical_expr::expressions::{lit, OrderSensitiveArrayAgg}; use datafusion_physical_expr::PhysicalSortExpr; + use crate::common::collect; use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::expressions::Literal; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1226,10 +1235,10 @@ mod tests { // define data. ( - schema.clone(), + Arc::clone(&schema), vec![ RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), @@ -1261,10 +1270,10 @@ mod tests { // the expected result by accident, but merging actually works properly; // i.e. it doesn't depend on the data insertion order. ( - schema.clone(), + Arc::clone(&schema), vec![ RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), @@ -1272,7 +1281,7 @@ mod tests { ) .unwrap(), RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])), @@ -1280,7 +1289,7 @@ mod tests { ) .unwrap(), RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])), @@ -1360,11 +1369,11 @@ mod tests { aggregates.clone(), vec![None], input, - input_schema.clone(), + Arc::clone(&input_schema), )?); let result = - common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; + common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { vec![ @@ -1442,7 +1451,7 @@ mod tests { )?); let result = - common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?; + common::collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let batch = concat_batches(&result[0].schema(), &result)?; assert_eq!(batch.num_columns(), 3); assert_eq!(batch.num_rows(), 12); @@ -1510,11 +1519,11 @@ mod tests { aggregates.clone(), vec![None], input, - input_schema.clone(), + Arc::clone(&input_schema), )?); let result = - common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; + common::collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { vec![ @@ -1564,7 +1573,7 @@ mod tests { // enlarge memory limit to let the final aggregation finish new_spill_ctx(2, 2600) } else { - task_ctx.clone() + Arc::clone(&task_ctx) }; let result = common::collect(merged_aggregate.execute(0, task_ctx)?).await?; let batch = concat_batches(&result[0].schema(), &result)?; @@ -1847,11 +1856,11 @@ mod tests { groups, aggregates, vec![None; n_aggr], - input.clone(), - input_schema.clone(), + Arc::clone(&input), + Arc::clone(&input_schema), )?); - let stream = partial_aggregate.execute_typed(0, task_ctx.clone())?; + let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))?; // ensure that we really got the version we wanted match version { @@ -2103,7 +2112,7 @@ mod tests { vec![partition3], vec![partition4], ], - schema.clone(), + Arc::clone(&schema), None, )?); let aggregate_exec = Arc::new(AggregateExec::try_new( @@ -2112,7 +2121,7 @@ mod tests { aggregates.clone(), vec![None], memory_exec, - schema.clone(), + Arc::clone(&schema), )?); let coalesce = if use_coalesce_batches { let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec)); @@ -2177,41 +2186,41 @@ mod tests { let order_by_exprs = vec![ None, Some(vec![PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: options1, }]), Some(vec![ PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: options1, }, PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: options1, }, PhysicalSortExpr { - expr: col_c.clone(), + expr: Arc::clone(col_c), options: options1, }, ]), Some(vec![ PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: options1, }, PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: options1, }, ]), ]; let common_requirement = vec![ PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: options1, }, PhysicalSortExpr { - expr: col_c.clone(), + expr: Arc::clone(col_c), options: options1, }, ]; @@ -2219,7 +2228,7 @@ mod tests { .into_iter() .map(|order_by_expr| { Arc::new(OrderSensitiveArrayAgg::new( - col_a.clone(), + Arc::clone(col_a), "array_agg", DataType::Int32, false, @@ -2264,13 +2273,105 @@ mod tests { groups, aggregates.clone(), vec![None, None], - blocking_exec.clone(), + Arc::clone(&blocking_exec) as Arc, schema, )?); - let new_agg = aggregate_exec - .clone() - .with_new_children(vec![blocking_exec])?; + let new_agg = + Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?; assert_eq!(new_agg.schema(), aggregate_exec.schema()); Ok(()) } + + #[tokio::test] + async fn test_agg_exec_group_by_const() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + Field::new("const", DataType::Int32, false), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); + + let groups = PhysicalGroupBy::new( + vec![ + (col_a, "a".to_string()), + (col_b, "b".to_string()), + (const_expr, "const".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "b".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Int32(None))), + "const".to_string(), + ), + ], + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + ); + + let aggregates: Vec> = vec![create_aggregate_expr( + count_udaf().as_ref(), + &[lit(1)], + &[datafusion_expr::lit(1)], + &[], + &[], + schema.as_ref(), + "1", + false, + false, + )?]; + + let input_batches = (0..4) + .map(|_| { + let a = Arc::new(Float32Array::from(vec![0.; 8192])); + let b = Arc::new(Float32Array::from(vec![0.; 8192])); + let c = Arc::new(Int32Array::from(vec![1; 8192])); + + RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap() + }) + .collect(); + + let input = Arc::new(MemoryExec::try_new( + &[input_batches], + Arc::clone(&schema), + None, + )?); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None], + input, + schema, + )?); + + let output = + collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; + + let expected = [ + "+-----+-----+-------+----------+", + "| a | b | const | 1[count] |", + "+-----+-----+-------+----------+", + "| | 0.0 | | 32768 |", + "| 0.0 | | | 32768 |", + "| | | 1 | 32768 |", + "+-----+-----+-------+----------+", + ]; + assert_batches_sorted_eq!(expected, &output); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 5ec95bd79942..f85164f7f1e2 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -140,8 +140,11 @@ impl AggregateStream { let result = finalize_aggregation(&mut this.accumulators, &this.mode) .and_then(|columns| { - RecordBatch::try_new(this.schema.clone(), columns) - .map_err(Into::into) + RecordBatch::try_new( + Arc::clone(&this.schema), + columns, + ) + .map_err(Into::into) }) .record_output(&this.baseline_metrics); @@ -181,7 +184,7 @@ impl Stream for AggregateStream { impl RecordBatchStream for AggregateStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index ecd37c913e98..f8fd86ff8b50 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -22,6 +22,7 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use datafusion_physical_expr::PhysicalSortExpr; +use std::sync::Arc; /// Tracks grouping state when the data is ordered by some subset of /// the group keys. @@ -138,7 +139,7 @@ impl GroupOrderingPartial { let sort_values: Vec<_> = self .order_indices .iter() - .map(|&idx| group_values[idx].clone()) + .map(|&idx| Arc::clone(&group_values[idx])) .collect(); Ok(self.row_converter.convert_columns(&sort_values)?) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 27577e6c8bf8..a1d3378181c2 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -358,7 +358,7 @@ impl GroupedHashAggregateStream { let spill_state = SpillState { spills: vec![], spill_expr, - spill_schema: agg_schema.clone(), + spill_schema: Arc::clone(&agg_schema), is_stream_merging: false, merging_aggregate_arguments, merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), @@ -401,7 +401,7 @@ pub(crate) fn create_group_accumulator( "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}", agg_expr.name() ); - let agg_expr_captured = agg_expr.clone(); + let agg_expr_captured = Arc::clone(agg_expr); let factory = move || agg_expr_captured.create_accumulator(); Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) } @@ -515,7 +515,7 @@ impl Stream for GroupedHashAggregateStream { impl RecordBatchStream for GroupedHashAggregateStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -625,7 +625,7 @@ impl GroupedHashAggregateStream { /// accumulator states/values specified in emit_to fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result { let schema = if spilling { - self.spill_state.spill_schema.clone() + Arc::clone(&self.spill_state.spill_schema) } else { self.schema() }; @@ -746,13 +746,13 @@ impl GroupedHashAggregateStream { let expr = self.spill_state.spill_expr.clone(); let schema = batch.schema(); streams.push(Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), + Arc::clone(&schema), futures::stream::once(futures::future::lazy(move |_| { sort_batch(&batch, &expr, None) })), ))); for spill in self.spill_state.spills.drain(..) { - let stream = read_spill_as_stream(spill, schema.clone(), 2)?; + let stream = read_spill_as_stream(spill, Arc::clone(&schema), 2)?; streams.push(stream); } self.spill_state.is_stream_merging = true; diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index 9f25473cb9b4..075d8c5f2883 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -84,14 +84,14 @@ impl GroupedTopKAggregateStream { impl RecordBatchStream for GroupedTopKAggregateStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } impl GroupedTopKAggregateStream { fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> { let len = ids.len(); - self.priority_map.set_batch(ids, vals.clone()); + self.priority_map.set_batch(ids, Arc::clone(&vals)); let has_nulls = vals.null_count() > 0; for row_idx in 0..len { @@ -139,14 +139,14 @@ impl Stream for GroupedTopKAggregateStream { 1, "Exactly 1 group value required" ); - let group_by_values = group_by_values[0][0].clone(); + let group_by_values = Arc::clone(&group_by_values[0][0]); let input_values = evaluate_many( &self.aggregate_arguments, batches.first().unwrap(), )?; assert_eq!(input_values.len(), 1, "Exactly 1 input required"); assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); - let input_values = input_values[0][0].clone(); + let input_values = Arc::clone(&input_values[0][0]); // iterate over each column of group_by values (*self).intern(group_by_values, input_values)?; @@ -158,7 +158,7 @@ impl Stream for GroupedTopKAggregateStream { return Poll::Ready(None); } let cols = self.priority_map.emit()?; - let batch = RecordBatch::try_new(self.schema.clone(), cols)?; + let batch = RecordBatch::try_new(Arc::clone(&self.schema), cols)?; trace!( "partition {} emit batch with {} rows", self.partition, diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 5b859804163b..b4c1e25e6191 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -59,7 +59,7 @@ impl AnalyzeExec { input: Arc, schema: SchemaRef, ) -> Self { - let cache = Self::compute_properties(&input, schema.clone()); + let cache = Self::compute_properties(&input, Arc::clone(&schema)); AnalyzeExec { verbose, show_statistics, @@ -141,7 +141,7 @@ impl ExecutionPlan for AnalyzeExec { self.verbose, self.show_statistics, children.pop().unwrap(), - self.schema.clone(), + Arc::clone(&self.schema), ))) } @@ -164,13 +164,17 @@ impl ExecutionPlan for AnalyzeExec { RecordBatchReceiverStream::builder(self.schema(), num_input_partitions); for input_partition in 0..num_input_partitions { - builder.run_input(self.input.clone(), input_partition, context.clone()); + builder.run_input( + Arc::clone(&self.input), + input_partition, + Arc::clone(&context), + ); } // Create future that computes thefinal output let start = Instant::now(); - let captured_input = self.input.clone(); - let captured_schema = self.schema.clone(); + let captured_input = Arc::clone(&self.input); + let captured_schema = Arc::clone(&self.schema); let verbose = self.verbose; let show_statistics = self.show_statistics; @@ -196,7 +200,7 @@ impl ExecutionPlan for AnalyzeExec { }; Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), futures::stream::once(output), ))) } diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 804fabff71ac..b9bdfcdee712 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -134,7 +134,7 @@ impl ExecutionPlan for CoalesceBatchesExec { children: Vec>, ) -> Result> { Ok(Arc::new(CoalesceBatchesExec::new( - children[0].clone(), + Arc::clone(&children[0]), self.target_batch_size, ))) } @@ -272,7 +272,7 @@ impl CoalesceBatchesStream { impl RecordBatchStream for CoalesceBatchesStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -329,7 +329,7 @@ mod tests { target_batch_size: usize, ) -> Result>> { // create physical plan - let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; + let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)?; let exec = RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?; let exec: Arc = @@ -341,7 +341,7 @@ mod tests { for i in 0..output_partition_count { // execute this *output* partition and collect all batches let task_ctx = Arc::new(TaskContext::default()); - let mut stream = exec.execute(i, task_ctx.clone())?; + let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 93f449f2d39b..ef6afee80307 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -114,7 +114,9 @@ impl ExecutionPlan for CoalescePartitionsExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(CoalescePartitionsExec::new(children[0].clone()))) + Ok(Arc::new(CoalescePartitionsExec::new(Arc::clone( + &children[0], + )))) } fn execute( @@ -152,7 +154,11 @@ impl ExecutionPlan for CoalescePartitionsExec { // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. for part_i in 0..input_partitions { - builder.run_input(self.input.clone(), part_i, context.clone()); + builder.run_input( + Arc::clone(&self.input), + part_i, + Arc::clone(&context), + ); } let stream = builder.build(); diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index c61e9a05bfa6..bf9d14e73dd8 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -618,16 +618,17 @@ mod tests { expr: col("f32", &schema).unwrap(), options: SortOptions::default(), }]; - let memory_exec = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) as _; + let memory_exec = + Arc::new(MemoryExec::try_new(&[], Arc::clone(&schema), None)?) as _; let sort_exec = Arc::new(SortExec::new(sort_expr.clone(), memory_exec)) as Arc; let memory_exec2 = Arc::new(MemoryExec::try_new(&[], schema, None)?) as _; // memory_exec2 doesn't have output ordering - let union_exec = UnionExec::new(vec![sort_exec.clone(), memory_exec2]); + let union_exec = UnionExec::new(vec![Arc::clone(&sort_exec), memory_exec2]); let res = get_meet_of_orderings(union_exec.inputs()); assert!(res.is_none()); - let union_exec = UnionExec::new(vec![sort_exec.clone(), sort_exec]); + let union_exec = UnionExec::new(vec![Arc::clone(&sort_exec), sort_exec]); let res = get_meet_of_orderings(union_exec.inputs()); assert_eq!(res, Some(&sort_expr[..])); Ok(()) diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 11af0624db15..4bacea48c347 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -47,7 +47,7 @@ pub struct EmptyExec { impl EmptyExec { /// Create a new EmptyExec pub fn new(schema: SchemaRef) -> Self { - let cache = Self::compute_properties(schema.clone(), 1); + let cache = Self::compute_properties(Arc::clone(&schema), 1); EmptyExec { schema, partitions: 1, @@ -142,7 +142,7 @@ impl ExecutionPlan for EmptyExec { Ok(Box::pin(MemoryStream::try_new( self.data()?, - self.schema.clone(), + Arc::clone(&self.schema), None, )?)) } @@ -170,7 +170,7 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(schema.clone()); + let empty = EmptyExec::new(Arc::clone(&schema)); assert_eq!(empty.schema(), schema); // we should have no results @@ -184,9 +184,12 @@ mod tests { #[test] fn with_new_children() -> Result<()> { let schema = test::aggr_test_schema(); - let empty = Arc::new(EmptyExec::new(schema.clone())); + let empty = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?; + let empty2 = with_new_children_if_necessary( + Arc::clone(&empty) as Arc, + vec![], + )?; assert_eq!(empty.schema(), empty2.schema()); let too_many_kids = vec![empty2]; @@ -204,7 +207,7 @@ mod tests { let empty = EmptyExec::new(schema); // ask for the wrong partition - assert!(empty.execute(1, task_ctx.clone()).is_err()); + assert!(empty.execute(1, Arc::clone(&task_ctx)).is_err()); assert!(empty.execute(20, task_ctx).is_err()); Ok(()) } diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 4b2edbf2045d..56dc35e8819d 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -53,7 +53,7 @@ impl ExplainExec { stringified_plans: Vec, verbose: bool, ) -> Self { - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(Arc::clone(&schema)); ExplainExec { schema, stringified_plans, @@ -160,7 +160,7 @@ impl ExecutionPlan for ExplainExec { } let record_batch = RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(type_builder.finish()), Arc::new(plan_builder.finish()), @@ -171,7 +171,7 @@ impl ExecutionPlan for ExplainExec { "Before returning RecordBatchStream in ExplainExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), futures::stream::iter(vec![Ok(record_batch)]), ))) } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index c141958c1171..96ec6c0cf34d 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -78,7 +78,7 @@ impl FilterExec { Self::compute_properties(&input, &predicate, default_selectivity)?; Ok(Self { predicate, - input: input.clone(), + input: Arc::clone(&input), metrics: ExecutionPlanMetricsSet::new(), default_selectivity, cache, @@ -173,13 +173,11 @@ impl FilterExec { // Filter evaluates to single value for all partitions if input_eqs.is_expr_constant(binary.left()) { res_constants.push( - ConstExpr::new(binary.right().clone()) - .with_across_partitions(true), + ConstExpr::from(binary.right()).with_across_partitions(true), ) } else if input_eqs.is_expr_constant(binary.right()) { res_constants.push( - ConstExpr::new(binary.left().clone()) - .with_across_partitions(true), + ConstExpr::from(binary.left()).with_across_partitions(true), ) } } @@ -265,7 +263,7 @@ impl ExecutionPlan for FilterExec { self: Arc, mut children: Vec>, ) -> Result> { - FilterExec::try_new(self.predicate.clone(), children.swap_remove(0)) + FilterExec::try_new(Arc::clone(&self.predicate), children.swap_remove(0)) .and_then(|e| { let selectivity = e.default_selectivity(); e.with_default_selectivity(selectivity) @@ -282,7 +280,7 @@ impl ExecutionPlan for FilterExec { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); Ok(Box::pin(FilterExecStream { schema: self.input.schema(), - predicate: self.predicate.clone(), + predicate: Arc::clone(&self.predicate), input: self.input.execute(partition, context)?, baseline_metrics, })) @@ -407,7 +405,7 @@ impl Stream for FilterExecStream { impl RecordBatchStream for FilterExecStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1126,7 +1124,7 @@ mod tests { binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?, &schema, )?, - Arc::new(EmptyExec::new(schema.clone())), + Arc::new(EmptyExec::new(Arc::clone(&schema))), )?; exec.statistics().unwrap(); diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 30c3353d4b71..1c21991d93c5 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -153,7 +153,7 @@ impl DataSinkExec { } else { // Check not null constraint on the input stream Ok(Box::pin(RecordBatchStreamAdapter::new( - self.sink_schema.clone(), + Arc::clone(&self.sink_schema), input_stream .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), ))) @@ -252,9 +252,9 @@ impl ExecutionPlan for DataSinkExec { children: Vec>, ) -> Result> { Ok(Arc::new(Self::new( - children[0].clone(), - self.sink.clone(), - self.sink_schema.clone(), + Arc::clone(&children[0]), + Arc::clone(&self.sink), + Arc::clone(&self.sink_schema), self.sort_order.clone(), ))) } @@ -269,10 +269,10 @@ impl ExecutionPlan for DataSinkExec { if partition != 0 { return internal_err!("DataSinkExec can only be called on partition 0!"); } - let data = self.execute_input_stream(0, context.clone())?; + let data = self.execute_input_stream(0, Arc::clone(&context))?; - let count_schema = self.count_schema.clone(); - let sink = self.sink.clone(); + let count_schema = Arc::clone(&self.count_schema); + let sink = Arc::clone(&self.sink); let stream = futures::stream::once(async move { sink.write_all(data, &context).await.map(make_count_batch) diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 92443d06856a..33a9c061bf31 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -79,7 +79,7 @@ impl CrossJoinExec { }; let schema = Arc::new(Schema::new(all_columns)); - let cache = Self::compute_properties(&left, &right, schema.clone()); + let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); CrossJoinExec { left, right, @@ -220,8 +220,8 @@ impl ExecutionPlan for CrossJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(CrossJoinExec::new( - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), ))) } @@ -237,7 +237,7 @@ impl ExecutionPlan for CrossJoinExec { partition: usize, context: Arc, ) -> Result { - let stream = self.right.execute(partition, context.clone())?; + let stream = self.right.execute(partition, Arc::clone(&context))?; let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); @@ -247,7 +247,7 @@ impl ExecutionPlan for CrossJoinExec { let left_fut = self.left_fut.once(|| { load_left_input( - self.left.clone(), + Arc::clone(&self.left), context, join_metrics.clone(), reservation, @@ -255,7 +255,7 @@ impl ExecutionPlan for CrossJoinExec { }); Ok(Box::pin(CrossJoinStream { - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), left_fut, right: stream, left_index: 0, @@ -337,7 +337,7 @@ struct CrossJoinStream { impl RecordBatchStream for CrossJoinStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index b2f9ef560745..2f4ee00da35f 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -367,7 +367,7 @@ impl HashJoinExec { let cache = Self::compute_properties( &left, &right, - join_schema.clone(), + Arc::clone(&join_schema), *join_type, &on, partition_mode, @@ -433,7 +433,10 @@ impl HashJoinExec { false, matches!( join_type, - JoinType::Inner | JoinType::RightAnti | JoinType::RightSemi + JoinType::Inner + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightSemi ), ] } @@ -461,8 +464,8 @@ impl HashJoinExec { None => None, }; Self::try_new( - self.left.clone(), - self.right.clone(), + Arc::clone(&self.left), + Arc::clone(&self.right), self.on.clone(), self.filter.clone(), &self.join_type, @@ -487,7 +490,7 @@ impl HashJoinExec { left.equivalence_properties().clone(), right.equivalence_properties().clone(), &join_type, - schema.clone(), + Arc::clone(&schema), &Self::maintains_input_order(join_type), Some(Self::probe_side()), on, @@ -635,8 +638,11 @@ impl ExecutionPlan for HashJoinExec { Distribution::UnspecifiedDistribution, ], PartitionMode::Partitioned => { - let (left_expr, right_expr) = - self.on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); vec![ Distribution::HashPartitioned(left_expr), Distribution::HashPartitioned(right_expr), @@ -678,8 +684,8 @@ impl ExecutionPlan for HashJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(HashJoinExec::try_new( - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), self.on.clone(), self.filter.clone(), &self.join_type, @@ -694,8 +700,16 @@ impl ExecutionPlan for HashJoinExec { partition: usize, context: Arc, ) -> Result { - let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); - let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); + let on_left = self + .on + .iter() + .map(|on| Arc::clone(&on.0)) + .collect::>(); + let on_right = self + .on + .iter() + .map(|on| Arc::clone(&on.1)) + .collect::>(); let left_partitions = self.left.output_partitioning().partition_count(); let right_partitions = self.right.output_partitioning().partition_count(); @@ -715,9 +729,9 @@ impl ExecutionPlan for HashJoinExec { collect_left_input( None, self.random_state.clone(), - self.left.clone(), + Arc::clone(&self.left), on_left.clone(), - context.clone(), + Arc::clone(&context), join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), @@ -732,9 +746,9 @@ impl ExecutionPlan for HashJoinExec { OnceFut::new(collect_left_input( Some(partition), self.random_state.clone(), - self.left.clone(), + Arc::clone(&self.left), on_left.clone(), - context.clone(), + Arc::clone(&context), join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), @@ -779,6 +793,7 @@ impl ExecutionPlan for HashJoinExec { build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), batch_size, hashes_buffer: vec![], + right_side_ordered: self.right.output_ordering().is_some(), })) } @@ -791,8 +806,8 @@ impl ExecutionPlan for HashJoinExec { // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` let mut stats = estimate_join_statistics( - self.left.clone(), - self.right.clone(), + Arc::clone(&self.left), + Arc::clone(&self.right), self.on.clone(), &self.join_type, &self.join_schema, @@ -836,7 +851,7 @@ async fn collect_left_input( }; // Depending on partition argument load single partition or whole left side in memory - let stream = left_input.execute(left_input_partition, context.clone())?; + let stream = left_input.execute(left_input_partition, Arc::clone(&context))?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -1107,11 +1122,13 @@ struct HashJoinStream { batch_size: usize, /// Scratch space for computing hashes hashes_buffer: Vec, + /// Specifies whether the right side has an ordering to potentially preserve + right_side_ordered: bool, } impl RecordBatchStream for HashJoinStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1212,11 +1229,16 @@ fn eq_dyn_null( right: &dyn Array, null_equals_null: bool, ) -> Result { - // Nested datatypes cannot use the underlying not_distinct function and must use a special + // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special // implementation // - if left.data_type().is_nested() && null_equals_null { - return Ok(compare_op_for_nested(&Operator::Eq, &left, &right)?); + if left.data_type().is_nested() { + let op = if null_equals_null { + Operator::IsNotDistinctFrom + } else { + Operator::Eq + }; + return Ok(compare_op_for_nested(&op, &left, &right)?); } match (left.data_type(), right.data_type()) { _ if null_equals_null => not_distinct(&left, &right), @@ -1444,6 +1466,7 @@ impl HashJoinStream { right_indices, index_alignment_range_start..index_alignment_range_end, self.join_type, + self.right_side_ordered, ); let result = build_batch_from_indices( @@ -1537,7 +1560,6 @@ impl Stream for HashJoinStream { #[cfg(test)] mod tests { - use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, @@ -1546,6 +1568,8 @@ mod tests { use arrow::array::{Date32Array, Int32Array, UInt32Builder, UInt64Builder}; use arrow::datatypes::{DataType, Field}; + use arrow_array::StructArray; + use arrow_buffer::NullBuffer; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, ScalarValue, @@ -1669,8 +1693,10 @@ mod tests { ) -> Result<(Vec, Vec)> { let partition_count = 4; - let (left_expr, right_expr) = - on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); let left_repartitioned: Arc = match partition_mode { PartitionMode::CollectLeft => Arc::new(CoalescePartitionsExec::new(left)), @@ -1719,7 +1745,7 @@ mod tests { let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; + let stream = join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -1753,8 +1779,8 @@ mod tests { )]; let (columns, batches) = join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &JoinType::Inner, false, @@ -1800,8 +1826,8 @@ mod tests { )]; let (columns, batches) = partitioned_join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &JoinType::Inner, false, @@ -2104,7 +2130,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); // first part - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; // expected joined records = 1 (first right batch) @@ -2127,7 +2153,7 @@ mod tests { assert_batches_eq!(expected, &batches); // second part - let stream = join.execute(1, task_ctx.clone())?; + let stream = join.execute(1, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; // expected joined records = 2 (second right batch) @@ -2342,8 +2368,8 @@ mod tests { )]; let (columns, batches) = join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &JoinType::Left, false, @@ -2386,8 +2412,8 @@ mod tests { )]; let (columns, batches) = partitioned_join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &JoinType::Left, false, @@ -2498,8 +2524,8 @@ mod tests { ); let join = join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter, &JoinType::LeftSemi, @@ -2509,7 +2535,7 @@ mod tests { let columns_header = columns(&join.schema()); assert_eq!(columns_header.clone(), vec!["a1", "b1", "c1"]); - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; let expected = [ @@ -2622,8 +2648,8 @@ mod tests { ); let join = join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter, &JoinType::RightSemi, @@ -2633,7 +2659,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a2", "b2", "c2"]); - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; let expected = [ @@ -2744,8 +2770,8 @@ mod tests { ); let join = join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter, &JoinType::LeftAnti, @@ -2755,7 +2781,7 @@ mod tests { let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a1", "b1", "c1"]); - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; let expected = [ @@ -2873,8 +2899,8 @@ mod tests { ); let join = join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter, &JoinType::RightAnti, @@ -2884,7 +2910,7 @@ mod tests { let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a2", "b2", "c2"]); - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; let expected = [ @@ -3074,8 +3100,11 @@ mod tests { let random_state = RandomState::with_seeds(0, 0, 0, 0); let hashes_buff = &mut vec![0; left.num_rows()]; - let hashes = - create_hashes(&[left.columns()[0].clone()], &random_state, hashes_buff)?; + let hashes = create_hashes( + &[Arc::clone(&left.columns()[0])], + &random_state, + hashes_buff, + )?; // Create hash collisions (same hashes) hashmap_left.insert(hashes[0], (hashes[0], 1), |(h, _)| *h); @@ -3103,7 +3132,7 @@ mod tests { &join_hash_map, &left, &right, - &[key_column.clone()], + &[Arc::clone(&key_column)], &[key_column], false, &hashes_buffer, @@ -3463,13 +3492,13 @@ mod tests { for (join_type, expected) in test_cases { let (_, batches) = join_collect_with_partition_mode( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &join_type, PartitionMode::CollectLeft, false, - task_ctx.clone(), + Arc::clone(&task_ctx), ) .await?; assert_batches_sorted_eq!(expected, &batches); @@ -3487,13 +3516,14 @@ mod tests { let dates: ArrayRef = Arc::new(Date32Array::from(vec![19107, 19108, 19109])); let n: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); - let batch = RecordBatch::try_new(schema.clone(), vec![dates, n])?; - let left = - Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None).unwrap()); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![dates, n])?; + let left = Arc::new( + MemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None).unwrap(), + ); let dates: ArrayRef = Arc::new(Date32Array::from(vec![19108, 19108, 19109])); let n: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); - let batch = RecordBatch::try_new(schema.clone(), vec![dates, n])?; + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![dates, n])?; let right = Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()); let on = vec![( @@ -3555,8 +3585,8 @@ mod tests { for join_type in join_types { let join = join( - left.clone(), - right_input.clone(), + Arc::clone(&left), + Arc::clone(&right_input) as Arc, on.clone(), &join_type, false, @@ -3671,9 +3701,14 @@ mod tests { for batch_size in (1..21).rev() { let task_ctx = prepare_task_ctx(batch_size); - let join = - join(left.clone(), right.clone(), on.clone(), &join_type, false) - .unwrap(); + let join = join( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &join_type, + false, + ) + .unwrap(); let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); @@ -3746,7 +3781,13 @@ mod tests { let task_ctx = TaskContext::default().with_runtime(runtime); let task_ctx = Arc::new(task_ctx); - let join = join(left.clone(), right.clone(), on.clone(), &join_type, false)?; + let join = join( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &join_type, + false, + )?; let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); @@ -3819,8 +3860,8 @@ mod tests { let task_ctx = Arc::new(task_ctx); let join = HashJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(&left) as Arc, + Arc::clone(&right) as Arc, on.clone(), None, &join_type, @@ -3844,6 +3885,104 @@ mod tests { Ok(()) } + fn build_table_struct( + struct_name: &str, + field_name_and_values: (&str, &Vec>), + nulls: Option, + ) -> Arc { + let (field_name, values) = field_name_and_values; + let inner_fields = vec![Field::new(field_name, DataType::Int32, true)]; + let schema = Schema::new(vec![Field::new( + struct_name, + DataType::Struct(inner_fields.clone().into()), + nulls.is_some(), + )]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(StructArray::new( + inner_fields.into(), + vec![Arc::new(Int32Array::from(values.clone()))], + nulls, + ))], + ) + .unwrap(); + let schema_ref = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema_ref, None).unwrap()) + } + + #[tokio::test] + async fn join_on_struct() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = + build_table_struct("n1", ("a", &vec![None, Some(1), Some(2), Some(3)]), None); + let right = + build_table_struct("n2", ("a", &vec![None, Some(1), Some(2), Some(4)]), None); + let on = vec![( + Arc::new(Column::new_with_schema("n1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["n1", "n2"]); + + let expected = [ + "+--------+--------+", + "| n1 | n2 |", + "+--------+--------+", + "| {a: } | {a: } |", + "| {a: 1} | {a: 1} |", + "| {a: 2} | {a: 2} |", + "+--------+--------+", + ]; + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_on_struct_with_nulls() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = + build_table_struct("n1", ("a", &vec![None]), Some(NullBuffer::new_null(1))); + let right = + build_table_struct("n2", ("a", &vec![None]), Some(NullBuffer::new_null(1))); + let on = vec![( + Arc::new(Column::new_with_schema("n1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, + )]; + + let (_, batches_null_eq) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::Inner, + true, + Arc::clone(&task_ctx), + ) + .await?; + + let expected_null_eq = [ + "+----+----+", + "| n1 | n2 |", + "+----+----+", + "| | |", + "+----+----+", + ]; + assert_batches_eq!(expected_null_eq, &batches_null_eq); + + let (_, batches_null_neq) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + let expected_null_neq = + ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; + assert_batches_eq!(expected_null_neq, &batches_null_neq); + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 6be124cce06f..754e55e49650 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -173,7 +173,8 @@ impl NestedLoopJoinExec { let (schema, column_indices) = build_join_schema(&left_schema, &right_schema, join_type); let schema = Arc::new(schema); - let cache = Self::compute_properties(&left, &right, schema.clone(), *join_type); + let cache = + Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type); Ok(NestedLoopJoinExec { left, @@ -287,8 +288,8 @@ impl ExecutionPlan for NestedLoopJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(NestedLoopJoinExec::try_new( - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), self.filter.clone(), &self.join_type, )?)) @@ -308,8 +309,8 @@ impl ExecutionPlan for NestedLoopJoinExec { let inner_table = self.inner_table.once(|| { collect_left_input( - self.left.clone(), - context.clone(), + Arc::clone(&self.left), + Arc::clone(&context), join_metrics.clone(), load_reservation, need_produce_result_in_final(self.join_type), @@ -319,7 +320,7 @@ impl ExecutionPlan for NestedLoopJoinExec { let outer_table = self.right.execute(partition, context)?; Ok(Box::pin(NestedLoopJoinStream { - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), filter: self.filter.clone(), join_type: self.join_type, outer_table, @@ -336,8 +337,8 @@ impl ExecutionPlan for NestedLoopJoinExec { fn statistics(&self) -> Result { estimate_join_statistics( - self.left.clone(), - self.right.clone(), + Arc::clone(&self.left), + Arc::clone(&self.right), vec![], &self.join_type, &self.schema, @@ -604,6 +605,7 @@ fn join_left_and_right_batch( right_side, 0..right_batch.num_rows(), join_type, + false, ); build_batch_from_indices( @@ -641,13 +643,12 @@ impl Stream for NestedLoopJoinStream { impl RecordBatchStream for NestedLoopJoinStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } #[cfg(test)] mod tests { - use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, @@ -752,7 +753,7 @@ mod tests { let columns = columns(&nested_loop_join.schema()); let mut batches = vec![]; for i in 0..partition_count { - let stream = nested_loop_join.execute(i, context.clone())?; + let stream = nested_loop_join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -1037,8 +1038,8 @@ mod tests { let task_ctx = Arc::new(task_ctx); let err = multi_partitioned_join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), &join_type, Some(filter.clone()), task_ctx, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 91b2151d32e7..e9124a72970a 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -126,11 +126,11 @@ impl SortMergeJoinExec { .zip(sort_options.iter()) .map(|((l, r), sort_op)| { let left = PhysicalSortExpr { - expr: l.clone(), + expr: Arc::clone(l), options: *sort_op, }; let right = PhysicalSortExpr { - expr: r.clone(), + expr: Arc::clone(r), options: *sort_op, }; (left, right) @@ -140,7 +140,7 @@ impl SortMergeJoinExec { let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); let cache = - Self::compute_properties(&left, &right, schema.clone(), join_type, &on); + Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on); Ok(Self { left, right, @@ -271,8 +271,11 @@ impl ExecutionPlan for SortMergeJoinExec { } fn required_input_distribution(&self) -> Vec { - let (left_expr, right_expr) = - self.on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); vec![ Distribution::HashPartitioned(left_expr), Distribution::HashPartitioned(right_expr), @@ -304,8 +307,8 @@ impl ExecutionPlan for SortMergeJoinExec { ) -> Result> { match &children[..] { [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), self.on.clone(), self.filter.clone(), self.join_type, @@ -332,14 +335,24 @@ impl ExecutionPlan for SortMergeJoinExec { let (on_left, on_right) = self.on.iter().cloned().unzip(); let (streamed, buffered, on_streamed, on_buffered) = if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left { - (self.left.clone(), self.right.clone(), on_left, on_right) + ( + Arc::clone(&self.left), + Arc::clone(&self.right), + on_left, + on_right, + ) } else { - (self.right.clone(), self.left.clone(), on_right, on_left) + ( + Arc::clone(&self.right), + Arc::clone(&self.left), + on_right, + on_left, + ) }; // execute children plans - let streamed = streamed.execute(partition, context.clone())?; - let buffered = buffered.execute(partition, context.clone())?; + let streamed = streamed.execute(partition, Arc::clone(&context))?; + let buffered = buffered.execute(partition, Arc::clone(&context))?; // create output buffer let batch_size = context.session_config().batch_size(); @@ -350,7 +363,7 @@ impl ExecutionPlan for SortMergeJoinExec { // create join stream Ok(Box::pin(SMJStream::try_new( - self.schema.clone(), + Arc::clone(&self.schema), self.sort_options.clone(), self.null_equals_null, streamed, @@ -374,8 +387,8 @@ impl ExecutionPlan for SortMergeJoinExec { // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` estimate_join_statistics( - self.left.clone(), - self.right.clone(), + Arc::clone(&self.left), + Arc::clone(&self.right), self.on.clone(), &self.join_type, &self.schema, @@ -657,7 +670,7 @@ struct SMJStream { impl RecordBatchStream for SMJStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -780,7 +793,7 @@ impl SMJStream { sort_options, null_equals_null, schema, - streamed_schema: streamed_schema.clone(), + streamed_schema: Arc::clone(&streamed_schema), buffered_schema, streamed, buffered, @@ -1233,7 +1246,7 @@ impl SMJStream { }; let output_batch = - RecordBatch::try_new(self.schema.clone(), columns.clone())?; + RecordBatch::try_new(Arc::clone(&self.schema), columns.clone())?; // Apply join filter if any if !filter_columns.is_empty() { @@ -1353,8 +1366,10 @@ impl SMJStream { }; // Push the streamed/buffered batch joined nulls to the output - let null_joined_streamed_batch = - RecordBatch::try_new(self.schema.clone(), columns.clone())?; + let null_joined_streamed_batch = RecordBatch::try_new( + Arc::clone(&self.schema), + columns.clone(), + )?; self.output_record_batches.push(null_joined_streamed_batch); // For full join, we also need to output the null joined rows from the buffered side. @@ -1430,14 +1445,14 @@ fn get_filter_column( .column_indices() .iter() .filter(|col_index| col_index.side == JoinSide::Left) - .map(|i| streamed_columns[i.index].clone()) + .map(|i| Arc::clone(&streamed_columns[i.index])) .collect::>(); let right_columns = f .column_indices() .iter() .filter(|col_index| col_index.side == JoinSide::Right) - .map(|i| buffered_columns[i.index].clone()) + .map(|i| Arc::clone(&buffered_columns[i.index])) .collect::>(); filter_columns.extend(left_columns); @@ -1476,7 +1491,7 @@ fn produce_buffered_null_batch( streamed_columns.extend(buffered_columns); Ok(Some(RecordBatch::try_new( - schema.clone(), + Arc::clone(schema), streamed_columns, )?)) } @@ -1927,7 +1942,7 @@ mod tests { Field::new(c.0, DataType::Int32, true), ])); let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(Int32Array::from(a.1.clone())), Arc::new(Int32Array::from(b.1.clone())), @@ -2771,8 +2786,8 @@ mod tests { let task_ctx = Arc::new(task_ctx); let join = join_with_options( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), join_type, sort_options.clone(), @@ -2849,8 +2864,8 @@ mod tests { .with_runtime(runtime); let task_ctx = Arc::new(task_ctx); let join = join_with_options( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), join_type, sort_options.clone(), diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 46d3ac5acf1e..ba9384aef1a6 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -270,7 +270,7 @@ pub fn convert_sort_expr_with_filter_schema( sort_expr: &PhysicalSortExpr, ) -> Result>> { let column_map = map_origin_col_to_filter_col(filter, schema, side)?; - let expr = sort_expr.expr.clone(); + let expr = Arc::clone(&sort_expr.expr); // Get main schema columns: let expr_columns = collect_columns(&expr); // Calculation is possible with `column_map` since sort exprs belong to a child. @@ -697,7 +697,7 @@ fn update_sorted_exprs_with_node_indices( // Extract filter expressions from the sorted expressions: let filter_exprs = sorted_exprs .iter() - .map(|expr| expr.filter_expr().clone()) + .map(|expr| Arc::clone(expr.filter_expr())) .collect::>(); // Gather corresponding node indices for the extracted filter expressions from the graph: @@ -756,7 +756,7 @@ pub fn prepare_sorted_exprs( // Build the expression interval graph let mut graph = - ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; + ExprIntervalGraph::try_new(Arc::clone(filter.expression()), filter.schema())?; // Update sorted expressions with node indices update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); @@ -818,9 +818,9 @@ pub mod tests { &intermediate_schema, )?; let filter_expr = binary( - filter_left.clone(), + Arc::clone(&filter_left), Operator::Gt, - filter_right.clone(), + Arc::clone(&filter_right), &intermediate_schema, )?; let column_indices = vec![ diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 813f670147bc..c23dc2032c4b 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -233,7 +233,7 @@ impl SymmetricHashJoinExec { let random_state = RandomState::with_seeds(0, 0, 0, 0); let schema = Arc::new(schema); let cache = - Self::compute_properties(&left, &right, schema.clone(), *join_type, &on); + Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type, &on); Ok(SymmetricHashJoinExec { left, right, @@ -397,7 +397,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { let (left_expr, right_expr) = self .on .iter() - .map(|(l, r)| (l.clone() as _, r.clone() as _)) + .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _)) .unzip(); vec![ Distribution::HashPartitioned(left_expr), @@ -430,8 +430,8 @@ impl ExecutionPlan for SymmetricHashJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(SymmetricHashJoinExec::try_new( - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), self.on.clone(), self.filter.clone(), &self.join_type, @@ -489,9 +489,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_side_joiner = OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema()); - let left_stream = self.left.execute(partition, context.clone())?; + let left_stream = self.left.execute(partition, Arc::clone(&context))?; - let right_stream = self.right.execute(partition, context.clone())?; + let right_stream = self.right.execute(partition, Arc::clone(&context))?; let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) @@ -559,7 +559,7 @@ struct SymmetricHashJoinStream { impl RecordBatchStream for SymmetricHashJoinStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1634,13 +1634,13 @@ mod tests { task_ctx: Arc, ) -> Result<()> { let first_batches = partitioned_sym_join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter.clone(), &join_type, false, - task_ctx.clone(), + Arc::clone(&task_ctx), ) .await?; let second_batches = partitioned_hash_join_with_filter( diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 7e05ded6f69d..264f297ffb4c 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -78,17 +78,23 @@ pub async fn partitioned_sym_join_with_filter( ) -> Result> { let partition_count = 4; - let left_expr = on.iter().map(|(l, _)| l.clone() as _).collect::>(); + let left_expr = on + .iter() + .map(|(l, _)| Arc::clone(l) as _) + .collect::>(); - let right_expr = on.iter().map(|(_, r)| r.clone() as _).collect::>(); + let right_expr = on + .iter() + .map(|(_, r)| Arc::clone(r) as _) + .collect::>(); let join = SymmetricHashJoinExec::try_new( Arc::new(RepartitionExec::try_new( - left.clone(), + Arc::clone(&left), Partitioning::Hash(left_expr, partition_count), )?), Arc::new(RepartitionExec::try_new( - right.clone(), + Arc::clone(&right), Partitioning::Hash(right_expr, partition_count), )?), on, @@ -102,7 +108,7 @@ pub async fn partitioned_sym_join_with_filter( let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; + let stream = join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -127,7 +133,7 @@ pub async fn partitioned_hash_join_with_filter( let partition_count = 4; let (left_expr, right_expr) = on .iter() - .map(|(l, r)| (l.clone() as _, r.clone() as _)) + .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _)) .unzip(); let join = Arc::new(HashJoinExec::try_new( @@ -149,7 +155,7 @@ pub async fn partitioned_hash_join_with_filter( let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; + let stream = join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -475,20 +481,29 @@ pub fn build_sides_record_batches( )); let left = RecordBatch::try_from_iter(vec![ - ("la1", ordered.clone()), - ("lb1", cardinality.clone()), + ("la1", Arc::clone(&ordered)), + ("lb1", Arc::clone(&cardinality) as ArrayRef), ("lc1", cardinality_key_left), - ("lt1", time.clone()), - ("la2", ordered.clone()), - ("la1_des", ordered_des.clone()), - ("l_asc_null_first", ordered_asc_null_first.clone()), - ("l_asc_null_last", ordered_asc_null_last.clone()), - ("l_desc_null_first", ordered_desc_null_first.clone()), - ("li1", interval_time.clone()), - ("l_float", float_asc.clone()), + ("lt1", Arc::clone(&time) as ArrayRef), + ("la2", Arc::clone(&ordered)), + ("la1_des", Arc::clone(&ordered_des) as ArrayRef), + ( + "l_asc_null_first", + Arc::clone(&ordered_asc_null_first) as ArrayRef, + ), + ( + "l_asc_null_last", + Arc::clone(&ordered_asc_null_last) as ArrayRef, + ), + ( + "l_desc_null_first", + Arc::clone(&ordered_desc_null_first) as ArrayRef, + ), + ("li1", Arc::clone(&interval_time)), + ("l_float", Arc::clone(&float_asc) as ArrayRef), ])?; let right = RecordBatch::try_from_iter(vec![ - ("ra1", ordered.clone()), + ("ra1", Arc::clone(&ordered)), ("rb1", cardinality), ("rc1", cardinality_key_right), ("rt1", time), diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index dfa1fd4763f4..e3ec242ce8de 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -34,8 +34,9 @@ use arrow::array::{ UInt32BufferBuilder, UInt32Builder, UInt64Array, UInt64BufferBuilder, }; use arrow::compute; -use arrow::datatypes::{Field, Schema, SchemaBuilder}; +use arrow::datatypes::{Field, Schema, SchemaBuilder, UInt32Type, UInt64Type}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use arrow_array::builder::UInt64Builder; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray}; use arrow_buffer::ArrowNativeType; use datafusion_common::cast::as_boolean_array; @@ -439,7 +440,7 @@ pub fn adjust_right_output_partitioning( Partitioning::Hash(exprs, size) => { let new_exprs = exprs .iter() - .map(|expr| add_offset_to_expr(expr.clone(), left_columns_len)) + .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len)) .collect(); Partitioning::Hash(new_exprs, *size) } @@ -455,12 +456,10 @@ fn replace_on_columns_of_right_ordering( ) -> Result<()> { for (left_col, right_col) in on_columns { for item in right_ordering.iter_mut() { - let new_expr = item - .expr - .clone() + let new_expr = Arc::clone(&item.expr) .transform(|e| { if e.eq(right_col) { - Ok(Transformed::yes(left_col.clone())) + Ok(Transformed::yes(Arc::clone(left_col))) } else { Ok(Transformed::no(e)) } @@ -483,7 +482,7 @@ fn offset_ordering( JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => ordering .iter() .map(|sort_expr| PhysicalSortExpr { - expr: add_offset_to_expr(sort_expr.expr.clone(), offset), + expr: add_offset_to_expr(Arc::clone(&sort_expr.expr), offset), options: sort_expr.options, }) .collect(), @@ -1121,7 +1120,7 @@ impl OnceFut { OnceFutState::Ready(r) => Poll::Ready( r.as_ref() .map(|r| r.as_ref()) - .map_err(|e| DataFusionError::External(Box::new(e.clone()))), + .map_err(|e| DataFusionError::External(Box::new(Arc::clone(e)))), ), } } @@ -1284,6 +1283,7 @@ pub(crate) fn adjust_indices_by_join_type( right_indices: UInt32Array, adjust_range: Range, join_type: JoinType, + preserve_order_for_right: bool, ) -> (UInt64Array, UInt32Array) { match join_type { JoinType::Inner => { @@ -1295,12 +1295,17 @@ pub(crate) fn adjust_indices_by_join_type( (left_indices, right_indices) // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap } - JoinType::Right | JoinType::Full => { - // matched - // unmatched right row will be produced in this batch - let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); + JoinType::Right => { // combine the matched and unmatched right result together - append_right_indices(left_indices, right_indices, right_unmatched_indices) + append_right_indices( + left_indices, + right_indices, + adjust_range, + preserve_order_for_right, + ) + } + JoinType::Full => { + append_right_indices(left_indices, right_indices, adjust_range, false) } JoinType::RightSemi => { // need to remove the duplicated record in the right side @@ -1326,30 +1331,48 @@ pub(crate) fn adjust_indices_by_join_type( } } -/// Appends the `right_unmatched_indices` to the `right_indices`, -/// and fills Null to tail of `left_indices` to -/// keep the length of `right_indices` and `left_indices` consistent. +/// Appends right indices to left indices based on the specified order mode. +/// +/// The function operates in two modes: +/// 1. If `preserve_order_for_right` is true, probe matched and unmatched indices +/// are inserted in order using the `append_probe_indices_in_order()` method. +/// 2. Otherwise, unmatched probe indices are simply appended after matched ones. +/// +/// # Parameters +/// - `left_indices`: UInt64Array of left indices. +/// - `right_indices`: UInt32Array of right indices. +/// - `adjust_range`: Range to adjust the right indices. +/// - `preserve_order_for_right`: Boolean flag to determine the mode of operation. +/// +/// # Returns +/// A tuple of updated `UInt64Array` and `UInt32Array`. pub(crate) fn append_right_indices( left_indices: UInt64Array, right_indices: UInt32Array, - right_unmatched_indices: UInt32Array, + adjust_range: Range, + preserve_order_for_right: bool, ) -> (UInt64Array, UInt32Array) { - // left_indices, right_indices and right_unmatched_indices must not contain the null value - if right_unmatched_indices.is_empty() { - (left_indices, right_indices) + if preserve_order_for_right { + append_probe_indices_in_order(left_indices, right_indices, adjust_range) } else { - let unmatched_size = right_unmatched_indices.len(); - // the new left indices: left_indices + null array - // the new right indices: right_indices + right_unmatched_indices - let new_left_indices = left_indices - .iter() - .chain(std::iter::repeat(None).take(unmatched_size)) - .collect::(); - let new_right_indices = right_indices - .iter() - .chain(right_unmatched_indices.iter()) - .collect::(); - (new_left_indices, new_right_indices) + let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); + + if right_unmatched_indices.is_empty() { + (left_indices, right_indices) + } else { + let unmatched_size = right_unmatched_indices.len(); + // the new left indices: left_indices + null array + // the new right indices: right_indices + right_unmatched_indices + let new_left_indices = left_indices + .iter() + .chain(std::iter::repeat(None).take(unmatched_size)) + .collect(); + let new_right_indices = right_indices + .iter() + .chain(right_unmatched_indices.iter()) + .collect(); + (new_left_indices, new_right_indices) + } } } @@ -1379,7 +1402,7 @@ where .filter_map(|idx| { (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) }) - .collect::>() + .collect() } /// Returns intersection of `range` and `input_indices` omitting duplicates @@ -1408,7 +1431,61 @@ where .filter_map(|idx| { (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) }) - .collect::>() + .collect() +} + +/// Appends probe indices in order by considering the given build indices. +/// +/// This function constructs new build and probe indices by iterating through +/// the provided indices, and appends any missing values between previous and +/// current probe index with a corresponding null build index. +/// +/// # Parameters +/// +/// - `build_indices`: `PrimitiveArray` of `UInt64Type` containing build indices. +/// - `probe_indices`: `PrimitiveArray` of `UInt32Type` containing probe indices. +/// - `range`: The range of indices to consider. +/// +/// # Returns +/// +/// A tuple of two arrays: +/// - A `PrimitiveArray` of `UInt64Type` with the newly constructed build indices. +/// - A `PrimitiveArray` of `UInt32Type` with the newly constructed probe indices. +fn append_probe_indices_in_order( + build_indices: PrimitiveArray, + probe_indices: PrimitiveArray, + range: Range, +) -> (PrimitiveArray, PrimitiveArray) { + // Builders for new indices: + let mut new_build_indices = UInt64Builder::new(); + let mut new_probe_indices = UInt32Builder::new(); + // Set previous index as the start index for the initial loop: + let mut prev_index = range.start as u32; + // Zip the two iterators. + debug_assert!(build_indices.len() == probe_indices.len()); + for (build_index, probe_index) in build_indices + .values() + .into_iter() + .zip(probe_indices.values().into_iter()) + { + // Append values between previous and current probe index with null build index: + for value in prev_index..*probe_index { + new_probe_indices.append_value(value); + new_build_indices.append_null(); + } + // Append current indices: + new_probe_indices.append_value(*probe_index); + new_build_indices.append_value(*build_index); + // Set current probe index as previous for the next iteration: + prev_index = probe_index + 1; + } + // Append remaining probe indices after the last valid probe index with null build index. + for value in prev_index..range.end as u32 { + new_probe_indices.append_value(value); + new_build_indices.append_null(); + } + // Build arrays and return: + (new_build_indices.finish(), new_probe_indices.finish()) } /// Metrics for build & probe joins @@ -2475,7 +2552,7 @@ mod tests { &on_columns, left_columns_len, maintains_input_order, - probe_side + probe_side, ), expected[i] ); diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index aef5b307968c..f3a709ff7670 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! Traits for physical query plan, supporting parallel execution for partitioned relations. @@ -155,7 +157,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// Get the schema for this execution plan fn schema(&self) -> SchemaRef { - self.properties().schema().clone() + Arc::clone(self.properties().schema()) } /// Return properties of the output of the `ExecutionPlan`, such as output @@ -736,7 +738,7 @@ pub fn execute_stream( 1 => plan.execute(0, context), _ => { // merge into a single partition - let plan = CoalescePartitionsExec::new(plan.clone()); + let plan = CoalescePartitionsExec::new(Arc::clone(&plan)); // CoalescePartitionsExec must produce a single partition assert_eq!(1, plan.properties().output_partitioning().partition_count()); plan.execute(0, context) @@ -798,7 +800,7 @@ pub fn execute_stream_partitioned( let num_partitions = plan.output_partitioning().partition_count(); let mut streams = Vec::with_capacity(num_partitions); for i in 0..num_partitions { - streams.push(plan.execute(i, context.clone())?); + streams.push(plan.execute(i, Arc::clone(&context))?); } Ok(streams) } diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 4c6d1b3674d5..9c77a3d05cc2 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -145,7 +145,7 @@ impl ExecutionPlan for GlobalLimitExec { children: Vec>, ) -> Result> { Ok(Arc::new(GlobalLimitExec::new( - children[0].clone(), + Arc::clone(&children[0]), self.skip, self.fetch, ))) @@ -352,7 +352,7 @@ impl ExecutionPlan for LocalLimitExec { ) -> Result> { match children.len() { 1 => Ok(Arc::new(LocalLimitExec::new( - children[0].clone(), + Arc::clone(&children[0]), self.fetch, ))), _ => internal_err!("LocalLimitExec wrong number of children"), @@ -551,7 +551,7 @@ impl Stream for LimitStream { impl RecordBatchStream for LimitStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -864,11 +864,11 @@ mod tests { // Adding a "GROUP BY i" changes the input stats from Exact to Inexact. let agg = AggregateExec::try_new( AggregateMode::Final, - build_group_by(&csv.schema().clone(), vec!["i".to_string()]), + build_group_by(&csv.schema(), vec!["i".to_string()]), vec![], vec![], - csv.clone(), - csv.schema().clone(), + Arc::clone(&csv), + Arc::clone(&csv.schema()), )?; let agg_exec: Arc = Arc::new(agg); diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 39ae8d551f4b..6b2c78902eae 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -140,7 +140,7 @@ impl ExecutionPlan for MemoryExec { ) -> Result { Ok(Box::pin(MemoryStream::try_new( self.partitions[partition].clone(), - self.projected_schema.clone(), + Arc::clone(&self.projected_schema), self.projection.clone(), )?)) } @@ -164,7 +164,8 @@ impl MemoryExec { projection: Option>, ) -> Result { let projected_schema = project_schema(&schema, projection.as_ref())?; - let cache = Self::compute_properties(projected_schema.clone(), &[], partitions); + let cache = + Self::compute_properties(Arc::clone(&projected_schema), &[], partitions); Ok(Self { partitions: partitions.to_vec(), schema, @@ -219,7 +220,7 @@ impl MemoryExec { } pub fn original_schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -305,7 +306,7 @@ impl Stream for MemoryStream { impl RecordBatchStream for MemoryStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 3b10cc0ac435..272211d5056e 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -50,7 +50,7 @@ impl PlaceholderRowExec { /// Create a new PlaceholderRowExec pub fn new(schema: SchemaRef) -> Self { let partitions = 1; - let cache = Self::compute_properties(schema.clone(), partitions); + let cache = Self::compute_properties(Arc::clone(&schema), partitions); PlaceholderRowExec { schema, partitions, @@ -160,7 +160,7 @@ impl ExecutionPlan for PlaceholderRowExec { Ok(Box::pin(MemoryStream::try_new( self.data()?, - self.schema.clone(), + Arc::clone(&self.schema), None, )?)) } @@ -188,7 +188,10 @@ mod tests { let placeholder = Arc::new(PlaceholderRowExec::new(schema)); - let placeholder_2 = with_new_children_if_necessary(placeholder.clone(), vec![])?; + let placeholder_2 = with_new_children_if_necessary( + Arc::clone(&placeholder) as Arc, + vec![], + )?; assert_eq!(placeholder.schema(), placeholder_2.schema()); let too_many_kids = vec![placeholder_2]; @@ -206,7 +209,7 @@ mod tests { let placeholder = PlaceholderRowExec::new(schema); // ask for the wrong partition - assert!(placeholder.execute(1, task_ctx.clone()).is_err()); + assert!(placeholder.execute(1, Arc::clone(&task_ctx)).is_err()); assert!(placeholder.execute(20, task_ctx).is_err()); Ok(()) } @@ -234,7 +237,7 @@ mod tests { let placeholder = PlaceholderRowExec::new(schema).with_partitions(partitions); for n in 0..partitions { - let iter = placeholder.execute(n, task_ctx.clone())?; + let iter = placeholder.execute(n, Arc::clone(&task_ctx))?; let batches = common::collect(iter).await?; // should have one item diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 8341549340dd..9efa0422ec75 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -94,7 +94,7 @@ impl ProjectionExec { // construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?; let cache = - Self::compute_properties(&input, &projection_mapping, schema.clone())?; + Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?; Ok(Self { expr, schema, @@ -227,8 +227,8 @@ impl ExecutionPlan for ProjectionExec { ) -> Result { trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); Ok(Box::pin(ProjectionStream { - schema: self.schema.clone(), - expr: self.expr.iter().map(|x| x.0.clone()).collect(), + schema: Arc::clone(&self.schema), + expr: self.expr.iter().map(|x| Arc::clone(&x.0)).collect(), input: self.input.execute(partition, context)?, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) @@ -242,7 +242,7 @@ impl ExecutionPlan for ProjectionExec { Ok(stats_projection( self.input.statistics()?, self.expr.iter().map(|(e, _)| Arc::clone(e)), - self.schema.clone(), + Arc::clone(&self.schema), )) } } @@ -311,10 +311,10 @@ impl ProjectionStream { if arrays.is_empty() { let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - RecordBatch::try_new_with_options(self.schema.clone(), arrays, &options) + RecordBatch::try_new_with_options(Arc::clone(&self.schema), arrays, &options) .map_err(Into::into) } else { - RecordBatch::try_new(self.schema.clone(), arrays).map_err(Into::into) + RecordBatch::try_new(Arc::clone(&self.schema), arrays).map_err(Into::into) } } } @@ -351,7 +351,7 @@ impl Stream for ProjectionStream { impl RecordBatchStream for ProjectionStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -370,10 +370,12 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let exec = test::scan_partitioned(1); - let expected = collect(exec.execute(0, task_ctx.clone())?).await.unwrap(); + let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?) + .await + .unwrap(); let projection = ProjectionExec::try_new(vec![], exec)?; - let stream = projection.execute(0, task_ctx.clone())?; + let stream = projection.execute(0, Arc::clone(&task_ctx))?; let output = collect(stream).await.unwrap(); assert_eq!(output.len(), expected.len()); diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 9a0b66caba31..bd9303f97db0 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -82,7 +82,7 @@ impl RecursiveQueryExec { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new()); // Use the same work table for both the WorkTableExec and the recursive term - let recursive_term = assign_work_table(recursive_term, work_table.clone())?; + let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?; let cache = Self::compute_properties(static_term.schema()); Ok(RecursiveQueryExec { name, @@ -147,8 +147,8 @@ impl ExecutionPlan for RecursiveQueryExec { ) -> Result> { RecursiveQueryExec::try_new( self.name.clone(), - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), self.is_distinct, ) .map(|e| Arc::new(e) as _) @@ -167,12 +167,12 @@ impl ExecutionPlan for RecursiveQueryExec { ))); } - let static_stream = self.static_term.execute(partition, context.clone())?; + let static_stream = self.static_term.execute(partition, Arc::clone(&context))?; let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); Ok(Box::pin(RecursiveQueryStream::new( context, - self.work_table.clone(), - self.recursive_term.clone(), + Arc::clone(&self.work_table), + Arc::clone(&self.recursive_term), static_stream, baseline_metrics, ))) @@ -313,9 +313,9 @@ impl RecursiveQueryStream { // Downstream plans should not expect any partitioning. let partition = 0; - let recursive_plan = reset_plan_states(self.recursive_term.clone())?; + let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?; self.recursive_stream = - Some(recursive_plan.execute(partition, self.task_context.clone())?); + Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?); self.poll_next(cx) } } @@ -334,7 +334,7 @@ fn assign_work_table( } else { work_table_refs += 1; Ok(Transformed::yes(Arc::new( - exec.with_work_table(work_table.clone()), + exec.with_work_table(Arc::clone(&work_table)), ))) } } else if plan.as_any().is::() { @@ -358,8 +358,7 @@ fn reset_plan_states(plan: Arc) -> Result() { Ok(Transformed::no(plan)) } else { - let new_plan = plan - .clone() + let new_plan = Arc::clone(&plan) .with_new_children(plan.children().into_iter().cloned().collect())?; Ok(Transformed::yes(new_plan)) } @@ -407,7 +406,7 @@ impl Stream for RecursiveQueryStream { impl RecordBatchStream for RecursiveQueryStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index d9e16c98eee8..3d4d3058393e 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -133,12 +133,12 @@ impl RepartitionExecState { let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics); let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( - input.clone(), + Arc::clone(&input), i, txs.clone(), partitioning.clone(), r_metrics, - context.clone(), + Arc::clone(&context), )); // In a separate task, wait for each input to be done @@ -616,7 +616,7 @@ impl ExecutionPlan for RepartitionExec { schema: Arc::clone(&schema_captured), receiver, drop_helper: Arc::clone(&abort_helper), - reservation: reservation.clone(), + reservation: Arc::clone(&reservation), }) as SendableRecordBatchStream }) .collect::>(); @@ -866,7 +866,7 @@ impl RepartitionExec { for (_, tx) in txs { // wrap it because need to send error to all output partitions - let err = Err(DataFusionError::External(Box::new(e.clone()))); + let err = Err(DataFusionError::External(Box::new(Arc::clone(&e)))); tx.send(Some(err)).await.ok(); } } @@ -945,7 +945,7 @@ impl Stream for RepartitionStream { impl RecordBatchStream for RepartitionStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -995,7 +995,7 @@ impl Stream for PerPartitionStream { impl RecordBatchStream for PerPartitionStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1117,14 +1117,14 @@ mod tests { ) -> Result>> { let task_ctx = Arc::new(TaskContext::default()); // create physical plan - let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; + let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)?; let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; // execute and collect results let mut output_partitions = vec![]; for i in 0..exec.partitioning.partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i, task_ctx.clone())?; + let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -1301,10 +1301,14 @@ mod tests { let input = Arc::new(make_barrier_exec()); // partition into two output streams - let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); + let exec = RepartitionExec::try_new( + Arc::clone(&input) as Arc, + partitioning, + ) + .unwrap(); - let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap(); - let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); + let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap(); + let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced @@ -1349,8 +1353,12 @@ mod tests { // We first collect the results without droping the output stream. let input = Arc::new(make_barrier_exec()); - let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); - let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); + let exec = RepartitionExec::try_new( + Arc::clone(&input) as Arc, + partitioning.clone(), + ) + .unwrap(); + let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); let mut background_task = JoinSet::new(); background_task.spawn(async move { input.wait().await; @@ -1370,9 +1378,13 @@ mod tests { // Now do the same but dropping the stream before waiting for the barrier let input = Arc::new(make_barrier_exec()); - let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap(); - let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); + let exec = RepartitionExec::try_new( + Arc::clone(&input) as Arc, + partitioning, + ) + .unwrap(); + let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap(); + let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced std::mem::drop(output_stream0); @@ -1471,9 +1483,9 @@ mod tests { let schema = batch.schema(); let input = MockExec::new(vec![Ok(batch)], schema); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap(); + let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap(); let batch0 = crate::common::collect(output_stream0).await.unwrap(); - let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); + let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); let batch1 = crate::common::collect(output_stream1).await.unwrap(); assert!(batch0.is_empty() || batch1.is_empty()); Ok(()) @@ -1496,12 +1508,12 @@ mod tests { let task_ctx = Arc::new(task_ctx); // create physical plan - let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; + let exec = MemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; // pull partitions for i in 0..exec.partitioning.partition_count() { - let mut stream = exec.execute(i, task_ctx.clone())?; + let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; let err = arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into()); let err = err.find_root(); @@ -1642,7 +1654,7 @@ mod test { } fn memory_exec(schema: &SchemaRef) -> Arc { - Arc::new(MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap()) + Arc::new(MemoryExec::try_new(&[vec![]], Arc::clone(schema), None).unwrap()) } fn sorted_memory_exec( @@ -1650,7 +1662,7 @@ mod test { sort_exprs: Vec, ) -> Arc { Arc::new( - MemoryExec::try_new(&[vec![]], schema.clone(), None) + MemoryExec::try_new(&[vec![]], Arc::clone(schema), None) .unwrap() .with_sort_information(vec![sort_exprs]), ) diff --git a/datafusion/physical-plan/src/sorts/builder.rs b/datafusion/physical-plan/src/sorts/builder.rs index 3527d5738223..d32c60697ec8 100644 --- a/datafusion/physical-plan/src/sorts/builder.rs +++ b/datafusion/physical-plan/src/sorts/builder.rs @@ -20,6 +20,7 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; +use std::sync::Arc; #[derive(Debug, Copy, Clone, Default)] struct BatchCursor { @@ -145,6 +146,9 @@ impl BatchBuilder { retain }); - Ok(Some(RecordBatch::try_new(self.schema.clone(), columns)?)) + Ok(Some(RecordBatch::try_new( + Arc::clone(&self.schema), + columns, + )?)) } } diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 422ff3aebdb3..85418ff36119 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -29,6 +29,7 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; use futures::Stream; use std::pin::Pin; +use std::sync::Arc; use std::task::{ready, Context, Poll}; /// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] @@ -324,6 +325,6 @@ impl Stream for SortPreservingMergeStream { impl RecordBatchStream for SortPreservingMergeStream { fn schema(&self) -> SchemaRef { - self.in_progress.schema().clone() + Arc::clone(self.in_progress.schema()) } } diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index ad5d485cffc9..fe6b744935fb 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -260,7 +260,7 @@ impl ExecutionPlan for PartialSortExec { ) -> Result> { let new_partial_sort = PartialSortExec::new( self.expr.clone(), - children[0].clone(), + Arc::clone(&children[0]), self.common_prefix_length, ) .with_fetch(self.fetch) @@ -276,7 +276,7 @@ impl ExecutionPlan for PartialSortExec { ) -> Result { trace!("Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); - let input = self.input.execute(partition, context.clone())?; + let input = self.input.execute(partition, Arc::clone(&context))?; trace!( "End PartialSortExec's input.execute for partition: {}", @@ -485,11 +485,11 @@ mod tests { options: option_asc, }, ], - source.clone(), + Arc::clone(&source), 2, )) as Arc; - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; let expected_after_sort = [ "+---+---+---+", @@ -549,13 +549,13 @@ mod tests { options: option_asc, }, ], - source.clone(), + Arc::clone(&source), common_prefix_length, ) .with_fetch(Some(4)), ) as Arc; - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; let expected_after_sort = [ "+---+---+---+", @@ -621,11 +621,11 @@ mod tests { options: option_asc, }, ], - source.clone(), + Arc::clone(source), common_prefix_length, )); - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!(2, result.len()); assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), @@ -676,7 +676,7 @@ mod tests { Arc::new( MemoryExec::try_new( &[vec![batch1, batch2, batch3, batch4]], - schema.clone(), + Arc::clone(&schema), None, ) .unwrap(), @@ -711,7 +711,7 @@ mod tests { options: option_asc, }, ], - mem_exec.clone(), + Arc::clone(&mem_exec), 1, ); let partial_sort_exec = @@ -720,7 +720,7 @@ mod tests { partial_sort_executor.expr, partial_sort_executor.input, )) as Arc; - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), [0, 125, 125, 0, 150] @@ -732,7 +732,7 @@ mod tests { "The sort should have returned all memory used back to the memory manager" ); let partial_sort_result = concat_batches(&schema, &result).unwrap(); - let sort_result = collect(sort_exec, task_ctx.clone()).await?; + let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!(sort_result[0], partial_sort_result); Ok(()) @@ -772,7 +772,7 @@ mod tests { options: option_asc, }, ], - mem_exec.clone(), + Arc::clone(&mem_exec), 1, ) .with_fetch(fetch_size); @@ -783,7 +783,7 @@ mod tests { SortExec::new(partial_sort_executor.expr, partial_sort_executor.input) .with_fetch(fetch_size), ) as Arc; - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), expected_batch_num_rows @@ -795,7 +795,7 @@ mod tests { "The sort should have returned all memory used back to the memory manager" ); let partial_sort_result = concat_batches(&schema, &result)?; - let sort_result = collect(sort_exec, task_ctx.clone()).await?; + let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!(sort_result[0], partial_sort_result); } @@ -822,8 +822,12 @@ mod tests { let data: ArrayRef = Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::()); - let batch = RecordBatch::try_new(schema.clone(), vec![data])?; - let input = Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None)?); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?; + let input = Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?); let partial_sort_exec = Arc::new(PartialSortExec::new( vec![PhysicalSortExpr { @@ -837,13 +841,13 @@ mod tests { let result: Vec = collect(partial_sort_exec, task_ctx).await?; let expected_batch = vec![ RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new( vec![1, 1].into_iter().map(Some).collect::(), )], )?, RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new( vec![2].into_iter().map(Some).collect::(), )], @@ -879,7 +883,7 @@ mod tests { // define data. let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(Float32Array::from(vec![ Some(1.0_f32), @@ -961,8 +965,11 @@ mod tests { *partial_sort_exec.schema().field(2).data_type() ); - let result: Vec = - collect(partial_sort_exec.clone(), task_ctx).await?; + let result: Vec = collect( + Arc::clone(&partial_sort_exec) as Arc, + task_ctx, + ) + .await?; assert_batches_eq!(expected, &result); assert_eq!(result.len(), 2); let metrics = partial_sort_exec.metrics().unwrap(); @@ -997,7 +1004,7 @@ mod tests { 1, )); - let fut = collect(sort_exec, task_ctx.clone()); + let fut = collect(sort_exec, Arc::clone(&task_ctx)); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 0bf66bc6e522..f347a0f5b6d5 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -338,13 +338,13 @@ impl ExternalSorter { spill.path() ))); } - let stream = read_spill_as_stream(spill, self.schema.clone(), 2)?; + let stream = read_spill_as_stream(spill, Arc::clone(&self.schema), 2)?; streams.push(stream); } streaming_merge( streams, - self.schema.clone(), + Arc::clone(&self.schema), &self.expr, self.metrics.baseline.clone(), self.batch_size, @@ -354,7 +354,9 @@ impl ExternalSorter { } else if !self.in_mem_batches.is_empty() { self.in_mem_sort_stream(self.metrics.baseline.clone()) } else { - Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))) + Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone( + &self.schema, + )))) } } @@ -394,8 +396,11 @@ impl ExternalSorter { let spill_file = self.runtime.disk_manager.create_tmp_file("Sorting")?; let batches = std::mem::take(&mut self.in_mem_batches); - let spilled_rows = - spill_record_batches(batches, spill_file.path().into(), self.schema.clone())?; + let spilled_rows = spill_record_batches( + batches, + spill_file.path().into(), + Arc::clone(&self.schema), + )?; let used = self.reservation.free(); self.metrics.spill_count.add(1); self.metrics.spilled_bytes.add(used); @@ -525,7 +530,7 @@ impl ExternalSorter { streaming_merge( streams, - self.schema.clone(), + Arc::clone(&self.schema), &self.expr, metrics, self.batch_size, @@ -548,7 +553,7 @@ impl ExternalSorter { let schema = batch.schema(); let fetch = self.fetch; - let expressions = self.expr.clone(); + let expressions = Arc::clone(&self.expr); let stream = futures::stream::once(futures::future::lazy(move |_| { let sorted = sort_batch(&batch, &expressions, fetch)?; metrics.record_output(sorted.num_rows()); @@ -844,7 +849,7 @@ impl ExecutionPlan for SortExec { self: Arc, children: Vec>, ) -> Result> { - let new_sort = SortExec::new(self.expr.clone(), children[0].clone()) + let new_sort = SortExec::new(self.expr.clone(), Arc::clone(&children[0])) .with_fetch(self.fetch) .with_preserve_partitioning(self.preserve_partitioning); @@ -858,7 +863,7 @@ impl ExecutionPlan for SortExec { ) -> Result { trace!("Start SortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); - let mut input = self.input.execute(partition, context.clone())?; + let mut input = self.input.execute(partition, Arc::clone(&context))?; let execution_options = &context.session_config().options().execution; @@ -962,7 +967,7 @@ mod tests { Arc::new(CoalescePartitionsExec::new(csv)), )); - let result = collect(sort_exec, task_ctx.clone()).await?; + let result = collect(sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!(result.len(), 1); assert_eq!(result[0].num_rows(), 400); @@ -1005,7 +1010,11 @@ mod tests { Arc::new(CoalescePartitionsExec::new(input)), )); - let result = collect(sort_exec.clone(), task_ctx.clone()).await?; + let result = collect( + Arc::clone(&sort_exec) as Arc, + Arc::clone(&task_ctx), + ) + .await?; assert_eq!(result.len(), 2); @@ -1081,7 +1090,11 @@ mod tests { .with_fetch(fetch), ); - let result = collect(sort_exec.clone(), task_ctx.clone()).await?; + let result = collect( + Arc::clone(&sort_exec) as Arc, + Arc::clone(&task_ctx), + ) + .await?; assert_eq!(result.len(), 1); let metrics = sort_exec.metrics().unwrap(); @@ -1111,9 +1124,10 @@ mod tests { let data: ArrayRef = Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); - let batch = RecordBatch::try_new(schema.clone(), vec![data]).unwrap(); - let input = - Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None).unwrap()); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data]).unwrap(); + let input = Arc::new( + MemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None).unwrap(), + ); let sort_exec = Arc::new(SortExec::new( vec![PhysicalSortExpr { @@ -1128,7 +1142,7 @@ mod tests { let expected_data: ArrayRef = Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); let expected_batch = - RecordBatch::try_new(schema.clone(), vec![expected_data]).unwrap(); + RecordBatch::try_new(Arc::clone(&schema), vec![expected_data]).unwrap(); // Data is correct assert_eq!(&vec![expected_batch], &result); @@ -1154,7 +1168,7 @@ mod tests { // define data. let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![Some(2), None, Some(1), Some(2)])), Arc::new(ListArray::from_iter_primitive::(vec![ @@ -1183,7 +1197,11 @@ mod tests { }, }, ], - Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None)?), + Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?), )); assert_eq!(DataType::Int32, *sort_exec.schema().field(0).data_type()); @@ -1192,7 +1210,8 @@ mod tests { *sort_exec.schema().field(1).data_type() ); - let result: Vec = collect(sort_exec.clone(), task_ctx).await?; + let result: Vec = + collect(Arc::clone(&sort_exec) as Arc, task_ctx).await?; let metrics = sort_exec.metrics().unwrap(); assert!(metrics.elapsed_compute().unwrap() > 0); assert_eq!(metrics.output_rows().unwrap(), 4); @@ -1226,7 +1245,7 @@ mod tests { // define data. let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(Float32Array::from(vec![ Some(f32::NAN), @@ -1274,7 +1293,8 @@ mod tests { assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type()); assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type()); - let result: Vec = collect(sort_exec.clone(), task_ctx).await?; + let result: Vec = + collect(Arc::clone(&sort_exec) as Arc, task_ctx).await?; let metrics = sort_exec.metrics().unwrap(); assert!(metrics.elapsed_compute().unwrap() > 0); assert_eq!(metrics.output_rows().unwrap(), 8); @@ -1337,7 +1357,7 @@ mod tests { blocking_exec, )); - let fut = collect(sort_exec, task_ctx.clone()); + let fut = collect(sort_exec, Arc::clone(&task_ctx)); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1358,7 +1378,8 @@ mod tests { let schema = Arc::new(Schema::empty()); let options = RecordBatchOptions::new().with_row_count(Some(1)); let batch = - RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); + RecordBatch::try_new_with_options(Arc::clone(&schema), vec![], &options) + .unwrap(); let expressions = vec![PhysicalSortExpr { expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index e364aca3791c..41dfd449dd82 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -188,7 +188,7 @@ impl ExecutionPlan for SortPreservingMergeExec { children: Vec>, ) -> Result> { Ok(Arc::new( - SortPreservingMergeExec::new(self.expr.clone(), children[0].clone()) + SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0])) .with_fetch(self.fetch), )) } @@ -232,7 +232,8 @@ impl ExecutionPlan for SortPreservingMergeExec { _ => { let receivers = (0..input_partitions) .map(|partition| { - let stream = self.input.execute(partition, context.clone())?; + let stream = + self.input.execute(partition, Arc::clone(&context))?; Ok(spawn_buffered(stream, 1)) }) .collect::>()?; @@ -587,8 +588,9 @@ mod tests { }, }]; - let basic = basic_sort(csv.clone(), sort.clone(), task_ctx.clone()).await; - let partition = partition_sort(csv, sort, task_ctx.clone()).await; + let basic = + basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await; + let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -654,10 +656,11 @@ mod tests { }]; let input = - sorted_partitioned_input(sort.clone(), &[10, 3, 11], task_ctx.clone()) + sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx)) .await?; - let basic = basic_sort(input.clone(), sort.clone(), task_ctx.clone()).await; - let partition = sorted_merge(input, sort, task_ctx.clone()).await; + let basic = + basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await; + let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await; assert_eq!(basic.num_rows(), 1200); assert_eq!(partition.num_rows(), 1200); @@ -685,9 +688,9 @@ mod tests { // Test streaming with default batch size let task_ctx = Arc::new(TaskContext::default()); let input = - sorted_partitioned_input(sort.clone(), &[10, 5, 13], task_ctx.clone()) + sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx)) .await?; - let basic = basic_sort(input.clone(), sort.clone(), task_ctx).await; + let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await; // batch size of 23 let task_ctx = TaskContext::default() @@ -805,17 +808,18 @@ mod tests { }]; let batches = - sorted_partitioned_input(sort.clone(), &[5, 7, 3], task_ctx.clone()).await?; + sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx)) + .await?; let partition_count = batches.output_partitioning().partition_count(); let mut streams = Vec::with_capacity(partition_count); for partition in 0..partition_count { - let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1); + let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1); let sender = builder.tx(); - let mut stream = batches.execute(partition, task_ctx.clone()).unwrap(); + let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap(); builder.spawn(async move { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); @@ -849,7 +853,7 @@ mod tests { assert_eq!(merged.len(), 1); let merged = merged.remove(0); - let basic = basic_sort(batches, sort.clone(), task_ctx.clone()).await; + let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -885,7 +889,9 @@ mod tests { let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge.clone(), task_ctx).await.unwrap(); + let collected = collect(Arc::clone(&merge) as Arc, task_ctx) + .await + .unwrap(); let expected = [ "+----+---+", "| a | b |", diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index 135b4fbdece4..c7924edfb1eb 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -109,7 +109,7 @@ impl RowCursorStream { Ok(Self { converter, reservation, - column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), + column_expressions: expressions.iter().map(|x| Arc::clone(&x.expr)).collect(), streams: FusedStreams(streams), }) } diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 99d9367740be..faeb4799f5af 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -382,7 +382,7 @@ where S: Stream>, { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -402,7 +402,7 @@ impl EmptyRecordBatchStream { impl RecordBatchStream for EmptyRecordBatchStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -474,7 +474,7 @@ mod test { let schema = schema(); let num_partitions = 10; - let input = PanicExec::new(schema.clone(), num_partitions); + let input = PanicExec::new(Arc::clone(&schema), num_partitions); consume(input, 10).await } @@ -485,7 +485,7 @@ mod test { // make 2 partitions, second partition panics before the first let num_partitions = 2; - let input = PanicExec::new(schema.clone(), num_partitions) + let input = PanicExec::new(Arc::clone(&schema), num_partitions) .with_partition_panic(0, 10) .with_partition_panic(1, 3); // partition 1 should panic first (after 3 ) @@ -504,12 +504,12 @@ mod test { let schema = schema(); // Make an input that never proceeds - let input = BlockingExec::new(schema.clone(), 1); + let input = BlockingExec::new(Arc::clone(&schema), 1); let refs = input.refs(); // Configure a RecordBatchReceiverStream to consume the input let mut builder = RecordBatchReceiverStream::builder(schema, 2); - builder.run_input(Arc::new(input), 0, task_ctx.clone()); + builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx)); let stream = builder.build(); // input should still be present @@ -529,12 +529,14 @@ mod test { let schema = schema(); // make an input that will error twice - let error_stream = - MockExec::new(vec![exec_err!("Test1"), exec_err!("Test2")], schema.clone()) - .with_use_task(false); + let error_stream = MockExec::new( + vec![exec_err!("Test1"), exec_err!("Test2")], + Arc::clone(&schema), + ) + .with_use_task(false); let mut builder = RecordBatchReceiverStream::builder(schema, 2); - builder.run_input(Arc::new(error_stream), 0, task_ctx.clone()); + builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx)); let mut stream = builder.build(); // get the first result, which should be an error @@ -560,7 +562,11 @@ mod test { let mut builder = RecordBatchReceiverStream::builder(input.schema(), num_partitions); for partition in 0..num_partitions { - builder.run_input(input.clone(), partition, task_ctx.clone()); + builder.run_input( + Arc::clone(&input) as Arc, + partition, + Arc::clone(&task_ctx), + ); } let mut stream = builder.build(); diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index ff57adde4e2e..5a9035c8dbfc 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -93,7 +93,7 @@ impl StreamingTableExec { let projected_output_ordering = projected_output_ordering.into_iter().collect::>(); let cache = Self::compute_properties( - projected_schema.clone(), + Arc::clone(&projected_schema), &projected_output_ordering, &partitions, infinite, @@ -240,7 +240,7 @@ impl ExecutionPlan for StreamingTableExec { let stream = self.partitions[partition].execute(ctx); let projected_stream = match self.projection.clone() { Some(projection) => Box::pin(RecordBatchStreamAdapter::new( - self.projected_schema.clone(), + Arc::clone(&self.projected_schema), stream.map(move |x| { x.and_then(|b| b.project(projection.as_ref()).map_err(Into::into)) }), @@ -327,7 +327,7 @@ mod test { /// Set the batches for the stream fn with_batches(mut self, batches: Vec) -> Self { let stream = TestPartitionStream::new_with_batches(batches); - self.schema = Some(stream.schema().clone()); + self.schema = Some(Arc::clone(stream.schema())); self.partitions = vec![Arc::new(stream)]; self } diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index 377b919bb407..f5b4a096018f 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -144,6 +144,9 @@ impl PartitionStream for TestPartitionStream { } fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { let stream = futures::stream::iter(self.batches.clone().into_iter().map(Ok)); - Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream)) + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + stream, + )) } } diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index ad47a484c9f7..ac4eb1ca9e58 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -133,7 +133,7 @@ impl MockExec { /// ensure any poll loops are correct. This behavior can be /// changed with `with_use_task` pub fn new(data: Vec>, schema: SchemaRef) -> Self { - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(Arc::clone(&schema)); Self { data, schema, @@ -294,7 +294,7 @@ impl BarrierExec { pub fn new(data: Vec>, schema: SchemaRef) -> Self { // wait for all streams and the input let barrier = Arc::new(Barrier::new(data.len() + 1)); - let cache = Self::compute_properties(schema.clone(), &data); + let cache = Self::compute_properties(Arc::clone(&schema), &data); Self { data, schema, @@ -374,7 +374,7 @@ impl ExecutionPlan for BarrierExec { // task simply sends data in order after barrier is reached let data = self.data[partition].clone(); - let b = self.barrier.clone(); + let b = Arc::clone(&self.barrier); let tx = builder.tx(); builder.spawn(async move { println!("Partition {partition} waiting on barrier"); @@ -421,7 +421,7 @@ impl ErrorExec { DataType::Int64, true, )])); - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(schema); Self { cache } } @@ -591,7 +591,7 @@ pub struct BlockingExec { impl BlockingExec { /// Create new [`BlockingExec`] with a give schema and number of partitions. pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { - let cache = Self::compute_properties(schema.clone(), n_partitions); + let cache = Self::compute_properties(Arc::clone(&schema), n_partitions); Self { schema, refs: Default::default(), @@ -735,7 +735,7 @@ impl PanicExec { /// partitions, which will each panic immediately. pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { let batches_until_panics = vec![0; n_partitions]; - let cache = Self::compute_properties(schema.clone(), &batches_until_panics); + let cache = Self::compute_properties(Arc::clone(&schema), &batches_until_panics); Self { schema, batches_until_panics, @@ -845,7 +845,7 @@ impl Stream for PanicStream { if self.ready { self.batches_until_panic -= 1; self.ready = false; - let batch = RecordBatch::new_empty(self.schema.clone()); + let batch = RecordBatch::new_empty(Arc::clone(&self.schema)); return Poll::Ready(Some(Ok(batch))); } else { self.ready = true; diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 6a77bfaf3ccd..5366a5707696 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -131,7 +131,7 @@ impl TopK { ); Ok(Self { - schema: schema.clone(), + schema: Arc::clone(&schema), metrics: TopKMetrics::new(metrics, partition), reservation, batch_size, @@ -355,7 +355,7 @@ impl TopKHeap { /// high, as a single [`RecordBatch`], and a sorted vec of the /// current heap's contents pub fn emit_with_state(&mut self) -> Result<(RecordBatch, Vec)> { - let schema = self.store.schema().clone(); + let schema = Arc::clone(self.store.schema()); // generate sorted rows let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec(); diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index 1570778be69b..96bd0de3d37c 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -62,7 +62,7 @@ impl PlanContext { } pub fn update_plan_from_children(mut self) -> Result { - let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); + let children_plans = self.children.iter().map(|c| Arc::clone(&c.plan)).collect(); self.plan = with_new_children_if_necessary(self.plan, children_plans)?; Ok(self) diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 3f88eb4c3732..b39c6aee82b9 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -188,7 +188,7 @@ fn calculate_union_eq_properties( // TODO: Check whether constant expressions evaluates the same value or not for each partition let across_partitions = false; return Some( - ConstExpr::new(meet_constant.owned_expr()) + ConstExpr::from(meet_constant.owned_expr()) .with_across_partitions(across_partitions), ); } @@ -449,7 +449,7 @@ impl ExecutionPlan for InterleaveExec { let mut input_stream_vec = vec![]; for input in self.inputs.iter() { if partition < input.output_partitioning().partition_count() { - input_stream_vec.push(input.execute(partition, context.clone())?); + input_stream_vec.push(input.execute(partition, Arc::clone(&context))?); } else { // Do not find a partition to execute break; @@ -550,7 +550,7 @@ impl CombinedRecordBatchStream { impl RecordBatchStream for CombinedRecordBatchStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -657,7 +657,7 @@ mod tests { in_data .iter() .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), + expr: Arc::clone(*expr), options: *options, }) .collect::>() @@ -842,11 +842,11 @@ mod tests { .map(|ordering| convert_to_sort_exprs(ordering)) .collect::>(); let child1 = Arc::new( - MemoryExec::try_new(&[], schema.clone(), None)? + MemoryExec::try_new(&[], Arc::clone(&schema), None)? .with_sort_information(first_orderings), ); let child2 = Arc::new( - MemoryExec::try_new(&[], schema.clone(), None)? + MemoryExec::try_new(&[], Arc::clone(&schema), None)? .with_sort_information(second_orderings), ); diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index e072b214fd36..bdd56f4b5aa4 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -83,7 +83,7 @@ impl UnnestExec { schema: SchemaRef, options: UnnestOptions, ) -> Self { - let cache = Self::compute_properties(&input, schema.clone()); + let cache = Self::compute_properties(&input, Arc::clone(&schema)); UnnestExec { input, @@ -147,10 +147,10 @@ impl ExecutionPlan for UnnestExec { children: Vec>, ) -> Result> { Ok(Arc::new(UnnestExec::new( - children[0].clone(), + Arc::clone(&children[0]), self.list_column_indices.clone(), self.struct_column_indices.clone(), - self.schema.clone(), + Arc::clone(&self.schema), self.options.clone(), ))) } @@ -169,7 +169,7 @@ impl ExecutionPlan for UnnestExec { Ok(Box::pin(UnnestStream { input, - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), list_type_columns: self.list_column_indices.clone(), struct_column_indices: self.struct_column_indices.iter().copied().collect(), options: self.options.clone(), @@ -237,7 +237,7 @@ struct UnnestStream { impl RecordBatchStream for UnnestStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -330,13 +330,13 @@ fn flatten_struct_cols( data_type ), }, - None => Ok(vec![column_data.clone()]), + None => Ok(vec![Arc::clone(column_data)]), }) .collect::>>()? .into_iter() .flatten() .collect(); - Ok(RecordBatch::try_new(schema.clone(), columns_expanded)?) + Ok(RecordBatch::try_new(Arc::clone(schema), columns_expanded)?) } /// For each row in a `RecordBatch`, some list/struct columns need to be unnested. @@ -357,7 +357,7 @@ fn build_batch( let list_arrays: Vec = list_type_columns .iter() .map(|index| { - ColumnarValue::Array(batch.column(*index).clone()) + ColumnarValue::Array(Arc::clone(batch.column(*index))) .into_array(batch.num_rows()) }) .collect::>()?; @@ -372,7 +372,7 @@ fn build_batch( })? as usize }; if total_length == 0 { - return Ok(RecordBatch::new_empty(schema.clone())); + return Ok(RecordBatch::new_empty(Arc::clone(schema))); } // Unnest all the list arrays @@ -444,7 +444,7 @@ fn find_longest_length( .collect::>()?; let longest_length = list_lengths.iter().skip(1).try_fold( - list_lengths[0].clone(), + Arc::clone(&list_lengths[0]), |longest, current| { let is_lt = lt(&longest, ¤t)?; zip(&is_lt, ¤t, &longest) @@ -649,7 +649,7 @@ fn flatten_list_cols_from_indices( .iter() .enumerate() .map(|(col_idx, arr)| match unnested_list_arrays.get(&col_idx) { - Some(unnested_array) => Ok(unnested_array.clone()), + Some(unnested_array) => Ok(Arc::clone(unnested_array)), None => Ok(kernels::take::take(arr, indices, None)?), }) .collect::>>()?; @@ -813,27 +813,27 @@ mod tests { // Test with single ListArray // [A, B, C], [], NULL, [D], NULL, [NULL, F] let list_array = Arc::new(make_generic_array::()) as ArrayRef; - verify_longest_length(&[list_array.clone()], false, vec![3, 0, 0, 1, 0, 2])?; - verify_longest_length(&[list_array.clone()], true, vec![3, 0, 1, 1, 1, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], false, vec![3, 0, 0, 1, 0, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], true, vec![3, 0, 1, 1, 1, 2])?; // Test with single LargeListArray // [A, B, C], [], NULL, [D], NULL, [NULL, F] let list_array = Arc::new(make_generic_array::()) as ArrayRef; - verify_longest_length(&[list_array.clone()], false, vec![3, 0, 0, 1, 0, 2])?; - verify_longest_length(&[list_array.clone()], true, vec![3, 0, 1, 1, 1, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], false, vec![3, 0, 0, 1, 0, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], true, vec![3, 0, 1, 1, 1, 2])?; // Test with single FixedSizeListArray // [A, B], NULL, [C, D], NULL, [NULL, F], [NULL, NULL] let list_array = Arc::new(make_fixed_list()) as ArrayRef; - verify_longest_length(&[list_array.clone()], false, vec![2, 0, 2, 0, 2, 2])?; - verify_longest_length(&[list_array.clone()], true, vec![2, 1, 2, 1, 2, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], false, vec![2, 0, 2, 0, 2, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], true, vec![2, 1, 2, 1, 2, 2])?; // Test with multiple list arrays // [A, B, C], [], NULL, [D], NULL, [NULL, F] // [A, B], NULL, [C, D], NULL, [NULL, F], [NULL, NULL] let list1 = Arc::new(make_generic_array::()) as ArrayRef; let list2 = Arc::new(make_fixed_list()) as ArrayRef; - let list_arrays = vec![list1.clone(), list2.clone()]; + let list_arrays = vec![Arc::clone(&list1), Arc::clone(&list2)]; verify_longest_length(&list_arrays, false, vec![3, 0, 2, 1, 2, 2])?; verify_longest_length(&list_arrays, true, vec![3, 1, 2, 1, 2, 2])?; diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index 4d385812d4a8..3ea27d62d80b 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -88,7 +88,7 @@ impl ValuesExec { .and_then(ScalarValue::iter_to_array) }) .collect::>>()?; - let batch = RecordBatch::try_new(schema.clone(), arr)?; + let batch = RecordBatch::try_new(Arc::clone(&schema), arr)?; let data: Vec = vec![batch]; Self::try_new_from_batches(schema, data) } @@ -114,7 +114,7 @@ impl ValuesExec { } } - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(Arc::clone(&schema)); Ok(ValuesExec { schema, data: batches, @@ -175,7 +175,7 @@ impl ExecutionPlan for ValuesExec { self: Arc, _: Vec>, ) -> Result> { - ValuesExec::try_new_from_batches(self.schema.clone(), self.data.clone()) + ValuesExec::try_new_from_batches(Arc::clone(&self.schema), self.data.clone()) .map(|e| Arc::new(e) as _) } @@ -193,7 +193,7 @@ impl ExecutionPlan for ValuesExec { Ok(Box::pin(MemoryStream::try_new( self.data(), - self.schema.clone(), + Arc::clone(&self.schema), None, )?)) } @@ -260,7 +260,7 @@ mod tests { DataType::UInt32, false, )])); - let _ = ValuesExec::try_new(schema.clone(), vec![vec![lit(1u32)]]).unwrap(); + let _ = ValuesExec::try_new(Arc::clone(&schema), vec![vec![lit(1u32)]]).unwrap(); // Test that a null value is rejected let _ = ValuesExec::try_new(schema, vec![vec![lit(ScalarValue::UInt32(None))]]) .unwrap_err(); diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 9eb29891703e..6311107f7b58 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -289,7 +289,7 @@ impl ExecutionPlan for BoundedWindowAggExec { ) -> Result> { Ok(Arc::new(BoundedWindowAggExec::try_new( self.window_expr.clone(), - children[0].clone(), + Arc::clone(&children[0]), self.partition_keys.clone(), self.input_order_mode.clone(), )?)) @@ -303,7 +303,7 @@ impl ExecutionPlan for BoundedWindowAggExec { let input = self.input.execute(partition, context)?; let search_mode = self.get_search_algo()?; let stream = Box::pin(BoundedWindowAggStream::new( - self.schema.clone(), + Arc::clone(&self.schema), self.window_expr.clone(), input, BaselineMetrics::new(&self.metrics, partition), @@ -394,7 +394,9 @@ trait PartitionSearcher: Send { // as it may not have the "correct" schema in terms of output // nullability constraints. For details, see the following issue: // https://github.com/apache/datafusion/issues/9320 - .or_insert_with(|| PartitionBatchState::new(self.input_schema().clone())); + .or_insert_with(|| { + PartitionBatchState::new(Arc::clone(self.input_schema())) + }); partition_batch_state.extend(&partition_batch)?; } @@ -513,7 +515,7 @@ impl PartitionSearcher for LinearSearch { let length = indices.len(); for (idx, window_agg_state) in window_agg_states.iter().enumerate() { let partition = &window_agg_state[&row]; - let values = partition.state.out_col.slice(0, length).clone(); + let values = Arc::clone(&partition.state.out_col.slice(0, length)); new_columns[idx].push(values); } let partition_batch_state = &mut partition_buffers[&row]; @@ -935,7 +937,7 @@ impl BoundedWindowAggStream { search_mode: Box, ) -> Result { let state = window_expr.iter().map(|_| IndexMap::new()).collect(); - let empty_batch = RecordBatch::new_empty(schema.clone()); + let empty_batch = RecordBatch::new_empty(Arc::clone(&schema)); Ok(Self { schema, input, @@ -957,7 +959,7 @@ impl BoundedWindowAggStream { cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?; } - let schema = self.schema.clone(); + let schema = Arc::clone(&self.schema); let window_expr_out = self.search_mode.calculate_out_columns( &self.input_buffer, &self.window_agg_states, @@ -1114,7 +1116,7 @@ impl BoundedWindowAggStream { impl RecordBatchStream for BoundedWindowAggStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1287,7 +1289,7 @@ mod tests { impl RecordBatchStream for TestStreamPartition { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1467,7 +1469,7 @@ mod tests { } let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(schema), vec![Arc::new(sn1_array.finish()), Arc::new(hash_array.finish())], )?; batches.push(batch); @@ -1500,7 +1502,7 @@ mod tests { // Source has 2 partitions let partitions = vec![ Arc::new(TestStreamPartition { - schema: schema.clone(), + schema: Arc::clone(&schema), batches: batches.clone(), idx: 0, state: PolingState::BatchReturn, @@ -1510,7 +1512,7 @@ mod tests { n_partition ]; let source = Arc::new(StreamingTableExec::try_new( - schema.clone(), + Arc::clone(&schema), partitions, None, orderings, @@ -1533,28 +1535,38 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); // Create a new batch of data to insert into the table let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], )?; let memory_exec = MemoryExec::try_new( &[vec![batch.clone(), batch.clone(), batch.clone()]], - schema.clone(), + Arc::clone(&schema), None, ) .map(|e| Arc::new(e) as Arc)?; let col_a = col("a", &schema)?; - let nth_value_func1 = - NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1, false)? - .reverse_expr() - .unwrap(); - let nth_value_func2 = - NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2, false)? - .reverse_expr() - .unwrap(); + let nth_value_func1 = NthValue::nth( + "nth_value(-1)", + Arc::clone(&col_a), + DataType::Int32, + 1, + false, + )? + .reverse_expr() + .unwrap(); + let nth_value_func2 = NthValue::nth( + "nth_value(-2)", + Arc::clone(&col_a), + DataType::Int32, + 2, + false, + )? + .reverse_expr() + .unwrap(); let last_value_func = Arc::new(NthValue::last( "last", - col_a.clone(), + Arc::clone(&col_a), DataType::Int32, false, )) as _; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 252c8d12b519..7f794556a241 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -64,11 +64,11 @@ pub fn schema_add_window_field( ) -> Result> { let data_types = args .iter() - .map(|e| e.clone().as_ref().data_type(schema)) + .map(|e| Arc::clone(e).as_ref().data_type(schema)) .collect::>>()?; let nullability = args .iter() - .map(|e| e.clone().as_ref().nullable(schema)) + .map(|e| Arc::clone(e).as_ref().nullable(schema)) .collect::>>()?; let window_expr_return_type = window_fn.return_type(&data_types, &nullability)?; let mut window_fields = schema @@ -220,6 +220,24 @@ fn get_scalar_value_from_args( }) } +fn get_signed_integer(value: ScalarValue) -> Result { + if !value.data_type().is_integer() { + return Err(DataFusionError::Execution( + "Expected an integer value".to_string(), + )); + } + value.cast_to(&DataType::Int64)?.try_into() +} + +fn get_unsigned_integer(value: ScalarValue) -> Result { + if !value.data_type().is_integer() { + return Err(DataFusionError::Execution( + "Expected an integer value".to_string(), + )); + } + value.cast_to(&DataType::UInt64)?.try_into() +} + fn get_casted_value( default_value: Option, dtype: &DataType, @@ -259,10 +277,10 @@ fn create_built_in_window_expr( } if n.is_unsigned() { - let n: u64 = n.try_into()?; + let n = get_unsigned_integer(n)?; Arc::new(Ntile::new(name, n, out_data_type)) } else { - let n: i64 = n.try_into()?; + let n: i64 = get_signed_integer(n)?; if n <= 0 { return exec_err!("NTILE requires a positive integer"); } @@ -270,10 +288,10 @@ fn create_built_in_window_expr( } } BuiltInWindowFunction::Lag => { - let arg = args[0].clone(); + let arg = Arc::clone(&args[0]); let shift_offset = get_scalar_value_from_args(args, 1)? - .map(|v| v.try_into()) - .and_then(|v| v.ok()); + .map(get_signed_integer) + .map_or(Ok(None), |v| v.map(Some))?; let default_value = get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; Arc::new(lag( @@ -286,10 +304,10 @@ fn create_built_in_window_expr( )) } BuiltInWindowFunction::Lead => { - let arg = args[0].clone(); + let arg = Arc::clone(&args[0]); let shift_offset = get_scalar_value_from_args(args, 1)? - .map(|v| v.try_into()) - .and_then(|v| v.ok()); + .map(get_signed_integer) + .map_or(Ok(None), |v| v.map(Some))?; let default_value = get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; Arc::new(lead( @@ -302,12 +320,15 @@ fn create_built_in_window_expr( )) } BuiltInWindowFunction::NthValue => { - let arg = args[0].clone(); - let n = args[1].as_any().downcast_ref::().unwrap().value(); - let n: i64 = n - .clone() - .try_into() - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + let arg = Arc::clone(&args[0]); + let n = get_signed_integer( + args[1] + .as_any() + .downcast_ref::() + .unwrap() + .value() + .clone(), + )?; Arc::new(NthValue::nth( name, arg, @@ -317,7 +338,7 @@ fn create_built_in_window_expr( )?) } BuiltInWindowFunction::FirstValue => { - let arg = args[0].clone(); + let arg = Arc::clone(&args[0]); Arc::new(NthValue::first( name, arg, @@ -326,7 +347,7 @@ fn create_built_in_window_expr( )) } BuiltInWindowFunction::LastValue => { - let arg = args[0].clone(); + let arg = Arc::clone(&args[0]); Arc::new(NthValue::last( name, arg, @@ -408,13 +429,16 @@ pub(crate) fn calc_requirements< let mut sort_reqs = partition_by_exprs .into_iter() .map(|partition_by| { - PhysicalSortRequirement::new(partition_by.borrow().clone(), None) + PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None) }) .collect::>(); for element in orderby_sort_exprs.into_iter() { let PhysicalSortExpr { expr, options } = element.borrow(); if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { - sort_reqs.push(PhysicalSortRequirement::new(expr.clone(), Some(*options))); + sort_reqs.push(PhysicalSortRequirement::new( + Arc::clone(expr), + Some(*options), + )); } } // Convert empty result to None. Otherwise wrap result inside Some() @@ -443,7 +467,7 @@ pub(crate) fn get_partition_by_sort_exprs( ) -> Result { let ordered_partition_exprs = ordered_partition_by_indices .iter() - .map(|idx| partition_by_exprs[*idx].clone()) + .map(|idx| Arc::clone(&partition_by_exprs[*idx])) .collect::>(); // Make sure ordered section doesn't move over the partition by expression assert!(ordered_partition_by_indices.len() <= partition_by_exprs.len()); @@ -464,7 +488,7 @@ pub(crate) fn window_equivalence_properties( ) -> EquivalenceProperties { // We need to update the schema, so we can not directly use // `input.equivalence_properties()`. - let mut window_eq_properties = EquivalenceProperties::new(schema.clone()) + let mut window_eq_properties = EquivalenceProperties::new(Arc::clone(schema)) .extend(input.equivalence_properties().clone()); for expr in window_expr { @@ -535,7 +559,7 @@ pub fn get_best_fitting_window( if window_expr.iter().all(|e| e.uses_bounded_memory()) { Ok(Some(Arc::new(BoundedWindowAggExec::try_new( window_expr, - input.clone(), + Arc::clone(input), physical_partition_keys.to_vec(), input_order_mode, )?) as _)) @@ -548,7 +572,7 @@ pub fn get_best_fitting_window( } else { Ok(Some(Arc::new(WindowAggExec::try_new( window_expr, - input.clone(), + Arc::clone(input), physical_partition_keys.to_vec(), )?) as _)) } @@ -573,13 +597,11 @@ pub fn get_window_mode( let mut partition_by_reqs: Vec = vec![]; let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); partition_by_reqs.extend(indices.iter().map(|&idx| PhysicalSortRequirement { - expr: partitionby_exprs[idx].clone(), + expr: Arc::clone(&partitionby_exprs[idx]), options: None, })); // Treat partition by exprs as constant. During analysis of requirements are satisfied. - let const_exprs = partitionby_exprs - .iter() - .map(|expr| ConstExpr::new(expr.clone())); + let const_exprs = partitionby_exprs.iter().map(ConstExpr::from); let partition_by_eqs = input_eqs.add_constants(const_exprs); let order_by_reqs = PhysicalSortRequirement::from_sort_exprs(orderby_keys); let reverse_order_by_reqs = @@ -618,7 +640,6 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use futures::FutureExt; - use InputOrderMode::{Linear, PartiallySorted, Sorted}; fn create_test_schema() -> Result { @@ -676,7 +697,7 @@ mod tests { let sort_exprs = sort_exprs.into_iter().collect(); Ok(Arc::new(StreamingTableExec::try_new( - schema.clone(), + Arc::clone(schema), vec![], None, Some(sort_exprs), diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index eb01da2ec094..b6330f65e0b7 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -79,7 +79,7 @@ impl WindowAggExec { let ordered_partition_by_indices = get_ordered_partition_by_indices(window_expr[0].partition_by(), &input); - let cache = Self::compute_properties(schema.clone(), &input, &window_expr); + let cache = Self::compute_properties(Arc::clone(&schema), &input, &window_expr); Ok(Self { input, window_expr, @@ -220,7 +220,7 @@ impl ExecutionPlan for WindowAggExec { ) -> Result> { Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), - children[0].clone(), + Arc::clone(&children[0]), self.partition_keys.clone(), )?)) } @@ -232,7 +232,7 @@ impl ExecutionPlan for WindowAggExec { ) -> Result { let input = self.input.execute(partition, context)?; let stream = Box::pin(WindowAggStream::new( - self.schema.clone(), + Arc::clone(&self.schema), self.window_expr.clone(), input, BaselineMetrics::new(&self.metrics, partition), @@ -333,7 +333,7 @@ impl WindowAggStream { let _timer = self.baseline_metrics.elapsed_compute().timer(); let batch = concat_batches(&self.input.schema(), &self.batches)?; if batch.num_rows() == 0 { - return Ok(RecordBatch::new_empty(self.schema.clone())); + return Ok(RecordBatch::new_empty(Arc::clone(&self.schema))); } let partition_by_sort_keys = self @@ -366,7 +366,10 @@ impl WindowAggStream { let mut batch_columns = batch.columns().to_vec(); // calculate window cols batch_columns.extend_from_slice(&columns); - Ok(RecordBatch::try_new(self.schema.clone(), batch_columns)?) + Ok(RecordBatch::try_new( + Arc::clone(&self.schema), + batch_columns, + )?) } } @@ -412,6 +415,6 @@ impl WindowAggStream { impl RecordBatchStream for WindowAggStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index 003957947fec..5f3cf6e2aee8 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -110,7 +110,7 @@ pub struct WorkTableExec { impl WorkTableExec { /// Create a new execution plan for a worktable exec. pub fn new(name: String, schema: SchemaRef) -> Self { - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(Arc::clone(&schema)); Self { name, schema, @@ -123,7 +123,7 @@ impl WorkTableExec { pub(super) fn with_work_table(&self, work_table: Arc) -> Self { Self { name: self.name.clone(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), metrics: ExecutionPlanMetricsSet::new(), work_table, cache: self.cache.clone(), @@ -185,7 +185,7 @@ impl ExecutionPlan for WorkTableExec { self: Arc, _: Vec>, ) -> Result> { - Ok(self.clone()) + Ok(Arc::clone(&self) as Arc) } /// Stream the batches that were written to the work table. @@ -202,7 +202,7 @@ impl ExecutionPlan for WorkTableExec { } let batch = self.work_table.take()?; Ok(Box::pin( - MemoryStream::try_new(batch.batches, self.schema.clone(), None)? + MemoryStream::try_new(batch.batches, Arc::clone(&self.schema), None)? .with_reservation(batch.reservation), )) } diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto index 225bb9ddf661..e2a405595fb7 100644 --- a/datafusion/proto-common/proto/datafusion_common.proto +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -164,7 +164,7 @@ message Union{ repeated int32 type_ids = 3; } -// Used for List/FixedSizeList/LargeList/Struct +// Used for List/FixedSizeList/LargeList/Struct/Map message ScalarNestedValue { message Dictionary { bytes ipc_message = 1; @@ -266,6 +266,7 @@ message ScalarValue{ ScalarNestedValue list_value = 17; ScalarNestedValue fixed_size_list_value = 18; ScalarNestedValue struct_value = 32; + ScalarNestedValue map_value = 41; Decimal128 decimal128_value = 20; Decimal256 decimal256_value = 39; diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index de9fede9ee86..df673de4e119 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -380,7 +380,8 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::ListValue(v) | Value::FixedSizeListValue(v) | Value::LargeListValue(v) - | Value::StructValue(v) => { + | Value::StructValue(v) + | Value::MapValue(v) => { let protobuf::ScalarNestedValue { ipc_message, arrow_data, @@ -479,6 +480,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::StructValue(_) => { Self::Struct(arr.as_struct().to_owned().into()) } + Value::MapValue(_) => Self::Map(arr.as_map().to_owned().into()), _ => unreachable!(), } } diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs index 3cf34aeb6d01..be3cc58b23df 100644 --- a/datafusion/proto-common/src/generated/pbjson.rs +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -6409,6 +6409,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::StructValue(v) => { struct_ser.serialize_field("structValue", v)?; } + scalar_value::Value::MapValue(v) => { + struct_ser.serialize_field("mapValue", v)?; + } scalar_value::Value::Decimal128Value(v) => { struct_ser.serialize_field("decimal128Value", v)?; } @@ -6525,6 +6528,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "fixedSizeListValue", "struct_value", "structValue", + "map_value", + "mapValue", "decimal128_value", "decimal128Value", "decimal256_value", @@ -6586,6 +6591,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { ListValue, FixedSizeListValue, StructValue, + MapValue, Decimal128Value, Decimal256Value, Date64Value, @@ -6646,6 +6652,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "listValue" | "list_value" => Ok(GeneratedField::ListValue), "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), "structValue" | "struct_value" => Ok(GeneratedField::StructValue), + "mapValue" | "map_value" => Ok(GeneratedField::MapValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), @@ -6816,6 +6823,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("structValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::StructValue) +; + } + GeneratedField::MapValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("mapValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::MapValue) ; } GeneratedField::Decimal128Value => { diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs index 57893321e665..b0674ff28d75 100644 --- a/datafusion/proto-common/src/generated/prost.rs +++ b/datafusion/proto-common/src/generated/prost.rs @@ -184,7 +184,7 @@ pub struct Union { #[prost(int32, repeated, tag = "3")] pub type_ids: ::prost::alloc::vec::Vec, } -/// Used for List/FixedSizeList/LargeList/Struct +/// Used for List/FixedSizeList/LargeList/Struct/Map #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarNestedValue { @@ -326,7 +326,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -380,6 +380,8 @@ pub mod scalar_value { FixedSizeListValue(super::ScalarNestedValue), #[prost(message, tag = "32")] StructValue(super::ScalarNestedValue), + #[prost(message, tag = "41")] + MapValue(super::ScalarNestedValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] @@ -581,7 +583,7 @@ pub struct CsvWriterOptions { /// Optional escape. Defaults to `'\\'` #[prost(string, tag = "10")] pub escape: ::prost::alloc::string::String, - /// Optional flag whether to double quote instead of escaping. Defaults to `true` + /// Optional flag whether to double quotes, instead of escaping. Defaults to `true` #[prost(bool, tag = "11")] pub double_quote: bool, } diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 877043f66809..705a479e0178 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -364,6 +364,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { ScalarValue::Struct(arr) => { encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) } + ScalarValue::Map(arr) => { + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + } ScalarValue::Date32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) } @@ -938,7 +941,7 @@ fn create_proto_scalar protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } -// ScalarValue::List / FixedSizeList / LargeList / Struct are serialized using +// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using // Arrow IPC messages as a single column RecordBatch fn encode_scalar_nested_value( arr: ArrayRef, @@ -992,6 +995,9 @@ fn encode_scalar_nested_value( scalar_list_value, )), }), + ScalarValue::Map(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::MapValue(scalar_list_value)), + }), _ => unreachable!(), } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index ce6c0c53c3fc..345765b08be3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -500,7 +500,7 @@ enum AggregateFunction { // REGR_SYY = 33; // REGR_SXY = 34; // STRING_AGG = 35; - NTH_VALUE_AGG = 36; + // NTH_VALUE_AGG = 36; } message AggregateExprNode { diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 901aa2455e16..83210cb4e41f 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -39,6 +39,7 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; +use datafusion_expr::planner::UserDefinedSQLPlanner; mod registry; @@ -165,6 +166,10 @@ impl Serializeable for Expr { "register_udwf called in Placeholder Registry!" ) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 4bf2bb3d7b79..075993e2ba76 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -20,6 +20,7 @@ use std::{collections::HashSet, sync::Arc}; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::plan_err; use datafusion_common::Result; +use datafusion_expr::planner::UserDefinedSQLPlanner; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; /// A default [`FunctionRegistry`] registry that does not resolve any @@ -54,4 +55,8 @@ impl FunctionRegistry for NoRegistry { fn register_udwf(&mut self, udwf: Arc) -> Result>> { plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Window Function '{}'", udwf.inner().name()) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs index 875fe8992e90..b0674ff28d75 100644 --- a/datafusion/proto/src/generated/datafusion_proto_common.rs +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -184,7 +184,7 @@ pub struct Union { #[prost(int32, repeated, tag = "3")] pub type_ids: ::prost::alloc::vec::Vec, } -/// Used for List/FixedSizeList/LargeList/Struct +/// Used for List/FixedSizeList/LargeList/Struct/Map #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarNestedValue { @@ -326,7 +326,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -380,6 +380,8 @@ pub mod scalar_value { FixedSizeListValue(super::ScalarNestedValue), #[prost(message, tag = "32")] StructValue(super::ScalarNestedValue), + #[prost(message, tag = "41")] + MapValue(super::ScalarNestedValue), #[prost(message, tag = "20")] Decimal128Value(super::Decimal128), #[prost(message, tag = "39")] diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 347654e52b73..905f0d984955 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -535,7 +535,6 @@ impl serde::Serialize for AggregateFunction { Self::Min => "MIN", Self::Max => "MAX", Self::ArrayAgg => "ARRAY_AGG", - Self::NthValueAgg => "NTH_VALUE_AGG", }; serializer.serialize_str(variant) } @@ -550,7 +549,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN", "MAX", "ARRAY_AGG", - "NTH_VALUE_AGG", ]; struct GeneratedVisitor; @@ -594,7 +592,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), - "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c74f172482b7..b16d26ee6e1e 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1924,7 +1924,7 @@ pub enum AggregateFunction { /// AVG = 3; /// COUNT = 4; /// APPROX_DISTINCT = 5; - ArrayAgg = 6, + /// /// VARIANCE = 7; /// VARIANCE_POP = 8; /// COVARIANCE = 9; @@ -1952,7 +1952,8 @@ pub enum AggregateFunction { /// REGR_SYY = 33; /// REGR_SXY = 34; /// STRING_AGG = 35; - NthValueAgg = 36, + /// NTH_VALUE_AGG = 36; + ArrayAgg = 6, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -1964,7 +1965,6 @@ impl AggregateFunction { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", AggregateFunction::ArrayAgg => "ARRAY_AGG", - AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1973,7 +1973,6 @@ impl AggregateFunction { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), "ARRAY_AGG" => Some(Self::ArrayAgg), - "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f4fb69280436..a58af8afdd04 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -145,7 +145,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Min => Self::Min, protobuf::AggregateFunction::Max => Self::Max, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, - protobuf::AggregateFunction::NthValueAgg => Self::NthValue, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 7570040a1d08..d8f8ea002b2d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -117,7 +117,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Min => Self::Min, AggregateFunction::Max => Self::Max, AggregateFunction::ArrayAgg => Self::ArrayAgg, - AggregateFunction::NthValue => Self::NthValueAgg, } } } @@ -377,9 +376,6 @@ pub fn serialize_expr( AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::NthValue => { - protobuf::AggregateFunction::NthValueAgg - } }; let aggregate_expr = protobuf::AggregateExprNode { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 23cdc666e701..5e982ad2afde 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -25,8 +25,8 @@ use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ArrayAgg, BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, DistinctArrayAgg, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, - NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, - TryCastExpr, WindowShift, + NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, + WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -255,8 +255,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Min } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Max - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::NthValueAgg } else { return not_impl_err!("Aggregate function not supported: {expr:?}"); }; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 5fc3a9a8a197..f764a050a6cd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -21,7 +21,9 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use std::vec; -use arrow::array::{ArrayRef, FixedSizeListArray}; +use arrow::array::{ + ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, +}; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, @@ -1247,6 +1249,31 @@ fn round_trip_scalar_values() { ), ]))) .unwrap(), + ScalarValue::try_from(&DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int32, true), + Field::new("value", DataType::Utf8, false), + ])), + false, + )), + false, + )) + .unwrap(), + ScalarValue::try_from(&DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int32, true), + Field::new("value", DataType::Utf8, true), + ])), + false, + )), + true, + )) + .unwrap(), + ScalarValue::Map(Arc::new(create_map_array_test_case())), ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())), ScalarValue::FixedSizeBinary(0, None), ScalarValue::FixedSizeBinary(5, None), @@ -1269,6 +1296,25 @@ fn round_trip_scalar_values() { } } +// create a map array [{joe:1}, {blogs:2, foo:4}, {}, null] for testing +fn create_map_array_test_case() -> MapArray { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.finish() +} + #[test] fn round_trip_scalar_types() { let should_pass: Vec = vec![ diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 106247b2d441..d8d85ace1a29 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -38,7 +38,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; -use datafusion::physical_expr::expressions::{Max, NthValueAgg}; +use datafusion::physical_expr::expressions::Max; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -81,6 +81,7 @@ use datafusion_expr::{ ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::nth_value::nth_value_udaf; use datafusion_functions_aggregate::string_agg::StringAgg; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, @@ -362,15 +363,17 @@ fn rountrip_aggregate() -> Result<()> { false, )?], // NTH_VALUE - vec![Arc::new(NthValueAgg::new( - col("b", &schema)?, - 1, - "NTH_VALUE(b, 1)".to_string(), - DataType::Int64, + vec![udaf::create_aggregate_expr( + &nth_value_udaf(), + &[col("b", &schema)?, lit(1u64)], + &[], + &[], + &[], + &schema, + "NTH_VALUE(b, 1)", false, - Vec::new(), - Vec::new(), - ))], + false, + )?], // STRING_AGG vec![udaf::create_aggregate_expr( &AggregateUDF::new_from_impl(StringAgg::new()), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index ea460cb3efc2..d9ddf57eb192 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -415,9 +415,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { // check udaf first let udaf = self.context_provider.get_aggregate_meta(name); - // Skip first value and last value, since we expect window builtin first/last value not udaf version + // Use the builtin window function instead of the user-defined aggregate function if udaf.as_ref().is_some_and(|udaf| { - udaf.name() != "first_value" && udaf.name() != "last_value" + udaf.name() != "first_value" + && udaf.name() != "last_value" + && udaf.name() != "nth_value" }) { Ok(WindowFunctionDefinition::AggregateUDF(udaf.unwrap())) } else { diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 786ea288fa0e..0546a101fcb2 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -18,8 +18,12 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; use datafusion_expr::planner::PlannerResult; +use datafusion_expr::planner::RawDictionaryExpr; use datafusion_expr::planner::RawFieldAccessExpr; -use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; +use sqlparser::ast::{ + CastKind, DictionaryField, Expr as SQLExpr, StructField, Subscript, TrimWhereField, + Value, +}; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result, @@ -175,21 +179,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.parse_value(value, planner_context.prepare_param_data_types()) } SQLExpr::Extract { field, expr } => { - let date_part = self - .context_provider - .get_function_meta("date_part") - .ok_or_else(|| { - internal_datafusion_err!( - "Unable to find expected 'date_part' function" - ) - })?; - let args = vec![ + let mut extract_args = vec![ Expr::Literal(ScalarValue::from(format!("{field}"))), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - date_part, args, - ))) + + for planner in self.planners.iter() { + match planner.plan_extract(extract_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => { + extract_args = args; + } + } + } + + not_impl_err!("Extract not supported by UserDefinedExtensionPlanners: {extract_args:?}") } SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), @@ -594,7 +598,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::Struct { values, fields } => { - self.parse_struct(values, fields, schema, planner_context) + self.parse_struct(schema, planner_context, values, fields) } SQLExpr::Position { expr, r#in } => { self.sql_position_to_expr(*expr, *r#in, schema, planner_context) @@ -619,41 +623,106 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }, ))), + SQLExpr::Dictionary(fields) => { + self.try_plan_dictionary_literal(fields, schema, planner_context) + } _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), } } - /// Parses a struct(..) expression + /// Parses a struct(..) expression and plans it creation fn parse_struct( &self, - values: Vec, - fields: Vec, - input_schema: &DFSchema, + schema: &DFSchema, planner_context: &mut PlannerContext, + values: Vec, + fields: Vec, ) -> Result { if !fields.is_empty() { return not_impl_err!("Struct fields are not supported yet"); } - - if values + let is_named_struct = values .iter() - .any(|value| matches!(value, SQLExpr::Named { .. })) - { - self.create_named_struct(values, input_schema, planner_context) + .any(|value| matches!(value, SQLExpr::Named { .. })); + + let mut create_struct_args = if is_named_struct { + self.create_named_struct_expr(values, schema, planner_context)? } else { - self.create_struct(values, input_schema, planner_context) + self.create_struct_expr(values, schema, planner_context)? + }; + + for planner in self.planners.iter() { + match planner.plan_struct_literal(create_struct_args, is_named_struct)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => create_struct_args = args, + } } + not_impl_err!("Struct not supported by UserDefinedExtensionPlanners: {create_struct_args:?}") + } + + fn sql_position_to_expr( + &self, + substr_expr: SQLExpr, + str_expr: SQLExpr, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let substr = + self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; + let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; + let mut position_args = vec![fullstr, substr]; + for planner in self.planners.iter() { + match planner.plan_position(position_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => { + position_args = args; + } + } + } + + not_impl_err!( + "Position not supported by UserDefinedExtensionPlanners: {position_args:?}" + ) + } + + fn try_plan_dictionary_literal( + &self, + fields: Vec, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut keys = vec![]; + let mut values = vec![]; + for field in fields { + let key = lit(field.key.value); + let value = + self.sql_expr_to_logical_expr(*field.value, schema, planner_context)?; + keys.push(key); + values.push(value); + } + + let mut raw_expr = RawDictionaryExpr { keys, values }; + + for planner in self.planners.iter() { + match planner.plan_dictionary_literal(raw_expr, schema)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(expr) => raw_expr = expr, + } + } + not_impl_err!("Unsupported dictionary literal: {raw_expr:?}") } // Handles a call to struct(...) where the arguments are named. For example // `struct (v as foo, v2 as bar)` by creating a call to the `named_struct` function - fn create_named_struct( + fn create_named_struct_expr( &self, values: Vec, input_schema: &DFSchema, planner_context: &mut PlannerContext, - ) -> Result { - let args = values + ) -> Result> { + Ok(values .into_iter() .enumerate() .map(|(i, value)| { @@ -682,47 +751,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()? .into_iter() .flatten() - .collect(); - - let named_struct_func = self - .context_provider - .get_function_meta("named_struct") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'named_struct' function") - })?; - - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - named_struct_func, - args, - ))) + .collect()) } // Handles a call to struct(...) where the arguments are not named. For example // `struct (v, v2)` by creating a call to the `struct` function // which will create a struct with fields named `c0`, `c1`, etc. - fn create_struct( + fn create_struct_expr( &self, values: Vec, input_schema: &DFSchema, planner_context: &mut PlannerContext, - ) -> Result { - let args = values + ) -> Result> { + values .into_iter() .map(|value| { self.sql_expr_to_logical_expr(value, input_schema, planner_context) }) - .collect::>>()?; - let struct_func = self - .context_provider - .get_function_meta("struct") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'struct' function") - })?; - - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - struct_func, - args, - ))) + .collect::>>() } fn sql_in_list_to_expr( @@ -889,25 +935,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } - fn sql_position_to_expr( - &self, - substr_expr: SQLExpr, - str_expr: SQLExpr, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - let fun = self - .context_provider - .get_function_meta("strpos") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'strpos' function") - })?; - let substr = - self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; - let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; - let args = vec![fullstr, substr]; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) - } } #[cfg(test)] diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index f58c6f3b94d0..a0dfee1b9d90 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -16,9 +16,9 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{internal_datafusion_err, plan_err}; +use datafusion_common::{not_impl_err, plan_err}; use datafusion_common::{DFSchema, Result, ScalarValue}; -use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::Expr; use sqlparser::ast::Expr as SQLExpr; @@ -31,7 +31,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let args = match (substring_from, substring_for) { + let mut substring_args = match (substring_from, substring_for) { (Some(from_expr), Some(for_expr)) => { let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; @@ -68,13 +68,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - let fun = self - .context_provider - .get_function_meta("substr") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'substr' function") - })?; + for planner in self.planners.iter() { + match planner.plan_substring(substring_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => { + substring_args = args; + } + } + } - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + not_impl_err!( + "Substring not supported by UserDefinedExtensionPlanners: {substring_args:?}" + ) } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index ad898de5987a..198186934c84 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -952,6 +952,7 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), + ScalarValue::Map(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } @@ -976,8 +977,13 @@ impl Unparser<'_> { } DataType::Float32 => Ok(ast::DataType::Float(None)), DataType::Float64 => Ok(ast::DataType::Double), - DataType::Timestamp(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + DataType::Timestamp(_, tz) => { + let tz_info = match tz { + Some(_) => TimezoneInfo::WithTimeZone, + None => TimezoneInfo::None, + }; + + Ok(ast::DataType::Timestamp(None, tz_info)) } DataType::Date32 => Ok(ast::DataType::Date), DataType::Date64 => Ok(ast::DataType::Datetime(None)), @@ -1062,9 +1068,9 @@ mod tests { use std::ops::{Add, Sub}; use std::{any::Any, sync::Arc, vec}; + use arrow::datatypes::TimeUnit; use arrow::datatypes::{Field, Schema}; use arrow_schema::DataType::Int8; - use datafusion_common::TableReference; use datafusion_expr::{ case, col, cube, exists, grouping_set, interval_datetime_lit, @@ -1157,6 +1163,23 @@ mod tests { }), r#"CAST(a AS DATETIME)"#, ), + ( + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Timestamp( + TimeUnit::Nanosecond, + Some("+08:00".into()), + ), + }), + r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#, + ), + ( + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + }), + r#"CAST(a AS TIMESTAMP)"#, + ), ( Expr::Cast(Cast { expr: Box::new(col("a")), diff --git a/datafusion/sqllogictest/test_files/agg_func_substitute.slt b/datafusion/sqllogictest/test_files/agg_func_substitute.slt index 342d45e7fb24..9a0a1d587433 100644 --- a/datafusion/sqllogictest/test_files/agg_func_substitute.slt +++ b/datafusion/sqllogictest/test_files/agg_func_substitute.slt @@ -39,16 +39,16 @@ EXPLAIN SELECT a, ARRAY_AGG(c ORDER BY c)[1] as result GROUP BY a; ---- logical_plan -01)Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result -02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] +01)Projection: multiple_ordered_table.a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result +02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[nth_value(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] 03)----TableScan: multiple_ordered_table projection=[a, c] physical_plan -01)ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] -02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +01)ProjectionExec: expr=[a@0 as a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] +02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=8192 05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true @@ -59,16 +59,16 @@ EXPLAIN SELECT a, NTH_VALUE(c, 1 ORDER BY c) as result GROUP BY a; ---- logical_plan -01)Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result -02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] +01)Projection: multiple_ordered_table.a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result +02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[nth_value(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] 03)----TableScan: multiple_ordered_table projection=[a, c] physical_plan -01)ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] -02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +01)ProjectionExec: expr=[a@0 as a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] +02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=8192 05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true @@ -78,16 +78,16 @@ EXPLAIN SELECT a, ARRAY_AGG(c ORDER BY c)[1 + 100] as result GROUP BY a; ---- logical_plan -01)Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result -02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(101)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] +01)Projection: multiple_ordered_table.a, nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result +02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[nth_value(multiple_ordered_table.c, Int64(101)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] 03)----TableScan: multiple_ordered_table projection=[a, c] physical_plan -01)ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] -02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +01)ProjectionExec: expr=[a@0 as a, nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] +02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=8192 05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true diff --git a/datafusion/sqllogictest/test_files/cast.slt b/datafusion/sqllogictest/test_files/cast.slt index 4554c9292b6e..3466354e54d7 100644 --- a/datafusion/sqllogictest/test_files/cast.slt +++ b/datafusion/sqllogictest/test_files/cast.slt @@ -69,3 +69,23 @@ query ? SELECT CAST(MAKE_ARRAY() AS VARCHAR[]) ---- [] + +statement ok +create table t0(v0 BIGINT); + +statement ok +insert into t0 values (1),(2),(3); + +query I +select * from t0 where v0>1e100; +---- + +query I +select * from t0 where v0<1e100; +---- +1 +2 +3 + +statement ok +drop table t0; diff --git a/datafusion/sqllogictest/test_files/cse.slt b/datafusion/sqllogictest/test_files/cse.slt new file mode 100644 index 000000000000..3579c1c1635c --- /dev/null +++ b/datafusion/sqllogictest/test_files/cse.slt @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE IF NOT EXISTS t1(a DOUBLE, b DOUBLE) + +# Trivial common expression +query TT +EXPLAIN SELECT + a + 1 AS c1, + a + 1 AS c2 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 +02)--Projection: t1.a + Float64(1) AS __common_expr_1 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 as c1, __common_expr_1@0 as c2] +02)--ProjectionExec: expr=[a@0 + 1 as __common_expr_1] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Common volatile expression +query TT +EXPLAIN SELECT + a + random() AS c1, + a + random() AS c2 +FROM t1 +---- +logical_plan +01)Projection: t1.a + random() AS c1, t1.a + random() AS c2 +02)--TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[a@0 + random() as c1, a@0 + random() as c2] +02)--MemoryExec: partitions=1, partition_sizes=[0] + +# Volatile expression with non-volatile common child +query TT +EXPLAIN SELECT + a + 1 + random() AS c1, + a + 1 + random() AS c2 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2 +02)--Projection: t1.a + Float64(1) AS __common_expr_1 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 + random() as c1, __common_expr_1@0 + random() as c2] +02)--ProjectionExec: expr=[a@0 + 1 as __common_expr_1] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Volatile expression with non-volatile common children +query TT +EXPLAIN SELECT + a + 1 + random() + (a + 2) AS c1, + a + 1 + random() + (a + 2) AS c2 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 + random() + __common_expr_2 AS c1, __common_expr_1 + random() + __common_expr_2 AS c2 +02)--Projection: t1.a + Float64(1) AS __common_expr_1, t1.a + Float64(2) AS __common_expr_2 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 + random() + __common_expr_2@1 as c1, __common_expr_1@0 + random() + __common_expr_2@1 as c2] +02)--ProjectionExec: expr=[a@0 + 1 as __common_expr_1, a@0 + 2 as __common_expr_2] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Common short-circuit expression +query TT +EXPLAIN SELECT + a = 0 AND b = 0 AS c1, + a = 0 AND b = 0 AS c2, + a = 0 OR b = 0 AS c3, + a = 0 OR b = 0 AS c4, + CASE WHEN (a = 0) THEN 0 ELSE 1 END AS c5, + CASE WHEN (a = 0) THEN 0 ELSE 1 END AS c6 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4, __common_expr_3 AS c5, __common_expr_3 AS c6 +02)--Projection: t1.a = Float64(0) AND t1.b = Float64(0) AS __common_expr_1, t1.a = Float64(0) OR t1.b = Float64(0) AS __common_expr_2, CASE WHEN t1.a = Float64(0) THEN Int64(0) ELSE Int64(1) END AS __common_expr_3 +03)----TableScan: t1 projection=[a, b] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 as c1, __common_expr_1@0 as c2, __common_expr_2@1 as c3, __common_expr_2@1 as c4, __common_expr_3@2 as c5, __common_expr_3@2 as c6] +02)--ProjectionExec: expr=[a@0 = 0 AND b@1 = 0 as __common_expr_1, a@0 = 0 OR b@1 = 0 as __common_expr_2, CASE WHEN a@0 = 0 THEN 0 ELSE 1 END as __common_expr_3] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Common children of short-circuit expression +# TODO: consider surely executed children of "short circuited"s for CSE. i.e. `a = 0`, `a = 2`, `a = 4` should be extracted +query TT +EXPLAIN SELECT + a = 0 AND b = 0 AS c1, + a = 0 AND b = 1 AS c2, + b = 2 AND a = 1 AS c3, + b = 3 AND a = 1 AS c4, + a = 2 OR b = 4 AS c5, + a = 2 OR b = 5 AS c6, + b = 6 OR a = 3 AS c7, + b = 7 OR a = 3 AS c8, + CASE WHEN (a = 4) THEN 0 ELSE 1 END AS c9, + CASE WHEN (a = 4) THEN 0 ELSE 2 END AS c10, + CASE WHEN (b = 8) THEN a + 1 ELSE 0 END AS c11, + CASE WHEN (b = 9) THEN a + 1 ELSE 0 END AS c12, + CASE WHEN (b = 10) THEN 0 ELSE a + 2 END AS c13, + CASE WHEN (b = 11) THEN 0 ELSE a + 2 END AS c14 +FROM t1 +---- +logical_plan +01)Projection: t1.a = Float64(0) AND t1.b = Float64(0) AS c1, t1.a = Float64(0) AND t1.b = Float64(1) AS c2, t1.b = Float64(2) AND t1.a = Float64(1) AS c3, t1.b = Float64(3) AND t1.a = Float64(1) AS c4, t1.a = Float64(2) OR t1.b = Float64(4) AS c5, t1.a = Float64(2) OR t1.b = Float64(5) AS c6, t1.b = Float64(6) OR t1.a = Float64(3) AS c7, t1.b = Float64(7) OR t1.a = Float64(3) AS c8, CASE WHEN t1.a = Float64(4) THEN Int64(0) ELSE Int64(1) END AS c9, CASE WHEN t1.a = Float64(4) THEN Int64(0) ELSE Int64(2) END AS c10, CASE WHEN t1.b = Float64(8) THEN t1.a + Float64(1) ELSE Float64(0) END AS c11, CASE WHEN t1.b = Float64(9) THEN t1.a + Float64(1) ELSE Float64(0) END AS c12, CASE WHEN t1.b = Float64(10) THEN Float64(0) ELSE t1.a + Float64(2) END AS c13, CASE WHEN t1.b = Float64(11) THEN Float64(0) ELSE t1.a + Float64(2) END AS c14 +02)--TableScan: t1 projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@0 = 0 AND b@1 = 0 as c1, a@0 = 0 AND b@1 = 1 as c2, b@1 = 2 AND a@0 = 1 as c3, b@1 = 3 AND a@0 = 1 as c4, a@0 = 2 OR b@1 = 4 as c5, a@0 = 2 OR b@1 = 5 as c6, b@1 = 6 OR a@0 = 3 as c7, b@1 = 7 OR a@0 = 3 as c8, CASE WHEN a@0 = 4 THEN 0 ELSE 1 END as c9, CASE WHEN a@0 = 4 THEN 0 ELSE 2 END as c10, CASE WHEN b@1 = 8 THEN a@0 + 1 ELSE 0 END as c11, CASE WHEN b@1 = 9 THEN a@0 + 1 ELSE 0 END as c12, CASE WHEN b@1 = 10 THEN 0 ELSE a@0 + 2 END as c13, CASE WHEN b@1 = 11 THEN 0 ELSE a@0 + 2 END as c14] +02)--MemoryExec: partitions=1, partition_sizes=[0] + +# Common children of volatile, short-circuit expression +# TODO: consider surely executed children of "short circuited"s for CSE. i.e. `a = 0`, `a = 2`, `a = 4` should be extracted +query TT +EXPLAIN SELECT + a = 0 AND b = random() AS c1, + a = 0 AND b = 1 + random() AS c2, + b = 2 + random() AND a = 1 AS c3, + b = 3 + random() AND a = 1 AS c4, + a = 2 OR b = 4 + random() AS c5, + a = 2 OR b = 5 + random() AS c6, + b = 6 + random() OR a = 3 AS c7, + b = 7 + random() OR a = 3 AS c8, + CASE WHEN (a = 4) THEN random() ELSE 1 END AS c9, + CASE WHEN (a = 4) THEN random() ELSE 2 END AS c10, + CASE WHEN (b = 8 + random()) THEN a + 1 ELSE 0 END AS c11, + CASE WHEN (b = 9 + random()) THEN a + 1 ELSE 0 END AS c12, + CASE WHEN (b = 10 + random()) THEN 0 ELSE a + 2 END AS c13, + CASE WHEN (b = 11 + random()) THEN 0 ELSE a + 2 END AS c14 +FROM t1 +---- +logical_plan +01)Projection: t1.a = Float64(0) AND t1.b = random() AS c1, t1.a = Float64(0) AND t1.b = Float64(1) + random() AS c2, t1.b = Float64(2) + random() AND t1.a = Float64(1) AS c3, t1.b = Float64(3) + random() AND t1.a = Float64(1) AS c4, t1.a = Float64(2) OR t1.b = Float64(4) + random() AS c5, t1.a = Float64(2) OR t1.b = Float64(5) + random() AS c6, t1.b = Float64(6) + random() OR t1.a = Float64(3) AS c7, t1.b = Float64(7) + random() OR t1.a = Float64(3) AS c8, CASE WHEN t1.a = Float64(4) THEN random() ELSE Float64(1) END AS c9, CASE WHEN t1.a = Float64(4) THEN random() ELSE Float64(2) END AS c10, CASE WHEN t1.b = Float64(8) + random() THEN t1.a + Float64(1) ELSE Float64(0) END AS c11, CASE WHEN t1.b = Float64(9) + random() THEN t1.a + Float64(1) ELSE Float64(0) END AS c12, CASE WHEN t1.b = Float64(10) + random() THEN Float64(0) ELSE t1.a + Float64(2) END AS c13, CASE WHEN t1.b = Float64(11) + random() THEN Float64(0) ELSE t1.a + Float64(2) END AS c14 +02)--TableScan: t1 projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@0 = 0 AND b@1 = random() as c1, a@0 = 0 AND b@1 = 1 + random() as c2, b@1 = 2 + random() AND a@0 = 1 as c3, b@1 = 3 + random() AND a@0 = 1 as c4, a@0 = 2 OR b@1 = 4 + random() as c5, a@0 = 2 OR b@1 = 5 + random() as c6, b@1 = 6 + random() OR a@0 = 3 as c7, b@1 = 7 + random() OR a@0 = 3 as c8, CASE WHEN a@0 = 4 THEN random() ELSE 1 END as c9, CASE WHEN a@0 = 4 THEN random() ELSE 2 END as c10, CASE WHEN b@1 = 8 + random() THEN a@0 + 1 ELSE 0 END as c11, CASE WHEN b@1 = 9 + random() THEN a@0 + 1 ELSE 0 END as c12, CASE WHEN b@1 = 10 + random() THEN 0 ELSE a@0 + 2 END as c13, CASE WHEN b@1 = 11 + random() THEN 0 ELSE a@0 + 2 END as c14] +02)--MemoryExec: partitions=1, partition_sizes=[0] + +# Common volatile children of short-circuit expression +query TT +EXPLAIN SELECT + a = random() AND b = 0 AS c1, + a = random() AND b = 1 AS c2, + a = 2 + random() OR b = 4 AS c3, + a = 2 + random() OR b = 5 AS c4, + CASE WHEN (a = 4 + random()) THEN 0 ELSE 1 END AS c5, + CASE WHEN (a = 4 + random()) THEN 0 ELSE 2 END AS c6 +FROM t1 +---- +logical_plan +01)Projection: t1.a = random() AND t1.b = Float64(0) AS c1, t1.a = random() AND t1.b = Float64(1) AS c2, t1.a = Float64(2) + random() OR t1.b = Float64(4) AS c3, t1.a = Float64(2) + random() OR t1.b = Float64(5) AS c4, CASE WHEN t1.a = Float64(4) + random() THEN Int64(0) ELSE Int64(1) END AS c5, CASE WHEN t1.a = Float64(4) + random() THEN Int64(0) ELSE Int64(2) END AS c6 +02)--TableScan: t1 projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@0 = random() AND b@1 = 0 as c1, a@0 = random() AND b@1 = 1 as c2, a@0 = 2 + random() OR b@1 = 4 as c3, a@0 = 2 + random() OR b@1 = 5 as c4, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 1 END as c5, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 2 END as c6] +02)--MemoryExec: partitions=1, partition_sizes=[0] diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 6732d3e9108b..3c89109145d7 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -793,3 +793,196 @@ DROP TABLE companies statement ok DROP TABLE leads + +#### +## Test ON clause predicates are not pushed past join for OUTER JOINs +#### + + +# create tables +statement ok +CREATE TABLE employees(emp_id INT, name VARCHAR); + +statement ok +CREATE TABLE department(emp_id INT, dept_name VARCHAR); + +statement ok +INSERT INTO employees (emp_id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Carol'); + +statement ok +INSERT INTO department (emp_id, dept_name) VALUES (1, 'HR'), (3, 'Engineering'), (4, 'Sales'); + +# Can not push the ON filter below an OUTER JOIN +query TT +EXPLAIN SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +logical_plan +01)Left Join: Filter: e.name = Utf8("Alice") OR e.name = Utf8("Bob") +02)--SubqueryAlias: e +03)----TableScan: employees projection=[emp_id, name] +04)--SubqueryAlias: d +05)----TableScan: department projection=[dept_name] +physical_plan +01)ProjectionExec: expr=[emp_id@1 as emp_id, name@2 as name, dept_name@0 as dept_name] +02)--NestedLoopJoinExec: join_type=Right, filter=name@0 = Alice OR name@0 = Bob +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +2 Bob HR +1 Alice Engineering +2 Bob Engineering +1 Alice Sales +2 Bob Sales +3 Carol NULL + +# neither RIGHT OUTER JOIN +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM department AS d +RIGHT JOIN employees AS e +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +2 Bob HR +1 Alice Engineering +2 Bob Engineering +1 Alice Sales +2 Bob Sales +3 Carol NULL + +# neither FULL OUTER JOIN +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM department AS d +FULL JOIN employees AS e +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +2 Bob HR +1 Alice Engineering +2 Bob Engineering +1 Alice Sales +2 Bob Sales +3 Carol NULL + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees e +LEFT JOIN department d +ON (e.name = 'NotExist1' OR e.name = 'NotExist2'); +---- +1 Alice NULL +2 Bob NULL +3 Carol NULL + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees e +LEFT JOIN department d +ON (e.name = 'Alice' OR e.name = 'NotExist'); +---- +1 Alice HR +1 Alice Engineering +1 Alice Sales +2 Bob NULL +3 Carol NULL + +# Can push the ON filter below the JOIN for INNER JOIN (expect to see a filter below the join) +query TT +EXPLAIN SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +logical_plan +01)CrossJoin: +02)--SubqueryAlias: e +03)----Filter: employees.name = Utf8("Alice") OR employees.name = Utf8("Bob") +04)------TableScan: employees projection=[emp_id, name] +05)--SubqueryAlias: d +06)----TableScan: department projection=[dept_name] +physical_plan +01)CrossJoinExec +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: name@1 = Alice OR name@1 = Bob +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)--MemoryExec: partitions=1, partition_sizes=[1] + +# expect no row for Carol +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +1 Alice Engineering +1 Alice Sales +2 Bob HR +2 Bob Engineering +2 Bob Sales + +# OR conditions on Filter (not join filter) +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON e.emp_id = d.emp_id +WHERE (e.name = 'Alice' OR e.name = 'Carol'); +---- +1 Alice HR +3 Carol Engineering + +# Push down OR conditions on Filter through LEFT JOIN if possible +query TT +EXPLAIN SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON e.emp_id = d.emp_id +WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); +---- +logical_plan +01)Filter: d.dept_name != Utf8("Engineering") AND e.name = Utf8("Alice") OR e.name != Utf8("Alice") AND e.name = Utf8("Carol") +02)--Projection: e.emp_id, e.name, d.dept_name +03)----Left Join: e.emp_id = d.emp_id +04)------SubqueryAlias: e +05)--------Filter: employees.name = Utf8("Alice") OR employees.name != Utf8("Alice") AND employees.name = Utf8("Carol") +06)----------TableScan: employees projection=[emp_id, name] +07)------SubqueryAlias: d +08)--------TableScan: department projection=[emp_id, dept_name] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: dept_name@2 != Engineering AND name@1 = Alice OR name@1 != Alice AND name@1 = Carol +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------HashJoinExec: mode=CollectLeft, join_type=Left, on=[(emp_id@0, emp_id@0)], projection=[emp_id@0, name@1, dept_name@3] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: name@1 = Alice OR name@1 != Alice AND name@1 = Carol +08)--------------MemoryExec: partitions=1, partition_sizes=[1] +09)----------MemoryExec: partitions=1, partition_sizes=[1] + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON e.emp_id = d.emp_id +WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); +---- +1 Alice HR +3 Carol Engineering + +statement ok +DROP TABLE employees + +statement ok +DROP TABLE department diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 3cbeea0f9222..df66bffab8e8 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -53,6 +53,20 @@ AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE join_t3(s3 struct) + AS VALUES + (NULL), + (struct(1)), + (struct(2)); + +statement ok +CREATE TABLE join_t4(s4 struct) + AS VALUES + (NULL), + (struct(2)), + (struct(3)); + # Left semi anti join statement ok @@ -1336,6 +1350,44 @@ physical_plan 10)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 11)------------MemoryExec: partitions=1, partition_sizes=[1] +# Join on struct +query TT +explain select join_t3.s3, join_t4.s4 +from join_t3 +inner join join_t4 on join_t3.s3 = join_t4.s4 +---- +logical_plan +01)Inner Join: join_t3.s3 = join_t4.s4 +02)--TableScan: join_t3 projection=[s3] +03)--TableScan: join_t4 projection=[s4] +physical_plan +01)CoalesceBatchesExec: target_batch_size=2 +02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s3@0, s4@0)] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([s3@0], 2), input_partitions=2 +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------MemoryExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=2 +08)------RepartitionExec: partitioning=Hash([s4@0], 2), input_partitions=2 +09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +10)----------MemoryExec: partitions=1, partition_sizes=[1] + +query ?? +select join_t3.s3, join_t4.s4 +from join_t3 +inner join join_t4 on join_t3.s3 = join_t4.s4 +---- +{id: 2} {id: 2} + +# join with struct key and nulls +# Note that intersect or except applies `null_equals_null` as true for Join. +query ? +SELECT * FROM join_t3 +EXCEPT +SELECT * FROM join_t4 +---- +{id: 1} + query TT EXPLAIN select count(*) @@ -3813,3 +3865,182 @@ logical_plan 01)SubqueryAlias: b 02)--Projection: Int64(1) AS a 03)----EmptyRelation + + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.execution.batch_size = 3; + +# Right Hash Joins preserve the right ordering +# No nulls on build side: +statement ok +CREATE TABLE left_table_no_nulls(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(11, 1), +(12, 3), +(13, 5), +(14, 2), +(15, 4); + +statement ok +CREATE TABLE right_table_no_nulls(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(21, 1), +(22, 2), +(23, 3), +(24, 4); + +query IIII +SELECT * FROM ( + SELECT * from left_table_no_nulls +) as lhs RIGHT JOIN ( + SELECT * from right_table_no_nulls + ORDER BY b +) AS rhs ON lhs.b=rhs.b +---- +11 1 21 1 +14 2 22 2 +12 3 23 3 +15 4 24 4 + +query TT +EXPLAIN SELECT * FROM ( + SELECT * from left_table_no_nulls +) as lhs RIGHT JOIN ( + SELECT * from right_table_no_nulls + ORDER BY b +) AS rhs ON lhs.b=rhs.b +---- +logical_plan +01)Right Join: lhs.b = rhs.b +02)--SubqueryAlias: lhs +03)----TableScan: left_table_no_nulls projection=[a, b] +04)--SubqueryAlias: rhs +05)----Sort: right_table_no_nulls.b ASC NULLS LAST +06)------TableScan: right_table_no_nulls projection=[a, b] +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@1, b@1)] +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----SortExec: expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] +05)------MemoryExec: partitions=1, partition_sizes=[1] + + +# Missing probe index in the middle of the batch: +statement ok +CREATE TABLE left_table_missing_probe(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(11, 1), +(12, 2), +(13, 3), +(14, 6), +(15, 8); + +statement ok +CREATE TABLE right_table_missing_probe(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(21, 1), +(22, 4), +(23, 6), +(24, 7), +(25, 8); + +query IIII +SELECT * FROM ( + SELECT * from left_table_missing_probe +) as lhs RIGHT JOIN ( + SELECT * from right_table_missing_probe + ORDER BY b +) AS rhs ON lhs.b=rhs.b +---- +11 1 21 1 +NULL NULL 22 4 +14 6 23 6 +NULL NULL 24 7 +15 8 25 8 + +query TT +EXPLAIN SELECT * FROM ( + SELECT * from left_table_no_nulls +) as lhs RIGHT JOIN ( + SELECT * from right_table_no_nulls + ORDER BY b +) AS rhs ON lhs.b=rhs.b +---- +logical_plan +01)Right Join: lhs.b = rhs.b +02)--SubqueryAlias: lhs +03)----TableScan: left_table_no_nulls projection=[a, b] +04)--SubqueryAlias: rhs +05)----Sort: right_table_no_nulls.b ASC NULLS LAST +06)------TableScan: right_table_no_nulls projection=[a, b] +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@1, b@1)] +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----SortExec: expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] +05)------MemoryExec: partitions=1, partition_sizes=[1] + + +# Null build indices: +statement ok +CREATE TABLE left_table_append_null_build(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(11, 1), +(12, 1), +(13, 5), +(14, 5), +(15, 3); + +statement ok +CREATE TABLE right_table_append_null_build(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(21, 4), +(22, 5), +(23, 6), +(24, 7), +(25, 8); + +query IIII +SELECT * FROM ( + SELECT * from left_table_append_null_build +) as lhs RIGHT JOIN ( + SELECT * from right_table_append_null_build + ORDER BY b +) AS rhs ON lhs.b=rhs.b +---- +NULL NULL 21 4 +13 5 22 5 +14 5 22 5 +NULL NULL 23 6 +NULL NULL 24 7 +NULL NULL 25 8 + + +query TT +EXPLAIN SELECT * FROM ( + SELECT * from left_table_no_nulls +) as lhs RIGHT JOIN ( + SELECT * from right_table_no_nulls + ORDER BY b +) AS rhs ON lhs.b=rhs.b +---- +logical_plan +01)Right Join: lhs.b = rhs.b +02)--SubqueryAlias: lhs +03)----TableScan: left_table_no_nulls projection=[a, b] +04)--SubqueryAlias: rhs +05)----Sort: right_table_no_nulls.b ASC NULLS LAST +06)------TableScan: right_table_no_nulls projection=[a, b] +physical_plan +01)CoalesceBatchesExec: target_batch_size=3 +02)--HashJoinExec: mode=CollectLeft, join_type=Right, on=[(b@1, b@1)] +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----SortExec: expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] +05)------MemoryExec: partitions=1, partition_sizes=[1] + diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 749daa7e20e7..fd6e25ea749d 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -162,6 +162,13 @@ select named_struct('scalar', 27, 'array', values.a, 'null', NULL) from values; {scalar: 27, array: 2, null: } {scalar: 27, array: 3, null: } +query ? +select {'scalar': 27, 'array': values.a, 'null': NULL} from values; +---- +{scalar: 27, array: 1, null: } +{scalar: 27, array: 2, null: } +{scalar: 27, array: 3, null: } + # named_struct with mixed scalar and array values #2 query ? select named_struct('array', values.a, 'scalar', 27, 'null', NULL) from values; @@ -170,6 +177,13 @@ select named_struct('array', values.a, 'scalar', 27, 'null', NULL) from values; {array: 2, scalar: 27, null: } {array: 3, scalar: 27, null: } +query ? +select {'array': values.a, 'scalar': 27, 'null': NULL} from values; +---- +{array: 1, scalar: 27, null: } +{array: 2, scalar: 27, null: } +{array: 3, scalar: 27, null: } + # named_struct with mixed scalar and array values #3 query ? select named_struct('null', NULL, 'array', values.a, 'scalar', 27) from values; @@ -207,3 +221,14 @@ query T select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3)); ---- Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +query T +select arrow_typeof({'first': 1, 'second': 2, 'third': 3}); +---- +Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +# test nested struct literal +query ? +select {'animal': {'cat': 1, 'dog': 2, 'bird': {'parrot': 3, 'canary': 1}}, 'genre': {'fiction': ['mystery', 'sci-fi', 'fantasy'], 'non-fiction': {'biography': 5, 'history': 7, 'science': {'physics': 2, 'biology': 3}}}, 'vehicle': {'car': {'sedan': 4, 'suv': 2}, 'bicycle': 3, 'boat': ['sailboat', 'motorboat']}, 'weather': {'sunny': True, 'temperature': 25.5, 'wind': {'speed': 10, 'direction': 'NW'}}}; +---- +{animal: {cat: 1, dog: 2, bird: {parrot: 3, canary: 1}}, genre: {fiction: [mystery, sci-fi, fantasy], non-fiction: {biography: 5, history: 7, science: {physics: 2, biology: 3}}}, vehicle: {car: {sedan: 4, suv: 2}, bicycle: 3, boat: [sailboat, motorboat]}, weather: {sunny: true, temperature: 25.5, wind: {speed: 10, direction: NW}}} diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index ba07b4ed0a87..7f2e766aab91 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4830,6 +4830,8 @@ NULL 3 NULL 2 NULL 1 +statement ok +drop table t ### Test for window functions with arrays statement ok @@ -4852,3 +4854,38 @@ c [4, 5, 6] NULL statement ok drop table array_data + +# Test for non-i64 offsets for NTILE, LAG, LEAD, NTH_VALUE +statement ok +CREATE TABLE t AS VALUES (3, 3), (4, 4), (5, 5), (6, 6); + +query IIIIIIIII +SELECT + column1, + ntile(2) OVER (order by column1), + ntile(arrow_cast(2, 'Int32')) OVER (order by column1), + lag(column2, -1) OVER (order by column1), + lag(column2, arrow_cast(-1, 'Int32')) OVER (order by column1), + lead(column2, -1) OVER (order by column1), + lead(column2, arrow_cast(-1, 'Int32')) OVER (order by column1), + nth_value(column2, 2) OVER (order by column1), + nth_value(column2, arrow_cast(2, 'Int32')) OVER (order by column1) +FROM t; +---- +3 1 1 4 4 NULL NULL NULL NULL +4 1 1 5 5 3 3 4 4 +5 2 2 6 6 4 4 4 4 +6 2 2 NULL NULL 5 5 4 4 + +# NTILE specifies the argument types so the error is different +query error +SELECT ntile(1.1) OVER (order by column1) FROM t; + +query error DataFusion error: Execution error: Expected an integer value +SELECT lag(column2, 1.1) OVER (order by column1) FROM t; + +query error DataFusion error: Execution error: Expected an integer value +SELECT lead(column2, 1.1) OVER (order by column1) FROM t; + +query error DataFusion error: Execution error: Expected an integer value +SELECT nth_value(column2, 1.1) OVER (order by column1) FROM t; diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index f3f8f6e3abca..9e7ef9632ad3 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -40,7 +40,7 @@ itertools = { workspace = true } object_store = { workspace = true } pbjson-types = "0.6" prost = "0.12" -substrait = { version = "0.34.0", features = ["serde"] } +substrait = { version = "0.36.0", features = ["serde"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index af8dd60f6566..77fd5fe44d44 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -15,25 +15,39 @@ // specific language governing permissions and limitations // under the License. +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use async_recursion::async_recursion; +use datafusion::arrow::array::GenericListArray; use datafusion::arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; +use datafusion::common::plan_err; use datafusion::common::{ - not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, - substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, + not_impl_datafusion_err, not_impl_err, plan_datafusion_err, substrait_datafusion_err, + substrait_err, DFSchema, DFSchemaRef, }; -use substrait::proto::expression::literal::IntervalDayToSecond; -use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; -use url::Url; - -use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; + use datafusion::logical_expr::{ aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values, }; +use substrait::proto::expression::subquery::set_predicate::PredicateOp; +use url::Url; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, + INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, + UNSIGNED_INTEGER_TYPE_VARIATION_REF, +}; +use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, @@ -41,15 +55,20 @@ use datafusion::logical_expr::{ use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, logical_expr::utils::split_conjunction, prelude::{Column, SessionContext}, scalar::ScalarValue, }; +use std::collections::{HashMap, HashSet}; +use std::str::FromStr; +use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::literal::user_defined::Val; +use substrait::proto::expression::literal::IntervalDayToSecond; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction}; +use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -70,54 +89,36 @@ use substrait::proto::{ }; use substrait::proto::{FunctionArgument, SortField}; -use datafusion::arrow::array::GenericListArray; -use datafusion::common::scalar::ScalarStructBuilder; -use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; -use std::collections::HashMap; -use std::str::FromStr; -use std::sync::Arc; - -use crate::variation_const::{ - DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, - DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, - DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, - INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF, - INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, - TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, - TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, - UNSIGNED_INTEGER_TYPE_VARIATION_REF, -}; - -pub fn name_to_op(name: &str) -> Result { +pub fn name_to_op(name: &str) -> Option { match name { - "equal" => Ok(Operator::Eq), - "not_equal" => Ok(Operator::NotEq), - "lt" => Ok(Operator::Lt), - "lte" => Ok(Operator::LtEq), - "gt" => Ok(Operator::Gt), - "gte" => Ok(Operator::GtEq), - "add" => Ok(Operator::Plus), - "subtract" => Ok(Operator::Minus), - "multiply" => Ok(Operator::Multiply), - "divide" => Ok(Operator::Divide), - "mod" => Ok(Operator::Modulo), - "and" => Ok(Operator::And), - "or" => Ok(Operator::Or), - "is_distinct_from" => Ok(Operator::IsDistinctFrom), - "is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom), - "regex_match" => Ok(Operator::RegexMatch), - "regex_imatch" => Ok(Operator::RegexIMatch), - "regex_not_match" => Ok(Operator::RegexNotMatch), - "regex_not_imatch" => Ok(Operator::RegexNotIMatch), - "bitwise_and" => Ok(Operator::BitwiseAnd), - "bitwise_or" => Ok(Operator::BitwiseOr), - "str_concat" => Ok(Operator::StringConcat), - "at_arrow" => Ok(Operator::AtArrow), - "arrow_at" => Ok(Operator::ArrowAt), - "bitwise_xor" => Ok(Operator::BitwiseXor), - "bitwise_shift_right" => Ok(Operator::BitwiseShiftRight), - "bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft), - _ => not_impl_err!("Unsupported function name: {name:?}"), + "equal" => Some(Operator::Eq), + "not_equal" => Some(Operator::NotEq), + "lt" => Some(Operator::Lt), + "lte" => Some(Operator::LtEq), + "gt" => Some(Operator::Gt), + "gte" => Some(Operator::GtEq), + "add" => Some(Operator::Plus), + "subtract" => Some(Operator::Minus), + "multiply" => Some(Operator::Multiply), + "divide" => Some(Operator::Divide), + "mod" => Some(Operator::Modulo), + "and" => Some(Operator::And), + "or" => Some(Operator::Or), + "is_distinct_from" => Some(Operator::IsDistinctFrom), + "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), + "regex_match" => Some(Operator::RegexMatch), + "regex_imatch" => Some(Operator::RegexIMatch), + "regex_not_match" => Some(Operator::RegexNotMatch), + "regex_not_imatch" => Some(Operator::RegexNotIMatch), + "bitwise_and" => Some(Operator::BitwiseAnd), + "bitwise_or" => Some(Operator::BitwiseOr), + "str_concat" => Some(Operator::StringConcat), + "at_arrow" => Some(Operator::AtArrow), + "arrow_at" => Some(Operator::ArrowAt), + "bitwise_xor" => Some(Operator::BitwiseXor), + "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), + "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), + _ => None, } } @@ -403,22 +404,33 @@ pub async fn from_substrait_rel( let mut input = LogicalPlanBuilder::from( from_substrait_rel(ctx, input, extensions).await?, ); + let mut names: HashSet = HashSet::new(); let mut exprs: Vec = vec![]; for e in &p.expressions { let x = from_substrait_rex(ctx, e, input.clone().schema(), extensions) .await?; // if the expression is WindowFunction, wrap in a Window relation - // before returning and do not add to list of this Projection's expression list - // otherwise, add expression to the Projection's expression list - match &*x { - Expr::WindowFunction(_) => { - input = input.window(vec![x.as_ref().clone()])?; - exprs.push(x.as_ref().clone()); - } - _ => { - exprs.push(x.as_ref().clone()); - } + if let Expr::WindowFunction(_) = x.as_ref() { + // Adding the same expression here and in the project below + // works because the project's builder uses columnize_expr(..) + // to transform it into a column reference + input = input.window(vec![x.as_ref().clone()])? + } + // Ensure the expression has a unique display name, so that project's + // validate_unique_names doesn't fail + let name = x.display_name()?; + let mut new_name = name.clone(); + let mut i = 0; + while names.contains(&new_name) { + new_name = format!("{}__temp__{}", name, i); + i += 1; + } + names.insert(new_name.clone()); + if new_name != name { + exprs.push(x.as_ref().clone().alias(new_name.clone())); + } else { + exprs.push(x.as_ref().clone()); } } input.project(exprs)?.build() @@ -1124,18 +1136,33 @@ pub async fn from_substrait_rex( Ok(Arc::new(Expr::ScalarFunction( expr::ScalarFunction::new_udf(func.to_owned(), args), ))) - } else if let Ok(op) = name_to_op(fn_name) { - if args.len() != 2 { + } else if let Some(op) = name_to_op(fn_name) { + if f.arguments.len() < 2 { return not_impl_err!( - "Expect two arguments for binary operator {op:?}" + "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", + f.arguments.len() ); } + // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. + // In those cases we iterate through all the arguments, applying the binary expression against them all + let combined_expr = args + .into_iter() + .fold(None, |combined_expr: Option>, arg: Expr| { + Some(match combined_expr { + Some(expr) => Arc::new(Expr::BinaryExpr(BinaryExpr { + left: Box::new( + Arc::try_unwrap(expr) + .unwrap_or_else(|arc: Arc| (*arc).clone()), + ), // Avoid cloning if possible + op: op.clone(), + right: Box::new(arg), + })), + None => Arc::new(arg), + }) + }) + .unwrap(); - Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { - left: Box::new(args[0].to_owned()), - op, - right: Box::new(args[1].to_owned()), - }))) + Ok(combined_expr) } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { builder.build(ctx, f, input_schema, extensions).await } else { @@ -1234,10 +1261,7 @@ pub async fn from_substrait_rex( Some(subquery_type) => match subquery_type { SubqueryType::InPredicate(in_predicate) => { if in_predicate.needles.len() != 1 { - Err(DataFusionError::Substrait( - "InPredicate Subquery type must have exactly one Needle expression" - .to_string(), - )) + substrait_err!("InPredicate Subquery type must have exactly one Needle expression") } else { let needle_expr = &in_predicate.needles[0]; let haystack_expr = &in_predicate.haystack; @@ -1269,7 +1293,48 @@ pub async fn from_substrait_rex( } } } - _ => substrait_err!("Subquery type not implemented"), + SubqueryType::Scalar(query) => { + let plan = from_substrait_rel( + ctx, + &(query.input.clone()).unwrap_or_default(), + extensions, + ) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Arc::new(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + }))) + } + SubqueryType::SetPredicate(predicate) => { + match predicate.predicate_op() { + // exist + PredicateOp::Exists => { + let relation = &predicate.tuples; + let plan = from_substrait_rel( + ctx, + &relation.clone().unwrap_or_default(), + extensions, + ) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Arc::new(Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + }, + false, + )))) + } + other_type => substrait_err!( + "unimplemented type {:?} for set predicate", + other_type + ), + } + } + other_type => { + substrait_err!("Subquery type {:?} not implemented", other_type) + } }, None => { substrait_err!("Subquery experssion without SubqueryType is not allowed") @@ -1514,12 +1579,22 @@ fn from_substrait_bound( BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { Ok(WindowFrameBound::CurrentRow) } - BoundKind::Preceding(SubstraitBound::Preceding { offset }) => Ok( - WindowFrameBound::Preceding(ScalarValue::Int64(Some(*offset))), - ), - BoundKind::Following(SubstraitBound::Following { offset }) => Ok( - WindowFrameBound::Following(ScalarValue::Int64(Some(*offset))), - ), + BoundKind::Preceding(SubstraitBound::Preceding { offset }) => { + if *offset <= 0 { + return plan_err!("Preceding bound must be positive"); + } + Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } + BoundKind::Following(SubstraitBound::Following { offset }) => { + if *offset <= 0 { + return plan_err!("Following bound must be positive"); + } + Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { if is_lower { Ok(WindowFrameBound::Preceding(ScalarValue::Null)) @@ -1699,6 +1774,7 @@ fn from_substrait_literal( })) => { ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000)) } + Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { match user_defined.type_reference { INTERVAL_YEAR_MONTH_TYPE_REF => { @@ -1988,8 +2064,8 @@ impl BuiltinExprBuilder { extensions: &HashMap, ) -> Result> { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 3 { - return substrait_err!("Expect three arguments for `{fn_name}` expr"); + if f.arguments.len() != 2 && f.arguments.len() != 3 { + return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { @@ -2007,25 +2083,40 @@ impl BuiltinExprBuilder { .await? .as_ref() .clone(); - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let escape_char_expr = - from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { - return substrait_err!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" - ); + + // Default case: escape character is Literal(Utf8(None)) + let escape_char = if f.arguments.len() == 3 { + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type + else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + + let escape_char_expr = + from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) + .await? + .as_ref() + .clone(); + + match escape_char_expr { + Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { + // Convert Option to Option + escape_char_string.and_then(|s| s.chars().next()) + } + _ => { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ) + } + } + } else { + None }; Ok(Arc::new(Expr::Like(Like { negated: false, expr: Box::new(expr), pattern: Box::new(pattern), - escape_char: escape_char.map(|c| c.chars().next().unwrap()), + escape_char, case_insensitive, }))) } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c3bef1689d14..959542080161 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -818,7 +818,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } - let function_anchor = _register_function(fun.to_string(), extension_info); + let function_anchor = register_function(fun.to_string(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -849,7 +849,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); } - let function_anchor = _register_function(fun.name().to_string(), extension_info); + let function_anchor = register_function(fun.name().to_string(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -917,7 +917,7 @@ fn to_substrait_sort_field( } } -fn _register_function( +fn register_function( function_name: String, extension_info: &mut ( Vec, @@ -926,6 +926,14 @@ fn _register_function( ) -> u32 { let (function_extensions, function_set) = extension_info; let function_name = function_name.to_lowercase(); + + // Some functions are named differently in Substrait default extensions than in DF + // Rename those to match the Substrait extensions for interoperability + let function_name = match function_name.as_str() { + "substr" => "substring".to_string(), + _ => function_name, + }; + // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, // a plan-relative identifier starting from 0 is used as the function_anchor. // The consumer is responsible for correctly registering @@ -969,7 +977,7 @@ pub fn make_binary_op_scalar_func( ), ) -> Expression { let function_anchor = - _register_function(operator_to_name(op).to_string(), extension_info); + register_function(operator_to_name(op).to_string(), extension_info); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1044,7 +1052,7 @@ pub fn to_substrait_rex( if *negated { let function_anchor = - _register_function("not".to_string(), extension_info); + register_function("not".to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1076,7 +1084,7 @@ pub fn to_substrait_rex( } let function_anchor = - _register_function(fun.name().to_string(), extension_info); + register_function(fun.name().to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1252,7 +1260,7 @@ pub fn to_substrait_rex( null_treatment: _, }) => { // function reference - let function_anchor = _register_function(fun.to_string(), extension_info); + let function_anchor = register_function(fun.to_string(), extension_info); // arguments let mut arguments: Vec = vec![]; for arg in args { @@ -1330,7 +1338,7 @@ pub fn to_substrait_rex( }; if *negated { let function_anchor = - _register_function("not".to_string(), extension_info); + register_function("not".to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1727,9 +1735,9 @@ fn make_substrait_like_expr( ), ) -> Result { let function_anchor = if ignore_case { - _register_function("ilike".to_string(), extension_info) + register_function("ilike".to_string(), extension_info) } else { - _register_function("like".to_string(), extension_info) + register_function("like".to_string(), extension_info) }; let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; @@ -1759,7 +1767,7 @@ fn make_substrait_like_expr( }; if negated { - let function_anchor = _register_function("not".to_string(), extension_info); + let function_anchor = register_function("not".to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2128,7 +2136,7 @@ fn to_substrait_unary_scalar_fn( HashMap, ), ) -> Result { - let function_anchor = _register_function(fn_name.to_string(), extension_info); + let function_anchor = register_function(fn_name.to_string(), extension_info); let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; @@ -2222,6 +2230,7 @@ mod test { use crate::logical_plan::consumer::{ from_substrait_literal_without_names, from_substrait_type_without_names, }; + use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::GenericListArray; use datafusion::arrow::datatypes::Field; use datafusion::common::scalar::ScalarStructBuilder; @@ -2301,6 +2310,14 @@ mod test { )?; round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?; + round_trip_literal(ScalarValue::IntervalYearMonth(Some(17)))?; + round_trip_literal(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano::new(17, 25, 1234567890), + )))?; + round_trip_literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( + 57, 123456, + ))))?; + Ok(()) } @@ -2368,6 +2385,10 @@ mod test { .into(), ))?; + round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; + round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; + round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; + Ok(()) } diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index 8ea3a69cab61..5d565c037852 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -32,9 +32,101 @@ mod tests { use std::io::BufReader; use substrait::proto::Plan; + async fn register_csv( + ctx: &SessionContext, + table_name: &str, + file_path: &str, + ) -> Result<()> { + ctx.register_csv(table_name, file_path, CsvReadOptions::default()) + .await + } + + async fn create_context_tpch1() -> Result { + let ctx = SessionContext::new(); + register_csv( + &ctx, + "FILENAME_PLACEHOLDER_0", + "tests/testdata/tpch/lineitem.csv", + ) + .await?; + Ok(ctx) + } + + async fn create_context_tpch2() -> Result { + let ctx = SessionContext::new(); + + let registrations = vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"), + ("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"), + ("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"), + ("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"), + ("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"), + ]; + + for (table_name, file_path) in registrations { + register_csv(&ctx, table_name, file_path).await?; + } + + Ok(ctx) + } + + async fn create_context_tpch3() -> Result { + let ctx = SessionContext::new(); + + let registrations = vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ]; + + for (table_name, file_path) in registrations { + register_csv(&ctx, table_name, file_path).await?; + } + + Ok(ctx) + } + + async fn create_context_tpch4() -> Result { + let ctx = SessionContext::new(); + + let registrations = vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/lineitem.csv"), + ]; + + for (table_name, file_path) in registrations { + register_csv(&ctx, table_name, file_path).await?; + } + + Ok(ctx) + } + + async fn create_context_tpch5() -> Result { + let ctx = SessionContext::new(); + + let registrations = vec![ + ("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"), + ("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"), + ("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"), + ("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/supplier.csv"), + ("NATION", "tests/testdata/tpch/nation.csv"), + ("REGION", "tests/testdata/tpch/region.csv"), + ]; + + for (table_name, file_path) in registrations { + register_csv(&ctx, table_name, file_path).await?; + } + + Ok(ctx) + } + #[tokio::test] async fn tpch_test_1() -> Result<()> { - let ctx = create_context().await?; + let ctx = create_context_tpch1().await?; let path = "tests/testdata/tpch_substrait_plans/query_1.json"; let proto = serde_json::from_reader::<_, Plan>(BufReader::new( File::open(path).expect("file not found"), @@ -56,14 +148,122 @@ mod tests { Ok(()) } - async fn create_context() -> datafusion::common::Result { - let ctx = SessionContext::new(); - ctx.register_csv( - "FILENAME_PLACEHOLDER_0", - "tests/testdata/tpch/lineitem.csv", - CsvReadOptions::default(), - ) - .await?; - Ok(ctx) + #[tokio::test] + async fn tpch_test_2() -> Result<()> { + let ctx = create_context_tpch2().await?; + let path = "tests/testdata/tpch_substrait_plans/query_2.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!( + plan_str, + "Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\ + \n Limit: skip=0, fetch=100\ + \n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.n_name ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\ + \n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\ + \n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_regionkey = FILENAME_PLACEHOLDER_4.r_regionkey AND FILENAME_PLACEHOLDER_4.r_name = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = ()\ + \n Subquery:\ + \n Aggregate: groupBy=[[]], aggr=[[MIN(FILENAME_PLACEHOLDER_5.ps_supplycost)]]\ + \n Projection: FILENAME_PLACEHOLDER_5.ps_supplycost\ + \n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.n_nationkey AND FILENAME_PLACEHOLDER_7.n_regionkey = FILENAME_PLACEHOLDER_8.r_regionkey AND FILENAME_PLACEHOLDER_8.r_name = CAST(Utf8(\"EUROPE\") AS Utf8)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_5 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_6 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_7 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_8 projection=[r_regionkey, r_name, r_comment]\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_4 projection=[r_regionkey, r_name, r_comment]" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_3() -> Result<()> { + let ctx = create_context_tpch3().await?; + let path = "tests/testdata/tpch_substrait_plans/query_3.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_2.l_orderkey AS L_ORDERKEY, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE, FILENAME_PLACEHOLDER_1.o_orderdate AS O_ORDERDATE, FILENAME_PLACEHOLDER_1.o_shippriority AS O_SHIPPRIORITY\ + \n Limit: skip=0, fetch=10\ + \n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST, FILENAME_PLACEHOLDER_1.o_orderdate ASC NULLS LAST\ + \n Projection: FILENAME_PLACEHOLDER_2.l_orderkey, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount), FILENAME_PLACEHOLDER_1.o_orderdate, FILENAME_PLACEHOLDER_1.o_shippriority\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_2.l_orderkey, FILENAME_PLACEHOLDER_1.o_orderdate, FILENAME_PLACEHOLDER_1.o_shippriority]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ + \n Projection: FILENAME_PLACEHOLDER_2.l_orderkey, FILENAME_PLACEHOLDER_1.o_orderdate, FILENAME_PLACEHOLDER_1.o_shippriority, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.c_mktsegment = CAST(Utf8(\"HOUSEHOLD\") AS Utf8) AND FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_1.o_orderdate < Date32(\"1995-03-25\") AND FILENAME_PLACEHOLDER_2.l_shipdate > Date32(\"1995-03-25\")\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_4() -> Result<()> { + let ctx = create_context_tpch4().await?; + let path = "tests/testdata/tpch_substrait_plans/query_4.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.o_orderpriority AS O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT\ + \n Sort: FILENAME_PLACEHOLDER_0.o_orderpriority ASC NULLS LAST\ + \n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.o_orderpriority]], aggr=[[count(Int64(1))]]\ + \n Projection: FILENAME_PLACEHOLDER_0.o_orderpriority\ + \n Filter: FILENAME_PLACEHOLDER_0.o_orderdate >= CAST(Utf8(\"1993-07-01\") AS Date32) AND FILENAME_PLACEHOLDER_0.o_orderdate < CAST(Utf8(\"1993-10-01\") AS Date32) AND EXISTS ()\ + \n Subquery:\ + \n Filter: FILENAME_PLACEHOLDER_1.l_orderkey = FILENAME_PLACEHOLDER_1.l_orderkey AND FILENAME_PLACEHOLDER_1.l_commitdate < FILENAME_PLACEHOLDER_1.l_receiptdate\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]"); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_5() -> Result<()> { + let ctx = create_context_tpch5().await?; + let path = "tests/testdata/tpch_substrait_plans/query_5.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{:?}", plan); + assert_eq!(plan_str, "Projection: NATION.n_name AS N_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE\ + \n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST\ + \n Aggregate: groupBy=[[NATION.n_name]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\ + \n Projection: NATION.n_name, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\ + \n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_2.l_suppkey = FILENAME_PLACEHOLDER_3.s_suppkey AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.s_nationkey AND FILENAME_PLACEHOLDER_3.s_nationkey = NATION.n_nationkey AND NATION.n_regionkey = REGION.r_regionkey AND REGION.r_name = CAST(Utf8(\"ASIA\") AS Utf8) AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1995-01-01\") AS Date32)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n Inner Join: Filter: Boolean(true)\ + \n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\ + \n TableScan: FILENAME_PLACEHOLDER_3 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\ + \n TableScan: NATION projection=[n_nationkey, n_name, n_regionkey, n_comment]\ + \n TableScan: REGION projection=[r_regionkey, r_name, r_comment]"); + Ok(()) } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 7ed376f62ba0..2893b1a31a26 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -327,7 +327,7 @@ async fn simple_scalar_function_pow() -> Result<()> { #[tokio::test] async fn simple_scalar_function_substr() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await + roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await } #[tokio::test] @@ -523,24 +523,6 @@ async fn roundtrip_arithmetic_ops() -> Result<()> { Ok(()) } -#[tokio::test] -async fn roundtrip_interval_literal() -> Result<()> { - roundtrip( - "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(YearMonth)')", - ) - .await?; - roundtrip( - "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(DayTime)')", - ) - .await?; - roundtrip( - "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 'Interval(MonthDayNano)')", - ) - .await?; - - Ok(()) -} - #[tokio::test] async fn roundtrip_like() -> Result<()> { roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await @@ -650,6 +632,17 @@ async fn simple_window_function() -> Result<()> { roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, sum(b) OVER (PARTITION BY a) FROM data;").await } +#[tokio::test] +async fn window_with_rows() -> Result<()> { + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 FOLLOWING AND 4 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 4 PRECEDING AND 2 PRECEDING) FROM data;").await +} + #[tokio::test] async fn qualified_schema_table_reference() -> Result<()> { roundtrip("SELECT * FROM public.data;").await @@ -758,6 +751,22 @@ async fn roundtrip_values_duplicate_column_join() -> Result<()> { .await } +#[tokio::test] +async fn duplicate_column() -> Result<()> { + // Substrait does not keep column names (aliases) in the plan, rather it operates on column indices + // only. DataFusion however, is strict about not having duplicate column names appear in the plan. + // This test confirms that we generate aliases for columns in the plan which would otherwise have + // colliding names. + assert_expected_plan( + "SELECT a + 1 as sum_a, a + 1 as sum_a_2 FROM data", + "Projection: data.a + Int64(1) AS sum_a, data.a + Int64(1) AS data.a + Int64(1)__temp__0 AS sum_a_2\ + \n Projection: data.a + Int64(1)\ + \n TableScan: data projection=[a]", + true, + ) + .await +} + /// Construct a plan that cast columns. Only those SQL types are supported for now. #[tokio::test] async fn new_test_grammar() -> Result<()> { @@ -810,23 +819,20 @@ async fn roundtrip_aggregate_udf() -> Result<()> { struct Dummy {} impl Accumulator for Dummy { - fn state(&mut self) -> datafusion::error::Result> { - Ok(vec![]) + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Float64(None), ScalarValue::UInt32(None)]) } - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { Ok(()) } - fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { Ok(()) } - fn evaluate(&mut self) -> datafusion::error::Result { - Ok(ScalarValue::Float64(None)) + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(None)) } fn size(&self) -> usize { @@ -1060,6 +1066,8 @@ async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { assert_eq!(plan1str, plan2str); assert_eq!(plan.schema(), plan2.schema()); + + DataFrame::new(ctx.state(), plan2).show().await?; Ok(()) } @@ -1132,7 +1140,6 @@ async fn create_context() -> Result { Field::new("d", DataType::Boolean, true), Field::new("e", DataType::UInt32, true), Field::new("f", DataType::Utf8, true), - Field::new("g", DataType::Interval(IntervalUnit::DayTime), true), ]; let schema = Schema::new(fields); explicit_options.schema = Some(&schema); @@ -1195,6 +1202,11 @@ async fn create_all_type_context() -> Result { ), Field::new("decimal_128_col", DataType::Decimal128(10, 2), true), Field::new("decimal_256_col", DataType::Decimal256(10, 2), true), + Field::new( + "interval_day_time_col", + DataType::Interval(IntervalUnit::DayTime), + true, + ), ]); explicit_options.schema = Some(&schema); explicit_options.has_header = false; diff --git a/datafusion/substrait/tests/testdata/data.csv b/datafusion/substrait/tests/testdata/data.csv index 1b85b166b1df..ef2766d29565 100644 --- a/datafusion/substrait/tests/testdata/data.csv +++ b/datafusion/substrait/tests/testdata/data.csv @@ -1,3 +1,3 @@ a,b,c,d,e,f -1,2.0,2020-01-01,false,4294967296,'a' -3,4.5,2020-01-01,true,2147483648,'b' \ No newline at end of file +1,2.0,2020-01-01,false,4294967295,'a' +3,4.5,2020-01-01,true,2147483648,'b' diff --git a/datafusion/substrait/tests/testdata/tpch/customer.csv b/datafusion/substrait/tests/testdata/tpch/customer.csv new file mode 100644 index 000000000000..ed15da17d47d --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/customer.csv @@ -0,0 +1,2 @@ +c_custkey,c_name,c_address,c_nationkey,c_phone,c_acctbal,c_mktsegment,c_comment +1,Customer#000000001,Address1,1,123-456-7890,5000.00,BUILDING,No comment \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/nation.csv b/datafusion/substrait/tests/testdata/tpch/nation.csv new file mode 100644 index 000000000000..fdf7421467d3 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/nation.csv @@ -0,0 +1,2 @@ +n_nationkey,n_name,n_regionkey,n_comment +0,ALGERIA,0, haggle. carefully final deposits detect slyly agai \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/orders.csv b/datafusion/substrait/tests/testdata/tpch/orders.csv new file mode 100644 index 000000000000..b9abea3cbb5b --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/orders.csv @@ -0,0 +1,2 @@ +o_orderkey,o_custkey,o_orderstatus,o_totalprice,o_orderdate,o_orderpriority,o_clerk,o_shippriority,o_comment +1,1,O,1000.00,2023-01-01,5-LOW,Clerk#000000001,0,No comment \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/part.csv b/datafusion/substrait/tests/testdata/tpch/part.csv new file mode 100644 index 000000000000..ef6d04271117 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/part.csv @@ -0,0 +1,2 @@ +p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,p_retailprice,p_comment +1,pink powder puff,Manufacturer#1,Brand#13,SMALL PLATED COPPER,7,JUMBO PKG,901.00,ly final dependencies: slyly bold \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/partsupp.csv b/datafusion/substrait/tests/testdata/tpch/partsupp.csv new file mode 100644 index 000000000000..5c585abc7733 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/partsupp.csv @@ -0,0 +1,2 @@ +ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment +1,1,1000,50.00,slyly final packages boost against the slyly regular \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/region.csv b/datafusion/substrait/tests/testdata/tpch/region.csv new file mode 100644 index 000000000000..6c3fb4524355 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/region.csv @@ -0,0 +1,2 @@ +r_regionkey,r_name,r_comment +0,AFRICA,lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch/supplier.csv b/datafusion/substrait/tests/testdata/tpch/supplier.csv new file mode 100644 index 000000000000..f73d2cbeaf91 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch/supplier.csv @@ -0,0 +1,2 @@ +s_suppkey,s_name,s_address,s_nationkey,s_phone,s_acctbal,s_comment +1,Supplier#1,123 Main St,0,555-1234,1000.00,No comments \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_2.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_2.json new file mode 100644 index 000000000000..dd570ca06d45 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_2.json @@ -0,0 +1,1582 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, + { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "like:vchar_vchar" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 3, + "name": "min:decimal" + } + } + ], + "relations": [ + { + "root": { + "input": { + "fetch": { + "common": { + "direct": {} + }, + "input": { + "sort": { + "common": { + "direct": {} + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "P_PARTKEY", + "P_NAME", + "P_MFGR", + "P_BRAND", + "P_TYPE", + "P_SIZE", + "P_CONTAINER", + "P_RETAILPRICE", + "P_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 55, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 23, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "S_SUPPKEY", + "S_NAME", + "S_ADDRESS", + "S_NATIONKEY", + "S_PHONE", + "S_ACCTBAL", + "S_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 101, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "PS_PARTKEY", + "PS_SUPPKEY", + "PS_AVAILQTY", + "PS_SUPPLYCOST", + "PS_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 199, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "N_NATIONKEY", + "N_NAME", + "N_REGIONKEY", + "N_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_3", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "R_REGIONKEY", + "R_NAME", + "R_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_4", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "i32": 15, + "nullable": false, + "typeVariationReference": 0 + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "fixedChar": "%BRASS", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 25 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 26 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "EUROPE", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "subquery": { + "scalar": { + "input": { + "aggregate": { + "common": { + "direct": {} + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 19 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "PS_PARTKEY", + "PS_SUPPKEY", + "PS_AVAILQTY", + "PS_SUPPLYCOST", + "PS_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 199, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_5", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "S_SUPPKEY", + "S_NAME", + "S_ADDRESS", + "S_NATIONKEY", + "S_PHONE", + "S_ACCTBAL", + "S_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 101, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_6", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "N_NATIONKEY", + "N_NAME", + "N_REGIONKEY", + "N_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_7", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "R_REGIONKEY", + "R_NAME", + "R_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_8", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "EUROPE", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [] + } + ], + "measures": [ + { + "measure": { + "functionReference": 3, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + } + } + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": {} + } + } + ] + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ] + } + }, + "offset": "0", + "count": "100" + } + }, + "names": [ + "S_ACCTBAL", + "S_NAME", + "N_NAME", + "P_PARTKEY", + "P_MFGR", + "S_ADDRESS", + "S_PHONE", + "S_COMMENT" + ] + } + } + ], + "expectedTypeUrls": [] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_3.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_3.json new file mode 100644 index 000000000000..4ca074d2e8cf --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_3.json @@ -0,0 +1,851 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any1_any1" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "gt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "multiply:opt_decimal_decimal" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "subtract:opt_decimal_decimal" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "sum:opt_decimal" + } + }], + "relations": [{ + "root": { + "input": { + "fetch": { + "common": { + "direct": { + } + }, + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [4, 5, 6, 7] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [33, 34, 35, 36] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 117, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 79, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "HOUSEHOLD", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "date": 9214, + "nullable": false, + "typeVariationReference": 0 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 27 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "date": 9214, + "nullable": false, + "typeVariationReference": 0 + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 6, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "offset": "0", + "count": "10" + } + }, + "names": ["L_ORDERKEY", "REVENUE", "O_ORDERDATE", "O_SHIPPRIORITY"] + } + }], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_4.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_4.json new file mode 100644 index 000000000000..6e946cefdd13 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_4.json @@ -0,0 +1,540 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 4, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "equal:any1_any1" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "count:opt" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [9] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 79, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1993-07-01", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1993-10-01", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }] + } + } + }, { + "value": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + } + } + } + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 4, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["O_ORDERPRIORITY", "ORDER_COUNT"] + } + }], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_5.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_5.json new file mode 100644 index 000000000000..75b82a305eb3 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_5.json @@ -0,0 +1,1254 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any1_any1" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gte:date_date" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "lt:date_date" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "multiply:opt_decimal_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "subtract:opt_decimal_decimal" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "sum:opt_decimal" + } + } + ], + "relations": [ + { + "root": { + "input": { + "sort": { + "common": { + "direct": {} + }, + "input": { + "aggregate": { + "common": { + "direct": {} + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 47, + 48 + ] + } + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "join": { + "common": { + "direct": {} + }, + "left": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "C_CUSTKEY", + "C_NAME", + "C_ADDRESS", + "C_NATIONKEY", + "C_PHONE", + "C_ACCTBAL", + "C_MKTSEGMENT", + "C_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 117, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_0", + "parquet": {} + } + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "O_ORDERKEY", + "O_CUSTKEY", + "O_ORDERSTATUS", + "O_TOTALPRICE", + "O_ORDERDATE", + "O_ORDERPRIORITY", + "O_CLERK", + "O_SHIPPRIORITY", + "O_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 79, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_1", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "L_ORDERKEY", + "L_PARTKEY", + "L_SUPPKEY", + "L_LINENUMBER", + "L_QUANTITY", + "L_EXTENDEDPRICE", + "L_DISCOUNT", + "L_TAX", + "L_RETURNFLAG", + "L_LINESTATUS", + "L_SHIPDATE", + "L_COMMITDATE", + "L_RECEIPTDATE", + "L_SHIPINSTRUCT", + "L_SHIPMODE", + "L_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 1, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fixedChar": { + "length": 10, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 44, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_2", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "S_SUPPKEY", + "S_NAME", + "S_ADDRESS", + "S_NATIONKEY", + "S_PHONE", + "S_ACCTBAL", + "S_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 40, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 15, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 101, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "local_files": { + "items": [ + { + "uri_file": "file://FILENAME_PLACEHOLDER_3", + "parquet": {} + } + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "N_NATIONKEY", + "N_NAME", + "N_REGIONKEY", + "N_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "NATION" + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "right": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "R_REGIONKEY", + "R_NAME", + "R_COMMENT" + ], + "struct": { + "types": [ + { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "varchar": { + "length": 152, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "REGION" + ] + } + } + }, + "expression": { + "literal": { + "boolean": true, + "nullable": false, + "typeVariationReference": 0 + } + }, + "type": "JOIN_TYPE_INNER" + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 33 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 36 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 36 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 40 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 42 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 44 + } + }, + "rootReference": {} + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 45 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "fixedChar": { + "length": 25, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "ASIA", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1994-01-01", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 3, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "cast": { + "type": { + "date": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-01-01", + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + } + ] + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 41 + } + }, + "rootReference": {} + } + }, + { + "scalarFunction": { + "functionReference": 4, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 5, + "args": [], + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_UNSPECIFIED" + } + } + }, + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": {} + } + } + } + ] + } + } + } + ] + } + } + ] + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 6, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 0, + "precision": 19, + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + }, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + } + ] + } + }, + "names": [ + "N_NAME", + "REVENUE" + ] + } + } + ], + "expectedTypeUrls": [] +} \ No newline at end of file diff --git a/docs/source/contributor-guide/inviting.md b/docs/source/contributor-guide/inviting.md index be189b95f684..967f417e6e9a 100644 --- a/docs/source/contributor-guide/inviting.md +++ b/docs/source/contributor-guide/inviting.md @@ -294,7 +294,7 @@ Subject: [DISCUSS] $NEW_PMC_MEMBER for PMC I would like to propose adding $NEW_PMC_MEMBER[1] to the DataFusion PMC. -$NEW_PMC_MEMBMER has been a committer since $COMMITER_MONTH [2], has a +$NEW_PMC_MEMBMER has been a committer since $COMMITTER_MONTH [2], has a strong and sustained contribution record for more than a year, and focused on helping the community and the project grow[3]. diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index f805f0a99292..fe3990b90c3c 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -268,7 +268,7 @@ impl PartitionEvaluator for MyPartitionEvaluator { } } -/// Create a `PartitionEvalutor` to evaluate this function on a new +/// Create a `PartitionEvaluator` to evaluate this function on a new /// partition. fn make_partition_evaluator() -> Result> { Ok(Box::new(MyPartitionEvaluator::new())) @@ -474,7 +474,7 @@ impl Accumulator for GeometricMean { ### registering an Aggregate UDF -To register a Aggreate UDF, you need to wrap the function implementation in a [`AggregateUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udaf`] helper functions to make this easier. +To register a Aggregate UDF, you need to wrap the function implementation in a [`AggregateUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udaf`] helper functions to make this easier. There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udaf.rs`]. ```rust diff --git a/docs/source/library-user-guide/using-the-sql-api.md b/docs/source/library-user-guide/using-the-sql-api.md index f4e85ee4e3a9..1a25f078cc2e 100644 --- a/docs/source/library-user-guide/using-the-sql-api.md +++ b/docs/source/library-user-guide/using-the-sql-api.md @@ -19,4 +19,199 @@ # Using the SQL API +DataFusion has a full SQL API that allows you to interact with DataFusion using +SQL query strings. The simplest way to use the SQL API is to use the +[`SessionContext`] struct which provides the highest-level API for executing SQL +queries. + +To use SQL, you first register your data as a table and then run queries +using the [`SessionContext::sql`] method. For lower level control such as +preventing DDL, you can use [`SessionContext::sql_with_options`] or the +[`SessionState`] APIs + +[`sessioncontext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html +[`sessioncontext::sql`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.sql +[`sessioncontext::sql_with_options`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.sql_with_options +[`sessionstate`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html + +## Registering Data Sources using `SessionContext::register*` + +The `SessionContext::register*` methods tell DataFusion the name of +the source and how to read data. Once registered, you can execute SQL queries +using the `SessionContext::sql` method referring to your data source as a table. + +### Read a CSV File + +```rust +use datafusion::error::Result; +use datafusion::prelude::*; +use arrow::record_batch::RecordBatch; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + // register the "example" table + ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; + // create a plan to run a SQL query + let df = ctx.sql("SELECT a, MIN(b) FROM example WHERE a <= b GROUP BY a LIMIT 100").await?; + // execute the plan and collect the results as Vec + let results: Vec = df.collect().await?; + // Use the assert_batches_eq macro to compare the results with expected output + datafusion::assert_batches_eq!(vec![ + "+---+----------------+", + "| a | MIN(example.b) |", + "+---+----------------+", + "| 1 | 2 |", + "+---+----------------+", + ], + &results + ); + Ok(()) +} +``` + +### Read an Apache Parquet file + +Similarly to CSV, you can register a Parquet file as a table using the `register_parquet` method. + +```rust +use datafusion::error::Result; +use datafusion::prelude::*; +#[tokio::main] +async fn main() -> Result<()> { + // create local session context + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + + // register parquet file with the execution context + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + // execute the query + let df = ctx.sql( + "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ + FROM alltypes_plain \ + WHERE id > 1 AND tinyint_col < double_col", + ).await?; + + // execute the plan, and compare to the expected results + let results = df.collect().await?; + datafusion::assert_batches_eq!(vec![ + "+---------+------------+--------------------------------+", + "| int_col | double_col | alltypes_plain.date_string_col |", + "+---------+------------+--------------------------------+", + "| 1 | 10.1 | 03/01/09 |", + "| 1 | 10.1 | 04/01/09 |", + "| 1 | 10.1 | 02/01/09 |", + "+---------+------------+--------------------------------+", + ], + &results + ); + Ok(()) +} +``` + +### Read an Apache Avro file + +DataFusion can also read Avro files using the `register_avro` method. + +```rust +use datafusion::arrow::util::pretty; +use datafusion::error::Result; +use datafusion::prelude::*; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + // find the path to the avro test files + let testdata = datafusion::test_util::arrow_test_data(); + // register avro file with the execution context + let avro_file = &format!("{testdata}/avro/alltypes_plain.avro"); + ctx.register_avro("alltypes_plain", avro_file, AvroReadOptions::default()).await?; + + // execute the query + let df = ctx.sql( + "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ + FROM alltypes_plain \ + WHERE id > 1 AND tinyint_col < double_col" + ).await?; + + // execute the plan, and compare to the expected results + let results = df.collect().await?; + datafusion::assert_batches_eq!(vec![ + "+---------+------------+--------------------------------+", + "| int_col | double_col | alltypes_plain.date_string_col |", + "+---------+------------+--------------------------------+", + "| 1 | 10.1 | 03/01/09 |", + "| 1 | 10.1 | 04/01/09 |", + "| 1 | 10.1 | 02/01/09 |", + "+---------+------------+--------------------------------+", + ], + &results + ); + Ok(()) +} +``` + +## Reading Multiple Files as a table + +It is also possible to read multiple files as a single table. This is done +with the ListingTableProvider which takes a list of file paths and reads them +as a single table, matching schemas as appropriate + Coming Soon + +```rust + +``` + +## Using `CREATE EXTERNAL TABLE` to register data sources via SQL + +You can also register files using SQL using the [`CREATE EXTERNAL TABLE`] +statement. + +[`create external table`]: ../user-guide/sql/ddl.md#create-external-table + +```rust +use datafusion::error::Result; +use datafusion::prelude::*; +#[tokio::main] +async fn main() -> Result<()> { + // create local session context + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + + // register parquet file using SQL + let ddl = format!( + "CREATE EXTERNAL TABLE alltypes_plain \ + STORED AS PARQUET LOCATION '{testdata}/alltypes_plain.parquet'" + ); + ctx.sql(&ddl).await?; + + // execute the query referring to the alltypes_plain table we just registered + let df = ctx.sql( + "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ + FROM alltypes_plain \ + WHERE id > 1 AND tinyint_col < double_col", + ).await?; + + // execute the plan, and compare to the expected results + let results = df.collect().await?; + datafusion::assert_batches_eq!(vec![ + "+---------+------------+--------------------------------+", + "| int_col | double_col | alltypes_plain.date_string_col |", + "+---------+------------+--------------------------------+", + "| 1 | 10.1 | 03/01/09 |", + "| 1 | 10.1 | 04/01/09 |", + "| 1 | 10.1 | 02/01/09 |", + "+---------+------------+--------------------------------+", + ], + &results + ); + Ok(()) +} +``` diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index e0c9e69eb6ed..e0b6f434a032 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -80,7 +80,11 @@ If you'd like to learn more about `Expr`s, before we get into the details of cre ## Rewriting `Expr`s -[rewrite_expr.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) contains example code for rewriting `Expr`s. +There are several examples of rewriting and working with `Exprs`: + +- [expr_api.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs) +- [analyzer_rule.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/analyzer_rule.rs) +- [optimizer_rule.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/optimizer_rule.rs) Rewriting Expressions is the process of taking an `Expr` and transforming it into another `Expr`. This is useful for a number of reasons, including: diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 71a614313e8a..7dbd4045e75b 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -56,7 +56,7 @@ datafusion = { git = "https://github.com/apache/datafusion", branch = "main", de More on [Cargo dependencies](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies) -## Run a SQL query against data stored in a CSV: +## Run a SQL query against data stored in a CSV ```rust use datafusion::prelude::*; @@ -76,7 +76,10 @@ async fn main() -> datafusion::error::Result<()> { } ``` -## Use the DataFrame API to process data stored in a CSV: +See [the SQL API](../library-user-guide/using-the-sql-api.md) section of the +library guide for more information on the SQL API. + +## Use the DataFrame API to process data stored in a CSV ```rust use datafusion::prelude::*; @@ -261,8 +264,8 @@ Set environment [variables](https://doc.rust-lang.org/std/backtrace/index.html#e ```bash RUST_BACKTRACE=1 ./target/debug/datafusion-cli DataFusion CLI v31.0.0 -> select row_numer() over (partition by a order by a) from (select 1 a); -Error during planning: Invalid function 'row_numer'. +> select row_number() over (partition by a order by a) from (select 1 a); +Error during planning: Invalid function 'row_number'. Did you mean 'ROW_NUMBER'? backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace @@ -287,7 +290,7 @@ async fn test_get_backtrace_for_failed_code() -> Result<()> { let ctx = SessionContext::new(); let sql = " - select row_numer() over (partition by a order by a) from (select 1 a); + select row_number() over (partition by a order by a) from (select 1 a); "; let _ = ctx.sql(sql).await?.collect().await?; diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 427a7bf130a7..edb0e1d0c9f0 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -123,11 +123,9 @@ bool_or(expression) ### `count` -Returns the number of rows in the specified column. +Returns the number of non-null values in the specified column. -Count includes _null_ values in the total count. -To exclude _null_ values from the total count, include ` IS NOT NULL` -in the `WHERE` clause. +To include _null_ values in the total count, use `count(*)`. ``` count(expression) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index ec34dbf9ba6c..d636726b45fe 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1132,6 +1132,14 @@ substr(str, start_pos[, length]) - **length**: Number of characters to extract. If not specified, returns the rest of the string after the start position. +#### Aliases + +- substring + +### `substring` + +_Alias of [substr](#substr)._ + ### `translate` Translates characters in a string to specified translation characters.