diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 5600dab98b55..c2b7d626d1fd 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -8,13 +8,21 @@ assignees: '' --- **Describe the bug** + **To Reproduce** + **Expected behavior** + **Additional context** + \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index d9883dd454b7..d7aad5e7761a 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -8,14 +8,22 @@ assignees: '' --- **Is your feature request related to a problem or challenge? Please describe what you are trying to do.** + **Describe the solution you'd like** + **Describe alternatives you've considered** + **Additional context** + diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md index 75c7308046b6..aafac7cb86c1 100644 --- a/.github/ISSUE_TEMPLATE/question.md +++ b/.github/ISSUE_TEMPLATE/question.md @@ -8,10 +8,16 @@ assignees: '' --- **Which part is this question about** + **Describe your question** + **Additional context** + diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 0157caf8c296..0ef6532da477 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -42,14 +42,6 @@ runs: - name: Generate lockfile shell: bash run: cargo fetch - - name: Cache Rust dependencies - uses: actions/cache@v3 - with: - # these represent compiled steps of both dependencies and arrow - # and thus are specific for a particular OS, arch and rust version. - path: /github/home/target - key: ${{ runner.os }}-${{ runner.arch }}-target-cache3-${{ inputs.rust-version }}-${{ hashFiles('**/Cargo.lock') }} - restore-keys: ${{ runner.os }}-${{ runner.arch }}-target-cache3-${{ inputs.rust-version }}- - name: Install Build Dependencies shell: bash run: | diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 9bd42dbaa0d6..9c4cda5d034d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -7,3 +7,9 @@ updates: open-pull-requests-limit: 10 target-branch: master labels: [auto-dependencies] + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + open-pull-requests-limit: 10 + labels: [auto-dependencies] diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index db170e360ce5..7c51452c54b8 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,6 +1,6 @@ # Which issue does this PR close? - @@ -8,21 +8,21 @@ Closes #. # Rationale for this change - # What changes are included in this PR? - # Are there any user-facing changes? - diff --git a/.github/workflows/README.md b/.github/workflows/README.md new file mode 100644 index 000000000000..679ccc956a20 --- /dev/null +++ b/.github/workflows/README.md @@ -0,0 +1,25 @@ + + +The CI is structured so most tests are run in specific workflows: +`arrow.yml` for `arrow`, `parquet.yml` for `parquet` and so on. + +The basic idea is to run all tests on pushes to master (to ensure we +keep master green) but run only the individual workflows on PRs that +change files that could affect them. diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml new file mode 100644 index 000000000000..d34ee3b49b5c --- /dev/null +++ b/.github/workflows/arrow.yml @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# tests for arrow crate +name: arrow + +on: + # always trigger + push: + branches: + - master + pull_request: + paths: + - arrow/** + - .github/** + +jobs: + + # test the crate + linux-test: + name: Test + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Test + run: | + cargo test -p arrow + - name: Test --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict + run: | + cargo test -p arrow --features=force_validate,prettyprint,ipc_compression,ffi,dyn_cmp_dict + - name: Run examples + run: | + # Test arrow examples + cargo run --example builders + cargo run --example dynamic_types + cargo run --example read_csv + cargo run --example read_csv_infer_schema + - name: Run non-archery based integration-tests + run: cargo test -p arrow-integration-testing + + # test compilaton features + linux-features: + name: Check Compilation + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Check compilation + run: | + cargo check -p arrow + - name: Check compilation --no-default-features + run: | + cargo check -p arrow --no-default-features + - name: Check compilation --all-targets + run: | + cargo check -p arrow --all-targets + - name: Check compilation --no-default-features --all-targets + run: | + cargo check -p arrow --no-default-features --all-targets + - name: Check compilation --no-default-features --all-targets --features test_utils + run: | + cargo check -p arrow --no-default-features --all-targets --features test_utils + + # test the --features "simd" of the arrow crate. This requires nightly Rust. + linux-test-simd: + name: Test SIMD on AMD64 Rust ${{ matrix.rust }} + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: nightly + - name: Run tests --features "simd" + run: | + cargo test -p arrow --features "simd" + - name: Check compilation --features "simd" + run: | + cargo check -p arrow --features simd + - name: Check compilation --features simd --all-targets + run: | + cargo check -p arrow --features simd --all-targets + + + # test the arrow crate builds against wasm32 in stable rust + wasm32-build: + name: Build wasm32 + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Cache Cargo + uses: actions/cache@v3 + with: + path: /github/home/.cargo + key: cargo-wasm32-cache3- + - name: Setup Rust toolchain for WASM + run: | + rustup toolchain install nightly + rustup override set nightly + rustup target add wasm32-unknown-unknown + rustup target add wasm32-wasi + - name: Build + run: | + cd arrow + cargo build --no-default-features --features=json,csv,ipc,simd,ffi --target wasm32-unknown-unknown + cargo build --no-default-features --features=json,csv,ipc,simd,ffi --target wasm32-wasi + + clippy: + name: Clippy + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v3 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Setup Clippy + run: | + rustup component add clippy + - name: Run clippy + run: | + cargo clippy -p arrow --features=prettyprint,csv,ipc,test_utils,ffi,ipc_compression,dyn_cmp_dict --all-targets -- -D warnings diff --git a/.github/workflows/arrow_flight.yml b/.github/workflows/arrow_flight.yml new file mode 100644 index 000000000000..86a67ff9a6a4 --- /dev/null +++ b/.github/workflows/arrow_flight.yml @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +# tests for arrow_flight crate +name: arrow_flight + + +# trigger for all PRs that touch certain files and changes to master +on: + push: + branches: + - master + pull_request: + paths: + - arrow/** + - arrow-flight/** + - .github/** + +jobs: + # test the crate + linux-test: + name: Test + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Test + run: | + cargo test -p arrow-flight + - name: Test --all-features + run: | + cargo test -p arrow-flight --all-features + + clippy: + name: Clippy + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v3 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Setup Clippy + run: | + rustup component add clippy + - name: Run clippy + run: | + cargo clippy -p arrow-flight --all-features -- -D warnings diff --git a/.github/workflows/cancel.yml b/.github/workflows/cancel.yml index b4fb904842e6..a98c8ee5d225 100644 --- a/.github/workflows/cancel.yml +++ b/.github/workflows/cancel.yml @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -name: Cancel stale runs +# Attempt to cancel stale workflow runs to save github actions runner time +name: cancel on: workflow_run: diff --git a/.github/workflows/comment_bot.yml b/.github/workflows/comment_bot.yml deleted file mode 100644 index 6ca095328af1..000000000000 --- a/.github/workflows/comment_bot.yml +++ /dev/null @@ -1,72 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Comment Bot - -on: - # TODO(kszucs): support pull_request_review_comment - issue_comment: - types: - - created - - edited - -jobs: - crossbow: - name: Listen! - if: startsWith(github.event.comment.body, '@github-actions crossbow') - runs-on: ubuntu-latest - steps: - - name: Checkout Arrow - uses: actions/checkout@v2 - with: - repository: apache/arrow - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - name: Install Archery and Crossbow dependencies - run: pip install -e dev/archery[bot] - - name: Handle Github comment event - env: - ARROW_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - CROSSBOW_GITHUB_TOKEN: ${{ secrets.CROSSBOW_GITHUB_TOKEN }} - run: | - archery trigger-bot \ - --event-name ${{ github.event_name }} \ - --event-payload ${{ github.event_path }} - - rebase: - name: "Rebase" - if: startsWith(github.event.comment.body, '@github-actions rebase') - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: r-lib/actions/pr-fetch@master - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Rebase on ${{ github.repository }} master - run: | - set -ex - git config user.name "$(git log -1 --pretty=format:%an)" - git config user.email "$(git log -1 --pretty=format:%ae)" - git remote add upstream https://github.com/${{ github.repository }} - git fetch --unshallow upstream master - git rebase upstream/master - - uses: r-lib/actions/pr-push@master - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - args: "--force" diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml new file mode 100644 index 000000000000..e688428e187c --- /dev/null +++ b/.github/workflows/coverage.yml @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: coverage + +# Trigger only on pushes to master, not pull requests +on: + push: + branches: + - master + +jobs: + + coverage: + name: Coverage + runs-on: ubuntu-latest + # Note runs outside of a container + # otherwise we get this error: + # Failed to run tests: ASLR disable failed: EPERM: Operation not permitted + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Rust toolchain + run: | + rustup toolchain install stable + rustup default stable + - name: Install protobuf compiler in /protoc + run: | + sudo mkdir /protoc + sudo chmod a+rwx /protoc + cd /protoc + curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/protoc-21.4-linux-x86_64.zip + unzip protoc-21.4-linux-x86_64.zip + - name: Cache Cargo + uses: actions/cache@v3 + with: + path: /home/runner/.cargo + key: cargo-coverage-cache3- + - name: Run coverage + run: | + export PATH=$PATH:/protoc/bin + rustup toolchain install stable + rustup default stable + cargo install --version 0.18.2 cargo-tarpaulin + cargo tarpaulin --all --out Xml + - name: Report coverage + continue-on-error: true + run: bash <(curl -s https://codecov.io/bash) diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 51569c0029a7..57dc19482761 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -15,11 +15,13 @@ # specific language governing permissions and limitations # under the License. -name: Dev +name: dev +# trigger for all PRs and changes to master on: - # always trigger push: + branches: + - master pull_request: env: @@ -32,24 +34,24 @@ jobs: name: Release Audit Tool (RAT) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Setup Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: 3.8 - name: Audit licenses run: ./dev/release/run-rat.sh . prettier: - name: Use prettier to check formatting of documents + name: Markdown format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-node@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-node@v3 with: node-version: "14" - name: Prettier check run: | - # if you encounter error, try rerun the command below with --write instead of --check - # and commit the changes - npx prettier@2.3.0 --check {arrow,arrow-flight,dev,integration-testing,parquet}/**/*.md README.md CODE_OF_CONDUCT.md CONTRIBUTING.md + # if you encounter error, run the command below and commit the changes + npx prettier@2.3.2 --write {arrow,arrow-flight,dev,integration-testing,parquet}/**/*.md README.md CODE_OF_CONDUCT.md CONTRIBUTING.md + git diff --exit-code diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 78fe37ba8a5b..38bb39390097 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. -name: Dev PR +name: dev_pr +# Trigger whenever a PR is changed (title as well as new / changed commits) on: pull_request_target: types: @@ -29,21 +30,15 @@ jobs: name: Process runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@2.2.0 + uses: actions/labeler@v4.0.1 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml sync-labels: true - - #- name: Checks if PR needs rebase - # uses: eps1lon/actions-label-merge-conflict@releases/2.x - # with: - # dirtyLabel: "needs-rebase" - # repoToken: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index a3d27cabb8ea..aadf9c377c64 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -26,3 +26,6 @@ parquet: parquet-derive: - parquet_derive/**/* + +object-store: + - object_store/**/* diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 000000000000..5e82d76febe6 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: docs + +# trigger for all PRs and changes to master +on: + push: + branches: + - master + pull_request: + +jobs: + + # test doc links still work + docs: + name: Rustdocs are clean + runs-on: ubuntu-latest + strategy: + matrix: + arch: [ amd64 ] + rust: [ nightly ] + container: + image: ${{ matrix.arch }}/rust + env: + RUSTDOCFLAGS: "-Dwarnings" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Install python dev + run: | + apt update + apt install -y libpython3.9-dev + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ matrix.rust }} + - name: Run cargo doc + run: | + cargo doc --document-private-items --no-deps --workspace --all-features diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 7eed6b8e94c9..10a8e30212a9 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -15,47 +15,103 @@ # specific language governing permissions and limitations # under the License. -name: Integration +name: integration +# trigger for all PRs that touch certain files and changes to master on: push: + branches: + - master pull_request: + paths: + - arrow/** + - arrow-pyarrow-integration-testing/** + - integration-testing/** + - .github/** jobs: integration: - name: Integration Test + name: Archery test With other arrows runs-on: ubuntu-latest + container: + image: apache/arrow-dev:amd64-conda-integration + env: + ARROW_USE_CCACHE: OFF + ARROW_CPP_EXE_PATH: /build/cpp/debug + BUILD_DOCS_CPP: OFF + # These are necessary because the github runner overrides $HOME + # https://github.com/actions/runner/issues/863 + RUSTUP_HOME: /root/.rustup + CARGO_HOME: /root/.cargo + defaults: + run: + shell: bash steps: + # This is necessary so that actions/checkout can find git + - name: Export conda path + run: echo "/opt/conda/envs/arrow/bin" >> $GITHUB_PATH + # This is necessary so that Rust can find cargo + - name: Export cargo path + run: echo "/root/.cargo/bin" >> $GITHUB_PATH + - name: Check rustup + run: which rustup + - name: Check cmake + run: which cmake - name: Checkout Arrow - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: repository: apache/arrow submodules: true fetch-depth: 0 - name: Checkout Arrow Rust - uses: actions/checkout@v2 + uses: actions/checkout@v3 with: path: rust fetch-depth: 0 - - name: Setup Python - uses: actions/setup-python@v3 - with: - python-version: 3.8 - - name: Setup Archery - run: pip install -e dev/archery[docker] - - name: Execute Docker Build - run: archery docker run -e ARCHERY_INTEGRATION_WITH_RUST=1 conda-integration + - name: Make build directory + run: mkdir /build + - name: Build Rust + run: conda run --no-capture-output ci/scripts/rust_build.sh $PWD /build + - name: Build C++ + run: conda run --no-capture-output ci/scripts/cpp_build.sh $PWD /build + - name: Build C# + run: conda run --no-capture-output ci/scripts/csharp_build.sh $PWD /build + - name: Build Go + run: conda run --no-capture-output ci/scripts/go_build.sh $PWD + - name: Build Java + run: conda run --no-capture-output ci/scripts/java_build.sh $PWD /build + # Temporarily disable JS https://issues.apache.org/jira/browse/ARROW-17410 + # - name: Build JS + # run: conda run --no-capture-output ci/scripts/js_build.sh $PWD /build + - name: Install archery + run: conda run --no-capture-output pip install -e dev/archery + - name: Run integration tests + run: | + conda run --no-capture-output archery integration \ + --run-flight \ + --with-cpp=1 \ + --with-csharp=1 \ + --with-java=1 \ + --with-js=0 \ + --with-go=1 \ + --with-rust=1 \ + --gold-dirs=testing/data/arrow-ipc-stream/integration/0.14.1 \ + --gold-dirs=testing/data/arrow-ipc-stream/integration/0.17.1 \ + --gold-dirs=testing/data/arrow-ipc-stream/integration/1.0.0-bigendian \ + --gold-dirs=testing/data/arrow-ipc-stream/integration/1.0.0-littleendian \ + --gold-dirs=testing/data/arrow-ipc-stream/integration/2.0.0-compression \ + --gold-dirs=testing/data/arrow-ipc-stream/integration/4.0.0-shareddict # test FFI against the C-Data interface exposed by pyarrow pyarrow-integration-test: - name: Test Pyarrow C Data Interface + name: Pyarrow C Data Interface runs-on: ubuntu-latest strategy: matrix: - rust: [stable] + rust: [ stable ] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: submodules: true - name: Setup Rust toolchain @@ -74,7 +130,7 @@ jobs: path: /home/runner/target # this key is not equal because maturin uses different compilation flags. key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - - uses: actions/setup-python@v3 + - uses: actions/setup-python@v4 with: python-version: '3.7' - name: Upgrade pip and setuptools diff --git a/.github/workflows/miri.yaml b/.github/workflows/miri.yaml index 7feacc07dd73..b4669bbcccc0 100644 --- a/.github/workflows/miri.yaml +++ b/.github/workflows/miri.yaml @@ -15,19 +15,24 @@ # specific language governing permissions and limitations # under the License. -name: MIRI +name: miri +# trigger for all PRs that touch certain files and changes to master on: - # always trigger push: + branches: + - master pull_request: + paths: + - arrow/** + - .github/** jobs: miri-checks: name: MIRI runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: submodules: true - name: Setup Rust toolchain diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml new file mode 100644 index 000000000000..6996aa706636 --- /dev/null +++ b/.github/workflows/object_store.yml @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +# tests for `object_store` crate +name: object_store + +# trigger for all PRs that touch certain files and changes to master +on: + push: + branches: + - master + pull_request: + paths: + - object_store/** + - .github/** + +jobs: + clippy: + name: Clippy + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v3 + - name: Setup Rust toolchain with clippy + run: | + rustup toolchain install stable + rustup default stable + rustup component add clippy + # Run different tests for the library on its own as well as + # all targets to ensure that it still works in the absence of + # features that might be enabled by dev-dependencies of other + # targets. + - name: Run clippy with default features + run: cargo clippy -p object_store -- -D warnings + - name: Run clippy with aws feature + run: cargo clippy -p object_store --features aws -- -D warnings + - name: Run clippy with gcp feature + run: cargo clippy -p object_store --features gcp -- -D warnings + - name: Run clippy with azure feature + run: cargo clippy -p object_store --features azure -- -D warnings + - name: Run clippy with all features + run: cargo clippy -p object_store --all-features -- -D warnings + - name: Run clippy with all features and all targets + run: cargo clippy -p object_store --all-features --all-targets -- -D warnings + + # test the crate + # This runs outside a container to workaround lack of support for passing arguments + # to service containers - https://github.com/orgs/community/discussions/26688 + linux-test: + name: Emulator Tests + runs-on: ubuntu-latest + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + # https://github.com/rust-lang/cargo/issues/10280 + CARGO_NET_GIT_FETCH_WITH_CLI: "true" + RUST_BACKTRACE: "1" + # Run integration tests + TEST_INTEGRATION: 1 + EC2_METADATA_ENDPOINT: http://localhost:1338 + AZURE_USE_EMULATOR: "1" + AZURITE_BLOB_STORAGE_URL: "http://localhost:10000" + AZURITE_QUEUE_STORAGE_URL: "http://localhost:10001" + GOOGLE_SERVICE_ACCOUNT: "/tmp/gcs.json" + OBJECT_STORE_BUCKET: test-bucket + + steps: + - uses: actions/checkout@v3 + + - name: Configure Fake GCS Server (GCP emulation) + run: | + docker run -d -p 4443:4443 fsouza/fake-gcs-server -scheme http + curl -v -X POST --data-binary '{"name":"test-bucket"}' -H "Content-Type: application/json" "http://localhost:4443/storage/v1/b" + echo '{"gcs_base_url": "http://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": ""}' > "$GOOGLE_SERVICE_ACCOUNT" + + - name: Setup LocalStack (AWS emulation) + env: + AWS_DEFAULT_REGION: "us-east-1" + AWS_ACCESS_KEY_ID: test + AWS_SECRET_ACCESS_KEY: test + AWS_ENDPOINT: http://localhost:4566 + run: | + docker run -d -p 4566:4566 localstack/localstack:0.14.4 + docker run -d -p 1338:1338 amazon/amazon-ec2-metadata-mock:v1.9.2 --imdsv2 + aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket + + - name: Configure Azurite (Azure emulation) + # the magical connection string is from + # https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio#http-connection-strings + run: | + docker run -d -p 10000:10000 -p 10001:10001 -p 10002:10002 mcr.microsoft.com/azure-storage/azurite + az storage container create -n test-bucket --connection-string 'DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://localhost:10000/devstoreaccount1;QueueEndpoint=http://localhost:10001/devstoreaccount1;' + + - name: Setup Rust toolchain + run: | + rustup toolchain install stable + rustup default stable + + - name: Run object_store tests + env: + OBJECT_STORE_AWS_DEFAULT_REGION: "us-east-1" + OBJECT_STORE_AWS_ACCESS_KEY_ID: test + OBJECT_STORE_AWS_SECRET_ACCESS_KEY: test + OBJECT_STORE_AWS_ENDPOINT: http://localhost:4566 + run: | + # run tests + cargo test -p object_store --features=aws,azure,gcp diff --git a/.github/workflows/parquet.yml b/.github/workflows/parquet.yml new file mode 100644 index 000000000000..42cb06bb0a86 --- /dev/null +++ b/.github/workflows/parquet.yml @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +# tests for parquet crate +name: "parquet" + + +# trigger for all PRs that touch certain files and changes to master +on: + push: + branches: + - master + pull_request: + paths: + - arrow/** + - parquet/** + - .github/** + +jobs: + # test the crate + linux-test: + name: Test + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Test + run: | + cargo test -p parquet + - name: Test --all-features + run: | + cargo test -p parquet --all-features + + + # test compilation + linux-features: + name: Check Compilation + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + + # Run different tests for the library on its own as well as + # all targets to ensure that it still works in the absence of + # features that might be enabled by dev-dependencies of other + # targets. + # + # This for each of (library and all-targets), check + # 1. compiles with default features + # 1. compiles with no default features + # 3. compiles with just arrow feature + # 3. compiles with all features + - name: Check compilation + run: | + cargo check -p parquet + - name: Check compilation --no-default-features + run: | + cargo check -p parquet --no-default-features + - name: Check compilation --no-default-features --features arrow + run: | + cargo check -p parquet --no-default-features --features arrow + - name: Check compilation --no-default-features --all-features + run: | + cargo check -p parquet --all-features + - name: Check compilation --all-targets + run: | + cargo check -p parquet --all-targets + - name: Check compilation --all-targets --no-default-features + run: | + cargo check -p parquet --all-targets --no-default-features + - name: Check compilation --all-targets --no-default-features --features arrow + run: | + cargo check -p parquet --all-targets --no-default-features --features arrow + - name: Check compilation --all-targets --all-features + run: | + cargo check -p parquet --all-targets --all-features + + clippy: + name: Clippy + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v3 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Setup Clippy + run: | + rustup component add clippy + - name: Run clippy + run: | + cargo clippy -p parquet --all-targets --all-features -- -D warnings diff --git a/.github/workflows/parquet_derive.yml b/.github/workflows/parquet_derive.yml new file mode 100644 index 000000000000..bd70fc30d1c5 --- /dev/null +++ b/.github/workflows/parquet_derive.yml @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +# tests for parquet_derive crate +name: parquet_derive + + +# trigger for all PRs that touch certain files and changes to master +on: + push: + branches: + - master + pull_request: + paths: + - parquet/** + - parquet_derive/** + - parquet_derive_test/** + - .github/** + +jobs: + # test the crate + linux-test: + name: Test + runs-on: ubuntu-latest + container: + image: amd64/rust + env: + # Disable full debug symbol generation to speed up CI build and keep memory down + # "1" means line tables only, which is useful for panic tracebacks. + RUSTFLAGS: "-C debuginfo=1" + steps: + - uses: actions/checkout@v3 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Test + run: | + cargo test -p parquet_derive + + clippy: + name: Clippy + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v3 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Setup Clippy + run: | + rustup component add clippy + - name: Run clippy + run: | + cargo clippy -p parquet_derive --all-features -- -D warnings diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index d0102c609f24..c04d5643b49a 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -15,226 +15,82 @@ # specific language governing permissions and limitations # under the License. -name: Rust +# workspace wide tests +name: rust +# trigger for all PRs and changes to master on: - # always trigger push: + branches: + - master pull_request: jobs: - # build the library, a compilation step used by multiple steps below - linux-build-lib: - name: Build Libraries on AMD64 Rust ${{ matrix.rust }} - runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ stable ] - container: - image: ${{ matrix.arch }}/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" - steps: - - uses: actions/checkout@v2 - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: ${{ matrix.rust }} - - name: Build Workspace - run: | - cargo build - - # test the crate - linux-test: - name: Test Workspace on AMD64 Rust ${{ matrix.rust }} - needs: [ linux-build-lib ] - runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ stable ] - container: - image: ${{ matrix.arch }}/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" - ARROW_TEST_DATA: /__w/arrow-rs/arrow-rs/testing/data - PARQUET_TEST_DATA: /__w/arrow-rs/arrow-rs/parquet-testing/data + # Check workspace wide compile and test with default features for + # mac + macos: + name: Test on Mac + runs-on: macos-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: submodules: true + - name: Install protoc with brew + run: | + brew install protobuf - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: ${{ matrix.rust }} + run: | + rustup toolchain install stable --no-self-update + rustup default stable - name: Run tests + shell: bash run: | - # run tests on all workspace members with default feature list + # do not produce debug symbols to keep memory usage down + export RUSTFLAGS="-C debuginfo=0" cargo test - - name: Re-run tests with all supported features - run: | - cargo test -p arrow --features=force_validate,prettyprint - - name: Run examples - run: | - # Test arrow examples - cargo run --example builders - cargo run --example dynamic_types - cargo run --example read_csv - cargo run --example read_csv_infer_schema - - name: Test compilation of arrow library crate with different feature combinations - run: | - cargo check -p arrow - cargo check -p arrow --no-default-features - - name: Test compilation of arrow targets with different feature combinations - run: | - cargo check -p arrow --all-targets - cargo check -p arrow --no-default-features --all-targets - cargo check -p arrow --no-default-features --all-targets --features test_utils - - name: Re-run tests on arrow-flight with all features - run: | - cargo test -p arrow-flight --all-features - - name: Re-run tests on parquet crate with all features - run: | - cargo test -p parquet --all-features - - name: Test compilation of parquet library crate with different feature combinations - run: | - cargo check -p parquet - cargo check -p parquet --no-default-features - cargo check -p parquet --no-default-features --features arrow - cargo check -p parquet --all-features - - name: Test compilation of parquet targets with different feature combinations - run: | - cargo check -p parquet --all-targets - cargo check -p parquet --no-default-features --all-targets - cargo check -p parquet --no-default-features --features arrow --all-targets - - name: Test compilation of parquet_derive macro with different feature combinations - run: | - cargo check -p parquet_derive - # test the --features "simd" of the arrow crate. This requires nightly. - linux-test-simd: - name: Test SIMD on AMD64 Rust ${{ matrix.rust }} - runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ nightly ] - container: - image: ${{ matrix.arch }}/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" - ARROW_TEST_DATA: /__w/arrow-rs/arrow-rs/testing/data + + # Check workspace wide compile and test with default features for + # windows + windows: + name: Test on Windows + runs-on: windows-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: submodules: true - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: ${{ matrix.rust }} - - name: Run tests - run: | - cargo test -p arrow --features "simd" - - name: Check compilation with simd features + - name: Install protobuf compiler in /d/protoc + shell: bash run: | - cargo check -p arrow --features simd - cargo check -p arrow --features simd --all-targets + mkdir /d/protoc + cd /d/protoc + curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v21.4/protoc-21.4-win64.zip + unzip protoc-21.4-win64.zip + export PATH=$PATH:/d/protoc/bin + protoc --version - windows-and-macos: - name: Test on ${{ matrix.os }} Rust ${{ matrix.rust }} - runs-on: ${{ matrix.os }} - strategy: - matrix: - os: [ windows-latest, macos-latest ] - rust: [ stable ] - steps: - - uses: actions/checkout@v2 - with: - submodules: true - # TODO: this won't cache anything, which is expensive. Setup this action - # with a OS-dependent path. - name: Setup Rust toolchain run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} + rustup toolchain install stable --no-self-update + rustup default stable - name: Run tests shell: bash run: | - export ARROW_TEST_DATA=$(pwd)/testing/data - export PARQUET_TEST_DATA=$(pwd)/parquet-testing/data # do not produce debug symbols to keep memory usage down export RUSTFLAGS="-C debuginfo=0" + export PATH=$PATH:/d/protoc/bin cargo test - clippy: - name: Clippy - needs: [ linux-build-lib ] - runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ stable ] - container: - image: ${{ matrix.arch }}/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: ${{ matrix.rust }} - - name: Setup Clippy - run: | - rustup component add clippy - - name: Run clippy - run: | - cargo clippy --features test_common --features prettyprint --features=async --all-targets --workspace -- -D warnings - - check_benches: - name: Check Benchmarks (but don't run them) - runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ stable ] - container: - image: ${{ matrix.arch }}/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: ${{ matrix.rust }} - - name: Check benchmarks - run: | - cargo check --benches --workspace --features test_common,prettyprint,async,experimental + # Run cargo fmt for all crates lint: name: Lint (cargo fmt) runs-on: ubuntu-latest container: image: amd64/rust steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Setup toolchain run: | rustup toolchain install stable @@ -242,119 +98,3 @@ jobs: rustup component add rustfmt - name: Run run: cargo fmt --all -- --check - - coverage: - name: Coverage - runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ stable ] - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /home/runner/.cargo - # this key is not equal because the user is different than on a container (runner vs github) - key: cargo-coverage-cache3- - - name: Cache Rust dependencies - uses: actions/cache@v3 - with: - path: /home/runner/target - # this key is not equal because coverage uses different compilation flags. - key: ${{ runner.os }}-${{ matrix.arch }}-target-coverage-cache3-${{ matrix.rust }}- - - name: Run coverage - run: | - export CARGO_HOME="/home/runner/.cargo" - export CARGO_TARGET_DIR="/home/runner/target" - - export ARROW_TEST_DATA=$(pwd)/testing/data - export PARQUET_TEST_DATA=$(pwd)/parquet-testing/data - - rustup toolchain install stable - rustup default stable - cargo install --version 0.18.2 cargo-tarpaulin - cargo tarpaulin --all --out Xml - - name: Report coverage - continue-on-error: true - run: bash <(curl -s https://codecov.io/bash) - - # test the arrow crate builds against wasm32 in stable rust - wasm32-build: - name: Build wasm32 on AMD64 Rust ${{ matrix.rust }} - runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ nightly ] - container: - image: ${{ matrix.arch }}/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" - ARROW_TEST_DATA: /__w/arrow-rs/arrow-rs/testing/data - PARQUET_TEST_DATA: /__w/arrow/arrow/parquet-testing/data - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Cache Cargo - uses: actions/cache@v3 - with: - path: /github/home/.cargo - key: cargo-wasm32-cache3- - - name: Cache Rust dependencies - uses: actions/cache@v3 - with: - path: /github/home/target - key: ${{ runner.os }}-${{ matrix.arch }}-target-wasm32-cache3-${{ matrix.rust }} - - name: Setup Rust toolchain for WASM - run: | - rustup toolchain install ${{ matrix.rust }} - rustup override set ${{ matrix.rust }} - rustup target add wasm32-unknown-unknown - rustup target add wasm32-wasi - - name: Build arrow crate - run: | - cd arrow - cargo build --no-default-features --features=csv,ipc,simd --target wasm32-unknown-unknown - cargo build --no-default-features --features=csv,ipc,simd --target wasm32-wasi - - # test doc links still work - docs: - name: Docs are clean on AMD64 Rust ${{ matrix.rust }} - runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ nightly ] - container: - image: ${{ matrix.arch }}/rust - env: - # Disable full debug symbol generation to speed up CI build and keep memory down - # "1" means line tables only, which is useful for panic tracebacks. - RUSTFLAGS: "-C debuginfo=1" - RUSTDOCFLAGS: "-Dwarnings" - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Install python dev - run: | - apt update - apt install -y libpython3.9-dev - - name: Setup Rust toolchain - uses: ./.github/actions/setup-builder - with: - rust-version: ${{ matrix.rust }} - - name: Run cargo doc - run: | - cargo doc --document-private-items --no-deps --workspace --all-features diff --git a/.github_changelog_generator b/.github_changelog_generator index cc23a6332d60..9a9a84344866 100644 --- a/.github_changelog_generator +++ b/.github_changelog_generator @@ -24,5 +24,5 @@ add-sections={"documentation":{"prefix":"**Documentation updates:**","labels":[" #pull-requests=false # so that the component is shown associated with the issue issue-line-labels=arrow,parquet,arrow-flight -exclude-labels=development-process,invalid +exclude-labels=development-process,invalid,object-store breaking_labels=api-change diff --git a/.gitignore b/.gitignore index 2088dd5d2068..2a21776aa545 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,85 @@ venv/* parquet/data.parquet # release notes cache .githubchangeloggenerator.cache -.githubchangeloggenerator.cache.log \ No newline at end of file +.githubchangeloggenerator.cache.log +justfile +.prettierignore +.env +# local azurite file +__azurite* +__blobstorage__ + +# .bak files +*.bak + +# OS-specific .gitignores + +# Mac .gitignore +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# Linux .gitignore +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +# Windows .gitignore +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md index b4923bfb0d8e..70322b5cfd1d 100644 --- a/CHANGELOG-old.md +++ b/CHANGELOG-old.md @@ -20,6 +20,450 @@ # Historical Changelog +## [21.0.0](https://github.com/apache/arrow-rs/tree/21.0.0) (2022-08-18) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/20.0.0...21.0.0) + +**Breaking changes:** + +- Return structured `ColumnCloseResult` \(\#2465\) [\#2466](https://github.com/apache/arrow-rs/pull/2466) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Push `ChunkReader` into `SerializedPageReader` \(\#2463\) [\#2464](https://github.com/apache/arrow-rs/pull/2464) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Revise FromIterator for Decimal128Array to use Into instead of Borrow [\#2442](https://github.com/apache/arrow-rs/pull/2442) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Use Fixed-Length Array in BasicDecimal new and raw\_value [\#2405](https://github.com/apache/arrow-rs/pull/2405) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove deprecated ParquetWriter [\#2380](https://github.com/apache/arrow-rs/pull/2380) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove deprecated SliceableCursor and InMemoryWriteableCursor [\#2378](https://github.com/apache/arrow-rs/pull/2378) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- add into\_inner method to ArrowWriter [\#2491](https://github.com/apache/arrow-rs/issues/2491) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Remove byteorder dependency [\#2472](https://github.com/apache/arrow-rs/issues/2472) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Return Structured ColumnCloseResult from GenericColumnWriter::close [\#2465](https://github.com/apache/arrow-rs/issues/2465) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Push `ChunkReader` into `SerializedPageReader` [\#2463](https://github.com/apache/arrow-rs/issues/2463) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support SerializedPageReader::skip\_page without OffsetIndex [\#2459](https://github.com/apache/arrow-rs/issues/2459) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support Time64/Time32 comparison [\#2457](https://github.com/apache/arrow-rs/issues/2457) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Revise FromIterator for Decimal128Array to use Into instead of Borrow [\#2441](https://github.com/apache/arrow-rs/issues/2441) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support `RowFilter` within`ParquetRecordBatchReader` [\#2431](https://github.com/apache/arrow-rs/issues/2431) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Remove the field `StructBuilder::len` [\#2429](https://github.com/apache/arrow-rs/issues/2429) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Standardize creation and configuration of parquet --\> Arrow readers \( `ParquetRecordBatchReaderBuilder`\) [\#2427](https://github.com/apache/arrow-rs/issues/2427) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Use `OffsetIndex` to Prune IO in `ParquetRecordBatchStream` [\#2426](https://github.com/apache/arrow-rs/issues/2426) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support `peek_next_page` and `skip_next_page` in `InMemoryPageReader` [\#2406](https://github.com/apache/arrow-rs/issues/2406) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support casting from `Utf8`/`LargeUtf8` to `Binary`/`LargeBinary` [\#2402](https://github.com/apache/arrow-rs/issues/2402) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support casting between `Decimal128` and `Decimal256` arrays [\#2375](https://github.com/apache/arrow-rs/issues/2375) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Combine multiple selections into the same batch size in `skip_records` [\#2358](https://github.com/apache/arrow-rs/issues/2358) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add API to change timezone for timestamp array [\#2346](https://github.com/apache/arrow-rs/issues/2346) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Change the output of `read_buffer` Arrow IPC API to return `Result<_>` [\#2342](https://github.com/apache/arrow-rs/issues/2342) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow `skip_records` in `GenericColumnReader` to skip across row groups [\#2331](https://github.com/apache/arrow-rs/issues/2331) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Optimize the validation of `Decimal256` [\#2320](https://github.com/apache/arrow-rs/issues/2320) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement Skip for `DeltaBitPackDecoder` [\#2281](https://github.com/apache/arrow-rs/issues/2281) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Changes to `ParquetRecordBatchStream` to support row filtering in DataFusion [\#2270](https://github.com/apache/arrow-rs/issues/2270) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add `ArrayReader::skip_records` API [\#2197](https://github.com/apache/arrow-rs/issues/2197) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- Panic in SerializedPageReader without offset index [\#2503](https://github.com/apache/arrow-rs/issues/2503) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- MapArray columns don't handle null values correctly [\#2484](https://github.com/apache/arrow-rs/issues/2484) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- There is no compiler error when using an invalid Decimal type. [\#2440](https://github.com/apache/arrow-rs/issues/2440) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Flight SQL Server sends incorrect response for `DoPutUpdateResult` [\#2403](https://github.com/apache/arrow-rs/issues/2403) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- `AsyncFileReader`No Longer Object-Safe [\#2372](https://github.com/apache/arrow-rs/issues/2372) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- StructBuilder Does not Verify Child Lengths [\#2252](https://github.com/apache/arrow-rs/issues/2252) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Combine `DecimalArray` validation [\#2447](https://github.com/apache/arrow-rs/issues/2447) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- Fix bug in page skipping [\#2504](https://github.com/apache/arrow-rs/pull/2504) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Fix `MapArrayReader` \(\#2484\) \(\#1699\) \(\#1561\) [\#2500](https://github.com/apache/arrow-rs/pull/2500) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add API to Retrieve Finished Writer from Parquet Writer [\#2498](https://github.com/apache/arrow-rs/pull/2498) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jiacai2050](https://github.com/jiacai2050)) +- Derive Copy,Clone for BasicDecimal [\#2495](https://github.com/apache/arrow-rs/pull/2495) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- remove byteorder dependency from parquet [\#2486](https://github.com/apache/arrow-rs/pull/2486) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([psvri](https://github.com/psvri)) +- parquet-read: add support to read parquet data from stdin [\#2482](https://github.com/apache/arrow-rs/pull/2482) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nvartolomei](https://github.com/nvartolomei)) +- Remove Position trait \(\#1163\) [\#2479](https://github.com/apache/arrow-rs/pull/2479) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add ChunkReader::get\_bytes [\#2478](https://github.com/apache/arrow-rs/pull/2478) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- RFC: Simplify decimal \(\#2440\) [\#2477](https://github.com/apache/arrow-rs/pull/2477) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Use Parquet OffsetIndex to prune IO with RowSelection [\#2473](https://github.com/apache/arrow-rs/pull/2473) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Remove unnecessary Option from Int96 [\#2471](https://github.com/apache/arrow-rs/pull/2471) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- remove len field from StructBuilder [\#2468](https://github.com/apache/arrow-rs/pull/2468) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Make Parquet reader filter APIs public \(\#1792\) [\#2467](https://github.com/apache/arrow-rs/pull/2467) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- enable ipc compression feature for integration test [\#2462](https://github.com/apache/arrow-rs/pull/2462) ([liukun4515](https://github.com/liukun4515)) +- Simplify implementation of Schema [\#2461](https://github.com/apache/arrow-rs/pull/2461) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Support skip\_page missing OffsetIndex Fallback in SerializedPageReader [\#2460](https://github.com/apache/arrow-rs/pull/2460) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- support time32/time64 comparison [\#2458](https://github.com/apache/arrow-rs/pull/2458) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([waitingkuo](https://github.com/waitingkuo)) +- Utf8array casting [\#2456](https://github.com/apache/arrow-rs/pull/2456) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Remove outdated license text [\#2455](https://github.com/apache/arrow-rs/pull/2455) ([alamb](https://github.com/alamb)) +- Support RowFilter within ParquetRecordBatchReader \(\#2431\) [\#2452](https://github.com/apache/arrow-rs/pull/2452) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- benchmark: decimal builder and vec to decimal array [\#2450](https://github.com/apache/arrow-rs/pull/2450) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Collocate Decimal Array Validation Logic [\#2446](https://github.com/apache/arrow-rs/pull/2446) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Minor: Move From trait for Decimal256 impl to decimal.rs [\#2443](https://github.com/apache/arrow-rs/pull/2443) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- decimal benchmark: arrow reader decimal from parquet int32 and int64 [\#2438](https://github.com/apache/arrow-rs/pull/2438) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- MINOR: Simplify `split_second` function [\#2436](https://github.com/apache/arrow-rs/pull/2436) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add ParquetRecordBatchReaderBuilder \(\#2427\) [\#2435](https://github.com/apache/arrow-rs/pull/2435) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- refactor: refine validation for decimal128 array [\#2428](https://github.com/apache/arrow-rs/pull/2428) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Benchmark of casting decimal arrays [\#2424](https://github.com/apache/arrow-rs/pull/2424) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Test non-annotated repeated fields \(\#2394\) [\#2422](https://github.com/apache/arrow-rs/pull/2422) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix \#2416 Automatic version updates for github actions with dependabot [\#2417](https://github.com/apache/arrow-rs/pull/2417) ([iemejia](https://github.com/iemejia)) +- Add validation logic for StructBuilder::finish [\#2413](https://github.com/apache/arrow-rs/pull/2413) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- test: add test for reading decimal value from primitive array reader [\#2411](https://github.com/apache/arrow-rs/pull/2411) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- Upgrade ahash to 0.8 [\#2410](https://github.com/apache/arrow-rs/pull/2410) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Support peek\_next\_page and skip\_next\_page in InMemoryPageReader [\#2407](https://github.com/apache/arrow-rs/pull/2407) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Fix DoPutUpdateResult [\#2404](https://github.com/apache/arrow-rs/pull/2404) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Implement Skip for DeltaBitPackDecoder [\#2393](https://github.com/apache/arrow-rs/pull/2393) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- fix: Don't instantiate the scalar composition code quadratically for dictionaries [\#2391](https://github.com/apache/arrow-rs/pull/2391) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Marwes](https://github.com/Marwes)) +- MINOR: Remove unused trait and some cleanup [\#2389](https://github.com/apache/arrow-rs/pull/2389) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Decouple parquet fuzz tests from converter \(\#1661\) [\#2386](https://github.com/apache/arrow-rs/pull/2386) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Rewrite `Decimal` and `DecimalArray` using `const_generic` [\#2383](https://github.com/apache/arrow-rs/pull/2383) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Simplify BitReader \(~5-10% faster\) [\#2381](https://github.com/apache/arrow-rs/pull/2381) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix parquet clippy lints \(\#1254\) [\#2377](https://github.com/apache/arrow-rs/pull/2377) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Cast between `Decimal128` and `Decimal256` arrays [\#2376](https://github.com/apache/arrow-rs/pull/2376) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- support compression for IPC with revamped feature flags [\#2369](https://github.com/apache/arrow-rs/pull/2369) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Implement AsyncFileReader for `Box` [\#2368](https://github.com/apache/arrow-rs/pull/2368) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove get\_byte\_ranges where bound [\#2366](https://github.com/apache/arrow-rs/pull/2366) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- refactor: Make read\_num\_bytes a function instead of a macro [\#2364](https://github.com/apache/arrow-rs/pull/2364) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Marwes](https://github.com/Marwes)) +- refactor: Group metrics into page and column metrics structs [\#2363](https://github.com/apache/arrow-rs/pull/2363) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Marwes](https://github.com/Marwes)) +- Speed up `Decimal256` validation based on bytes comparison and add benchmark test [\#2360](https://github.com/apache/arrow-rs/pull/2360) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Combine multiple selections into the same batch size in skip\_records [\#2359](https://github.com/apache/arrow-rs/pull/2359) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Add API to change timezone for timestamp array [\#2347](https://github.com/apache/arrow-rs/pull/2347) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Clean the code in `field.rs` and add more tests [\#2345](https://github.com/apache/arrow-rs/pull/2345) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Add Parquet RowFilter API [\#2335](https://github.com/apache/arrow-rs/pull/2335) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Make skip\_records in complex\_object\_array can skip cross row groups [\#2332](https://github.com/apache/arrow-rs/pull/2332) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Integrate Record Skipping into Column Reader Fuzz Test [\#2315](https://github.com/apache/arrow-rs/pull/2315) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) + +## [20.0.0](https://github.com/apache/arrow-rs/tree/20.0.0) (2022-08-05) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/19.0.0...20.0.0) + +**Breaking changes:** + +- Add more const evaluation for `GenericBinaryArray` and `GenericListArray`: add `PREFIX` and data type constructor [\#2327](https://github.com/apache/arrow-rs/pull/2327) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Make FFI support optional, change APIs to be `safe` \(\#2302\) [\#2303](https://github.com/apache/arrow-rs/pull/2303) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove `test_utils` from default features \(\#2298\) [\#2299](https://github.com/apache/arrow-rs/pull/2299) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Rename `DataType::Decimal` to `DataType::Decimal128` [\#2229](https://github.com/apache/arrow-rs/pull/2229) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add `Decimal128Iter` and `Decimal256Iter` and do maximum precision/scale check [\#2140](https://github.com/apache/arrow-rs/pull/2140) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) + +**Implemented enhancements:** + +- Add the constant data type constructors for `ListArray` [\#2311](https://github.com/apache/arrow-rs/issues/2311) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update `FlightSqlService` trait to pass session info along [\#2308](https://github.com/apache/arrow-rs/issues/2308) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Optimize `take_bits` for non-null indices [\#2306](https://github.com/apache/arrow-rs/issues/2306) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make FFI support optional via Feature Flag `ffi` [\#2302](https://github.com/apache/arrow-rs/issues/2302) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Mark `ffi::ArrowArray::try_new` is safe [\#2301](https://github.com/apache/arrow-rs/issues/2301) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove test\_utils from default arrow-rs features [\#2298](https://github.com/apache/arrow-rs/issues/2298) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Remove `JsonEqual` trait [\#2296](https://github.com/apache/arrow-rs/issues/2296) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Move `with_precision_and_scale` to `Decimal` array traits [\#2291](https://github.com/apache/arrow-rs/issues/2291) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve readability and maybe performance of string --\> numeric/time/date/timetamp cast kernels [\#2285](https://github.com/apache/arrow-rs/issues/2285) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add vectorized unpacking for 8, 16, and 64 bit integers [\#2276](https://github.com/apache/arrow-rs/issues/2276) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Use initial capacity for interner hashmap [\#2273](https://github.com/apache/arrow-rs/issues/2273) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Impl FromIterator for Decimal256Array [\#2248](https://github.com/apache/arrow-rs/issues/2248) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Separate `ArrayReader::next_batch`with `ArrayReader::read_records` and `ArrayReader::consume_batch` [\#2236](https://github.com/apache/arrow-rs/issues/2236) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Rename `DataType::Decimal` to `DataType::Decimal128` [\#2228](https://github.com/apache/arrow-rs/issues/2228) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Automatically Grow Parquet BitWriter Buffer [\#2226](https://github.com/apache/arrow-rs/issues/2226) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add `append_option` support to `Decimal128Builder` and `Decimal256Builder` [\#2224](https://github.com/apache/arrow-rs/issues/2224) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Split the `FixedSizeBinaryArray` and `FixedSizeListArray` from `array_binary.rs` and `array_list.rs` [\#2217](https://github.com/apache/arrow-rs/issues/2217) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Don't `Box` Values in `PrimitiveDictionaryBuilder` [\#2215](https://github.com/apache/arrow-rs/issues/2215) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use BitChunks in equal\_bits [\#2186](https://github.com/apache/arrow-rs/issues/2186) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `Hash` for `Schema` [\#2182](https://github.com/apache/arrow-rs/issues/2182) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- read decimal data type from parquet file with binary physical type [\#2159](https://github.com/apache/arrow-rs/issues/2159) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- The `GenericStringBuilder` should use `GenericBinaryBuilder` [\#2156](https://github.com/apache/arrow-rs/issues/2156) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update Rust version to 1.62 [\#2143](https://github.com/apache/arrow-rs/issues/2143) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Check precision and scale against maximum value when constructing `Decimal128` and `Decimal256` [\#2139](https://github.com/apache/arrow-rs/issues/2139) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `ArrayAccessor` in `Decimal128Iter` and `Decimal256Iter` [\#2138](https://github.com/apache/arrow-rs/issues/2138) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `ArrayAccessor` and `FromIterator` in Cast Kernels [\#2137](https://github.com/apache/arrow-rs/issues/2137) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `TypedDictionaryArray` for more ergonomic interaction with `DictionaryArray` [\#2136](https://github.com/apache/arrow-rs/issues/2136) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `ArrayAccessor` in Comparison Kernels [\#2135](https://github.com/apache/arrow-rs/issues/2135) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `peek_next_page()` and s`kip_next_page` in `InMemoryColumnChunkReader` [\#2129](https://github.com/apache/arrow-rs/issues/2129) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Lazily materialize the null buffer builder for all array builders. [\#2125](https://github.com/apache/arrow-rs/issues/2125) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Do value validation for `Decimal256` [\#2112](https://github.com/apache/arrow-rs/issues/2112) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `skip_def_levels` for `ColumnLevelDecoder` [\#2107](https://github.com/apache/arrow-rs/issues/2107) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add integration test for scan rows with selection [\#2106](https://github.com/apache/arrow-rs/issues/2106) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support for casting from Utf8/String to `Time32` / `Time64` [\#2053](https://github.com/apache/arrow-rs/issues/2053) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Update prost and tonic related crates [\#2268](https://github.com/apache/arrow-rs/pull/2268) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([carols10cents](https://github.com/carols10cents)) + +**Fixed bugs:** + +- temporal conversion functions cannot work on negative input properly [\#2325](https://github.com/apache/arrow-rs/issues/2325) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- IPC writer should truncate string array with all empty string [\#2312](https://github.com/apache/arrow-rs/issues/2312) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Error order for comparing `Decimal128` or `Decimal256` [\#2256](https://github.com/apache/arrow-rs/issues/2256) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix maximum and minimum for decimal values for precision greater than 38 [\#2246](https://github.com/apache/arrow-rs/issues/2246) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `IntervalMonthDayNanoType::make_value()` does not match C implementation [\#2234](https://github.com/apache/arrow-rs/issues/2234) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `FlightSqlService` trait does not allow `impl`s to do handshake [\#2210](https://github.com/apache/arrow-rs/issues/2210) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- `EnabledStatistics::None` not working [\#2185](https://github.com/apache/arrow-rs/issues/2185) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Boolean ArrayData Equality Incorrect Slice Handling [\#2184](https://github.com/apache/arrow-rs/issues/2184) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Publicly export MapFieldNames [\#2118](https://github.com/apache/arrow-rs/issues/2118) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- Update instructions on How to join the slack \#arrow-rust channel -- or maybe try to switch to discord?? [\#2192](https://github.com/apache/arrow-rs/issues/2192) +- \[Minor\] Improve arrow and parquet READMEs, document parquet feature flags [\#2324](https://github.com/apache/arrow-rs/pull/2324) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- Improve speed of writing string dictionaries to parquet by skipping a copy\(\#1764\) [\#2322](https://github.com/apache/arrow-rs/pull/2322) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) + +**Closed issues:** + +- Fix wrong logic in calculate\_row\_count when skipping values [\#2328](https://github.com/apache/arrow-rs/issues/2328) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support filter for parquet data type [\#2126](https://github.com/apache/arrow-rs/issues/2126) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make skip value in ByteArrayDecoderDictionary avoid decoding [\#2088](https://github.com/apache/arrow-rs/issues/2088) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- fix: Fix skip error in calculate\_row\_count. [\#2329](https://github.com/apache/arrow-rs/pull/2329) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- temporal conversion functions should work on negative input properly [\#2326](https://github.com/apache/arrow-rs/pull/2326) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Increase DeltaBitPackEncoder miniblock size to 64 for 64-bit integers \(\#2282\) [\#2319](https://github.com/apache/arrow-rs/pull/2319) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove JsonEqual [\#2317](https://github.com/apache/arrow-rs/pull/2317) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- fix: IPC writer should truncate string array with all empty string [\#2314](https://github.com/apache/arrow-rs/pull/2314) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JasonLi-cn](https://github.com/JasonLi-cn)) +- Pass pull `Request` to `FlightSqlService` `impl`s [\#2309](https://github.com/apache/arrow-rs/pull/2309) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Speedup take\_boolean / take\_bits for non-null indices \(~4 - 5x speedup\) [\#2307](https://github.com/apache/arrow-rs/pull/2307) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Add typed dictionary \(\#2136\) [\#2297](https://github.com/apache/arrow-rs/pull/2297) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- \[Minor\] Improve types shown in cast error messages [\#2295](https://github.com/apache/arrow-rs/pull/2295) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Move `with_precision_and_scale` to `BasicDecimalArray` trait [\#2292](https://github.com/apache/arrow-rs/pull/2292) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Replace the `fn get_data_type` by `const DATA_TYPE` in BinaryArray and StringArray [\#2289](https://github.com/apache/arrow-rs/pull/2289) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Clean up string casts and improve performance [\#2284](https://github.com/apache/arrow-rs/pull/2284) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- \[Minor\] Add tests for temporal cast error paths [\#2283](https://github.com/apache/arrow-rs/pull/2283) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add unpack8, unpack16, unpack64 \(\#2276\) ~10-50% faster [\#2278](https://github.com/apache/arrow-rs/pull/2278) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix bugs in the `from_list` function. [\#2277](https://github.com/apache/arrow-rs/pull/2277) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- fix: use signed comparator to compare decimal128 and decimal256 [\#2275](https://github.com/apache/arrow-rs/pull/2275) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Use initial capacity for interner hashmap [\#2272](https://github.com/apache/arrow-rs/pull/2272) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Dandandan](https://github.com/Dandandan)) +- Remove fallibility from paruqet RleEncoder \(\#2226\) [\#2259](https://github.com/apache/arrow-rs/pull/2259) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix escaped like wildcards in `like_utf8` / `nlike_utf8` kernels [\#2258](https://github.com/apache/arrow-rs/pull/2258) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([daniel-martinez-maqueda-sap](https://github.com/daniel-martinez-maqueda-sap)) +- Add tests for reading nested decimal arrays from parquet [\#2254](https://github.com/apache/arrow-rs/pull/2254) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- feat: Implement string cast operations for Time32 and Time64 [\#2251](https://github.com/apache/arrow-rs/pull/2251) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([stuartcarnie](https://github.com/stuartcarnie)) +- move `FixedSizeList` to `array_fixed_size_list.rs` [\#2250](https://github.com/apache/arrow-rs/pull/2250) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Impl FromIterator for Decimal256Array [\#2247](https://github.com/apache/arrow-rs/pull/2247) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix max and min value for decimal precision greater than 38 [\#2245](https://github.com/apache/arrow-rs/pull/2245) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Make `Schema::fields` and `Schema::metadata` `pub` \(public\) [\#2239](https://github.com/apache/arrow-rs/pull/2239) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- \[Minor\] Improve Schema metadata mismatch error [\#2238](https://github.com/apache/arrow-rs/pull/2238) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Separate ArrayReader::next\_batch with read\_records and consume\_batch [\#2237](https://github.com/apache/arrow-rs/pull/2237) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Update `IntervalMonthDayNanoType::make_value()` to conform to specifications [\#2235](https://github.com/apache/arrow-rs/pull/2235) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([avantgardnerio](https://github.com/avantgardnerio)) +- Disable value validation for Decimal256 case [\#2232](https://github.com/apache/arrow-rs/pull/2232) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Automatically grow parquet BitWriter \(\#2226\) \(~10% faster\) [\#2231](https://github.com/apache/arrow-rs/pull/2231) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Only trigger `arrow` CI on changes to arrow [\#2227](https://github.com/apache/arrow-rs/pull/2227) ([alamb](https://github.com/alamb)) +- Add append\_option support to decimal builders [\#2225](https://github.com/apache/arrow-rs/pull/2225) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([bphillips-exos](https://github.com/bphillips-exos)) +- Optimized writing of byte array to parquet \(\#1764\) \(2x faster\) [\#2221](https://github.com/apache/arrow-rs/pull/2221) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Increase test coverage of ArrowWriter [\#2220](https://github.com/apache/arrow-rs/pull/2220) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Update instructions on how to join the Slack channel [\#2219](https://github.com/apache/arrow-rs/pull/2219) ([HaoYang670](https://github.com/HaoYang670)) +- Move `FixedSizeBinaryArray` to `array_fixed_size_binary.rs` [\#2218](https://github.com/apache/arrow-rs/pull/2218) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Avoid boxing in PrimitiveDictionaryBuilder [\#2216](https://github.com/apache/arrow-rs/pull/2216) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- remove redundant CI benchmark check, cleanups [\#2212](https://github.com/apache/arrow-rs/pull/2212) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Update `FlightSqlService` trait to proxy handshake [\#2211](https://github.com/apache/arrow-rs/pull/2211) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) +- parquet: export json api with `serde_json` feature name [\#2209](https://github.com/apache/arrow-rs/pull/2209) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([flisky](https://github.com/flisky)) +- Cleanup record skipping logic and tests \(\#2158\) [\#2199](https://github.com/apache/arrow-rs/pull/2199) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use BitChunks in equal\_bits [\#2194](https://github.com/apache/arrow-rs/pull/2194) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Fix disabling parquet statistics \(\#2185\) [\#2191](https://github.com/apache/arrow-rs/pull/2191) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Change CI names to match crate names [\#2189](https://github.com/apache/arrow-rs/pull/2189) ([alamb](https://github.com/alamb)) +- Fix offset handling in boolean\_equal \(\#2184\) [\#2187](https://github.com/apache/arrow-rs/pull/2187) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Implement `Hash` for `Schema` [\#2183](https://github.com/apache/arrow-rs/pull/2183) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Let the `StringBuilder` use `BinaryBuilder` [\#2181](https://github.com/apache/arrow-rs/pull/2181) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Use ArrayAccessor and FromIterator in Cast Kernels [\#2169](https://github.com/apache/arrow-rs/pull/2169) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Split most arrow specific CI checks into their own workflows \(reduce common CI time to 21 minutes\) [\#2168](https://github.com/apache/arrow-rs/pull/2168) ([alamb](https://github.com/alamb)) +- Remove another attempt to cache target directory in action.yaml [\#2167](https://github.com/apache/arrow-rs/pull/2167) ([alamb](https://github.com/alamb)) +- Run actions on push to master, pull requests [\#2166](https://github.com/apache/arrow-rs/pull/2166) ([alamb](https://github.com/alamb)) +- Break parquet\_derive and arrow\_flight tests into their own workflows [\#2165](https://github.com/apache/arrow-rs/pull/2165) ([alamb](https://github.com/alamb)) +- \[minor\] use type aliases refine code. [\#2161](https://github.com/apache/arrow-rs/pull/2161) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- parquet reader: Support reading decimals from parquet `BYTE_ARRAY` type [\#2160](https://github.com/apache/arrow-rs/pull/2160) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- Add integration test for scan rows with selection [\#2158](https://github.com/apache/arrow-rs/pull/2158) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Use ArrayAccessor in Comparison Kernels [\#2157](https://github.com/apache/arrow-rs/pull/2157) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Implement `peek\_next\_page` and `skip\_next\_page` for `InMemoryColumnCh… [\#2155](https://github.com/apache/arrow-rs/pull/2155) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Avoid decoding unneeded values in ByteArrayDecoderDictionary [\#2154](https://github.com/apache/arrow-rs/pull/2154) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Only run integration tests when `arrow` changes [\#2152](https://github.com/apache/arrow-rs/pull/2152) ([alamb](https://github.com/alamb)) +- Break out docs CI job to its own github action [\#2151](https://github.com/apache/arrow-rs/pull/2151) ([alamb](https://github.com/alamb)) +- Do not pretend to cache rust build artifacts, speed up CI by ~20% [\#2150](https://github.com/apache/arrow-rs/pull/2150) ([alamb](https://github.com/alamb)) +- Update rust version to 1.62 [\#2144](https://github.com/apache/arrow-rs/pull/2144) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Make MapFieldNames public \(\#2118\) [\#2134](https://github.com/apache/arrow-rs/pull/2134) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add ArrayAccessor trait, remove duplication in array iterators \(\#1948\) [\#2133](https://github.com/apache/arrow-rs/pull/2133) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Lazily materialize the null buffer builder for all array builders. [\#2127](https://github.com/apache/arrow-rs/pull/2127) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Faster parquet DictEncoder \(~20%\) [\#2123](https://github.com/apache/arrow-rs/pull/2123) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add validation for Decimal256 [\#2113](https://github.com/apache/arrow-rs/pull/2113) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support skip\_def\_levels for ColumnLevelDecoder [\#2111](https://github.com/apache/arrow-rs/pull/2111) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Donate `object_store` code from object\_store\_rs to arrow-rs [\#2081](https://github.com/apache/arrow-rs/pull/2081) ([alamb](https://github.com/alamb)) +- Improve `validate_utf8` performance [\#2048](https://github.com/apache/arrow-rs/pull/2048) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tfeda](https://github.com/tfeda)) + +## [19.0.0](https://github.com/apache/arrow-rs/tree/19.0.0) (2022-07-22) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/18.0.0...19.0.0) + +**Breaking changes:** + +- Rename `DecimalArray``/DecimalBuilder` to `Decimal128Array`/`Decimal128Builder` [\#2101](https://github.com/apache/arrow-rs/issues/2101) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Change builder `append` methods to be infallible where possible [\#2103](https://github.com/apache/arrow-rs/pull/2103) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Return reference from `UnionArray::child` \(\#2035\) [\#2099](https://github.com/apache/arrow-rs/pull/2099) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove `preserve_order` feature from `serde_json` dependency \(\#2095\) [\#2098](https://github.com/apache/arrow-rs/pull/2098) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Rename `weekday` and `weekday0` kernels to to `num_days_from_monday` and `num_days_since_sunday` [\#2066](https://github.com/apache/arrow-rs/pull/2066) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Remove `null_count` from `write_batch_with_statistics` [\#2047](https://github.com/apache/arrow-rs/pull/2047) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) + +**Implemented enhancements:** + +- Use `total_cmp` from std [\#2130](https://github.com/apache/arrow-rs/issues/2130) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Permit parallel fetching of column chunks in `ParquetRecordBatchStream` [\#2110](https://github.com/apache/arrow-rs/issues/2110) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- The `GenericBinaryBuilder` should use buffer builders directly. [\#2104](https://github.com/apache/arrow-rs/issues/2104) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Pass `generate_decimal256_case` arrow integration test [\#2093](https://github.com/apache/arrow-rs/issues/2093) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Rename `weekday` and `weekday0` kernels to to `num_days_from_monday` and `days_since_sunday` [\#2065](https://github.com/apache/arrow-rs/issues/2065) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of `filter_dict` [\#2062](https://github.com/apache/arrow-rs/issues/2062) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of `set_bits` [\#2060](https://github.com/apache/arrow-rs/issues/2060) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Lazily materialize the null buffer builder of `BooleanBuilder` [\#2058](https://github.com/apache/arrow-rs/issues/2058) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `BooleanArray::from_iter` should omit validity buffer if all values are valid [\#2055](https://github.com/apache/arrow-rs/issues/2055) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- FFI\_ArrowSchema should set `DICTIONARY_ORDERED` flag if a field's dictionary is ordered [\#2049](https://github.com/apache/arrow-rs/issues/2049) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `peek_next_page()` and `skip_next_page` in `SerializedPageReader` [\#2043](https://github.com/apache/arrow-rs/issues/2043) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support FFI / C Data Interface for `MapType` [\#2037](https://github.com/apache/arrow-rs/issues/2037) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- The `DecimalArrayBuilder` should use `FixedSizedBinaryBuilder` [\#2026](https://github.com/apache/arrow-rs/issues/2026) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Enable `serialized_reader` read specific Page by passing row ranges. [\#1976](https://github.com/apache/arrow-rs/issues/1976) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- `type_id` and `value_offset` are incorrect for sliced `UnionArray` [\#2086](https://github.com/apache/arrow-rs/issues/2086) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Boolean `take` kernel does not handle null indices correctly [\#2057](https://github.com/apache/arrow-rs/issues/2057) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Don't double-count nulls in `write_batch_with_statistics` [\#2046](https://github.com/apache/arrow-rs/issues/2046) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet Writer Ignores Statistics specification in `WriterProperties` [\#2014](https://github.com/apache/arrow-rs/issues/2014) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Improve docstrings + examples for `as_primitive_array` cast functions [\#2114](https://github.com/apache/arrow-rs/pull/2114) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- Why does `serde_json` specify the `preserve_order` feature in `arrow` package [\#2095](https://github.com/apache/arrow-rs/issues/2095) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `skip_values` in DictionaryDecoder [\#2079](https://github.com/apache/arrow-rs/issues/2079) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support skip\_values in ColumnValueDecoderImpl [\#2078](https://github.com/apache/arrow-rs/issues/2078) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support `skip_values` in `ByteArrayColumnValueDecoder` [\#2072](https://github.com/apache/arrow-rs/issues/2072) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Several `Builder::append` methods returning results even though they are infallible [\#2071](https://github.com/apache/arrow-rs/issues/2071) +- Improve formatting of logical plans containing subqueries [\#2059](https://github.com/apache/arrow-rs/issues/2059) +- Return reference from `UnionArray::child` [\#2035](https://github.com/apache/arrow-rs/issues/2035) +- support write page index [\#1777](https://github.com/apache/arrow-rs/issues/1777) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Use `total_cmp` from std [\#2131](https://github.com/apache/arrow-rs/pull/2131) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- fix clippy [\#2124](https://github.com/apache/arrow-rs/pull/2124) ([alamb](https://github.com/alamb)) +- Fix logical merge conflict: `match` arms have incompatible types [\#2121](https://github.com/apache/arrow-rs/pull/2121) ([alamb](https://github.com/alamb)) +- Update `GenericBinaryBuilder` to use buffer builders directly. [\#2117](https://github.com/apache/arrow-rs/pull/2117) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Simplify null mask preservation in parquet reader [\#2116](https://github.com/apache/arrow-rs/pull/2116) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add get\_byte\_ranges method to AsyncFileReader trait [\#2115](https://github.com/apache/arrow-rs/pull/2115) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- add test for skip\_values in DictionaryDecoder and fix it [\#2105](https://github.com/apache/arrow-rs/pull/2105) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Define Decimal128Builder and Decimal128Array [\#2102](https://github.com/apache/arrow-rs/pull/2102) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support skip\_values in DictionaryDecoder [\#2100](https://github.com/apache/arrow-rs/pull/2100) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Pass generate\_decimal256\_case integration test, add `DataType::Decimal256` [\#2094](https://github.com/apache/arrow-rs/pull/2094) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- `DecimalBuilder` should use `FixedSizeBinaryBuilder` [\#2092](https://github.com/apache/arrow-rs/pull/2092) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Array writer indirection [\#2091](https://github.com/apache/arrow-rs/pull/2091) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Remove doc hidden from GenericColumnReader [\#2090](https://github.com/apache/arrow-rs/pull/2090) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Support skip\_values in ColumnValueDecoderImpl [\#2089](https://github.com/apache/arrow-rs/pull/2089) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- type\_id and value\_offset are incorrect for sliced UnionArray [\#2087](https://github.com/apache/arrow-rs/pull/2087) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add IPC truncation test case for StructArray [\#2083](https://github.com/apache/arrow-rs/pull/2083) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Improve performance of set\_bits by using copy\_from\_slice instead of setting individual bytes [\#2077](https://github.com/apache/arrow-rs/pull/2077) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Support skip\_values in ByteArrayColumnValueDecoder [\#2076](https://github.com/apache/arrow-rs/pull/2076) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Lazily materialize the null buffer builder of boolean builder [\#2073](https://github.com/apache/arrow-rs/pull/2073) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix windows CI \(\#2069\) [\#2070](https://github.com/apache/arrow-rs/pull/2070) ([tustvold](https://github.com/tustvold)) +- Test utf8\_validation checks char boundaries [\#2068](https://github.com/apache/arrow-rs/pull/2068) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- feat\(compute\): Support doy \(day of year\) for temporal [\#2067](https://github.com/apache/arrow-rs/pull/2067) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ovr](https://github.com/ovr)) +- Support nullable indices in boolean take kernel and some optimizations [\#2064](https://github.com/apache/arrow-rs/pull/2064) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Improve performance of filter\_dict [\#2063](https://github.com/apache/arrow-rs/pull/2063) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Ignore null buffer when creating ArrayData if null count is zero [\#2056](https://github.com/apache/arrow-rs/pull/2056) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- feat\(compute\): Support week0 \(PostgreSQL behaviour\) for temporal [\#2052](https://github.com/apache/arrow-rs/pull/2052) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ovr](https://github.com/ovr)) +- Set DICTIONARY\_ORDERED flag for FFI\_ArrowSchema [\#2050](https://github.com/apache/arrow-rs/pull/2050) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Generify parquet write path \(\#1764\) [\#2045](https://github.com/apache/arrow-rs/pull/2045) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Support peek\_next\_page\(\) and skip\_next\_page in serialized\_reader. [\#2044](https://github.com/apache/arrow-rs/pull/2044) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Support MapType in FFI [\#2042](https://github.com/apache/arrow-rs/pull/2042) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add support of converting `FixedSizeBinaryArray` to `DecimalArray` [\#2041](https://github.com/apache/arrow-rs/pull/2041) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Truncate IPC record batch [\#2040](https://github.com/apache/arrow-rs/pull/2040) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Refine the List builder [\#2034](https://github.com/apache/arrow-rs/pull/2034) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Add more tests of RecordReader Batch Size Edge Cases \(\#2025\) [\#2032](https://github.com/apache/arrow-rs/pull/2032) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add support for adding intervals to dates [\#2031](https://github.com/apache/arrow-rs/pull/2031) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([avantgardnerio](https://github.com/avantgardnerio)) + +## [18.0.0](https://github.com/apache/arrow-rs/tree/18.0.0) (2022-07-08) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/17.0.0...18.0.0) + +**Breaking changes:** + +- Fix several bugs in parquet writer statistics generation, add `EnabledStatistics` to control level of statistics generated [\#2022](https://github.com/apache/arrow-rs/pull/2022) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add page index reader test for all types and support empty index. [\#2012](https://github.com/apache/arrow-rs/pull/2012) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Add `Decimal256Builder` and `Decimal256Array`; Decimal arrays now implement `BasicDecimalArray` trait [\#2000](https://github.com/apache/arrow-rs/pull/2000) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Simplify `ColumnReader::read_batch` [\#1995](https://github.com/apache/arrow-rs/pull/1995) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Remove `PrimitiveBuilder::finish_dict` \(\#1978\) [\#1980](https://github.com/apache/arrow-rs/pull/1980) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Disallow cast from other datatypes to `NullType` [\#1942](https://github.com/apache/arrow-rs/pull/1942) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) +- Add column index writer for parquet [\#1935](https://github.com/apache/arrow-rs/pull/1935) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) + +**Implemented enhancements:** + +- Add `DataType::Dictionary` support to `subtract_scalar`, `multiply_scalar`, `divide_scalar` [\#2019](https://github.com/apache/arrow-rs/issues/2019) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support DictionaryArray in `add_scalar` kernel [\#2017](https://github.com/apache/arrow-rs/issues/2017) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Enable column page index read test for all types [\#2010](https://github.com/apache/arrow-rs/issues/2010) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Simplify `FixedSizeBinaryBuilder` [\#2007](https://github.com/apache/arrow-rs/issues/2007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `Decimal256Builder` and `Decimal256Array` [\#1999](https://github.com/apache/arrow-rs/issues/1999) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `DictionaryArray` in `unary` kernel [\#1989](https://github.com/apache/arrow-rs/issues/1989) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add kernel to quickly compute comparisons on `Array`s [\#1987](https://github.com/apache/arrow-rs/issues/1987) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `DictionaryArray` in `divide` kernel [\#1982](https://github.com/apache/arrow-rs/issues/1982) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `Into` for `T: Array` [\#1979](https://github.com/apache/arrow-rs/issues/1979) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `DictionaryArray` in `multiply` kernel [\#1972](https://github.com/apache/arrow-rs/issues/1972) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `DictionaryArray` in `subtract` kernel [\#1970](https://github.com/apache/arrow-rs/issues/1970) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Declare `DecimalArray::length` as a constant [\#1967](https://github.com/apache/arrow-rs/issues/1967) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `DictionaryArray` in `add` kernel [\#1950](https://github.com/apache/arrow-rs/issues/1950) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add builder style methods to `Field` [\#1934](https://github.com/apache/arrow-rs/issues/1934) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Make `StringDictionaryBuilder` faster [\#1851](https://github.com/apache/arrow-rs/issues/1851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `concat_elements_utf8` should accept arbitrary number of input arrays [\#1748](https://github.com/apache/arrow-rs/issues/1748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Array reader for list columns fails to decode if batches fall on row group boundaries [\#2025](https://github.com/apache/arrow-rs/issues/2025) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `ColumnWriterImpl::write_batch_with_statistics` incorrect distinct count in statistics [\#2016](https://github.com/apache/arrow-rs/issues/2016) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `ColumnWriterImpl::write_batch_with_statistics` can write incorrect page statistics [\#2015](https://github.com/apache/arrow-rs/issues/2015) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `RowFormatter` is not part of the public api [\#2008](https://github.com/apache/arrow-rs/issues/2008) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Infinite Loop possible in `ColumnReader::read_batch` For Corrupted Files [\#1997](https://github.com/apache/arrow-rs/issues/1997) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `PrimitiveBuilder::finish_dict` does not validate dictionary offsets [\#1978](https://github.com/apache/arrow-rs/issues/1978) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect `n_buffers` in `FFI_ArrowArray` [\#1959](https://github.com/apache/arrow-rs/issues/1959) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `DecimalArray::from_fixed_size_list_array` fails when `offset > 0` [\#1958](https://github.com/apache/arrow-rs/issues/1958) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect \(but ignored\) metadata written after ColumnChunk [\#1946](https://github.com/apache/arrow-rs/issues/1946) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `Send` + `Sync` impl for `Allocation` may not be sound unless `Allocation` is `Send` + `Sync` as well [\#1944](https://github.com/apache/arrow-rs/issues/1944) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Disallow cast from other datatypes to `NullType` [\#1923](https://github.com/apache/arrow-rs/issues/1923) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- The doc of `FixedSizeListArray::value_length` is incorrect. [\#1908](https://github.com/apache/arrow-rs/issues/1908) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Closed issues:** + +- Column chunk statistics of `min_bytes` and `max_bytes` return wrong size [\#2021](https://github.com/apache/arrow-rs/issues/2021) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Discussion\] Refactor the `Decimal`s by using constant generic. [\#2001](https://github.com/apache/arrow-rs/issues/2001) +- Move `DecimalArray` to a new file [\#1985](https://github.com/apache/arrow-rs/issues/1985) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `DictionaryArray` in `multiply` kernel [\#1974](https://github.com/apache/arrow-rs/issues/1974) +- close function instead of mutable reference [\#1969](https://github.com/apache/arrow-rs/issues/1969) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Incorrect `null_count` of DictionaryArray [\#1962](https://github.com/apache/arrow-rs/issues/1962) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support multi diskRanges for ChunkReader [\#1955](https://github.com/apache/arrow-rs/issues/1955) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Persisting Arrow timestamps with Parquet produces missing `TIMESTAMP` in schema [\#1920](https://github.com/apache/arrow-rs/issues/1920) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Sperate get\_next\_page\_header from get\_next\_page in PageReader [\#1834](https://github.com/apache/arrow-rs/issues/1834) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Consistent case in Index enumeration [\#2029](https://github.com/apache/arrow-rs/pull/2029) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Fix record delimiting on row group boundaries \(\#2025\) [\#2027](https://github.com/apache/arrow-rs/pull/2027) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add builder style APIs For `Field`: `with_name`, `with_data_type` and `with_nullable` [\#2024](https://github.com/apache/arrow-rs/pull/2024) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add dictionary support to subtract\_scalar, multiply\_scalar, divide\_scalar [\#2020](https://github.com/apache/arrow-rs/pull/2020) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support DictionaryArray in add\_scalar kernel [\#2018](https://github.com/apache/arrow-rs/pull/2018) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Refine the `FixedSizeBinaryBuilder` [\#2013](https://github.com/apache/arrow-rs/pull/2013) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Add RowFormatter to record public API [\#2009](https://github.com/apache/arrow-rs/pull/2009) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([FabioBatSilva](https://github.com/FabioBatSilva)) +- Fix parquet test\_common feature flags [\#2003](https://github.com/apache/arrow-rs/pull/2003) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Stub out Skip Records API \(\#1792\) [\#1998](https://github.com/apache/arrow-rs/pull/1998) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) +- Implement `Into` for `T: Array` [\#1992](https://github.com/apache/arrow-rs/pull/1992) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([heyrutvik](https://github.com/heyrutvik)) +- Add unary\_cmp [\#1991](https://github.com/apache/arrow-rs/pull/1991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support DictionaryArray in unary kernel [\#1990](https://github.com/apache/arrow-rs/pull/1990) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Refine `FixedSizeListBuilder` [\#1988](https://github.com/apache/arrow-rs/pull/1988) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Move `DecimalArray` to array\_decimal.rs [\#1986](https://github.com/apache/arrow-rs/pull/1986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- MINOR: Fix clippy error after updating rust toolchain [\#1984](https://github.com/apache/arrow-rs/pull/1984) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) +- Support dictionary array for divide kernel [\#1983](https://github.com/apache/arrow-rs/pull/1983) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support dictionary array for subtract and multiply kernel [\#1971](https://github.com/apache/arrow-rs/pull/1971) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Declare the value\_length of decimal array as a `const` [\#1968](https://github.com/apache/arrow-rs/pull/1968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Fix the behavior of `from_fixed_size_list` when offset \> 0 [\#1964](https://github.com/apache/arrow-rs/pull/1964) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Calculate n\_buffers in FFI\_ArrowArray by data layout [\#1960](https://github.com/apache/arrow-rs/pull/1960) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix the doc of `FixedSizeListArray::value_length` [\#1957](https://github.com/apache/arrow-rs/pull/1957) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Use InMemoryColumnChunkReader \(~20% faster\) [\#1956](https://github.com/apache/arrow-rs/pull/1956) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Unpin clap \(\#1867\) [\#1954](https://github.com/apache/arrow-rs/pull/1954) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Set is\_adjusted\_to\_utc if any timezone set \(\#1932\) [\#1953](https://github.com/apache/arrow-rs/pull/1953) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add add\_dyn for DictionaryArray support [\#1951](https://github.com/apache/arrow-rs/pull/1951) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- write `ColumnMetadata` after the column chunk data, not the `ColumnChunk` [\#1947](https://github.com/apache/arrow-rs/pull/1947) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- Require Send+Sync bounds for Allocation trait [\#1945](https://github.com/apache/arrow-rs/pull/1945) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Faster StringDictionaryBuilder \(~60% faster\) \(\#1851\) [\#1861](https://github.com/apache/arrow-rs/pull/1861) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Arbitrary size concat elements utf8 [\#1787](https://github.com/apache/arrow-rs/pull/1787) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Ismail-Maj](https://github.com/Ismail-Maj)) + ## [17.0.0](https://github.com/apache/arrow-rs/tree/17.0.0) (2022-06-24) [Full Changelog](https://github.com/apache/arrow-rs/compare/16.0.0...17.0.0) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7954e07a4c8a..69f2b8af6cf8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,100 +19,119 @@ # Changelog -## [18.0.0](https://github.com/apache/arrow-rs/tree/18.0.0) (2022-07-08) +## [22.0.0](https://github.com/apache/arrow-rs/tree/22.0.0) (2022-09-02) -[Full Changelog](https://github.com/apache/arrow-rs/compare/17.0.0...18.0.0) +[Full Changelog](https://github.com/apache/arrow-rs/compare/21.0.0...22.0.0) **Breaking changes:** -- Fix several bugs in parquet writer statistics generation, add `EnabledStatistics` to control level of statistics generated [\#2022](https://github.com/apache/arrow-rs/pull/2022) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Add page index reader test for all types and support empty index. [\#2012](https://github.com/apache/arrow-rs/pull/2012) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) -- Add `Decimal256Builder` and `Decimal256Array`; Decimal arrays now implement `BasicDecimalArray` trait [\#2000](https://github.com/apache/arrow-rs/pull/2000) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Simplify `ColumnReader::read_batch` [\#1995](https://github.com/apache/arrow-rs/pull/1995) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Remove `PrimitiveBuilder::finish_dict` \(\#1978\) [\#1980](https://github.com/apache/arrow-rs/pull/1980) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Disallow cast from other datatypes to `NullType` [\#1942](https://github.com/apache/arrow-rs/pull/1942) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liukun4515](https://github.com/liukun4515)) -- Add column index writer for parquet [\#1935](https://github.com/apache/arrow-rs/pull/1935) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2614](https://github.com/apache/arrow-rs/pull/2614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Gate dyn comparison of dictionary arrays behind `dyn_cmp_dict` [\#2597](https://github.com/apache/arrow-rs/pull/2597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Move JsonSerializable to json module \(\#2300\) [\#2595](https://github.com/apache/arrow-rs/pull/2595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Decimal precision scale datatype change [\#2532](https://github.com/apache/arrow-rs/pull/2532) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor PrimitiveBuilder Constructors [\#2518](https://github.com/apache/arrow-rs/pull/2518) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactoring DecimalBuilder constructors [\#2517](https://github.com/apache/arrow-rs/pull/2517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor FixedSizeBinaryBuilder Constructors [\#2516](https://github.com/apache/arrow-rs/pull/2516) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor BooleanBuilder Constructors [\#2515](https://github.com/apache/arrow-rs/pull/2515) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Refactor UnionBuilder Constructors [\#2488](https://github.com/apache/arrow-rs/pull/2488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) **Implemented enhancements:** -- Add `DataType::Dictionary` support to `subtract_scalar`, `multiply_scalar`, `divide_scalar` [\#2019](https://github.com/apache/arrow-rs/issues/2019) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support DictionaryArray in `add_scalar` kernel [\#2017](https://github.com/apache/arrow-rs/issues/2017) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Enable column page index read test for all types [\#2010](https://github.com/apache/arrow-rs/issues/2010) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Simplify `FixedSizeBinaryBuilder` [\#2007](https://github.com/apache/arrow-rs/issues/2007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support `Decimal256Builder` and `Decimal256Array` [\#1999](https://github.com/apache/arrow-rs/issues/1999) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support `DictionaryArray` in `unary` kernel [\#1989](https://github.com/apache/arrow-rs/issues/1989) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add kernel to quickly compute comparisons on `Array`s [\#1987](https://github.com/apache/arrow-rs/issues/1987) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support `DictionaryArray` in `divide` kernel [\#1982](https://github.com/apache/arrow-rs/issues/1982) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Implement `Into` for `T: Array` [\#1979](https://github.com/apache/arrow-rs/issues/1979) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support `DictionaryArray` in `multiply` kernel [\#1972](https://github.com/apache/arrow-rs/issues/1972) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support `DictionaryArray` in `subtract` kernel [\#1970](https://github.com/apache/arrow-rs/issues/1970) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Declare `DecimalArray::length` as a constant [\#1967](https://github.com/apache/arrow-rs/issues/1967) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support `DictionaryArray` in `add` kernel [\#1950](https://github.com/apache/arrow-rs/issues/1950) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add builder style methods to `Field` [\#1934](https://github.com/apache/arrow-rs/issues/1934) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Make `StringDictionaryBuilder` faster [\#1851](https://github.com/apache/arrow-rs/issues/1851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- `concat_elements_utf8` should accept arbitrary number of input arrays [\#1748](https://github.com/apache/arrow-rs/issues/1748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add Macros to assist with static dispatch [\#2635](https://github.com/apache/arrow-rs/issues/2635) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support comparison between DictionaryArray and BooleanArray [\#2617](https://github.com/apache/arrow-rs/issues/2617) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use `total_cmp` for floating value ordering and remove `nan_ordering` feature flag [\#2613](https://github.com/apache/arrow-rs/issues/2613) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support empty projection in CSV, JSON readers [\#2603](https://github.com/apache/arrow-rs/issues/2603) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support SQL-compliant NaN ordering between for DictionaryArray and non-DictionaryArray [\#2599](https://github.com/apache/arrow-rs/issues/2599) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `dyn_cmp_dict` feature flag to gate dyn comparison of dictionary arrays [\#2596](https://github.com/apache/arrow-rs/issues/2596) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2584](https://github.com/apache/arrow-rs/issues/2584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow FlightSQL implementers to extend `do_get()` [\#2581](https://github.com/apache/arrow-rs/issues/2581) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Support SQL-compliant behavior on `eq_dyn`, `neq_dyn`, `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2569](https://github.com/apache/arrow-rs/issues/2569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add sql-compliant feature for enabling sql-compliant kernel behavior [\#2568](https://github.com/apache/arrow-rs/issues/2568) +- Calculate `sum` for dictionary array [\#2565](https://github.com/apache/arrow-rs/issues/2565) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add test for float nan comparison [\#2556](https://github.com/apache/arrow-rs/issues/2556) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with string array [\#2548](https://github.com/apache/arrow-rs/issues/2548) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with primitive array in `lt_dyn`, `lt_eq_dyn`, `gt_dyn`, `gt_eq_dyn` [\#2538](https://github.com/apache/arrow-rs/issues/2538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2535](https://github.com/apache/arrow-rs/issues/2535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- UnionBuilder Create Children With Capacity [\#2523](https://github.com/apache/arrow-rs/issues/2523) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speed up `like_utf8_scalar` for `%pat%` [\#2519](https://github.com/apache/arrow-rs/issues/2519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Replace macro with TypedDictionaryArray in comparison kernels [\#2513](https://github.com/apache/arrow-rs/issues/2513) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use same codebase for boolean kernels [\#2507](https://github.com/apache/arrow-rs/issues/2507) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use u8 for Decimal Precision and Scale [\#2496](https://github.com/apache/arrow-rs/issues/2496) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Integrate skip row without pageIndex in SerializedPageReader in Fuzz Test [\#2475](https://github.com/apache/arrow-rs/issues/2475) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Avoid unecessary copies in Arrow IPC reader [\#2437](https://github.com/apache/arrow-rs/issues/2437) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add GenericColumnReader::skip\_records Missing OffsetIndex Fallback [\#2433](https://github.com/apache/arrow-rs/issues/2433) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support Reading PageIndex with ParquetRecordBatchStream [\#2430](https://github.com/apache/arrow-rs/issues/2430) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Specialize FixedLenByteArrayReader for Parquet [\#2318](https://github.com/apache/arrow-rs/issues/2318) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make JSON support Optional via Feature Flag [\#2300](https://github.com/apache/arrow-rs/issues/2300) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Fixed bugs:** -- Array reader for list columns fails to decode if batches fall on row group boundaries [\#2025](https://github.com/apache/arrow-rs/issues/2025) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `ColumnWriterImpl::write_batch_with_statistics` incorrect distinct count in statistics [\#2016](https://github.com/apache/arrow-rs/issues/2016) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `ColumnWriterImpl::write_batch_with_statistics` can write incorrect page statistics [\#2015](https://github.com/apache/arrow-rs/issues/2015) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `RowFormatter` is not part of the public api [\#2008](https://github.com/apache/arrow-rs/issues/2008) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Infinite Loop possible in `ColumnReader::read_batch` For Corrupted Files [\#1997](https://github.com/apache/arrow-rs/issues/1997) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `PrimitiveBuilder::finish_dict` does not validate dictionary offsets [\#1978](https://github.com/apache/arrow-rs/issues/1978) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Incorrect `n_buffers` in `FFI_ArrowArray` [\#1959](https://github.com/apache/arrow-rs/issues/1959) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- `DecimalArray::from_fixed_size_list_array` fails when `offset > 0` [\#1958](https://github.com/apache/arrow-rs/issues/1958) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Incorrect \(but ignored\) metadata written after ColumnChunk [\#1946](https://github.com/apache/arrow-rs/issues/1946) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `Send` + `Sync` impl for `Allocation` may not be sound unless `Allocation` is `Send` + `Sync` as well [\#1944](https://github.com/apache/arrow-rs/issues/1944) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Disallow cast from other datatypes to `NullType` [\#1923](https://github.com/apache/arrow-rs/issues/1923) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Casting timestamp array to string should not ignore timezone [\#2607](https://github.com/apache/arrow-rs/issues/2607) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Ilike\_ut8\_scalar kernals have incorrect logic [\#2544](https://github.com/apache/arrow-rs/issues/2544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Always validate the array data when creating array in IPC reader [\#2541](https://github.com/apache/arrow-rs/issues/2541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Int96Converter Truncates Timestamps [\#2480](https://github.com/apache/arrow-rs/issues/2480) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Error Reading Page Index When Not Available [\#2434](https://github.com/apache/arrow-rs/issues/2434) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `ParquetFileArrowReader::get_record_reader[_by_colum]` `batch_size` overallocates [\#2321](https://github.com/apache/arrow-rs/issues/2321) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] **Documentation updates:** -- The doc of `FixedSizeListArray::value_length` is incorrect. [\#1908](https://github.com/apache/arrow-rs/issues/1908) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Document All Arrow Features in docs.rs [\#2633](https://github.com/apache/arrow-rs/issues/2633) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Closed issues:** -- Column chunk statistics of `min_bytes` and `max_bytes` return wrong size [\#2021](https://github.com/apache/arrow-rs/issues/2021) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Discussion\] Refactor the `Decimal`s by using constant generic. [\#2001](https://github.com/apache/arrow-rs/issues/2001) -- Move `DecimalArray` to a new file [\#1985](https://github.com/apache/arrow-rs/issues/1985) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support `DictionaryArray` in `multiply` kernel [\#1974](https://github.com/apache/arrow-rs/issues/1974) -- close function instead of mutable reference [\#1969](https://github.com/apache/arrow-rs/issues/1969) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Incorrect `null_count` of DictionaryArray [\#1962](https://github.com/apache/arrow-rs/issues/1962) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support multi diskRanges for ChunkReader [\#1955](https://github.com/apache/arrow-rs/issues/1955) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Persisting Arrow timestamps with Parquet produces missing `TIMESTAMP` in schema [\#1920](https://github.com/apache/arrow-rs/issues/1920) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Sperate get\_next\_page\_header from get\_next\_page in PageReader [\#1834](https://github.com/apache/arrow-rs/issues/1834) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add support for CAST from `Interval(DayTime)` to `Timestamp(Nanosecond, None)` [\#2606](https://github.com/apache/arrow-rs/issues/2606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Why do we check for null in TypedDictionaryArray value function [\#2564](https://github.com/apache/arrow-rs/issues/2564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add the `length` field for `Buffer` [\#2524](https://github.com/apache/arrow-rs/issues/2524) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Avoid large over allocate buffer in async reader [\#2512](https://github.com/apache/arrow-rs/issues/2512) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Rewriting Decimal Builders using `const_generic`. [\#2390](https://github.com/apache/arrow-rs/issues/2390) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Rewrite Decimal Array using `const_generic` [\#2384](https://github.com/apache/arrow-rs/issues/2384) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Merged pull requests:** -- Consistent case in Index enumeration [\#2029](https://github.com/apache/arrow-rs/pull/2029) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Fix record delimiting on row group boundaries \(\#2025\) [\#2027](https://github.com/apache/arrow-rs/pull/2027) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Add builder style APIs For `Field`: `with_name`, `with_data_type` and `with_nullable` [\#2024](https://github.com/apache/arrow-rs/pull/2024) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Add dictionary support to subtract\_scalar, multiply\_scalar, divide\_scalar [\#2020](https://github.com/apache/arrow-rs/pull/2020) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Support DictionaryArray in add\_scalar kernel [\#2018](https://github.com/apache/arrow-rs/pull/2018) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Refine the `FixedSizeBinaryBuilder` [\#2013](https://github.com/apache/arrow-rs/pull/2013) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- Add RowFormatter to record public API [\#2009](https://github.com/apache/arrow-rs/pull/2009) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([FabioBatSilva](https://github.com/FabioBatSilva)) -- Fix parquet test\_common feature flags [\#2003](https://github.com/apache/arrow-rs/pull/2003) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Stub out Skip Records API \(\#1792\) [\#1998](https://github.com/apache/arrow-rs/pull/1998) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([tustvold](https://github.com/tustvold)) -- Implement `Into` for `T: Array` [\#1992](https://github.com/apache/arrow-rs/pull/1992) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([heyrutvik](https://github.com/heyrutvik)) -- Add unary\_cmp [\#1991](https://github.com/apache/arrow-rs/pull/1991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Support DictionaryArray in unary kernel [\#1990](https://github.com/apache/arrow-rs/pull/1990) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Refine `FixedSizeListBuilder` [\#1988](https://github.com/apache/arrow-rs/pull/1988) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- Move `DecimalArray` to array\_decimal.rs [\#1986](https://github.com/apache/arrow-rs/pull/1986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- MINOR: Fix clippy error after updating rust toolchain [\#1984](https://github.com/apache/arrow-rs/pull/1984) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([viirya](https://github.com/viirya)) -- Support dictionary array for divide kernel [\#1983](https://github.com/apache/arrow-rs/pull/1983) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Support dictionary array for subtract and multiply kernel [\#1971](https://github.com/apache/arrow-rs/pull/1971) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Declare the value\_length of decimal array as a `const` [\#1968](https://github.com/apache/arrow-rs/pull/1968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- Fix the behavior of `from_fixed_size_list` when offset \> 0 [\#1964](https://github.com/apache/arrow-rs/pull/1964) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- Calculate n\_buffers in FFI\_ArrowArray by data layout [\#1960](https://github.com/apache/arrow-rs/pull/1960) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- Fix the doc of `FixedSizeListArray::value_length` [\#1957](https://github.com/apache/arrow-rs/pull/1957) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) -- Use InMemoryColumnChunkReader \(~20% faster\) [\#1956](https://github.com/apache/arrow-rs/pull/1956) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Unpin clap \(\#1867\) [\#1954](https://github.com/apache/arrow-rs/pull/1954) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) -- Set is\_adjusted\_to\_utc if any timezone set \(\#1932\) [\#1953](https://github.com/apache/arrow-rs/pull/1953) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Add add\_dyn for DictionaryArray support [\#1951](https://github.com/apache/arrow-rs/pull/1951) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) -- write `ColumnMetadata` after the column chunk data, not the `ColumnChunk` [\#1947](https://github.com/apache/arrow-rs/pull/1947) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) -- Require Send+Sync bounds for Allocation trait [\#1945](https://github.com/apache/arrow-rs/pull/1945) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) -- Faster StringDictionaryBuilder \(~60% faster\) \(\#1851\) [\#1861](https://github.com/apache/arrow-rs/pull/1861) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) -- Arbitrary size concat elements utf8 [\#1787](https://github.com/apache/arrow-rs/pull/1787) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Ismail-Maj](https://github.com/Ismail-Maj)) +- Add downcast macros \(\#2635\) [\#2636](https://github.com/apache/arrow-rs/pull/2636) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Document all arrow features in docs.rs \(\#2633\) [\#2634](https://github.com/apache/arrow-rs/pull/2634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Document dyn\_cmp\_dict [\#2624](https://github.com/apache/arrow-rs/pull/2624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support comparison between DictionaryArray and BooleanArray [\#2618](https://github.com/apache/arrow-rs/pull/2618) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Cast timestamp array to string array with timezone [\#2608](https://github.com/apache/arrow-rs/pull/2608) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Support empty projection in CSV and JSON readers [\#2604](https://github.com/apache/arrow-rs/pull/2604) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Make JSON support optional via a feature flag \(\#2300\) [\#2601](https://github.com/apache/arrow-rs/pull/2601) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Support SQL-compliant NaN ordering for DictionaryArray and non-DictionaryArray [\#2600](https://github.com/apache/arrow-rs/pull/2600) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Split out integration test plumbing \(\#2594\) \(\#2300\) [\#2598](https://github.com/apache/arrow-rs/pull/2598) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Refactor Binary Builder and String Builder Constructors [\#2592](https://github.com/apache/arrow-rs/pull/2592) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Dictionary like scalar kernels [\#2591](https://github.com/apache/arrow-rs/pull/2591) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Validate dictionary key in TypedDictionaryArray \(\#2578\) [\#2589](https://github.com/apache/arrow-rs/pull/2589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- Add max\_dyn and min\_dyn for max/min for dictionary array [\#2585](https://github.com/apache/arrow-rs/pull/2585) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Code cleanup of array value functions [\#2583](https://github.com/apache/arrow-rs/pull/2583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Allow overriding of do\_get & export useful macro [\#2582](https://github.com/apache/arrow-rs/pull/2582) [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([avantgardnerio](https://github.com/avantgardnerio)) +- MINOR: Upgrade to pyo3 0.17 [\#2576](https://github.com/apache/arrow-rs/pull/2576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([andygrove](https://github.com/andygrove)) +- Support SQL-compliant NaN behavior on eq\_dyn, neq\_dyn, lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn [\#2570](https://github.com/apache/arrow-rs/pull/2570) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add sum\_dyn to calculate sum for dictionary array [\#2566](https://github.com/apache/arrow-rs/pull/2566) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- struct UnionBuilder will create child buffers with capacity [\#2560](https://github.com/apache/arrow-rs/pull/2560) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kastolars](https://github.com/kastolars)) +- Don't panic on RleValueEncoder::flush\_buffer if empty \(\#2558\) [\#2559](https://github.com/apache/arrow-rs/pull/2559) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Add the `length` field for Buffer and use more `Buffer` in IPC reader to avoid memory copy. [\#2557](https://github.com/apache/arrow-rs/pull/2557) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([HaoYang670](https://github.com/HaoYang670)) +- Add test for float nan comparison [\#2555](https://github.com/apache/arrow-rs/pull/2555) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Compare dictionary array with string array [\#2549](https://github.com/apache/arrow-rs/pull/2549) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Always validate the array data \(except the `Decimal`\) when creating array in IPC reader [\#2547](https://github.com/apache/arrow-rs/pull/2547) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- MINOR: Fix test\_row\_type\_validation test [\#2546](https://github.com/apache/arrow-rs/pull/2546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Fix ilike\_utf8\_scalar kernals [\#2545](https://github.com/apache/arrow-rs/pull/2545) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- fix typo [\#2540](https://github.com/apache/arrow-rs/pull/2540) ([00Masato](https://github.com/00Masato)) +- Compare dictionary array and primitive array in lt\_dyn, lt\_eq\_dyn, gt\_dyn, gt\_eq\_dyn kernels [\#2539](https://github.com/apache/arrow-rs/pull/2539) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- \[MINOR\]Avoid large over allocate buffer in async reader [\#2537](https://github.com/apache/arrow-rs/pull/2537) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Compare dictionary with primitive array in `eq_dyn` and `neq_dyn` [\#2533](https://github.com/apache/arrow-rs/pull/2533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Add iterator for FixedSizeBinaryArray [\#2531](https://github.com/apache/arrow-rs/pull/2531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- add bench: decimal with byte array and fixed length byte array [\#2529](https://github.com/apache/arrow-rs/pull/2529) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liukun4515](https://github.com/liukun4515)) +- Add FixedLengthByteArrayReader Remove ComplexObjectArrayReader [\#2528](https://github.com/apache/arrow-rs/pull/2528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Split out byte array decoders \(\#2318\) [\#2527](https://github.com/apache/arrow-rs/pull/2527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Use offset index in ParquetRecordBatchStream [\#2526](https://github.com/apache/arrow-rs/pull/2526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([thinkharderdev](https://github.com/thinkharderdev)) +- Clean the `create_array` in IPC reader. [\#2525](https://github.com/apache/arrow-rs/pull/2525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Remove DecimalByteArrayConvert \(\#2480\) [\#2522](https://github.com/apache/arrow-rs/pull/2522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) +- Improve performance of `%pat%` \(\>3x speedup\) [\#2521](https://github.com/apache/arrow-rs/pull/2521) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- remove len field from MapBuilder [\#2520](https://github.com/apache/arrow-rs/pull/2520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([psvri](https://github.com/psvri)) +- Replace macro with TypedDictionaryArray in comparison kernels [\#2514](https://github.com/apache/arrow-rs/pull/2514) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Avoid large over allocate buffer in sync reader [\#2511](https://github.com/apache/arrow-rs/pull/2511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Ted-Jiang](https://github.com/Ted-Jiang)) +- Avoid useless memory copies in IPC reader. [\#2510](https://github.com/apache/arrow-rs/pull/2510) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([HaoYang670](https://github.com/HaoYang670)) +- Refactor boolean kernels to use same codebase [\#2508](https://github.com/apache/arrow-rs/pull/2508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove Int96Converter \(\#2480\) [\#2481](https://github.com/apache/arrow-rs/pull/2481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([tustvold](https://github.com/tustvold)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4e4c53e5e2bd..67121f6cd5a3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -78,7 +78,7 @@ git submodule update --init This populates data in two git submodules: -- `../parquet_testing/data` (sourced from https://github.com/apache/parquet-testing.git) +- `../parquet-testing/data` (sourced from https://github.com/apache/parquet-testing.git) - `../testing` (sourced from https://github.com/apache/arrow-testing) By default, `cargo test` will look for these directories at their diff --git a/Cargo.toml b/Cargo.toml index 2837f028e8c4..9bf55c0f2360 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ members = [ "parquet_derive_test", "arrow-flight", "integration-testing", + "object_store", ] # Enable the version 2 feature resolver, which avoids unifying features for targets that are not being built # diff --git a/LICENSE.txt b/LICENSE.txt index 4cec07fd0c99..d74c6b599d2a 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -201,674 +201,6 @@ See the License for the specific language governing permissions and limitations under the License. --------------------------------------------------------------------------------- - -src/plasma/fling.cc and src/plasma/fling.h: Apache 2.0 - -Copyright 2013 Sharvil Nanavati - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - --------------------------------------------------------------------------------- - -src/plasma/thirdparty/ae: Modified / 3-Clause BSD - -Copyright (c) 2006-2010, Salvatore Sanfilippo -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of Redis nor the names of its contributors may be used - to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -src/plasma/thirdparty/dlmalloc.c: CC0 - -This is a version (aka dlmalloc) of malloc/free/realloc written by -Doug Lea and released to the public domain, as explained at -http://creativecommons.org/publicdomain/zero/1.0/ Send questions, -comments, complaints, performance data, etc to dl@cs.oswego.edu - --------------------------------------------------------------------------------- - -src/plasma/common.cc (some portions) - -Copyright (c) Austin Appleby (aappleby (AT) gmail) - -Some portions of this file are derived from code in the MurmurHash project - -All code is released to the public domain. For business purposes, Murmurhash is -under the MIT license. - -https://sites.google.com/site/murmurhash/ - --------------------------------------------------------------------------------- - -src/arrow/util (some portions): Apache 2.0, and 3-clause BSD - -Some portions of this module are derived from code in the Chromium project, -copyright (c) Google inc and (c) The Chromium Authors and licensed under the -Apache 2.0 License or the under the 3-clause BSD license: - - Copyright (c) 2013 The Chromium Authors. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following disclaimer - in the documentation and/or other materials provided with the - distribution. - * Neither the name of Google Inc. nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -This project includes code from Daniel Lemire's FrameOfReference project. - -https://github.com/lemire/FrameOfReference/blob/6ccaf9e97160f9a3b299e23a8ef739e711ef0c71/src/bpacking.cpp - -Copyright: 2013 Daniel Lemire -Home page: http://lemire.me/en/ -Project page: https://github.com/lemire/FrameOfReference -License: Apache License Version 2.0 http://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This project includes code from the TensorFlow project - -Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - --------------------------------------------------------------------------------- - -This project includes code from the NumPy project. - -https://github.com/numpy/numpy/blob/e1f191c46f2eebd6cb892a4bfe14d9dd43a06c4e/numpy/core/src/multiarray/multiarraymodule.c#L2910 - -https://github.com/numpy/numpy/blob/68fd82271b9ea5a9e50d4e761061dfcca851382a/numpy/core/src/multiarray/datetime.c - -Copyright (c) 2005-2017, NumPy Developers. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. - - * Neither the name of the NumPy Developers nor the names of any - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -This project includes code from the Boost project - -Boost Software License - Version 1.0 - August 17th, 2003 - -Permission is hereby granted, free of charge, to any person or organization -obtaining a copy of the software and accompanying documentation covered by -this license (the "Software") to use, reproduce, display, distribute, -execute, and transmit the Software, and to prepare derivative works of the -Software, and to permit third-parties to whom the Software is furnished to -do so, all subject to the following: - -The copyright notices in the Software and this entire statement, including -the above license grant, this restriction and the following disclaimer, -must be included in all copies of the Software, in whole or in part, and -all derivative works of the Software, unless such copies or derivative -works are solely in the form of machine-executable object code generated by -a source language processor. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT -SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE -FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, -ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. - --------------------------------------------------------------------------------- - -This project includes code from the FlatBuffers project - -Copyright 2014 Google Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - --------------------------------------------------------------------------------- - -This project includes code from the tslib project - -Copyright 2015 Microsoft Corporation. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - --------------------------------------------------------------------------------- - -This project includes code from the jemalloc project - -https://github.com/jemalloc/jemalloc - -Copyright (C) 2002-2017 Jason Evans . -All rights reserved. -Copyright (C) 2007-2012 Mozilla Foundation. All rights reserved. -Copyright (C) 2009-2017 Facebook, Inc. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: -1. Redistributions of source code must retain the above copyright notice(s), - this list of conditions and the following disclaimer. -2. Redistributions in binary form must reproduce the above copyright notice(s), - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER(S) ``AS IS'' AND ANY EXPRESS -OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO -EVENT SHALL THE COPYRIGHT HOLDER(S) BE LIABLE FOR ANY DIRECT, INDIRECT, -INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE -OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF -ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. --------------------------------------------------------------------------------- - -This project includes code from the Go project, BSD 3-clause license + PATENTS -weak patent termination clause -(https://github.com/golang/go/blob/master/PATENTS). - -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -This project includes code from the hs2client - -https://github.com/cloudera/hs2client - -Copyright 2016 Cloudera Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - --------------------------------------------------------------------------------- - -The script ci/scripts/util_wait_for_it.sh has the following license - -Copyright (c) 2016 Giles Hall - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - --------------------------------------------------------------------------------- - -The script r/configure has the following license (MIT) - -Copyright (c) 2017, Jeroen Ooms and Jim Hester - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - --------------------------------------------------------------------------------- - -cpp/src/arrow/util/logging.cc, cpp/src/arrow/util/logging.h and -cpp/src/arrow/util/logging-test.cc are adapted from -Ray Project (https://github.com/ray-project/ray) (Apache 2.0). - -Copyright (c) 2016 Ray Project (https://github.com/ray-project/ray) - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - --------------------------------------------------------------------------------- -The files cpp/src/arrow/vendored/datetime/date.h, cpp/src/arrow/vendored/datetime/tz.h, -cpp/src/arrow/vendored/datetime/tz_private.h, cpp/src/arrow/vendored/datetime/ios.h, -cpp/src/arrow/vendored/datetime/ios.mm, -cpp/src/arrow/vendored/datetime/tz.cpp are adapted from -Howard Hinnant's date library (https://github.com/HowardHinnant/date) -It is licensed under MIT license. - -The MIT License (MIT) -Copyright (c) 2015, 2016, 2017 Howard Hinnant -Copyright (c) 2016 Adrian Colomitchi -Copyright (c) 2017 Florian Dang -Copyright (c) 2017 Paul Thompson -Copyright (c) 2018 Tomasz Kamiński - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - --------------------------------------------------------------------------------- - -The file cpp/src/arrow/util/utf8.h includes code adapted from the page - https://bjoern.hoehrmann.de/utf-8/decoder/dfa/ -with the following license (MIT) - -Copyright (c) 2008-2009 Bjoern Hoehrmann - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - --------------------------------------------------------------------------------- - -The file cpp/src/arrow/vendored/string_view.hpp has the following license - -Boost Software License - Version 1.0 - August 17th, 2003 - -Permission is hereby granted, free of charge, to any person or organization -obtaining a copy of the software and accompanying documentation covered by -this license (the "Software") to use, reproduce, display, distribute, -execute, and transmit the Software, and to prepare derivative works of the -Software, and to permit third-parties to whom the Software is furnished to -do so, all subject to the following: - -The copyright notices in the Software and this entire statement, including -the above license grant, this restriction and the following disclaimer, -must be included in all copies of the Software, in whole or in part, and -all derivative works of the Software, unless such copies or derivative -works are solely in the form of machine-executable object code generated by -a source language processor. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT -SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE -FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, -ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. - --------------------------------------------------------------------------------- - -The files in cpp/src/arrow/vendored/xxhash/ have the following license -(BSD 2-Clause License) - -xxHash Library -Copyright (c) 2012-2014, Yann Collet -All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, this - list of conditions and the following disclaimer in the documentation and/or - other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -You can contact the author at : -- xxHash homepage: http://www.xxhash.com -- xxHash source repository : https://github.com/Cyan4973/xxHash - --------------------------------------------------------------------------------- - -The files in cpp/src/arrow/vendored/double-conversion/ have the following license -(BSD 3-Clause License) - -Copyright 2006-2011, the V8 project authors. All rights reserved. -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. - * Neither the name of Google Inc. nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -The files in cpp/src/arrow/vendored/uriparser/ have the following license -(BSD 3-Clause License) - -uriparser - RFC 3986 URI parsing library - -Copyright (C) 2007, Weijia Song -Copyright (C) 2007, Sebastian Pipping -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - - * Redistributions of source code must retain the above - copyright notice, this list of conditions and the following - disclaimer. - - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials - provided with the distribution. - - * Neither the name of the nor the names of its - contributors may be used to endorse or promote products - derived from this software without specific prior written - permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -The files under dev/tasks/conda-recipes have the following license - -BSD 3-clause license -Copyright (c) 2015-2018, conda-forge -All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its contributors - may be used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF -THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -The files in cpp/src/arrow/vendored/utfcpp/ have the following license - -Copyright 2006-2018 Nemanja Trifunovic - -Permission is hereby granted, free of charge, to any person or organization -obtaining a copy of the software and accompanying documentation covered by -this license (the "Software") to use, reproduce, display, distribute, -execute, and transmit the Software, and to prepare derivative works of the -Software, and to permit third-parties to whom the Software is furnished to -do so, all subject to the following: - -The copyright notices in the Software and this entire statement, including -the above license grant, this restriction and the following disclaimer, -must be included in all copies of the Software, in whole or in part, and -all derivative works of the Software, unless such copies or derivative -works are solely in the form of machine-executable object code generated by -a source language processor. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT -SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE -FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, -ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. - --------------------------------------------------------------------------------- - -This project includes code from Apache Kudu. - - * cpp/cmake_modules/CompilerInfo.cmake is based on Kudu's cmake_modules/CompilerInfo.cmake - -Copyright: 2016 The Apache Software Foundation. -Home page: https://kudu.apache.org/ -License: http://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This project includes code from Apache Impala (incubating), formerly -Impala. The Impala code and rights were donated to the ASF as part of the -Incubator process after the initial code imports into Apache Parquet. - -Copyright: 2012 Cloudera, Inc. -Copyright: 2016 The Apache Software Foundation. -Home page: http://impala.apache.org/ -License: http://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- This project includes code from Apache Aurora. @@ -878,1343 +210,3 @@ This project includes code from Apache Aurora. Copyright: 2016 The Apache Software Foundation. Home page: https://aurora.apache.org/ License: http://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This project includes code from the Google styleguide. - -* cpp/build-support/cpplint.py is based on the scripts from the Google styleguide. - -Copyright: 2009 Google Inc. All rights reserved. -Homepage: https://github.com/google/styleguide -License: 3-clause BSD - --------------------------------------------------------------------------------- - -This project includes code from Snappy. - -* cpp/cmake_modules/{SnappyCMakeLists.txt,SnappyConfig.h} are based on code - from Google's Snappy project. - -Copyright: 2009 Google Inc. All rights reserved. -Homepage: https://github.com/google/snappy -License: 3-clause BSD - --------------------------------------------------------------------------------- - -This project includes code from the manylinux project. - -* python/manylinux1/scripts/{build_python.sh,python-tag-abi-tag.py, - requirements.txt} are based on code from the manylinux project. - -Copyright: 2016 manylinux -Homepage: https://github.com/pypa/manylinux -License: The MIT License (MIT) - --------------------------------------------------------------------------------- - -This project includes code from the cymove project: - -* python/pyarrow/includes/common.pxd includes code from the cymove project - -The MIT License (MIT) -Copyright (c) 2019 Omer Ozarslan - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR -OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE -OR OTHER DEALINGS IN THE SOFTWARE. - --------------------------------------------------------------------------------- - -The projects includes code from the Ursabot project under the dev/archery -directory. - -License: BSD 2-Clause - -Copyright 2019 RStudio, Inc. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -This project include code from mingw-w64. - -* cpp/src/arrow/util/cpu-info.cc has a polyfill for mingw-w64 < 5 - -Copyright (c) 2009 - 2013 by the mingw-w64 project -Homepage: https://mingw-w64.org -License: Zope Public License (ZPL) Version 2.1. - ---------------------------------------------------------------------------------- - -This project include code from Google's Asylo project. - -* cpp/src/arrow/result.h is based on status_or.h - -Copyright (c) Copyright 2017 Asylo authors -Homepage: https://asylo.dev/ -License: Apache 2.0 - --------------------------------------------------------------------------------- - -This project includes code from Google's protobuf project - -* cpp/src/arrow/result.h ARROW_ASSIGN_OR_RAISE is based off ASSIGN_OR_RETURN - -Copyright 2008 Google Inc. All rights reserved. -Homepage: https://developers.google.com/protocol-buffers/ -License: - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -Code generated by the Protocol Buffer compiler is owned by the owner -of the input file used when generating it. This code is not -standalone and requires a support library to be linked with it. This -support library is itself covered by the above license. - --------------------------------------------------------------------------------- - -3rdparty dependency LLVM is statically linked in certain binary distributions. -Additionally some sections of source code have been derived from sources in LLVM -and have been clearly labeled as such. LLVM has the following license: - -============================================================================== -The LLVM Project is under the Apache License v2.0 with LLVM Exceptions: -============================================================================== - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - ----- LLVM Exceptions to the Apache 2.0 License ---- - -As an exception, if, as a result of your compiling your source code, portions -of this Software are embedded into an Object form of such source code, you -may redistribute such embedded portions in such Object form without complying -with the conditions of Sections 4(a), 4(b) and 4(d) of the License. - -In addition, if you combine or link compiled forms of this Software with -software that is licensed under the GPLv2 ("Combined Software") and if a -court of competent jurisdiction determines that the patent provision (Section -3), the indemnity provision (Section 9) or other Section of the License -conflicts with the conditions of the GPLv2, you may retroactively and -prospectively choose to deem waived or otherwise exclude such Section(s) of -the License, but only in their entirety and only with respect to the Combined -Software. - -============================================================================== -Software from third parties included in the LLVM Project: -============================================================================== -The LLVM Project contains third party software which is under different license -terms. All such code will be identified clearly using at least one of two -mechanisms: -1) It will be in a separate directory tree with its own `LICENSE.txt` or - `LICENSE` file at the top containing the specific license and restrictions - which apply to that software, or -2) It will contain specific license and restriction terms at the top of every - file. - --------------------------------------------------------------------------------- - -3rdparty dependency gRPC is statically linked in certain binary -distributions, like the python wheels. gRPC has the following license: - -Copyright 2014 gRPC authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - --------------------------------------------------------------------------------- - -3rdparty dependency Apache Thrift is statically linked in certain binary -distributions, like the python wheels. Apache Thrift has the following license: - -Apache Thrift -Copyright (C) 2006 - 2019, The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - --------------------------------------------------------------------------------- - -3rdparty dependency Apache ORC is statically linked in certain binary -distributions, like the python wheels. Apache ORC has the following license: - -Apache ORC -Copyright 2013-2019 The Apache Software Foundation - -This product includes software developed by The Apache Software -Foundation (http://www.apache.org/). - -This product includes software developed by Hewlett-Packard: -(c) Copyright [2014-2015] Hewlett-Packard Development Company, L.P - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - --------------------------------------------------------------------------------- - -3rdparty dependency zstd is statically linked in certain binary -distributions, like the python wheels. ZSTD has the following license: - -BSD License - -For Zstandard software - -Copyright (c) 2016-present, Facebook, Inc. All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - - * Neither the name Facebook nor the names of its contributors may be used to - endorse or promote products derived from this software without specific - prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -3rdparty dependency lz4 is statically linked in certain binary -distributions, like the python wheels. lz4 has the following license: - -LZ4 Library -Copyright (c) 2011-2016, Yann Collet -All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, this - list of conditions and the following disclaimer in the documentation and/or - other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -3rdparty dependency Brotli is statically linked in certain binary -distributions, like the python wheels. Brotli has the following license: - -Copyright (c) 2009, 2010, 2013-2016 by the Brotli Authors. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. - --------------------------------------------------------------------------------- - -3rdparty dependency rapidjson is statically linked in certain binary -distributions, like the python wheels. rapidjson and its dependencies have the -following licenses: - -Tencent is pleased to support the open source community by making RapidJSON -available. - -Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip. -All rights reserved. - -If you have downloaded a copy of the RapidJSON binary from Tencent, please note -that the RapidJSON binary is licensed under the MIT License. -If you have downloaded a copy of the RapidJSON source code from Tencent, please -note that RapidJSON source code is licensed under the MIT License, except for -the third-party components listed below which are subject to different license -terms. Your integration of RapidJSON into your own projects may require -compliance with the MIT License, as well as the other licenses applicable to -the third-party components included within RapidJSON. To avoid the problematic -JSON license in your own projects, it's sufficient to exclude the -bin/jsonchecker/ directory, as it's the only code under the JSON license. -A copy of the MIT License is included in this file. - -Other dependencies and licenses: - - Open Source Software Licensed Under the BSD License: - -------------------------------------------------------------------- - - The msinttypes r29 - Copyright (c) 2006-2013 Alexander Chemeris - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - * Neither the name of copyright holder nor the names of its contributors - may be used to endorse or promote products derived from this software - without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND ANY - EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE REGENTS AND CONTRIBUTORS BE LIABLE FOR - ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT - LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY - OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH - DAMAGE. - - Open Source Software Licensed Under the JSON License: - -------------------------------------------------------------------- - - json.org - Copyright (c) 2002 JSON.org - All Rights Reserved. - - JSON_checker - Copyright (c) 2002 JSON.org - All Rights Reserved. - - - Terms of the JSON License: - --------------------------------------------------- - - Permission is hereby granted, free of charge, to any person obtaining a - copy of this software and associated documentation files (the "Software"), - to deal in the Software without restriction, including without limitation - the rights to use, copy, modify, merge, publish, distribute, sublicense, - and/or sell copies of the Software, and to permit persons to whom the - Software is furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - - The Software shall be used for Good, not Evil. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - DEALINGS IN THE SOFTWARE. - - - Terms of the MIT License: - -------------------------------------------------------------------- - - Permission is hereby granted, free of charge, to any person obtaining a - copy of this software and associated documentation files (the "Software"), - to deal in the Software without restriction, including without limitation - the rights to use, copy, modify, merge, publish, distribute, sublicense, - and/or sell copies of the Software, and to permit persons to whom the - Software is furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included - in all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER - DEALINGS IN THE SOFTWARE. - --------------------------------------------------------------------------------- - -3rdparty dependency snappy is statically linked in certain binary -distributions, like the python wheels. snappy has the following license: - -Copyright 2011, Google Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - * Neither the name of Google Inc. nor the names of its contributors may be - used to endorse or promote products derived from this software without - specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -=== - -Some of the benchmark data in testdata/ is licensed differently: - - - fireworks.jpeg is Copyright 2013 Steinar H. Gunderson, and - is licensed under the Creative Commons Attribution 3.0 license - (CC-BY-3.0). See https://creativecommons.org/licenses/by/3.0/ - for more information. - - - kppkn.gtb is taken from the Gaviota chess tablebase set, and - is licensed under the MIT License. See - https://sites.google.com/site/gaviotachessengine/Home/endgame-tablebases-1 - for more information. - - - paper-100k.pdf is an excerpt (bytes 92160 to 194560) from the paper - “Combinatorial Modeling of Chromatin Features Quantitatively Predicts DNA - Replication Timing in _Drosophila_” by Federico Comoglio and Renato Paro, - which is licensed under the CC-BY license. See - http://www.ploscompbiol.org/static/license for more ifnormation. - - - alice29.txt, asyoulik.txt, plrabn12.txt and lcet10.txt are from Project - Gutenberg. The first three have expired copyrights and are in the public - domain; the latter does not have expired copyright, but is still in the - public domain according to the license information - (http://www.gutenberg.org/ebooks/53). - --------------------------------------------------------------------------------- - -3rdparty dependency gflags is statically linked in certain binary -distributions, like the python wheels. gflags has the following license: - -Copyright (c) 2006, Google Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -3rdparty dependency glog is statically linked in certain binary -distributions, like the python wheels. glog has the following license: - -Copyright (c) 2008, Google Inc. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -A function gettimeofday in utilities.cc is based on - -http://www.google.com/codesearch/p?hl=en#dR3YEbitojA/COPYING&q=GetSystemTimeAsFileTime%20license:bsd - -The license of this code is: - -Copyright (c) 2003-2008, Jouni Malinen and contributors -All Rights Reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - -1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -3. Neither the name(s) of the above-listed copyright holder(s) nor the - names of its contributors may be used to endorse or promote products - derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -3rdparty dependency re2 is statically linked in certain binary -distributions, like the python wheels. re2 has the following license: - -Copyright (c) 2009 The RE2 Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. - * Neither the name of Google Inc. nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -3rdparty dependency c-ares is statically linked in certain binary -distributions, like the python wheels. c-ares has the following license: - -# c-ares license - -Copyright (c) 2007 - 2018, Daniel Stenberg with many contributors, see AUTHORS -file. - -Copyright 1998 by the Massachusetts Institute of Technology. - -Permission to use, copy, modify, and distribute this software and its -documentation for any purpose and without fee is hereby granted, provided that -the above copyright notice appear in all copies and that both that copyright -notice and this permission notice appear in supporting documentation, and that -the name of M.I.T. not be used in advertising or publicity pertaining to -distribution of the software without specific, written prior permission. -M.I.T. makes no representations about the suitability of this software for any -purpose. It is provided "as is" without express or implied warranty. - --------------------------------------------------------------------------------- - -3rdparty dependency zlib is redistributed as a dynamically linked shared -library in certain binary distributions, like the python wheels. In the future -this will likely change to static linkage. zlib has the following license: - -zlib.h -- interface of the 'zlib' general purpose compression library - version 1.2.11, January 15th, 2017 - - Copyright (C) 1995-2017 Jean-loup Gailly and Mark Adler - - This software is provided 'as-is', without any express or implied - warranty. In no event will the authors be held liable for any damages - arising from the use of this software. - - Permission is granted to anyone to use this software for any purpose, - including commercial applications, and to alter it and redistribute it - freely, subject to the following restrictions: - - 1. The origin of this software must not be misrepresented; you must not - claim that you wrote the original software. If you use this software - in a product, an acknowledgment in the product documentation would be - appreciated but is not required. - 2. Altered source versions must be plainly marked as such, and must not be - misrepresented as being the original software. - 3. This notice may not be removed or altered from any source distribution. - - Jean-loup Gailly Mark Adler - jloup@gzip.org madler@alumni.caltech.edu - --------------------------------------------------------------------------------- - -3rdparty dependency openssl is redistributed as a dynamically linked shared -library in certain binary distributions, like the python wheels. openssl -preceding version 3 has the following license: - - LICENSE ISSUES - ============== - - The OpenSSL toolkit stays under a double license, i.e. both the conditions of - the OpenSSL License and the original SSLeay license apply to the toolkit. - See below for the actual license texts. - - OpenSSL License - --------------- - -/* ==================================================================== - * Copyright (c) 1998-2019 The OpenSSL Project. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in - * the documentation and/or other materials provided with the - * distribution. - * - * 3. All advertising materials mentioning features or use of this - * software must display the following acknowledgment: - * "This product includes software developed by the OpenSSL Project - * for use in the OpenSSL Toolkit. (http://www.openssl.org/)" - * - * 4. The names "OpenSSL Toolkit" and "OpenSSL Project" must not be used to - * endorse or promote products derived from this software without - * prior written permission. For written permission, please contact - * openssl-core@openssl.org. - * - * 5. Products derived from this software may not be called "OpenSSL" - * nor may "OpenSSL" appear in their names without prior written - * permission of the OpenSSL Project. - * - * 6. Redistributions of any form whatsoever must retain the following - * acknowledgment: - * "This product includes software developed by the OpenSSL Project - * for use in the OpenSSL Toolkit (http://www.openssl.org/)" - * - * THIS SOFTWARE IS PROVIDED BY THE OpenSSL PROJECT ``AS IS'' AND ANY - * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE OpenSSL PROJECT OR - * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT - * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) - * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED - * OF THE POSSIBILITY OF SUCH DAMAGE. - * ==================================================================== - * - * This product includes cryptographic software written by Eric Young - * (eay@cryptsoft.com). This product includes software written by Tim - * Hudson (tjh@cryptsoft.com). - * - */ - - Original SSLeay License - ----------------------- - -/* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com) - * All rights reserved. - * - * This package is an SSL implementation written - * by Eric Young (eay@cryptsoft.com). - * The implementation was written so as to conform with Netscapes SSL. - * - * This library is free for commercial and non-commercial use as long as - * the following conditions are aheared to. The following conditions - * apply to all code found in this distribution, be it the RC4, RSA, - * lhash, DES, etc., code; not just the SSL code. The SSL documentation - * included with this distribution is covered by the same copyright terms - * except that the holder is Tim Hudson (tjh@cryptsoft.com). - * - * Copyright remains Eric Young's, and as such any Copyright notices in - * the code are not to be removed. - * If this package is used in a product, Eric Young should be given attribution - * as the author of the parts of the library used. - * This can be in the form of a textual message at program startup or - * in documentation (online or textual) provided with the package. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions - * are met: - * 1. Redistributions of source code must retain the copyright - * notice, this list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * 3. All advertising materials mentioning features or use of this software - * must display the following acknowledgement: - * "This product includes cryptographic software written by - * Eric Young (eay@cryptsoft.com)" - * The word 'cryptographic' can be left out if the rouines from the library - * being used are not cryptographic related :-). - * 4. If you include any Windows specific code (or a derivative thereof) from - * the apps directory (application code) you must include an acknowledgement: - * "This product includes software written by Tim Hudson (tjh@cryptsoft.com)" - * - * THIS SOFTWARE IS PROVIDED BY ERIC YOUNG ``AS IS'' AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS - * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) - * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT - * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY - * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF - * SUCH DAMAGE. - * - * The licence and distribution terms for any publically available version or - * derivative of this code cannot be changed. i.e. this code cannot simply be - * copied and put under another distribution licence - * [including the GNU Public Licence.] - */ - --------------------------------------------------------------------------------- - -This project includes code from the rtools-backports project. - -* ci/scripts/PKGBUILD and ci/scripts/r_windows_build.sh are based on code - from the rtools-backports project. - -Copyright: Copyright (c) 2013 - 2019, Алексей and Jeroen Ooms. -All rights reserved. -Homepage: https://github.com/r-windows/rtools-backports -License: 3-clause BSD - --------------------------------------------------------------------------------- - -Some code from pandas has been adapted for the pyarrow codebase. pandas is -available under the 3-clause BSD license, which follows: - -pandas license -============== - -Copyright (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team -All rights reserved. - -Copyright (c) 2008-2011 AQR Capital Management, LLC -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. - - * Neither the name of the copyright holder nor the names of any - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - --------------------------------------------------------------------------------- - -Some bits from DyND, in particular aspects of the build system, have been -adapted from libdynd and dynd-python under the terms of the BSD 2-clause -license - -The BSD 2-Clause License - - Copyright (C) 2011-12, Dynamic NDArray Developers - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials provided - with the distribution. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -Dynamic NDArray Developers list: - - * Mark Wiebe - * Continuum Analytics - --------------------------------------------------------------------------------- - -Some source code from Ibis (https://github.com/cloudera/ibis) has been adapted -for PyArrow. Ibis is released under the Apache License, Version 2.0. - --------------------------------------------------------------------------------- - -This project includes code from the autobrew project. - -* r/tools/autobrew and dev/tasks/homebrew-formulae/autobrew/apache-arrow.rb - are based on code from the autobrew project. - -Copyright (c) 2019, Jeroen Ooms -License: MIT -Homepage: https://github.com/jeroen/autobrew - --------------------------------------------------------------------------------- - -dev/tasks/homebrew-formulae/apache-arrow.rb has the following license: - -BSD 2-Clause License - -Copyright (c) 2009-present, Homebrew contributors -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - ----------------------------------------------------------------------- - -cpp/src/arrow/vendored/base64.cpp has the following license - -ZLIB License - -Copyright (C) 2004-2017 René Nyffenegger - -This source code is provided 'as-is', without any express or implied -warranty. In no event will the author be held liable for any damages arising -from the use of this software. - -Permission is granted to anyone to use this software for any purpose, including -commercial applications, and to alter it and redistribute it freely, subject to -the following restrictions: - -1. The origin of this source code must not be misrepresented; you must not - claim that you wrote the original source code. If you use this source code - in a product, an acknowledgment in the product documentation would be - appreciated but is not required. - -2. Altered source versions must be plainly marked as such, and must not be - misrepresented as being the original source code. - -3. This notice may not be removed or altered from any source distribution. - -René Nyffenegger rene.nyffenegger@adp-gmbh.ch - --------------------------------------------------------------------------------- - -The file cpp/src/arrow/vendored/optional.hpp has the following license - -Boost Software License - Version 1.0 - August 17th, 2003 - -Permission is hereby granted, free of charge, to any person or organization -obtaining a copy of the software and accompanying documentation covered by -this license (the "Software") to use, reproduce, display, distribute, -execute, and transmit the Software, and to prepare derivative works of the -Software, and to permit third-parties to whom the Software is furnished to -do so, all subject to the following: - -The copyright notices in the Software and this entire statement, including -the above license grant, this restriction and the following disclaimer, -must be included in all copies of the Software, in whole or in part, and -all derivative works of the Software, unless such copies or derivative -works are solely in the form of machine-executable object code generated by -a source language processor. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT -SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE -FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, -ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. - --------------------------------------------------------------------------------- - -This project includes code from Folly. - - * cpp/src/arrow/vendored/ProducerConsumerQueue.h - -is based on Folly's - - * folly/Portability.h - * folly/lang/Align.h - * folly/ProducerConsumerQueue.h - -Copyright: Copyright (c) Facebook, Inc. and its affiliates. -Home page: https://github.com/facebook/folly -License: http://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -The file cpp/src/arrow/vendored/musl/strptime.c has the following license - -Copyright © 2005-2020 Rich Felker, et al. - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - --------------------------------------------------------------------------------- - -The file cpp/cmake_modules/BuildUtils.cmake contains code from - -https://gist.github.com/cristianadam/ef920342939a89fae3e8a85ca9459b49 - -which is made available under the MIT license - -Copyright (c) 2019 Cristian Adam - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - --------------------------------------------------------------------------------- - -The files in cpp/src/arrow/vendored/portable-snippets/ contain code from - -https://github.com/nemequ/portable-snippets - -and have the following copyright notice: - -Each source file contains a preamble explaining the license situation -for that file, which takes priority over this file. With the -exception of some code pulled in from other repositories (such as -µnit, an MIT-licensed project which is used for testing), the code is -public domain, released using the CC0 1.0 Universal dedication (*). - -(*) https://creativecommons.org/publicdomain/zero/1.0/legalcode - --------------------------------------------------------------------------------- - -The files in cpp/src/arrow/vendored/fast_float/ contain code from - -https://github.com/lemire/fast_float - -which is made available under the Apache License 2.0. - --------------------------------------------------------------------------------- - -The file python/pyarrow/vendored/version.py contains code from - -https://github.com/pypa/packaging/ - -which is made available under both the Apache license v2.0 and the -BSD 2-clause license. diff --git a/README.md b/README.md index 08385fb6c15d..55bdad6cb55c 100644 --- a/README.md +++ b/README.md @@ -25,11 +25,12 @@ Welcome to the implementation of Arrow, the popular in-memory columnar format, i This repo contains the following main components: -| Crate | Description | Documentation | -| ------------ | ------------------------------------------------------------------ | -------------------------- | -| arrow | Core functionality (memory layout, arrays, low level computations) | [(README)][arrow-readme] | -| parquet | Support for Parquet columnar file format | [(README)][parquet-readme] | -| arrow-flight | Support for Arrow-Flight IPC protocol | [(README)][flight-readme] | +| Crate | Description | Documentation | +| ------------ | ------------------------------------------------------------------------- | ------------------------------ | +| arrow | Core functionality (memory layout, arrays, low level computations) | [(README)][arrow-readme] | +| parquet | Support for Parquet columnar file format | [(README)][parquet-readme] | +| arrow-flight | Support for Arrow-Flight IPC protocol | [(README)][flight-readme] | +| object-store | Support for object store interactions (aws, azure, gcp, local, in-memory) | [(README)][objectstore-readme] | There are two related crates in a different repository @@ -51,7 +52,11 @@ You can find more details about each crate in their respective READMEs. The `dev@arrow.apache.org` mailing list serves as the core communication channel for the Arrow community. Instructions for signing up and links to the archives can be found at the [Arrow Community](https://arrow.apache.org/community/) page. All major announcements and communications happen there. The Rust Arrow community also uses the official [ASF Slack](https://s.apache.org/slack-invite) for informal discussions and coordination. This is -a great place to meet other contributors and get guidance on where to contribute. Join us in the `#arrow-rust` channel. +a great place to meet other contributors and get guidance on where to contribute. Join us in the `#arrow-rust` channel and feel free to ask for an invite via: + +1. the `dev@arrow.apache.org` mailing list +2. the [GitHub Discussions][discussions] +3. the [Discord channel](https://discord.gg/YAb2TdazKQ) Unlike other parts of the Arrow ecosystem, the Rust implementation uses [GitHub issues][issues] as the system of record for new features and bug fixes and this plays a critical role in the release process. @@ -67,4 +72,6 @@ There is more information in the [contributing] guide. [flight-readme]: arrow-flight/README.md [datafusion-readme]: https://github.com/apache/arrow-datafusion/blob/master/README.md [ballista-readme]: https://github.com/apache/arrow-ballista/blob/master/README.md +[objectstore-readme]: https://github.com/apache/arrow-rs/blob/master/object_store/README.md [issues]: https://github.com/apache/arrow-rs/issues +[discussions]: https://github.com/apache/arrow-rs/discussions diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index eb8374156e7f..ecf02625c9d3 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -18,22 +18,22 @@ [package] name = "arrow-flight" description = "Apache Arrow Flight" -version = "18.0.0" +version = "22.0.0" edition = "2021" -rust-version = "1.57" +rust-version = "1.62" authors = ["Apache Arrow "] homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" license = "Apache-2.0" [dependencies] -arrow = { path = "../arrow", version = "18.0.0", default-features = false, features = ["ipc"] } +arrow = { path = "../arrow", version = "22.0.0", default-features = false, features = ["ipc"] } base64 = { version = "0.13", default-features = false } -tonic = { version = "0.7", default-features = false, features = ["transport", "codegen", "prost"] } +tonic = { version = "0.8", default-features = false, features = ["transport", "codegen", "prost"] } bytes = { version = "1", default-features = false } -prost = { version = "0.10", default-features = false } -prost-types = { version = "0.10.0", default-features = false, optional = true } -prost-derive = { version = "0.10", default-features = false } +prost = { version = "0.11", default-features = false } +prost-types = { version = "0.11.0", default-features = false, optional = true } +prost-derive = { version = "0.11", default-features = false } tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] } futures = { version = "0.3", default-features = false, features = ["alloc"]} @@ -44,7 +44,7 @@ flight-sql-experimental = ["prost-types"] [dev-dependencies] [build-dependencies] -tonic-build = { version = "0.7", default-features = false, features = ["transport", "prost"] } +tonic-build = { version = "0.8", default-features = false, features = ["transport", "prost"] } # Pin specific version of the tonic-build dependencies to avoid auto-generated # (and checked in) arrow.flight.protocol.rs from changing proc-macro2 = { version = ">1.0.30", default-features = false } diff --git a/arrow-flight/README.md b/arrow-flight/README.md index 9f835a8dc357..9e9a18ad4789 100644 --- a/arrow-flight/README.md +++ b/arrow-flight/README.md @@ -27,7 +27,7 @@ Add this to your Cargo.toml: ```toml [dependencies] -arrow-flight = "18.0.0" +arrow-flight = "22.0.0" ``` Apache Arrow Flight is a gRPC based protocol for exchanging Arrow data between processes. See the blog post [Introducing Apache Arrow Flight: A Framework for Fast Data Transport](https://arrow.apache.org/blog/2019/10/13/introducing-arrow-flight/) for more information. diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 8b4fe477b868..aa0d407113d7 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -16,9 +16,11 @@ // under the License. use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo}; -use arrow_flight::FlightData; +use arrow_flight::{Action, FlightData, HandshakeRequest, HandshakeResponse, Ticket}; +use futures::Stream; +use std::pin::Pin; use tonic::transport::Server; -use tonic::{Response, Status, Streaming}; +use tonic::{Request, Response, Status, Streaming}; use arrow_flight::{ flight_service_server::FlightService, @@ -41,183 +43,303 @@ pub struct FlightSqlServiceImpl {} #[tonic::async_trait] impl FlightSqlService for FlightSqlServiceImpl { type FlightService = FlightSqlServiceImpl; + + async fn do_handshake( + &self, + request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + let basic = "Basic "; + let authorization = request + .metadata() + .get("authorization") + .ok_or(Status::invalid_argument("authorization field not present"))? + .to_str() + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + if !authorization.starts_with(basic) { + Err(Status::invalid_argument(format!( + "Auth type not implemented: {}", + authorization + )))?; + } + let base64 = &authorization[basic.len()..]; + let bytes = base64::decode(base64) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let str = String::from_utf8(bytes) + .map_err(|_| Status::invalid_argument("authorization not parsable"))?; + let parts: Vec<_> = str.split(":").collect(); + if parts.len() != 2 { + Err(Status::invalid_argument(format!( + "Invalid authorization header" + )))?; + } + let user = parts[0]; + let pass = parts[1]; + if user != "admin" || pass != "password" { + Err(Status::unauthenticated("Invalid credentials!"))? + } + let result = HandshakeResponse { + protocol_version: 0, + payload: "random_uuid_token".as_bytes().to_vec(), + }; + let result = Ok(result); + let output = futures::stream::iter(vec![result]); + return Ok(Response::new(Box::pin(output))); + } + // get_flight_info async fn get_flight_info_statement( &self, _query: CommandStatementQuery, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_statement not implemented", + )) } + async fn get_flight_info_prepared_statement( &self, _query: CommandPreparedStatementQuery, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_prepared_statement not implemented", + )) } + async fn get_flight_info_catalogs( &self, _query: CommandGetCatalogs, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_catalogs not implemented", + )) } + async fn get_flight_info_schemas( &self, _query: CommandGetDbSchemas, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_schemas not implemented", + )) } + async fn get_flight_info_tables( &self, _query: CommandGetTables, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_tables not implemented", + )) } + async fn get_flight_info_table_types( &self, _query: CommandGetTableTypes, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_table_types not implemented", + )) } + async fn get_flight_info_sql_info( &self, _query: CommandGetSqlInfo, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_sql_info not implemented", + )) } + async fn get_flight_info_primary_keys( &self, _query: CommandGetPrimaryKeys, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_primary_keys not implemented", + )) } + async fn get_flight_info_exported_keys( &self, _query: CommandGetExportedKeys, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_exported_keys not implemented", + )) } + async fn get_flight_info_imported_keys( &self, _query: CommandGetImportedKeys, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_imported_keys not implemented", + )) } + async fn get_flight_info_cross_reference( &self, _query: CommandGetCrossReference, - _request: FlightDescriptor, + _request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "get_flight_info_imported_keys not implemented", + )) } + // do_get async fn do_get_statement( &self, _ticket: TicketStatementQuery, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_statement not implemented")) } async fn do_get_prepared_statement( &self, _query: CommandPreparedStatementQuery, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_get_prepared_statement not implemented", + )) } + async fn do_get_catalogs( &self, _query: CommandGetCatalogs, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_catalogs not implemented")) } + async fn do_get_schemas( &self, _query: CommandGetDbSchemas, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_schemas not implemented")) } + async fn do_get_tables( &self, _query: CommandGetTables, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_tables not implemented")) } + async fn do_get_table_types( &self, _query: CommandGetTableTypes, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_table_types not implemented")) } + async fn do_get_sql_info( &self, _query: CommandGetSqlInfo, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_sql_info not implemented")) } + async fn do_get_primary_keys( &self, _query: CommandGetPrimaryKeys, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented("do_get_primary_keys not implemented")) } + async fn do_get_exported_keys( &self, _query: CommandGetExportedKeys, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_get_exported_keys not implemented", + )) } + async fn do_get_imported_keys( &self, _query: CommandGetImportedKeys, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_get_imported_keys not implemented", + )) } + async fn do_get_cross_reference( &self, _query: CommandGetCrossReference, + _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_get_cross_reference not implemented", + )) } + // do_put async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, + _request: Request>, ) -> Result { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_put_statement_update not implemented", + )) } + async fn do_put_prepared_statement_query( &self, _query: CommandPreparedStatementQuery, - _request: Streaming, + _request: Request>, ) -> Result::DoPutStream>, Status> { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_put_prepared_statement_query not implemented", + )) } + async fn do_put_prepared_statement_update( &self, _query: CommandPreparedStatementUpdate, - _request: Streaming, + _request: Request>, ) -> Result { - Err(Status::unimplemented("Not yet implemented")) + Err(Status::unimplemented( + "do_put_prepared_statement_update not implemented", + )) } + // do_action async fn do_action_create_prepared_statement( &self, _query: ActionCreatePreparedStatementRequest, + _request: Request, ) -> Result { Err(Status::unimplemented("Not yet implemented")) } async fn do_action_close_prepared_statement( &self, _query: ActionClosePreparedStatementRequest, + _request: Request, ) { unimplemented!("Not yet implemented") } diff --git a/arrow-flight/src/arrow.flight.protocol.rs b/arrow-flight/src/arrow.flight.protocol.rs index c76469b39ce7..2b085d6d1f6b 100644 --- a/arrow-flight/src/arrow.flight.protocol.rs +++ b/arrow-flight/src/arrow.flight.protocol.rs @@ -1,31 +1,31 @@ // This file was automatically generated through the build.rs script, and should not be edited. /// -/// The request that a client provides to a server on handshake. +/// The request that a client provides to a server on handshake. #[derive(Clone, PartialEq, ::prost::Message)] pub struct HandshakeRequest { /// - /// A defined protocol version + /// A defined protocol version #[prost(uint64, tag="1")] pub protocol_version: u64, /// - /// Arbitrary auth/handshake info. + /// Arbitrary auth/handshake info. #[prost(bytes="vec", tag="2")] pub payload: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct HandshakeResponse { /// - /// A defined protocol version + /// A defined protocol version #[prost(uint64, tag="1")] pub protocol_version: u64, /// - /// Arbitrary auth/handshake info. + /// Arbitrary auth/handshake info. #[prost(bytes="vec", tag="2")] pub payload: ::prost::alloc::vec::Vec, } /// -/// A message for doing simple auth. +/// A message for doing simple auth. #[derive(Clone, PartialEq, ::prost::Message)] pub struct BasicAuth { #[prost(string, tag="2")] @@ -37,8 +37,8 @@ pub struct BasicAuth { pub struct Empty { } /// -/// Describes an available action, including both the name used for execution -/// along with a short description of the purpose of the action. +/// Describes an available action, including both the name used for execution +/// along with a short description of the purpose of the action. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionType { #[prost(string, tag="1")] @@ -47,15 +47,15 @@ pub struct ActionType { pub description: ::prost::alloc::string::String, } /// -/// A service specific expression that can be used to return a limited set -/// of available Arrow Flight streams. +/// A service specific expression that can be used to return a limited set +/// of available Arrow Flight streams. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Criteria { #[prost(bytes="vec", tag="1")] pub expression: ::prost::alloc::vec::Vec, } /// -/// An opaque action specific for the service. +/// An opaque action specific for the service. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Action { #[prost(string, tag="1")] @@ -64,138 +64,151 @@ pub struct Action { pub body: ::prost::alloc::vec::Vec, } /// -/// An opaque result returned after executing an action. +/// An opaque result returned after executing an action. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Result { #[prost(bytes="vec", tag="1")] pub body: ::prost::alloc::vec::Vec, } /// -/// Wrap the result of a getSchema call +/// Wrap the result of a getSchema call #[derive(Clone, PartialEq, ::prost::Message)] pub struct SchemaResult { - /// schema of the dataset as described in Schema.fbs::Schema. + /// schema of the dataset as described in Schema.fbs::Schema. #[prost(bytes="vec", tag="1")] pub schema: ::prost::alloc::vec::Vec, } /// -/// The name or tag for a Flight. May be used as a way to retrieve or generate -/// a flight or be used to expose a set of previously defined flights. +/// The name or tag for a Flight. May be used as a way to retrieve or generate +/// a flight or be used to expose a set of previously defined flights. #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightDescriptor { #[prost(enumeration="flight_descriptor::DescriptorType", tag="1")] pub r#type: i32, /// - /// Opaque value used to express a command. Should only be defined when - /// type = CMD. + /// Opaque value used to express a command. Should only be defined when + /// type = CMD. #[prost(bytes="vec", tag="2")] pub cmd: ::prost::alloc::vec::Vec, /// - /// List of strings identifying a particular dataset. Should only be defined - /// when type = PATH. + /// List of strings identifying a particular dataset. Should only be defined + /// when type = PATH. #[prost(string, repeated, tag="3")] pub path: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } /// Nested message and enum types in `FlightDescriptor`. pub mod flight_descriptor { /// - /// Describes what type of descriptor is defined. + /// Describes what type of descriptor is defined. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum DescriptorType { - /// Protobuf pattern, not used. + /// Protobuf pattern, not used. Unknown = 0, /// - /// A named path that identifies a dataset. A path is composed of a string - /// or list of strings describing a particular dataset. This is conceptually - /// similar to a path inside a filesystem. + /// A named path that identifies a dataset. A path is composed of a string + /// or list of strings describing a particular dataset. This is conceptually + /// similar to a path inside a filesystem. Path = 1, /// - /// An opaque command to generate a dataset. + /// An opaque command to generate a dataset. Cmd = 2, } + impl DescriptorType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + DescriptorType::Unknown => "UNKNOWN", + DescriptorType::Path => "PATH", + DescriptorType::Cmd => "CMD", + } + } + } } /// -/// The access coordinates for retrieval of a dataset. With a FlightInfo, a -/// consumer is able to determine how to retrieve a dataset. +/// The access coordinates for retrieval of a dataset. With a FlightInfo, a +/// consumer is able to determine how to retrieve a dataset. #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightInfo { - /// schema of the dataset as described in Schema.fbs::Schema. + /// schema of the dataset as described in Schema.fbs::Schema. #[prost(bytes="vec", tag="1")] pub schema: ::prost::alloc::vec::Vec, /// - /// The descriptor associated with this info. + /// The descriptor associated with this info. #[prost(message, optional, tag="2")] pub flight_descriptor: ::core::option::Option, /// - /// A list of endpoints associated with the flight. To consume the whole - /// flight, all endpoints must be consumed. + /// A list of endpoints associated with the flight. To consume the whole + /// flight, all endpoints must be consumed. #[prost(message, repeated, tag="3")] pub endpoint: ::prost::alloc::vec::Vec, - /// Set these to -1 if unknown. + /// Set these to -1 if unknown. #[prost(int64, tag="4")] pub total_records: i64, #[prost(int64, tag="5")] pub total_bytes: i64, } /// -/// A particular stream or split associated with a flight. +/// A particular stream or split associated with a flight. #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightEndpoint { /// - /// Token used to retrieve this stream. + /// Token used to retrieve this stream. #[prost(message, optional, tag="1")] pub ticket: ::core::option::Option, /// - /// A list of URIs where this ticket can be redeemed. If the list is - /// empty, the expectation is that the ticket can only be redeemed on the - /// current service where the ticket was generated. + /// A list of URIs where this ticket can be redeemed. If the list is + /// empty, the expectation is that the ticket can only be redeemed on the + /// current service where the ticket was generated. #[prost(message, repeated, tag="2")] pub location: ::prost::alloc::vec::Vec, } /// -/// A location where a Flight service will accept retrieval of a particular -/// stream given a ticket. +/// A location where a Flight service will accept retrieval of a particular +/// stream given a ticket. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Location { #[prost(string, tag="1")] pub uri: ::prost::alloc::string::String, } /// -/// An opaque identifier that the service can use to retrieve a particular -/// portion of a stream. +/// An opaque identifier that the service can use to retrieve a particular +/// portion of a stream. #[derive(Clone, PartialEq, ::prost::Message)] pub struct Ticket { #[prost(bytes="vec", tag="1")] pub ticket: ::prost::alloc::vec::Vec, } /// -/// A batch of Arrow data as part of a stream of batches. +/// A batch of Arrow data as part of a stream of batches. #[derive(Clone, PartialEq, ::prost::Message)] pub struct FlightData { /// - /// The descriptor of the data. This is only relevant when a client is - /// starting a new DoPut stream. + /// The descriptor of the data. This is only relevant when a client is + /// starting a new DoPut stream. #[prost(message, optional, tag="1")] pub flight_descriptor: ::core::option::Option, /// - /// Header for message data as described in Message.fbs::Message. + /// Header for message data as described in Message.fbs::Message. #[prost(bytes="vec", tag="2")] pub data_header: ::prost::alloc::vec::Vec, /// - /// Application-defined metadata. + /// Application-defined metadata. #[prost(bytes="vec", tag="3")] pub app_metadata: ::prost::alloc::vec::Vec, /// - /// The actual batch of Arrow data. Preferably handled with minimal-copies - /// coming last in the definition to help with sidecar patterns (it is - /// expected that some implementations will fetch this field off the wire - /// with specialized code to avoid extra memory copies). + /// The actual batch of Arrow data. Preferably handled with minimal-copies + /// coming last in the definition to help with sidecar patterns (it is + /// expected that some implementations will fetch this field off the wire + /// with specialized code to avoid extra memory copies). #[prost(bytes="vec", tag="1000")] pub data_body: ::prost::alloc::vec::Vec, } -///* -/// The response message associated with the submission of a DoPut. +/// * +/// The response message associated with the submission of a DoPut. #[derive(Clone, PartialEq, ::prost::Message)] pub struct PutResult { #[prost(bytes="vec", tag="1")] @@ -205,6 +218,7 @@ pub struct PutResult { pub mod flight_service_client { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] use tonic::codegen::*; + use tonic::codegen::http::Uri; /// /// A flight service is an endpoint for retrieving or storing Arrow data. A /// flight service can expose one or more predefined endpoints that can be @@ -236,6 +250,10 @@ pub mod flight_service_client { let inner = tonic::client::Grpc::new(inner); Self { inner } } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } pub fn with_interceptor( inner: T, interceptor: F, @@ -255,19 +273,19 @@ pub mod flight_service_client { { FlightServiceClient::new(InterceptedService::new(inner, interceptor)) } - /// Compress requests with `gzip`. + /// Compress requests with the given encoding. /// /// This requires the server to support it otherwise it might respond with an /// error. #[must_use] - pub fn send_gzip(mut self) -> Self { - self.inner = self.inner.send_gzip(); + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); self } - /// Enable decompressing responses with `gzip`. + /// Enable decompressing responses. #[must_use] - pub fn accept_gzip(mut self) -> Self { - self.inner = self.inner.accept_gzip(); + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); self } /// @@ -672,8 +690,8 @@ pub mod flight_service_server { #[derive(Debug)] pub struct FlightServiceServer { inner: _Inner, - accept_compression_encodings: (), - send_compression_encodings: (), + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, } struct _Inner(Arc); impl FlightServiceServer { @@ -697,6 +715,18 @@ pub mod flight_service_server { { InterceptedService::new(Self::new(inner), interceptor) } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } } impl tonic::codegen::Service> for FlightServiceServer where @@ -1108,7 +1138,7 @@ pub mod flight_service_server { write!(f, "{:?}", self.0) } } - impl tonic::transport::NamedService for FlightServiceServer { + impl tonic::server::NamedService for FlightServiceServer { const NAME: &'static str = "arrow.flight.protocol.FlightService"; } } diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 5cfbd3f60657..3f4f09855353 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -28,6 +28,7 @@ use std::{ ops::Deref, }; +#[allow(clippy::derive_partial_eq_without_eq)] mod gen { include!("arrow.flight.protocol.rs"); } diff --git a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs index ea378a0a2577..77221dd1a489 100644 --- a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs +++ b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs @@ -1,1008 +1,1099 @@ // This file was automatically generated through the build.rs script, and should not be edited. /// -/// Represents a metadata request. Used in the command member of FlightDescriptor -/// for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the metadata request. +/// Represents a metadata request. Used in the command member of FlightDescriptor +/// for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the metadata request. /// -/// The returned Arrow schema will be: -/// < -/// info_name: uint32 not null, -/// value: dense_union< -/// string_value: utf8, -/// bool_value: bool, -/// bigint_value: int64, -/// int32_bitmask: int32, -/// string_list: list -/// int32_to_int32_list_map: map> -/// > -/// where there is one row per requested piece of metadata information. +/// The returned Arrow schema will be: +/// < +/// info_name: uint32 not null, +/// value: dense_union< +/// string_value: utf8, +/// bool_value: bool, +/// bigint_value: int64, +/// int32_bitmask: int32, +/// string_list: list +/// int32_to_int32_list_map: map> +/// > +/// where there is one row per requested piece of metadata information. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetSqlInfo { /// - /// Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide - /// Flight SQL clients with basic, SQL syntax and SQL functions related information. - /// More information types can be added in future releases. - /// E.g. more SQL syntax support types, scalar functions support, type conversion support etc. + /// Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide + /// Flight SQL clients with basic, SQL syntax and SQL functions related information. + /// More information types can be added in future releases. + /// E.g. more SQL syntax support types, scalar functions support, type conversion support etc. /// - /// Note that the set of metadata may expand. + /// Note that the set of metadata may expand. /// - /// Initially, Flight SQL will support the following information types: - /// - Server Information - Range [0-500) - /// - Syntax Information - Range [500-1000) - /// Range [0-10,000) is reserved for defaults (see SqlInfo enum for default options). - /// Custom options should start at 10,000. + /// Initially, Flight SQL will support the following information types: + /// - Server Information - Range [0-500) + /// - Syntax Information - Range [500-1000) + /// Range [0-10,000) is reserved for defaults (see SqlInfo enum for default options). + /// Custom options should start at 10,000. /// - /// If omitted, then all metadata will be retrieved. - /// Flight SQL Servers may choose to include additional metadata above and beyond the specified set, however they must - /// at least return the specified set. IDs ranging from 0 to 10,000 (exclusive) are reserved for future use. - /// If additional metadata is included, the metadata IDs should start from 10,000. + /// If omitted, then all metadata will be retrieved. + /// Flight SQL Servers may choose to include additional metadata above and beyond the specified set, however they must + /// at least return the specified set. IDs ranging from 0 to 10,000 (exclusive) are reserved for future use. + /// If additional metadata is included, the metadata IDs should start from 10,000. #[prost(uint32, repeated, tag="1")] pub info: ::prost::alloc::vec::Vec, } /// -/// Represents a request to retrieve the list of catalogs on a Flight SQL enabled backend. -/// The definition of a catalog depends on vendor/implementation. It is usually the database itself -/// Used in the command member of FlightDescriptor for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the catalog metadata request. +/// Represents a request to retrieve the list of catalogs on a Flight SQL enabled backend. +/// The definition of a catalog depends on vendor/implementation. It is usually the database itself +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < -/// catalog_name: utf8 not null -/// > -/// The returned data should be ordered by catalog_name. +/// The returned Arrow schema will be: +/// < +/// catalog_name: utf8 not null +/// > +/// The returned data should be ordered by catalog_name. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetCatalogs { } /// -/// Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. -/// The definition of a database schema depends on vendor/implementation. It is usually a collection of tables. -/// Used in the command member of FlightDescriptor for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the catalog metadata request. +/// Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. +/// The definition of a database schema depends on vendor/implementation. It is usually a collection of tables. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < -/// catalog_name: utf8, -/// db_schema_name: utf8 not null -/// > -/// The returned data should be ordered by catalog_name, then db_schema_name. +/// The returned Arrow schema will be: +/// < +/// catalog_name: utf8, +/// db_schema_name: utf8 not null +/// > +/// The returned data should be ordered by catalog_name, then db_schema_name. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetDbSchemas { /// - /// Specifies the Catalog to search for the tables. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. + /// Specifies the Catalog to search for the tables. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. #[prost(string, optional, tag="1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter pattern for schemas to search for. - /// When no db_schema_filter_pattern is provided, the pattern will not be used to narrow the search. - /// In the pattern string, two special characters can be used to denote matching rules: - /// - "%" means to match any substring with 0 or more characters. - /// - "_" means to match any one character. + /// Specifies a filter pattern for schemas to search for. + /// When no db_schema_filter_pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. #[prost(string, optional, tag="2")] pub db_schema_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, } /// -/// Represents a request to retrieve the list of tables, and optionally their schemas, on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the catalog metadata request. +/// Represents a request to retrieve the list of tables, and optionally their schemas, on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < -/// catalog_name: utf8, -/// db_schema_name: utf8, -/// table_name: utf8 not null, -/// table_type: utf8 not null, -/// \[optional\] table_schema: bytes not null (schema of the table as described in Schema.fbs::Schema, -/// it is serialized as an IPC message.) -/// > -/// The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. +/// The returned Arrow schema will be: +/// < +/// catalog_name: utf8, +/// db_schema_name: utf8, +/// table_name: utf8 not null, +/// table_type: utf8 not null, +/// \[optional\] table_schema: bytes not null (schema of the table as described in Schema.fbs::Schema, +/// it is serialized as an IPC message.) +/// > +/// The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetTables { /// - /// Specifies the Catalog to search for the tables. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. + /// Specifies the Catalog to search for the tables. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. #[prost(string, optional, tag="1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter pattern for schemas to search for. - /// When no db_schema_filter_pattern is provided, all schemas matching other filters are searched. - /// In the pattern string, two special characters can be used to denote matching rules: - /// - "%" means to match any substring with 0 or more characters. - /// - "_" means to match any one character. + /// Specifies a filter pattern for schemas to search for. + /// When no db_schema_filter_pattern is provided, all schemas matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. #[prost(string, optional, tag="2")] pub db_schema_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter pattern for tables to search for. - /// When no table_name_filter_pattern is provided, all tables matching other filters are searched. - /// In the pattern string, two special characters can be used to denote matching rules: - /// - "%" means to match any substring with 0 or more characters. - /// - "_" means to match any one character. + /// Specifies a filter pattern for tables to search for. + /// When no table_name_filter_pattern is provided, all tables matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. #[prost(string, optional, tag="3")] pub table_name_filter_pattern: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies a filter of table types which must match. - /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. - /// TABLE, VIEW, and SYSTEM TABLE are commonly supported. + /// Specifies a filter of table types which must match. + /// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. + /// TABLE, VIEW, and SYSTEM TABLE are commonly supported. #[prost(string, repeated, tag="4")] pub table_types: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// Specifies if the Arrow schema should be returned for found tables. + /// Specifies if the Arrow schema should be returned for found tables. #[prost(bool, tag="5")] pub include_schema: bool, } /// -/// Represents a request to retrieve the list of table types on a Flight SQL enabled backend. -/// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. -/// TABLE, VIEW, and SYSTEM TABLE are commonly supported. -/// Used in the command member of FlightDescriptor for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the catalog metadata request. +/// Represents a request to retrieve the list of table types on a Flight SQL enabled backend. +/// The table types depend on vendor/implementation. It is usually used to separate tables from views or system tables. +/// TABLE, VIEW, and SYSTEM TABLE are commonly supported. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < -/// table_type: utf8 not null -/// > -/// The returned data should be ordered by table_type. +/// The returned Arrow schema will be: +/// < +/// table_type: utf8 not null +/// > +/// The returned data should be ordered by table_type. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetTableTypes { } /// -/// Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the catalog metadata request. +/// Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < -/// catalog_name: utf8, -/// db_schema_name: utf8, -/// table_name: utf8 not null, -/// column_name: utf8 not null, -/// key_name: utf8, -/// key_sequence: int not null -/// > -/// The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. +/// The returned Arrow schema will be: +/// < +/// catalog_name: utf8, +/// db_schema_name: utf8, +/// table_name: utf8 not null, +/// column_name: utf8 not null, +/// key_name: utf8, +/// key_sequence: int not null +/// > +/// The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetPrimaryKeys { /// - /// Specifies the catalog to search for the table. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. + /// Specifies the catalog to search for the table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. #[prost(string, optional, tag="1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies the schema to search for the table. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. + /// Specifies the schema to search for the table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. #[prost(string, optional, tag="2")] pub db_schema: ::core::option::Option<::prost::alloc::string::String>, - /// Specifies the table to get the primary keys for. + /// Specifies the table to get the primary keys for. #[prost(string, tag="3")] pub table: ::prost::alloc::string::String, } /// -/// Represents a request to retrieve a description of the foreign key columns that reference the given table's -/// primary key columns (the foreign keys exported by a table) of a table on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the catalog metadata request. +/// Represents a request to retrieve a description of the foreign key columns that reference the given table's +/// primary key columns (the foreign keys exported by a table) of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < -/// pk_catalog_name: utf8, -/// pk_db_schema_name: utf8, -/// pk_table_name: utf8 not null, -/// pk_column_name: utf8 not null, -/// fk_catalog_name: utf8, -/// fk_db_schema_name: utf8, -/// fk_table_name: utf8 not null, -/// fk_column_name: utf8 not null, -/// key_sequence: int not null, -/// fk_key_name: utf8, -/// pk_key_name: utf8, -/// update_rule: uint1 not null, -/// delete_rule: uint1 not null -/// > -/// The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. -/// update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. +/// The returned Arrow schema will be: +/// < +/// pk_catalog_name: utf8, +/// pk_db_schema_name: utf8, +/// pk_table_name: utf8 not null, +/// pk_column_name: utf8 not null, +/// fk_catalog_name: utf8, +/// fk_db_schema_name: utf8, +/// fk_table_name: utf8 not null, +/// fk_column_name: utf8 not null, +/// key_sequence: int not null, +/// fk_key_name: utf8, +/// pk_key_name: utf8, +/// update_rule: uint1 not null, +/// delete_rule: uint1 not null +/// > +/// The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetExportedKeys { /// - /// Specifies the catalog to search for the foreign key table. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. + /// Specifies the catalog to search for the foreign key table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. #[prost(string, optional, tag="1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies the schema to search for the foreign key table. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. + /// Specifies the schema to search for the foreign key table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. #[prost(string, optional, tag="2")] pub db_schema: ::core::option::Option<::prost::alloc::string::String>, - /// Specifies the foreign key table to get the foreign keys for. + /// Specifies the foreign key table to get the foreign keys for. #[prost(string, tag="3")] pub table: ::prost::alloc::string::String, } /// -/// Represents a request to retrieve the foreign keys of a table on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the catalog metadata request. +/// Represents a request to retrieve the foreign keys of a table on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < -/// pk_catalog_name: utf8, -/// pk_db_schema_name: utf8, -/// pk_table_name: utf8 not null, -/// pk_column_name: utf8 not null, -/// fk_catalog_name: utf8, -/// fk_db_schema_name: utf8, -/// fk_table_name: utf8 not null, -/// fk_column_name: utf8 not null, -/// key_sequence: int not null, -/// fk_key_name: utf8, -/// pk_key_name: utf8, -/// update_rule: uint1 not null, -/// delete_rule: uint1 not null -/// > -/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. -/// update_rule and delete_rule returns a byte that is equivalent to actions: -/// - 0 = CASCADE -/// - 1 = RESTRICT -/// - 2 = SET NULL -/// - 3 = NO ACTION -/// - 4 = SET DEFAULT +/// The returned Arrow schema will be: +/// < +/// pk_catalog_name: utf8, +/// pk_db_schema_name: utf8, +/// pk_table_name: utf8 not null, +/// pk_column_name: utf8 not null, +/// fk_catalog_name: utf8, +/// fk_db_schema_name: utf8, +/// fk_table_name: utf8 not null, +/// fk_column_name: utf8 not null, +/// key_sequence: int not null, +/// fk_key_name: utf8, +/// pk_key_name: utf8, +/// update_rule: uint1 not null, +/// delete_rule: uint1 not null +/// > +/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions: +/// - 0 = CASCADE +/// - 1 = RESTRICT +/// - 2 = SET NULL +/// - 3 = NO ACTION +/// - 4 = SET DEFAULT #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetImportedKeys { /// - /// Specifies the catalog to search for the primary key table. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. + /// Specifies the catalog to search for the primary key table. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. #[prost(string, optional, tag="1")] pub catalog: ::core::option::Option<::prost::alloc::string::String>, /// - /// Specifies the schema to search for the primary key table. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. + /// Specifies the schema to search for the primary key table. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. #[prost(string, optional, tag="2")] pub db_schema: ::core::option::Option<::prost::alloc::string::String>, - /// Specifies the primary key table to get the foreign keys for. + /// Specifies the primary key table to get the foreign keys for. #[prost(string, tag="3")] pub table: ::prost::alloc::string::String, } /// -/// Represents a request to retrieve a description of the foreign key columns in the given foreign key table that -/// reference the primary key or the columns representing a unique constraint of the parent table (could be the same -/// or a different table) on a Flight SQL enabled backend. -/// Used in the command member of FlightDescriptor for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the catalog metadata request. +/// Represents a request to retrieve a description of the foreign key columns in the given foreign key table that +/// reference the primary key or the columns representing a unique constraint of the parent table (could be the same +/// or a different table) on a Flight SQL enabled backend. +/// Used in the command member of FlightDescriptor for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the catalog metadata request. /// -/// The returned Arrow schema will be: -/// < -/// pk_catalog_name: utf8, -/// pk_db_schema_name: utf8, -/// pk_table_name: utf8 not null, -/// pk_column_name: utf8 not null, -/// fk_catalog_name: utf8, -/// fk_db_schema_name: utf8, -/// fk_table_name: utf8 not null, -/// fk_column_name: utf8 not null, -/// key_sequence: int not null, -/// fk_key_name: utf8, -/// pk_key_name: utf8, -/// update_rule: uint1 not null, -/// delete_rule: uint1 not null -/// > -/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. -/// update_rule and delete_rule returns a byte that is equivalent to actions: -/// - 0 = CASCADE -/// - 1 = RESTRICT -/// - 2 = SET NULL -/// - 3 = NO ACTION -/// - 4 = SET DEFAULT +/// The returned Arrow schema will be: +/// < +/// pk_catalog_name: utf8, +/// pk_db_schema_name: utf8, +/// pk_table_name: utf8 not null, +/// pk_column_name: utf8 not null, +/// fk_catalog_name: utf8, +/// fk_db_schema_name: utf8, +/// fk_table_name: utf8 not null, +/// fk_column_name: utf8 not null, +/// key_sequence: int not null, +/// fk_key_name: utf8, +/// pk_key_name: utf8, +/// update_rule: uint1 not null, +/// delete_rule: uint1 not null +/// > +/// The returned data should be ordered by pk_catalog_name, pk_db_schema_name, pk_table_name, pk_key_name, then key_sequence. +/// update_rule and delete_rule returns a byte that is equivalent to actions: +/// - 0 = CASCADE +/// - 1 = RESTRICT +/// - 2 = SET NULL +/// - 3 = NO ACTION +/// - 4 = SET DEFAULT #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandGetCrossReference { - ///* - /// The catalog name where the parent table is. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. + /// * + /// The catalog name where the parent table is. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. #[prost(string, optional, tag="1")] pub pk_catalog: ::core::option::Option<::prost::alloc::string::String>, - ///* - /// The Schema name where the parent table is. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. + /// * + /// The Schema name where the parent table is. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. #[prost(string, optional, tag="2")] pub pk_db_schema: ::core::option::Option<::prost::alloc::string::String>, - ///* - /// The parent table name. It cannot be null. + /// * + /// The parent table name. It cannot be null. #[prost(string, tag="3")] pub pk_table: ::prost::alloc::string::String, - ///* - /// The catalog name where the foreign table is. - /// An empty string retrieves those without a catalog. - /// If omitted the catalog name should not be used to narrow the search. + /// * + /// The catalog name where the foreign table is. + /// An empty string retrieves those without a catalog. + /// If omitted the catalog name should not be used to narrow the search. #[prost(string, optional, tag="4")] pub fk_catalog: ::core::option::Option<::prost::alloc::string::String>, - ///* - /// The schema name where the foreign table is. - /// An empty string retrieves those without a schema. - /// If omitted the schema name should not be used to narrow the search. + /// * + /// The schema name where the foreign table is. + /// An empty string retrieves those without a schema. + /// If omitted the schema name should not be used to narrow the search. #[prost(string, optional, tag="5")] pub fk_db_schema: ::core::option::Option<::prost::alloc::string::String>, - ///* - /// The foreign table name. It cannot be null. + /// * + /// The foreign table name. It cannot be null. #[prost(string, tag="6")] pub fk_table: ::prost::alloc::string::String, } -// SQL Execution Action Messages +// SQL Execution Action Messages /// -/// Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. +/// Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionCreatePreparedStatementRequest { - /// The valid SQL string to create a prepared statement for. + /// The valid SQL string to create a prepared statement for. #[prost(string, tag="1")] pub query: ::prost::alloc::string::String, } /// -/// Wrap the result of a "GetPreparedStatement" action. +/// Wrap the result of a "GetPreparedStatement" action. /// -/// The resultant PreparedStatement can be closed either: -/// - Manually, through the "ClosePreparedStatement" action; -/// - Automatically, by a server timeout. +/// The resultant PreparedStatement can be closed either: +/// - Manually, through the "ClosePreparedStatement" action; +/// - Automatically, by a server timeout. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionCreatePreparedStatementResult { - /// Opaque handle for the prepared statement on the server. + /// Opaque handle for the prepared statement on the server. #[prost(bytes="vec", tag="1")] pub prepared_statement_handle: ::prost::alloc::vec::Vec, - /// If a result set generating query was provided, dataset_schema contains the - /// schema of the dataset as described in Schema.fbs::Schema, it is serialized as an IPC message. + /// If a result set generating query was provided, dataset_schema contains the + /// schema of the dataset as described in Schema.fbs::Schema, it is serialized as an IPC message. #[prost(bytes="vec", tag="2")] pub dataset_schema: ::prost::alloc::vec::Vec, - /// If the query provided contained parameters, parameter_schema contains the - /// schema of the expected parameters as described in Schema.fbs::Schema, it is serialized as an IPC message. + /// If the query provided contained parameters, parameter_schema contains the + /// schema of the expected parameters as described in Schema.fbs::Schema, it is serialized as an IPC message. #[prost(bytes="vec", tag="3")] pub parameter_schema: ::prost::alloc::vec::Vec, } /// -/// Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. -/// Closes server resources associated with the prepared statement handle. +/// Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. +/// Closes server resources associated with the prepared statement handle. #[derive(Clone, PartialEq, ::prost::Message)] pub struct ActionClosePreparedStatementRequest { - /// Opaque handle for the prepared statement on the server. + /// Opaque handle for the prepared statement on the server. #[prost(bytes="vec", tag="1")] pub prepared_statement_handle: ::prost::alloc::vec::Vec, } -// SQL Execution Messages. +// SQL Execution Messages. /// -/// Represents a SQL query. Used in the command member of FlightDescriptor -/// for the following RPC calls: -/// - GetSchema: return the Arrow schema of the query. -/// - GetFlightInfo: execute the query. +/// Represents a SQL query. Used in the command member of FlightDescriptor +/// for the following RPC calls: +/// - GetSchema: return the Arrow schema of the query. +/// - GetFlightInfo: execute the query. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandStatementQuery { - /// The SQL syntax. + /// The SQL syntax. #[prost(string, tag="1")] pub query: ::prost::alloc::string::String, } -///* -/// Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. -/// This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. +/// * +/// Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. +/// This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. #[derive(Clone, PartialEq, ::prost::Message)] pub struct TicketStatementQuery { - /// Unique identifier for the instance of the statement to execute. + /// Unique identifier for the instance of the statement to execute. #[prost(bytes="vec", tag="1")] pub statement_handle: ::prost::alloc::vec::Vec, } /// -/// Represents an instance of executing a prepared statement. Used in the command member of FlightDescriptor for -/// the following RPC calls: -/// - DoPut: bind parameter values. All of the bound parameter sets will be executed as a single atomic execution. -/// - GetFlightInfo: execute the prepared statement instance. +/// Represents an instance of executing a prepared statement. Used in the command member of FlightDescriptor for +/// the following RPC calls: +/// - DoPut: bind parameter values. All of the bound parameter sets will be executed as a single atomic execution. +/// - GetFlightInfo: execute the prepared statement instance. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandPreparedStatementQuery { - /// Opaque handle for the prepared statement on the server. + /// Opaque handle for the prepared statement on the server. #[prost(bytes="vec", tag="1")] pub prepared_statement_handle: ::prost::alloc::vec::Vec, } /// -/// Represents a SQL update query. Used in the command member of FlightDescriptor -/// for the the RPC call DoPut to cause the server to execute the included SQL update. +/// Represents a SQL update query. Used in the command member of FlightDescriptor +/// for the the RPC call DoPut to cause the server to execute the included SQL update. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandStatementUpdate { - /// The SQL syntax. + /// The SQL syntax. #[prost(string, tag="1")] pub query: ::prost::alloc::string::String, } /// -/// Represents a SQL update query. Used in the command member of FlightDescriptor -/// for the the RPC call DoPut to cause the server to execute the included -/// prepared statement handle as an update. +/// Represents a SQL update query. Used in the command member of FlightDescriptor +/// for the the RPC call DoPut to cause the server to execute the included +/// prepared statement handle as an update. #[derive(Clone, PartialEq, ::prost::Message)] pub struct CommandPreparedStatementUpdate { - /// Opaque handle for the prepared statement on the server. + /// Opaque handle for the prepared statement on the server. #[prost(bytes="vec", tag="1")] pub prepared_statement_handle: ::prost::alloc::vec::Vec, } /// -/// Returned from the RPC call DoPut when a CommandStatementUpdate -/// CommandPreparedStatementUpdate was in the request, containing -/// results from the update. +/// Returned from the RPC call DoPut when a CommandStatementUpdate +/// CommandPreparedStatementUpdate was in the request, containing +/// results from the update. #[derive(Clone, PartialEq, ::prost::Message)] pub struct DoPutUpdateResult { - /// The number of records updated. A return value of -1 represents - /// an unknown updated record count. + /// The number of records updated. A return value of -1 represents + /// an unknown updated record count. #[prost(int64, tag="1")] pub record_count: i64, } -/// Options for CommandGetSqlInfo. +/// Options for CommandGetSqlInfo. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlInfo { - // Server Information [0-500): Provides basic information about the Flight SQL Server. + // Server Information [0-500): Provides basic information about the Flight SQL Server. - /// Retrieves a UTF-8 string with the name of the Flight SQL Server. + /// Retrieves a UTF-8 string with the name of the Flight SQL Server. FlightSqlServerName = 0, - /// Retrieves a UTF-8 string with the native version of the Flight SQL Server. + /// Retrieves a UTF-8 string with the native version of the Flight SQL Server. FlightSqlServerVersion = 1, - /// Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. + /// Retrieves a UTF-8 string with the Arrow format version of the Flight SQL Server. FlightSqlServerArrowVersion = 2, - /// - /// Retrieves a boolean value indicating whether the Flight SQL Server is read only. + /// + /// Retrieves a boolean value indicating whether the Flight SQL Server is read only. /// - /// Returns: - /// - false: if read-write - /// - true: if read only + /// Returns: + /// - false: if read-write + /// - true: if read only FlightSqlServerReadOnly = 3, - // SQL Syntax Information [500-1000): provides information about SQL syntax supported by the Flight SQL Server. + // SQL Syntax Information [500-1000): provides information about SQL syntax supported by the Flight SQL Server. /// - /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of catalogs. + /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of catalogs. /// - /// Returns: - /// - false: if it doesn't support CREATE and DROP of catalogs. - /// - true: if it supports CREATE and DROP of catalogs. + /// Returns: + /// - false: if it doesn't support CREATE and DROP of catalogs. + /// - true: if it supports CREATE and DROP of catalogs. SqlDdlCatalog = 500, /// - /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of schemas. + /// Retrieves a boolean value indicating whether the Flight SQL Server supports CREATE and DROP of schemas. /// - /// Returns: - /// - false: if it doesn't support CREATE and DROP of schemas. - /// - true: if it supports CREATE and DROP of schemas. + /// Returns: + /// - false: if it doesn't support CREATE and DROP of schemas. + /// - true: if it supports CREATE and DROP of schemas. SqlDdlSchema = 501, /// - /// Indicates whether the Flight SQL Server supports CREATE and DROP of tables. + /// Indicates whether the Flight SQL Server supports CREATE and DROP of tables. /// - /// Returns: - /// - false: if it doesn't support CREATE and DROP of tables. - /// - true: if it supports CREATE and DROP of tables. + /// Returns: + /// - false: if it doesn't support CREATE and DROP of tables. + /// - true: if it supports CREATE and DROP of tables. SqlDdlTable = 502, /// - /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of catalog, table, schema and table names. + /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of catalog, table, schema and table names. /// - /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. SqlIdentifierCase = 503, - /// Retrieves a UTF-8 string with the supported character(s) used to surround a delimited identifier. + /// Retrieves a UTF-8 string with the supported character(s) used to surround a delimited identifier. SqlIdentifierQuoteChar = 504, /// - /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of quoted identifiers. + /// Retrieves a uint32 value representing the enu uint32 ordinal for the case sensitivity of quoted identifiers. /// - /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. + /// The possible values are listed in `arrow.flight.protocol.sql.SqlSupportedCaseSensitivity`. SqlQuotedIdentifierCase = 505, /// - /// Retrieves a boolean value indicating whether all tables are selectable. + /// Retrieves a boolean value indicating whether all tables are selectable. /// - /// Returns: - /// - false: if not all tables are selectable or if none are; - /// - true: if all tables are selectable. + /// Returns: + /// - false: if not all tables are selectable or if none are; + /// - true: if all tables are selectable. SqlAllTablesAreSelectable = 506, /// - /// Retrieves the null ordering. + /// Retrieves the null ordering. /// - /// Returns a uint32 ordinal for the null ordering being used, as described in - /// `arrow.flight.protocol.sql.SqlNullOrdering`. + /// Returns a uint32 ordinal for the null ordering being used, as described in + /// `arrow.flight.protocol.sql.SqlNullOrdering`. SqlNullOrdering = 507, - /// Retrieves a UTF-8 string list with values of the supported keywords. + /// Retrieves a UTF-8 string list with values of the supported keywords. SqlKeywords = 508, - /// Retrieves a UTF-8 string list with values of the supported numeric functions. + /// Retrieves a UTF-8 string list with values of the supported numeric functions. SqlNumericFunctions = 509, - /// Retrieves a UTF-8 string list with values of the supported string functions. + /// Retrieves a UTF-8 string list with values of the supported string functions. SqlStringFunctions = 510, - /// Retrieves a UTF-8 string list with values of the supported system functions. + /// Retrieves a UTF-8 string list with values of the supported system functions. SqlSystemFunctions = 511, - /// Retrieves a UTF-8 string list with values of the supported datetime functions. + /// Retrieves a UTF-8 string list with values of the supported datetime functions. SqlDatetimeFunctions = 512, /// - /// Retrieves the UTF-8 string that can be used to escape wildcard characters. - /// This is the string that can be used to escape '_' or '%' in the catalog search parameters that are a pattern - /// (and therefore use one of the wildcard characters). - /// The '_' character represents any single character; the '%' character represents any sequence of zero or more - /// characters. + /// Retrieves the UTF-8 string that can be used to escape wildcard characters. + /// This is the string that can be used to escape '_' or '%' in the catalog search parameters that are a pattern + /// (and therefore use one of the wildcard characters). + /// The '_' character represents any single character; the '%' character represents any sequence of zero or more + /// characters. SqlSearchStringEscape = 513, /// - /// Retrieves a UTF-8 string with all the "extra" characters that can be used in unquoted identifier names - /// (those beyond a-z, A-Z, 0-9 and _). + /// Retrieves a UTF-8 string with all the "extra" characters that can be used in unquoted identifier names + /// (those beyond a-z, A-Z, 0-9 and _). SqlExtraNameCharacters = 514, /// - /// Retrieves a boolean value indicating whether column aliasing is supported. - /// If so, the SQL AS clause can be used to provide names for computed columns or to provide alias names for columns - /// as required. + /// Retrieves a boolean value indicating whether column aliasing is supported. + /// If so, the SQL AS clause can be used to provide names for computed columns or to provide alias names for columns + /// as required. /// - /// Returns: - /// - false: if column aliasing is unsupported; - /// - true: if column aliasing is supported. + /// Returns: + /// - false: if column aliasing is unsupported; + /// - true: if column aliasing is supported. SqlSupportsColumnAliasing = 515, /// - /// Retrieves a boolean value indicating whether concatenations between null and non-null values being - /// null are supported. + /// Retrieves a boolean value indicating whether concatenations between null and non-null values being + /// null are supported. /// - /// - Returns: - /// - false: if concatenations between null and non-null values being null are unsupported; - /// - true: if concatenations between null and non-null values being null are supported. + /// - Returns: + /// - false: if concatenations between null and non-null values being null are unsupported; + /// - true: if concatenations between null and non-null values being null are supported. SqlNullPlusNullIsNull = 516, /// - /// Retrieves a map where the key is the type to convert from and the value is a list with the types to convert to, - /// indicating the supported conversions. Each key and each item on the list value is a value to a predefined type on - /// SqlSupportsConvert enum. - /// The returned map will be: map> + /// Retrieves a map where the key is the type to convert from and the value is a list with the types to convert to, + /// indicating the supported conversions. Each key and each item on the list value is a value to a predefined type on + /// SqlSupportsConvert enum. + /// The returned map will be: map> SqlSupportsConvert = 517, /// - /// Retrieves a boolean value indicating whether, when table correlation names are supported, - /// they are restricted to being different from the names of the tables. + /// Retrieves a boolean value indicating whether, when table correlation names are supported, + /// they are restricted to being different from the names of the tables. /// - /// Returns: - /// - false: if table correlation names are unsupported; - /// - true: if table correlation names are supported. + /// Returns: + /// - false: if table correlation names are unsupported; + /// - true: if table correlation names are supported. SqlSupportsTableCorrelationNames = 518, /// - /// Retrieves a boolean value indicating whether, when table correlation names are supported, - /// they are restricted to being different from the names of the tables. + /// Retrieves a boolean value indicating whether, when table correlation names are supported, + /// they are restricted to being different from the names of the tables. /// - /// Returns: - /// - false: if different table correlation names are unsupported; - /// - true: if different table correlation names are supported + /// Returns: + /// - false: if different table correlation names are unsupported; + /// - true: if different table correlation names are supported SqlSupportsDifferentTableCorrelationNames = 519, /// - /// Retrieves a boolean value indicating whether expressions in ORDER BY lists are supported. + /// Retrieves a boolean value indicating whether expressions in ORDER BY lists are supported. /// - /// Returns: - /// - false: if expressions in ORDER BY are unsupported; - /// - true: if expressions in ORDER BY are supported; + /// Returns: + /// - false: if expressions in ORDER BY are unsupported; + /// - true: if expressions in ORDER BY are supported; SqlSupportsExpressionsInOrderBy = 520, /// - /// Retrieves a boolean value indicating whether using a column that is not in the SELECT statement in a GROUP BY - /// clause is supported. + /// Retrieves a boolean value indicating whether using a column that is not in the SELECT statement in a GROUP BY + /// clause is supported. /// - /// Returns: - /// - false: if using a column that is not in the SELECT statement in a GROUP BY clause is unsupported; - /// - true: if using a column that is not in the SELECT statement in a GROUP BY clause is supported. + /// Returns: + /// - false: if using a column that is not in the SELECT statement in a GROUP BY clause is unsupported; + /// - true: if using a column that is not in the SELECT statement in a GROUP BY clause is supported. SqlSupportsOrderByUnrelated = 521, /// - /// Retrieves the supported GROUP BY commands; + /// Retrieves the supported GROUP BY commands; /// - /// Returns an int32 bitmask value representing the supported commands. - /// The returned bitmask should be parsed in order to retrieve the supported commands. + /// Returns an int32 bitmask value representing the supported commands. + /// The returned bitmask should be parsed in order to retrieve the supported commands. /// - /// For instance: - /// - return 0 (\b0) => [] (GROUP BY is unsupported); - /// - return 1 (\b1) => \[SQL_GROUP_BY_UNRELATED\]; - /// - return 2 (\b10) => \[SQL_GROUP_BY_BEYOND_SELECT\]; - /// - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT]. - /// Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. + /// For instance: + /// - return 0 (\b0) => [] (GROUP BY is unsupported); + /// - return 1 (\b1) => \[SQL_GROUP_BY_UNRELATED\]; + /// - return 2 (\b10) => \[SQL_GROUP_BY_BEYOND_SELECT\]; + /// - return 3 (\b11) => [SQL_GROUP_BY_UNRELATED, SQL_GROUP_BY_BEYOND_SELECT]. + /// Valid GROUP BY types are described under `arrow.flight.protocol.sql.SqlSupportedGroupBy`. SqlSupportedGroupBy = 522, /// - /// Retrieves a boolean value indicating whether specifying a LIKE escape clause is supported. + /// Retrieves a boolean value indicating whether specifying a LIKE escape clause is supported. /// - /// Returns: - /// - false: if specifying a LIKE escape clause is unsupported; - /// - true: if specifying a LIKE escape clause is supported. + /// Returns: + /// - false: if specifying a LIKE escape clause is unsupported; + /// - true: if specifying a LIKE escape clause is supported. SqlSupportsLikeEscapeClause = 523, /// - /// Retrieves a boolean value indicating whether columns may be defined as non-nullable. + /// Retrieves a boolean value indicating whether columns may be defined as non-nullable. /// - /// Returns: - /// - false: if columns cannot be defined as non-nullable; - /// - true: if columns may be defined as non-nullable. + /// Returns: + /// - false: if columns cannot be defined as non-nullable; + /// - true: if columns may be defined as non-nullable. SqlSupportsNonNullableColumns = 524, /// - /// Retrieves the supported SQL grammar level as per the ODBC specification. - /// - /// Returns an int32 bitmask value representing the supported SQL grammar level. - /// The returned bitmask should be parsed in order to retrieve the supported grammar levels. - /// - /// For instance: - /// - return 0 (\b0) => [] (SQL grammar is unsupported); - /// - return 1 (\b1) => \[SQL_MINIMUM_GRAMMAR\]; - /// - return 2 (\b10) => \[SQL_CORE_GRAMMAR\]; - /// - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR]; - /// - return 4 (\b100) => \[SQL_EXTENDED_GRAMMAR\]; - /// - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR]; - /// - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]; - /// - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]. - /// Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. + /// Retrieves the supported SQL grammar level as per the ODBC specification. + /// + /// Returns an int32 bitmask value representing the supported SQL grammar level. + /// The returned bitmask should be parsed in order to retrieve the supported grammar levels. + /// + /// For instance: + /// - return 0 (\b0) => [] (SQL grammar is unsupported); + /// - return 1 (\b1) => \[SQL_MINIMUM_GRAMMAR\]; + /// - return 2 (\b10) => \[SQL_CORE_GRAMMAR\]; + /// - return 3 (\b11) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR]; + /// - return 4 (\b100) => \[SQL_EXTENDED_GRAMMAR\]; + /// - return 5 (\b101) => [SQL_MINIMUM_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + /// - return 6 (\b110) => [SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]; + /// - return 7 (\b111) => [SQL_MINIMUM_GRAMMAR, SQL_CORE_GRAMMAR, SQL_EXTENDED_GRAMMAR]. + /// Valid SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedSqlGrammar`. SqlSupportedGrammar = 525, /// - /// Retrieves the supported ANSI92 SQL grammar level. - /// - /// Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. - /// The returned bitmask should be parsed in order to retrieve the supported commands. - /// - /// For instance: - /// - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported); - /// - return 1 (\b1) => \[ANSI92_ENTRY_SQL\]; - /// - return 2 (\b10) => \[ANSI92_INTERMEDIATE_SQL\]; - /// - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL]; - /// - return 4 (\b100) => \[ANSI92_FULL_SQL\]; - /// - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL]; - /// - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]; - /// - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]. - /// Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. + /// Retrieves the supported ANSI92 SQL grammar level. + /// + /// Returns an int32 bitmask value representing the supported ANSI92 SQL grammar level. + /// The returned bitmask should be parsed in order to retrieve the supported commands. + /// + /// For instance: + /// - return 0 (\b0) => [] (ANSI92 SQL grammar is unsupported); + /// - return 1 (\b1) => \[ANSI92_ENTRY_SQL\]; + /// - return 2 (\b10) => \[ANSI92_INTERMEDIATE_SQL\]; + /// - return 3 (\b11) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL]; + /// - return 4 (\b100) => \[ANSI92_FULL_SQL\]; + /// - return 5 (\b101) => [ANSI92_ENTRY_SQL, ANSI92_FULL_SQL]; + /// - return 6 (\b110) => [ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]; + /// - return 7 (\b111) => [ANSI92_ENTRY_SQL, ANSI92_INTERMEDIATE_SQL, ANSI92_FULL_SQL]. + /// Valid ANSI92 SQL grammar levels are described under `arrow.flight.protocol.sql.SupportedAnsi92SqlGrammarLevel`. SqlAnsi92SupportedLevel = 526, /// - /// Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility is supported. + /// Retrieves a boolean value indicating whether the SQL Integrity Enhancement Facility is supported. /// - /// Returns: - /// - false: if the SQL Integrity Enhancement Facility is supported; - /// - true: if the SQL Integrity Enhancement Facility is supported. + /// Returns: + /// - false: if the SQL Integrity Enhancement Facility is supported; + /// - true: if the SQL Integrity Enhancement Facility is supported. SqlSupportsIntegrityEnhancementFacility = 527, /// - /// Retrieves the support level for SQL OUTER JOINs. + /// Retrieves the support level for SQL OUTER JOINs. /// - /// Returns a uint3 uint32 ordinal for the SQL ordering being used, as described in - /// `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. + /// Returns a uint3 uint32 ordinal for the SQL ordering being used, as described in + /// `arrow.flight.protocol.sql.SqlOuterJoinsSupportLevel`. SqlOuterJoinsSupportLevel = 528, - /// Retrieves a UTF-8 string with the preferred term for "schema". + /// Retrieves a UTF-8 string with the preferred term for "schema". SqlSchemaTerm = 529, - /// Retrieves a UTF-8 string with the preferred term for "procedure". + /// Retrieves a UTF-8 string with the preferred term for "procedure". SqlProcedureTerm = 530, - /// Retrieves a UTF-8 string with the preferred term for "catalog". + /// Retrieves a UTF-8 string with the preferred term for "catalog". SqlCatalogTerm = 531, /// - /// Retrieves a boolean value indicating whether a catalog appears at the start of a fully qualified table name. + /// Retrieves a boolean value indicating whether a catalog appears at the start of a fully qualified table name. /// - /// - false: if a catalog does not appear at the start of a fully qualified table name; - /// - true: if a catalog appears at the start of a fully qualified table name. + /// - false: if a catalog does not appear at the start of a fully qualified table name; + /// - true: if a catalog appears at the start of a fully qualified table name. SqlCatalogAtStart = 532, /// - /// Retrieves the supported actions for a SQL schema. - /// - /// Returns an int32 bitmask value representing the supported actions for a SQL schema. - /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported actions for SQL schema); - /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; - /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; - /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; - /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; - /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. - /// Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + /// Retrieves the supported actions for a SQL schema. + /// + /// Returns an int32 bitmask value representing the supported actions for a SQL schema. + /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL schema. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported actions for SQL schema); + /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; + /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + /// Valid actions for a SQL schema described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. SqlSchemasSupportedActions = 533, /// - /// Retrieves the supported actions for a SQL schema. - /// - /// Returns an int32 bitmask value representing the supported actions for a SQL catalog. - /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported actions for SQL catalog); - /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; - /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; - /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; - /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; - /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; - /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. - /// Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. + /// Retrieves the supported actions for a SQL schema. + /// + /// Returns an int32 bitmask value representing the supported actions for a SQL catalog. + /// The returned bitmask should be parsed in order to retrieve the supported actions for a SQL catalog. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported actions for SQL catalog); + /// - return 1 (\b1) => \[SQL_ELEMENT_IN_PROCEDURE_CALLS\]; + /// - return 2 (\b10) => \[SQL_ELEMENT_IN_INDEX_DEFINITIONS\]; + /// - return 3 (\b11) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS]; + /// - return 4 (\b100) => \[SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS\]; + /// - return 5 (\b101) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + /// - return 6 (\b110) => [SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]; + /// - return 7 (\b111) => [SQL_ELEMENT_IN_PROCEDURE_CALLS, SQL_ELEMENT_IN_INDEX_DEFINITIONS, SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS]. + /// Valid actions for a SQL catalog are described under `arrow.flight.protocol.sql.SqlSupportedElementActions`. SqlCatalogsSupportedActions = 534, /// - /// Retrieves the supported SQL positioned commands. + /// Retrieves the supported SQL positioned commands. /// - /// Returns an int32 bitmask value representing the supported SQL positioned commands. - /// The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. + /// Returns an int32 bitmask value representing the supported SQL positioned commands. + /// The returned bitmask should be parsed in order to retrieve the supported SQL positioned commands. /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL positioned commands); - /// - return 1 (\b1) => \[SQL_POSITIONED_DELETE\]; - /// - return 2 (\b10) => \[SQL_POSITIONED_UPDATE\]; - /// - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE]. - /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. + /// For instance: + /// - return 0 (\b0) => [] (no supported SQL positioned commands); + /// - return 1 (\b1) => \[SQL_POSITIONED_DELETE\]; + /// - return 2 (\b10) => \[SQL_POSITIONED_UPDATE\]; + /// - return 3 (\b11) => [SQL_POSITIONED_DELETE, SQL_POSITIONED_UPDATE]. + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedPositionedCommands`. SqlSupportedPositionedCommands = 535, /// - /// Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are supported. + /// Retrieves a boolean value indicating whether SELECT FOR UPDATE statements are supported. /// - /// Returns: - /// - false: if SELECT FOR UPDATE statements are unsupported; - /// - true: if SELECT FOR UPDATE statements are supported. + /// Returns: + /// - false: if SELECT FOR UPDATE statements are unsupported; + /// - true: if SELECT FOR UPDATE statements are supported. SqlSelectForUpdateSupported = 536, /// - /// Retrieves a boolean value indicating whether stored procedure calls that use the stored procedure escape syntax - /// are supported. + /// Retrieves a boolean value indicating whether stored procedure calls that use the stored procedure escape syntax + /// are supported. /// - /// Returns: - /// - false: if stored procedure calls that use the stored procedure escape syntax are unsupported; - /// - true: if stored procedure calls that use the stored procedure escape syntax are supported. + /// Returns: + /// - false: if stored procedure calls that use the stored procedure escape syntax are unsupported; + /// - true: if stored procedure calls that use the stored procedure escape syntax are supported. SqlStoredProceduresSupported = 537, /// - /// Retrieves the supported SQL subqueries. - /// - /// Returns an int32 bitmask value representing the supported SQL subqueries. - /// The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL subqueries); - /// - return 1 (\b1) => \[SQL_SUBQUERIES_IN_COMPARISONS\]; - /// - return 2 (\b10) => \[SQL_SUBQUERIES_IN_EXISTS\]; - /// - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS]; - /// - return 4 (\b100) => \[SQL_SUBQUERIES_IN_INS\]; - /// - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS]; - /// - return 6 (\b110) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS]; - /// - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS]; - /// - return 8 (\b1000) => \[SQL_SUBQUERIES_IN_QUANTIFIEDS\]; - /// - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; - /// - ... - /// Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. + /// Retrieves the supported SQL subqueries. + /// + /// Returns an int32 bitmask value representing the supported SQL subqueries. + /// The returned bitmask should be parsed in order to retrieve the supported SQL subqueries. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported SQL subqueries); + /// - return 1 (\b1) => \[SQL_SUBQUERIES_IN_COMPARISONS\]; + /// - return 2 (\b10) => \[SQL_SUBQUERIES_IN_EXISTS\]; + /// - return 3 (\b11) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS]; + /// - return 4 (\b100) => \[SQL_SUBQUERIES_IN_INS\]; + /// - return 5 (\b101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS]; + /// - return 6 (\b110) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_EXISTS]; + /// - return 7 (\b111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS]; + /// - return 8 (\b1000) => \[SQL_SUBQUERIES_IN_QUANTIFIEDS\]; + /// - return 9 (\b1001) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 10 (\b1010) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 11 (\b1011) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 12 (\b1100) => [SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 13 (\b1101) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 14 (\b1110) => [SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - return 15 (\b1111) => [SQL_SUBQUERIES_IN_COMPARISONS, SQL_SUBQUERIES_IN_EXISTS, SQL_SUBQUERIES_IN_INS, SQL_SUBQUERIES_IN_QUANTIFIEDS]; + /// - ... + /// Valid SQL subqueries are described under `arrow.flight.protocol.sql.SqlSupportedSubqueries`. SqlSupportedSubqueries = 538, /// - /// Retrieves a boolean value indicating whether correlated subqueries are supported. + /// Retrieves a boolean value indicating whether correlated subqueries are supported. /// - /// Returns: - /// - false: if correlated subqueries are unsupported; - /// - true: if correlated subqueries are supported. + /// Returns: + /// - false: if correlated subqueries are unsupported; + /// - true: if correlated subqueries are supported. SqlCorrelatedSubqueriesSupported = 539, /// - /// Retrieves the supported SQL UNIONs. + /// Retrieves the supported SQL UNIONs. /// - /// Returns an int32 bitmask value representing the supported SQL UNIONs. - /// The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. + /// Returns an int32 bitmask value representing the supported SQL UNIONs. + /// The returned bitmask should be parsed in order to retrieve the supported SQL UNIONs. /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL positioned commands); - /// - return 1 (\b1) => \[SQL_UNION\]; - /// - return 2 (\b10) => \[SQL_UNION_ALL\]; - /// - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL]. - /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. + /// For instance: + /// - return 0 (\b0) => [] (no supported SQL positioned commands); + /// - return 1 (\b1) => \[SQL_UNION\]; + /// - return 2 (\b10) => \[SQL_UNION_ALL\]; + /// - return 3 (\b11) => [SQL_UNION, SQL_UNION_ALL]. + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlSupportedUnions`. SqlSupportedUnions = 540, - /// Retrieves a uint32 value representing the maximum number of hex characters allowed in an inline binary literal. + /// Retrieves a uint32 value representing the maximum number of hex characters allowed in an inline binary literal. SqlMaxBinaryLiteralLength = 541, - /// Retrieves a uint32 value representing the maximum number of characters allowed for a character literal. + /// Retrieves a uint32 value representing the maximum number of characters allowed for a character literal. SqlMaxCharLiteralLength = 542, - /// Retrieves a uint32 value representing the maximum number of characters allowed for a column name. + /// Retrieves a uint32 value representing the maximum number of characters allowed for a column name. SqlMaxColumnNameLength = 543, - /// Retrieves a uint32 value representing the the maximum number of columns allowed in a GROUP BY clause. + /// Retrieves a uint32 value representing the the maximum number of columns allowed in a GROUP BY clause. SqlMaxColumnsInGroupBy = 544, - /// Retrieves a uint32 value representing the maximum number of columns allowed in an index. + /// Retrieves a uint32 value representing the maximum number of columns allowed in an index. SqlMaxColumnsInIndex = 545, - /// Retrieves a uint32 value representing the maximum number of columns allowed in an ORDER BY clause. + /// Retrieves a uint32 value representing the maximum number of columns allowed in an ORDER BY clause. SqlMaxColumnsInOrderBy = 546, - /// Retrieves a uint32 value representing the maximum number of columns allowed in a SELECT list. + /// Retrieves a uint32 value representing the maximum number of columns allowed in a SELECT list. SqlMaxColumnsInSelect = 547, - /// Retrieves a uint32 value representing the maximum number of columns allowed in a table. + /// Retrieves a uint32 value representing the maximum number of columns allowed in a table. SqlMaxColumnsInTable = 548, - /// Retrieves a uint32 value representing the maximum number of concurrent connections possible. + /// Retrieves a uint32 value representing the maximum number of concurrent connections possible. SqlMaxConnections = 549, - /// Retrieves a uint32 value the maximum number of characters allowed in a cursor name. + /// Retrieves a uint32 value the maximum number of characters allowed in a cursor name. SqlMaxCursorNameLength = 550, /// - /// Retrieves a uint32 value representing the maximum number of bytes allowed for an index, - /// including all of the parts of the index. + /// Retrieves a uint32 value representing the maximum number of bytes allowed for an index, + /// including all of the parts of the index. SqlMaxIndexLength = 551, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a schema name. + /// Retrieves a uint32 value representing the maximum number of characters allowed in a schema name. SqlDbSchemaNameLength = 552, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a procedure name. + /// Retrieves a uint32 value representing the maximum number of characters allowed in a procedure name. SqlMaxProcedureNameLength = 553, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a catalog name. + /// Retrieves a uint32 value representing the maximum number of characters allowed in a catalog name. SqlMaxCatalogNameLength = 554, - /// Retrieves a uint32 value representing the maximum number of bytes allowed in a single row. + /// Retrieves a uint32 value representing the maximum number of bytes allowed in a single row. SqlMaxRowSize = 555, /// - /// Retrieves a boolean indicating whether the return value for the JDBC method getMaxRowSize includes the SQL - /// data types LONGVARCHAR and LONGVARBINARY. + /// Retrieves a boolean indicating whether the return value for the JDBC method getMaxRowSize includes the SQL + /// data types LONGVARCHAR and LONGVARBINARY. /// - /// Returns: - /// - false: if return value for the JDBC method getMaxRowSize does - /// not include the SQL data types LONGVARCHAR and LONGVARBINARY; - /// - true: if return value for the JDBC method getMaxRowSize includes - /// the SQL data types LONGVARCHAR and LONGVARBINARY. + /// Returns: + /// - false: if return value for the JDBC method getMaxRowSize does + /// not include the SQL data types LONGVARCHAR and LONGVARBINARY; + /// - true: if return value for the JDBC method getMaxRowSize includes + /// the SQL data types LONGVARCHAR and LONGVARBINARY. SqlMaxRowSizeIncludesBlobs = 556, /// - /// Retrieves a uint32 value representing the maximum number of characters allowed for an SQL statement; - /// a result of 0 (zero) means that there is no limit or the limit is not known. + /// Retrieves a uint32 value representing the maximum number of characters allowed for an SQL statement; + /// a result of 0 (zero) means that there is no limit or the limit is not known. SqlMaxStatementLength = 557, - /// Retrieves a uint32 value representing the maximum number of active statements that can be open at the same time. + /// Retrieves a uint32 value representing the maximum number of active statements that can be open at the same time. SqlMaxStatements = 558, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a table name. + /// Retrieves a uint32 value representing the maximum number of characters allowed in a table name. SqlMaxTableNameLength = 559, - /// Retrieves a uint32 value representing the maximum number of tables allowed in a SELECT statement. + /// Retrieves a uint32 value representing the maximum number of tables allowed in a SELECT statement. SqlMaxTablesInSelect = 560, - /// Retrieves a uint32 value representing the maximum number of characters allowed in a user name. + /// Retrieves a uint32 value representing the maximum number of characters allowed in a user name. SqlMaxUsernameLength = 561, /// - /// Retrieves this database's default transaction isolation level as described in - /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + /// Retrieves this database's default transaction isolation level as described in + /// `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. /// - /// Returns a uint32 ordinal for the SQL transaction isolation level. + /// Returns a uint32 ordinal for the SQL transaction isolation level. SqlDefaultTransactionIsolation = 562, /// - /// Retrieves a boolean value indicating whether transactions are supported. If not, invoking the method commit is a - /// noop, and the isolation level is `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. + /// Retrieves a boolean value indicating whether transactions are supported. If not, invoking the method commit is a + /// noop, and the isolation level is `arrow.flight.protocol.sql.SqlTransactionIsolationLevel.TRANSACTION_NONE`. /// - /// Returns: - /// - false: if transactions are unsupported; - /// - true: if transactions are supported. + /// Returns: + /// - false: if transactions are unsupported; + /// - true: if transactions are supported. SqlTransactionsSupported = 563, /// - /// Retrieves the supported transactions isolation levels. - /// - /// Returns an int32 bitmask value representing the supported transactions isolation levels. - /// The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported SQL transactions isolation levels); - /// - return 1 (\b1) => \[SQL_TRANSACTION_NONE\]; - /// - return 2 (\b10) => \[SQL_TRANSACTION_READ_UNCOMMITTED\]; - /// - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED]; - /// - return 4 (\b100) => \[SQL_TRANSACTION_REPEATABLE_READ\]; - /// - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 8 (\b1000) => \[SQL_TRANSACTION_REPEATABLE_READ\]; - /// - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; - /// - return 16 (\b10000) => \[SQL_TRANSACTION_SERIALIZABLE\]; - /// - ... - /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. + /// Retrieves the supported transactions isolation levels. + /// + /// Returns an int32 bitmask value representing the supported transactions isolation levels. + /// The returned bitmask should be parsed in order to retrieve the supported transactions isolation levels. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported SQL transactions isolation levels); + /// - return 1 (\b1) => \[SQL_TRANSACTION_NONE\]; + /// - return 2 (\b10) => \[SQL_TRANSACTION_READ_UNCOMMITTED\]; + /// - return 3 (\b11) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED]; + /// - return 4 (\b100) => \[SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 5 (\b101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 6 (\b110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 7 (\b111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 8 (\b1000) => \[SQL_TRANSACTION_REPEATABLE_READ\]; + /// - return 9 (\b1001) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 10 (\b1010) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 11 (\b1011) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 12 (\b1100) => [SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 13 (\b1101) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 14 (\b1110) => [SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 15 (\b1111) => [SQL_TRANSACTION_NONE, SQL_TRANSACTION_READ_UNCOMMITTED, SQL_TRANSACTION_REPEATABLE_READ, SQL_TRANSACTION_REPEATABLE_READ]; + /// - return 16 (\b10000) => \[SQL_TRANSACTION_SERIALIZABLE\]; + /// - ... + /// Valid SQL positioned commands are described under `arrow.flight.protocol.sql.SqlTransactionIsolationLevel`. SqlSupportedTransactionsIsolationLevels = 564, /// - /// Retrieves a boolean value indicating whether a data definition statement within a transaction forces - /// the transaction to commit. + /// Retrieves a boolean value indicating whether a data definition statement within a transaction forces + /// the transaction to commit. /// - /// Returns: - /// - false: if a data definition statement within a transaction does not force the transaction to commit; - /// - true: if a data definition statement within a transaction forces the transaction to commit. + /// Returns: + /// - false: if a data definition statement within a transaction does not force the transaction to commit; + /// - true: if a data definition statement within a transaction forces the transaction to commit. SqlDataDefinitionCausesTransactionCommit = 565, /// - /// Retrieves a boolean value indicating whether a data definition statement within a transaction is ignored. + /// Retrieves a boolean value indicating whether a data definition statement within a transaction is ignored. /// - /// Returns: - /// - false: if a data definition statement within a transaction is taken into account; - /// - true: a data definition statement within a transaction is ignored. + /// Returns: + /// - false: if a data definition statement within a transaction is taken into account; + /// - true: a data definition statement within a transaction is ignored. SqlDataDefinitionsInTransactionsIgnored = 566, /// - /// Retrieves an int32 bitmask value representing the supported result set types. - /// The returned bitmask should be parsed in order to retrieve the supported result set types. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported result set types); - /// - return 1 (\b1) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED\]; - /// - return 2 (\b10) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; - /// - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY]; - /// - return 4 (\b100) => \[SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; - /// - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; - /// - return 8 (\b1000) => \[SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE\]; - /// - ... - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. + /// Retrieves an int32 bitmask value representing the supported result set types. + /// The returned bitmask should be parsed in order to retrieve the supported result set types. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported result set types); + /// - return 1 (\b1) => \[SQL_RESULT_SET_TYPE_UNSPECIFIED\]; + /// - return 2 (\b10) => \[SQL_RESULT_SET_TYPE_FORWARD_ONLY\]; + /// - return 3 (\b11) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY]; + /// - return 4 (\b100) => \[SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE\]; + /// - return 5 (\b101) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + /// - return 6 (\b110) => [SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + /// - return 7 (\b111) => [SQL_RESULT_SET_TYPE_UNSPECIFIED, SQL_RESULT_SET_TYPE_FORWARD_ONLY, SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE]; + /// - return 8 (\b1000) => \[SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE\]; + /// - ... + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetType`. SqlSupportedResultSetTypes = 567, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_UNSPECIFIED`. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetUnspecified = 568, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY`. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetForwardOnly = 569, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE`. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetScrollSensitive = 570, /// - /// Returns an int32 bitmask value concurrency types supported for - /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. - /// - /// For instance: - /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) - /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] - /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] - /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] - /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] - /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] - /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. + /// Returns an int32 bitmask value concurrency types supported for + /// `arrow.flight.protocol.sql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE`. + /// + /// For instance: + /// - return 0 (\b0) => [] (no supported concurrency types for this result set type) + /// - return 1 (\b1) => \[SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED\] + /// - return 2 (\b10) => \[SQL_RESULT_SET_CONCURRENCY_READ_ONLY\] + /// - return 3 (\b11) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY] + /// - return 4 (\b100) => \[SQL_RESULT_SET_CONCURRENCY_UPDATABLE\] + /// - return 5 (\b101) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 6 (\b110) => [SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// - return 7 (\b111) => [SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED, SQL_RESULT_SET_CONCURRENCY_READ_ONLY, SQL_RESULT_SET_CONCURRENCY_UPDATABLE] + /// Valid result set types are described under `arrow.flight.protocol.sql.SqlSupportedResultSetConcurrency`. SqlSupportedConcurrenciesForResultSetScrollInsensitive = 571, /// - /// Retrieves a boolean value indicating whether this database supports batch updates. + /// Retrieves a boolean value indicating whether this database supports batch updates. /// - /// - false: if this database does not support batch updates; - /// - true: if this database supports batch updates. + /// - false: if this database does not support batch updates; + /// - true: if this database supports batch updates. SqlBatchUpdatesSupported = 572, /// - /// Retrieves a boolean value indicating whether this database supports savepoints. + /// Retrieves a boolean value indicating whether this database supports savepoints. /// - /// Returns: - /// - false: if this database does not support savepoints; - /// - true: if this database supports savepoints. + /// Returns: + /// - false: if this database does not support savepoints; + /// - true: if this database supports savepoints. SqlSavepointsSupported = 573, /// - /// Retrieves a boolean value indicating whether named parameters are supported in callable statements. + /// Retrieves a boolean value indicating whether named parameters are supported in callable statements. /// - /// Returns: - /// - false: if named parameters in callable statements are unsupported; - /// - true: if named parameters in callable statements are supported. + /// Returns: + /// - false: if named parameters in callable statements are unsupported; + /// - true: if named parameters in callable statements are supported. SqlNamedParametersSupported = 574, /// - /// Retrieves a boolean value indicating whether updates made to a LOB are made on a copy or directly to the LOB. + /// Retrieves a boolean value indicating whether updates made to a LOB are made on a copy or directly to the LOB. /// - /// Returns: - /// - false: if updates made to a LOB are made directly to the LOB; - /// - true: if updates made to a LOB are made on a copy. + /// Returns: + /// - false: if updates made to a LOB are made directly to the LOB; + /// - true: if updates made to a LOB are made on a copy. SqlLocatorsUpdateCopy = 575, /// - /// Retrieves a boolean value indicating whether invoking user-defined or vendor functions - /// using the stored procedure escape syntax is supported. + /// Retrieves a boolean value indicating whether invoking user-defined or vendor functions + /// using the stored procedure escape syntax is supported. /// - /// Returns: - /// - false: if invoking user-defined or vendor functions using the stored procedure escape syntax is unsupported; - /// - true: if invoking user-defined or vendor functions using the stored procedure escape syntax is supported. + /// Returns: + /// - false: if invoking user-defined or vendor functions using the stored procedure escape syntax is unsupported; + /// - true: if invoking user-defined or vendor functions using the stored procedure escape syntax is supported. SqlStoredFunctionsUsingCallSyntaxSupported = 576, } +impl SqlInfo { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlInfo::FlightSqlServerName => "FLIGHT_SQL_SERVER_NAME", + SqlInfo::FlightSqlServerVersion => "FLIGHT_SQL_SERVER_VERSION", + SqlInfo::FlightSqlServerArrowVersion => "FLIGHT_SQL_SERVER_ARROW_VERSION", + SqlInfo::FlightSqlServerReadOnly => "FLIGHT_SQL_SERVER_READ_ONLY", + SqlInfo::SqlDdlCatalog => "SQL_DDL_CATALOG", + SqlInfo::SqlDdlSchema => "SQL_DDL_SCHEMA", + SqlInfo::SqlDdlTable => "SQL_DDL_TABLE", + SqlInfo::SqlIdentifierCase => "SQL_IDENTIFIER_CASE", + SqlInfo::SqlIdentifierQuoteChar => "SQL_IDENTIFIER_QUOTE_CHAR", + SqlInfo::SqlQuotedIdentifierCase => "SQL_QUOTED_IDENTIFIER_CASE", + SqlInfo::SqlAllTablesAreSelectable => "SQL_ALL_TABLES_ARE_SELECTABLE", + SqlInfo::SqlNullOrdering => "SQL_NULL_ORDERING", + SqlInfo::SqlKeywords => "SQL_KEYWORDS", + SqlInfo::SqlNumericFunctions => "SQL_NUMERIC_FUNCTIONS", + SqlInfo::SqlStringFunctions => "SQL_STRING_FUNCTIONS", + SqlInfo::SqlSystemFunctions => "SQL_SYSTEM_FUNCTIONS", + SqlInfo::SqlDatetimeFunctions => "SQL_DATETIME_FUNCTIONS", + SqlInfo::SqlSearchStringEscape => "SQL_SEARCH_STRING_ESCAPE", + SqlInfo::SqlExtraNameCharacters => "SQL_EXTRA_NAME_CHARACTERS", + SqlInfo::SqlSupportsColumnAliasing => "SQL_SUPPORTS_COLUMN_ALIASING", + SqlInfo::SqlNullPlusNullIsNull => "SQL_NULL_PLUS_NULL_IS_NULL", + SqlInfo::SqlSupportsConvert => "SQL_SUPPORTS_CONVERT", + SqlInfo::SqlSupportsTableCorrelationNames => "SQL_SUPPORTS_TABLE_CORRELATION_NAMES", + SqlInfo::SqlSupportsDifferentTableCorrelationNames => "SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES", + SqlInfo::SqlSupportsExpressionsInOrderBy => "SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY", + SqlInfo::SqlSupportsOrderByUnrelated => "SQL_SUPPORTS_ORDER_BY_UNRELATED", + SqlInfo::SqlSupportedGroupBy => "SQL_SUPPORTED_GROUP_BY", + SqlInfo::SqlSupportsLikeEscapeClause => "SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE", + SqlInfo::SqlSupportsNonNullableColumns => "SQL_SUPPORTS_NON_NULLABLE_COLUMNS", + SqlInfo::SqlSupportedGrammar => "SQL_SUPPORTED_GRAMMAR", + SqlInfo::SqlAnsi92SupportedLevel => "SQL_ANSI92_SUPPORTED_LEVEL", + SqlInfo::SqlSupportsIntegrityEnhancementFacility => "SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY", + SqlInfo::SqlOuterJoinsSupportLevel => "SQL_OUTER_JOINS_SUPPORT_LEVEL", + SqlInfo::SqlSchemaTerm => "SQL_SCHEMA_TERM", + SqlInfo::SqlProcedureTerm => "SQL_PROCEDURE_TERM", + SqlInfo::SqlCatalogTerm => "SQL_CATALOG_TERM", + SqlInfo::SqlCatalogAtStart => "SQL_CATALOG_AT_START", + SqlInfo::SqlSchemasSupportedActions => "SQL_SCHEMAS_SUPPORTED_ACTIONS", + SqlInfo::SqlCatalogsSupportedActions => "SQL_CATALOGS_SUPPORTED_ACTIONS", + SqlInfo::SqlSupportedPositionedCommands => "SQL_SUPPORTED_POSITIONED_COMMANDS", + SqlInfo::SqlSelectForUpdateSupported => "SQL_SELECT_FOR_UPDATE_SUPPORTED", + SqlInfo::SqlStoredProceduresSupported => "SQL_STORED_PROCEDURES_SUPPORTED", + SqlInfo::SqlSupportedSubqueries => "SQL_SUPPORTED_SUBQUERIES", + SqlInfo::SqlCorrelatedSubqueriesSupported => "SQL_CORRELATED_SUBQUERIES_SUPPORTED", + SqlInfo::SqlSupportedUnions => "SQL_SUPPORTED_UNIONS", + SqlInfo::SqlMaxBinaryLiteralLength => "SQL_MAX_BINARY_LITERAL_LENGTH", + SqlInfo::SqlMaxCharLiteralLength => "SQL_MAX_CHAR_LITERAL_LENGTH", + SqlInfo::SqlMaxColumnNameLength => "SQL_MAX_COLUMN_NAME_LENGTH", + SqlInfo::SqlMaxColumnsInGroupBy => "SQL_MAX_COLUMNS_IN_GROUP_BY", + SqlInfo::SqlMaxColumnsInIndex => "SQL_MAX_COLUMNS_IN_INDEX", + SqlInfo::SqlMaxColumnsInOrderBy => "SQL_MAX_COLUMNS_IN_ORDER_BY", + SqlInfo::SqlMaxColumnsInSelect => "SQL_MAX_COLUMNS_IN_SELECT", + SqlInfo::SqlMaxColumnsInTable => "SQL_MAX_COLUMNS_IN_TABLE", + SqlInfo::SqlMaxConnections => "SQL_MAX_CONNECTIONS", + SqlInfo::SqlMaxCursorNameLength => "SQL_MAX_CURSOR_NAME_LENGTH", + SqlInfo::SqlMaxIndexLength => "SQL_MAX_INDEX_LENGTH", + SqlInfo::SqlDbSchemaNameLength => "SQL_DB_SCHEMA_NAME_LENGTH", + SqlInfo::SqlMaxProcedureNameLength => "SQL_MAX_PROCEDURE_NAME_LENGTH", + SqlInfo::SqlMaxCatalogNameLength => "SQL_MAX_CATALOG_NAME_LENGTH", + SqlInfo::SqlMaxRowSize => "SQL_MAX_ROW_SIZE", + SqlInfo::SqlMaxRowSizeIncludesBlobs => "SQL_MAX_ROW_SIZE_INCLUDES_BLOBS", + SqlInfo::SqlMaxStatementLength => "SQL_MAX_STATEMENT_LENGTH", + SqlInfo::SqlMaxStatements => "SQL_MAX_STATEMENTS", + SqlInfo::SqlMaxTableNameLength => "SQL_MAX_TABLE_NAME_LENGTH", + SqlInfo::SqlMaxTablesInSelect => "SQL_MAX_TABLES_IN_SELECT", + SqlInfo::SqlMaxUsernameLength => "SQL_MAX_USERNAME_LENGTH", + SqlInfo::SqlDefaultTransactionIsolation => "SQL_DEFAULT_TRANSACTION_ISOLATION", + SqlInfo::SqlTransactionsSupported => "SQL_TRANSACTIONS_SUPPORTED", + SqlInfo::SqlSupportedTransactionsIsolationLevels => "SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS", + SqlInfo::SqlDataDefinitionCausesTransactionCommit => "SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT", + SqlInfo::SqlDataDefinitionsInTransactionsIgnored => "SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED", + SqlInfo::SqlSupportedResultSetTypes => "SQL_SUPPORTED_RESULT_SET_TYPES", + SqlInfo::SqlSupportedConcurrenciesForResultSetUnspecified => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_UNSPECIFIED", + SqlInfo::SqlSupportedConcurrenciesForResultSetForwardOnly => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_FORWARD_ONLY", + SqlInfo::SqlSupportedConcurrenciesForResultSetScrollSensitive => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_SENSITIVE", + SqlInfo::SqlSupportedConcurrenciesForResultSetScrollInsensitive => "SQL_SUPPORTED_CONCURRENCIES_FOR_RESULT_SET_SCROLL_INSENSITIVE", + SqlInfo::SqlBatchUpdatesSupported => "SQL_BATCH_UPDATES_SUPPORTED", + SqlInfo::SqlSavepointsSupported => "SQL_SAVEPOINTS_SUPPORTED", + SqlInfo::SqlNamedParametersSupported => "SQL_NAMED_PARAMETERS_SUPPORTED", + SqlInfo::SqlLocatorsUpdateCopy => "SQL_LOCATORS_UPDATE_COPY", + SqlInfo::SqlStoredFunctionsUsingCallSyntaxSupported => "SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportedCaseSensitivity { @@ -1011,6 +1102,20 @@ pub enum SqlSupportedCaseSensitivity { SqlCaseSensitivityUppercase = 2, SqlCaseSensitivityLowercase = 3, } +impl SqlSupportedCaseSensitivity { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedCaseSensitivity::SqlCaseSensitivityUnknown => "SQL_CASE_SENSITIVITY_UNKNOWN", + SqlSupportedCaseSensitivity::SqlCaseSensitivityCaseInsensitive => "SQL_CASE_SENSITIVITY_CASE_INSENSITIVE", + SqlSupportedCaseSensitivity::SqlCaseSensitivityUppercase => "SQL_CASE_SENSITIVITY_UPPERCASE", + SqlSupportedCaseSensitivity::SqlCaseSensitivityLowercase => "SQL_CASE_SENSITIVITY_LOWERCASE", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlNullOrdering { @@ -1019,6 +1124,20 @@ pub enum SqlNullOrdering { SqlNullsSortedAtStart = 2, SqlNullsSortedAtEnd = 3, } +impl SqlNullOrdering { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlNullOrdering::SqlNullsSortedHigh => "SQL_NULLS_SORTED_HIGH", + SqlNullOrdering::SqlNullsSortedLow => "SQL_NULLS_SORTED_LOW", + SqlNullOrdering::SqlNullsSortedAtStart => "SQL_NULLS_SORTED_AT_START", + SqlNullOrdering::SqlNullsSortedAtEnd => "SQL_NULLS_SORTED_AT_END", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SupportedSqlGrammar { @@ -1026,6 +1145,19 @@ pub enum SupportedSqlGrammar { SqlCoreGrammar = 1, SqlExtendedGrammar = 2, } +impl SupportedSqlGrammar { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SupportedSqlGrammar::SqlMinimumGrammar => "SQL_MINIMUM_GRAMMAR", + SupportedSqlGrammar::SqlCoreGrammar => "SQL_CORE_GRAMMAR", + SupportedSqlGrammar::SqlExtendedGrammar => "SQL_EXTENDED_GRAMMAR", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SupportedAnsi92SqlGrammarLevel { @@ -1033,6 +1165,19 @@ pub enum SupportedAnsi92SqlGrammarLevel { Ansi92IntermediateSql = 1, Ansi92FullSql = 2, } +impl SupportedAnsi92SqlGrammarLevel { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SupportedAnsi92SqlGrammarLevel::Ansi92EntrySql => "ANSI92_ENTRY_SQL", + SupportedAnsi92SqlGrammarLevel::Ansi92IntermediateSql => "ANSI92_INTERMEDIATE_SQL", + SupportedAnsi92SqlGrammarLevel::Ansi92FullSql => "ANSI92_FULL_SQL", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlOuterJoinsSupportLevel { @@ -1040,12 +1185,37 @@ pub enum SqlOuterJoinsSupportLevel { SqlLimitedOuterJoins = 1, SqlFullOuterJoins = 2, } +impl SqlOuterJoinsSupportLevel { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlOuterJoinsSupportLevel::SqlJoinsUnsupported => "SQL_JOINS_UNSUPPORTED", + SqlOuterJoinsSupportLevel::SqlLimitedOuterJoins => "SQL_LIMITED_OUTER_JOINS", + SqlOuterJoinsSupportLevel::SqlFullOuterJoins => "SQL_FULL_OUTER_JOINS", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportedGroupBy { SqlGroupByUnrelated = 0, SqlGroupByBeyondSelect = 1, } +impl SqlSupportedGroupBy { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedGroupBy::SqlGroupByUnrelated => "SQL_GROUP_BY_UNRELATED", + SqlSupportedGroupBy::SqlGroupByBeyondSelect => "SQL_GROUP_BY_BEYOND_SELECT", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportedElementActions { @@ -1053,12 +1223,37 @@ pub enum SqlSupportedElementActions { SqlElementInIndexDefinitions = 1, SqlElementInPrivilegeDefinitions = 2, } +impl SqlSupportedElementActions { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedElementActions::SqlElementInProcedureCalls => "SQL_ELEMENT_IN_PROCEDURE_CALLS", + SqlSupportedElementActions::SqlElementInIndexDefinitions => "SQL_ELEMENT_IN_INDEX_DEFINITIONS", + SqlSupportedElementActions::SqlElementInPrivilegeDefinitions => "SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportedPositionedCommands { SqlPositionedDelete = 0, SqlPositionedUpdate = 1, } +impl SqlSupportedPositionedCommands { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedPositionedCommands::SqlPositionedDelete => "SQL_POSITIONED_DELETE", + SqlSupportedPositionedCommands::SqlPositionedUpdate => "SQL_POSITIONED_UPDATE", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportedSubqueries { @@ -1067,12 +1262,38 @@ pub enum SqlSupportedSubqueries { SqlSubqueriesInIns = 2, SqlSubqueriesInQuantifieds = 3, } +impl SqlSupportedSubqueries { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedSubqueries::SqlSubqueriesInComparisons => "SQL_SUBQUERIES_IN_COMPARISONS", + SqlSupportedSubqueries::SqlSubqueriesInExists => "SQL_SUBQUERIES_IN_EXISTS", + SqlSupportedSubqueries::SqlSubqueriesInIns => "SQL_SUBQUERIES_IN_INS", + SqlSupportedSubqueries::SqlSubqueriesInQuantifieds => "SQL_SUBQUERIES_IN_QUANTIFIEDS", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportedUnions { SqlUnion = 0, SqlUnionAll = 1, } +impl SqlSupportedUnions { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedUnions::SqlUnion => "SQL_UNION", + SqlSupportedUnions::SqlUnionAll => "SQL_UNION_ALL", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlTransactionIsolationLevel { @@ -1082,6 +1303,21 @@ pub enum SqlTransactionIsolationLevel { SqlTransactionRepeatableRead = 3, SqlTransactionSerializable = 4, } +impl SqlTransactionIsolationLevel { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlTransactionIsolationLevel::SqlTransactionNone => "SQL_TRANSACTION_NONE", + SqlTransactionIsolationLevel::SqlTransactionReadUncommitted => "SQL_TRANSACTION_READ_UNCOMMITTED", + SqlTransactionIsolationLevel::SqlTransactionReadCommitted => "SQL_TRANSACTION_READ_COMMITTED", + SqlTransactionIsolationLevel::SqlTransactionRepeatableRead => "SQL_TRANSACTION_REPEATABLE_READ", + SqlTransactionIsolationLevel::SqlTransactionSerializable => "SQL_TRANSACTION_SERIALIZABLE", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportedTransactions { @@ -1089,6 +1325,19 @@ pub enum SqlSupportedTransactions { SqlDataDefinitionTransactions = 1, SqlDataManipulationTransactions = 2, } +impl SqlSupportedTransactions { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedTransactions::SqlTransactionUnspecified => "SQL_TRANSACTION_UNSPECIFIED", + SqlSupportedTransactions::SqlDataDefinitionTransactions => "SQL_DATA_DEFINITION_TRANSACTIONS", + SqlSupportedTransactions::SqlDataManipulationTransactions => "SQL_DATA_MANIPULATION_TRANSACTIONS", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportedResultSetType { @@ -1097,6 +1346,20 @@ pub enum SqlSupportedResultSetType { SqlResultSetTypeScrollInsensitive = 2, SqlResultSetTypeScrollSensitive = 3, } +impl SqlSupportedResultSetType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedResultSetType::SqlResultSetTypeUnspecified => "SQL_RESULT_SET_TYPE_UNSPECIFIED", + SqlSupportedResultSetType::SqlResultSetTypeForwardOnly => "SQL_RESULT_SET_TYPE_FORWARD_ONLY", + SqlSupportedResultSetType::SqlResultSetTypeScrollInsensitive => "SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE", + SqlSupportedResultSetType::SqlResultSetTypeScrollSensitive => "SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportedResultSetConcurrency { @@ -1104,6 +1367,19 @@ pub enum SqlSupportedResultSetConcurrency { SqlResultSetConcurrencyReadOnly = 1, SqlResultSetConcurrencyUpdatable = 2, } +impl SqlSupportedResultSetConcurrency { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyUnspecified => "SQL_RESULT_SET_CONCURRENCY_UNSPECIFIED", + SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyReadOnly => "SQL_RESULT_SET_CONCURRENCY_READ_ONLY", + SqlSupportedResultSetConcurrency::SqlResultSetConcurrencyUpdatable => "SQL_RESULT_SET_CONCURRENCY_UPDATABLE", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum SqlSupportsConvert { @@ -1128,6 +1404,36 @@ pub enum SqlSupportsConvert { SqlConvertVarbinary = 18, SqlConvertVarchar = 19, } +impl SqlSupportsConvert { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + SqlSupportsConvert::SqlConvertBigint => "SQL_CONVERT_BIGINT", + SqlSupportsConvert::SqlConvertBinary => "SQL_CONVERT_BINARY", + SqlSupportsConvert::SqlConvertBit => "SQL_CONVERT_BIT", + SqlSupportsConvert::SqlConvertChar => "SQL_CONVERT_CHAR", + SqlSupportsConvert::SqlConvertDate => "SQL_CONVERT_DATE", + SqlSupportsConvert::SqlConvertDecimal => "SQL_CONVERT_DECIMAL", + SqlSupportsConvert::SqlConvertFloat => "SQL_CONVERT_FLOAT", + SqlSupportsConvert::SqlConvertInteger => "SQL_CONVERT_INTEGER", + SqlSupportsConvert::SqlConvertIntervalDayTime => "SQL_CONVERT_INTERVAL_DAY_TIME", + SqlSupportsConvert::SqlConvertIntervalYearMonth => "SQL_CONVERT_INTERVAL_YEAR_MONTH", + SqlSupportsConvert::SqlConvertLongvarbinary => "SQL_CONVERT_LONGVARBINARY", + SqlSupportsConvert::SqlConvertLongvarchar => "SQL_CONVERT_LONGVARCHAR", + SqlSupportsConvert::SqlConvertNumeric => "SQL_CONVERT_NUMERIC", + SqlSupportsConvert::SqlConvertReal => "SQL_CONVERT_REAL", + SqlSupportsConvert::SqlConvertSmallint => "SQL_CONVERT_SMALLINT", + SqlSupportsConvert::SqlConvertTime => "SQL_CONVERT_TIME", + SqlSupportsConvert::SqlConvertTimestamp => "SQL_CONVERT_TIMESTAMP", + SqlSupportsConvert::SqlConvertTinyint => "SQL_CONVERT_TINYINT", + SqlSupportsConvert::SqlConvertVarbinary => "SQL_CONVERT_VARBINARY", + SqlSupportsConvert::SqlConvertVarchar => "SQL_CONVERT_VARCHAR", + } + } +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum UpdateDeleteRules { @@ -1137,3 +1443,18 @@ pub enum UpdateDeleteRules { NoAction = 3, SetDefault = 4, } +impl UpdateDeleteRules { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + UpdateDeleteRules::Cascade => "CASCADE", + UpdateDeleteRules::Restrict => "RESTRICT", + UpdateDeleteRules::SetNull => "SET_NULL", + UpdateDeleteRules::NoAction => "NO_ACTION", + UpdateDeleteRules::SetDefault => "SET_DEFAULT", + } + } +} diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 87e282b103b7..f3208d376497 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -47,81 +47,107 @@ pub trait FlightSqlService: /// When impl FlightSqlService, you can always set FlightService to Self type FlightService: FlightService; + /// Accept authentication and return a token + /// + async fn do_handshake( + &self, + _request: Request>, + ) -> Result< + Response> + Send>>>, + Status, + > { + Err(Status::unimplemented( + "Handshake has no default implementation", + )) + } + + /// Implementors may override to handle additional calls to do_get() + async fn do_get_fallback( + &self, + _request: Request, + message: prost_types::Any, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented(format!( + "do_get: The defined request is invalid: {}", + message.type_url + ))) + } + /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( &self, query: CommandStatementQuery, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for executing an already created prepared statement. async fn get_flight_info_prepared_statement( &self, query: CommandPreparedStatementQuery, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for listing catalogs. async fn get_flight_info_catalogs( &self, query: CommandGetCatalogs, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for listing schemas. async fn get_flight_info_schemas( &self, query: CommandGetDbSchemas, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for listing tables. async fn get_flight_info_tables( &self, query: CommandGetTables, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about the table types. async fn get_flight_info_table_types( &self, query: CommandGetTableTypes, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo for retrieving other information (See SqlInfo). async fn get_flight_info_sql_info( &self, query: CommandGetSqlInfo, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about primary and foreign keys. async fn get_flight_info_primary_keys( &self, query: CommandGetPrimaryKeys, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about exported keys. async fn get_flight_info_exported_keys( &self, query: CommandGetExportedKeys, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about imported keys. async fn get_flight_info_imported_keys( &self, query: CommandGetImportedKeys, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; /// Get a FlightInfo to extract information about cross reference. async fn get_flight_info_cross_reference( &self, query: CommandGetCrossReference, - request: FlightDescriptor, + request: Request, ) -> Result, Status>; // do_get @@ -130,66 +156,77 @@ pub trait FlightSqlService: async fn do_get_statement( &self, ticket: TicketStatementQuery, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the prepared statement query results. async fn do_get_prepared_statement( &self, query: CommandPreparedStatementQuery, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of catalogs. async fn do_get_catalogs( &self, query: CommandGetCatalogs, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of schemas. async fn do_get_schemas( &self, query: CommandGetDbSchemas, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of tables. async fn do_get_tables( &self, query: CommandGetTables, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the table types. async fn do_get_table_types( &self, query: CommandGetTableTypes, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of SqlInfo results. async fn do_get_sql_info( &self, query: CommandGetSqlInfo, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the primary and foreign keys. async fn do_get_primary_keys( &self, query: CommandGetPrimaryKeys, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the exported keys. async fn do_get_exported_keys( &self, query: CommandGetExportedKeys, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the imported keys. async fn do_get_imported_keys( &self, query: CommandGetImportedKeys, + request: Request, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the cross reference. async fn do_get_cross_reference( &self, query: CommandGetCrossReference, + request: Request, ) -> Result::DoGetStream>, Status>; // do_put @@ -198,20 +235,21 @@ pub trait FlightSqlService: async fn do_put_statement_update( &self, ticket: CommandStatementUpdate, + request: Request>, ) -> Result; /// Bind parameters to given prepared statement. async fn do_put_prepared_statement_query( &self, query: CommandPreparedStatementQuery, - request: Streaming, + request: Request>, ) -> Result::DoPutStream>, Status>; /// Execute an update SQL prepared statement. async fn do_put_prepared_statement_update( &self, query: CommandPreparedStatementUpdate, - request: Streaming, + request: Request>, ) -> Result; // do_action @@ -220,12 +258,14 @@ pub trait FlightSqlService: async fn do_action_create_prepared_statement( &self, query: ActionCreatePreparedStatementRequest, + request: Request, ) -> Result; /// Close a prepared statement. async fn do_action_close_prepared_statement( &self, query: ActionClosePreparedStatementRequest, + request: Request, ); /// Register a new SqlInfo result, making it available when calling GetSqlInfo. @@ -256,9 +296,10 @@ where async fn handshake( &self, - _request: Request>, + request: Request>, ) -> Result, Status> { - Err(Status::unimplemented("Not yet implemented")) + let res = self.do_handshake(request).await?; + Ok(res) } async fn list_flights( @@ -272,124 +313,92 @@ where &self, request: Request, ) -> Result, Status> { - let request = request.into_inner(); - let any: prost_types::Any = - prost::Message::decode(&*request.cmd).map_err(decode_error_to_status)?; + let message: prost_types::Any = + Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; - if any.is::() { - return self - .get_flight_info_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_statement(token, request).await; } - if any.is::() { + if message.is::() { + let handle = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); return self - .get_flight_info_prepared_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) + .get_flight_info_prepared_statement(handle, request) .await; } - if any.is::() { - return self - .get_flight_info_catalogs( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_catalogs(token, request).await; } - if any.is::() { - return self - .get_flight_info_schemas( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_schemas(token, request).await; } - if any.is::() { - return self - .get_flight_info_tables( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_tables(token, request).await; } - if any.is::() { - return self - .get_flight_info_table_types( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_table_types(token, request).await; } - if any.is::() { - return self - .get_flight_info_sql_info( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_sql_info(token, request).await; } - if any.is::() { - return self - .get_flight_info_primary_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_primary_keys(token, request).await; } - if any.is::() { - return self - .get_flight_info_exported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_exported_keys(token, request).await; } - if any.is::() { - return self - .get_flight_info_imported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_imported_keys(token, request).await; } - if any.is::() { - return self - .get_flight_info_cross_reference( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_cross_reference(token, request).await; } Err(Status::unimplemented(format!( - "get_flight_info: The defined request is invalid: {:?}", - String::from_utf8(any.encode_to_vec()).unwrap() + "get_flight_info: The defined request is invalid: {}", + message.type_url ))) } @@ -404,168 +413,97 @@ where &self, request: Request, ) -> Result, Status> { - let request = request.into_inner(); - let any: prost_types::Any = - prost::Message::decode(&*request.ticket).map_err(decode_error_to_status)?; + let msg: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket) + .map_err(decode_error_to_status)?; - if any.is::() { - return self - .do_get_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + fn unpack(msg: prost_types::Any) -> Result { + msg.unpack() + .map_err(arrow_error_to_status)? + .ok_or_else(|| Status::internal("Expected a command, but found none.")) } - if any.is::() { - return self - .do_get_prepared_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + + if msg.is::() { + return self.do_get_statement(unpack(msg)?, request).await; } - if any.is::() { - return self - .do_get_catalogs( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + if msg.is::() { + return self.do_get_prepared_statement(unpack(msg)?, request).await; } - if any.is::() { - return self - .do_get_schemas( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + if msg.is::() { + return self.do_get_catalogs(unpack(msg)?, request).await; } - if any.is::() { - return self - .do_get_tables( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + if msg.is::() { + return self.do_get_schemas(unpack(msg)?, request).await; } - if any.is::() { - return self - .do_get_table_types( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + if msg.is::() { + return self.do_get_tables(unpack(msg)?, request).await; } - if any.is::() { - return self - .do_get_sql_info( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + if msg.is::() { + return self.do_get_table_types(unpack(msg)?, request).await; } - if any.is::() { - return self - .do_get_primary_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + if msg.is::() { + return self.do_get_sql_info(unpack(msg)?, request).await; } - if any.is::() { - return self - .do_get_exported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + if msg.is::() { + return self.do_get_primary_keys(unpack(msg)?, request).await; } - if any.is::() { - return self - .do_get_imported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + if msg.is::() { + return self.do_get_exported_keys(unpack(msg)?, request).await; } - if any.is::() { - return self - .do_get_cross_reference( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + if msg.is::() { + return self.do_get_imported_keys(unpack(msg)?, request).await; + } + if msg.is::() { + return self.do_get_cross_reference(unpack(msg)?, request).await; } - Err(Status::unimplemented(format!( - "do_get: The defined request is invalid: {:?}", - String::from_utf8(request.ticket).unwrap() - ))) + self.do_get_fallback(request, msg).await } async fn do_put( &self, - request: Request>, + mut request: Request>, ) -> Result, Status> { - let mut request = request.into_inner(); - let cmd = request.message().await?.unwrap(); - let any: prost_types::Any = + let cmd = request.get_mut().message().await?.unwrap(); + let message: prost_types::Any = prost::Message::decode(&*cmd.flight_descriptor.unwrap().cmd) .map_err(decode_error_to_status)?; - if any.is::() { - let record_count = self - .do_put_statement_update( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await?; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + let record_count = self.do_put_statement_update(token, request).await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { - app_metadata: result.as_any().encode_to_vec(), + app_metadata: result.encode_to_vec(), })]); return Ok(Response::new(Box::pin(output))); } - if any.is::() { - return self - .do_put_prepared_statement_query( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + if message.is::() { + let token = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_put_prepared_statement_query(token, request).await; } - if any.is::() { + if message.is::() { + let handle = message + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); let record_count = self - .do_put_prepared_statement_update( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) + .do_put_prepared_statement_update(handle, request) .await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { - app_metadata: result.as_any().encode_to_vec(), + app_metadata: result.encode_to_vec(), })]); return Ok(Response::new(Box::pin(output))); } Err(Status::invalid_argument(format!( - "do_put: The defined request is invalid: {:?}", - String::from_utf8(any.encode_to_vec()).unwrap() + "do_put: The defined request is invalid: {}", + message.type_url ))) } @@ -599,11 +537,9 @@ where &self, request: Request, ) -> Result, Status> { - let request = request.into_inner(); - - if request.r#type == CREATE_PREPARED_STATEMENT { - let any: prost_types::Any = - prost::Message::decode(&*request.body).map_err(decode_error_to_status)?; + if request.get_ref().r#type == CREATE_PREPARED_STATEMENT { + let any: prost_types::Any = Message::decode(&*request.get_ref().body) + .map_err(decode_error_to_status)?; let cmd: ActionCreatePreparedStatementRequest = any .unpack() @@ -613,15 +549,17 @@ where "Unable to unpack ActionCreatePreparedStatementRequest.", ) })?; - let stmt = self.do_action_create_prepared_statement(cmd).await?; + let stmt = self + .do_action_create_prepared_statement(cmd, request) + .await?; let output = futures::stream::iter(vec![Ok(super::super::gen::Result { body: stmt.as_any().encode_to_vec(), })]); return Ok(Response::new(Box::pin(output))); } - if request.r#type == CLOSE_PREPARED_STATEMENT { - let any: prost_types::Any = - prost::Message::decode(&*request.body).map_err(decode_error_to_status)?; + if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT { + let any: prost_types::Any = Message::decode(&*request.get_ref().body) + .map_err(decode_error_to_status)?; let cmd: ActionClosePreparedStatementRequest = any .unpack() @@ -631,13 +569,13 @@ where "Unable to unpack ActionClosePreparedStatementRequest.", ) })?; - self.do_action_close_prepared_statement(cmd).await; + self.do_action_close_prepared_statement(cmd, request).await; return Ok(Response::new(Box::pin(futures::stream::empty()))); } Err(Status::invalid_argument(format!( "do_action: The defined request is invalid: {:?}", - request.r#type + request.get_ref().r#type ))) } diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index dda3fc7fe3db..21a5a8572246 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -21,6 +21,7 @@ use crate::{FlightData, IpcMessage, SchemaAsIpc, SchemaResult}; use std::collections::HashMap; use arrow::array::ArrayRef; +use arrow::buffer::Buffer; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result}; use arrow::ipc::{reader, writer, writer::IpcWriteOptions}; @@ -66,7 +67,7 @@ pub fn flight_data_to_arrow_batch( }) .map(|batch| { reader::read_record_batch( - &data.data_body, + &Buffer::from(&data.data_body), batch, schema, dictionaries_by_id, diff --git a/arrow-pyarrow-integration-testing/Cargo.toml b/arrow-pyarrow-integration-testing/Cargo.toml index 60b06efb95a8..9aef5a0570a3 100644 --- a/arrow-pyarrow-integration-testing/Cargo.toml +++ b/arrow-pyarrow-integration-testing/Cargo.toml @@ -18,22 +18,22 @@ [package] name = "arrow-pyarrow-integration-testing" description = "" -version = "18.0.0" +version = "22.0.0" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] license = "Apache-2.0" keywords = [ "arrow" ] edition = "2021" -rust-version = "1.57" +rust-version = "1.62" [lib] name = "arrow_pyarrow_integration_testing" crate-type = ["cdylib"] [dependencies] -arrow = { path = "../arrow", version = "18.0.0", features = ["pyarrow"] } -pyo3 = { version = "0.16", features = ["extension-module"] } +arrow = { path = "../arrow", version = "22.0.0", features = ["pyarrow"] } +pyo3 = { version = "0.17", features = ["extension-module"] } [package.metadata.maturin] requires-dist = ["pyarrow>=1"] diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 7b3d4c64ad71..cedd48e4d313 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -17,54 +17,65 @@ [package] name = "arrow" -version = "18.0.0" +version = "22.0.0" description = "Rust implementation of Apache Arrow" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] license = "Apache-2.0" -keywords = [ "arrow" ] +keywords = ["arrow"] include = [ "benches/*.rs", "src/**/*.rs", "Cargo.toml", ] edition = "2021" -rust-version = "1.57" +rust-version = "1.62" [lib] name = "arrow" path = "src/lib.rs" bench = false +[target.'cfg(target_arch = "wasm32")'.dependencies] +ahash = { version = "0.8", default-features = false, features = ["compile-time-rng"] } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } + [dependencies] -ahash = { version = "0.7", default-features = false } -serde = { version = "1.0", default-features = false } -serde_derive = { version = "1.0", default-features = false } -serde_json = { version = "1.0", default-features = false, features = ["preserve_order"] } +serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } +serde_json = { version = "1.0", default-features = false, features = ["std"], optional = true } indexmap = { version = "1.9", default-features = false, features = ["std"] } -rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } num = { version = "0.4", default-features = false, features = ["std"] } half = { version = "2.0", default-features = false } hashbrown = { version = "0.12", default-features = false } -csv_crate = { version = "1.1", default-features = false, optional = true, package="csv" } +csv_crate = { version = "1.1", default-features = false, optional = true, package = "csv" } regex = { version = "1.5.6", default-features = false, features = ["std", "unicode"] } +regex-syntax = { version = "0.6.27", default-features = false, features = ["unicode"] } lazy_static = { version = "1.4", default-features = false } +lz4 = { version = "1.23", default-features = false, optional = true } packed_simd = { version = "0.3", default-features = false, optional = true, package = "packed_simd_2" } chrono = { version = "0.4", default-features = false, features = ["clock"] } -chrono-tz = {version = "0.6", default-features = false, optional = true} +chrono-tz = { version = "0.6", default-features = false, optional = true } flatbuffers = { version = "2.1.2", default-features = false, features = ["thiserror"], optional = true } -hex = { version = "0.4", default-features = false, features = ["std"] } comfy-table = { version = "6.0", optional = true, default-features = false } -pyo3 = { version = "0.16", default-features = false, optional = true } +pyo3 = { version = "0.17", default-features = false, optional = true } lexical-core = { version = "^0.8", default-features = false, features = ["write-integers", "write-floats", "parse-integers", "parse-floats"] } multiversion = { version = "0.6.1", default-features = false } bitflags = { version = "1.2.1", default-features = false } +zstd = { version = "0.11.1", default-features = false, optional = true } + +[package.metadata.docs.rs] +features = ["prettyprint", "ipc_compression", "dyn_cmp_dict", "ffi", "pyarrow"] [features] -default = ["csv", "ipc", "test_utils"] +default = ["csv", "ipc", "json"] +ipc_compression = ["ipc", "zstd", "lz4"] csv = ["csv_crate"] ipc = ["flatbuffers"] +json = ["serde", "serde_json"] simd = ["packed_simd"] prettyprint = ["comfy-table"] # The test utils feature enables code used in benchmarks and tests but @@ -72,20 +83,40 @@ prettyprint = ["comfy-table"] # an optional dependency for supporting compile to wasm32-unknown-unknown # target without assuming an environment containing JavaScript. test_utils = ["rand"] -pyarrow = ["pyo3"] +pyarrow = ["pyo3", "ffi"] # force_validate runs full data validation for all arrays that are created # this is not enabled by default as it is too computationally expensive # but is run as part of our CI checks force_validate = [] +# Enable ffi support +ffi = [] +# Enable dyn-comparison of dictionary arrays with other arrays +# Note: this does not impact comparison against scalars +dyn_cmp_dict = [] [dev-dependencies] -rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } criterion = { version = "0.3", default-features = false } flate2 = { version = "1", default-features = false, features = ["rust_backend"] } tempfile = { version = "3", default-features = false } [build-dependencies] +[[example]] +name = "dynamic_types" +required-features = ["prettyprint"] +path = "./examples/dynamic_types.rs" + +[[example]] +name = "read_csv" +required-features = ["prettyprint", "csv"] +path = "./examples/read_csv.rs" + +[[example]] +name = "read_csv_infer_schema" +required-features = ["prettyprint", "csv"] +path = "./examples/read_csv_infer_schema.rs" + [[bench]] name = "aggregate_kernels" harness = false @@ -126,7 +157,7 @@ required-features = ["test_utils"] [[bench]] name = "comparison_kernels" harness = false -required-features = ["test_utils"] +required-features = ["test_utils", "dyn_cmp_dict"] [[bench]] name = "filter_kernels" @@ -159,10 +190,12 @@ required-features = ["test_utils"] [[bench]] name = "csv_writer" harness = false +required-features = ["csv"] [[bench]] name = "json_reader" harness = false +required-features = ["json"] [[bench]] name = "equal" @@ -196,3 +229,7 @@ required-features = ["test_utils"] [[bench]] name = "array_data_validate" harness = false + +[[bench]] +name = "decimal_validate" +harness = false diff --git a/arrow/README.md b/arrow/README.md index 7507ff11cd19..7a95df0f2252 100644 --- a/arrow/README.md +++ b/arrow/README.md @@ -22,7 +22,10 @@ [![crates.io](https://img.shields.io/crates/v/arrow.svg)](https://crates.io/crates/arrow) [![docs.rs](https://img.shields.io/docsrs/arrow.svg)](https://docs.rs/arrow/latest/arrow/) -This crate contains the official Native Rust implementation of [Apache Arrow][arrow] in memory format, governed by the Apache Software Foundation. Additional details can be found on [crates.io](https://crates.io/crates/arrow), [docs.rs](https://docs.rs/arrow/latest/arrow/) and [examples](https://github.com/apache/arrow-rs/tree/master/arrow/examples). +This crate contains the official Native Rust implementation of [Apache Arrow][arrow] in memory format, governed by the Apache Software Foundation. + +The [crate documentation](https://docs.rs/arrow/latest/arrow/) contains examples and full API. +There are several [examples](https://github.com/apache/arrow-rs/tree/master/arrow/examples) to start from as well. ## Rust Version Compatibility @@ -32,48 +35,57 @@ This crate is tested with the latest stable version of Rust. We do not currently The arrow crate follows the [SemVer standard](https://doc.rust-lang.org/cargo/reference/semver.html) defined by Cargo and works well within the Rust crate ecosystem. -However, for historical reasons, this crate uses versions with major numbers greater than `0.x` (e.g. `18.0.0`), unlike many other crates in the Rust ecosystem which spend extended time releasing versions `0.x` to signal planned ongoing API changes. Minor arrow releases contain only compatible changes, while major releases may contain breaking API changes. +However, for historical reasons, this crate uses versions with major numbers greater than `0.x` (e.g. `22.0.0`), unlike many other crates in the Rust ecosystem which spend extended time releasing versions `0.x` to signal planned ongoing API changes. Minor arrow releases contain only compatible changes, while major releases may contain breaking API changes. -## Features +## Feature Flags -The arrow crate provides the following features which may be enabled: +The `arrow` crate provides the following features which may be enabled in your `Cargo.toml`: - `csv` (default) - support for reading and writing Arrow arrays to/from csv files -- `ipc` (default) - support for the [arrow-flight](https://crates.io/crates/arrow-flight) IPC and wire format +- `json` (default) - support for reading and writing Arrow array to/from json files +- `ipc` (default) - support for reading [Arrow IPC Format](https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc), also used as the wire protocol in [arrow-flight](https://crates.io/crates/arrow-flight) +- `ipc_compression` - Enables reading and writing compressed IPC streams (also enables `ipc`) - `prettyprint` - support for formatting record batches as textual columns - `js` - support for building arrow for WebAssembly / JavaScript -- `simd` - (_Requires Nightly Rust_) alternate optimized +- `simd` - (_Requires Nightly Rust_) Use alternate hand optimized implementations of some [compute](https://github.com/apache/arrow-rs/tree/master/arrow/src/compute/kernels) - kernels using explicit SIMD instructions available through [packed_simd_2](https://docs.rs/packed_simd_2/latest/packed_simd_2/). + kernels using explicit SIMD instructions via [packed_simd_2](https://docs.rs/packed_simd_2/latest/packed_simd_2/). - `chrono-tz` - support of parsing timezone using [chrono-tz](https://docs.rs/chrono-tz/0.6.0/chrono_tz/) +- `ffi` - bindings for the Arrow C [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) +- `pyarrow` - bindings for pyo3 to call arrow-rs from python +- `dyn_cmp_dict` - enables comparison of dictionary arrays within dyn comparison kernels + +## Arrow Feature Status + +The [Apache Arrow Status](https://arrow.apache.org/docs/status.html) page lists which features of Arrow this crate supports. ## Safety -Arrow seeks to uphold the Rust Soundness Pledge as articulated eloquently [here](https://raphlinus.github.io/rust/2020/01/18/soundness-pledge.html). Specifically: +Arrow seeks to uphold the Rust Soundness Pledge as articulated eloquently [here](https://raphlinus.github.io/rust/22.0.01/18/soundness-pledge.html). Specifically: > The intent of this crate is to be free of soundness bugs. The developers will do their best to avoid them, and welcome help in analyzing and fixing them Where soundness in turn is defined as: -> Code is unable to trigger undefined behaviour using safe APIs +> Code is unable to trigger undefined behavior using safe APIs -One way to ensure this would be to not use `unsafe`, however, as described in the opening chapter of the [Rustonomicon](https://doc.rust-lang.org/nomicon/meet-safe-and-unsafe.html) this is not a requirement, and flexibility in this regard is actually one of Rust's great strengths. +One way to ensure this would be to not use `unsafe`, however, as described in the opening chapter of the [Rustonomicon](https://doc.rust-lang.org/nomicon/meet-safe-and-unsafe.html) this is not a requirement, and flexibility in this regard is one of Rust's great strengths. In particular there are a number of scenarios where `unsafe` is largely unavoidable: -* Invariants that cannot be statically verified by the compiler and unlock non-trivial performance wins, e.g. values in a StringArray are UTF-8, [TrustedLen](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html) iterators, etc... -* FFI -* SIMD +- Invariants that cannot be statically verified by the compiler and unlock non-trivial performance wins, e.g. values in a StringArray are UTF-8, [TrustedLen](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html) iterators, etc... +- FFI +- SIMD -Additionally, this crate exposes a number of `unsafe` APIs, allowing downstream crates to explicitly opt-out of potentially expensive invariant checking where appropriate. +Additionally, this crate exposes a number of `unsafe` APIs, allowing downstream crates to explicitly opt-out of potentially expensive invariant checking where appropriate. We have a number of strategies to help reduce this risk: -* Provide strongly-typed `Array` and `ArrayBuilder` APIs to safely and efficiently interact with arrays -* Extensive validation logic to safely construct `ArrayData` from untrusted sources -* All commits are verified using [MIRI](https://github.com/rust-lang/miri) to detect undefined behaviour -* We provide a `force_validate` feature that enables additional validation checks for use in test/debug builds -* There is ongoing work to reduce and better document the use of unsafe, and we welcome contributions in this space +- Provide strongly-typed `Array` and `ArrayBuilder` APIs to safely and efficiently interact with arrays +- Extensive validation logic to safely construct `ArrayData` from untrusted sources +- All commits are verified using [MIRI](https://github.com/rust-lang/miri) to detect undefined behaviour +- Use a `force_validate` feature that enables additional validation checks for use in test/debug builds +- There is ongoing work to reduce and better document the use of unsafe, and we welcome contributions in this space ## Building for WASM @@ -101,16 +113,38 @@ cargo run --example read_csv [arrow]: https://arrow.apache.org/ +## Performance Tips -## Performance +Arrow aims to be as fast as possible out of the box, whilst not compromising on safety. However, +it relies heavily on LLVM auto-vectorisation to achieve this. Unfortunately the LLVM defaults, +particularly for x86_64, favour portability over performance, and LLVM will consequently avoid +using more recent instructions that would result in errors on older CPUs. -Most of the compute kernels benefit a lot from being optimized for a specific CPU target. -This is especially so on x86-64 since without specifying a target the compiler can only assume support for SSE2 vector instructions. -One of the following values as `-Ctarget-cpu=value` in `RUSTFLAGS` can therefore improve performance significantly: +To address this it is recommended that you override the LLVM defaults either +by setting the `RUSTFLAGS` environment variable, or by setting `rustflags` in your +[Cargo configuration](https://doc.rust-lang.org/cargo/reference/config.html) - - `native`: Target the exact features of the cpu that the build is running on. - This should give the best performance when building and running locally, but should be used carefully for example when building in a CI pipeline or when shipping pre-compiled software. - - `x86-64-v3`: Includes AVX2 support and is close to the intel `haswell` architecture released in 2013 and should be supported by any recent Intel or Amd cpu. - - `x86-64-v4`: Includes AVX512 support available on intel `skylake` server and `icelake`/`tigerlake`/`rocketlake` laptop and desktop processors. +Enable all features supported by the current CPU -These flags should be used in addition to the `simd` feature, since they will also affect the code generated by the simd library. \ No newline at end of file +```ignore +RUSTFLAGS="-C target-cpu=native" +``` + +Enable all features supported by the current CPU, and enable full use of AVX512 + +```ignore +RUSTFLAGS="-C target-cpu=native -C target-feature=-prefer-256-bit" +``` + +Enable all features supported by CPUs more recent than haswell (2013) + +```ignore +RUSTFLAGS="-C target-cpu=haswell" +``` + +For a full list of features and target CPUs use + +```shell +$ rustc --print target-cpus +$ rustc --print target-features +``` diff --git a/arrow/benches/array_data_validate.rs b/arrow/benches/array_data_validate.rs index c46252bececd..3cd13c09c58a 100644 --- a/arrow/benches/array_data_validate.rs +++ b/arrow/benches/array_data_validate.rs @@ -37,11 +37,22 @@ fn create_binary_array_data(length: i32) -> ArrayData { .unwrap() } -fn array_slice_benchmark(c: &mut Criterion) { +fn validate_utf8_array(arr: &StringArray) { + arr.data().validate_values().unwrap(); +} + +fn validate_benchmark(c: &mut Criterion) { + //Binary Array c.bench_function("validate_binary_array_data 20000", |b| { b.iter(|| create_binary_array_data(20000)) }); + + //Utf8 Array + let str_arr = StringArray::from(vec!["test"; 20000]); + c.bench_function("validate_utf8_array_data 20000", |b| { + b.iter(|| validate_utf8_array(&str_arr)) + }); } -criterion_group!(benches, array_slice_benchmark); +criterion_group!(benches, validate_benchmark); criterion_main!(benches); diff --git a/arrow/benches/array_from_vec.rs b/arrow/benches/array_from_vec.rs index 3f82beb6f534..59bef65a18c6 100644 --- a/arrow/benches/array_from_vec.rs +++ b/arrow/benches/array_from_vec.rs @@ -17,11 +17,15 @@ #[macro_use] extern crate criterion; + use criterion::Criterion; extern crate arrow; use arrow::array::*; +use arrow::util::decimal::Decimal256; +use num::BigInt; +use rand::Rng; use std::{convert::TryFrom, sync::Arc}; fn array_from_vec(n: usize) { @@ -72,6 +76,58 @@ fn struct_array_from_vec( ); } +fn decimal128_array_from_vec(array: &[Option]) { + criterion::black_box( + array + .iter() + .copied() + .collect::() + .with_precision_and_scale(34, 2) + .unwrap(), + ); +} + +fn decimal256_array_from_vec(array: &[Option]) { + criterion::black_box( + array + .iter() + .copied() + .collect::() + .with_precision_and_scale(70, 2) + .unwrap(), + ); +} + +fn decimal_benchmark(c: &mut Criterion) { + // bench decimal128 array + // create option array + let size: usize = 1 << 15; + let mut rng = rand::thread_rng(); + let mut array = vec![]; + for _ in 0..size { + array.push(Some(rng.gen_range::(0..9999999999))); + } + c.bench_function("decimal128_array_from_vec 32768", |b| { + b.iter(|| decimal128_array_from_vec(array.as_slice())) + }); + + // bench decimal256array + // create option> array + let size = 1 << 10; + let mut array = vec![]; + let mut rng = rand::thread_rng(); + for _ in 0..size { + let decimal = + Decimal256::from(BigInt::from(rng.gen_range::(0..9999999999999))); + array.push(Some(decimal)); + } + + // bench decimal256 array + c.bench_function("decimal256_array_from_vec 32768", |b| { + b.iter(|| decimal256_array_from_vec(array.as_slice())) + }); +} + fn criterion_benchmark(c: &mut Criterion) { c.bench_function("array_from_vec 128", |b| b.iter(|| array_from_vec(128))); c.bench_function("array_from_vec 256", |b| b.iter(|| array_from_vec(256))); @@ -108,5 +164,5 @@ fn criterion_benchmark(c: &mut Criterion) { }); } -criterion_group!(benches, criterion_benchmark); +criterion_group!(benches, criterion_benchmark, decimal_benchmark); criterion_main!(benches); diff --git a/arrow/benches/builder.rs b/arrow/benches/builder.rs index fd9f319e3976..c2ebcb3daa50 100644 --- a/arrow/benches/builder.rs +++ b/arrow/benches/builder.rs @@ -22,9 +22,11 @@ extern crate rand; use std::mem::size_of; use criterion::*; +use num::BigInt; use rand::distributions::Standard; use arrow::array::*; +use arrow::util::decimal::Decimal256; use arrow::util::test_util::seedable_rng; use rand::Rng; @@ -41,9 +43,9 @@ fn bench_primitive(c: &mut Criterion) { )); group.bench_function("bench_primitive", |b| { b.iter(|| { - let mut builder = Int64Builder::new(64); + let mut builder = Int64Builder::with_capacity(64); for _ in 0..NUM_BATCHES { - let _ = black_box(builder.append_slice(&data[..])); + builder.append_slice(&data[..]); } black_box(builder.finish()); }) @@ -55,9 +57,9 @@ fn bench_primitive_nulls(c: &mut Criterion) { let mut group = c.benchmark_group("bench_primitive_nulls"); group.bench_function("bench_primitive_nulls", |b| { b.iter(|| { - let mut builder = UInt8Builder::new(64); + let mut builder = UInt8Builder::with_capacity(64); for _ in 0..NUM_BATCHES * BATCH_SIZE { - let _ = black_box(builder.append_null()); + builder.append_null(); } black_box(builder.finish()); }) @@ -78,9 +80,9 @@ fn bench_bool(c: &mut Criterion) { )); group.bench_function("bench_bool", |b| { b.iter(|| { - let mut builder = BooleanBuilder::new(64); + let mut builder = BooleanBuilder::with_capacity(64); for _ in 0..NUM_BATCHES { - let _ = black_box(builder.append_slice(&data[..])); + builder.append_slice(&data[..]); } black_box(builder.finish()); }) @@ -96,9 +98,9 @@ fn bench_string(c: &mut Criterion) { )); group.bench_function("bench_string", |b| { b.iter(|| { - let mut builder = StringBuilder::new(64); + let mut builder = StringBuilder::new(); for _ in 0..NUM_BATCHES * BATCH_SIZE { - let _ = black_box(builder.append_value(SAMPLE_STRING)); + builder.append_value(SAMPLE_STRING); } black_box(builder.finish()); }) @@ -106,11 +108,46 @@ fn bench_string(c: &mut Criterion) { group.finish(); } +fn bench_decimal128(c: &mut Criterion) { + c.bench_function("bench_decimal128_builder", |b| { + b.iter(|| { + let mut rng = rand::thread_rng(); + let mut decimal_builder = Decimal128Builder::with_capacity(BATCH_SIZE, 38, 0); + for _ in 0..BATCH_SIZE { + decimal_builder + .append_value(rng.gen_range::(0..9999999999)) + .unwrap(); + } + black_box(decimal_builder.finish()); + }) + }); +} + +fn bench_decimal256(c: &mut Criterion) { + c.bench_function("bench_decimal128_builder", |b| { + b.iter(|| { + let mut rng = rand::thread_rng(); + let mut decimal_builder = + Decimal256Builder::with_capacity(BATCH_SIZE, 76, 10); + for _ in 0..BATCH_SIZE { + decimal_builder + .append_value(&Decimal256::from(BigInt::from( + rng.gen_range::(0..99999999999), + ))) + .unwrap() + } + black_box(decimal_builder.finish()); + }) + }); +} + criterion_group!( benches, bench_primitive, bench_primitive_nulls, bench_bool, - bench_string + bench_string, + bench_decimal128, + bench_decimal256, ); criterion_main!(benches); diff --git a/arrow/benches/cast_kernels.rs b/arrow/benches/cast_kernels.rs index d164e1facfd3..ac8fc08d9210 100644 --- a/arrow/benches/cast_kernels.rs +++ b/arrow/benches/cast_kernels.rs @@ -29,6 +29,7 @@ use arrow::array::*; use arrow::compute::cast; use arrow::datatypes::*; use arrow::util::bench_util::*; +use arrow::util::decimal::Decimal256; use arrow::util::test_util::seedable_rng; fn build_array(size: usize) -> ArrayRef @@ -44,17 +45,17 @@ fn build_utf8_date_array(size: usize, with_nulls: bool) -> ArrayRef { // use random numbers to avoid spurious compiler optimizations wrt to branching let mut rng = seedable_rng(); - let mut builder = StringBuilder::new(size); + let mut builder = StringBuilder::new(); let range = Uniform::new(0, 737776); for _ in 0..size { if with_nulls && rng.gen::() > 0.8 { - builder.append_null().unwrap(); + builder.append_null(); } else { let string = NaiveDate::from_num_days_from_ce(rng.sample(range)) .format("%Y-%m-%d") .to_string(); - builder.append_value(&string).unwrap(); + builder.append_value(&string); } } Arc::new(builder.finish()) @@ -65,22 +66,44 @@ fn build_utf8_date_time_array(size: usize, with_nulls: bool) -> ArrayRef { // use random numbers to avoid spurious compiler optimizations wrt to branching let mut rng = seedable_rng(); - let mut builder = StringBuilder::new(size); + let mut builder = StringBuilder::new(); let range = Uniform::new(0, 1608071414123); for _ in 0..size { if with_nulls && rng.gen::() > 0.8 { - builder.append_null().unwrap(); + builder.append_null(); } else { let string = NaiveDateTime::from_timestamp(rng.sample(range), 0) .format("%Y-%m-%dT%H:%M:%S") .to_string(); - builder.append_value(&string).unwrap(); + builder.append_value(&string); } } Arc::new(builder.finish()) } +fn build_decimal128_array(size: usize, precision: u8, scale: u8) -> ArrayRef { + let mut rng = seedable_rng(); + let mut builder = Decimal128Builder::with_capacity(size, precision, scale); + + for _ in 0..size { + let _ = builder.append_value(rng.gen_range::(0..1000000000)); + } + Arc::new(builder.finish()) +} + +fn build_decimal256_array(size: usize, precision: u8, scale: u8) -> ArrayRef { + let mut rng = seedable_rng(); + let mut builder = Decimal256Builder::with_capacity(size, precision, scale); + let mut bytes = [0; 32]; + for _ in 0..size { + let num = rng.gen_range::(0..1000000000); + bytes[0..16].clone_from_slice(&num.to_le_bytes()); + let _ = builder.append_value(&Decimal256::new(precision, scale, &bytes)); + } + Arc::new(builder.finish()) +} + // cast array from specified primitive array type to desired data type fn cast_array(array: &ArrayRef, to_type: DataType) { criterion::black_box(cast(array, &to_type).unwrap()); @@ -102,6 +125,9 @@ fn add_benchmark(c: &mut Criterion) { let utf8_date_array = build_utf8_date_array(512, true); let utf8_date_time_array = build_utf8_date_time_array(512, true); + let decimal128_array = build_decimal128_array(512, 10, 3); + let decimal256_array = build_decimal256_array(512, 50, 3); + c.bench_function("cast int32 to int32 512", |b| { b.iter(|| cast_array(&i32_array, DataType::Int32)) }); @@ -179,6 +205,19 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("cast utf8 to date64 512", |b| { b.iter(|| cast_array(&utf8_date_time_array, DataType::Date64)) }); + + c.bench_function("cast decimal128 to decimal128 512", |b| { + b.iter(|| cast_array(&decimal128_array, DataType::Decimal128(30, 5))) + }); + c.bench_function("cast decimal128 to decimal256 512", |b| { + b.iter(|| cast_array(&decimal128_array, DataType::Decimal256(50, 5))) + }); + c.bench_function("cast decimal256 to decimal128 512", |b| { + b.iter(|| cast_array(&decimal256_array, DataType::Decimal128(38, 2))) + }); + c.bench_function("cast decimal256 to decimal256 512", |b| { + b.iter(|| cast_array(&decimal256_array, DataType::Decimal256(50, 5))) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/benches/csv_writer.rs b/arrow/benches/csv_writer.rs index 3ecf514ad6db..05c6c226c464 100644 --- a/arrow/benches/csv_writer.rs +++ b/arrow/benches/csv_writer.rs @@ -21,7 +21,6 @@ extern crate criterion; use criterion::*; use arrow::array::*; -#[cfg(feature = "csv")] use arrow::csv; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; @@ -30,47 +29,44 @@ use std::fs::File; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { - #[cfg(feature = "csv")] - { - let schema = Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Float64, true), - Field::new("c3", DataType::UInt32, false), - Field::new("c4", DataType::Boolean, true), - ]); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::UInt32, false), + Field::new("c4", DataType::Boolean, true), + ]); - let c1 = StringArray::from(vec![ - "Lorem ipsum dolor sit amet", - "consectetur adipiscing elit", - "sed do eiusmod tempor", - ]); - let c2 = PrimitiveArray::::from(vec![ - Some(123.564532), - None, - Some(-556132.25), - ]); - let c3 = PrimitiveArray::::from(vec![3, 2, 1]); - let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); + let c1 = StringArray::from(vec![ + "Lorem ipsum dolor sit amet", + "consectetur adipiscing elit", + "sed do eiusmod tempor", + ]); + let c2 = PrimitiveArray::::from(vec![ + Some(123.564532), + None, + Some(-556132.25), + ]); + let c3 = PrimitiveArray::::from(vec![3, 2, 1]); + let c4 = BooleanArray::from(vec![Some(true), Some(false), None]); - let b = RecordBatch::try_new( - Arc::new(schema), - vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], - ) - .unwrap(); - let path = env::temp_dir().join("bench_write_csv.csv"); - let file = File::create(path).unwrap(); - let mut writer = csv::Writer::new(file); - let batches = vec![&b, &b, &b, &b, &b, &b, &b, &b, &b, &b, &b]; + let b = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], + ) + .unwrap(); + let path = env::temp_dir().join("bench_write_csv.csv"); + let file = File::create(path).unwrap(); + let mut writer = csv::Writer::new(file); + let batches = vec![&b, &b, &b, &b, &b, &b, &b, &b, &b, &b, &b]; - c.bench_function("record_batches_to_csv", |b| { - b.iter(|| { - #[allow(clippy::unit_arg)] - criterion::black_box(for batch in &batches { - writer.write(batch).unwrap() - }); + c.bench_function("record_batches_to_csv", |b| { + b.iter(|| { + #[allow(clippy::unit_arg)] + criterion::black_box(for batch in &batches { + writer.write(batch).unwrap() }); }); - } + }); } criterion_group!(benches, criterion_benchmark); diff --git a/arrow/benches/decimal_validate.rs b/arrow/benches/decimal_validate.rs new file mode 100644 index 000000000000..555373e4a634 --- /dev/null +++ b/arrow/benches/decimal_validate.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; + +use arrow::array::{ + Array, Decimal128Array, Decimal128Builder, Decimal256Array, Decimal256Builder, +}; +use criterion::Criterion; +use num::BigInt; +use rand::Rng; + +extern crate arrow; + +use arrow::util::decimal::Decimal256; + +fn validate_decimal128_array(array: Decimal128Array) { + array.with_precision_and_scale(35, 0).unwrap(); +} + +fn validate_decimal256_array(array: Decimal256Array) { + array.with_precision_and_scale(35, 0).unwrap(); +} + +fn validate_decimal128_benchmark(c: &mut Criterion) { + let mut rng = rand::thread_rng(); + let size: i128 = 20000; + let mut decimal_builder = Decimal128Builder::with_capacity(size as usize, 38, 0); + for _ in 0..size { + decimal_builder + .append_value(rng.gen_range::(0..999999999999)) + .unwrap(); + } + let decimal_array = decimal_builder.finish(); + let data = decimal_array.into_data(); + c.bench_function("validate_decimal128_array 20000", |b| { + b.iter(|| { + let array = Decimal128Array::from(data.clone()); + validate_decimal128_array(array); + }) + }); +} + +fn validate_decimal256_benchmark(c: &mut Criterion) { + let mut rng = rand::thread_rng(); + let size: i128 = 20000; + let mut decimal_builder = Decimal256Builder::with_capacity(size as usize, 76, 0); + for _ in 0..size { + let v = rng.gen_range::(0..999999999999999); + let decimal = Decimal256::from_big_int(&BigInt::from(v), 76, 0).unwrap(); + decimal_builder.append_value(&decimal).unwrap(); + } + let decimal_array256_data = decimal_builder.finish(); + let data = decimal_array256_data.into_data(); + c.bench_function("validate_decimal256_array 20000", |b| { + b.iter(|| { + let array = Decimal256Array::from(data.clone()); + validate_decimal256_array(array); + }) + }); +} + +criterion_group!( + benches, + validate_decimal128_benchmark, + validate_decimal256_benchmark, +); +criterion_main!(benches); diff --git a/arrow/benches/string_dictionary_builder.rs b/arrow/benches/string_dictionary_builder.rs index bc014bec155d..1a3b95917207 100644 --- a/arrow/benches/string_dictionary_builder.rs +++ b/arrow/benches/string_dictionary_builder.rs @@ -43,8 +43,11 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { let strings = build_strings(dict_size, total_size, key_len); b.iter(|| { - let keys = Int32Builder::new(strings.len()); - let values = StringBuilder::new((key_len + 1) * dict_size); + let keys = Int32Builder::with_capacity(strings.len()); + let values = StringBuilder::with_capacity( + key_len + 1, + (key_len + 1) * dict_size, + ); let mut builder = StringDictionaryBuilder::new(keys, values); for val in &strings { diff --git a/arrow/benches/take_kernels.rs b/arrow/benches/take_kernels.rs index dc9799b9a733..c4677cc72616 100644 --- a/arrow/benches/take_kernels.rs +++ b/arrow/benches/take_kernels.rs @@ -30,13 +30,13 @@ use arrow::{array::*, util::bench_util::*}; fn create_random_index(size: usize, null_density: f32) -> UInt32Array { let mut rng = seedable_rng(); - let mut builder = UInt32Builder::new(size); + let mut builder = UInt32Builder::with_capacity(size); for _ in 0..size { if rng.gen::() < null_density { - builder.append_null().unwrap() + builder.append_null(); } else { let value = rng.gen_range::(0u32..size as u32); - builder.append_value(value).unwrap(); + builder.append_value(value); } } builder.finish() diff --git a/arrow/examples/builders.rs b/arrow/examples/builders.rs index d35cb5ab744d..bacd550bdfde 100644 --- a/arrow/examples/builders.rs +++ b/arrow/examples/builders.rs @@ -34,22 +34,20 @@ fn main() { // u64, i8, i16, i32, i64, f32, f64) // Create a new builder with a capacity of 100 - let mut primitive_array_builder = Int32Builder::new(100); + let mut primitive_array_builder = Int32Builder::with_capacity(100); // Append an individual primitive value - primitive_array_builder.append_value(55).unwrap(); + primitive_array_builder.append_value(55); // Append a null value - primitive_array_builder.append_null().unwrap(); + primitive_array_builder.append_null(); // Append a slice of primitive values - primitive_array_builder.append_slice(&[39, 89, 12]).unwrap(); + primitive_array_builder.append_slice(&[39, 89, 12]); // Append lots of values - primitive_array_builder.append_null().unwrap(); - primitive_array_builder - .append_slice(&(25..50).collect::>()) - .unwrap(); + primitive_array_builder.append_null(); + primitive_array_builder.append_slice(&(25..50).collect::>()); // Build the `PrimitiveArray` let primitive_array = primitive_array_builder.finish(); diff --git a/arrow/examples/dynamic_types.rs b/arrow/examples/dynamic_types.rs index f98596f2e777..eefbf6dcd4ff 100644 --- a/arrow/examples/dynamic_types.rs +++ b/arrow/examples/dynamic_types.rs @@ -65,10 +65,7 @@ fn main() -> Result<()> { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)])?; - #[cfg(feature = "prettyprint")] - { - print_batches(&[batch.clone()]).unwrap(); - } + print_batches(&[batch.clone()]).unwrap(); process(&batch); Ok(()) @@ -108,8 +105,5 @@ fn process(batch: &RecordBatch) { ) .unwrap(); - #[cfg(feature = "prettyprint")] - { - print_batches(&[projection]).unwrap(); - } + print_batches(&[projection]).unwrap(); } diff --git a/arrow/examples/read_csv.rs b/arrow/examples/read_csv.rs index 5ccf0c58a797..a1a592134eba 100644 --- a/arrow/examples/read_csv.rs +++ b/arrow/examples/read_csv.rs @@ -20,30 +20,22 @@ extern crate arrow; use std::fs::File; use std::sync::Arc; -#[cfg(feature = "csv")] use arrow::csv; use arrow::datatypes::{DataType, Field, Schema}; -#[cfg(feature = "prettyprint")] use arrow::util::pretty::print_batches; fn main() { - #[cfg(feature = "csv")] - { - let schema = Schema::new(vec![ - Field::new("city", DataType::Utf8, false), - Field::new("lat", DataType::Float64, false), - Field::new("lng", DataType::Float64, false), - ]); + let schema = Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Float64, false), + Field::new("lng", DataType::Float64, false), + ]); - let path = format!("{}/test/data/uk_cities.csv", env!("CARGO_MANIFEST_DIR")); - let file = File::open(path).unwrap(); + let path = format!("{}/test/data/uk_cities.csv", env!("CARGO_MANIFEST_DIR")); + let file = File::open(path).unwrap(); - let mut csv = - csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None, None); - let _batch = csv.next().unwrap().unwrap(); - #[cfg(feature = "prettyprint")] - { - print_batches(&[_batch]).unwrap(); - } - } + let mut csv = + csv::Reader::new(file, Arc::new(schema), false, None, 1024, None, None, None); + let batch = csv.next().unwrap().unwrap(); + print_batches(&[batch]).unwrap(); } diff --git a/arrow/examples/read_csv_infer_schema.rs b/arrow/examples/read_csv_infer_schema.rs index e9f5ff650706..120a7b81910b 100644 --- a/arrow/examples/read_csv_infer_schema.rs +++ b/arrow/examples/read_csv_infer_schema.rs @@ -17,28 +17,20 @@ extern crate arrow; -#[cfg(feature = "csv")] use arrow::csv; -#[cfg(feature = "prettyprint")] use arrow::util::pretty::print_batches; use std::fs::File; fn main() { - #[cfg(feature = "csv")] - { - let path = format!( - "{}/test/data/uk_cities_with_headers.csv", - env!("CARGO_MANIFEST_DIR") - ); - let file = File::open(path).unwrap(); - let builder = csv::ReaderBuilder::new() - .has_header(true) - .infer_schema(Some(100)); - let mut csv = builder.build(file).unwrap(); - let _batch = csv.next().unwrap().unwrap(); - #[cfg(feature = "prettyprint")] - { - print_batches(&[_batch]).unwrap(); - } - } + let path = format!( + "{}/test/data/uk_cities_with_headers.csv", + env!("CARGO_MANIFEST_DIR") + ); + let file = File::open(path).unwrap(); + let builder = csv::ReaderBuilder::new() + .has_header(true) + .infer_schema(Some(100)); + let mut csv = builder.build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + print_batches(&[batch]).unwrap(); } diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs index f01fa5cc91db..38ba2025a2e3 100644 --- a/arrow/src/array/array.rs +++ b/arrow/src/array/array.rs @@ -16,19 +16,16 @@ // under the License. use std::any::Any; -use std::convert::{From, TryFrom}; +use std::convert::From; use std::fmt; use std::sync::Arc; use super::*; -use crate::array::equal_json::JsonEqual; use crate::buffer::{Buffer, MutableBuffer}; -use crate::error::Result; -use crate::ffi; /// Trait for dealing with different types of array at runtime when the type of the /// array is not known in advance. -pub trait Array: fmt::Debug + Send + Sync + JsonEqual { +pub trait Array: fmt::Debug + Send + Sync { /// Returns the array as [`Any`](std::any::Any) so that it can be /// downcasted to a specific implementation. /// @@ -216,15 +213,6 @@ pub trait Array: fmt::Debug + Send + Sync + JsonEqual { self.data_ref().get_array_memory_size() + std::mem::size_of_val(self) - std::mem::size_of::() } - - /// returns two pointers that represent this array in the C Data Interface (FFI) - fn to_raw( - &self, - ) -> Result<(*const ffi::FFI_ArrowArray, *const ffi::FFI_ArrowSchema)> { - let data = self.data().clone(); - let array = ffi::ArrowArray::try_from(data)?; - Ok(ffi::ArrowArray::into_raw(array)) - } } /// A reference-counted reference to a generic `Array`. @@ -287,16 +275,89 @@ impl Array for ArrayRef { fn get_array_memory_size(&self) -> usize { self.as_ref().get_array_memory_size() } +} + +impl<'a, T: Array> Array for &'a T { + fn as_any(&self) -> &dyn Any { + T::as_any(self) + } + + fn data(&self) -> &ArrayData { + T::data(self) + } + + fn into_data(self) -> ArrayData { + self.data().clone() + } + + fn data_ref(&self) -> &ArrayData { + T::data_ref(self) + } - fn to_raw( - &self, - ) -> Result<(*const ffi::FFI_ArrowArray, *const ffi::FFI_ArrowSchema)> { - let data = self.data().clone(); - let array = ffi::ArrowArray::try_from(data)?; - Ok(ffi::ArrowArray::into_raw(array)) + fn data_type(&self) -> &DataType { + T::data_type(self) + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + T::slice(self, offset, length) + } + + fn len(&self) -> usize { + T::len(self) + } + + fn is_empty(&self) -> bool { + T::is_empty(self) + } + + fn offset(&self) -> usize { + T::offset(self) + } + + fn is_null(&self, index: usize) -> bool { + T::is_null(self, index) + } + + fn is_valid(&self, index: usize) -> bool { + T::is_valid(self, index) + } + + fn null_count(&self) -> usize { + T::null_count(self) + } + + fn get_buffer_memory_size(&self) -> usize { + T::get_buffer_memory_size(self) + } + + fn get_array_memory_size(&self) -> usize { + T::get_array_memory_size(self) } } +/// A generic trait for accessing the values of an [`Array`] +/// +/// # Validity +/// +/// An [`ArrayAccessor`] must always return a well-defined value for an index that is +/// within the bounds `0..Array::len`, including for null indexes where [`Array::is_null`] is true. +/// +/// The value at null indexes is unspecified, and implementations must not rely on a specific +/// value such as [`Default::default`] being returned, however, it must not be undefined +pub trait ArrayAccessor: Array { + type Item: Send + Sync; + + /// Returns the element at index `i` + /// # Panics + /// Panics if the value is outside the bounds of the array + fn value(&self, index: usize) -> Self::Item; + + /// Returns the element at index `i` + /// # Safety + /// Caller is responsible for ensuring that the index is within the bounds of the array + unsafe fn value_unchecked(&self, index: usize) -> Self::Item; +} + /// Constructs an array using the input `data`. /// Returns a reference-counted `Array` instance. pub fn make_array(data: ArrayData) -> ArrayRef { @@ -403,7 +464,8 @@ pub fn make_array(data: ArrayData) -> ArrayRef { dt => panic!("Unexpected dictionary key type {:?}", dt), }, DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef, - DataType::Decimal(_, _) => Arc::new(DecimalArray::from(data)) as ArrayRef, + DataType::Decimal128(_, _) => Arc::new(Decimal128Array::from(data)) as ArrayRef, + DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef, dt => panic!("Unexpected data type {:?}", dt), } } @@ -567,7 +629,10 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { ) }) } - DataType::Decimal(_, _) => new_null_sized_decimal(data_type, length), + DataType::Decimal128(_, _) => { + new_null_sized_decimal(data_type, length, std::mem::size_of::()) + } + DataType::Decimal256(_, _) => new_null_sized_decimal(data_type, length, 32), } } @@ -632,7 +697,11 @@ fn new_null_sized_array( } #[inline] -fn new_null_sized_decimal(data_type: &DataType, length: usize) -> ArrayRef { +fn new_null_sized_decimal( + data_type: &DataType, + length: usize, + byte_width: usize, +) -> ArrayRef { make_array(unsafe { ArrayData::new_unchecked( data_type.clone(), @@ -640,51 +709,12 @@ fn new_null_sized_decimal(data_type: &DataType, length: usize) -> ArrayRef { Some(length), Some(MutableBuffer::new_null(length).into()), 0, - vec![Buffer::from(vec![ - 0u8; - length * std::mem::size_of::() - ])], + vec![Buffer::from(vec![0u8; length * byte_width])], vec![], ) }) } -/// Creates a new array from two FFI pointers. Used to import arrays from the C Data Interface -/// # Safety -/// Assumes that these pointers represent valid C Data Interfaces, both in memory -/// representation and lifetime via the `release` mechanism. -pub unsafe fn make_array_from_raw( - array: *const ffi::FFI_ArrowArray, - schema: *const ffi::FFI_ArrowSchema, -) -> Result { - let array = ffi::ArrowArray::try_from_raw(array, schema)?; - let data = ArrayData::try_from(array)?; - Ok(make_array(data)) -} - -/// Exports an array to raw pointers of the C Data Interface provided by the consumer. -/// # Safety -/// Assumes that these pointers represent valid C Data Interfaces, both in memory -/// representation and lifetime via the `release` mechanism. -/// -/// This function copies the content of two FFI structs [ffi::FFI_ArrowArray] and -/// [ffi::FFI_ArrowSchema] in the array to the location pointed by the raw pointers. -/// Usually the raw pointers are provided by the array data consumer. -pub unsafe fn export_array_into_raw( - src: ArrayRef, - out_array: *mut ffi::FFI_ArrowArray, - out_schema: *mut ffi::FFI_ArrowSchema, -) -> Result<()> { - let data = src.data(); - let array = ffi::FFI_ArrowArray::new(data); - let schema = ffi::FFI_ArrowSchema::try_from(data.data_type())?; - - std::ptr::write_unaligned(out_array, array); - std::ptr::write_unaligned(out_schema, schema); - - Ok(()) -} - // Helper function for printing potentially long arrays. pub(super) fn print_long_array( array: &A, diff --git a/arrow/src/array/array_binary.rs b/arrow/src/array/array_binary.rs index d9cad1cce661..1c63e8e24b29 100644 --- a/arrow/src/array/array_binary.rs +++ b/arrow/src/array/array_binary.rs @@ -20,12 +20,11 @@ use std::fmt; use std::{any::Any, iter::FromIterator}; use super::{ - array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, - FixedSizeListArray, GenericBinaryIter, GenericListArray, OffsetSizeTrait, + array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, GenericBinaryIter, + GenericListArray, OffsetSizeTrait, }; -pub use crate::array::DecimalIter; +use crate::array::array::ArrayAccessor; use crate::buffer::Buffer; -use crate::error::{ArrowError, Result}; use crate::util::bit_util; use crate::{buffer::MutableBuffer, datatypes::DataType}; @@ -38,15 +37,17 @@ pub struct GenericBinaryArray { } impl GenericBinaryArray { + /// Data type of the array. + pub const DATA_TYPE: DataType = if OffsetSize::IS_LARGE { + DataType::LargeBinary + } else { + DataType::Binary + }; + /// Get the data type of the array. - // Declare this function as `pub const fn` after - // https://github.com/rust-lang/rust/issues/93706 is merged. - pub fn get_data_type() -> DataType { - if OffsetSize::IS_LARGE { - DataType::LargeBinary - } else { - DataType::Binary - } + #[deprecated(note = "please use `Self::DATA_TYPE` instead")] + pub const fn get_data_type() -> DataType { + Self::DATA_TYPE } /// Returns the length for value at index `i`. @@ -98,8 +99,15 @@ impl GenericBinaryArray { } /// Returns the element at index `i` as bytes slice + /// # Panics + /// Panics if index `i` is out of bounds. pub fn value(&self, i: usize) -> &[u8] { - assert!(i < self.data.len(), "BinaryArray out of bounds access"); + assert!( + i < self.data.len(), + "Trying to access an element at index {} from a BinaryArray of length {}", + i, + self.len() + ); //Soundness: length checked above, offset buffer length is 1 larger than logical array length let end = unsafe { self.value_offsets().get_unchecked(i + 1) }; let start = unsafe { self.value_offsets().get_unchecked(i) }; @@ -135,21 +143,35 @@ impl GenericBinaryArray { fn from_list(v: GenericListArray) -> Self { assert_eq!( - v.data_ref().child_data()[0].child_data().len(), + v.data_ref().child_data().len(), + 1, + "BinaryArray can only be created from list array of u8 values \ + (i.e. List>)." + ); + let child_data = &v.data_ref().child_data()[0]; + + assert_eq!( + child_data.child_data().len(), 0, "BinaryArray can only be created from list array of u8 values \ (i.e. List>)." ); assert_eq!( - v.data_ref().child_data()[0].data_type(), + child_data.data_type(), &DataType::UInt8, "BinaryArray can only be created from List arrays, mismatched data types." ); + assert_eq!( + child_data.null_count(), + 0, + "The child array cannot contain null values." + ); - let builder = ArrayData::builder(Self::get_data_type()) + let builder = ArrayData::builder(Self::DATA_TYPE) .len(v.len()) + .offset(v.offset()) .add_buffer(v.data_ref().buffers()[0].clone()) - .add_buffer(v.data_ref().child_data()[0].buffers()[0].clone()) + .add_buffer(child_data.buffers()[0].slice(child_data.offset())) .null_bit_buffer(v.data_ref().null_buffer().cloned()); let data = unsafe { builder.build_unchecked() }; @@ -184,7 +206,7 @@ impl GenericBinaryArray { assert!(!offsets.is_empty()); // wrote at least one let actual_len = (offsets.len() / std::mem::size_of::()) - 1; - let array_data = ArrayData::builder(Self::get_data_type()) + let array_data = ArrayData::builder(Self::DATA_TYPE) .len(actual_len) .add_buffer(offsets.into()) .add_buffer(values.into()); @@ -210,18 +232,16 @@ impl GenericBinaryArray { ) -> impl Iterator> + 'a { indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) } -} -impl<'a, T: OffsetSizeTrait> GenericBinaryArray { /// constructs a new iterator - pub fn iter(&'a self) -> GenericBinaryIter<'a, T> { - GenericBinaryIter::<'a, T>::new(self) + pub fn iter(&self) -> GenericBinaryIter<'_, OffsetSize> { + GenericBinaryIter::<'_, OffsetSize>::new(self) } } impl fmt::Debug for GenericBinaryArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = if OffsetSize::IS_LARGE { "Large" } else { "" }; + let prefix = OffsetSize::PREFIX; write!(f, "{}BinaryArray\n[\n", prefix)?; print_long_array(self, f, |array, index, f| { @@ -245,11 +265,25 @@ impl Array for GenericBinaryArray { } } +impl<'a, OffsetSize: OffsetSizeTrait> ArrayAccessor + for &'a GenericBinaryArray +{ + type Item = &'a [u8]; + + fn value(&self, index: usize) -> Self::Item { + GenericBinaryArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + GenericBinaryArray::value_unchecked(self, index) + } +} + impl From for GenericBinaryArray { fn from(data: ArrayData) -> Self { assert_eq!( data.data_type(), - &Self::get_data_type(), + &Self::DATA_TYPE, "[Large]BinaryArray expects Datatype::[Large]Binary" ); assert_eq!( @@ -273,6 +307,26 @@ impl From> for Array } } +impl From>> + for GenericBinaryArray +{ + fn from(v: Vec>) -> Self { + Self::from_opt_vec(v) + } +} + +impl From> for GenericBinaryArray { + fn from(v: Vec<&[u8]>) -> Self { + Self::from_iter_values(v) + } +} + +impl From> for GenericBinaryArray { + fn from(v: GenericListArray) -> Self { + Self::from_list(v) + } +} + impl FromIterator> for GenericBinaryArray where @@ -306,7 +360,7 @@ where // calculate actual data_len, which may be different from the iterator's upper bound let data_len = offsets.len() - 1; - let array_data = ArrayData::builder(Self::get_data_type()) + let array_data = ArrayData::builder(Self::DATA_TYPE) .len(data_len) .add_buffer(Buffer::from_slice_ref(&offsets)) .add_buffer(Buffer::from_slice_ref(&values)) @@ -316,6 +370,15 @@ where } } +impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = GenericBinaryIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + GenericBinaryIter::<'a, T>::new(self) + } +} + /// An array where each element contains 0 or more bytes. /// The byte length of each element is represented by an i32. /// @@ -396,367 +459,10 @@ pub type BinaryArray = GenericBinaryArray; /// pub type LargeBinaryArray = GenericBinaryArray; -impl<'a, T: OffsetSizeTrait> IntoIterator for &'a GenericBinaryArray { - type Item = Option<&'a [u8]>; - type IntoIter = GenericBinaryIter<'a, T>; - - fn into_iter(self) -> Self::IntoIter { - GenericBinaryIter::<'a, T>::new(self) - } -} - -impl From>> - for GenericBinaryArray -{ - fn from(v: Vec>) -> Self { - Self::from_opt_vec(v) - } -} - -impl From> for GenericBinaryArray { - fn from(v: Vec<&[u8]>) -> Self { - Self::from_iter_values(v) - } -} - -impl From> for GenericBinaryArray { - fn from(v: GenericListArray) -> Self { - Self::from_list(v) - } -} - -/// An array where each element is a fixed-size sequence of bytes. -/// -/// # Examples -/// -/// Create an array from an iterable argument of byte slices. -/// -/// ``` -/// use arrow::array::{Array, FixedSizeBinaryArray}; -/// let input_arg = vec![ vec![1, 2], vec![3, 4], vec![5, 6] ]; -/// let arr = FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap(); -/// -/// assert_eq!(3, arr.len()); -/// -/// ``` -/// Create an array from an iterable argument of sparse byte slices. -/// Sparsity means that the input argument can contain `None` items. -/// ``` -/// use arrow::array::{Array, FixedSizeBinaryArray}; -/// let input_arg = vec![ None, Some(vec![7, 8]), Some(vec![9, 10]), None, Some(vec![13, 14]) ]; -/// let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); -/// assert_eq!(5, arr.len()) -/// -/// ``` -/// -pub struct FixedSizeBinaryArray { - data: ArrayData, - value_data: RawPtrBox, - length: i32, -} - -impl FixedSizeBinaryArray { - /// Returns the element at index `i` as a byte slice. - pub fn value(&self, i: usize) -> &[u8] { - assert!( - i < self.data.len(), - "FixedSizeBinaryArray out of bounds access" - ); - let offset = i + self.data.offset(); - unsafe { - let pos = self.value_offset_at(offset); - std::slice::from_raw_parts( - self.value_data.as_ptr().offset(pos as isize), - (self.value_offset_at(offset + 1) - pos) as usize, - ) - } - } - - /// Returns the element at index `i` as a byte slice. - /// # Safety - /// Caller is responsible for ensuring that the index is within the bounds of the array - pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { - let offset = i + self.data.offset(); - let pos = self.value_offset_at(offset); - std::slice::from_raw_parts( - self.value_data.as_ptr().offset(pos as isize), - (self.value_offset_at(offset + 1) - pos) as usize, - ) - } - - /// Returns the offset for the element at index `i`. - /// - /// Note this doesn't do any bound checking, for performance reason. - #[inline] - pub fn value_offset(&self, i: usize) -> i32 { - self.value_offset_at(self.data.offset() + i) - } - - /// Returns the length for an element. - /// - /// All elements have the same length as the array is a fixed size. - #[inline] - pub fn value_length(&self) -> i32 { - self.length - } - - /// Returns a clone of the value data buffer - pub fn value_data(&self) -> Buffer { - self.data.buffers()[0].clone() - } - - /// Create an array from an iterable argument of sparse byte slices. - /// Sparsity means that items returned by the iterator are optional, i.e input argument can - /// contain `None` items. - /// - /// # Examples - /// - /// ``` - /// use arrow::array::FixedSizeBinaryArray; - /// let input_arg = vec![ - /// None, - /// Some(vec![7, 8]), - /// Some(vec![9, 10]), - /// None, - /// Some(vec![13, 14]), - /// None, - /// ]; - /// let array = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); - /// ``` - /// - /// # Errors - /// - /// Returns error if argument has length zero, or sizes of nested slices don't match. - pub fn try_from_sparse_iter(mut iter: T) -> Result - where - T: Iterator>, - U: AsRef<[u8]>, - { - let mut len = 0; - let mut size = None; - let mut byte = 0; - let mut null_buf = MutableBuffer::from_len_zeroed(0); - let mut buffer = MutableBuffer::from_len_zeroed(0); - let mut prepend = 0; - iter.try_for_each(|item| -> Result<()> { - // extend null bitmask by one byte per each 8 items - if byte == 0 { - null_buf.push(0u8); - byte = 8; - } - byte -= 1; - - if let Some(slice) = item { - let slice = slice.as_ref(); - if let Some(size) = size { - if size != slice.len() { - return Err(ArrowError::InvalidArgumentError(format!( - "Nested array size mismatch: one is {}, and the other is {}", - size, - slice.len() - ))); - } - } else { - size = Some(slice.len()); - buffer.extend_zeros(slice.len() * prepend); - } - bit_util::set_bit(null_buf.as_slice_mut(), len); - buffer.extend_from_slice(slice); - } else if let Some(size) = size { - buffer.extend_zeros(size); - } else { - prepend += 1; - } - - len += 1; - - Ok(()) - })?; - - if len == 0 { - return Err(ArrowError::InvalidArgumentError( - "Input iterable argument has no data".to_owned(), - )); - } - - let size = size.unwrap_or(0); - let array_data = unsafe { - ArrayData::new_unchecked( - DataType::FixedSizeBinary(size as i32), - len, - None, - Some(null_buf.into()), - 0, - vec![buffer.into()], - vec![], - ) - }; - Ok(FixedSizeBinaryArray::from(array_data)) - } - - /// Create an array from an iterable argument of byte slices. - /// - /// # Examples - /// - /// ``` - /// use arrow::array::FixedSizeBinaryArray; - /// let input_arg = vec![ - /// vec![1, 2], - /// vec![3, 4], - /// vec![5, 6], - /// ]; - /// let array = FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap(); - /// ``` - /// - /// # Errors - /// - /// Returns error if argument has length zero, or sizes of nested slices don't match. - pub fn try_from_iter(mut iter: T) -> Result - where - T: Iterator, - U: AsRef<[u8]>, - { - let mut len = 0; - let mut size = None; - let mut buffer = MutableBuffer::from_len_zeroed(0); - iter.try_for_each(|item| -> Result<()> { - let slice = item.as_ref(); - if let Some(size) = size { - if size != slice.len() { - return Err(ArrowError::InvalidArgumentError(format!( - "Nested array size mismatch: one is {}, and the other is {}", - size, - slice.len() - ))); - } - } else { - size = Some(slice.len()); - } - buffer.extend_from_slice(slice); - - len += 1; - - Ok(()) - })?; - - if len == 0 { - return Err(ArrowError::InvalidArgumentError( - "Input iterable argument has no data".to_owned(), - )); - } - - let size = size.unwrap_or(0); - let array_data = ArrayData::builder(DataType::FixedSizeBinary(size as i32)) - .len(len) - .add_buffer(buffer.into()); - let array_data = unsafe { array_data.build_unchecked() }; - Ok(FixedSizeBinaryArray::from(array_data)) - } - - #[inline] - fn value_offset_at(&self, i: usize) -> i32 { - self.length * i as i32 - } -} - -impl From for FixedSizeBinaryArray { - fn from(data: ArrayData) -> Self { - assert_eq!( - data.buffers().len(), - 1, - "FixedSizeBinaryArray data should contain 1 buffer only (values)" - ); - let value_data = data.buffers()[0].as_ptr(); - let length = match data.data_type() { - DataType::FixedSizeBinary(len) => *len, - _ => panic!("Expected data type to be FixedSizeBinary"), - }; - Self { - data, - value_data: unsafe { RawPtrBox::new(value_data) }, - length, - } - } -} - -impl From for ArrayData { - fn from(array: FixedSizeBinaryArray) -> Self { - array.data - } -} - -/// Creates a `FixedSizeBinaryArray` from `FixedSizeList` array -impl From for FixedSizeBinaryArray { - fn from(v: FixedSizeListArray) -> Self { - assert_eq!( - v.data_ref().child_data()[0].child_data().len(), - 0, - "FixedSizeBinaryArray can only be created from list array of u8 values \ - (i.e. FixedSizeList>)." - ); - assert_eq!( - v.data_ref().child_data()[0].data_type(), - &DataType::UInt8, - "FixedSizeBinaryArray can only be created from FixedSizeList arrays, mismatched data types." - ); - - let builder = ArrayData::builder(DataType::FixedSizeBinary(v.value_length())) - .len(v.len()) - .add_buffer(v.data_ref().child_data()[0].buffers()[0].clone()) - .null_bit_buffer(v.data_ref().null_buffer().cloned()); - - let data = unsafe { builder.build_unchecked() }; - Self::from(data) - } -} - -impl From>> for FixedSizeBinaryArray { - fn from(v: Vec>) -> Self { - Self::try_from_sparse_iter(v.into_iter()).unwrap() - } -} - -impl From> for FixedSizeBinaryArray { - fn from(v: Vec<&[u8]>) -> Self { - Self::try_from_iter(v.into_iter()).unwrap() - } -} - -impl fmt::Debug for FixedSizeBinaryArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "FixedSizeBinaryArray<{}>\n[\n", self.value_length())?; - print_long_array(self, f, |array, index, f| { - fmt::Debug::fmt(&array.value(index), f) - })?; - write!(f, "]") - } -} - -impl Array for FixedSizeBinaryArray { - fn as_any(&self) -> &dyn Any { - self - } - - fn data(&self) -> &ArrayData { - &self.data - } - - fn into_data(self) -> ArrayData { - self.into() - } -} - #[cfg(test)] mod tests { - use std::sync::Arc; - - use crate::{ - array::{LargeListArray, ListArray}, - datatypes::{Field, Schema}, - record_batch::RecordBatch, - }; - use super::*; + use crate::{array::ListArray, datatypes::Field}; #[test] fn test_binary_array() { @@ -889,37 +595,36 @@ mod tests { assert_eq!(7, binary_array.value_length(1)); } - #[test] - fn test_binary_array_from_list_array() { - let values: [u8; 12] = [ - b'h', b'e', b'l', b'l', b'o', b'p', b'a', b'r', b'q', b'u', b'e', b't', - ]; - let values_data = ArrayData::builder(DataType::UInt8) + fn _test_generic_binary_array_from_list_array() { + let values = b"helloparquet"; + let child_data = ArrayData::builder(DataType::UInt8) .len(12) .add_buffer(Buffer::from(&values[..])) .build() .unwrap(); - let offsets: [i32; 4] = [0, 5, 5, 12]; + let offsets = [0, 5, 5, 12].map(|n| O::from_usize(n).unwrap()); // Array data: ["hello", "", "parquet"] - let array_data1 = ArrayData::builder(DataType::Binary) + let array_data1 = ArrayData::builder(GenericBinaryArray::::DATA_TYPE) .len(3) .add_buffer(Buffer::from_slice_ref(&offsets)) .add_buffer(Buffer::from_slice_ref(&values)) .build() .unwrap(); - let binary_array1 = BinaryArray::from(array_data1); + let binary_array1 = GenericBinaryArray::::from(array_data1); + + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( + Field::new("item", DataType::UInt8, false), + )); - let data_type = - DataType::List(Box::new(Field::new("item", DataType::UInt8, false))); let array_data2 = ArrayData::builder(data_type) .len(3) .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_child_data(values_data) + .add_child_data(child_data) .build() .unwrap(); - let list_array = ListArray::from(array_data2); - let binary_array2 = BinaryArray::from(list_array); + let list_array = GenericListArray::::from(array_data2); + let binary_array2 = GenericBinaryArray::::from(list_array); assert_eq!(2, binary_array2.data().buffers().len()); assert_eq!(0, binary_array2.data().child_data().len()); @@ -936,51 +641,98 @@ mod tests { } } + #[test] + fn test_binary_array_from_list_array() { + _test_generic_binary_array_from_list_array::(); + } + #[test] fn test_large_binary_array_from_list_array() { - let values: [u8; 12] = [ - b'h', b'e', b'l', b'l', b'o', b'p', b'a', b'r', b'q', b'u', b'e', b't', - ]; - let values_data = ArrayData::builder(DataType::UInt8) - .len(12) + _test_generic_binary_array_from_list_array::(); + } + + fn _test_generic_binary_array_from_list_array_with_offset() { + let values = b"HelloArrowAndParquet"; + // b"ArrowAndParquet" + let child_data = ArrayData::builder(DataType::UInt8) + .len(15) + .offset(5) .add_buffer(Buffer::from(&values[..])) .build() .unwrap(); - let offsets: [i64; 4] = [0, 5, 5, 12]; - // Array data: ["hello", "", "parquet"] - let array_data1 = ArrayData::builder(DataType::LargeBinary) - .len(3) + let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); + let null_buffer = Buffer::from_slice_ref(&[0b101]); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( + Field::new("item", DataType::UInt8, false), + )); + + // [None, Some(b"Parquet")] + let array_data = ArrayData::builder(data_type) + .len(2) + .offset(1) .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_buffer(Buffer::from_slice_ref(&values)) + .null_bit_buffer(Some(null_buffer)) + .add_child_data(child_data) .build() .unwrap(); - let binary_array1 = LargeBinaryArray::from(array_data1); + let list_array = GenericListArray::::from(array_data); + let binary_array = GenericBinaryArray::::from(list_array); - let data_type = - DataType::LargeList(Box::new(Field::new("item", DataType::UInt8, false))); - let array_data2 = ArrayData::builder(data_type) - .len(3) + assert_eq!(2, binary_array.len()); + assert_eq!(1, binary_array.null_count()); + assert!(binary_array.is_null(0)); + assert!(binary_array.is_valid(1)); + assert_eq!(b"Parquet", binary_array.value(1)); + } + + #[test] + fn test_binary_array_from_list_array_with_offset() { + _test_generic_binary_array_from_list_array_with_offset::(); + } + + #[test] + fn test_large_binary_array_from_list_array_with_offset() { + _test_generic_binary_array_from_list_array_with_offset::(); + } + + fn _test_generic_binary_array_from_list_array_with_child_nulls_failed< + O: OffsetSizeTrait, + >() { + let values = b"HelloArrow"; + let child_data = ArrayData::builder(DataType::UInt8) + .len(10) + .add_buffer(Buffer::from(&values[..])) + .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b1010101010]))) + .build() + .unwrap(); + + let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( + Field::new("item", DataType::UInt8, false), + )); + + // [None, Some(b"Parquet")] + let array_data = ArrayData::builder(data_type) + .len(2) .add_buffer(Buffer::from_slice_ref(&offsets)) - .add_child_data(values_data) + .add_child_data(child_data) .build() .unwrap(); - let list_array = LargeListArray::from(array_data2); - let binary_array2 = LargeBinaryArray::from(list_array); + let list_array = GenericListArray::::from(array_data); + drop(GenericBinaryArray::::from(list_array)); + } - assert_eq!(2, binary_array2.data().buffers().len()); - assert_eq!(0, binary_array2.data().child_data().len()); + #[test] + #[should_panic(expected = "The child array cannot contain null values.")] + fn test_binary_array_from_list_array_with_child_nulls_failed() { + _test_generic_binary_array_from_list_array_with_child_nulls_failed::(); + } - assert_eq!(binary_array1.len(), binary_array2.len()); - assert_eq!(binary_array1.null_count(), binary_array2.null_count()); - assert_eq!(binary_array1.value_offsets(), binary_array2.value_offsets()); - for i in 0..binary_array1.len() { - assert_eq!(binary_array1.value(i), binary_array2.value(i)); - assert_eq!(binary_array1.value(i), unsafe { - binary_array2.value_unchecked(i) - }); - assert_eq!(binary_array1.value_length(i), binary_array2.value_length(i)); - } + #[test] + #[should_panic(expected = "The child array cannot contain null values.")] + fn test_large_binary_array_from_list_array_with_child_nulls_failed() { + _test_generic_binary_array_from_list_array_with_child_nulls_failed::(); } fn test_generic_binary_array_from_opt_vec() { @@ -1060,89 +812,10 @@ mod tests { drop(BinaryArray::from(list_array)); } - #[test] - fn test_fixed_size_binary_array() { - let values: [u8; 15] = *b"hellotherearrow"; - - let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) - .len(3) - .add_buffer(Buffer::from(&values[..])) - .build() - .unwrap(); - let fixed_size_binary_array = FixedSizeBinaryArray::from(array_data); - assert_eq!(3, fixed_size_binary_array.len()); - assert_eq!(0, fixed_size_binary_array.null_count()); - assert_eq!( - [b'h', b'e', b'l', b'l', b'o'], - fixed_size_binary_array.value(0) - ); - assert_eq!( - [b't', b'h', b'e', b'r', b'e'], - fixed_size_binary_array.value(1) - ); - assert_eq!( - [b'a', b'r', b'r', b'o', b'w'], - fixed_size_binary_array.value(2) - ); - assert_eq!(5, fixed_size_binary_array.value_length()); - assert_eq!(10, fixed_size_binary_array.value_offset(2)); - for i in 0..3 { - assert!(fixed_size_binary_array.is_valid(i)); - assert!(!fixed_size_binary_array.is_null(i)); - } - - // Test binary array with offset - let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) - .len(2) - .offset(1) - .add_buffer(Buffer::from(&values[..])) - .build() - .unwrap(); - let fixed_size_binary_array = FixedSizeBinaryArray::from(array_data); - assert_eq!( - [b't', b'h', b'e', b'r', b'e'], - fixed_size_binary_array.value(0) - ); - assert_eq!( - [b'a', b'r', b'r', b'o', b'w'], - fixed_size_binary_array.value(1) - ); - assert_eq!(2, fixed_size_binary_array.len()); - assert_eq!(5, fixed_size_binary_array.value_offset(0)); - assert_eq!(5, fixed_size_binary_array.value_length()); - assert_eq!(10, fixed_size_binary_array.value_offset(1)); - } - #[test] #[should_panic( - expected = "FixedSizeBinaryArray can only be created from FixedSizeList arrays" + expected = "Trying to access an element at index 4 from a BinaryArray of length 3" )] - // Different error messages, so skip for now - // https://github.com/apache/arrow-rs/issues/1545 - #[cfg(not(feature = "force_validate"))] - fn test_fixed_size_binary_array_from_incorrect_list_array() { - let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; - let values_data = ArrayData::builder(DataType::UInt32) - .len(12) - .add_buffer(Buffer::from_slice_ref(&values)) - .build() - .unwrap(); - - let array_data = unsafe { - ArrayData::builder(DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Binary, false)), - 4, - )) - .len(3) - .add_child_data(values_data) - .build_unchecked() - }; - let list_array = FixedSizeListArray::from(array_data); - drop(FixedSizeBinaryArray::from(list_array)); - } - - #[test] - #[should_panic(expected = "BinaryArray out of bounds access")] fn test_binary_array_get_value_index_out_of_bound() { let values: [u8; 12] = [104, 101, 108, 108, 111, 112, 97, 114, 113, 117, 101, 116]; @@ -1157,114 +830,6 @@ mod tests { binary_array.value(4); } - #[test] - fn test_binary_array_fmt_debug() { - let values: [u8; 15] = *b"hellotherearrow"; - - let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) - .len(3) - .add_buffer(Buffer::from(&values[..])) - .build() - .unwrap(); - let arr = FixedSizeBinaryArray::from(array_data); - assert_eq!( - "FixedSizeBinaryArray<5>\n[\n [104, 101, 108, 108, 111],\n [116, 104, 101, 114, 101],\n [97, 114, 114, 111, 119],\n]", - format!("{:?}", arr) - ); - } - - #[test] - fn test_fixed_size_binary_array_from_iter() { - let input_arg = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; - let arr = FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap(); - - assert_eq!(2, arr.value_length()); - assert_eq!(3, arr.len()) - } - - #[test] - fn test_all_none_fixed_size_binary_array_from_sparse_iter() { - let none_option: Option<[u8; 32]> = None; - let input_arg = vec![none_option, none_option, none_option]; - let arr = - FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); - assert_eq!(0, arr.value_length()); - assert_eq!(3, arr.len()) - } - - #[test] - fn test_fixed_size_binary_array_from_sparse_iter() { - let input_arg = vec![ - None, - Some(vec![7, 8]), - Some(vec![9, 10]), - None, - Some(vec![13, 14]), - ]; - let arr = - FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); - assert_eq!(2, arr.value_length()); - assert_eq!(5, arr.len()) - } - - #[test] - fn test_fixed_size_binary_array_from_vec() { - let values = vec!["one".as_bytes(), b"two", b"six", b"ten"]; - let array = FixedSizeBinaryArray::from(values); - assert_eq!(array.len(), 4); - assert_eq!(array.null_count(), 0); - assert_eq!(array.value(0), b"one"); - assert_eq!(array.value(1), b"two"); - assert_eq!(array.value(2), b"six"); - assert_eq!(array.value(3), b"ten"); - assert!(!array.is_null(0)); - assert!(!array.is_null(1)); - assert!(!array.is_null(2)); - assert!(!array.is_null(3)); - } - - #[test] - #[should_panic(expected = "Nested array size mismatch: one is 3, and the other is 5")] - fn test_fixed_size_binary_array_from_vec_incorrect_length() { - let values = vec!["one".as_bytes(), b"two", b"three", b"four"]; - let _ = FixedSizeBinaryArray::from(values); - } - - #[test] - fn test_fixed_size_binary_array_from_opt_vec() { - let values = vec![ - Some("one".as_bytes()), - Some(b"two"), - None, - Some(b"six"), - Some(b"ten"), - ]; - let array = FixedSizeBinaryArray::from(values); - assert_eq!(array.len(), 5); - assert_eq!(array.value(0), b"one"); - assert_eq!(array.value(1), b"two"); - assert_eq!(array.value(3), b"six"); - assert_eq!(array.value(4), b"ten"); - assert!(!array.is_null(0)); - assert!(!array.is_null(1)); - assert!(array.is_null(2)); - assert!(!array.is_null(3)); - assert!(!array.is_null(4)); - } - - #[test] - #[should_panic(expected = "Nested array size mismatch: one is 3, and the other is 5")] - fn test_fixed_size_binary_array_from_opt_vec_incorrect_length() { - let values = vec![ - Some("one".as_bytes()), - Some(b"two"), - None, - Some(b"three"), - Some(b"four"), - ]; - let _ = FixedSizeBinaryArray::from(values); - } - #[test] fn test_binary_array_all_null() { let data = vec![None]; @@ -1284,33 +849,4 @@ mod tests { .validate_full() .expect("All null array has valid array data"); } - - #[test] - fn fixed_size_binary_array_all_null() { - let data = vec![None] as Vec>; - let array = FixedSizeBinaryArray::try_from_sparse_iter(data.into_iter()).unwrap(); - array - .data() - .validate_full() - .expect("All null array has valid array data"); - } - - #[test] - // Test for https://github.com/apache/arrow-rs/issues/1390 - #[should_panic( - expected = "column types must match schema types, expected FixedSizeBinary(2) but found FixedSizeBinary(0) at column index 0" - )] - fn fixed_size_binary_array_all_null_in_batch_with_schema() { - let schema = - Schema::new(vec![Field::new("a", DataType::FixedSizeBinary(2), true)]); - - let none_option: Option<[u8; 2]> = None; - let item = FixedSizeBinaryArray::try_from_sparse_iter( - vec![none_option, none_option, none_option].into_iter(), - ) - .unwrap(); - - // Should not panic - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(item)]).unwrap(); - } } diff --git a/arrow/src/array/array_boolean.rs b/arrow/src/array/array_boolean.rs index 279db3253d53..7ea18ea62036 100644 --- a/arrow/src/array/array_boolean.rs +++ b/arrow/src/array/array_boolean.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::array::array::ArrayAccessor; use std::borrow::Borrow; use std::convert::From; use std::iter::{FromIterator, IntoIterator}; @@ -94,7 +95,7 @@ impl BooleanArray { // Returns a new boolean array builder pub fn builder(capacity: usize) -> BooleanBuilder { - BooleanBuilder::new(capacity) + BooleanBuilder::with_capacity(capacity) } /// Returns a `Buffer` holding all the values of this array. @@ -114,10 +115,15 @@ impl BooleanArray { } /// Returns the boolean value at index `i`. - /// - /// Panics of offset `i` is out of bounds + /// # Panics + /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> bool { - assert!(i < self.len()); + assert!( + i < self.len(), + "Trying to access an element at index {} from a BooleanArray of length {}", + i, + self.len() + ); // Safety: // `i < self.len() unsafe { self.value_unchecked(i) } @@ -157,6 +163,18 @@ impl Array for BooleanArray { } } +impl<'a> ArrayAccessor for &'a BooleanArray { + type Item = bool; + + fn value(&self, index: usize) -> Self::Item { + BooleanArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + BooleanArray::value_unchecked(self, index) + } +} + impl From> for BooleanArray { fn from(data: Vec) -> Self { let mut mut_buf = MutableBuffer::new_null(data.len()); @@ -227,12 +245,12 @@ impl>> FromIterator for BooleanArray { let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. let num_bytes = bit_util::ceil(data_len, 8); - let mut null_buf = MutableBuffer::from_len_zeroed(num_bytes); - let mut val_buf = MutableBuffer::from_len_zeroed(num_bytes); + let mut null_builder = MutableBuffer::from_len_zeroed(num_bytes); + let mut val_builder = MutableBuffer::from_len_zeroed(num_bytes); - let data = val_buf.as_slice_mut(); + let data = val_builder.as_slice_mut(); - let null_slice = null_buf.as_slice_mut(); + let null_slice = null_builder.as_slice_mut(); iter.enumerate().for_each(|(i, item)| { if let Some(a) = item.borrow() { bit_util::set_bit(null_slice, i); @@ -247,9 +265,9 @@ impl>> FromIterator for BooleanArray { DataType::Boolean, data_len, None, - Some(null_buf.into()), + Some(null_builder.into()), 0, - vec![val_buf.into()], + vec![val_builder.into()], vec![], ) }; @@ -276,9 +294,9 @@ mod tests { #[test] fn test_boolean_with_null_fmt_debug() { let mut builder = BooleanArray::builder(3); - builder.append_value(true).unwrap(); - builder.append_null().unwrap(); - builder.append_value(false).unwrap(); + builder.append_value(true); + builder.append_null(); + builder.append_value(false); let arr = builder.finish(); assert_eq!( "BooleanArray\n[\n true,\n null,\n false,\n]", @@ -328,6 +346,7 @@ mod tests { assert_eq!(4, arr.len()); assert_eq!(0, arr.offset()); assert_eq!(0, arr.null_count()); + assert!(arr.data().null_buffer().is_none()); for i in 0..3 { assert!(!arr.is_null(i)); assert!(arr.is_valid(i)); @@ -335,6 +354,24 @@ mod tests { } } + #[test] + fn test_boolean_array_from_nullable_iter() { + let v = vec![Some(true), None, Some(false), None]; + let arr = v.into_iter().collect::(); + assert_eq!(4, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(2, arr.null_count()); + assert!(arr.data().null_buffer().is_some()); + + assert!(arr.is_valid(0)); + assert!(arr.is_null(1)); + assert!(arr.is_valid(2)); + assert!(arr.is_null(3)); + + assert!(arr.value(0)); + assert!(!arr.value(2)); + } + #[test] fn test_boolean_array_builder() { // Test building a boolean array with ArrayData builder and offset @@ -357,6 +394,17 @@ mod tests { } } + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a BooleanArray of length 3" + )] + fn test_fixed_size_binary_array_get_value_index_out_of_bound() { + let v = vec![Some(true), None, Some(false)]; + let array = v.into_iter().collect::(); + + array.value(4); + } + #[test] #[should_panic(expected = "BooleanArray data should contain a single buffer only \ (values buffer)")] diff --git a/arrow/src/array/array_decimal.rs b/arrow/src/array/array_decimal.rs index ccb1fe052845..543fda1b1a8a 100644 --- a/arrow/src/array/array_decimal.rs +++ b/arrow/src/array/array_decimal.rs @@ -15,39 +15,40 @@ // specific language governing permissions and limitations // under the License. -use std::borrow::Borrow; +use crate::array::ArrayAccessor; use std::convert::From; use std::fmt; +use std::marker::PhantomData; use std::{any::Any, iter::FromIterator}; use super::{ array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, FixedSizeListArray, }; -use super::{BooleanBufferBuilder, FixedSizeBinaryArray}; -pub use crate::array::DecimalIter; -use crate::buffer::Buffer; -use crate::datatypes::DataType; +use super::{BooleanBufferBuilder, DecimalIter, FixedSizeBinaryArray}; +#[allow(deprecated)] +use crate::buffer::{Buffer, MutableBuffer}; +use crate::datatypes::validate_decimal_precision; use crate::datatypes::{ - validate_decimal_precision, DECIMAL_DEFAULT_SCALE, DECIMAL_MAX_PRECISION, - DECIMAL_MAX_SCALE, + validate_decimal256_precision_with_lt_bytes, DataType, Decimal128Type, + Decimal256Type, DecimalType, NativeDecimalType, }; use crate::error::{ArrowError, Result}; -use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256}; +use crate::util::decimal::{Decimal, Decimal256}; -/// `DecimalArray` stores fixed width decimal numbers, +/// `Decimal128Array` stores fixed width decimal numbers, /// with a fixed precision and scale. /// /// # Examples /// /// ``` -/// use arrow::array::{Array, BasicDecimalArray, DecimalArray}; +/// use arrow::array::{Array, DecimalArray, Decimal128Array}; /// use arrow::datatypes::DataType; /// /// // Create a DecimalArray with the default precision and scale -/// let decimal_array: DecimalArray = vec![ -/// Some(8_887_000_000), +/// let decimal_array: Decimal128Array = vec![ +/// Some(8_887_000_000_i128), /// None, -/// Some(-8_887_000_000), +/// Some(-8_887_000_000_i128), /// ] /// .into_iter().collect(); /// @@ -57,7 +58,7 @@ use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256}; /// .with_precision_and_scale(23, 6) /// .unwrap(); /// -/// assert_eq!(&DataType::Decimal(23, 6), decimal_array.data_type()); +/// assert_eq!(&DataType::Decimal128(23, 6), decimal_array.data_type()); /// assert_eq!(8_887_000_000_i128, decimal_array.value(0).as_i128()); /// assert_eq!("8887.000000", decimal_array.value_as_string(0)); /// assert_eq!(3, decimal_array.len()); @@ -68,60 +69,79 @@ use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256}; /// assert_eq!(6, decimal_array.scale()); /// ``` /// -pub struct DecimalArray { - data: ArrayData, - value_data: RawPtrBox, - precision: usize, - scale: usize, -} +pub type Decimal128Array = DecimalArray; -pub struct Decimal256Array { +/// `Decimal256Array` stores fixed width decimal numbers, +/// with a fixed precision and scale +pub type Decimal256Array = DecimalArray; + +/// A generic [`Array`] for fixed width decimal numbers +/// +/// See [`Decimal128Array`] and [`Decimal256Array`] +pub struct DecimalArray { data: ArrayData, value_data: RawPtrBox, - precision: usize, - scale: usize, + precision: u8, + scale: u8, + _phantom: PhantomData, } -mod private_decimal { - pub trait DecimalArrayPrivate { - fn raw_value_data_ptr(&self) -> *const u8; - } -} +impl DecimalArray { + pub const VALUE_LENGTH: i32 = T::BYTE_LENGTH as i32; + const DEFAULT_TYPE: DataType = T::DEFAULT_TYPE; + pub const MAX_PRECISION: u8 = T::MAX_PRECISION; + pub const MAX_SCALE: u8 = T::MAX_SCALE; + const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = T::TYPE_CONSTRUCTOR; -pub trait BasicDecimalArray>: - private_decimal::DecimalArrayPrivate -{ - const VALUE_LENGTH: i32; - - fn data(&self) -> &ArrayData; + pub fn data(&self) -> &ArrayData { + &self.data + } /// Return the precision (total digits) that can be stored by this array - fn precision(&self) -> usize; + pub fn precision(&self) -> u8 { + self.precision + } /// Return the scale (digits after the decimal) that can be stored by this array - fn scale(&self) -> usize; + pub fn scale(&self) -> u8 { + self.scale + } /// Returns the element at index `i`. - fn value(&self, i: usize) -> T { - let data = self.data(); - assert!(i < data.len(), "Out of bounds access"); + /// # Panics + /// Panics if index `i` is out of bounds. + pub fn value(&self, i: usize) -> Decimal { + assert!( + i < self.data().len(), + "Trying to access an element at index {} from a DecimalArray of length {}", + i, + self.len() + ); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i`. + /// # Safety + /// Caller is responsible for ensuring that the index is within the bounds of the array + pub unsafe fn value_unchecked(&self, i: usize) -> Decimal { + let data = self.data(); let offset = i + data.offset(); - let raw_val = unsafe { + let raw_val = { let pos = self.value_offset_at(offset); - std::slice::from_raw_parts( + T::Native::from_slice(std::slice::from_raw_parts( self.raw_value_data_ptr().offset(pos as isize), Self::VALUE_LENGTH as usize, - ) + )) }; - T::new(self.precision(), self.scale(), raw_val) + Decimal::new(self.precision(), self.scale(), &raw_val) } /// Returns the offset for the element at index `i`. /// /// Note this doesn't do any bound checking, for performance reason. #[inline] - fn value_offset(&self, i: usize) -> i32 { + pub fn value_offset(&self, i: usize) -> i32 { self.value_offset_at(self.data().offset() + i) } @@ -129,22 +149,22 @@ pub trait BasicDecimalArray>: /// /// All elements have the same length as the array is a fixed size. #[inline] - fn value_length(&self) -> i32 { + pub fn value_length(&self) -> i32 { Self::VALUE_LENGTH } /// Returns a clone of the value data buffer - fn value_data(&self) -> Buffer { + pub fn value_data(&self) -> Buffer { self.data().buffers()[0].clone() } #[inline] - fn value_offset_at(&self, i: usize) -> i32 { + pub fn value_offset_at(&self, i: usize) -> i32 { Self::VALUE_LENGTH * i as i32 } #[inline] - fn value_as_string(&self, row: usize) -> String { + pub fn value_as_string(&self, row: usize) -> String { self.value(row).to_string() } @@ -152,32 +172,46 @@ pub trait BasicDecimalArray>: /// /// NB: This function does not validate that each value is in the permissible /// range for a decimal - fn from_fixed_size_binary_array( + pub fn from_fixed_size_binary_array( v: FixedSizeBinaryArray, - precision: usize, - scale: usize, - ) -> U { + precision: u8, + scale: u8, + ) -> Self { assert!( v.value_length() == Self::VALUE_LENGTH, "Value length of the array ({}) must equal to the byte width of the decimal ({})", v.value_length(), Self::VALUE_LENGTH, ); - let builder = v - .into_data() - .into_builder() - .data_type(DataType::Decimal(precision, scale)); + let data_type = if Self::VALUE_LENGTH == 16 { + DataType::Decimal128(precision, scale) + } else { + DataType::Decimal256(precision, scale) + }; + let builder = v.into_data().into_builder().data_type(data_type); let array_data = unsafe { builder.build_unchecked() }; - U::from(array_data) + Self::from(array_data) } - fn from_fixed_size_list_array( + /// Build a decimal array from [`FixedSizeListArray`]. + /// + /// NB: This function does not validate that each value is in the permissible + /// range for a decimal. + #[deprecated(note = "please use `from_fixed_size_binary_array` instead")] + pub fn from_fixed_size_list_array( v: FixedSizeListArray, - precision: usize, - scale: usize, - ) -> U { + precision: u8, + scale: u8, + ) -> Self { + assert_eq!( + v.data_ref().child_data().len(), + 1, + "DecimalArray can only be created from list array of u8 values \ + (i.e. FixedSizeList>)." + ); let child_data = &v.data_ref().child_data()[0]; + assert_eq!( child_data.child_data().len(), 0, @@ -189,54 +223,121 @@ pub trait BasicDecimalArray>: &DataType::UInt8, "DecimalArray can only be created from FixedSizeList arrays, mismatched data types." ); + assert!( + v.value_length() == Self::VALUE_LENGTH, + "Value length of the array ({}) must equal to the byte width of the decimal ({})", + v.value_length(), + Self::VALUE_LENGTH, + ); + assert_eq!( + v.data_ref().child_data()[0].null_count(), + 0, + "The child array cannot contain null values." + ); let list_offset = v.offset(); let child_offset = child_data.offset(); - let builder = ArrayData::builder(DataType::Decimal(precision, scale)) + let data_type = if Self::VALUE_LENGTH == 16 { + DataType::Decimal128(precision, scale) + } else { + DataType::Decimal256(precision, scale) + }; + let builder = ArrayData::builder(data_type) .len(v.len()) .add_buffer(child_data.buffers()[0].slice(child_offset)) .null_bit_buffer(v.data_ref().null_buffer().cloned()) .offset(list_offset); let array_data = unsafe { builder.build_unchecked() }; - U::from(array_data) + Self::from(array_data) } -} -impl BasicDecimalArray for DecimalArray { - const VALUE_LENGTH: i32 = 16; - - fn data(&self) -> &ArrayData { - &self.data + /// The default precision and scale used when not specified. + pub const fn default_type() -> DataType { + Self::DEFAULT_TYPE } - fn precision(&self) -> usize { - self.precision + fn raw_value_data_ptr(&self) -> *const u8 { + self.value_data.as_ptr() } - fn scale(&self) -> usize { - self.scale - } -} - -impl BasicDecimalArray for Decimal256Array { - const VALUE_LENGTH: i32 = 32; + /// Returns a Decimal array with the same data as self, with the + /// specified precision. + /// + /// Returns an Error if: + /// 1. `precision` is larger than [`Self::MAX_PRECISION`] + /// 2. `scale` is larger than [`Self::MAX_SCALE`]; + /// 3. `scale` is > `precision` + pub fn with_precision_and_scale(self, precision: u8, scale: u8) -> Result + where + Self: Sized, + { + // validate precision and scale + self.validate_precision_scale(precision, scale)?; - fn data(&self) -> &ArrayData { - &self.data - } + // Ensure that all values are within the requested + // precision. For performance, only check if the precision is + // decreased + if precision < self.precision { + self.validate_data(precision)?; + } - fn precision(&self) -> usize { - self.precision + // safety: self.data is valid DataType::Decimal as checked above + let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale); + Ok(self.data().clone().with_data_type(new_data_type).into()) } - fn scale(&self) -> usize { - self.scale + // validate that the new precision and scale are valid or not + fn validate_precision_scale(&self, precision: u8, scale: u8) -> Result<()> { + if precision > Self::MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "precision {} is greater than max {}", + precision, + Self::MAX_PRECISION + ))); + } + if scale > Self::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than max {}", + scale, + Self::MAX_SCALE + ))); + } + if scale > precision { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than precision {}", + scale, precision + ))); + } + let data_type = Self::TYPE_CONSTRUCTOR(self.precision, self.scale); + assert_eq!(self.data().data_type(), &data_type); + + Ok(()) + } + + // validate all the data in the array are valid within the new precision or not + fn validate_data(&self, precision: u8) -> Result<()> { + // TODO: Move into DecimalType + match Self::VALUE_LENGTH { + 16 => self + .as_any() + .downcast_ref::() + .unwrap() + .validate_decimal_precision(precision), + 32 => self + .as_any() + .downcast_ref::() + .unwrap() + .validate_decimal_precision(precision), + other_width => { + panic!("invalid byte width {}", other_width); + } + } } } -impl DecimalArray { - /// Creates a [DecimalArray] with default precision and scale, +impl Decimal128Array { + /// Creates a [Decimal128Array] with default precision and scale, /// based on an iterator of `i128` values without nulls pub fn from_iter_values>(iter: I) -> Self { let val_buf: Buffer = iter.into_iter().collect(); @@ -251,70 +352,45 @@ impl DecimalArray { vec![], ) }; - DecimalArray::from(data) - } - - /// Returns a DecimalArray with the same data as self, with the - /// specified precision. - /// - /// Returns an Error if: - /// 1. `precision` is larger than [`DECIMAL_MAX_PRECISION`] - /// 2. `scale` is larger than [`DECIMAL_MAX_SCALE`]; - /// 3. `scale` is > `precision` - pub fn with_precision_and_scale( - mut self, - precision: usize, - scale: usize, - ) -> Result { - if precision > DECIMAL_MAX_PRECISION { - return Err(ArrowError::InvalidArgumentError(format!( - "precision {} is greater than max {}", - precision, DECIMAL_MAX_PRECISION - ))); - } - if scale > DECIMAL_MAX_SCALE { - return Err(ArrowError::InvalidArgumentError(format!( - "scale {} is greater than max {}", - scale, DECIMAL_MAX_SCALE - ))); - } - if scale > precision { - return Err(ArrowError::InvalidArgumentError(format!( - "scale {} is greater than precision {}", - scale, precision - ))); - } - - // Ensure that all values are within the requested - // precision. For performance, only check if the precision is - // decreased - if precision < self.precision { - for v in self.iter().flatten() { - validate_decimal_precision(v, precision)?; + Decimal128Array::from(data) + } + + // Validates decimal128 values in this array can be properly interpreted + // with the specified precision. + fn validate_decimal_precision(&self, precision: u8) -> Result<()> { + (0..self.len()).try_for_each(|idx| { + if self.is_valid(idx) { + let decimal = unsafe { self.value_unchecked(idx) }; + validate_decimal_precision(decimal.as_i128(), precision) + } else { + Ok(()) } - } - - assert_eq!( - self.data.data_type(), - &DataType::Decimal(self.precision, self.scale) - ); - - // safety: self.data is valid DataType::Decimal as checked above - let new_data_type = DataType::Decimal(precision, scale); - self.precision = precision; - self.scale = scale; - self.data = self.data.with_data_type(new_data_type); - Ok(self) + }) } +} - /// The default precision and scale used when not specified. - pub fn default_type() -> DataType { - // Keep maximum precision - DataType::Decimal(DECIMAL_MAX_PRECISION, DECIMAL_DEFAULT_SCALE) +impl Decimal256Array { + // Validates decimal256 values in this array can be properly interpreted + // with the specified precision. + fn validate_decimal_precision(&self, precision: u8) -> Result<()> { + (0..self.len()).try_for_each(|idx| { + if self.is_valid(idx) { + let raw_val = unsafe { + let pos = self.value_offset(idx); + std::slice::from_raw_parts( + self.raw_value_data_ptr().offset(pos as isize), + Self::VALUE_LENGTH as usize, + ) + }; + validate_decimal256_precision_with_lt_bytes(raw_val, precision) + } else { + Ok(()) + } + }) } } -impl From for DecimalArray { +impl From for DecimalArray { fn from(data: ArrayData) -> Self { assert_eq!( data.buffers().len(), @@ -322,8 +398,9 @@ impl From for DecimalArray { "DecimalArray data should contain 1 buffer only (values)" ); let values = data.buffers()[0].as_ptr(); - let (precision, scale) = match data.data_type() { - DataType::Decimal(precision, scale) => (*precision, *scale), + let (precision, scale) = match (data.data_type(), Self::VALUE_LENGTH) { + (DataType::Decimal128(precision, scale), 16) + | (DataType::Decimal256(precision, scale), 32) => (*precision, *scale), _ => panic!("Expected data type to be Decimal"), }; Self { @@ -331,49 +408,55 @@ impl From for DecimalArray { value_data: unsafe { RawPtrBox::new(values) }, precision, scale, + _phantom: Default::default(), } } } -impl From for Decimal256Array { - fn from(data: ArrayData) -> Self { - assert_eq!( - data.buffers().len(), - 1, - "DecimalArray data should contain 1 buffer only (values)" - ); - let values = data.buffers()[0].as_ptr(); - let (precision, scale) = match data.data_type() { - DataType::Decimal(precision, scale) => (*precision, *scale), - _ => panic!("Expected data type to be Decimal"), - }; - Self { - data, - value_data: unsafe { RawPtrBox::new(values) }, - precision, - scale, - } - } +fn build_decimal_array_from( + null_buf: BooleanBufferBuilder, + buffer: Buffer, +) -> DecimalArray { + let data = unsafe { + ArrayData::new_unchecked( + DecimalArray::::default_type(), + null_buf.len(), + None, + Some(null_buf.into()), + 0, + vec![buffer], + vec![], + ) + }; + DecimalArray::from(data) } -impl<'a> IntoIterator for &'a DecimalArray { - type Item = Option; - type IntoIter = DecimalIter<'a>; +impl> FromIterator> for Decimal256Array { + fn from_iter>>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, upper) = iter.size_hint(); + let size_hint = upper.unwrap_or(lower); - fn into_iter(self) -> Self::IntoIter { - DecimalIter::<'a>::new(self) - } -} + let mut null_buf = BooleanBufferBuilder::new(size_hint); -impl<'a> DecimalArray { - /// constructs a new iterator - pub fn iter(&'a self) -> DecimalIter<'a> { - DecimalIter::new(self) + let mut buffer = MutableBuffer::with_capacity(size_hint); + + iter.for_each(|item| { + if let Some(a) = item { + null_buf.append(true); + buffer.extend_from_slice(Into::into(a).raw_value()); + } else { + null_buf.append(false); + buffer.extend_zeros(32); + } + }); + + build_decimal_array_from(null_buf, buffer.into()) } } -impl>> FromIterator for DecimalArray { - fn from_iter>(iter: I) -> Self { +impl> FromIterator> for Decimal128Array { + fn from_iter>>(iter: I) -> Self { let iter = iter.into_iter(); let (lower, upper) = iter.size_hint(); let size_hint = upper.unwrap_or(lower); @@ -382,9 +465,9 @@ impl>> FromIterator for DecimalArray { let buffer: Buffer = iter .map(|item| { - if let Some(a) = item.borrow() { + if let Some(a) = item { null_buf.append(true); - *a + a.into() } else { null_buf.append(false); // arbitrary value for NULL @@ -393,73 +476,83 @@ impl>> FromIterator for DecimalArray { }) .collect(); - let data = unsafe { - ArrayData::new_unchecked( - Self::default_type(), - null_buf.len(), - None, - Some(null_buf.into()), - 0, - vec![buffer], - vec![], - ) - }; - DecimalArray::from(data) + build_decimal_array_from(null_buf, buffer) } } -macro_rules! def_decimal_array { - ($ty:ident, $array_name:expr) => { - impl private_decimal::DecimalArrayPrivate for $ty { - fn raw_value_data_ptr(&self) -> *const u8 { - self.value_data.as_ptr() - } - } +impl Array for DecimalArray { + fn as_any(&self) -> &dyn Any { + self + } - impl Array for $ty { - fn as_any(&self) -> &dyn Any { - self - } + fn data(&self) -> &ArrayData { + &self.data + } - fn data(&self) -> &ArrayData { - &self.data - } + fn into_data(self) -> ArrayData { + self.into() + } +} - fn into_data(self) -> ArrayData { - self.into() - } - } +impl From> for ArrayData { + fn from(array: DecimalArray) -> Self { + array.data + } +} - impl From<$ty> for ArrayData { - fn from(array: $ty) -> Self { - array.data - } - } +impl fmt::Debug for DecimalArray { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Decimal{}Array<{}, {}>\n[\n", + T::BYTE_LENGTH * 8, + self.precision, + self.scale + )?; + print_long_array(self, f, |array, index, f| { + let formatted_decimal = array.value_as_string(index); + + write!(f, "{}", formatted_decimal) + })?; + write!(f, "]") + } +} - impl fmt::Debug for $ty { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{}<{}, {}>\n[\n", - $array_name, self.precision, self.scale - )?; - print_long_array(self, f, |array, index, f| { - let formatted_decimal = array.value_as_string(index); - - write!(f, "{}", formatted_decimal) - })?; - write!(f, "]") - } - } - }; +impl<'a, T: DecimalType> ArrayAccessor for &'a DecimalArray { + type Item = Decimal; + + fn value(&self, index: usize) -> Self::Item { + DecimalArray::::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + DecimalArray::::value_unchecked(self, index) + } +} + +impl<'a, T: DecimalType> IntoIterator for &'a DecimalArray { + type Item = Option>; + type IntoIter = DecimalIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + DecimalIter::<'a, T>::new(self) + } } -def_decimal_array!(DecimalArray, "DecimalArray"); -def_decimal_array!(Decimal256Array, "Decimal256Array"); +impl<'a, T: DecimalType> DecimalArray { + /// constructs a new iterator + pub fn iter(&'a self) -> DecimalIter<'a, T> { + DecimalIter::<'a, T>::new(self) + } +} #[cfg(test)] mod tests { - use crate::{array::DecimalBuilder, datatypes::Field}; + use crate::array::Decimal256Builder; + use crate::datatypes::{DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE}; + use crate::util::decimal::Decimal128; + use crate::{array::Decimal128Builder, datatypes::Field}; + use num::{BigInt, Num}; use super::*; @@ -471,12 +564,12 @@ mod tests { 192, 219, 180, 17, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 64, 36, 75, 238, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, ]; - let array_data = ArrayData::builder(DataType::Decimal(38, 6)) + let array_data = ArrayData::builder(DataType::Decimal128(38, 6)) .len(2) .add_buffer(Buffer::from(&values[..])) .build() .unwrap(); - let decimal_array = DecimalArray::from(array_data); + let decimal_array = Decimal128Array::from(array_data); assert_eq!(8_887_000_000_i128, decimal_array.value(0).into()); assert_eq!(-8_887_000_000_i128, decimal_array.value(1).into()); assert_eq!(16, decimal_array.value_length()); @@ -485,11 +578,11 @@ mod tests { #[test] #[cfg(not(feature = "force_validate"))] fn test_decimal_append_error_value() { - let mut decimal_builder = DecimalBuilder::new(10, 5, 3); + let mut decimal_builder = Decimal128Builder::with_capacity(10, 5, 3); let mut result = decimal_builder.append_value(123456); let mut error = result.unwrap_err(); assert_eq!( - "Invalid argument error: 123456 is too large to store in a Decimal of precision 5. Max is 99999", + "Invalid argument error: 123456 is too large to store in a Decimal128 of precision 5. Max is 99999", error.to_string() ); @@ -502,11 +595,11 @@ mod tests { let arr = decimal_builder.finish(); assert_eq!("12.345", arr.value_as_string(1)); - decimal_builder = DecimalBuilder::new(10, 2, 1); + decimal_builder = Decimal128Builder::new(2, 1); result = decimal_builder.append_value(100); error = result.unwrap_err(); assert_eq!( - "Invalid argument error: 100 is too large to store in a Decimal of precision 2. Max is 99", + "Invalid argument error: 100 is too large to store in a Decimal128 of precision 2. Max is 99", error.to_string() ); @@ -526,9 +619,9 @@ mod tests { #[test] fn test_decimal_from_iter_values() { - let array = DecimalArray::from_iter_values(vec![-100, 0, 101].into_iter()); + let array = Decimal128Array::from_iter_values(vec![-100, 0, 101].into_iter()); assert_eq!(array.len(), 3); - assert_eq!(array.data_type(), &DataType::Decimal(38, 10)); + assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); assert_eq!(-100_i128, array.value(0).into()); assert!(!array.is_null(0)); assert_eq!(0_i128, array.value(1).into()); @@ -539,9 +632,10 @@ mod tests { #[test] fn test_decimal_from_iter() { - let array: DecimalArray = vec![Some(-100), None, Some(101)].into_iter().collect(); + let array: Decimal128Array = + vec![Some(-100), None, Some(101)].into_iter().collect(); assert_eq!(array.len(), 3); - assert_eq!(array.data_type(), &DataType::Decimal(38, 10)); + assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); assert_eq!(-100_i128, array.value(0).into()); assert!(!array.is_null(0)); assert!(array.is_null(1)); @@ -552,25 +646,26 @@ mod tests { #[test] fn test_decimal_iter() { let data = vec![Some(-100), None, Some(101)]; - let array: DecimalArray = data.clone().into_iter().collect(); + let array: Decimal128Array = data.clone().into_iter().collect(); - let collected: Vec<_> = array.iter().collect(); + let collected: Vec<_> = array.iter().map(|d| d.map(|v| v.as_i128())).collect(); assert_eq!(data, collected); } #[test] fn test_decimal_into_iter() { let data = vec![Some(-100), None, Some(101)]; - let array: DecimalArray = data.clone().into_iter().collect(); + let array: Decimal128Array = data.clone().into_iter().collect(); - let collected: Vec<_> = array.into_iter().collect(); + let collected: Vec<_> = + array.into_iter().map(|d| d.map(|v| v.as_i128())).collect(); assert_eq!(data, collected); } #[test] fn test_decimal_iter_sized() { let data = vec![Some(-100), None, Some(101)]; - let array: DecimalArray = data.into_iter().collect(); + let array: Decimal128Array = data.into_iter().collect(); let mut iter = array.into_iter(); // is exact sized @@ -592,7 +687,7 @@ mod tests { let arr = [123450, -123450, 100, -100, 10, -10, 0] .into_iter() .map(Some) - .collect::() + .collect::() .with_precision_and_scale(6, 3) .unwrap(); @@ -607,11 +702,11 @@ mod tests { #[test] fn test_decimal_array_with_precision_and_scale() { - let arr = DecimalArray::from_iter_values([12345, 456, 7890, -123223423432432]) + let arr = Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432]) .with_precision_and_scale(20, 2) .unwrap(); - assert_eq!(arr.data_type(), &DataType::Decimal(20, 2)); + assert_eq!(arr.data_type(), &DataType::Decimal128(20, 2)); assert_eq!(arr.precision(), 20); assert_eq!(arr.scale(), 2); @@ -623,10 +718,10 @@ mod tests { #[test] #[should_panic( - expected = "-123223423432432 is too small to store in a Decimal of precision 5. Min is -99999" + expected = "-123223423432432 is too small to store in a Decimal128 of precision 5. Min is -99999" )] fn test_decimal_array_with_precision_and_scale_out_of_range() { - DecimalArray::from_iter_values([12345, 456, 7890, -123223423432432]) + Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432]) // precision is too small to hold value .with_precision_and_scale(5, 2) .unwrap(); @@ -635,7 +730,7 @@ mod tests { #[test] #[should_panic(expected = "precision 40 is greater than max 38")] fn test_decimal_array_with_precision_and_scale_invalid_precision() { - DecimalArray::from_iter_values([12345, 456]) + Decimal128Array::from_iter_values([12345, 456]) .with_precision_and_scale(40, 2) .unwrap(); } @@ -643,7 +738,7 @@ mod tests { #[test] #[should_panic(expected = "scale 40 is greater than max 38")] fn test_decimal_array_with_precision_and_scale_invalid_scale() { - DecimalArray::from_iter_values([12345, 456]) + Decimal128Array::from_iter_values([12345, 456]) .with_precision_and_scale(20, 40) .unwrap(); } @@ -651,21 +746,21 @@ mod tests { #[test] #[should_panic(expected = "scale 10 is greater than precision 4")] fn test_decimal_array_with_precision_and_scale_invalid_precision_and_scale() { - DecimalArray::from_iter_values([12345, 456]) + Decimal128Array::from_iter_values([12345, 456]) .with_precision_and_scale(4, 10) .unwrap(); } #[test] fn test_decimal_array_fmt_debug() { - let arr = [Some(8887000000), Some(-8887000000), None] - .iter() - .collect::() + let arr = [Some(8887000000_i128), Some(-8887000000_i128), None] + .into_iter() + .collect::() .with_precision_and_scale(23, 6) .unwrap(); assert_eq!( - "DecimalArray<23, 6>\n[\n 8887.000000,\n -8887.000000,\n null,\n]", + "Decimal128Array<23, 6>\n[\n 8887.000000,\n -8887.000000,\n null,\n]", format!("{:?}", arr) ); } @@ -681,7 +776,7 @@ mod tests { .unwrap(); let binary_array = FixedSizeBinaryArray::from(value_data); - let decimal = DecimalArray::from_fixed_size_binary_array(binary_array, 38, 1); + let decimal = Decimal128Array::from_fixed_size_binary_array(binary_array, 38, 1); assert_eq!(decimal.len(), 3); assert_eq!(decimal.value_as_string(0), "0.2".to_string()); @@ -703,10 +798,11 @@ mod tests { .unwrap(); let binary_array = FixedSizeBinaryArray::from(value_data); - let _ = DecimalArray::from_fixed_size_binary_array(binary_array, 38, 1); + let _ = Decimal128Array::from_fixed_size_binary_array(binary_array, 38, 1); } #[test] + #[allow(deprecated)] fn test_decimal_array_from_fixed_size_list() { let value_data = ArrayData::builder(DataType::UInt8) .offset(16) @@ -730,10 +826,147 @@ mod tests { .build() .unwrap(); let list_array = FixedSizeListArray::from(list_data); - let decimal = DecimalArray::from_fixed_size_list_array(list_array, 38, 0); + let decimal = Decimal128Array::from_fixed_size_list_array(list_array, 38, 0); assert_eq!(decimal.len(), 2); assert!(decimal.is_null(0)); assert_eq!(decimal.value_as_string(1), "56".to_string()); } + + #[test] + #[allow(deprecated)] + #[should_panic(expected = "The child array cannot contain null values.")] + fn test_decimal_array_from_fixed_size_list_with_child_nulls_failed() { + let value_data = ArrayData::builder(DataType::UInt8) + .len(16) + .add_buffer(Buffer::from_slice_ref(&[12_i128])) + .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b1010101010101010]))) + .build() + .unwrap(); + + // Construct a list array from the above two + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::UInt8, false)), + 16, + ); + let list_data = ArrayData::builder(list_data_type) + .len(1) + .add_child_data(value_data) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + drop(Decimal128Array::from_fixed_size_list_array( + list_array, 38, 0, + )); + } + + #[test] + #[allow(deprecated)] + #[should_panic( + expected = "Value length of the array (8) must equal to the byte width of the decimal (16)" + )] + fn test_decimal_array_from_fixed_size_list_with_wrong_length() { + let value_data = ArrayData::builder(DataType::UInt8) + .len(16) + .add_buffer(Buffer::from_slice_ref(&[12_i128])) + .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b1010101010101010]))) + .build() + .unwrap(); + + // Construct a list array from the above two + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::UInt8, false)), + 8, + ); + let list_data = ArrayData::builder(list_data_type) + .len(2) + .add_child_data(value_data) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + drop(Decimal128Array::from_fixed_size_list_array( + list_array, 38, 0, + )); + } + + #[test] + fn test_decimal256_iter() { + let mut builder = Decimal256Builder::with_capacity(30, 76, 6); + let value = BigInt::from_str_radix("12345", 10).unwrap(); + let decimal1 = Decimal256::from_big_int(&value, 76, 6).unwrap(); + builder.append_value(&decimal1).unwrap(); + + builder.append_null(); + + let value = BigInt::from_str_radix("56789", 10).unwrap(); + let decimal2 = Decimal256::from_big_int(&value, 76, 6).unwrap(); + builder.append_value(&decimal2).unwrap(); + + let array: Decimal256Array = builder.finish(); + + let collected: Vec<_> = array.iter().collect(); + assert_eq!(vec![Some(decimal1), None, Some(decimal2)], collected); + } + + #[test] + fn test_from_iter_decimal256array() { + let value1 = BigInt::from_str_radix("12345", 10).unwrap(); + let value2 = BigInt::from_str_radix("56789", 10).unwrap(); + + let array: Decimal256Array = + vec![Some(value1.clone()), None, Some(value2.clone())] + .into_iter() + .collect(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal256(76, 10)); + assert_eq!( + Decimal256::from_big_int( + &value1, + DECIMAL256_MAX_PRECISION, + DECIMAL_DEFAULT_SCALE, + ) + .unwrap(), + array.value(0) + ); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!( + Decimal256::from_big_int( + &value2, + DECIMAL256_MAX_PRECISION, + DECIMAL_DEFAULT_SCALE, + ) + .unwrap(), + array.value(2) + ); + assert!(!array.is_null(2)); + } + + #[test] + fn test_from_iter_decimal128array() { + let array: Decimal128Array = vec![ + Some(Decimal128::new_from_i128(38, 10, -100)), + None, + Some(Decimal128::new_from_i128(38, 10, 101)), + ] + .into_iter() + .collect(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal128(38, 10)); + assert_eq!(-100_i128, array.value(0).into()); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(101_i128, array.value(2).into()); + assert!(!array.is_null(2)); + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a DecimalArray of length 3" + )] + fn test_fixed_size_binary_array_get_value_index_out_of_bound() { + let array = Decimal128Array::from_iter_values(vec![-100, 0, 101].into_iter()); + + array.value(4); + } } diff --git a/arrow/src/array/array_dictionary.rs b/arrow/src/array/array_dictionary.rs index 9350daae53e1..79f2969df688 100644 --- a/arrow/src/array/array_dictionary.rs +++ b/arrow/src/array/array_dictionary.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::array::{ArrayAccessor, ArrayIter}; use std::any::Any; use std::fmt; use std::iter::IntoIterator; @@ -234,6 +235,28 @@ impl DictionaryArray { .expect("Dictionary index not usize") }) } + + /// Downcast this dictionary to a [`TypedDictionaryArray`] + /// + /// ``` + /// use arrow::array::{Array, ArrayAccessor, DictionaryArray, StringArray}; + /// use arrow::datatypes::Int32Type; + /// + /// let orig = [Some("a"), Some("b"), None]; + /// let dictionary = DictionaryArray::::from_iter(orig); + /// let typed = dictionary.downcast_dict::().unwrap(); + /// assert_eq!(typed.value(0), "a"); + /// assert_eq!(typed.value(1), "b"); + /// assert!(typed.is_null(2)); + /// ``` + /// + pub fn downcast_dict(&self) -> Option> { + let values = self.values.as_any().downcast_ref()?; + Some(TypedDictionaryArray { + dictionary: self, + values, + }) + } } /// Constructs a `DictionaryArray` from an array data reference. @@ -302,14 +325,12 @@ impl From> for ArrayData { /// format!("{:?}", array) /// ); /// ``` -impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> FromIterator> - for DictionaryArray -{ +impl<'a, T: ArrowDictionaryKeyType> FromIterator> for DictionaryArray { fn from_iter>>(iter: I) -> Self { let it = iter.into_iter(); let (lower, _) = it.size_hint(); - let key_builder = PrimitiveBuilder::::new(lower); - let value_builder = StringBuilder::new(256); + let key_builder = PrimitiveBuilder::::with_capacity(lower); + let value_builder = StringBuilder::with_capacity(256, 1024); let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); it.for_each(|i| { if let Some(i) = i { @@ -319,9 +340,7 @@ impl<'a, T: ArrowPrimitiveType + ArrowDictionaryKeyType> FromIterator FromIterator FromIterator<&'a str> - for DictionaryArray -{ +impl<'a, T: ArrowDictionaryKeyType> FromIterator<&'a str> for DictionaryArray { fn from_iter>(iter: I) -> Self { let it = iter.into_iter(); let (lower, _) = it.size_hint(); - let key_builder = PrimitiveBuilder::::new(lower); - let value_builder = StringBuilder::new(256); + let key_builder = PrimitiveBuilder::::with_capacity(lower); + let value_builder = StringBuilder::with_capacity(256, 1024); let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); it.for_each(|i| { builder @@ -387,6 +404,119 @@ impl fmt::Debug for DictionaryArray { } } +/// A strongly-typed wrapper around a [`DictionaryArray`] that implements [`ArrayAccessor`] +/// allowing fast access to its elements +/// +/// ``` +/// use arrow::array::{ArrayIter, DictionaryArray, StringArray}; +/// use arrow::datatypes::Int32Type; +/// +/// let orig = ["a", "b", "a", "b"]; +/// let dictionary = DictionaryArray::::from_iter(orig); +/// +/// // `TypedDictionaryArray` allows you to access the values directly +/// let typed = dictionary.downcast_dict::().unwrap(); +/// +/// for (maybe_val, orig) in typed.into_iter().zip(orig) { +/// assert_eq!(maybe_val.unwrap(), orig) +/// } +/// ``` +pub struct TypedDictionaryArray<'a, K: ArrowPrimitiveType, V> { + /// The dictionary array + dictionary: &'a DictionaryArray, + /// The values of the dictionary + values: &'a V, +} + +// Manually implement `Clone` to avoid `V: Clone` type constraint +impl<'a, K: ArrowPrimitiveType, V> Clone for TypedDictionaryArray<'a, K, V> { + fn clone(&self) -> Self { + Self { + dictionary: self.dictionary, + values: self.values, + } + } +} + +impl<'a, K: ArrowPrimitiveType, V> Copy for TypedDictionaryArray<'a, K, V> {} + +impl<'a, K: ArrowPrimitiveType, V> fmt::Debug for TypedDictionaryArray<'a, K, V> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "TypedDictionaryArray({:?})", self.dictionary) + } +} + +impl<'a, K: ArrowPrimitiveType, V> TypedDictionaryArray<'a, K, V> { + /// Returns the keys of this [`TypedDictionaryArray`] + pub fn keys(&self) -> &'a PrimitiveArray { + self.dictionary.keys() + } + + /// Returns the values of this [`TypedDictionaryArray`] + pub fn values(&self) -> &'a V { + self.values + } +} + +impl<'a, K: ArrowPrimitiveType, V: Sync> Array for TypedDictionaryArray<'a, K, V> { + fn as_any(&self) -> &dyn Any { + self.dictionary + } + + fn data(&self) -> &ArrayData { + &self.dictionary.data + } + + fn into_data(self) -> ArrayData { + self.dictionary.into_data() + } +} + +impl<'a, K, V> IntoIterator for TypedDictionaryArray<'a, K, V> +where + K: ArrowPrimitiveType, + Self: ArrayAccessor, +{ + type Item = Option<::Item>; + type IntoIter = ArrayIter; + + fn into_iter(self) -> Self::IntoIter { + ArrayIter::new(self) + } +} + +impl<'a, K, V> ArrayAccessor for TypedDictionaryArray<'a, K, V> +where + K: ArrowPrimitiveType, + V: Sync + Send, + &'a V: ArrayAccessor, + <&'a V as ArrayAccessor>::Item: Default, +{ + type Item = <&'a V as ArrayAccessor>::Item; + + fn value(&self, index: usize) -> Self::Item { + assert!( + index < self.len(), + "Trying to access an element at index {} from a TypedDictionaryArray of length {}", + index, + self.len() + ); + unsafe { self.value_unchecked(index) } + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + let val = self.dictionary.keys.value_unchecked(index); + let value_idx = val.to_usize().unwrap(); + + // As dictionary keys are only verified for non-null indexes + // we must check the value is within bounds + match value_idx < self.values.len() { + true => self.values.value_unchecked(value_idx), + false => Default::default(), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -459,11 +589,11 @@ mod tests { #[test] fn test_dictionary_array_fmt_debug() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(12345678).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(22345678).unwrap(); let array = builder.finish(); assert_eq!( @@ -471,8 +601,8 @@ mod tests { format!("{:?}", array) ); - let key_builder = PrimitiveBuilder::::new(20); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(20); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); for _ in 0..20 { builder.append(1).unwrap(); diff --git a/arrow/src/array/array_fixed_size_binary.rs b/arrow/src/array/array_fixed_size_binary.rs new file mode 100644 index 000000000000..22eac1435a8d --- /dev/null +++ b/arrow/src/array/array_fixed_size_binary.rs @@ -0,0 +1,690 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::convert::From; +use std::fmt; + +use super::{ + array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, FixedSizeListArray, +}; +use crate::array::{ArrayAccessor, FixedSizeBinaryIter}; +use crate::buffer::Buffer; +use crate::error::{ArrowError, Result}; +use crate::util::bit_util; +use crate::{buffer::MutableBuffer, datatypes::DataType}; + +/// An array where each element is a fixed-size sequence of bytes. +/// +/// # Examples +/// +/// Create an array from an iterable argument of byte slices. +/// +/// ``` +/// use arrow::array::{Array, FixedSizeBinaryArray}; +/// let input_arg = vec![ vec![1, 2], vec![3, 4], vec![5, 6] ]; +/// let arr = FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap(); +/// +/// assert_eq!(3, arr.len()); +/// +/// ``` +/// Create an array from an iterable argument of sparse byte slices. +/// Sparsity means that the input argument can contain `None` items. +/// ``` +/// use arrow::array::{Array, FixedSizeBinaryArray}; +/// let input_arg = vec![ None, Some(vec![7, 8]), Some(vec![9, 10]), None, Some(vec![13, 14]) ]; +/// let arr = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); +/// assert_eq!(5, arr.len()) +/// +/// ``` +/// +pub struct FixedSizeBinaryArray { + data: ArrayData, + value_data: RawPtrBox, + length: i32, +} + +impl FixedSizeBinaryArray { + /// Returns the element at index `i` as a byte slice. + /// # Panics + /// Panics if index `i` is out of bounds. + pub fn value(&self, i: usize) -> &[u8] { + assert!( + i < self.data.len(), + "Trying to access an element at index {} from a FixedSizeBinaryArray of length {}", + i, + self.len() + ); + let offset = i + self.data.offset(); + unsafe { + let pos = self.value_offset_at(offset); + std::slice::from_raw_parts( + self.value_data.as_ptr().offset(pos as isize), + (self.value_offset_at(offset + 1) - pos) as usize, + ) + } + } + + /// Returns the element at index `i` as a byte slice. + /// # Safety + /// Caller is responsible for ensuring that the index is within the bounds of the array + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + let offset = i + self.data.offset(); + let pos = self.value_offset_at(offset); + std::slice::from_raw_parts( + self.value_data.as_ptr().offset(pos as isize), + (self.value_offset_at(offset + 1) - pos) as usize, + ) + } + + /// Returns the offset for the element at index `i`. + /// + /// Note this doesn't do any bound checking, for performance reason. + #[inline] + pub fn value_offset(&self, i: usize) -> i32 { + self.value_offset_at(self.data.offset() + i) + } + + /// Returns the length for an element. + /// + /// All elements have the same length as the array is a fixed size. + #[inline] + pub fn value_length(&self) -> i32 { + self.length + } + + /// Returns a clone of the value data buffer + pub fn value_data(&self) -> Buffer { + self.data.buffers()[0].clone() + } + + /// Create an array from an iterable argument of sparse byte slices. + /// Sparsity means that items returned by the iterator are optional, i.e input argument can + /// contain `None` items. + /// + /// # Examples + /// + /// ``` + /// use arrow::array::FixedSizeBinaryArray; + /// let input_arg = vec![ + /// None, + /// Some(vec![7, 8]), + /// Some(vec![9, 10]), + /// None, + /// Some(vec![13, 14]), + /// None, + /// ]; + /// let array = FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); + /// ``` + /// + /// # Errors + /// + /// Returns error if argument has length zero, or sizes of nested slices don't match. + pub fn try_from_sparse_iter(mut iter: T) -> Result + where + T: Iterator>, + U: AsRef<[u8]>, + { + let mut len = 0; + let mut size = None; + let mut byte = 0; + let mut null_buf = MutableBuffer::from_len_zeroed(0); + let mut buffer = MutableBuffer::from_len_zeroed(0); + let mut prepend = 0; + iter.try_for_each(|item| -> Result<()> { + // extend null bitmask by one byte per each 8 items + if byte == 0 { + null_buf.push(0u8); + byte = 8; + } + byte -= 1; + + if let Some(slice) = item { + let slice = slice.as_ref(); + if let Some(size) = size { + if size != slice.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Nested array size mismatch: one is {}, and the other is {}", + size, + slice.len() + ))); + } + } else { + size = Some(slice.len()); + buffer.extend_zeros(slice.len() * prepend); + } + bit_util::set_bit(null_buf.as_slice_mut(), len); + buffer.extend_from_slice(slice); + } else if let Some(size) = size { + buffer.extend_zeros(size); + } else { + prepend += 1; + } + + len += 1; + + Ok(()) + })?; + + if len == 0 { + return Err(ArrowError::InvalidArgumentError( + "Input iterable argument has no data".to_owned(), + )); + } + + let size = size.unwrap_or(0); + let array_data = unsafe { + ArrayData::new_unchecked( + DataType::FixedSizeBinary(size as i32), + len, + None, + Some(null_buf.into()), + 0, + vec![buffer.into()], + vec![], + ) + }; + Ok(FixedSizeBinaryArray::from(array_data)) + } + + /// Create an array from an iterable argument of byte slices. + /// + /// # Examples + /// + /// ``` + /// use arrow::array::FixedSizeBinaryArray; + /// let input_arg = vec![ + /// vec![1, 2], + /// vec![3, 4], + /// vec![5, 6], + /// ]; + /// let array = FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap(); + /// ``` + /// + /// # Errors + /// + /// Returns error if argument has length zero, or sizes of nested slices don't match. + pub fn try_from_iter(mut iter: T) -> Result + where + T: Iterator, + U: AsRef<[u8]>, + { + let mut len = 0; + let mut size = None; + let mut buffer = MutableBuffer::from_len_zeroed(0); + iter.try_for_each(|item| -> Result<()> { + let slice = item.as_ref(); + if let Some(size) = size { + if size != slice.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Nested array size mismatch: one is {}, and the other is {}", + size, + slice.len() + ))); + } + } else { + size = Some(slice.len()); + } + buffer.extend_from_slice(slice); + + len += 1; + + Ok(()) + })?; + + if len == 0 { + return Err(ArrowError::InvalidArgumentError( + "Input iterable argument has no data".to_owned(), + )); + } + + let size = size.unwrap_or(0); + let array_data = ArrayData::builder(DataType::FixedSizeBinary(size as i32)) + .len(len) + .add_buffer(buffer.into()); + let array_data = unsafe { array_data.build_unchecked() }; + Ok(FixedSizeBinaryArray::from(array_data)) + } + + #[inline] + fn value_offset_at(&self, i: usize) -> i32 { + self.length * i as i32 + } + + /// constructs a new iterator + pub fn iter(&self) -> FixedSizeBinaryIter<'_> { + FixedSizeBinaryIter::new(self) + } +} + +impl From for FixedSizeBinaryArray { + fn from(data: ArrayData) -> Self { + assert_eq!( + data.buffers().len(), + 1, + "FixedSizeBinaryArray data should contain 1 buffer only (values)" + ); + let value_data = data.buffers()[0].as_ptr(); + let length = match data.data_type() { + DataType::FixedSizeBinary(len) => *len, + _ => panic!("Expected data type to be FixedSizeBinary"), + }; + Self { + data, + value_data: unsafe { RawPtrBox::new(value_data) }, + length, + } + } +} + +impl From for ArrayData { + fn from(array: FixedSizeBinaryArray) -> Self { + array.data + } +} + +/// Creates a `FixedSizeBinaryArray` from `FixedSizeList` array +impl From for FixedSizeBinaryArray { + fn from(v: FixedSizeListArray) -> Self { + assert_eq!( + v.data_ref().child_data().len(), + 1, + "FixedSizeBinaryArray can only be created from list array of u8 values \ + (i.e. FixedSizeList>)." + ); + let child_data = &v.data_ref().child_data()[0]; + + assert_eq!( + child_data.child_data().len(), + 0, + "FixedSizeBinaryArray can only be created from list array of u8 values \ + (i.e. FixedSizeList>)." + ); + assert_eq!( + child_data.data_type(), + &DataType::UInt8, + "FixedSizeBinaryArray can only be created from FixedSizeList arrays, mismatched data types." + ); + assert_eq!( + child_data.null_count(), + 0, + "The child array cannot contain null values." + ); + + let builder = ArrayData::builder(DataType::FixedSizeBinary(v.value_length())) + .len(v.len()) + .offset(v.offset()) + .add_buffer(child_data.buffers()[0].slice(child_data.offset())) + .null_bit_buffer(v.data_ref().null_buffer().cloned()); + + let data = unsafe { builder.build_unchecked() }; + Self::from(data) + } +} + +impl From>> for FixedSizeBinaryArray { + fn from(v: Vec>) -> Self { + Self::try_from_sparse_iter(v.into_iter()).unwrap() + } +} + +impl From> for FixedSizeBinaryArray { + fn from(v: Vec<&[u8]>) -> Self { + Self::try_from_iter(v.into_iter()).unwrap() + } +} + +impl fmt::Debug for FixedSizeBinaryArray { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "FixedSizeBinaryArray<{}>\n[\n", self.value_length())?; + print_long_array(self, f, |array, index, f| { + fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +impl Array for FixedSizeBinaryArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn data(&self) -> &ArrayData { + &self.data + } + + fn into_data(self) -> ArrayData { + self.into() + } +} + +impl<'a> ArrayAccessor for &'a FixedSizeBinaryArray { + type Item = &'a [u8]; + + fn value(&self, index: usize) -> Self::Item { + FixedSizeBinaryArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + FixedSizeBinaryArray::value_unchecked(self, index) + } +} + +impl<'a> IntoIterator for &'a FixedSizeBinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = FixedSizeBinaryIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + FixedSizeBinaryIter::<'a>::new(self) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + datatypes::{Field, Schema}, + record_batch::RecordBatch, + }; + + use super::*; + + #[test] + fn test_fixed_size_binary_array() { + let values: [u8; 15] = *b"hellotherearrow"; + + let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) + .len(3) + .add_buffer(Buffer::from(&values[..])) + .build() + .unwrap(); + let fixed_size_binary_array = FixedSizeBinaryArray::from(array_data); + assert_eq!(3, fixed_size_binary_array.len()); + assert_eq!(0, fixed_size_binary_array.null_count()); + assert_eq!( + [b'h', b'e', b'l', b'l', b'o'], + fixed_size_binary_array.value(0) + ); + assert_eq!( + [b't', b'h', b'e', b'r', b'e'], + fixed_size_binary_array.value(1) + ); + assert_eq!( + [b'a', b'r', b'r', b'o', b'w'], + fixed_size_binary_array.value(2) + ); + assert_eq!(5, fixed_size_binary_array.value_length()); + assert_eq!(10, fixed_size_binary_array.value_offset(2)); + for i in 0..3 { + assert!(fixed_size_binary_array.is_valid(i)); + assert!(!fixed_size_binary_array.is_null(i)); + } + + // Test binary array with offset + let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) + .len(2) + .offset(1) + .add_buffer(Buffer::from(&values[..])) + .build() + .unwrap(); + let fixed_size_binary_array = FixedSizeBinaryArray::from(array_data); + assert_eq!( + [b't', b'h', b'e', b'r', b'e'], + fixed_size_binary_array.value(0) + ); + assert_eq!( + [b'a', b'r', b'r', b'o', b'w'], + fixed_size_binary_array.value(1) + ); + assert_eq!(2, fixed_size_binary_array.len()); + assert_eq!(5, fixed_size_binary_array.value_offset(0)); + assert_eq!(5, fixed_size_binary_array.value_length()); + assert_eq!(10, fixed_size_binary_array.value_offset(1)); + } + + #[test] + fn test_fixed_size_binary_array_from_fixed_size_list_array() { + let values = [0_u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]; + let values_data = ArrayData::builder(DataType::UInt8) + .len(12) + .offset(2) + .add_buffer(Buffer::from_slice_ref(&values)) + .build() + .unwrap(); + // [null, [10, 11, 12, 13]] + let array_data = unsafe { + ArrayData::builder(DataType::FixedSizeList( + Box::new(Field::new("item", DataType::UInt8, false)), + 4, + )) + .len(2) + .offset(1) + .add_child_data(values_data) + .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b101]))) + .build_unchecked() + }; + let list_array = FixedSizeListArray::from(array_data); + let binary_array = FixedSizeBinaryArray::from(list_array); + + assert_eq!(2, binary_array.len()); + assert_eq!(1, binary_array.null_count()); + assert!(binary_array.is_null(0)); + assert!(binary_array.is_valid(1)); + assert_eq!(&[10, 11, 12, 13], binary_array.value(1)); + } + + #[test] + #[should_panic( + expected = "FixedSizeBinaryArray can only be created from FixedSizeList arrays" + )] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_fixed_size_binary_array_from_incorrect_fixed_size_list_array() { + let values: [u32; 12] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; + let values_data = ArrayData::builder(DataType::UInt32) + .len(12) + .add_buffer(Buffer::from_slice_ref(&values)) + .build() + .unwrap(); + + let array_data = unsafe { + ArrayData::builder(DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Binary, false)), + 4, + )) + .len(3) + .add_child_data(values_data) + .build_unchecked() + }; + let list_array = FixedSizeListArray::from(array_data); + drop(FixedSizeBinaryArray::from(list_array)); + } + + #[test] + #[should_panic(expected = "The child array cannot contain null values.")] + fn test_fixed_size_binary_array_from_fixed_size_list_array_with_child_nulls_failed() { + let values = [0_u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; + let values_data = ArrayData::builder(DataType::UInt8) + .len(12) + .add_buffer(Buffer::from_slice_ref(&values)) + .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b101010101010]))) + .build() + .unwrap(); + + let array_data = unsafe { + ArrayData::builder(DataType::FixedSizeList( + Box::new(Field::new("item", DataType::UInt8, false)), + 4, + )) + .len(3) + .add_child_data(values_data) + .build_unchecked() + }; + let list_array = FixedSizeListArray::from(array_data); + drop(FixedSizeBinaryArray::from(list_array)); + } + + #[test] + fn test_fixed_size_binary_array_fmt_debug() { + let values: [u8; 15] = *b"hellotherearrow"; + + let array_data = ArrayData::builder(DataType::FixedSizeBinary(5)) + .len(3) + .add_buffer(Buffer::from(&values[..])) + .build() + .unwrap(); + let arr = FixedSizeBinaryArray::from(array_data); + assert_eq!( + "FixedSizeBinaryArray<5>\n[\n [104, 101, 108, 108, 111],\n [116, 104, 101, 114, 101],\n [97, 114, 114, 111, 119],\n]", + format!("{:?}", arr) + ); + } + + #[test] + fn test_fixed_size_binary_array_from_iter() { + let input_arg = vec![vec![1, 2], vec![3, 4], vec![5, 6]]; + let arr = FixedSizeBinaryArray::try_from_iter(input_arg.into_iter()).unwrap(); + + assert_eq!(2, arr.value_length()); + assert_eq!(3, arr.len()) + } + + #[test] + fn test_all_none_fixed_size_binary_array_from_sparse_iter() { + let none_option: Option<[u8; 32]> = None; + let input_arg = vec![none_option, none_option, none_option]; + let arr = + FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); + assert_eq!(0, arr.value_length()); + assert_eq!(3, arr.len()) + } + + #[test] + fn test_fixed_size_binary_array_from_sparse_iter() { + let input_arg = vec![ + None, + Some(vec![7, 8]), + Some(vec![9, 10]), + None, + Some(vec![13, 14]), + ]; + let arr = + FixedSizeBinaryArray::try_from_sparse_iter(input_arg.into_iter()).unwrap(); + assert_eq!(2, arr.value_length()); + assert_eq!(5, arr.len()) + } + + #[test] + fn test_fixed_size_binary_array_from_vec() { + let values = vec!["one".as_bytes(), b"two", b"six", b"ten"]; + let array = FixedSizeBinaryArray::from(values); + assert_eq!(array.len(), 4); + assert_eq!(array.null_count(), 0); + assert_eq!(array.value(0), b"one"); + assert_eq!(array.value(1), b"two"); + assert_eq!(array.value(2), b"six"); + assert_eq!(array.value(3), b"ten"); + assert!(!array.is_null(0)); + assert!(!array.is_null(1)); + assert!(!array.is_null(2)); + assert!(!array.is_null(3)); + } + + #[test] + #[should_panic(expected = "Nested array size mismatch: one is 3, and the other is 5")] + fn test_fixed_size_binary_array_from_vec_incorrect_length() { + let values = vec!["one".as_bytes(), b"two", b"three", b"four"]; + let _ = FixedSizeBinaryArray::from(values); + } + + #[test] + fn test_fixed_size_binary_array_from_opt_vec() { + let values = vec![ + Some("one".as_bytes()), + Some(b"two"), + None, + Some(b"six"), + Some(b"ten"), + ]; + let array = FixedSizeBinaryArray::from(values); + assert_eq!(array.len(), 5); + assert_eq!(array.value(0), b"one"); + assert_eq!(array.value(1), b"two"); + assert_eq!(array.value(3), b"six"); + assert_eq!(array.value(4), b"ten"); + assert!(!array.is_null(0)); + assert!(!array.is_null(1)); + assert!(array.is_null(2)); + assert!(!array.is_null(3)); + assert!(!array.is_null(4)); + } + + #[test] + #[should_panic(expected = "Nested array size mismatch: one is 3, and the other is 5")] + fn test_fixed_size_binary_array_from_opt_vec_incorrect_length() { + let values = vec![ + Some("one".as_bytes()), + Some(b"two"), + None, + Some(b"three"), + Some(b"four"), + ]; + let _ = FixedSizeBinaryArray::from(values); + } + + #[test] + fn fixed_size_binary_array_all_null() { + let data = vec![None] as Vec>; + let array = FixedSizeBinaryArray::try_from_sparse_iter(data.into_iter()).unwrap(); + array + .data() + .validate_full() + .expect("All null array has valid array data"); + } + + #[test] + // Test for https://github.com/apache/arrow-rs/issues/1390 + #[should_panic( + expected = "column types must match schema types, expected FixedSizeBinary(2) but found FixedSizeBinary(0) at column index 0" + )] + fn fixed_size_binary_array_all_null_in_batch_with_schema() { + let schema = + Schema::new(vec![Field::new("a", DataType::FixedSizeBinary(2), true)]); + + let none_option: Option<[u8; 2]> = None; + let item = FixedSizeBinaryArray::try_from_sparse_iter( + vec![none_option, none_option, none_option].into_iter(), + ) + .unwrap(); + + // Should not panic + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(item)]).unwrap(); + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a FixedSizeBinaryArray of length 3" + )] + fn test_fixed_size_binary_array_get_value_index_out_of_bound() { + let values = vec![Some("one".as_bytes()), Some(b"two"), None]; + let array = FixedSizeBinaryArray::from(values); + + array.value(4); + } +} diff --git a/arrow/src/array/array_fixed_size_list.rs b/arrow/src/array/array_fixed_size_list.rs new file mode 100644 index 000000000000..fc568d54a831 --- /dev/null +++ b/arrow/src/array/array_fixed_size_list.rs @@ -0,0 +1,388 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt; + +use super::{array::print_long_array, make_array, Array, ArrayData, ArrayRef}; +use crate::array::array::ArrayAccessor; +use crate::datatypes::DataType; + +/// A list array where each element is a fixed-size sequence of values with the same +/// type whose maximum length is represented by a i32. +/// +/// # Example +/// +/// ``` +/// # use arrow::array::{Array, ArrayData, FixedSizeListArray, Int32Array}; +/// # use arrow::datatypes::{DataType, Field}; +/// # use arrow::buffer::Buffer; +/// // Construct a value array +/// let value_data = ArrayData::builder(DataType::Int32) +/// .len(9) +/// .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8])) +/// .build() +/// .unwrap(); +/// let list_data_type = DataType::FixedSizeList( +/// Box::new(Field::new("item", DataType::Int32, false)), +/// 3, +/// ); +/// let list_data = ArrayData::builder(list_data_type.clone()) +/// .len(3) +/// .add_child_data(value_data.clone()) +/// .build() +/// .unwrap(); +/// let list_array = FixedSizeListArray::from(list_data); +/// let list0 = list_array.value(0); +/// let list1 = list_array.value(1); +/// let list2 = list_array.value(2); +/// +/// assert_eq!( &[0, 1, 2], list0.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!( &[3, 4, 5], list1.as_any().downcast_ref::().unwrap().values()); +/// assert_eq!( &[6, 7, 8], list2.as_any().downcast_ref::().unwrap().values()); +/// ``` +/// +/// For non generic lists, you may wish to consider using +/// [crate::array::FixedSizeBinaryArray] +pub struct FixedSizeListArray { + data: ArrayData, + values: ArrayRef, + length: i32, +} + +impl FixedSizeListArray { + /// Returns a reference to the values of this list. + pub fn values(&self) -> ArrayRef { + self.values.clone() + } + + /// Returns a clone of the value type of this list. + pub fn value_type(&self) -> DataType { + self.values.data_ref().data_type().clone() + } + + /// Returns ith value of this list array. + pub fn value(&self, i: usize) -> ArrayRef { + self.values + .slice(self.value_offset(i) as usize, self.value_length() as usize) + } + + /// Returns the offset for value at index `i`. + /// + /// Note this doesn't do any bound checking, for performance reason. + #[inline] + pub fn value_offset(&self, i: usize) -> i32 { + self.value_offset_at(self.data.offset() + i) + } + + /// Returns the length for an element. + /// + /// All elements have the same length as the array is a fixed size. + #[inline] + pub const fn value_length(&self) -> i32 { + self.length + } + + #[inline] + const fn value_offset_at(&self, i: usize) -> i32 { + i as i32 * self.length + } +} + +impl From for FixedSizeListArray { + fn from(data: ArrayData) -> Self { + assert_eq!( + data.buffers().len(), + 0, + "FixedSizeListArray data should not contain a buffer for value offsets" + ); + assert_eq!( + data.child_data().len(), + 1, + "FixedSizeListArray should contain a single child array (values array)" + ); + let values = make_array(data.child_data()[0].clone()); + let length = match data.data_type() { + DataType::FixedSizeList(_, len) => { + if *len > 0 { + // check that child data is multiple of length + assert_eq!( + values.len() % *len as usize, + 0, + "FixedSizeListArray child array length should be a multiple of {}", + len + ); + } + + *len + } + _ => { + panic!("FixedSizeListArray data should contain a FixedSizeList data type") + } + }; + Self { + data, + values, + length, + } + } +} + +impl From for ArrayData { + fn from(array: FixedSizeListArray) -> Self { + array.data + } +} + +impl Array for FixedSizeListArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn data(&self) -> &ArrayData { + &self.data + } + + fn into_data(self) -> ArrayData { + self.into() + } +} + +impl ArrayAccessor for FixedSizeListArray { + type Item = ArrayRef; + + fn value(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + FixedSizeListArray::value(self, index) + } +} + +impl fmt::Debug for FixedSizeListArray { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "FixedSizeListArray<{}>\n[\n", self.value_length())?; + print_long_array(self, f, |array, index, f| { + fmt::Debug::fmt(&array.value(index), f) + })?; + write!(f, "]") + } +} + +#[cfg(test)] +mod tests { + use crate::{ + array::ArrayData, array::Int32Array, buffer::Buffer, datatypes::Field, + util::bit_util, + }; + + use super::*; + + #[test] + fn test_fixed_size_list_array() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(9) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8])) + .build() + .unwrap(); + + // Construct a list array from the above two + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, false)), + 3, + ); + let list_data = ArrayData::builder(list_data_type.clone()) + .len(3) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + let values = list_array.values(); + assert_eq!(&value_data, values.data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!(6, list_array.value_offset(2)); + assert_eq!(3, list_array.value_length()); + assert_eq!( + 0, + list_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + ); + for i in 0..3 { + assert!(list_array.is_valid(i)); + assert!(!list_array.is_null(i)); + } + + // Now test with a non-zero offset + let list_data = ArrayData::builder(list_data_type) + .len(3) + .offset(1) + .add_child_data(value_data.clone()) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + let values = list_array.values(); + assert_eq!(&value_data, values.data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(3, list_array.len()); + assert_eq!(0, list_array.null_count()); + assert_eq!( + 3, + list_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + ); + assert_eq!(6, list_array.value_offset(1)); + assert_eq!(3, list_array.value_length()); + } + + #[test] + #[should_panic( + expected = "FixedSizeListArray child array length should be a multiple of 3" + )] + // Different error messages, so skip for now + // https://github.com/apache/arrow-rs/issues/1545 + #[cfg(not(feature = "force_validate"))] + fn test_fixed_size_list_array_unequal_children() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(8) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a list array from the above two + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, false)), + 3, + ); + let list_data = unsafe { + ArrayData::builder(list_data_type) + .len(3) + .add_child_data(value_data) + .build_unchecked() + }; + drop(FixedSizeListArray::from(list_data)); + } + + #[test] + fn test_fixed_size_list_array_slice() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Set null buts for the nested array: + // [[0, 1], null, null, [6, 7], [8, 9]] + // 01011001 00000001 + let mut null_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + + // Construct a fixed size list array from the above two + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, false)), + 2, + ); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data.clone()) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + let values = list_array.values(); + assert_eq!(&value_data, values.data()); + assert_eq!(DataType::Int32, list_array.value_type()); + assert_eq!(5, list_array.len()); + assert_eq!(2, list_array.null_count()); + assert_eq!(6, list_array.value_offset(3)); + assert_eq!(2, list_array.value_length()); + + let sliced_array = list_array.slice(1, 4); + assert_eq!(4, sliced_array.len()); + assert_eq!(1, sliced_array.offset()); + assert_eq!(2, sliced_array.null_count()); + + for i in 0..sliced_array.len() { + if bit_util::get_bit(&null_bits, sliced_array.offset() + i) { + assert!(sliced_array.is_valid(i)); + } else { + assert!(sliced_array.is_null(i)); + } + } + + // Check offset and length for each non-null value. + let sliced_list_array = sliced_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(2, sliced_list_array.value_length()); + assert_eq!(6, sliced_list_array.value_offset(2)); + assert_eq!(8, sliced_list_array.value_offset(3)); + } + + #[test] + #[should_panic(expected = "assertion failed: (offset + length) <= self.len()")] + fn test_fixed_size_list_array_index_out_of_bound() { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int32) + .len(10) + .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) + .build() + .unwrap(); + + // Set null buts for the nested array: + // [[0, 1], null, null, [6, 7], [8, 9]] + // 01011001 00000001 + let mut null_bits: [u8; 1] = [0; 1]; + bit_util::set_bit(&mut null_bits, 0); + bit_util::set_bit(&mut null_bits, 3); + bit_util::set_bit(&mut null_bits, 4); + + // Construct a fixed size list array from the above two + let list_data_type = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::Int32, false)), + 2, + ); + let list_data = ArrayData::builder(list_data_type) + .len(5) + .add_child_data(value_data) + .null_bit_buffer(Some(Buffer::from(null_bits))) + .build() + .unwrap(); + let list_array = FixedSizeListArray::from(list_data); + + list_array.value(10); + } +} diff --git a/arrow/src/array/array_list.rs b/arrow/src/array/array_list.rs index ac37754e9bf4..b9c05014c3f7 100644 --- a/arrow/src/array/array_list.rs +++ b/arrow/src/array/array_list.rs @@ -24,6 +24,7 @@ use super::{ array::print_long_array, make_array, raw_pointer::RawPtrBox, Array, ArrayData, ArrayRef, BooleanBufferBuilder, GenericListArrayIter, PrimitiveArray, }; +use crate::array::array::ArrayAccessor; use crate::{ buffer::MutableBuffer, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType, Field}, @@ -33,14 +34,17 @@ use crate::{ /// trait declaring an offset size, relevant for i32 vs i64 array types. pub trait OffsetSizeTrait: ArrowNativeType + std::ops::AddAssign + Integer { const IS_LARGE: bool; + const PREFIX: &'static str; } impl OffsetSizeTrait for i32 { const IS_LARGE: bool = false; + const PREFIX: &'static str = ""; } impl OffsetSizeTrait for i64 { const IS_LARGE: bool = true; + const PREFIX: &'static str = "Large"; } /// Generic struct for a variable-size list array. @@ -56,6 +60,16 @@ pub struct GenericListArray { } impl GenericListArray { + /// The data type constructor of list array. + /// The input is the schema of the child array and + /// the output is the [`DataType`], List or LargeList. + pub const DATA_TYPE_CONSTRUCTOR: fn(Box) -> DataType = if OffsetSize::IS_LARGE + { + DataType::LargeList + } else { + DataType::List + }; + /// Returns a reference to the values of this list. pub fn values(&self) -> ArrayRef { self.values.clone() @@ -169,11 +183,7 @@ impl GenericListArray { .collect(); let field = Box::new(Field::new("item", T::DATA_TYPE, true)); - let data_type = if OffsetSize::IS_LARGE { - DataType::LargeList(field) - } else { - DataType::List(field) - }; + let data_type = Self::DATA_TYPE_CONSTRUCTOR(field); let array_data = ArrayData::builder(data_type) .len(null_buf.len()) .add_buffer(offsets.into()) @@ -245,7 +255,7 @@ impl GenericListArray { } } -impl Array for GenericListArray { +impl Array for GenericListArray { fn as_any(&self) -> &dyn Any { self } @@ -259,9 +269,21 @@ impl Array for GenericListArray ArrayAccessor for &'a GenericListArray { + type Item = ArrayRef; + + fn value(&self, index: usize) -> Self::Item { + GenericListArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + GenericListArray::value(self, index) + } +} + impl fmt::Debug for GenericListArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = if OffsetSize::IS_LARGE { "Large" } else { "" }; + let prefix = OffsetSize::PREFIX; write!(f, "{}ListArray\n[\n", prefix)?; print_long_array(self, f, |array, index, f| { @@ -326,156 +348,6 @@ pub type ListArray = GenericListArray; /// ``` pub type LargeListArray = GenericListArray; -/// A list array where each element is a fixed-size sequence of values with the same -/// type whose maximum length is represented by a i32. -/// -/// # Example -/// -/// ``` -/// # use arrow::array::{Array, ArrayData, FixedSizeListArray, Int32Array}; -/// # use arrow::datatypes::{DataType, Field}; -/// # use arrow::buffer::Buffer; -/// // Construct a value array -/// let value_data = ArrayData::builder(DataType::Int32) -/// .len(9) -/// .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8])) -/// .build() -/// .unwrap(); -/// let list_data_type = DataType::FixedSizeList( -/// Box::new(Field::new("item", DataType::Int32, false)), -/// 3, -/// ); -/// let list_data = ArrayData::builder(list_data_type.clone()) -/// .len(3) -/// .add_child_data(value_data.clone()) -/// .build() -/// .unwrap(); -/// let list_array = FixedSizeListArray::from(list_data); -/// let list0 = list_array.value(0); -/// let list1 = list_array.value(1); -/// let list2 = list_array.value(2); -/// -/// assert_eq!( &[0, 1, 2], list0.as_any().downcast_ref::().unwrap().values()); -/// assert_eq!( &[3, 4, 5], list1.as_any().downcast_ref::().unwrap().values()); -/// assert_eq!( &[6, 7, 8], list2.as_any().downcast_ref::().unwrap().values()); -/// ``` -/// -/// For non generic lists, you may wish to consider using -/// [crate::array::FixedSizeBinaryArray] -pub struct FixedSizeListArray { - data: ArrayData, - values: ArrayRef, - length: i32, -} - -impl FixedSizeListArray { - /// Returns a reference to the values of this list. - pub fn values(&self) -> ArrayRef { - self.values.clone() - } - - /// Returns a clone of the value type of this list. - pub fn value_type(&self) -> DataType { - self.values.data_ref().data_type().clone() - } - - /// Returns ith value of this list array. - pub fn value(&self, i: usize) -> ArrayRef { - self.values - .slice(self.value_offset(i) as usize, self.value_length() as usize) - } - - /// Returns the offset for value at index `i`. - /// - /// Note this doesn't do any bound checking, for performance reason. - #[inline] - pub fn value_offset(&self, i: usize) -> i32 { - self.value_offset_at(self.data.offset() + i) - } - - /// Returns the length for an element. - /// - /// All elements have the same length as the array is a fixed size. - #[inline] - pub const fn value_length(&self) -> i32 { - self.length - } - - #[inline] - const fn value_offset_at(&self, i: usize) -> i32 { - i as i32 * self.length - } -} - -impl From for FixedSizeListArray { - fn from(data: ArrayData) -> Self { - assert_eq!( - data.buffers().len(), - 0, - "FixedSizeListArray data should not contain a buffer for value offsets" - ); - assert_eq!( - data.child_data().len(), - 1, - "FixedSizeListArray should contain a single child array (values array)" - ); - let values = make_array(data.child_data()[0].clone()); - let length = match data.data_type() { - DataType::FixedSizeList(_, len) => { - if *len > 0 { - // check that child data is multiple of length - assert_eq!( - values.len() % *len as usize, - 0, - "FixedSizeListArray child array length should be a multiple of {}", - len - ); - } - - *len - } - _ => { - panic!("FixedSizeListArray data should contain a FixedSizeList data type") - } - }; - Self { - data, - values, - length, - } - } -} - -impl From for ArrayData { - fn from(array: FixedSizeListArray) -> Self { - array.data - } -} - -impl Array for FixedSizeListArray { - fn as_any(&self) -> &dyn Any { - self - } - - fn data(&self) -> &ArrayData { - &self.data - } - - fn into_data(self) -> ArrayData { - self.into() - } -} - -impl fmt::Debug for FixedSizeListArray { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "FixedSizeListArray<{}>\n[\n", self.value_length())?; - print_long_array(self, f, |array, index, f| { - fmt::Debug::fmt(&array.value(index), f) - })?; - write!(f, "]") - } -} - #[cfg(test)] mod tests { use crate::{ @@ -733,104 +605,6 @@ mod tests { ); } - #[test] - fn test_fixed_size_list_array() { - // Construct a value array - let value_data = ArrayData::builder(DataType::Int32) - .len(9) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8])) - .build() - .unwrap(); - - // Construct a list array from the above two - let list_data_type = DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Int32, false)), - 3, - ); - let list_data = ArrayData::builder(list_data_type.clone()) - .len(3) - .add_child_data(value_data.clone()) - .build() - .unwrap(); - let list_array = FixedSizeListArray::from(list_data); - - let values = list_array.values(); - assert_eq!(&value_data, values.data()); - assert_eq!(DataType::Int32, list_array.value_type()); - assert_eq!(3, list_array.len()); - assert_eq!(0, list_array.null_count()); - assert_eq!(6, list_array.value_offset(2)); - assert_eq!(3, list_array.value_length()); - assert_eq!( - 0, - list_array - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - ); - for i in 0..3 { - assert!(list_array.is_valid(i)); - assert!(!list_array.is_null(i)); - } - - // Now test with a non-zero offset - let list_data = ArrayData::builder(list_data_type) - .len(3) - .offset(1) - .add_child_data(value_data.clone()) - .build() - .unwrap(); - let list_array = FixedSizeListArray::from(list_data); - - let values = list_array.values(); - assert_eq!(&value_data, values.data()); - assert_eq!(DataType::Int32, list_array.value_type()); - assert_eq!(3, list_array.len()); - assert_eq!(0, list_array.null_count()); - assert_eq!( - 3, - list_array - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0) - ); - assert_eq!(6, list_array.value_offset(1)); - assert_eq!(3, list_array.value_length()); - } - - #[test] - #[should_panic( - expected = "FixedSizeListArray child array length should be a multiple of 3" - )] - // Different error messages, so skip for now - // https://github.com/apache/arrow-rs/issues/1545 - #[cfg(not(feature = "force_validate"))] - fn test_fixed_size_list_array_unequal_children() { - // Construct a value array - let value_data = ArrayData::builder(DataType::Int32) - .len(8) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7])) - .build() - .unwrap(); - - // Construct a list array from the above two - let list_data_type = DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Int32, false)), - 3, - ); - let list_data = unsafe { - ArrayData::builder(list_data_type) - .len(3) - .add_child_data(value_data) - .build_unchecked() - }; - drop(FixedSizeListArray::from(list_data)); - } - #[test] fn test_list_array_slice() { // Construct a value array @@ -997,102 +771,6 @@ mod tests { list_array.value(10); } - - #[test] - fn test_fixed_size_list_array_slice() { - // Construct a value array - let value_data = ArrayData::builder(DataType::Int32) - .len(10) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) - .build() - .unwrap(); - - // Set null buts for the nested array: - // [[0, 1], null, null, [6, 7], [8, 9]] - // 01011001 00000001 - let mut null_bits: [u8; 1] = [0; 1]; - bit_util::set_bit(&mut null_bits, 0); - bit_util::set_bit(&mut null_bits, 3); - bit_util::set_bit(&mut null_bits, 4); - - // Construct a fixed size list array from the above two - let list_data_type = DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Int32, false)), - 2, - ); - let list_data = ArrayData::builder(list_data_type) - .len(5) - .add_child_data(value_data.clone()) - .null_bit_buffer(Some(Buffer::from(null_bits))) - .build() - .unwrap(); - let list_array = FixedSizeListArray::from(list_data); - - let values = list_array.values(); - assert_eq!(&value_data, values.data()); - assert_eq!(DataType::Int32, list_array.value_type()); - assert_eq!(5, list_array.len()); - assert_eq!(2, list_array.null_count()); - assert_eq!(6, list_array.value_offset(3)); - assert_eq!(2, list_array.value_length()); - - let sliced_array = list_array.slice(1, 4); - assert_eq!(4, sliced_array.len()); - assert_eq!(1, sliced_array.offset()); - assert_eq!(2, sliced_array.null_count()); - - for i in 0..sliced_array.len() { - if bit_util::get_bit(&null_bits, sliced_array.offset() + i) { - assert!(sliced_array.is_valid(i)); - } else { - assert!(sliced_array.is_null(i)); - } - } - - // Check offset and length for each non-null value. - let sliced_list_array = sliced_array - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(2, sliced_list_array.value_length()); - assert_eq!(6, sliced_list_array.value_offset(2)); - assert_eq!(8, sliced_list_array.value_offset(3)); - } - - #[test] - #[should_panic(expected = "assertion failed: (offset + length) <= self.len()")] - fn test_fixed_size_list_array_index_out_of_bound() { - // Construct a value array - let value_data = ArrayData::builder(DataType::Int32) - .len(10) - .add_buffer(Buffer::from_slice_ref(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) - .build() - .unwrap(); - - // Set null buts for the nested array: - // [[0, 1], null, null, [6, 7], [8, 9]] - // 01011001 00000001 - let mut null_bits: [u8; 1] = [0; 1]; - bit_util::set_bit(&mut null_bits, 0); - bit_util::set_bit(&mut null_bits, 3); - bit_util::set_bit(&mut null_bits, 4); - - // Construct a fixed size list array from the above two - let list_data_type = DataType::FixedSizeList( - Box::new(Field::new("item", DataType::Int32, false)), - 2, - ); - let list_data = ArrayData::builder(list_data_type) - .len(5) - .add_child_data(value_data) - .null_bit_buffer(Some(Buffer::from(null_bits))) - .build() - .unwrap(); - let list_array = FixedSizeListArray::from(list_data); - - list_array.value(10); - } - #[test] #[should_panic( expected = "ListArray data should contain a single buffer only (value offsets)" diff --git a/arrow/src/array/array_primitive.rs b/arrow/src/array/array_primitive.rs index efac5a60cb32..7818e6ff01d5 100644 --- a/arrow/src/array/array_primitive.rs +++ b/arrow/src/array/array_primitive.rs @@ -33,6 +33,7 @@ use crate::{ util::trusted_len_unzip, }; +use crate::array::array::ArrayAccessor; use half::f16; /// Array whose elements are of primitive types. @@ -90,7 +91,7 @@ impl PrimitiveArray { // Returns a new primitive array builder pub fn builder(capacity: usize) -> PrimitiveBuilder { - PrimitiveBuilder::::new(capacity) + PrimitiveBuilder::::with_capacity(capacity) } /// Returns the primitive value at index `i`. @@ -105,11 +106,16 @@ impl PrimitiveArray { } /// Returns the primitive value at index `i`. - /// - /// Panics of offset `i` is out of bounds + /// # Panics + /// Panics if index `i` is out of bounds #[inline] pub fn value(&self, i: usize) -> T::Native { - assert!(i < self.len()); + assert!( + i < self.len(), + "Trying to access an element at index {} from a PrimitiveArray of length {}", + i, + self.len() + ); unsafe { self.value_unchecked(i) } } @@ -188,6 +194,18 @@ impl Array for PrimitiveArray { } } +impl<'a, T: ArrowPrimitiveType> ArrayAccessor for &'a PrimitiveArray { + type Item = T::Native; + + fn value(&self, index: usize) -> Self::Item { + PrimitiveArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + PrimitiveArray::value_unchecked(self, index) + } +} + fn as_datetime(v: i64) -> Option { match T::DATA_TYPE { DataType::Date32 => Some(temporal_conversions::date32_to_datetime(v as i32)), @@ -425,17 +443,13 @@ impl>> FromIterator .collect(); let len = null_builder.len(); - let null_buf: Buffer = null_builder.into(); - let valid_count = null_buf.count_set_bits(); - let null_count = len - valid_count; - let opt_null_buf = (null_count != 0).then(|| null_buf); let data = unsafe { ArrayData::new_unchecked( T::DATA_TYPE, len, - Some(null_count), - opt_null_buf, + None, + Some(null_builder.into()), 0, vec![buffer], vec![], @@ -540,6 +554,18 @@ impl PrimitiveArray { let array_data = unsafe { array_data.build_unchecked() }; PrimitiveArray::from(array_data) } + + /// Construct a timestamp array with new timezone + pub fn with_timezone(&self, timezone: String) -> Self { + let array_data = unsafe { + self.data + .clone() + .into_builder() + .data_type(DataType::Timestamp(T::get_time_unit(), Some(timezone))) + .build_unchecked() + }; + PrimitiveArray::from(array_data) + } } impl PrimitiveArray { @@ -937,9 +963,9 @@ mod tests { #[test] fn test_int32_with_null_fmt_debug() { let mut builder = Int32Array::builder(3); - builder.append_slice(&[0, 1]).unwrap(); - builder.append_null().unwrap(); - builder.append_slice(&[3, 4]).unwrap(); + builder.append_slice(&[0, 1]); + builder.append_null(); + builder.append_slice(&[3, 4]); let arr = builder.finish(); assert_eq!( "PrimitiveArray\n[\n 0,\n 1,\n null,\n 3,\n 4,\n]", @@ -1090,4 +1116,31 @@ mod tests { BooleanArray::from(vec![true, true, true, true, true]) ); } + + #[cfg(feature = "chrono-tz")] + #[test] + fn test_with_timezone() { + use crate::compute::hour; + let a: TimestampMicrosecondArray = vec![37800000000, 86339000000].into(); + + let b = hour(&a).unwrap(); + assert_eq!(10, b.value(0)); + assert_eq!(23, b.value(1)); + + let a = a.with_timezone(String::from("America/Los_Angeles")); + + let b = hour(&a).unwrap(); + assert_eq!(2, b.value(0)); + assert_eq!(15, b.value(1)); + } + + #[test] + #[should_panic( + expected = "Trying to access an element at index 4 from a PrimitiveArray of length 3" + )] + fn test_string_array_get_value_index_out_of_bound() { + let array: Int8Array = [10_i8, 11, 12].into_iter().collect(); + + array.value(4); + } } diff --git a/arrow/src/array/array_string.rs b/arrow/src/array/array_string.rs index b48f058cf0cf..62743a20a119 100644 --- a/arrow/src/array/array_string.rs +++ b/arrow/src/array/array_string.rs @@ -20,9 +20,10 @@ use std::fmt; use std::{any::Any, iter::FromIterator}; use super::{ - array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, GenericListArray, - GenericStringIter, OffsetSizeTrait, + array::print_long_array, raw_pointer::RawPtrBox, Array, ArrayData, + GenericBinaryArray, GenericListArray, GenericStringIter, OffsetSizeTrait, }; +use crate::array::array::ArrayAccessor; use crate::buffer::Buffer; use crate::util::bit_util; use crate::{buffer::MutableBuffer, datatypes::DataType}; @@ -38,15 +39,17 @@ pub struct GenericStringArray { } impl GenericStringArray { + /// Data type of the array. + pub const DATA_TYPE: DataType = if OffsetSize::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + /// Get the data type of the array. - // Declare this function as `pub const fn` after - // https://github.com/rust-lang/rust/issues/93706 is merged. - pub fn get_data_type() -> DataType { - if OffsetSize::IS_LARGE { - DataType::LargeUtf8 - } else { - DataType::Utf8 - } + #[deprecated(note = "please use `Self::DATA_TYPE` instead")] + pub const fn get_data_type() -> DataType { + Self::DATA_TYPE } /// Returns the length for the element at index `i`. @@ -110,31 +113,55 @@ impl GenericStringArray { } /// Returns the element at index `i` as &str + /// # Panics + /// Panics if index `i` is out of bounds. #[inline] pub fn value(&self, i: usize) -> &str { - assert!(i < self.data.len(), "StringArray out of bounds access"); + assert!( + i < self.data.len(), + "Trying to access an element at index {} from a StringArray of length {}", + i, + self.len() + ); // Safety: // `i < self.data.len() unsafe { self.value_unchecked(i) } } + /// Convert a list array to a string array. + /// This method is unsound because it does + /// not check the utf-8 validation for each element. fn from_list(v: GenericListArray) -> Self { assert_eq!( - v.data().child_data()[0].child_data().len(), + v.data_ref().child_data().len(), + 1, + "StringArray can only be created from list array of u8 values \ + (i.e. List>)." + ); + let child_data = &v.data_ref().child_data()[0]; + + assert_eq!( + child_data.child_data().len(), 0, "StringArray can only be created from list array of u8 values \ (i.e. List>)." ); assert_eq!( - v.data().child_data()[0].data_type(), + child_data.data_type(), &DataType::UInt8, "StringArray can only be created from List arrays, mismatched data types." ); + assert_eq!( + child_data.null_count(), + 0, + "The child array cannot contain null values." + ); - let builder = ArrayData::builder(Self::get_data_type()) + let builder = ArrayData::builder(Self::DATA_TYPE) .len(v.len()) + .offset(v.offset()) .add_buffer(v.data().buffers()[0].clone()) - .add_buffer(v.data().child_data()[0].buffers()[0].clone()) + .add_buffer(child_data.buffers()[0].slice(child_data.offset())) .null_bit_buffer(v.data().null_buffer().cloned()); let array_data = unsafe { builder.build_unchecked() }; @@ -169,7 +196,7 @@ impl GenericStringArray { assert!(!offsets.is_empty()); // wrote at least one let actual_len = (offsets.len() / std::mem::size_of::()) - 1; - let array_data = ArrayData::builder(Self::get_data_type()) + let array_data = ArrayData::builder(Self::DATA_TYPE) .len(actual_len) .add_buffer(offsets.into()) .add_buffer(values.into()); @@ -246,7 +273,7 @@ where // calculate actual data_len, which may be different from the iterator's upper bound let data_len = (offsets.len() / offset_size) - 1; - let array_data = ArrayData::builder(Self::get_data_type()) + let array_data = ArrayData::builder(Self::DATA_TYPE) .len(data_len) .add_buffer(offsets.into()) .add_buffer(values.into()) @@ -274,7 +301,7 @@ impl<'a, T: OffsetSizeTrait> GenericStringArray { impl fmt::Debug for GenericStringArray { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let prefix = if OffsetSize::IS_LARGE { "Large" } else { "" }; + let prefix = OffsetSize::PREFIX; write!(f, "{}StringArray\n[\n", prefix)?; print_long_array(self, f, |array, index, f| { @@ -298,11 +325,43 @@ impl Array for GenericStringArray { } } +impl<'a, OffsetSize: OffsetSizeTrait> ArrayAccessor + for &'a GenericStringArray +{ + type Item = &'a str; + + fn value(&self, index: usize) -> Self::Item { + GenericStringArray::value(self, index) + } + + unsafe fn value_unchecked(&self, index: usize) -> Self::Item { + GenericStringArray::value_unchecked(self, index) + } +} + +impl From> + for GenericStringArray +{ + fn from(v: GenericListArray) -> Self { + GenericStringArray::::from_list(v) + } +} + +impl From> + for GenericStringArray +{ + fn from(v: GenericBinaryArray) -> Self { + let builder = v.into_data().into_builder().data_type(Self::DATA_TYPE); + let data = unsafe { builder.build_unchecked() }; + Self::from(data) + } +} + impl From for GenericStringArray { fn from(data: ArrayData) -> Self { assert_eq!( data.data_type(), - &Self::get_data_type(), + &Self::DATA_TYPE, "[Large]StringArray expects Datatype::[Large]Utf8" ); assert_eq!( @@ -370,16 +429,13 @@ pub type StringArray = GenericStringArray; /// ``` pub type LargeStringArray = GenericStringArray; -impl From> for GenericStringArray { - fn from(v: GenericListArray) -> Self { - GenericStringArray::::from_list(v) - } -} - #[cfg(test)] mod tests { - use crate::array::{ListBuilder, StringBuilder}; + use crate::{ + array::{ListBuilder, StringBuilder}, + datatypes::Field, + }; use super::*; @@ -443,18 +499,15 @@ mod tests { #[test] fn test_nested_string_array() { - let string_builder = StringBuilder::new(3); + let string_builder = StringBuilder::with_capacity(3, 10); let mut list_of_string_builder = ListBuilder::new(string_builder); - list_of_string_builder.values().append_value("foo").unwrap(); - list_of_string_builder.values().append_value("bar").unwrap(); - list_of_string_builder.append(true).unwrap(); + list_of_string_builder.values().append_value("foo"); + list_of_string_builder.values().append_value("bar"); + list_of_string_builder.append(true); - list_of_string_builder - .values() - .append_value("foobar") - .unwrap(); - list_of_string_builder.append(true).unwrap(); + list_of_string_builder.values().append_value("foobar"); + list_of_string_builder.append(true); let list_of_strings = list_of_string_builder.finish(); assert_eq!(list_of_strings.len(), 2); @@ -475,7 +528,9 @@ mod tests { } #[test] - #[should_panic(expected = "StringArray out of bounds access")] + #[should_panic( + expected = "Trying to access an element at index 4 from a StringArray of length 3" + )] fn test_string_array_get_value_index_out_of_bound() { let values: [u8; 12] = [ b'h', b'e', b'l', b'l', b'o', b'p', b'a', b'r', b'q', b'u', b'e', b't', @@ -648,4 +703,127 @@ mod tests { LargeStringArray::from_iter_values(BadIterator::new(3, 1, data.clone())); assert_eq!(expected, arr); } + + fn _test_generic_string_array_from_list_array() { + let values = b"HelloArrowAndParquet"; + // "ArrowAndParquet" + let child_data = ArrayData::builder(DataType::UInt8) + .len(15) + .offset(5) + .add_buffer(Buffer::from(&values[..])) + .build() + .unwrap(); + + let offsets = [0, 5, 8, 15].map(|n| O::from_usize(n).unwrap()); + let null_buffer = Buffer::from_slice_ref(&[0b101]); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( + Field::new("item", DataType::UInt8, false), + )); + + // [None, Some("Parquet")] + let array_data = ArrayData::builder(data_type) + .len(2) + .offset(1) + .add_buffer(Buffer::from_slice_ref(&offsets)) + .null_bit_buffer(Some(null_buffer)) + .add_child_data(child_data) + .build() + .unwrap(); + let list_array = GenericListArray::::from(array_data); + let string_array = GenericStringArray::::from(list_array); + + assert_eq!(2, string_array.len()); + assert_eq!(1, string_array.null_count()); + assert!(string_array.is_null(0)); + assert!(string_array.is_valid(1)); + assert_eq!("Parquet", string_array.value(1)); + } + + #[test] + fn test_string_array_from_list_array() { + _test_generic_string_array_from_list_array::(); + } + + #[test] + fn test_large_string_array_from_list_array() { + _test_generic_string_array_from_list_array::(); + } + + fn _test_generic_string_array_from_list_array_with_child_nulls_failed< + O: OffsetSizeTrait, + >() { + let values = b"HelloArrow"; + let child_data = ArrayData::builder(DataType::UInt8) + .len(10) + .add_buffer(Buffer::from(&values[..])) + .null_bit_buffer(Some(Buffer::from_slice_ref(&[0b1010101010]))) + .build() + .unwrap(); + + let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( + Field::new("item", DataType::UInt8, false), + )); + + // [None, Some(b"Parquet")] + let array_data = ArrayData::builder(data_type) + .len(2) + .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_child_data(child_data) + .build() + .unwrap(); + let list_array = GenericListArray::::from(array_data); + drop(GenericStringArray::::from(list_array)); + } + + #[test] + #[should_panic(expected = "The child array cannot contain null values.")] + fn test_stirng_array_from_list_array_with_child_nulls_failed() { + _test_generic_string_array_from_list_array_with_child_nulls_failed::(); + } + + #[test] + #[should_panic(expected = "The child array cannot contain null values.")] + fn test_large_string_array_from_list_array_with_child_nulls_failed() { + _test_generic_string_array_from_list_array_with_child_nulls_failed::(); + } + + fn _test_generic_string_array_from_list_array_wrong_type() { + let values = b"HelloArrow"; + let child_data = ArrayData::builder(DataType::UInt16) + .len(5) + .add_buffer(Buffer::from(&values[..])) + .build() + .unwrap(); + + let offsets = [0, 2, 3].map(|n| O::from_usize(n).unwrap()); + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( + Field::new("item", DataType::UInt16, false), + )); + + let array_data = ArrayData::builder(data_type) + .len(2) + .add_buffer(Buffer::from_slice_ref(&offsets)) + .add_child_data(child_data) + .build() + .unwrap(); + let list_array = GenericListArray::::from(array_data); + drop(GenericStringArray::::from(list_array)); + } + + #[test] + #[should_panic( + expected = "StringArray can only be created from List arrays, mismatched data types." + )] + fn test_string_array_from_list_array_wrong_type() { + _test_generic_string_array_from_list_array_wrong_type::(); + } + + #[test] + #[should_panic( + expected = "StringArray can only be created from List arrays, mismatched data types." + )] + fn test_large_string_array_from_list_array_wrong_type() { + _test_generic_string_array_from_list_array_wrong_type::(); + } } diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs index 639b82ae9806..b221239b2dbe 100644 --- a/arrow/src/array/array_union.rs +++ b/arrow/src/array/array_union.rs @@ -231,10 +231,10 @@ impl UnionArray { /// /// Panics if the `type_id` provided is less than zero or greater than the number of types /// in the `Union`. - pub fn child(&self, type_id: i8) -> ArrayRef { + pub fn child(&self, type_id: i8) -> &ArrayRef { assert!(0 <= type_id); assert!((type_id as usize) < self.boxed_fields.len()); - self.boxed_fields[type_id as usize].clone() + &self.boxed_fields[type_id as usize] } /// Returns the `type_id` for the array slot at `index`. @@ -243,8 +243,8 @@ impl UnionArray { /// /// Panics if `index` is greater than the length of the array. pub fn type_id(&self, index: usize) -> i8 { - assert!(index - self.offset() < self.len()); - self.data().buffers()[0].as_slice()[index] as i8 + assert!(index < self.len()); + self.data().buffers()[0].as_slice()[self.offset() + index] as i8 } /// Returns the offset into the underlying values array for the array slot at `index`. @@ -253,22 +253,20 @@ impl UnionArray { /// /// Panics if `index` is greater than the length of the array. pub fn value_offset(&self, index: usize) -> i32 { - assert!(index - self.offset() < self.len()); + assert!(index < self.len()); if self.is_dense() { - self.data().buffers()[1].typed_data::()[index] + self.data().buffers()[1].typed_data::()[self.offset() + index] } else { - index as i32 + (self.offset() + index) as i32 } } - /// Returns the array's value at `index`. - /// + /// Returns the array's value at index `i`. /// # Panics - /// - /// Panics if `index` is greater than the length of the array. - pub fn value(&self, index: usize) -> ArrayRef { - let type_id = self.type_id(self.offset() + index); - let value_offset = self.value_offset(self.offset() + index) as usize; + /// Panics if index `i` is out of bounds + pub fn value(&self, i: usize) -> ArrayRef { + let type_id = self.type_id(i); + let value_offset = self.value_offset(i) as usize; let child_data = self.boxed_fields[type_id as usize].clone(); child_data.slice(value_offset, 1) } @@ -383,10 +381,11 @@ mod tests { use crate::array::*; use crate::buffer::Buffer; use crate::datatypes::{DataType, Field}; + use crate::record_batch::RecordBatch; #[test] fn test_dense_i32() { - let mut builder = UnionBuilder::new_dense(7); + let mut builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); @@ -446,7 +445,7 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] fn test_dense_i32_large() { - let mut builder = UnionBuilder::new_dense(1024); + let mut builder = UnionBuilder::new_dense(); let expected_type_ids = vec![0_i8; 1024]; let expected_value_offsets: Vec<_> = (0..1024).collect(); @@ -488,7 +487,7 @@ mod tests { #[test] fn test_dense_mixed() { - let mut builder = UnionBuilder::new_dense(7); + let mut builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 4).unwrap(); @@ -538,7 +537,7 @@ mod tests { #[test] fn test_dense_mixed_with_nulls() { - let mut builder = UnionBuilder::new_dense(7); + let mut builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 10).unwrap(); @@ -586,7 +585,7 @@ mod tests { #[test] fn test_dense_mixed_with_nulls_and_offset() { - let mut builder = UnionBuilder::new_dense(7); + let mut builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 10).unwrap(); @@ -713,7 +712,7 @@ mod tests { #[test] fn test_sparse_i32() { - let mut builder = UnionBuilder::new_sparse(7); + let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); @@ -765,7 +764,7 @@ mod tests { #[test] fn test_sparse_mixed() { - let mut builder = UnionBuilder::new_sparse(5); + let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append::("c", 3.0).unwrap(); builder.append::("a", 4).unwrap(); @@ -828,7 +827,7 @@ mod tests { #[test] fn test_sparse_mixed_with_nulls() { - let mut builder = UnionBuilder::new_sparse(5); + let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); @@ -881,7 +880,7 @@ mod tests { #[test] fn test_sparse_mixed_with_nulls_and_offset() { - let mut builder = UnionBuilder::new_sparse(5); + let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); @@ -928,7 +927,7 @@ mod tests { #[test] fn test_union_array_validaty() { - let mut builder = UnionBuilder::new_sparse(5); + let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); @@ -938,7 +937,7 @@ mod tests { test_union_validity(&union); - let mut builder = UnionBuilder::new_dense(5); + let mut builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); @@ -951,9 +950,74 @@ mod tests { #[test] fn test_type_check() { - let mut builder = UnionBuilder::new_sparse(2); + let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1.0).unwrap(); let err = builder.append::("a", 1).unwrap_err().to_string(); assert!(err.contains("Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"), "{}", err); } + + #[test] + fn slice_union_array() { + // [1, null, 3.0, null, 4] + fn create_union(mut builder: UnionBuilder) -> UnionArray { + builder.append::("a", 1).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("c", 3.0).unwrap(); + builder.append_null::("c").unwrap(); + builder.append::("a", 4).unwrap(); + builder.build().unwrap() + } + + fn create_batch(union: UnionArray) -> RecordBatch { + let schema = Schema::new(vec![Field::new( + "struct_array", + union.data_type().clone(), + true, + )]); + + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap() + } + + fn test_slice_union(record_batch_slice: RecordBatch) { + let union_slice = record_batch_slice + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(union_slice.type_id(0), 0); + assert_eq!(union_slice.type_id(1), 1); + assert_eq!(union_slice.type_id(2), 1); + + let slot = union_slice.value(0); + let array = slot.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + + let slot = union_slice.value(1); + let array = slot.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_valid(0)); + assert_eq!(array.value(0), 3.0); + + let slot = union_slice.value(2); + let array = slot.as_any().downcast_ref::().unwrap(); + assert_eq!(array.len(), 1); + assert!(array.is_null(0)); + } + + // Sparse Union + let builder = UnionBuilder::new_sparse(); + let record_batch = create_batch(create_union(builder)); + // [null, 3.0, null] + let record_batch_slice = record_batch.slice(1, 3); + test_slice_union(record_batch_slice); + + // Dense Union + let builder = UnionBuilder::new_dense(); + let record_batch = create_batch(create_union(builder)); + // [null, 3.0, null] + let record_batch_slice = record_batch.slice(1, 3); + test_slice_union(record_batch_slice); + } } diff --git a/arrow/src/array/builder/boolean_builder.rs b/arrow/src/array/builder/boolean_builder.rs index d0063e56629f..eed14a55fd91 100644 --- a/arrow/src/array/builder/boolean_builder.rs +++ b/arrow/src/array/builder/boolean_builder.rs @@ -23,9 +23,12 @@ use crate::array::ArrayData; use crate::array::ArrayRef; use crate::array::BooleanArray; use crate::datatypes::DataType; -use crate::error::{ArrowError, Result}; + +use crate::error::ArrowError; +use crate::error::Result; use super::BooleanBufferBuilder; +use super::NullBufferBuilder; /// Array builder for fixed-width primitive types /// @@ -36,7 +39,7 @@ use super::BooleanBufferBuilder; /// ``` /// use arrow::array::{Array, BooleanArray, BooleanBuilder}; /// -/// let mut b = BooleanBuilder::new(4); +/// let mut b = BooleanBuilder::new(); /// b.append_value(true); /// b.append_null(); /// b.append_value(false); @@ -60,17 +63,26 @@ use super::BooleanBufferBuilder; #[derive(Debug)] pub struct BooleanBuilder { values_builder: BooleanBufferBuilder, - /// We only materialize the builder when we add `false`. - /// This optimization is **very** important for the performance. - bitmap_builder: Option, + null_buffer_builder: NullBufferBuilder, +} + +impl Default for BooleanBuilder { + fn default() -> Self { + Self::new() + } } impl BooleanBuilder { - /// Creates a new primitive array builder - pub fn new(capacity: usize) -> Self { + /// Creates a new boolean builder + pub fn new() -> Self { + Self::with_capacity(1024) + } + + /// Creates a new boolean builder with space for `capacity` elements without re-allocating + pub fn with_capacity(capacity: usize) -> Self { Self { values_builder: BooleanBufferBuilder::new(capacity), - bitmap_builder: None, + null_buffer_builder: NullBufferBuilder::new(capacity), } } @@ -81,44 +93,44 @@ impl BooleanBuilder { /// Appends a value of type `T` into the builder #[inline] - pub fn append_value(&mut self, v: bool) -> Result<()> { + pub fn append_value(&mut self, v: bool) { self.values_builder.append(v); - if let Some(b) = self.bitmap_builder.as_mut() { - b.append(true) - } - Ok(()) + self.null_buffer_builder.append_non_null(); } /// Appends a null slot into the builder #[inline] - pub fn append_null(&mut self) -> Result<()> { - self.materialize_bitmap_builder(); - self.bitmap_builder.as_mut().unwrap().append(false); + pub fn append_null(&mut self) { + self.null_buffer_builder.append_null(); self.values_builder.advance(1); - Ok(()) + } + + /// Appends `n` `null`s into the builder. + #[inline] + pub fn append_nulls(&mut self, n: usize) { + self.null_buffer_builder.append_n_nulls(n); + self.values_builder.advance(n); } /// Appends an `Option` into the builder #[inline] - pub fn append_option(&mut self, v: Option) -> Result<()> { + pub fn append_option(&mut self, v: Option) { match v { - None => self.append_null()?, - Some(v) => self.append_value(v)?, + None => self.append_null(), + Some(v) => self.append_value(v), }; - Ok(()) } /// Appends a slice of type `T` into the builder #[inline] - pub fn append_slice(&mut self, v: &[bool]) -> Result<()> { - if let Some(b) = self.bitmap_builder.as_mut() { - b.append_n(v.len(), true) - } + pub fn append_slice(&mut self, v: &[bool]) { self.values_builder.append_slice(v); - Ok(()) + self.null_buffer_builder.append_n_non_nulls(v.len()); } - /// Appends values from a slice of type `T` and a validity boolean slice + /// Appends values from a slice of type `T` and a validity boolean slice. + /// + /// Returns an error if the slices are of different lengths #[inline] pub fn append_values(&mut self, values: &[bool], is_valid: &[bool]) -> Result<()> { if values.len() != is_valid.len() { @@ -126,13 +138,7 @@ impl BooleanBuilder { "Value and validity lengths must be equal".to_string(), )) } else { - is_valid - .iter() - .any(|v| !*v) - .then(|| self.materialize_bitmap_builder()); - if let Some(b) = self.bitmap_builder.as_mut() { - b.append_slice(is_valid) - } + self.null_buffer_builder.append_slice(is_valid); self.values_builder.append_slice(values); Ok(()) } @@ -141,7 +147,7 @@ impl BooleanBuilder { /// Builds the [BooleanArray] and reset this builder. pub fn finish(&mut self) -> BooleanArray { let len = self.len(); - let null_bit_buffer = self.bitmap_builder.as_mut().map(|b| b.finish()); + let null_bit_buffer = self.null_buffer_builder.finish(); let builder = ArrayData::builder(DataType::Boolean) .len(len) .add_buffer(self.values_builder.finish()) @@ -150,15 +156,6 @@ impl BooleanBuilder { let array_data = unsafe { builder.build_unchecked() }; BooleanArray::from(array_data) } - - fn materialize_bitmap_builder(&mut self) { - if self.bitmap_builder.is_none() { - let mut b = BooleanBufferBuilder::new(0); - b.reserve(self.values_builder.capacity()); - b.append_n(self.values_builder.len(), true); - self.bitmap_builder = Some(b); - } - } } impl ArrayBuilder for BooleanBuilder { @@ -205,9 +202,9 @@ mod tests { let mut builder = BooleanArray::builder(10); for i in 0..10 { if i == 3 || i == 6 || i == 9 { - builder.append_value(true).unwrap(); + builder.append_value(true); } else { - builder.append_value(false).unwrap(); + builder.append_value(false); } } @@ -229,10 +226,10 @@ mod tests { BooleanArray::from(vec![Some(true), Some(false), None, None, Some(false)]); let mut builder = BooleanArray::builder(0); - builder.append_slice(&[true, false]).unwrap(); - builder.append_null().unwrap(); - builder.append_null().unwrap(); - builder.append_value(false).unwrap(); + builder.append_slice(&[true, false]); + builder.append_null(); + builder.append_null(); + builder.append_value(false); let arr2 = builder.finish(); assert_eq!(arr1, arr2); @@ -243,7 +240,7 @@ mod tests { let arr1 = BooleanArray::from(vec![true; 513]); let mut builder = BooleanArray::builder(512); - builder.append_slice(&[true; 513]).unwrap(); + builder.append_slice(&[true; 513]); let arr2 = builder.finish(); assert_eq!(arr1, arr2); @@ -252,9 +249,9 @@ mod tests { #[test] fn test_boolean_array_builder_no_null() { let mut builder = BooleanArray::builder(0); - builder.append_option(Some(true)).unwrap(); - builder.append_value(false).unwrap(); - builder.append_slice(&[true, false, true]).unwrap(); + builder.append_option(Some(true)); + builder.append_value(false); + builder.append_slice(&[true, false, true]); builder .append_values(&[false, false, true], &[true, true, true]) .unwrap(); diff --git a/arrow/src/array/builder/buffer_builder.rs b/arrow/src/array/builder/buffer_builder.rs index 9dd13839800c..a6a81dfd6c0e 100644 --- a/arrow/src/array/builder/buffer_builder.rs +++ b/arrow/src/array/builder/buffer_builder.rs @@ -362,7 +362,6 @@ mod tests { use crate::array::Int32BufferBuilder; use crate::array::Int8Builder; use crate::array::UInt8BufferBuilder; - use crate::error::Result; #[test] fn test_builder_i32_empty() { @@ -457,17 +456,17 @@ mod tests { } #[test] - fn test_append_values() -> Result<()> { - let mut a = Int8Builder::new(0); - a.append_value(1)?; - a.append_null()?; - a.append_value(-2)?; + fn test_append_values() { + let mut a = Int8Builder::new(); + a.append_value(1); + a.append_null(); + a.append_value(-2); assert_eq!(a.len(), 3); // append values let values = &[1, 2, 3, 4]; let is_valid = &[true, true, false, true]; - a.append_values(values, is_valid)?; + a.append_values(values, is_valid); assert_eq!(a.len(), 7); let array = a.finish(); @@ -478,7 +477,5 @@ mod tests { assert_eq!(array.value(4), 2); assert!(array.is_null(5)); assert_eq!(array.value(6), 4); - - Ok(()) } } diff --git a/arrow/src/array/builder/decimal_builder.rs b/arrow/src/array/builder/decimal_builder.rs index 033de8976e34..daa30eebed92 100644 --- a/arrow/src/array/builder/decimal_builder.rs +++ b/arrow/src/array/builder/decimal_builder.rs @@ -18,26 +18,27 @@ use std::any::Any; use std::sync::Arc; -use crate::array::array_decimal::{BasicDecimalArray, Decimal256Array}; +use crate::array::array_decimal::Decimal256Array; use crate::array::ArrayRef; -use crate::array::DecimalArray; -use crate::array::UInt8Builder; -use crate::array::{ArrayBuilder, FixedSizeListBuilder}; +use crate::array::Decimal128Array; +use crate::array::{ArrayBuilder, FixedSizeBinaryBuilder}; use crate::error::{ArrowError, Result}; -use crate::datatypes::validate_decimal_precision; -use crate::util::decimal::{BasicDecimal, Decimal256}; +use crate::datatypes::{ + validate_decimal256_precision_with_lt_bytes, validate_decimal_precision, +}; +use crate::util::decimal::Decimal256; -/// Array Builder for [`DecimalArray`] +/// Array Builder for [`Decimal128Array`] /// -/// See [`DecimalArray`] for example. +/// See [`Decimal128Array`] for example. /// #[derive(Debug)] -pub struct DecimalBuilder { - builder: FixedSizeListBuilder, - precision: usize, - scale: usize, +pub struct Decimal128Builder { + builder: FixedSizeBinaryBuilder, + precision: u8, + scale: u8, /// Should i128 values be validated for compatibility with scale and precision? /// defaults to true @@ -49,19 +50,28 @@ pub struct DecimalBuilder { /// See [`Decimal256Array`] for example. #[derive(Debug)] pub struct Decimal256Builder { - builder: FixedSizeListBuilder, - precision: usize, - scale: usize, + builder: FixedSizeBinaryBuilder, + precision: u8, + scale: u8, + + /// Should decimal values be validated for compatibility with scale and precision? + /// defaults to true + value_validation: bool, } -impl DecimalBuilder { - /// Creates a new `DecimalBuilder`, `capacity` is the number of bytes in the values - /// array - pub fn new(capacity: usize, precision: usize, scale: usize) -> Self { - let values_builder = UInt8Builder::new(capacity); - let byte_width = 16; +impl Decimal128Builder { + const BYTE_LENGTH: i32 = 16; + + /// Creates a new [`Decimal128Builder`] + pub fn new(precision: u8, scale: u8) -> Self { + Self::with_capacity(1024, precision, scale) + } + + /// Creates a new [`Decimal128Builder`], `capacity` is the number of decimal values + /// that can be appended without reallocating + pub fn with_capacity(capacity: usize, precision: u8, scale: u8) -> Self { Self { - builder: FixedSizeListBuilder::new(values_builder, byte_width), + builder: FixedSizeBinaryBuilder::with_capacity(capacity, Self::BYTE_LENGTH), precision, scale, value_validation: true, @@ -78,55 +88,38 @@ impl DecimalBuilder { self.value_validation = false; } - /// Appends a byte slice into the builder. - /// - /// Automatically calls the `append` method to delimit the slice appended in as a - /// distinct array element. + /// Appends a decimal value into the builder. #[inline] pub fn append_value(&mut self, value: impl Into) -> Result<()> { - let value = if self.value_validation { - validate_decimal_precision(value.into(), self.precision)? - } else { - value.into() - }; - - let value_as_bytes = Self::from_i128_to_fixed_size_bytes( - value, - self.builder.value_length() as usize, - )?; - if self.builder.value_length() != value_as_bytes.len() as i32 { - return Err(ArrowError::InvalidArgumentError( - "Byte slice does not have the same length as DecimalBuilder value lengths".to_string() - )); + let value = value.into(); + if self.value_validation { + validate_decimal_precision(value, self.precision)? } - self.builder - .values() - .append_slice(value_as_bytes.as_slice())?; - self.builder.append(true) + let value_as_bytes: [u8; 16] = value.to_le_bytes(); + self.builder.append_value(value_as_bytes.as_slice()) } - pub(crate) fn from_i128_to_fixed_size_bytes(v: i128, size: usize) -> Result> { - if size > 16 { - return Err(ArrowError::InvalidArgumentError( - "DecimalBuilder only supports values up to 16 bytes.".to_string(), - )); - } - let res = v.to_le_bytes(); - let start_byte = 16 - size; - Ok(res[start_byte..16].to_vec()) + /// Append a null value to the array. + #[inline] + pub fn append_null(&mut self) { + self.builder.append_null() } - /// Append a null value to the array. + /// Appends an `Option>` into the builder. #[inline] - pub fn append_null(&mut self) -> Result<()> { - let length: usize = self.builder.value_length() as usize; - self.builder.values().append_slice(&vec![0u8; length][..])?; - self.builder.append(false) + pub fn append_option(&mut self, value: Option>) -> Result<()> { + match value { + None => { + self.append_null(); + Ok(()) + } + Some(value) => self.append_value(value), + } } - /// Builds the `DecimalArray` and reset this builder. - pub fn finish(&mut self) -> DecimalArray { - DecimalArray::from_fixed_size_list_array( + /// Builds the `Decimal128Array` and reset this builder. + pub fn finish(&mut self) -> Decimal128Array { + Decimal128Array::from_fixed_size_binary_array( self.builder.finish(), self.precision, self.scale, @@ -134,7 +127,7 @@ impl DecimalBuilder { } } -impl ArrayBuilder for DecimalBuilder { +impl ArrayBuilder for Decimal128Builder { /// Returns the builder as a non-mutable `Any` reference. fn as_any(&self) -> &dyn Any { self @@ -167,25 +160,46 @@ impl ArrayBuilder for DecimalBuilder { } impl Decimal256Builder { - /// Creates a new `Decimal256Builder`, `capacity` is the number of bytes in the values - /// array - pub fn new(capacity: usize, precision: usize, scale: usize) -> Self { - let values_builder = UInt8Builder::new(capacity); - let byte_width = 32; + const BYTE_LENGTH: i32 = 32; + + /// Creates a new [`Decimal256Builder`] + pub fn new(precision: u8, scale: u8) -> Self { + Self::with_capacity(1024, precision, scale) + } + + /// Creates a new [`Decimal256Builder`], `capacity` is the number of decimal values + /// that can be appended without reallocating + pub fn with_capacity(capacity: usize, precision: u8, scale: u8) -> Self { Self { - builder: FixedSizeListBuilder::new(values_builder, byte_width), + builder: FixedSizeBinaryBuilder::with_capacity(capacity, Self::BYTE_LENGTH), precision, scale, + value_validation: true, } } - /// Appends a byte slice into the builder. + /// Disable validation + /// + /// # Safety + /// + /// After disabling validation, caller must ensure that appended values are compatible + /// for the specified precision and scale. + pub unsafe fn disable_value_validation(&mut self) { + self.value_validation = false; + } + + /// Appends a [`Decimal256`] number into the builder. /// - /// Automatically calls the `append` method to delimit the slice appended in as a - /// distinct array element. + /// Returns an error if `value` has different precision, scale or length in bytes than this builder #[inline] pub fn append_value(&mut self, value: &Decimal256) -> Result<()> { - let value_as_bytes = value.raw_value(); + let value = if self.value_validation { + let raw_bytes = value.raw_value(); + validate_decimal256_precision_with_lt_bytes(raw_bytes, self.precision)?; + value + } else { + value + }; if self.precision != value.precision() || self.scale != value.scale() { return Err(ArrowError::InvalidArgumentError( @@ -193,26 +207,37 @@ impl Decimal256Builder { )); } - if self.builder.value_length() != value_as_bytes.len() as i32 { + let value_as_bytes = value.raw_value(); + + if Self::BYTE_LENGTH != value_as_bytes.len() as i32 { return Err(ArrowError::InvalidArgumentError( "Byte slice does not have the same length as Decimal256Builder value lengths".to_string() )); } - self.builder.values().append_slice(value_as_bytes)?; - self.builder.append(true) + self.builder.append_value(value_as_bytes) } /// Append a null value to the array. #[inline] - pub fn append_null(&mut self) -> Result<()> { - let length: usize = self.builder.value_length() as usize; - self.builder.values().append_slice(&vec![0u8; length][..])?; - self.builder.append(false) + pub fn append_null(&mut self) { + self.builder.append_null() } - /// Builds the `Decimal256Array` and reset this builder. + /// Appends an `Option<&Decimal256>` into the builder. + #[inline] + pub fn append_option(&mut self, value: Option<&Decimal256>) -> Result<()> { + match value { + None => { + self.append_null(); + Ok(()) + } + Some(value) => self.append_value(value), + } + } + + /// Builds the [`Decimal256Array`] and reset this builder. pub fn finish(&mut self) -> Decimal256Array { - Decimal256Array::from_fixed_size_list_array( + Decimal256Array::from_fixed_size_binary_array( self.builder.finish(), self.precision, self.scale, @@ -223,42 +248,45 @@ impl Decimal256Builder { #[cfg(test)] mod tests { use super::*; + use num::{BigInt, Num}; - use crate::array::array_decimal::BasicDecimalArray; - use crate::array::Array; + use crate::array::array_decimal::Decimal128Array; + use crate::array::{array_decimal, Array}; use crate::datatypes::DataType; - use crate::util::decimal::Decimal128; + use crate::util::decimal::{Decimal128, Decimal256}; #[test] fn test_decimal_builder() { - let mut builder = DecimalBuilder::new(30, 38, 6); + let mut builder = Decimal128Builder::new(38, 6); builder.append_value(8_887_000_000_i128).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append_value(-8_887_000_000_i128).unwrap(); - let decimal_array: DecimalArray = builder.finish(); + builder.append_option(None::).unwrap(); + builder.append_option(Some(8_887_000_000_i128)).unwrap(); + let decimal_array: Decimal128Array = builder.finish(); - assert_eq!(&DataType::Decimal(38, 6), decimal_array.data_type()); - assert_eq!(3, decimal_array.len()); - assert_eq!(1, decimal_array.null_count()); + assert_eq!(&DataType::Decimal128(38, 6), decimal_array.data_type()); + assert_eq!(5, decimal_array.len()); + assert_eq!(2, decimal_array.null_count()); assert_eq!(32, decimal_array.value_offset(2)); assert_eq!(16, decimal_array.value_length()); } #[test] fn test_decimal_builder_with_decimal128() { - let mut builder = DecimalBuilder::new(30, 38, 6); + let mut builder = Decimal128Builder::new(38, 6); builder .append_value(Decimal128::new_from_i128(30, 38, 8_887_000_000_i128)) .unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder .append_value(Decimal128::new_from_i128(30, 38, -8_887_000_000_i128)) .unwrap(); - let decimal_array: DecimalArray = builder.finish(); + let decimal_array: Decimal128Array = builder.finish(); - assert_eq!(&DataType::Decimal(38, 6), decimal_array.data_type()); + assert_eq!(&DataType::Decimal128(38, 6), decimal_array.data_type()); assert_eq!(3, decimal_array.len()); assert_eq!(1, decimal_array.null_count()); assert_eq!(32, decimal_array.value_offset(2)); @@ -267,30 +295,33 @@ mod tests { #[test] fn test_decimal256_builder() { - let mut builder = Decimal256Builder::new(30, 40, 6); + let mut builder = Decimal256Builder::new(40, 6); - let mut bytes = vec![0; 32]; + let mut bytes = [0_u8; 32]; bytes[0..16].clone_from_slice(&8_887_000_000_i128.to_le_bytes()); - let value = Decimal256::try_new_from_bytes(40, 6, bytes.as_slice()).unwrap(); + let value = Decimal256::try_new_from_bytes(40, 6, &bytes).unwrap(); builder.append_value(&value).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); - bytes = vec![255; 32]; - let value = Decimal256::try_new_from_bytes(40, 6, bytes.as_slice()).unwrap(); + bytes = [255; 32]; + let value = Decimal256::try_new_from_bytes(40, 6, &bytes).unwrap(); builder.append_value(&value).unwrap(); - bytes = vec![0; 32]; + bytes = [0; 32]; bytes[0..16].clone_from_slice(&0_i128.to_le_bytes()); bytes[15] = 128; - let value = Decimal256::try_new_from_bytes(40, 6, bytes.as_slice()).unwrap(); + let value = Decimal256::try_new_from_bytes(40, 6, &bytes).unwrap(); builder.append_value(&value).unwrap(); + builder.append_option(None::<&Decimal256>).unwrap(); + builder.append_option(Some(&value)).unwrap(); + let decimal_array: Decimal256Array = builder.finish(); - assert_eq!(&DataType::Decimal(40, 6), decimal_array.data_type()); - assert_eq!(4, decimal_array.len()); - assert_eq!(1, decimal_array.null_count()); + assert_eq!(&DataType::Decimal256(40, 6), decimal_array.data_type()); + assert_eq!(6, decimal_array.len()); + assert_eq!(2, decimal_array.null_count()); assert_eq!(64, decimal_array.value_offset(2)); assert_eq!(32, decimal_array.value_length()); @@ -308,11 +339,45 @@ mod tests { expected = "Decimal value does not have the same precision or scale as Decimal256Builder" )] fn test_decimal256_builder_unmatched_precision_scale() { - let mut builder = Decimal256Builder::new(30, 10, 6); + let mut builder = Decimal256Builder::with_capacity(30, 10, 6); - let mut bytes = vec![0; 32]; + let mut bytes = [0_u8; 32]; bytes[0..16].clone_from_slice(&8_887_000_000_i128.to_le_bytes()); - let value = Decimal256::try_new_from_bytes(40, 6, bytes.as_slice()).unwrap(); + let value = Decimal256::try_new_from_bytes(40, 6, &bytes).unwrap(); builder.append_value(&value).unwrap(); } + + #[test] + #[should_panic( + expected = "9999999999999999999999999999999999999999999999999999999999999999999999999999 is too large to store in a Decimal256 of precision 75. Max is 999999999999999999999999999999999999999999999999999999999999999999999999999" + )] + fn test_decimal256_builder_out_of_range_precision_scale() { + let mut builder = Decimal256Builder::new(75, 6); + + let big_value = BigInt::from_str_radix("9999999999999999999999999999999999999999999999999999999999999999999999999999", 10).unwrap(); + let value = Decimal256::from_big_int(&big_value, 75, 6).unwrap(); + builder.append_value(&value).unwrap(); + } + + #[test] + #[should_panic( + expected = "9999999999999999999999999999999999999999999999999999999999999999999999999999 is too large to store in a Decimal256 of precision 75. Max is 999999999999999999999999999999999999999999999999999999999999999999999999999" + )] + fn test_decimal256_data_validation() { + let mut builder = Decimal256Builder::new(75, 6); + // Disable validation at builder + unsafe { + builder.disable_value_validation(); + } + + let big_value = BigInt::from_str_radix("9999999999999999999999999999999999999999999999999999999999999999999999999999", 10).unwrap(); + let value = Decimal256::from_big_int(&big_value, 75, 6).unwrap(); + builder + .append_value(&value) + .expect("should not validate invalid value at builder"); + + let array = builder.finish(); + let array_data = array_decimal::DecimalArray::data(&array); + array_data.validate_values().unwrap(); + } } diff --git a/arrow/src/array/builder/fixed_size_binary_builder.rs b/arrow/src/array/builder/fixed_size_binary_builder.rs index e62aa8fa60cf..30c25e0a62b9 100644 --- a/arrow/src/array/builder/fixed_size_binary_builder.rs +++ b/arrow/src/array/builder/fixed_size_binary_builder.rs @@ -23,31 +23,32 @@ use crate::error::{ArrowError, Result}; use std::any::Any; use std::sync::Arc; -use super::BooleanBufferBuilder; +use super::NullBufferBuilder; #[derive(Debug)] pub struct FixedSizeBinaryBuilder { values_builder: UInt8BufferBuilder, - bitmap_builder: BooleanBufferBuilder, + null_buffer_builder: NullBufferBuilder, value_length: i32, } impl FixedSizeBinaryBuilder { - /// Creates a new [`FixedSizeBinaryBuilder`], `capacity` is the number of bytes in the values - /// buffer - pub fn new(capacity: usize, byte_width: i32) -> Self { + /// Creates a new [`FixedSizeBinaryBuilder`] + pub fn new(byte_width: i32) -> Self { + Self::with_capacity(1024, byte_width) + } + + /// Creates a new [`FixedSizeBinaryBuilder`], `capacity` is the number of byte slices + /// that can be appended without reallocating + pub fn with_capacity(capacity: usize, byte_width: i32) -> Self { assert!( byte_width >= 0, "value length ({}) of the array must >= 0", byte_width ); Self { - values_builder: UInt8BufferBuilder::new(capacity), - bitmap_builder: BooleanBufferBuilder::new(if byte_width > 0 { - capacity / byte_width as usize - } else { - 0 - }), + values_builder: UInt8BufferBuilder::new(capacity * byte_width as usize), + null_buffer_builder: NullBufferBuilder::new(capacity), value_length: byte_width, } } @@ -64,18 +65,17 @@ impl FixedSizeBinaryBuilder { )) } else { self.values_builder.append_slice(value.as_ref()); - self.bitmap_builder.append(true); + self.null_buffer_builder.append_non_null(); Ok(()) } } /// Append a null value to the array. #[inline] - pub fn append_null(&mut self) -> Result<()> { + pub fn append_null(&mut self) { self.values_builder .append_slice(&vec![0u8; self.value_length as usize][..]); - self.bitmap_builder.append(false); - Ok(()) + self.null_buffer_builder.append_null(); } /// Builds the [`FixedSizeBinaryArray`] and reset this builder. @@ -84,7 +84,7 @@ impl FixedSizeBinaryBuilder { let array_data_builder = ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) .add_buffer(self.values_builder.finish()) - .null_bit_buffer(Some(self.bitmap_builder.finish())) + .null_bit_buffer(self.null_buffer_builder.finish()) .len(array_length); let array_data = unsafe { array_data_builder.build_unchecked() }; FixedSizeBinaryArray::from(array_data) @@ -109,12 +109,12 @@ impl ArrayBuilder for FixedSizeBinaryBuilder { /// Returns the number of array slots in the builder fn len(&self) -> usize { - self.bitmap_builder.len() + self.null_buffer_builder.len() } /// Returns whether the number of array slots is zero fn is_empty(&self) -> bool { - self.bitmap_builder.is_empty() + self.null_buffer_builder.is_empty() } /// Builds the array and reset this builder. @@ -133,11 +133,11 @@ mod tests { #[test] fn test_fixed_size_binary_builder() { - let mut builder = FixedSizeBinaryBuilder::new(15, 5); + let mut builder = FixedSizeBinaryBuilder::with_capacity(3, 5); // [b"hello", null, "arrow"] builder.append_value(b"hello").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append_value(b"arrow").unwrap(); let array: FixedSizeBinaryArray = builder.finish(); @@ -150,10 +150,10 @@ mod tests { #[test] fn test_fixed_size_binary_builder_with_zero_value_length() { - let mut builder = FixedSizeBinaryBuilder::new(0, 0); + let mut builder = FixedSizeBinaryBuilder::new(0); builder.append_value(b"").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append_value(b"").unwrap(); assert!(!builder.is_empty()); @@ -172,12 +172,12 @@ mod tests { expected = "Byte slice does not have the same length as FixedSizeBinaryBuilder value lengths" )] fn test_fixed_size_binary_builder_with_inconsistent_value_length() { - let mut builder = FixedSizeBinaryBuilder::new(15, 4); + let mut builder = FixedSizeBinaryBuilder::with_capacity(1, 4); builder.append_value(b"hello").unwrap(); } #[test] fn test_fixed_size_binary_builder_empty() { - let mut builder = FixedSizeBinaryBuilder::new(15, 5); + let mut builder = FixedSizeBinaryBuilder::new(5); assert!(builder.is_empty()); let fixed_size_binary_array = builder.finish(); @@ -191,6 +191,6 @@ mod tests { #[test] #[should_panic(expected = "value length (-1) of the array must >= 0")] fn test_fixed_size_binary_builder_invalid_value_length() { - let _ = FixedSizeBinaryBuilder::new(15, -1); + let _ = FixedSizeBinaryBuilder::with_capacity(15, -1); } } diff --git a/arrow/src/array/builder/fixed_size_list_builder.rs b/arrow/src/array/builder/fixed_size_list_builder.rs index 91c20d2a5ace..da850d156243 100644 --- a/arrow/src/array/builder/fixed_size_list_builder.rs +++ b/arrow/src/array/builder/fixed_size_list_builder.rs @@ -23,15 +23,14 @@ use crate::array::ArrayRef; use crate::array::FixedSizeListArray; use crate::datatypes::DataType; use crate::datatypes::Field; -use crate::error::Result; use super::ArrayBuilder; -use super::BooleanBufferBuilder; +use super::NullBufferBuilder; /// Array builder for [`FixedSizeListArray`] #[derive(Debug)] pub struct FixedSizeListBuilder { - bitmap_builder: BooleanBufferBuilder, + null_buffer_builder: NullBufferBuilder, values_builder: T, list_len: i32, } @@ -49,7 +48,7 @@ impl FixedSizeListBuilder { /// `capacity` is the number of items to pre-allocate space for in this builder pub fn with_capacity(values_builder: T, value_length: i32, capacity: usize) -> Self { Self { - bitmap_builder: BooleanBufferBuilder::new(capacity), + null_buffer_builder: NullBufferBuilder::new(capacity), values_builder, list_len: value_length, } @@ -77,12 +76,12 @@ where /// Returns the number of array slots in the builder fn len(&self) -> usize { - self.bitmap_builder.len() + self.null_buffer_builder.len() } /// Returns whether the number of array slots is zero fn is_empty(&self) -> bool { - self.bitmap_builder.is_empty() + self.null_buffer_builder.is_empty() } /// Builds the array and reset this builder. @@ -109,9 +108,8 @@ where /// Finish the current fixed-length list array slot #[inline] - pub fn append(&mut self, is_valid: bool) -> Result<()> { - self.bitmap_builder.append(is_valid); - Ok(()) + pub fn append(&mut self, is_valid: bool) { + self.null_buffer_builder.append(is_valid); } /// Builds the [`FixedSizeListBuilder`] and reset this builder. @@ -133,14 +131,14 @@ where len, ); - let null_bit_buffer = self.bitmap_builder.finish(); + let null_bit_buffer = self.null_buffer_builder.finish(); let array_data = ArrayData::builder(DataType::FixedSizeList( Box::new(Field::new("item", values_data.data_type().clone(), true)), self.list_len, )) .len(len) .add_child_data(values_data.clone()) - .null_bit_buffer(Some(null_bit_buffer)); + .null_bit_buffer(null_bit_buffer); let array_data = unsafe { array_data.build_unchecked() }; @@ -158,26 +156,26 @@ mod tests { #[test] fn test_fixed_size_list_array_builder() { - let values_builder = Int32Builder::new(10); + let values_builder = Int32Builder::new(); let mut builder = FixedSizeListBuilder::new(values_builder, 3); // [[0, 1, 2], null, [3, null, 5], [6, 7, null]] - builder.values().append_value(0).unwrap(); - builder.values().append_value(1).unwrap(); - builder.values().append_value(2).unwrap(); - builder.append(true).unwrap(); - builder.values().append_null().unwrap(); - builder.values().append_null().unwrap(); - builder.values().append_null().unwrap(); - builder.append(false).unwrap(); - builder.values().append_value(3).unwrap(); - builder.values().append_null().unwrap(); - builder.values().append_value(5).unwrap(); - builder.append(true).unwrap(); - builder.values().append_value(6).unwrap(); - builder.values().append_value(7).unwrap(); - builder.values().append_null().unwrap(); - builder.append(true).unwrap(); + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_null(); + builder.append(false); + builder.values().append_value(3); + builder.values().append_null(); + builder.values().append_value(5); + builder.append(true); + builder.values().append_value(6); + builder.values().append_value(7); + builder.values().append_null(); + builder.append(true); let list_array = builder.finish(); assert_eq!(DataType::Int32, list_array.value_type()); @@ -202,17 +200,17 @@ mod tests { let values_builder = Int32Array::builder(5); let mut builder = FixedSizeListBuilder::new(values_builder, 3); - builder.values().append_slice(&[1, 2, 3]).unwrap(); - builder.append(true).unwrap(); - builder.values().append_slice(&[4, 5, 6]).unwrap(); - builder.append(true).unwrap(); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); let mut arr = builder.finish(); assert_eq!(2, arr.len()); assert_eq!(0, builder.len()); - builder.values().append_slice(&[7, 8, 9]).unwrap(); - builder.append(true).unwrap(); + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); arr = builder.finish(); assert_eq!(1, arr.len()); assert_eq!(0, builder.len()); @@ -226,12 +224,12 @@ mod tests { let values_builder = Int32Array::builder(5); let mut builder = FixedSizeListBuilder::new(values_builder, 3); - builder.values().append_slice(&[1, 2, 3]).unwrap(); - builder.append(true).unwrap(); - builder.values().append_slice(&[4, 5, 6]).unwrap(); - builder.append(true).unwrap(); - builder.values().append_slice(&[7, 8, 9, 10]).unwrap(); - builder.append(true).unwrap(); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); + builder.values().append_slice(&[7, 8, 9, 10]); + builder.append(true); builder.finish(); } diff --git a/arrow/src/array/builder/generic_binary_builder.rs b/arrow/src/array/builder/generic_binary_builder.rs index 8b7a05854a62..26501ba099da 100644 --- a/arrow/src/array/builder/generic_binary_builder.rs +++ b/arrow/src/array/builder/generic_binary_builder.rs @@ -16,65 +16,86 @@ // under the License. use crate::array::{ - ArrayBuilder, ArrayRef, GenericBinaryArray, GenericListBuilder, OffsetSizeTrait, - UInt8Builder, + ArrayBuilder, ArrayDataBuilder, ArrayRef, GenericBinaryArray, OffsetSizeTrait, + UInt8BufferBuilder, }; -use crate::error::Result; use std::any::Any; use std::sync::Arc; -/// Array builder for `BinaryArray` +use super::{BufferBuilder, NullBufferBuilder}; + +/// Array builder for [`GenericBinaryArray`] #[derive(Debug)] pub struct GenericBinaryBuilder { - builder: GenericListBuilder, + value_builder: UInt8BufferBuilder, + offsets_builder: BufferBuilder, + null_buffer_builder: NullBufferBuilder, } impl GenericBinaryBuilder { - /// Creates a new `GenericBinaryBuilder`, `capacity` is the number of bytes in the values - /// array - pub fn new(capacity: usize) -> Self { - let values_builder = UInt8Builder::new(capacity); + /// Creates a new [`GenericBinaryBuilder`]. + pub fn new() -> Self { + Self::with_capacity(1024, 1024) + } + + /// Creates a new [`GenericBinaryBuilder`], + /// `item_capacity` is the number of items to pre-allocate space for in this builder + /// `data_capacity` is the number of bytes to pre-allocate space for in this builder + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_builder = BufferBuilder::::new(item_capacity + 1); + offsets_builder.append(OffsetSize::zero()); Self { - builder: GenericListBuilder::new(values_builder), + value_builder: UInt8BufferBuilder::new(data_capacity), + offsets_builder, + null_buffer_builder: NullBufferBuilder::new(item_capacity), } } - /// Appends a single byte value into the builder's values array. - /// - /// Note, when appending individual byte values you must call `append` to delimit each - /// distinct list value. + /// Appends a byte slice into the builder. #[inline] - pub fn append_byte(&mut self, value: u8) -> Result<()> { - self.builder.values().append_value(value)?; - Ok(()) + pub fn append_value(&mut self, value: impl AsRef<[u8]>) { + self.value_builder.append_slice(value.as_ref()); + self.null_buffer_builder.append(true); + self.offsets_builder + .append(OffsetSize::from_usize(self.value_builder.len()).unwrap()); } - /// Appends a byte slice into the builder. - /// - /// Automatically calls the `append` method to delimit the slice appended in as a - /// distinct array element. + /// Append a null value to the array. #[inline] - pub fn append_value(&mut self, value: impl AsRef<[u8]>) -> Result<()> { - self.builder.values().append_slice(value.as_ref())?; - self.builder.append(true)?; - Ok(()) + pub fn append_null(&mut self) { + self.null_buffer_builder.append(false); + self.offsets_builder + .append(OffsetSize::from_usize(self.value_builder.len()).unwrap()); } - /// Finish the current variable-length list array slot. - #[inline] - pub fn append(&mut self, is_valid: bool) -> Result<()> { - self.builder.append(is_valid) + /// Builds the [`GenericBinaryArray`] and reset this builder. + pub fn finish(&mut self) -> GenericBinaryArray { + let array_type = GenericBinaryArray::::DATA_TYPE; + let array_builder = ArrayDataBuilder::new(array_type) + .len(self.len()) + .add_buffer(self.offsets_builder.finish()) + .add_buffer(self.value_builder.finish()) + .null_bit_buffer(self.null_buffer_builder.finish()); + + self.offsets_builder.append(OffsetSize::zero()); + let array_data = unsafe { array_builder.build_unchecked() }; + GenericBinaryArray::::from(array_data) } - /// Append a null value to the array. - #[inline] - pub fn append_null(&mut self) -> Result<()> { - self.append(false) + /// Returns the current values buffer as a slice + pub fn values_slice(&self) -> &[u8] { + self.value_builder.as_slice() } - /// Builds the `BinaryArray` and reset this builder. - pub fn finish(&mut self) -> GenericBinaryArray { - GenericBinaryArray::::from(self.builder.finish()) + /// Returns the current offsets buffer as a slice + pub fn offsets_slice(&self) -> &[OffsetSize] { + self.offsets_builder.as_slice() + } +} + +impl Default for GenericBinaryBuilder { + fn default() -> Self { + Self::new() } } @@ -94,14 +115,14 @@ impl ArrayBuilder for GenericBinaryBuilder usize { - self.builder.len() + self.null_buffer_builder.len() } - /// Returns whether the number of array slots is zero + /// Returns whether the number of binary slots is zero fn is_empty(&self) -> bool { - self.builder.is_empty() + self.null_buffer_builder.is_empty() } /// Builds the array and reset this builder. @@ -112,64 +133,100 @@ impl ArrayBuilder for GenericBinaryBuilder() { + let mut builder = GenericBinaryBuilder::::new(); + + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"rust"); + + let array = builder.finish(); + + assert_eq!(4, array.len()); + assert_eq!(1, array.null_count()); + assert_eq!(b"hello", array.value(0)); + assert_eq!([] as [u8; 0], array.value(1)); + assert!(array.is_null(2)); + assert_eq!(b"rust", array.value(3)); + assert_eq!(O::from_usize(5).unwrap(), array.value_offsets()[2]); + assert_eq!(O::from_usize(4).unwrap(), array.value_length(3)); + } + + #[test] + fn test_binary_builder() { + _test_generic_binary_builder::() + } + + #[test] + fn test_large_binary_builder() { + _test_generic_binary_builder::() + } + + fn _test_generic_binary_builder_all_nulls() { + let mut builder = GenericBinaryBuilder::::new(); + builder.append_null(); + builder.append_null(); + builder.append_null(); + assert_eq!(3, builder.len()); + assert!(!builder.is_empty()); + + let array = builder.finish(); + assert_eq!(3, array.null_count()); + assert_eq!(3, array.len()); + assert!(array.is_null(0)); + assert!(array.is_null(1)); + assert!(array.is_null(2)); + } + + #[test] + fn test_binary_builder_all_nulls() { + _test_generic_binary_builder_all_nulls::() + } + + #[test] + fn test_large_binary_builder_all_nulls() { + _test_generic_binary_builder_all_nulls::() + } + + fn _test_generic_binary_builder_reset() { + let mut builder = GenericBinaryBuilder::::new(); + + builder.append_value(b"hello"); + builder.append_value(b""); + builder.append_null(); + builder.append_value(b"rust"); + builder.finish(); + + assert!(builder.is_empty()); + + builder.append_value(b"parquet"); + builder.append_null(); + builder.append_value(b"arrow"); + builder.append_value(b""); + let array = builder.finish(); + + assert_eq!(4, array.len()); + assert_eq!(1, array.null_count()); + assert_eq!(b"parquet", array.value(0)); + assert!(array.is_null(1)); + assert_eq!(b"arrow", array.value(2)); + assert_eq!(b"", array.value(1)); + assert_eq!(O::zero(), array.value_offsets()[0]); + assert_eq!(O::from_usize(7).unwrap(), array.value_offsets()[2]); + assert_eq!(O::from_usize(5).unwrap(), array.value_length(2)); + } #[test] - fn test_binary_array_builder() { - let mut builder = BinaryBuilder::new(20); - - builder.append_byte(b'h').unwrap(); - builder.append_byte(b'e').unwrap(); - builder.append_byte(b'l').unwrap(); - builder.append_byte(b'l').unwrap(); - builder.append_byte(b'o').unwrap(); - builder.append(true).unwrap(); - builder.append(true).unwrap(); - builder.append_byte(b'w').unwrap(); - builder.append_byte(b'o').unwrap(); - builder.append_byte(b'r').unwrap(); - builder.append_byte(b'l').unwrap(); - builder.append_byte(b'd').unwrap(); - builder.append(true).unwrap(); - - let binary_array = builder.finish(); - - assert_eq!(3, binary_array.len()); - assert_eq!(0, binary_array.null_count()); - assert_eq!([b'h', b'e', b'l', b'l', b'o'], binary_array.value(0)); - assert_eq!([] as [u8; 0], binary_array.value(1)); - assert_eq!([b'w', b'o', b'r', b'l', b'd'], binary_array.value(2)); - assert_eq!(5, binary_array.value_offsets()[2]); - assert_eq!(5, binary_array.value_length(2)); + fn test_binary_builder_reset() { + _test_generic_binary_builder_reset::() } #[test] - fn test_large_binary_array_builder() { - let mut builder = LargeBinaryBuilder::new(20); - - builder.append_byte(b'h').unwrap(); - builder.append_byte(b'e').unwrap(); - builder.append_byte(b'l').unwrap(); - builder.append_byte(b'l').unwrap(); - builder.append_byte(b'o').unwrap(); - builder.append(true).unwrap(); - builder.append(true).unwrap(); - builder.append_byte(b'w').unwrap(); - builder.append_byte(b'o').unwrap(); - builder.append_byte(b'r').unwrap(); - builder.append_byte(b'l').unwrap(); - builder.append_byte(b'd').unwrap(); - builder.append(true).unwrap(); - - let binary_array = builder.finish(); - - assert_eq!(3, binary_array.len()); - assert_eq!(0, binary_array.null_count()); - assert_eq!([b'h', b'e', b'l', b'l', b'o'], binary_array.value(0)); - assert_eq!([] as [u8; 0], binary_array.value(1)); - assert_eq!([b'w', b'o', b'r', b'l', b'd'], binary_array.value(2)); - assert_eq!(5, binary_array.value_offsets()[2]); - assert_eq!(5, binary_array.value_length(2)); + fn test_large_binary_builder_reset() { + _test_generic_binary_builder_reset::() } } diff --git a/arrow/src/array/builder/generic_list_builder.rs b/arrow/src/array/builder/generic_list_builder.rs index cc39aad699e3..1beda7114171 100644 --- a/arrow/src/array/builder/generic_list_builder.rs +++ b/arrow/src/array/builder/generic_list_builder.rs @@ -22,17 +22,15 @@ use crate::array::ArrayData; use crate::array::ArrayRef; use crate::array::GenericListArray; use crate::array::OffsetSizeTrait; -use crate::datatypes::DataType; use crate::datatypes::Field; -use crate::error::Result; -use super::{ArrayBuilder, BooleanBufferBuilder, BufferBuilder}; +use super::{ArrayBuilder, BufferBuilder, NullBufferBuilder}; /// Array builder for [`GenericListArray`] #[derive(Debug)] pub struct GenericListBuilder { offsets_builder: BufferBuilder, - bitmap_builder: BooleanBufferBuilder, + null_buffer_builder: NullBufferBuilder, values_builder: T, } @@ -50,7 +48,7 @@ impl GenericListBuilder usize { - self.bitmap_builder.len() + self.null_buffer_builder.len() } /// Returns whether the number of array slots is zero fn is_empty(&self) -> bool { - self.bitmap_builder.is_empty() + self.null_buffer_builder.is_empty() } /// Builds the array and reset this builder. @@ -111,11 +109,10 @@ where /// Finish the current variable-length list array slot #[inline] - pub fn append(&mut self, is_valid: bool) -> Result<()> { + pub fn append(&mut self, is_valid: bool) { self.offsets_builder .append(OffsetSize::from_usize(self.values_builder.len()).unwrap()); - self.bitmap_builder.append(is_valid); - Ok(()) + self.null_buffer_builder.append(is_valid); } /// Builds the [`GenericListArray`] and reset this builder. @@ -130,23 +127,19 @@ where let values_data = values_arr.data(); let offset_buffer = self.offsets_builder.finish(); - let null_bit_buffer = self.bitmap_builder.finish(); + let null_bit_buffer = self.null_buffer_builder.finish(); self.offsets_builder.append(OffsetSize::zero()); let field = Box::new(Field::new( "item", values_data.data_type().clone(), true, // TODO: find a consistent way of getting this )); - let data_type = if OffsetSize::IS_LARGE { - DataType::LargeList(field) - } else { - DataType::List(field) - }; + let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(field); let array_data_builder = ArrayData::builder(data_type) .len(len) .add_buffer(offset_buffer) .add_child_data(values_data.clone()) - .null_bit_buffer(Some(null_bit_buffer)); + .null_bit_buffer(null_bit_buffer); let array_data = unsafe { array_data_builder.build_unchecked() }; @@ -165,23 +158,24 @@ mod tests { use crate::array::builder::ListBuilder; use crate::array::{Array, Int32Array, Int32Builder}; use crate::buffer::Buffer; + use crate::datatypes::DataType; fn _test_generic_list_array_builder() { - let values_builder = Int32Builder::new(10); + let values_builder = Int32Builder::with_capacity(10); let mut builder = GenericListBuilder::::new(values_builder); // [[0, 1, 2], [3, 4, 5], [6, 7]] - builder.values().append_value(0).unwrap(); - builder.values().append_value(1).unwrap(); - builder.values().append_value(2).unwrap(); - builder.append(true).unwrap(); - builder.values().append_value(3).unwrap(); - builder.values().append_value(4).unwrap(); - builder.values().append_value(5).unwrap(); - builder.append(true).unwrap(); - builder.values().append_value(6).unwrap(); - builder.values().append_value(7).unwrap(); - builder.append(true).unwrap(); + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.values().append_value(3); + builder.values().append_value(4); + builder.values().append_value(5); + builder.append(true); + builder.values().append_value(6); + builder.values().append_value(7); + builder.append(true); let list_array = builder.finish(); let values = list_array.values().data().buffers()[0].clone(); @@ -212,22 +206,23 @@ mod tests { } fn _test_generic_list_array_builder_nulls() { - let values_builder = Int32Builder::new(10); + let values_builder = Int32Builder::with_capacity(10); let mut builder = GenericListBuilder::::new(values_builder); // [[0, 1, 2], null, [3, null, 5], [6, 7]] - builder.values().append_value(0).unwrap(); - builder.values().append_value(1).unwrap(); - builder.values().append_value(2).unwrap(); - builder.append(true).unwrap(); - builder.append(false).unwrap(); - builder.values().append_value(3).unwrap(); - builder.values().append_null().unwrap(); - builder.values().append_value(5).unwrap(); - builder.append(true).unwrap(); - builder.values().append_value(6).unwrap(); - builder.values().append_value(7).unwrap(); - builder.append(true).unwrap(); + builder.values().append_value(0); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true); + builder.append(false); + builder.values().append_value(3); + builder.values().append_null(); + builder.values().append_value(5); + builder.append(true); + builder.values().append_value(6); + builder.values().append_value(7); + builder.append(true); + let list_array = builder.finish(); assert_eq!(DataType::Int32, list_array.value_type()); @@ -252,17 +247,17 @@ mod tests { let values_builder = Int32Array::builder(5); let mut builder = ListBuilder::new(values_builder); - builder.values().append_slice(&[1, 2, 3]).unwrap(); - builder.append(true).unwrap(); - builder.values().append_slice(&[4, 5, 6]).unwrap(); - builder.append(true).unwrap(); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(true); let mut arr = builder.finish(); assert_eq!(2, arr.len()); assert!(builder.is_empty()); - builder.values().append_slice(&[7, 8, 9]).unwrap(); - builder.append(true).unwrap(); + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); arr = builder.finish(); assert_eq!(1, arr.len()); assert!(builder.is_empty()); @@ -270,34 +265,34 @@ mod tests { #[test] fn test_list_list_array_builder() { - let primitive_builder = Int32Builder::new(10); + let primitive_builder = Int32Builder::with_capacity(10); let values_builder = ListBuilder::new(primitive_builder); let mut builder = ListBuilder::new(values_builder); // [[[1, 2], [3, 4]], [[5, 6, 7], null, [8]], null, [[9, 10]]] - builder.values().values().append_value(1).unwrap(); - builder.values().values().append_value(2).unwrap(); - builder.values().append(true).unwrap(); - builder.values().values().append_value(3).unwrap(); - builder.values().values().append_value(4).unwrap(); - builder.values().append(true).unwrap(); - builder.append(true).unwrap(); - - builder.values().values().append_value(5).unwrap(); - builder.values().values().append_value(6).unwrap(); - builder.values().values().append_value(7).unwrap(); - builder.values().append(true).unwrap(); - builder.values().append(false).unwrap(); - builder.values().values().append_value(8).unwrap(); - builder.values().append(true).unwrap(); - builder.append(true).unwrap(); - - builder.append(false).unwrap(); - - builder.values().values().append_value(9).unwrap(); - builder.values().values().append_value(10).unwrap(); - builder.values().append(true).unwrap(); - builder.append(true).unwrap(); + builder.values().values().append_value(1); + builder.values().values().append_value(2); + builder.values().append(true); + builder.values().values().append_value(3); + builder.values().values().append_value(4); + builder.values().append(true); + builder.append(true); + + builder.values().values().append_value(5); + builder.values().values().append_value(6); + builder.values().values().append_value(7); + builder.values().append(true); + builder.values().append(false); + builder.values().values().append_value(8); + builder.values().append(true); + builder.append(true); + + builder.append(false); + + builder.values().values().append_value(9); + builder.values().values().append_value(10); + builder.values().append(true); + builder.append(true); let list_array = builder.finish(); diff --git a/arrow/src/array/builder/generic_string_builder.rs b/arrow/src/array/builder/generic_string_builder.rs index 04205f87865b..8f69f5d9c7be 100644 --- a/arrow/src/array/builder/generic_string_builder.rs +++ b/arrow/src/array/builder/generic_string_builder.rs @@ -15,82 +15,64 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{ - ArrayBuilder, ArrayRef, GenericListBuilder, GenericStringArray, OffsetSizeTrait, - UInt8Builder, -}; -use crate::error::Result; +use crate::array::{ArrayBuilder, ArrayRef, GenericStringArray, OffsetSizeTrait}; use std::any::Any; use std::sync::Arc; +use super::GenericBinaryBuilder; + +/// Array builder for [`GenericStringArray`] #[derive(Debug)] pub struct GenericStringBuilder { - builder: GenericListBuilder, + builder: GenericBinaryBuilder, } impl GenericStringBuilder { - /// Creates a new `StringBuilder`, - /// `capacity` is the number of bytes of string data to pre-allocate space for in this builder - pub fn new(capacity: usize) -> Self { - let values_builder = UInt8Builder::new(capacity); + /// Creates a new [`GenericStringBuilder`], + pub fn new() -> Self { Self { - builder: GenericListBuilder::new(values_builder), + builder: GenericBinaryBuilder::new(), } } - /// Creates a new `StringBuilder`, + /// Creates a new [`GenericStringBuilder`], /// `data_capacity` is the number of bytes of string data to pre-allocate space for in this builder /// `item_capacity` is the number of items to pre-allocate space for in this builder pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let values_builder = UInt8Builder::new(data_capacity); Self { - builder: GenericListBuilder::with_capacity(values_builder, item_capacity), + builder: GenericBinaryBuilder::with_capacity(item_capacity, data_capacity), } } /// Appends a string into the builder. - /// - /// Automatically calls the `append` method to delimit the string appended in as a - /// distinct array element. #[inline] - pub fn append_value(&mut self, value: impl AsRef) -> Result<()> { - self.builder - .values() - .append_slice(value.as_ref().as_bytes())?; - self.builder.append(true)?; - Ok(()) - } - - /// Finish the current variable-length list array slot. - #[inline] - pub fn append(&mut self, is_valid: bool) -> Result<()> { - self.builder.append(is_valid) + pub fn append_value(&mut self, value: impl AsRef) { + self.builder.append_value(value.as_ref().as_bytes()); } /// Append a null value to the array. #[inline] - pub fn append_null(&mut self) -> Result<()> { - self.append(false) + pub fn append_null(&mut self) { + self.builder.append_null() } /// Append an `Option` value to the array. #[inline] - pub fn append_option(&mut self, value: Option>) -> Result<()> { + pub fn append_option(&mut self, value: Option>) { match value { - None => self.append_null()?, - Some(v) => self.append_value(v)?, + None => self.append_null(), + Some(v) => self.append_value(v), }; - Ok(()) } - /// Builds the `StringArray` and reset this builder. + /// Builds the [`GenericStringArray`] and reset this builder. pub fn finish(&mut self) -> GenericStringArray { GenericStringArray::::from(self.builder.finish()) } /// Returns the current values buffer as a slice pub fn values_slice(&self) -> &[u8] { - self.builder.values_ref().values_slice() + self.builder.values_slice() } /// Returns the current offsets buffer as a slice @@ -99,6 +81,12 @@ impl GenericStringBuilder { } } +impl Default for GenericStringBuilder { + fn default() -> Self { + Self::new() + } +} + impl ArrayBuilder for GenericStringBuilder { /// Returns the builder as a non-mutable `Any` reference. fn as_any(&self) -> &dyn Any { @@ -134,79 +122,72 @@ impl ArrayBuilder for GenericStringBuilder() { + let mut builder = GenericStringBuilder::::new(); + let owned = "arrow".to_owned(); + + builder.append_value("hello"); + builder.append_value(""); + builder.append_value(&owned); + builder.append_null(); + builder.append_option(Some("rust")); + builder.append_option(None::<&str>); + builder.append_option(None::); + assert_eq!(7, builder.len()); + + assert_eq!( + GenericStringArray::::from(vec![ + Some("hello"), + Some(""), + Some("arrow"), + None, + Some("rust"), + None, + None + ]), + builder.finish() + ); + } #[test] fn test_string_array_builder() { - let mut builder = StringBuilder::new(20); - - builder.append_value("hello").unwrap(); - builder.append(true).unwrap(); - builder.append_value("world").unwrap(); - - let string_array = builder.finish(); - - assert_eq!(3, string_array.len()); - assert_eq!(0, string_array.null_count()); - assert_eq!("hello", string_array.value(0)); - assert_eq!("", string_array.value(1)); - assert_eq!("world", string_array.value(2)); - assert_eq!(5, string_array.value_offsets()[2]); - assert_eq!(5, string_array.value_length(2)); + _test_generic_string_array_builder::() } #[test] - fn test_string_array_builder_finish() { - let mut builder = StringBuilder::new(10); + fn test_large_string_array_builder() { + _test_generic_string_array_builder::() + } + + fn _test_generic_string_array_builder_finish() { + let mut builder = GenericStringBuilder::::with_capacity(3, 11); - builder.append_value("hello").unwrap(); - builder.append_value("world").unwrap(); + builder.append_value("hello"); + builder.append_value("rust"); + builder.append_null(); - let mut arr = builder.finish(); - assert_eq!(2, arr.len()); - assert_eq!(0, builder.len()); + builder.finish(); + assert!(builder.is_empty()); + assert_eq!(&[O::zero()], builder.offsets_slice()); - builder.append_value("arrow").unwrap(); - arr = builder.finish(); - assert_eq!(1, arr.len()); - assert_eq!(0, builder.len()); + builder.append_value("arrow"); + builder.append_value("parquet"); + let arr = builder.finish(); + // array should not have null buffer because there is not `null` value. + assert_eq!(None, arr.data().null_buffer()); + assert_eq!(GenericStringArray::::from(vec!["arrow", "parquet"]), arr,) } #[test] - fn test_string_array_builder_append_string() { - let mut builder = StringBuilder::new(20); - - let var = "hello".to_owned(); - builder.append_value(&var).unwrap(); - builder.append(true).unwrap(); - builder.append_value("world").unwrap(); - - let string_array = builder.finish(); - - assert_eq!(3, string_array.len()); - assert_eq!(0, string_array.null_count()); - assert_eq!("hello", string_array.value(0)); - assert_eq!("", string_array.value(1)); - assert_eq!("world", string_array.value(2)); - assert_eq!(5, string_array.value_offsets()[2]); - assert_eq!(5, string_array.value_length(2)); + fn test_string_array_builder_finish() { + _test_generic_string_array_builder_finish::() } #[test] - fn test_string_array_builder_append_option() { - let mut builder = StringBuilder::new(20); - builder.append_option(Some("hello")).unwrap(); - builder.append_option(None::<&str>).unwrap(); - builder.append_option(None::).unwrap(); - builder.append_option(Some("world")).unwrap(); - - let string_array = builder.finish(); - - assert_eq!(4, string_array.len()); - assert_eq!("hello", string_array.value(0)); - assert!(string_array.is_null(1)); - assert!(string_array.is_null(2)); - assert_eq!("world", string_array.value(3)); + fn test_large_string_array_builder_finish() { + _test_generic_string_array_builder_finish::() } } diff --git a/arrow/src/array/builder/map_builder.rs b/arrow/src/array/builder/map_builder.rs index 7c30218972e5..766e8a56b387 100644 --- a/arrow/src/array/builder/map_builder.rs +++ b/arrow/src/array/builder/map_builder.rs @@ -18,6 +18,7 @@ use std::any::Any; use std::sync::Arc; +use super::{ArrayBuilder, BufferBuilder, NullBufferBuilder}; use crate::array::array::Array; use crate::array::ArrayData; use crate::array::ArrayRef; @@ -25,18 +26,16 @@ use crate::array::MapArray; use crate::array::StructArray; use crate::datatypes::DataType; use crate::datatypes::Field; -use crate::error::{ArrowError, Result}; - -use super::{ArrayBuilder, BooleanBufferBuilder, BufferBuilder}; +use crate::error::ArrowError; +use crate::error::Result; #[derive(Debug)] pub struct MapBuilder { offsets_builder: BufferBuilder, - bitmap_builder: BooleanBufferBuilder, + null_buffer_builder: NullBufferBuilder, field_names: MapFieldNames, key_builder: K, value_builder: V, - len: i32, } #[derive(Debug, Clone)] @@ -78,11 +77,10 @@ impl MapBuilder { offsets_builder.append(len); Self { offsets_builder, - bitmap_builder: BooleanBufferBuilder::new(capacity), + null_buffer_builder: NullBufferBuilder::new(capacity), field_names: field_names.unwrap_or_default(), key_builder, value_builder, - len, } } @@ -95,6 +93,8 @@ impl MapBuilder { } /// Finish the current map array slot + /// + /// Returns an error if the key and values builders are in an inconsistent state. #[inline] pub fn append(&mut self, is_valid: bool) -> Result<()> { if self.key_builder.len() != self.value_builder.len() { @@ -105,14 +105,12 @@ impl MapBuilder { ))); } self.offsets_builder.append(self.key_builder.len() as i32); - self.bitmap_builder.append(is_valid); - self.len += 1; + self.null_buffer_builder.append(is_valid); Ok(()) } pub fn finish(&mut self) -> MapArray { let len = self.len(); - self.len = 0; // Build the keys let keys_arr = self @@ -143,8 +141,8 @@ impl MapBuilder { StructArray::from(vec![(keys_field, keys_arr), (values_field, values_arr)]); let offset_buffer = self.offsets_builder.finish(); - let null_bit_buffer = self.bitmap_builder.finish(); - self.offsets_builder.append(self.len); + let null_bit_buffer = self.null_buffer_builder.finish(); + self.offsets_builder.append(0); let map_field = Box::new(Field::new( self.field_names.entry.as_str(), struct_array.data_type().clone(), @@ -154,7 +152,7 @@ impl MapBuilder { .len(len) .add_buffer(offset_buffer) .add_child_data(struct_array.into_data()) - .null_bit_buffer(Some(null_bit_buffer)); + .null_bit_buffer(null_bit_buffer); let array_data = unsafe { array_data.build_unchecked() }; @@ -164,11 +162,11 @@ impl MapBuilder { impl ArrayBuilder for MapBuilder { fn len(&self) -> usize { - self.len as usize + self.null_buffer_builder.len() } fn is_empty(&self) -> bool { - self.len == 0 + self.len() == 0 } fn finish(&mut self) -> ArrayRef { @@ -203,22 +201,22 @@ mod tests { #[test] fn test_map_array_builder() { - let string_builder = StringBuilder::new(4); - let int_builder = Int32Builder::new(4); + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); let mut builder = MapBuilder::new(None, string_builder, int_builder); let string_builder = builder.keys(); - string_builder.append_value("joe").unwrap(); - string_builder.append_null().unwrap(); - string_builder.append_null().unwrap(); - string_builder.append_value("mark").unwrap(); + string_builder.append_value("joe"); + string_builder.append_null(); + string_builder.append_null(); + string_builder.append_value("mark"); let int_builder = builder.values(); - int_builder.append_value(1).unwrap(); - int_builder.append_value(2).unwrap(); - int_builder.append_null().unwrap(); - int_builder.append_value(4).unwrap(); + int_builder.append_value(1); + int_builder.append_value(2); + int_builder.append_null(); + int_builder.append_value(4); builder.append(true).unwrap(); builder.append(false).unwrap(); diff --git a/arrow/src/array/builder/mod.rs b/arrow/src/array/builder/mod.rs index 045a11648d52..c02acb32653f 100644 --- a/arrow/src/array/builder/mod.rs +++ b/arrow/src/array/builder/mod.rs @@ -30,6 +30,7 @@ mod generic_binary_builder; mod generic_list_builder; mod generic_string_builder; mod map_builder; +mod null_buffer_builder; mod primitive_builder; mod primitive_dictionary_builder; mod string_dictionary_builder; @@ -45,14 +46,15 @@ use super::ArrayRef; pub use boolean_buffer_builder::BooleanBufferBuilder; pub use boolean_builder::BooleanBuilder; pub use buffer_builder::BufferBuilder; +pub use decimal_builder::Decimal128Builder; pub use decimal_builder::Decimal256Builder; -pub use decimal_builder::DecimalBuilder; pub use fixed_size_binary_builder::FixedSizeBinaryBuilder; pub use fixed_size_list_builder::FixedSizeListBuilder; pub use generic_binary_builder::GenericBinaryBuilder; pub use generic_list_builder::GenericListBuilder; pub use generic_string_builder::GenericStringBuilder; -pub use map_builder::MapBuilder; +pub use map_builder::{MapBuilder, MapFieldNames}; +use null_buffer_builder::NullBufferBuilder; pub use primitive_builder::PrimitiveBuilder; pub use primitive_dictionary_builder::PrimitiveDictionaryBuilder; pub use string_dictionary_builder::StringDictionaryBuilder; @@ -71,9 +73,9 @@ pub use union_builder::UnionBuilder; /// # fn main() -> std::result::Result<(), ArrowError> { /// // Create /// let mut data_builders: Vec> = vec![ -/// Box::new(Float64Builder::new(1024)), -/// Box::new(Int64Builder::new(1024)), -/// Box::new(StringBuilder::new(1024)), +/// Box::new(Float64Builder::new()), +/// Box::new(Int64Builder::new()), +/// Box::new(StringBuilder::new()), /// ]; /// /// // Fill @@ -81,17 +83,17 @@ pub use union_builder::UnionBuilder; /// .as_any_mut() /// .downcast_mut::() /// .unwrap() -/// .append_value(3.14)?; +/// .append_value(3.14); /// data_builders[1] /// .as_any_mut() /// .downcast_mut::() /// .unwrap() -/// .append_value(-1)?; +/// .append_value(-1); /// data_builders[2] /// .as_any_mut() /// .downcast_mut::() /// .unwrap() -/// .append_value("🍎")?; +/// .append_value("🍎"); /// /// // Finish /// let array_refs: Vec = data_builders diff --git a/arrow/src/array/builder/null_buffer_builder.rs b/arrow/src/array/builder/null_buffer_builder.rs new file mode 100644 index 000000000000..ef2e4c50ab9c --- /dev/null +++ b/arrow/src/array/builder/null_buffer_builder.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::buffer::Buffer; + +use super::BooleanBufferBuilder; + +/// Builder for creating the null bit buffer. +/// This builder only materializes the buffer when we append `false`. +/// If you only append `true`s to the builder, what you get will be +/// `None` when calling [`finish`](#method.finish). +/// This optimization is **very** important for the performance. +#[derive(Debug)] +pub(super) struct NullBufferBuilder { + bitmap_builder: Option, + /// Store the length of the buffer before materializing. + len: usize, + capacity: usize, +} + +impl NullBufferBuilder { + /// Creates a new empty builder. + /// `capacity` is the number of bits in the null buffer. + pub fn new(capacity: usize) -> Self { + Self { + bitmap_builder: None, + len: 0, + capacity, + } + } + + /// Appends `n` `true`s into the builder + /// to indicate that these `n` items are not nulls. + #[inline] + pub fn append_n_non_nulls(&mut self, n: usize) { + if let Some(buf) = self.bitmap_builder.as_mut() { + buf.append_n(n, true) + } else { + self.len += n; + } + } + + /// Appends a `true` into the builder + /// to indicate that this item is not null. + #[inline] + pub fn append_non_null(&mut self) { + if let Some(buf) = self.bitmap_builder.as_mut() { + buf.append(true) + } else { + self.len += 1; + } + } + + /// Appends `n` `false`s into the builder + /// to indicate that these `n` items are nulls. + #[inline] + pub fn append_n_nulls(&mut self, n: usize) { + self.materialize_if_needed(); + self.bitmap_builder.as_mut().unwrap().append_n(n, false); + } + + /// Appends a `false` into the builder + /// to indicate that this item is null. + #[inline] + pub fn append_null(&mut self) { + self.materialize_if_needed(); + self.bitmap_builder.as_mut().unwrap().append(false); + } + + /// Appends a boolean value into the builder. + #[inline] + pub fn append(&mut self, not_null: bool) { + if not_null { + self.append_non_null() + } else { + self.append_null() + } + } + + /// Appends a boolean slice into the builder + /// to indicate the validations of these items. + pub fn append_slice(&mut self, slice: &[bool]) { + if slice.iter().any(|v| !v) { + self.materialize_if_needed() + } + if let Some(buf) = self.bitmap_builder.as_mut() { + buf.append_slice(slice) + } else { + self.len += slice.len(); + } + } + + /// Builds the null buffer and resets the builder. + /// Returns `None` if the builder only contains `true`s. + pub fn finish(&mut self) -> Option { + let buf = self.bitmap_builder.as_mut().map(|b| b.finish()); + self.bitmap_builder = None; + self.len = 0; + buf + } + + #[inline] + fn materialize_if_needed(&mut self) { + if self.bitmap_builder.is_none() { + self.materialize() + } + } + + #[cold] + fn materialize(&mut self) { + if self.bitmap_builder.is_none() { + let mut b = BooleanBufferBuilder::new(self.len.max(self.capacity)); + b.append_n(self.len, true); + self.bitmap_builder = Some(b); + } + } +} + +impl NullBufferBuilder { + pub fn len(&self) -> usize { + if let Some(b) = &self.bitmap_builder { + b.len() + } else { + self.len + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_null_buffer_builder() { + let mut builder = NullBufferBuilder::new(0); + builder.append_null(); + builder.append_non_null(); + builder.append_n_nulls(2); + builder.append_n_non_nulls(2); + assert_eq!(6, builder.len()); + + let buf = builder.finish().unwrap(); + assert_eq!(Buffer::from(&[0b110010_u8]), buf); + } + + #[test] + fn test_null_buffer_builder_all_nulls() { + let mut builder = NullBufferBuilder::new(0); + builder.append_null(); + builder.append_n_nulls(2); + builder.append_slice(&[false, false, false]); + assert_eq!(6, builder.len()); + + let buf = builder.finish().unwrap(); + assert_eq!(Buffer::from(&[0b0_u8]), buf); + } + + #[test] + fn test_null_buffer_builder_no_null() { + let mut builder = NullBufferBuilder::new(0); + builder.append_non_null(); + builder.append_n_non_nulls(2); + builder.append_slice(&[true, true, true]); + assert_eq!(6, builder.len()); + + let buf = builder.finish(); + assert!(buf.is_none()); + } + + #[test] + fn test_null_buffer_builder_reset() { + let mut builder = NullBufferBuilder::new(0); + builder.append_slice(&[true, false, true]); + builder.finish(); + assert!(builder.is_empty()); + + builder.append_slice(&[true, true, true]); + assert!(builder.finish().is_none()); + assert!(builder.is_empty()); + + builder.append_slice(&[true, true, false, true]); + + let buf = builder.finish().unwrap(); + assert_eq!(Buffer::from(&[0b1011_u8]), buf); + } +} diff --git a/arrow/src/array/builder/primitive_builder.rs b/arrow/src/array/builder/primitive_builder.rs index ec1b408edfd9..38c8b4471477 100644 --- a/arrow/src/array/builder/primitive_builder.rs +++ b/arrow/src/array/builder/primitive_builder.rs @@ -22,17 +22,14 @@ use crate::array::ArrayData; use crate::array::ArrayRef; use crate::array::PrimitiveArray; use crate::datatypes::ArrowPrimitiveType; -use crate::error::{ArrowError, Result}; -use super::{ArrayBuilder, BooleanBufferBuilder, BufferBuilder}; +use super::{ArrayBuilder, BufferBuilder, NullBufferBuilder}; /// Array builder for fixed-width primitive types #[derive(Debug)] pub struct PrimitiveBuilder { values_builder: BufferBuilder, - /// We only materialize the builder when we add `false`. - /// This optimization is **very** important for performance of `StringBuilder`. - bitmap_builder: Option, + null_buffer_builder: NullBufferBuilder, } impl ArrayBuilder for PrimitiveBuilder { @@ -67,12 +64,23 @@ impl ArrayBuilder for PrimitiveBuilder { } } +impl Default for PrimitiveBuilder { + fn default() -> Self { + Self::new() + } +} + impl PrimitiveBuilder { /// Creates a new primitive array builder - pub fn new(capacity: usize) -> Self { + pub fn new() -> Self { + Self::with_capacity(1024) + } + + /// Creates a new primitive array builder with capacity no of items + pub fn with_capacity(capacity: usize) -> Self { Self { values_builder: BufferBuilder::::new(capacity), - bitmap_builder: None, + null_buffer_builder: NullBufferBuilder::new(capacity), } } @@ -83,71 +91,50 @@ impl PrimitiveBuilder { /// Appends a value of type `T` into the builder #[inline] - pub fn append_value(&mut self, v: T::Native) -> Result<()> { - if let Some(b) = self.bitmap_builder.as_mut() { - b.append(true); - } + pub fn append_value(&mut self, v: T::Native) { + self.null_buffer_builder.append_non_null(); self.values_builder.append(v); - Ok(()) } /// Appends a null slot into the builder #[inline] - pub fn append_null(&mut self) -> Result<()> { - self.materialize_bitmap_builder(); - self.bitmap_builder.as_mut().unwrap().append(false); + pub fn append_null(&mut self) { + self.null_buffer_builder.append_null(); self.values_builder.advance(1); - Ok(()) } #[inline] - pub fn append_nulls(&mut self, n: usize) -> Result<()> { - self.materialize_bitmap_builder(); - self.bitmap_builder.as_mut().unwrap().append_n(n, false); + pub fn append_nulls(&mut self, n: usize) { + self.null_buffer_builder.append_n_nulls(n); self.values_builder.advance(n); - Ok(()) } /// Appends an `Option` into the builder #[inline] - pub fn append_option(&mut self, v: Option) -> Result<()> { + pub fn append_option(&mut self, v: Option) { match v { - None => self.append_null()?, - Some(v) => self.append_value(v)?, + None => self.append_null(), + Some(v) => self.append_value(v), }; - Ok(()) } /// Appends a slice of type `T` into the builder #[inline] - pub fn append_slice(&mut self, v: &[T::Native]) -> Result<()> { - if let Some(b) = self.bitmap_builder.as_mut() { - b.append_n(v.len(), true); - } + pub fn append_slice(&mut self, v: &[T::Native]) { + self.null_buffer_builder.append_n_non_nulls(v.len()); self.values_builder.append_slice(v); - Ok(()) } /// Appends values from a slice of type `T` and a validity boolean slice #[inline] - pub fn append_values( - &mut self, - values: &[T::Native], - is_valid: &[bool], - ) -> Result<()> { - if values.len() != is_valid.len() { - return Err(ArrowError::InvalidArgumentError( - "Value and validity lengths must be equal".to_string(), - )); - } - if is_valid.iter().any(|v| !*v) { - self.materialize_bitmap_builder(); - } - if let Some(b) = self.bitmap_builder.as_mut() { - b.append_slice(is_valid); - } + pub fn append_values(&mut self, values: &[T::Native], is_valid: &[bool]) { + assert_eq!( + values.len(), + is_valid.len(), + "Value and validity lengths must be equal" + ); + self.null_buffer_builder.append_slice(is_valid); self.values_builder.append_slice(values); - Ok(()) } /// Appends values from a trusted length iterator. @@ -159,52 +146,30 @@ impl PrimitiveBuilder { pub unsafe fn append_trusted_len_iter( &mut self, iter: impl IntoIterator, - ) -> Result<()> { + ) { let iter = iter.into_iter(); let len = iter .size_hint() .1 .expect("append_trusted_len_iter requires an upper bound"); - if let Some(b) = self.bitmap_builder.as_mut() { - b.append_n(len, true); - } + self.null_buffer_builder.append_n_non_nulls(len); self.values_builder.append_trusted_len_iter(iter); - Ok(()) } - /// Builds the `PrimitiveArray` and reset this builder. + /// Builds the [`PrimitiveArray`] and reset this builder. pub fn finish(&mut self) -> PrimitiveArray { let len = self.len(); - let null_bit_buffer = self.bitmap_builder.as_mut().map(|b| b.finish()); - let null_count = len - - null_bit_buffer - .as_ref() - .map(|b| b.count_set_bits()) - .unwrap_or(len); + let null_bit_buffer = self.null_buffer_builder.finish(); let builder = ArrayData::builder(T::DATA_TYPE) .len(len) .add_buffer(self.values_builder.finish()) - .null_bit_buffer(if null_count > 0 { - null_bit_buffer - } else { - None - }); + .null_bit_buffer(null_bit_buffer); let array_data = unsafe { builder.build_unchecked() }; PrimitiveArray::::from(array_data) } - fn materialize_bitmap_builder(&mut self) { - if self.bitmap_builder.is_some() { - return; - } - let mut b = BooleanBufferBuilder::new(0); - b.reserve(self.values_builder.capacity()); - b.append_n(self.values_builder.len(), true); - self.bitmap_builder = Some(b); - } - /// Returns the current values buffer as a slice pub fn values_slice(&self) -> &[T::Native] { self.values_builder.as_slice() @@ -216,16 +181,18 @@ mod tests { use super::*; use crate::array::Array; + use crate::array::BooleanArray; use crate::array::Date32Array; use crate::array::Int32Array; use crate::array::Int32Builder; use crate::array::TimestampSecondArray; + use crate::buffer::Buffer; #[test] fn test_primitive_array_builder_i32() { let mut builder = Int32Array::builder(5); for i in 0..5 { - builder.append_value(i).unwrap(); + builder.append_value(i); } let arr = builder.finish(); assert_eq!(5, arr.len()); @@ -241,7 +208,7 @@ mod tests { #[test] fn test_primitive_array_builder_i32_append_iter() { let mut builder = Int32Array::builder(5); - unsafe { builder.append_trusted_len_iter(0..5) }.unwrap(); + unsafe { builder.append_trusted_len_iter(0..5) }; let arr = builder.finish(); assert_eq!(5, arr.len()); assert_eq!(0, arr.offset()); @@ -256,7 +223,7 @@ mod tests { #[test] fn test_primitive_array_builder_i32_append_nulls() { let mut builder = Int32Array::builder(5); - builder.append_nulls(5).unwrap(); + builder.append_nulls(5); let arr = builder.finish(); assert_eq!(5, arr.len()); assert_eq!(0, arr.offset()); @@ -271,7 +238,7 @@ mod tests { fn test_primitive_array_builder_date32() { let mut builder = Date32Array::builder(5); for i in 0..5 { - builder.append_value(i).unwrap(); + builder.append_value(i); } let arr = builder.finish(); assert_eq!(5, arr.len()); @@ -288,7 +255,7 @@ mod tests { fn test_primitive_array_builder_timestamp_second() { let mut builder = TimestampSecondArray::builder(5); for i in 0..5 { - builder.append_value(i).unwrap(); + builder.append_value(i); } let arr = builder.finish(); assert_eq!(5, arr.len()); @@ -301,16 +268,41 @@ mod tests { } } + #[test] + fn test_primitive_array_builder_bool() { + // 00000010 01001000 + let buf = Buffer::from([72_u8, 2_u8]); + let mut builder = BooleanArray::builder(10); + for i in 0..10 { + if i == 3 || i == 6 || i == 9 { + builder.append_value(true); + } else { + builder.append_value(false); + } + } + + let arr = builder.finish(); + assert_eq!(&buf, arr.values()); + assert_eq!(10, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + for i in 0..10 { + assert!(!arr.is_null(i)); + assert!(arr.is_valid(i)); + assert_eq!(i == 3 || i == 6 || i == 9, arr.value(i), "failed at {}", i) + } + } + #[test] fn test_primitive_array_builder_append_option() { let arr1 = Int32Array::from(vec![Some(0), None, Some(2), None, Some(4)]); let mut builder = Int32Array::builder(5); - builder.append_option(Some(0)).unwrap(); - builder.append_option(None).unwrap(); - builder.append_option(Some(2)).unwrap(); - builder.append_option(None).unwrap(); - builder.append_option(Some(4)).unwrap(); + builder.append_option(Some(0)); + builder.append_option(None); + builder.append_option(Some(2)); + builder.append_option(None); + builder.append_option(Some(4)); let arr2 = builder.finish(); assert_eq!(arr1.len(), arr2.len()); @@ -330,11 +322,11 @@ mod tests { let arr1 = Int32Array::from(vec![Some(0), Some(2), None, None, Some(4)]); let mut builder = Int32Array::builder(5); - builder.append_value(0).unwrap(); - builder.append_value(2).unwrap(); - builder.append_null().unwrap(); - builder.append_null().unwrap(); - builder.append_value(4).unwrap(); + builder.append_value(0); + builder.append_value(2); + builder.append_null(); + builder.append_null(); + builder.append_value(4); let arr2 = builder.finish(); assert_eq!(arr1.len(), arr2.len()); @@ -354,10 +346,10 @@ mod tests { let arr1 = Int32Array::from(vec![Some(0), Some(2), None, None, Some(4)]); let mut builder = Int32Array::builder(5); - builder.append_slice(&[0, 2]).unwrap(); - builder.append_null().unwrap(); - builder.append_null().unwrap(); - builder.append_value(4).unwrap(); + builder.append_slice(&[0, 2]); + builder.append_null(); + builder.append_null(); + builder.append_value(4); let arr2 = builder.finish(); assert_eq!(arr1.len(), arr2.len()); @@ -374,13 +366,13 @@ mod tests { #[test] fn test_primitive_array_builder_finish() { - let mut builder = Int32Builder::new(5); - builder.append_slice(&[2, 4, 6, 8]).unwrap(); + let mut builder = Int32Builder::new(); + builder.append_slice(&[2, 4, 6, 8]); let mut arr = builder.finish(); assert_eq!(4, arr.len()); assert_eq!(0, builder.len()); - builder.append_slice(&[1, 3, 5, 7, 9]).unwrap(); + builder.append_slice(&[1, 3, 5, 7, 9]); arr = builder.finish(); assert_eq!(5, arr.len()); assert_eq!(0, builder.len()); diff --git a/arrow/src/array/builder/primitive_dictionary_builder.rs b/arrow/src/array/builder/primitive_dictionary_builder.rs index 5cbd81720a86..71223c688283 100644 --- a/arrow/src/array/builder/primitive_dictionary_builder.rs +++ b/arrow/src/array/builder/primitive_dictionary_builder.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::Arc; @@ -26,6 +27,26 @@ use crate::error::{ArrowError, Result}; use super::ArrayBuilder; use super::PrimitiveBuilder; +/// Wraps a type implementing `ToByteSlice` implementing `Hash` and `Eq` for it +/// +/// This is necessary to handle types such as f32, which don't natively implement these +#[derive(Debug)] +struct Value(T); + +impl std::hash::Hash for Value { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + self.0.to_byte_slice().eq(other.0.to_byte_slice()) + } +} + +impl Eq for Value {} + /// Array builder for `DictionaryArray`. For example to map a set of byte indices /// to f32 values. Note that the use of a `HashMap` here will not scale to very large /// arrays or result in an ordered dictionary. @@ -39,11 +60,11 @@ use super::PrimitiveBuilder; /// }; /// use arrow::datatypes::{UInt8Type, UInt32Type}; /// -/// let key_builder = PrimitiveBuilder::::new(3); -/// let value_builder = PrimitiveBuilder::::new(2); +/// let key_builder = PrimitiveBuilder::::with_capacity(3); +/// let value_builder = PrimitiveBuilder::::with_capacity(2); /// let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); /// builder.append(12345678).unwrap(); -/// builder.append_null().unwrap(); +/// builder.append_null(); /// builder.append(22345678).unwrap(); /// let array = builder.finish(); /// @@ -71,7 +92,7 @@ where { keys_builder: PrimitiveBuilder, values_builder: PrimitiveBuilder, - map: HashMap, K::Native>, + map: HashMap, K::Native>, } impl PrimitiveDictionaryBuilder @@ -138,23 +159,24 @@ where /// value is appended to the values array. #[inline] pub fn append(&mut self, value: V::Native) -> Result { - if let Some(&key) = self.map.get(value.to_byte_slice()) { - // Append existing value. - self.keys_builder.append_value(key)?; - Ok(key) - } else { - // Append new value. - let key = K::Native::from_usize(self.values_builder.len()) - .ok_or(ArrowError::DictionaryKeyOverflowError)?; - self.values_builder.append_value(value)?; - self.keys_builder.append_value(key as K::Native)?; - self.map.insert(value.to_byte_slice().into(), key); - Ok(key) - } + let key = match self.map.entry(Value(value)) { + Entry::Vacant(vacant) => { + // Append new value. + let key = K::Native::from_usize(self.values_builder.len()) + .ok_or(ArrowError::DictionaryKeyOverflowError)?; + self.values_builder.append_value(value); + vacant.insert(key); + key + } + Entry::Occupied(o) => *o.get(), + }; + + self.keys_builder.append_value(key); + Ok(key) } #[inline] - pub fn append_null(&mut self) -> Result<()> { + pub fn append_null(&mut self) { self.keys_builder.append_null() } @@ -189,11 +211,11 @@ mod tests { #[test] fn test_primitive_dictionary_builder() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(12345678).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(22345678).unwrap(); let array = builder.finish(); @@ -217,8 +239,8 @@ mod tests { #[test] #[should_panic(expected = "DictionaryKeyOverflowError")] fn test_primitive_dictionary_overflow() { - let key_builder = PrimitiveBuilder::::new(257); - let value_builder = PrimitiveBuilder::::new(257); + let key_builder = PrimitiveBuilder::::with_capacity(257); + let value_builder = PrimitiveBuilder::::with_capacity(257); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); // 256 unique keys. for i in 0..256 { diff --git a/arrow/src/array/builder/string_dictionary_builder.rs b/arrow/src/array/builder/string_dictionary_builder.rs index 77b2b23160cc..6ad4e9075524 100644 --- a/arrow/src/array/builder/string_dictionary_builder.rs +++ b/arrow/src/array/builder/string_dictionary_builder.rs @@ -42,13 +42,13 @@ use std::sync::Arc; /// // Create a dictionary array indexed by bytes whose values are Strings. /// // It can thus hold up to 256 distinct string values. /// -/// let key_builder = PrimitiveBuilder::::new(100); -/// let value_builder = StringBuilder::new(100); +/// let key_builder = PrimitiveBuilder::::with_capacity(100); +/// let value_builder = StringBuilder::new(); /// let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); /// /// // The builder builds the dictionary value by value /// builder.append("abc").unwrap(); -/// builder.append_null().unwrap(); +/// builder.append_null(); /// builder.append("def").unwrap(); /// builder.append("def").unwrap(); /// builder.append("abc").unwrap(); @@ -111,9 +111,9 @@ where /// /// let dictionary_values = StringArray::from(vec![None, Some("abc"), Some("def")]); /// - /// let mut builder = StringDictionaryBuilder::new_with_dictionary(PrimitiveBuilder::::new(3), &dictionary_values).unwrap(); + /// let mut builder = StringDictionaryBuilder::new_with_dictionary(PrimitiveBuilder::::with_capacity(3), &dictionary_values).unwrap(); /// builder.append("def").unwrap(); - /// builder.append_null().unwrap(); + /// builder.append_null(); /// builder.append("abc").unwrap(); /// /// let dictionary_array = builder.finish(); @@ -137,7 +137,7 @@ where for (idx, maybe_value) in dictionary_values.iter().enumerate() { match maybe_value { Some(value) => { - let hash = compute_hash(&state, value.as_bytes()); + let hash = state.hash_one(value.as_bytes()); let key = K::Native::from_usize(idx) .ok_or(ArrowError::DictionaryKeyOverflowError)?; @@ -149,13 +149,13 @@ where if let RawEntryMut::Vacant(v) = entry { v.insert_with_hasher(hash, key, (), |key| { - compute_hash(&state, get_bytes(&values_builder, key)) + state.hash_one(get_bytes(&values_builder, key)) }); } - values_builder.append_value(value)?; + values_builder.append_value(value); } - None => values_builder.append_null()?, + None => values_builder.append_null(), } } @@ -210,12 +210,14 @@ where /// Append a primitive value to the array. Return an existing index /// if already present in the values array or a new index if the /// value is appended to the values array. + /// + /// Returns an error if the new index would overflow the key type. pub fn append(&mut self, value: impl AsRef) -> Result { let value = value.as_ref(); let state = &self.state; let storage = &mut self.values_builder; - let hash = compute_hash(state, value.as_bytes()); + let hash = state.hash_one(value.as_bytes()); let entry = self .dedup @@ -226,24 +228,24 @@ where RawEntryMut::Occupied(entry) => *entry.into_key(), RawEntryMut::Vacant(entry) => { let index = storage.len(); - storage.append_value(value)?; + storage.append_value(value); let key = K::Native::from_usize(index) .ok_or(ArrowError::DictionaryKeyOverflowError)?; *entry .insert_with_hasher(hash, key, (), |key| { - compute_hash(state, get_bytes(storage, key)) + state.hash_one(get_bytes(storage, key)) }) .0 } }; - self.keys_builder.append_value(key)?; + self.keys_builder.append_value(key); Ok(key) } #[inline] - pub fn append_null(&mut self) -> Result<()> { + pub fn append_null(&mut self) { self.keys_builder.append_null() } @@ -266,13 +268,6 @@ where } } -fn compute_hash(hasher: &ahash::RandomState, value: &[u8]) -> u64 { - use std::hash::{BuildHasher, Hash, Hasher}; - let mut state = hasher.build_hasher(); - value.hash(&mut state); - state.finish() -} - fn get_bytes<'a, K: ArrowNativeType>(values: &'a StringBuilder, key: &K) -> &'a [u8] { let offsets = values.offsets_slice(); let values = values.values_slice(); @@ -295,11 +290,11 @@ mod tests { #[test] fn test_string_dictionary_builder() { - let key_builder = PrimitiveBuilder::::new(5); - let value_builder = StringBuilder::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(5); + let value_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); builder.append("abc").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("def").unwrap(); builder.append("def").unwrap(); builder.append("abc").unwrap(); @@ -322,12 +317,12 @@ mod tests { fn test_string_dictionary_builder_with_existing_dictionary() { let dictionary = StringArray::from(vec![None, Some("def"), Some("abc")]); - let key_builder = PrimitiveBuilder::::new(6); + let key_builder = PrimitiveBuilder::::with_capacity(6); let mut builder = StringDictionaryBuilder::new_with_dictionary(key_builder, &dictionary) .unwrap(); builder.append("abc").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("def").unwrap(); builder.append("def").unwrap(); builder.append("abc").unwrap(); @@ -354,12 +349,12 @@ mod tests { let dictionary: Vec> = vec![None]; let dictionary = StringArray::from(dictionary); - let key_builder = PrimitiveBuilder::::new(4); + let key_builder = PrimitiveBuilder::::with_capacity(4); let mut builder = StringDictionaryBuilder::new_with_dictionary(key_builder, &dictionary) .unwrap(); builder.append("abc").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("def").unwrap(); builder.append("abc").unwrap(); let array = builder.finish(); diff --git a/arrow/src/array/builder/struct_builder.rs b/arrow/src/array/builder/struct_builder.rs index 206eb17c242d..c5db09119e08 100644 --- a/arrow/src/array/builder/struct_builder.rs +++ b/arrow/src/array/builder/struct_builder.rs @@ -19,10 +19,12 @@ use std::any::Any; use std::fmt; use std::sync::Arc; +use crate::array::builder::decimal_builder::Decimal128Builder; use crate::array::*; use crate::datatypes::DataType; use crate::datatypes::Field; -use crate::error::Result; + +use super::NullBufferBuilder; /// Array builder for Struct types. /// @@ -31,16 +33,15 @@ use crate::error::Result; pub struct StructBuilder { fields: Vec, field_builders: Vec>, - bitmap_builder: BooleanBufferBuilder, - len: usize, + null_buffer_builder: NullBufferBuilder, } impl fmt::Debug for StructBuilder { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("StructBuilder") .field("fields", &self.fields) - .field("bitmap_builder", &self.bitmap_builder) - .field("len", &self.len) + .field("bitmap_builder", &self.null_buffer_builder) + .field("len", &self.len()) .finish() } } @@ -52,12 +53,12 @@ impl ArrayBuilder for StructBuilder { /// the caller's responsibility to maintain the consistency that all the child field /// builder should have the equal number of elements. fn len(&self) -> usize { - self.len + self.null_buffer_builder.len() } /// Returns whether the number of array slots is zero fn is_empty(&self) -> bool { - self.len == 0 + self.len() == 0 } /// Builds the array. @@ -95,71 +96,71 @@ impl ArrayBuilder for StructBuilder { pub fn make_builder(datatype: &DataType, capacity: usize) -> Box { match datatype { DataType::Null => unimplemented!(), - DataType::Boolean => Box::new(BooleanBuilder::new(capacity)), - DataType::Int8 => Box::new(Int8Builder::new(capacity)), - DataType::Int16 => Box::new(Int16Builder::new(capacity)), - DataType::Int32 => Box::new(Int32Builder::new(capacity)), - DataType::Int64 => Box::new(Int64Builder::new(capacity)), - DataType::UInt8 => Box::new(UInt8Builder::new(capacity)), - DataType::UInt16 => Box::new(UInt16Builder::new(capacity)), - DataType::UInt32 => Box::new(UInt32Builder::new(capacity)), - DataType::UInt64 => Box::new(UInt64Builder::new(capacity)), - DataType::Float32 => Box::new(Float32Builder::new(capacity)), - DataType::Float64 => Box::new(Float64Builder::new(capacity)), - DataType::Binary => Box::new(BinaryBuilder::new(capacity)), + DataType::Boolean => Box::new(BooleanBuilder::with_capacity(capacity)), + DataType::Int8 => Box::new(Int8Builder::with_capacity(capacity)), + DataType::Int16 => Box::new(Int16Builder::with_capacity(capacity)), + DataType::Int32 => Box::new(Int32Builder::with_capacity(capacity)), + DataType::Int64 => Box::new(Int64Builder::with_capacity(capacity)), + DataType::UInt8 => Box::new(UInt8Builder::with_capacity(capacity)), + DataType::UInt16 => Box::new(UInt16Builder::with_capacity(capacity)), + DataType::UInt32 => Box::new(UInt32Builder::with_capacity(capacity)), + DataType::UInt64 => Box::new(UInt64Builder::with_capacity(capacity)), + DataType::Float32 => Box::new(Float32Builder::with_capacity(capacity)), + DataType::Float64 => Box::new(Float64Builder::with_capacity(capacity)), + DataType::Binary => Box::new(BinaryBuilder::with_capacity(1024, capacity)), DataType::FixedSizeBinary(len) => { - Box::new(FixedSizeBinaryBuilder::new(capacity, *len)) - } - DataType::Decimal(precision, scale) => { - Box::new(DecimalBuilder::new(capacity, *precision, *scale)) + Box::new(FixedSizeBinaryBuilder::with_capacity(capacity, *len)) } - DataType::Utf8 => Box::new(StringBuilder::new(capacity)), - DataType::Date32 => Box::new(Date32Builder::new(capacity)), - DataType::Date64 => Box::new(Date64Builder::new(capacity)), + DataType::Decimal128(precision, scale) => Box::new( + Decimal128Builder::with_capacity(capacity, *precision, *scale), + ), + DataType::Utf8 => Box::new(StringBuilder::with_capacity(1024, capacity)), + DataType::Date32 => Box::new(Date32Builder::with_capacity(capacity)), + DataType::Date64 => Box::new(Date64Builder::with_capacity(capacity)), DataType::Time32(TimeUnit::Second) => { - Box::new(Time32SecondBuilder::new(capacity)) + Box::new(Time32SecondBuilder::with_capacity(capacity)) } DataType::Time32(TimeUnit::Millisecond) => { - Box::new(Time32MillisecondBuilder::new(capacity)) + Box::new(Time32MillisecondBuilder::with_capacity(capacity)) } DataType::Time64(TimeUnit::Microsecond) => { - Box::new(Time64MicrosecondBuilder::new(capacity)) + Box::new(Time64MicrosecondBuilder::with_capacity(capacity)) } DataType::Time64(TimeUnit::Nanosecond) => { - Box::new(Time64NanosecondBuilder::new(capacity)) + Box::new(Time64NanosecondBuilder::with_capacity(capacity)) } DataType::Timestamp(TimeUnit::Second, _) => { - Box::new(TimestampSecondBuilder::new(capacity)) + Box::new(TimestampSecondBuilder::with_capacity(capacity)) } DataType::Timestamp(TimeUnit::Millisecond, _) => { - Box::new(TimestampMillisecondBuilder::new(capacity)) + Box::new(TimestampMillisecondBuilder::with_capacity(capacity)) } DataType::Timestamp(TimeUnit::Microsecond, _) => { - Box::new(TimestampMicrosecondBuilder::new(capacity)) + Box::new(TimestampMicrosecondBuilder::with_capacity(capacity)) } DataType::Timestamp(TimeUnit::Nanosecond, _) => { - Box::new(TimestampNanosecondBuilder::new(capacity)) + Box::new(TimestampNanosecondBuilder::with_capacity(capacity)) } DataType::Interval(IntervalUnit::YearMonth) => { - Box::new(IntervalYearMonthBuilder::new(capacity)) + Box::new(IntervalYearMonthBuilder::with_capacity(capacity)) } DataType::Interval(IntervalUnit::DayTime) => { - Box::new(IntervalDayTimeBuilder::new(capacity)) + Box::new(IntervalDayTimeBuilder::with_capacity(capacity)) } DataType::Interval(IntervalUnit::MonthDayNano) => { - Box::new(IntervalMonthDayNanoBuilder::new(capacity)) + Box::new(IntervalMonthDayNanoBuilder::with_capacity(capacity)) } DataType::Duration(TimeUnit::Second) => { - Box::new(DurationSecondBuilder::new(capacity)) + Box::new(DurationSecondBuilder::with_capacity(capacity)) } DataType::Duration(TimeUnit::Millisecond) => { - Box::new(DurationMillisecondBuilder::new(capacity)) + Box::new(DurationMillisecondBuilder::with_capacity(capacity)) } DataType::Duration(TimeUnit::Microsecond) => { - Box::new(DurationMicrosecondBuilder::new(capacity)) + Box::new(DurationMicrosecondBuilder::with_capacity(capacity)) } DataType::Duration(TimeUnit::Nanosecond) => { - Box::new(DurationNanosecondBuilder::new(capacity)) + Box::new(DurationNanosecondBuilder::with_capacity(capacity)) } DataType::Struct(fields) => { Box::new(StructBuilder::from_fields(fields.clone(), capacity)) @@ -173,8 +174,7 @@ impl StructBuilder { Self { fields, field_builders, - bitmap_builder: BooleanBufferBuilder::new(0), - len: 0, + null_buffer_builder: NullBufferBuilder::new(0), } } @@ -201,40 +201,48 @@ impl StructBuilder { /// Appends an element (either null or non-null) to the struct. The actual elements /// should be appended for each child sub-array in a consistent way. #[inline] - pub fn append(&mut self, is_valid: bool) -> Result<()> { - self.bitmap_builder.append(is_valid); - self.len += 1; - Ok(()) + pub fn append(&mut self, is_valid: bool) { + self.null_buffer_builder.append(is_valid); } /// Appends a null element to the struct. #[inline] - pub fn append_null(&mut self) -> Result<()> { + pub fn append_null(&mut self) { self.append(false) } /// Builds the `StructArray` and reset this builder. pub fn finish(&mut self) -> StructArray { + self.validate_content(); + let mut child_data = Vec::with_capacity(self.field_builders.len()); for f in &mut self.field_builders { let arr = f.finish(); child_data.push(arr.into_data()); } + let length = self.len(); + let null_bit_buffer = self.null_buffer_builder.finish(); - let null_bit_buffer = self.bitmap_builder.finish(); - let null_count = self.len - null_bit_buffer.count_set_bits(); - let mut builder = ArrayData::builder(DataType::Struct(self.fields.clone())) - .len(self.len) - .child_data(child_data); - if null_count > 0 { - builder = builder.null_bit_buffer(Some(null_bit_buffer)); - } - - self.len = 0; + let builder = ArrayData::builder(DataType::Struct(self.fields.clone())) + .len(length) + .child_data(child_data) + .null_bit_buffer(null_bit_buffer); let array_data = unsafe { builder.build_unchecked() }; StructArray::from(array_data) } + + /// Constructs and validates contents in the builder to ensure that + /// - fields and field_builders are of equal length + /// - the number of items in individual field_builders are equal to self.len() + fn validate_content(&self) { + if self.fields.len() != self.field_builders.len() { + panic!("Number of fields is not equal to the number of field_builders."); + } + if !self.field_builders.iter().all(|x| x.len() == self.len()) { + panic!("StructBuilder and field_builders are of unequal lengths."); + } + } } #[cfg(test)] @@ -247,8 +255,8 @@ mod tests { #[test] fn test_struct_array_builder() { - let string_builder = StringBuilder::new(4); - let int_builder = Int32Builder::new(4); + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); let mut fields = Vec::new(); let mut field_builders = Vec::new(); @@ -263,23 +271,23 @@ mod tests { let string_builder = builder .field_builder::(0) .expect("builder at field 0 should be string builder"); - string_builder.append_value("joe").unwrap(); - string_builder.append_null().unwrap(); - string_builder.append_null().unwrap(); - string_builder.append_value("mark").unwrap(); + string_builder.append_value("joe"); + string_builder.append_null(); + string_builder.append_null(); + string_builder.append_value("mark"); let int_builder = builder .field_builder::(1) .expect("builder at field 1 should be int builder"); - int_builder.append_value(1).unwrap(); - int_builder.append_value(2).unwrap(); - int_builder.append_null().unwrap(); - int_builder.append_value(4).unwrap(); + int_builder.append_value(1); + int_builder.append_value(2); + int_builder.append_null(); + int_builder.append_value(4); - builder.append(true).unwrap(); - builder.append(true).unwrap(); - builder.append_null().unwrap(); - builder.append(true).unwrap(); + builder.append(true); + builder.append(true); + builder.append_null(); + builder.append(true); let arr = builder.finish(); @@ -312,8 +320,8 @@ mod tests { #[test] fn test_struct_array_builder_finish() { - let int_builder = Int32Builder::new(10); - let bool_builder = BooleanBuilder::new(10); + let int_builder = Int32Builder::new(); + let bool_builder = BooleanBuilder::new(); let mut fields = Vec::new(); let mut field_builders = Vec::new(); @@ -326,19 +334,17 @@ mod tests { builder .field_builder::(0) .unwrap() - .append_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - .unwrap(); + .append_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); builder .field_builder::(1) .unwrap() .append_slice(&[ false, true, false, true, false, true, false, true, false, true, - ]) - .unwrap(); + ]); // Append slot values - all are valid. for _ in 0..10 { - assert!(builder.append(true).is_ok()) + builder.append(true); } assert_eq!(10, builder.len()); @@ -351,17 +357,15 @@ mod tests { builder .field_builder::(0) .unwrap() - .append_slice(&[1, 3, 5, 7, 9]) - .unwrap(); + .append_slice(&[1, 3, 5, 7, 9]); builder .field_builder::(1) .unwrap() - .append_slice(&[false, true, false, true, false]) - .unwrap(); + .append_slice(&[false, true, false, true, false]); // Append slot values - all are valid. for _ in 0..5 { - assert!(builder.append(true).is_ok()) + builder.append(true); } assert_eq!(5, builder.len()); @@ -407,7 +411,7 @@ mod tests { #[test] fn test_struct_array_builder_field_builder_type_mismatch() { - let int_builder = Int32Builder::new(10); + let int_builder = Int32Builder::with_capacity(10); let mut fields = Vec::new(); let mut field_builders = Vec::new(); @@ -417,4 +421,44 @@ mod tests { let mut builder = StructBuilder::new(fields, field_builders); assert!(builder.field_builder::(0).is_none()); } + + #[test] + #[should_panic(expected = "StructBuilder and field_builders are of unequal lengths.")] + fn test_struct_array_builder_unequal_field_builders_lengths() { + let mut int_builder = Int32Builder::with_capacity(10); + let mut bool_builder = BooleanBuilder::new(); + + int_builder.append_value(1); + int_builder.append_value(2); + bool_builder.append_value(true); + + let mut fields = Vec::new(); + let mut field_builders = Vec::new(); + fields.push(Field::new("f1", DataType::Int32, false)); + field_builders.push(Box::new(int_builder) as Box); + fields.push(Field::new("f2", DataType::Boolean, false)); + field_builders.push(Box::new(bool_builder) as Box); + + let mut builder = StructBuilder::new(fields, field_builders); + builder.append(true); + builder.append(true); + builder.finish(); + } + + #[test] + #[should_panic( + expected = "Number of fields is not equal to the number of field_builders." + )] + fn test_struct_array_builder_unequal_field_field_builders() { + let int_builder = Int32Builder::with_capacity(10); + + let mut fields = Vec::new(); + let mut field_builders = Vec::new(); + fields.push(Field::new("f1", DataType::Int32, false)); + field_builders.push(Box::new(int_builder) as Box); + fields.push(Field::new("f2", DataType::Boolean, false)); + + let mut builder = StructBuilder::new(fields, field_builders); + builder.finish(); + } } diff --git a/arrow/src/array/builder/union_builder.rs b/arrow/src/array/builder/union_builder.rs index 95d9ea40a3d8..c0ae76853dd2 100644 --- a/arrow/src/array/builder/union_builder.rs +++ b/arrow/src/array/builder/union_builder.rs @@ -29,7 +29,7 @@ use crate::datatypes::Field; use crate::datatypes::{ArrowNativeType, ArrowPrimitiveType}; use crate::error::{ArrowError, Result}; -use super::{BooleanBufferBuilder, BufferBuilder}; +use super::{BufferBuilder, NullBufferBuilder}; use crate::array::make_array; @@ -45,7 +45,7 @@ struct FieldData { /// The number of array slots represented by the buffer slots: usize, /// A builder for the null bitmap - bitmap_builder: BooleanBufferBuilder, + null_buffer_builder: NullBufferBuilder, } /// A type-erased [`BufferBuilder`] used by [`FieldData`] @@ -73,13 +73,17 @@ impl FieldDataValues for BufferBuilder { impl FieldData { /// Creates a new `FieldData`. - fn new(type_id: i8, data_type: DataType) -> Self { + fn new( + type_id: i8, + data_type: DataType, + capacity: usize, + ) -> Self { Self { type_id, data_type, slots: 0, - values_buffer: Box::new(BufferBuilder::::new(1)), - bitmap_builder: BooleanBufferBuilder::new(1), + values_buffer: Box::new(BufferBuilder::::new(capacity)), + null_buffer_builder: NullBufferBuilder::new(capacity), } } @@ -91,14 +95,14 @@ impl FieldData { .expect("Tried to append unexpected type") .append(v); - self.bitmap_builder.append(true); + self.null_buffer_builder.append(true); self.slots += 1; } /// Appends a null to this `FieldData`. fn append_null(&mut self) { self.values_buffer.append_null(); - self.bitmap_builder.append(false); + self.null_buffer_builder.append(false); self.slots += 1; } } @@ -111,7 +115,7 @@ impl FieldData { /// use arrow::array::UnionBuilder; /// use arrow::datatypes::{Float64Type, Int32Type}; /// -/// let mut builder = UnionBuilder::new_dense(3); +/// let mut builder = UnionBuilder::new_dense(); /// builder.append::("a", 1).unwrap(); /// builder.append::("b", 3.0).unwrap(); /// builder.append::("a", 4).unwrap(); @@ -131,7 +135,7 @@ impl FieldData { /// use arrow::array::UnionBuilder; /// use arrow::datatypes::{Float64Type, Int32Type}; /// -/// let mut builder = UnionBuilder::new_sparse(3); +/// let mut builder = UnionBuilder::new_sparse(); /// builder.append::("a", 1).unwrap(); /// builder.append::("b", 3.0).unwrap(); /// builder.append::("a", 4).unwrap(); @@ -155,26 +159,39 @@ pub struct UnionBuilder { type_id_builder: Int8BufferBuilder, /// Builder to keep track of offsets (`None` for sparse unions) value_offset_builder: Option, + initial_capacity: usize, } impl UnionBuilder { /// Creates a new dense array builder. - pub fn new_dense(capacity: usize) -> Self { + pub fn new_dense() -> Self { + Self::with_capacity_dense(1024) + } + + /// Creates a new sparse array builder. + pub fn new_sparse() -> Self { + Self::with_capacity_sparse(1024) + } + + /// Creates a new dense array builder with capacity. + pub fn with_capacity_dense(capacity: usize) -> Self { Self { len: 0, fields: HashMap::default(), type_id_builder: Int8BufferBuilder::new(capacity), value_offset_builder: Some(Int32BufferBuilder::new(capacity)), + initial_capacity: capacity, } } - /// Creates a new sparse array builder. - pub fn new_sparse(capacity: usize) -> Self { + /// Creates a new sparse array builder with capacity. + pub fn with_capacity_sparse(capacity: usize) -> Self { Self { len: 0, fields: HashMap::default(), type_id_builder: Int8BufferBuilder::new(capacity), value_offset_builder: None, + initial_capacity: capacity, } } @@ -215,10 +232,18 @@ impl UnionBuilder { data } None => match self.value_offset_builder { - Some(_) => FieldData::new::(self.fields.len() as i8, T::DATA_TYPE), + Some(_) => FieldData::new::( + self.fields.len() as i8, + T::DATA_TYPE, + self.initial_capacity, + ), + // In the case of a sparse union, we should pass the maximum of the currently length and the capacity. None => { - let mut fd = - FieldData::new::(self.fields.len() as i8, T::DATA_TYPE); + let mut fd = FieldData::new::( + self.fields.len() as i8, + T::DATA_TYPE, + self.len.max(self.initial_capacity), + ); for _ in 0..self.len { fd.append_null(); } @@ -264,7 +289,7 @@ impl UnionBuilder { data_type, mut values_buffer, slots, - mut bitmap_builder, + null_buffer_builder: mut bitmap_builder, }, ) in self.fields.into_iter() { @@ -272,7 +297,7 @@ impl UnionBuilder { let arr_data_builder = ArrayDataBuilder::new(data_type.clone()) .add_buffer(buffer) .len(slots) - .null_bit_buffer(Some(bitmap_builder.finish())); + .null_bit_buffer(bitmap_builder.finish()); let arr_data_ref = unsafe { arr_data_builder.build_unchecked() }; let array_ref = make_array(arr_data_ref); diff --git a/arrow/src/array/cast.rs b/arrow/src/array/cast.rs index d0b77a0d27b5..2b68cbbe6424 100644 --- a/arrow/src/array/cast.rs +++ b/arrow/src/array/cast.rs @@ -15,12 +15,242 @@ // specific language governing permissions and limitations // under the License. -//! Defines helper functions for force Array type downcast +//! Defines helper functions for force [`Array`] downcasts use crate::array::*; use crate::datatypes::*; -/// Force downcast ArrayRef to PrimitiveArray +/// Downcast an [`Array`] to a [`PrimitiveArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow::downcast_primitive_array; +/// # use arrow::array::Array; +/// # use arrow::datatypes::DataType; +/// # use arrow::array::as_string_array; +/// +/// fn print_primitive(array: &dyn Array) { +/// downcast_primitive_array!( +/// array => { +/// for v in array { +/// println!("{:?}", v); +/// } +/// } +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +/// +#[macro_export] +macro_rules! downcast_primitive_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + downcast_primitive_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + match $values.data_type() { + $crate::datatypes::DataType::Int8 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int8Type, + >($values); + $e + } + $crate::datatypes::DataType::Int16 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int16Type, + >($values); + $e + } + $crate::datatypes::DataType::Int32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int32Type, + >($values); + $e + } + $crate::datatypes::DataType::Int64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Int64Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt8 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt8Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt16 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt16Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt32Type, + >($values); + $e + } + $crate::datatypes::DataType::UInt64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::UInt64Type, + >($values); + $e + } + $crate::datatypes::DataType::Float16 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Float16Type, + >($values); + $e + } + $crate::datatypes::DataType::Float32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Float32Type, + >($values); + $e + } + $crate::datatypes::DataType::Float64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Float64Type, + >($values); + $e + } + $crate::datatypes::DataType::Date32 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Date32Type, + >($values); + $e + } + $crate::datatypes::DataType::Date64 => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Date64Type, + >($values); + $e + } + $crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Second) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time32SecondType, + >($values); + $e + } + $crate::datatypes::DataType::Time32($crate::datatypes::TimeUnit::Millisecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time32MillisecondType, + >($values); + $e + } + $crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Microsecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time64MicrosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Time64($crate::datatypes::TimeUnit::Nanosecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::Time64NanosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Second, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampSecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Millisecond, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampMillisecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Microsecond, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampMicrosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Timestamp($crate::datatypes::TimeUnit::Nanosecond, _) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::TimestampNanosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::YearMonth) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::IntervalYearMonthType, + >($values); + $e + } + $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::DayTime) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::IntervalDayTimeType, + >($values); + $e + } + $crate::datatypes::DataType::Interval($crate::datatypes::IntervalUnit::MonthDayNano) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::IntervalMonthDayNanoType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Second) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationSecondType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Millisecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationMillisecondType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Microsecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationMicrosecondType, + >($values); + $e + } + $crate::datatypes::DataType::Duration($crate::datatypes::TimeUnit::Nanosecond) => { + let $values = $crate::array::as_primitive_array::< + $crate::datatypes::DurationNanosecondType, + >($values); + $e + } + $($p => $fallback,)* + } + }; +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`], to +/// [`PrimitiveArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use arrow::array::*; +/// # use arrow::datatypes::*; +/// # use std::sync::Arc; +/// let arr: ArrayRef = Arc::new(Int32Array::from(vec![Some(1)])); +/// +/// // Downcast an `ArrayRef` to Int32Array / PrimiveArray: +/// let primitive_array: &Int32Array = as_primitive_array(&arr); +/// +/// // Equivalently: +/// let primitive_array = as_primitive_array::(&arr); +/// +/// // This is the equivalent of: +/// let primitive_array = arr +/// .as_any() +/// .downcast_ref::() +/// .unwrap(); +/// ``` + pub fn as_primitive_array(arr: &dyn Array) -> &PrimitiveArray where T: ArrowPrimitiveType, @@ -30,7 +260,111 @@ where .expect("Unable to downcast to primitive array") } -/// Force downcast ArrayRef to DictionaryArray +/// Downcast an [`Array`] to a [`DictionaryArray`] based on its [`DataType`], accepts +/// a number of subsequent patterns to match the data type +/// +/// ``` +/// # use arrow::downcast_dictionary_array; +/// # use arrow::array::{Array, StringArray}; +/// # use arrow::datatypes::DataType; +/// # use arrow::array::as_string_array; +/// +/// fn print_strings(array: &dyn Array) { +/// downcast_dictionary_array!( +/// array => match array.values().data_type() { +/// DataType::Utf8 => { +/// for v in array.downcast_dict::().unwrap() { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported dictionary value type {}", t), +/// }, +/// DataType::Utf8 => { +/// for v in as_string_array(array) { +/// println!("{:?}", v); +/// } +/// } +/// t => println!("Unsupported datatype {}", t) +/// ) +/// } +/// ``` +#[macro_export] +macro_rules! downcast_dictionary_array { + ($values:ident => $e:expr, $($p:pat => $fallback:expr $(,)*)*) => { + downcast_dictionary_array!($values => {$e} $($p => $fallback)*) + }; + + ($values:ident => $e:block $($p:pat => $fallback:expr $(,)*)*) => { + match $values.data_type() { + $crate::datatypes::DataType::Dictionary(k, _) => match k.as_ref() { + $crate::datatypes::DataType::Int8 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int8Type, + >($values); + $e + }, + $crate::datatypes::DataType::Int16 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int16Type, + >($values); + $e + }, + $crate::datatypes::DataType::Int32 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int32Type, + >($values); + $e + }, + $crate::datatypes::DataType::Int64 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::Int64Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt8 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt8Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt16 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt16Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt32 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt32Type, + >($values); + $e + }, + $crate::datatypes::DataType::UInt64 => { + let $values = $crate::array::as_dictionary_array::< + $crate::datatypes::UInt64Type, + >($values); + $e + }, + k => unreachable!("unsupported dictionary key type: {}", k) + } + $($p => $fallback,)* + } + } +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`DictionaryArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use arrow::array::*; +/// # use arrow::datatypes::*; +/// # use std::sync::Arc; +/// let arr: DictionaryArray = vec![Some("foo")].into_iter().collect(); +/// let arr: ArrayRef = std::sync::Arc::new(arr); +/// let dict_array: &DictionaryArray = as_dictionary_array::(&arr); +/// ``` pub fn as_dictionary_array(arr: &dyn Array) -> &DictionaryArray where T: ArrowDictionaryKeyType, @@ -40,7 +374,8 @@ where .expect("Unable to downcast to dictionary array") } -#[doc = "Force downcast ArrayRef to GenericListArray"] +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`GenericListArray`], panic'ing on failure. pub fn as_generic_list_array( arr: &dyn Array, ) -> &GenericListArray { @@ -49,19 +384,22 @@ pub fn as_generic_list_array( .expect("Unable to downcast to list array") } -#[doc = "Force downcast ArrayRef to ListArray"] +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`ListArray`], panic'ing on failure. #[inline] pub fn as_list_array(arr: &dyn Array) -> &ListArray { as_generic_list_array::(arr) } -#[doc = "Force downcast ArrayRef to LargeListArray"] +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`LargeListArray`], panic'ing on failure. #[inline] pub fn as_large_list_array(arr: &dyn Array) -> &LargeListArray { as_generic_list_array::(arr) } -#[doc = "Force downcast ArrayRef to GenericBinaryArray"] +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`GenericBinaryArray`], panic'ing on failure. #[inline] pub fn as_generic_binary_array( arr: &dyn Array, @@ -71,9 +409,43 @@ pub fn as_generic_binary_array( .expect("Unable to downcast to binary array") } +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`StringArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use arrow::array::*; +/// # use std::sync::Arc; +/// let arr: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("foo")])); +/// let string_array = as_string_array(&arr); +/// ``` +pub fn as_string_array(arr: &dyn Array) -> &StringArray { + arr.as_any() + .downcast_ref::() + .expect("Unable to downcast to StringArray") +} + +/// Force downcast of an [`Array`], such as an [`ArrayRef`] to +/// [`BooleanArray`], panic'ing on failure. +/// +/// # Example +/// +/// ``` +/// # use arrow::array::*; +/// # use std::sync::Arc; +/// let arr: ArrayRef = Arc::new(BooleanArray::from_iter(vec![Some(true)])); +/// let boolean_array = as_boolean_array(&arr); +/// ``` +pub fn as_boolean_array(arr: &dyn Array) -> &BooleanArray { + arr.as_any() + .downcast_ref::() + .expect("Unable to downcast to BooleanArray") +} + macro_rules! array_downcast_fn { ($name: ident, $arrty: ty, $arrty_str:expr) => { - #[doc = "Force downcast ArrayRef to "] + #[doc = "Force downcast of an [`Array`], such as an [`ArrayRef`] to "] #[doc = $arrty_str] pub fn $name(arr: &dyn Array) -> &$arrty { arr.as_any().downcast_ref::<$arrty>().expect(concat!( @@ -85,18 +457,20 @@ macro_rules! array_downcast_fn { // use recursive macro to generate dynamic doc string for a given array type ($name: ident, $arrty: ty) => { - array_downcast_fn!($name, $arrty, stringify!($arrty)); + array_downcast_fn!( + $name, + $arrty, + concat!("[`", stringify!($arrty), "`], panic'ing on failure.") + ); }; } -array_downcast_fn!(as_string_array, StringArray); array_downcast_fn!(as_largestring_array, LargeStringArray); -array_downcast_fn!(as_boolean_array, BooleanArray); array_downcast_fn!(as_null_array, NullArray); array_downcast_fn!(as_struct_array, StructArray); array_downcast_fn!(as_union_array, UnionArray); array_downcast_fn!(as_map_array, MapArray); -array_downcast_fn!(as_decimal_array, DecimalArray); +array_downcast_fn!(as_decimal_array, Decimal128Array); #[cfg(test)] mod tests { @@ -106,9 +480,9 @@ mod tests { #[test] fn test_as_decimal_array_ref() { - let array: DecimalArray = vec![Some(123), None, Some(1111)] + let array: Decimal128Array = vec![Some(123), None, Some(1111)] .into_iter() - .collect::() + .collect::() .with_precision_and_scale(10, 2) .unwrap(); assert!(!as_decimal_array(&array).is_empty()); diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index eba496cbf09f..7571ba210d7d 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -18,8 +18,12 @@ //! Contains `ArrayData`, a generic representation of Arrow array data which encapsulates //! common attributes and operations for Arrow array. -use crate::datatypes::{validate_decimal_precision, DataType, IntervalUnit, UnionMode}; +use crate::datatypes::{ + validate_decimal256_precision_with_lt_bytes, validate_decimal_precision, DataType, + IntervalUnit, UnionMode, +}; use crate::error::{ArrowError, Result}; +use crate::util::bit_iterator::BitSliceIterator; use crate::{bitmap::Bitmap, datatypes::ArrowNativeType}; use crate::{ buffer::{Buffer, MutableBuffer}, @@ -33,6 +37,21 @@ use std::sync::Arc; use super::equal::equal; +#[inline] +pub(crate) fn contains_nulls( + null_bit_buffer: Option<&Buffer>, + offset: usize, + len: usize, +) -> bool { + match null_bit_buffer { + Some(buffer) => match BitSliceIterator::new(buffer, offset, len).next() { + Some((start, end)) => start != 0 || end != len, + None => len != 0, // No non-null values + }, + None => false, // No null buffer + } +} + #[inline] pub(crate) fn count_nulls( null_bit_buffer: Option<&Buffer>, @@ -189,7 +208,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff DataType::FixedSizeList(_, _) | DataType::Struct(_) => { [empty_buffer, MutableBuffer::new(0)] } - DataType::Decimal(_, _) => [ + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => [ MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, ], @@ -267,7 +286,10 @@ impl ArrayData { /// Create a new ArrayData instance; /// /// If `null_count` is not specified, the number of nulls in - /// null_bit_buffer is calculated + /// null_bit_buffer is calculated. + /// + /// If the number of nulls is 0 then the null_bit_buffer + /// is set to `None`. /// /// # Safety /// @@ -291,7 +313,7 @@ impl ArrayData { None => count_nulls(null_bit_buffer.as_ref(), offset, len), Some(null_count) => null_count, }; - let null_bitmap = null_bit_buffer.map(Bitmap::from); + let null_bitmap = null_bit_buffer.filter(|_| null_count > 0).map(Bitmap::from); let new_self = Self { data_type, len, @@ -311,6 +333,9 @@ impl ArrayData { /// Create a new ArrayData, validating that the provided buffers /// form a valid Arrow array of the specified data type. /// + /// If the number of nulls in `null_bit_buffer` is 0 then the null_bit_buffer + /// is set to `None`. + /// /// Note: This is a low level API and most users of the arrow /// crate should create arrays using the methods in the `array` /// module. @@ -370,18 +395,24 @@ impl ArrayData { /// panic's if the new DataType is not compatible with the /// existing type. /// - /// Note: currently only changing a [DataType::Decimal]s precision - /// and scale are supported + /// Note: currently only changing a [DataType::Decimal128]s or + /// [DataType::Decimal256]s precision and scale are supported #[inline] pub(crate) fn with_data_type(mut self, new_data_type: DataType) -> Self { - assert!( - matches!(self.data_type, DataType::Decimal(_, _)), - "only DecimalType is supported for existing type" - ); - assert!( - matches!(new_data_type, DataType::Decimal(_, _)), - "only DecimalType is supported for new datatype" - ); + if matches!(self.data_type, DataType::Decimal128(_, _)) { + assert!( + matches!(new_data_type, DataType::Decimal128(_, _)), + "only 128-bit DecimalType is supported for new datatype" + ); + } else if matches!(self.data_type, DataType::Decimal256(_, _)) { + assert!( + matches!(new_data_type, DataType::Decimal256(_, _)), + "only 256-bit DecimalType is supported for new datatype" + ); + } else { + panic!("only DecimalType is supported.") + } + self.data_type = new_data_type; self } @@ -572,7 +603,8 @@ impl ArrayData { | DataType::LargeBinary | DataType::Interval(_) | DataType::FixedSizeBinary(_) - | DataType::Decimal(_, _) => vec![], + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => vec![], DataType::List(field) => { vec![Self::new_empty(field.data_type())] } @@ -1004,13 +1036,22 @@ impl ArrayData { pub fn validate_values(&self) -> Result<()> { match &self.data_type { - DataType::Decimal(p, _) => { + DataType::Decimal128(p, _) => { let values_buffer: &[i128] = self.typed_buffer(0, self.len)?; for value in values_buffer { validate_decimal_precision(*value, *p)?; } Ok(()) } + DataType::Decimal256(p, _) => { + let values = self.buffers()[0].as_slice(); + for pos in 0..self.len() { + let offset = pos * 32; + let raw_bytes = &values[offset..offset + 32]; + validate_decimal256_precision_with_lt_bytes(raw_bytes, *p)?; + } + Ok(()) + } DataType::Utf8 => self.validate_utf8::(), DataType::LargeUtf8 => self.validate_utf8::(), DataType::Binary => self.validate_offsets_full::(self.buffers[1].len()), @@ -1119,16 +1160,37 @@ impl ArrayData { T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, { let values_buffer = &self.buffers[1].as_slice(); - - self.validate_each_offset::(values_buffer.len(), |string_index, range| { - std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| { - ArrowError::InvalidArgumentError(format!( - "Invalid UTF8 sequence at string index {} ({:?}): {}", - string_index, range, e - )) - })?; - Ok(()) - }) + if let Ok(values_str) = std::str::from_utf8(values_buffer) { + // Validate Offsets are correct + self.validate_each_offset::( + values_buffer.len(), + |string_index, range| { + if !values_str.is_char_boundary(range.start) + || !values_str.is_char_boundary(range.end) + { + return Err(ArrowError::InvalidArgumentError(format!( + "incomplete utf-8 byte sequence from index {}", + string_index + ))); + } + Ok(()) + }, + ) + } else { + // find specific offset that failed utf8 validation + self.validate_each_offset::( + values_buffer.len(), + |string_index, range| { + std::str::from_utf8(&values_buffer[range.clone()]).map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Invalid UTF8 sequence at string index {} ({:?}): {}", + string_index, range, e + )) + })?; + Ok(()) + }, + ) + } } /// Ensures that all offsets in `buffers[0]` into `buffers[1]` are @@ -1302,11 +1364,15 @@ pub(crate) fn layout(data_type: &DataType) -> DataTypeLayout { } } DataType::Dictionary(key_type, _value_type) => layout(key_type), - DataType::Decimal(_, _) => { + DataType::Decimal128(_, _) => { // Decimals are always some fixed width; The rust implementation // always uses 16 bytes / size of i128 DataTypeLayout::new_fixed_width(size_of::()) } + DataType::Decimal256(_, _) => { + // Decimals are always some fixed width. + DataTypeLayout::new_fixed_width(32) + } DataType::Map(_, _) => { // same as ListType DataTypeLayout::new_fixed_width(size_of::()) @@ -1513,7 +1579,7 @@ mod tests { use std::ptr::NonNull; use crate::array::{ - make_array, Array, BooleanBuilder, DecimalBuilder, FixedSizeListBuilder, + make_array, Array, BooleanBuilder, Decimal128Builder, FixedSizeListBuilder, Int32Array, Int32Builder, Int64Array, StringArray, StructBuilder, UInt64Array, UInt8Builder, }; @@ -2610,8 +2676,8 @@ mod tests { Field::new("b", DataType::Boolean, true), ], vec![ - Box::new(Int32Builder::new(5)), - Box::new(BooleanBuilder::new(5)), + Box::new(Int32Builder::with_capacity(5)), + Box::new(BooleanBuilder::with_capacity(5)), ], ); @@ -2619,66 +2685,56 @@ mod tests { builder .field_builder::(0) .unwrap() - .append_option(Some(10)) - .unwrap(); + .append_option(Some(10)); builder .field_builder::(1) .unwrap() - .append_option(Some(true)) - .unwrap(); - builder.append(true).unwrap(); + .append_option(Some(true)); + builder.append(true); // struct[1] = null builder .field_builder::(0) .unwrap() - .append_option(None) - .unwrap(); + .append_option(None); builder .field_builder::(1) .unwrap() - .append_option(None) - .unwrap(); - builder.append(false).unwrap(); + .append_option(None); + builder.append(false); // struct[2] = { a: null, b: false } builder .field_builder::(0) .unwrap() - .append_option(None) - .unwrap(); + .append_option(None); builder .field_builder::(1) .unwrap() - .append_option(Some(false)) - .unwrap(); - builder.append(true).unwrap(); + .append_option(Some(false)); + builder.append(true); // struct[3] = { a: 21, b: null } builder .field_builder::(0) .unwrap() - .append_option(Some(21)) - .unwrap(); + .append_option(Some(21)); builder .field_builder::(1) .unwrap() - .append_option(None) - .unwrap(); - builder.append(true).unwrap(); + .append_option(None); + builder.append(true); // struct[4] = { a: 18, b: false } builder .field_builder::(0) .unwrap() - .append_option(Some(18)) - .unwrap(); + .append_option(Some(18)); builder .field_builder::(1) .unwrap() - .append_option(Some(false)) - .unwrap(); - builder.append(true).unwrap(); + .append_option(Some(false)); + builder.append(true); let struct_array = builder.finish(); let struct_array_slice = struct_array.slice(1, 3); @@ -2765,38 +2821,33 @@ mod tests { #[test] #[cfg(not(feature = "force_validate"))] fn test_decimal_full_validation() { - let values_builder = UInt8Builder::new(10); + let values_builder = UInt8Builder::with_capacity(10); let byte_width = 16; let mut fixed_size_builder = FixedSizeListBuilder::new(values_builder, byte_width); - let value_as_bytes = DecimalBuilder::from_i128_to_fixed_size_bytes( - 123456, - fixed_size_builder.value_length() as usize, - ) - .unwrap(); + let value_as_bytes = 123456_i128.to_le_bytes(); fixed_size_builder .values() - .append_slice(value_as_bytes.as_slice()) - .unwrap(); - fixed_size_builder.append(true).unwrap(); + .append_slice(value_as_bytes.as_slice()); + fixed_size_builder.append(true); let fixed_size_array = fixed_size_builder.finish(); // Build ArrayData for Decimal - let builder = ArrayData::builder(DataType::Decimal(5, 3)) + let builder = ArrayData::builder(DataType::Decimal128(5, 3)) .len(fixed_size_array.len()) .add_buffer(fixed_size_array.data_ref().child_data()[0].buffers()[0].clone()); let array_data = unsafe { builder.build_unchecked() }; let validation_result = array_data.validate_full(); let error = validation_result.unwrap_err(); assert_eq!( - "Invalid argument error: 123456 is too large to store in a Decimal of precision 5. Max is 99999", + "Invalid argument error: 123456 is too large to store in a Decimal128 of precision 5. Max is 99999", error.to_string() ); } #[test] fn test_decimal_validation() { - let mut builder = DecimalBuilder::new(4, 10, 4); + let mut builder = Decimal128Builder::with_capacity(4, 10, 4); builder.append_value(10000).unwrap(); builder.append_value(20000).unwrap(); let array = builder.finish(); @@ -2829,4 +2880,15 @@ mod tests { let err = data.validate_values().unwrap_err(); assert_eq!(err.to_string(), "Invalid argument error: Offset invariant failure: offset at position 1 out of bounds: 3 > 2"); } + + #[test] + fn test_contains_nulls() { + let buffer: Buffer = + MutableBuffer::from_iter([false, false, false, true, true, false]).into(); + + assert!(contains_nulls(Some(&buffer), 0, 6)); + assert!(contains_nulls(Some(&buffer), 0, 3)); + assert!(!contains_nulls(Some(&buffer), 3, 2)); + assert!(!contains_nulls(Some(&buffer), 0, 0)); + } } diff --git a/arrow/src/array/equal/boolean.rs b/arrow/src/array/equal/boolean.rs index de34d7fab189..fddf21b963ad 100644 --- a/arrow/src/array/equal/boolean.rs +++ b/arrow/src/array/equal/boolean.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData}; +use crate::array::{data::contains_nulls, ArrayData}; +use crate::util::bit_iterator::BitIndexIterator; use crate::util::bit_util::get_bit; use super::utils::{equal_bits, equal_len}; +/// Returns true if the value data for the arrays is equal, assuming the null masks have +/// already been checked for equality pub(super) fn boolean_equal( lhs: &ArrayData, rhs: &ArrayData, @@ -30,19 +33,22 @@ pub(super) fn boolean_equal( let lhs_values = lhs.buffers()[0].as_slice(); let rhs_values = rhs.buffers()[0].as_slice(); - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + let contains_nulls = contains_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - if lhs_null_count == 0 && rhs_null_count == 0 { + if !contains_nulls { // Optimize performance for starting offset at u8 boundary. - if lhs_start % 8 == 0 && rhs_start % 8 == 0 { + if lhs_start % 8 == 0 + && rhs_start % 8 == 0 + && lhs.offset() % 8 == 0 + && rhs.offset() % 8 == 0 + { let quot = len / 8; if quot > 0 && !equal_len( lhs_values, rhs_values, - lhs_start / 8 + lhs.offset(), - rhs_start / 8 + rhs.offset(), + lhs_start / 8 + lhs.offset() / 8, + rhs_start / 8 + rhs.offset() / 8, quot, ) { @@ -71,20 +77,41 @@ pub(super) fn boolean_equal( } else { // get a ref of the null buffer bytes, to use in testing for nullness let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); let lhs_start = lhs.offset() + lhs_start; let rhs_start = rhs.offset() + rhs_start; - (0..len).all(|i| { + BitIndexIterator::new(lhs_null_bytes, lhs_start, len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; - let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos); - let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos); - - lhs_is_null - || (lhs_is_null == rhs_is_null) - && equal_bits(lhs_values, rhs_values, lhs_pos, rhs_pos, 1) + get_bit(lhs_values, lhs_pos) == get_bit(rhs_values, rhs_pos) }) } } + +#[cfg(test)] +mod tests { + use crate::array::{Array, BooleanArray}; + + #[test] + fn test_boolean_slice() { + let array = BooleanArray::from(vec![true; 32]); + let slice = array.slice(4, 12); + assert_eq!(slice.data(), slice.data()); + + let slice = array.slice(8, 12); + assert_eq!(slice.data(), slice.data()); + + let slice = array.slice(8, 24); + assert_eq!(slice.data(), slice.data()); + } + + #[test] + fn test_sliced_nullable_boolean_array() { + let a = BooleanArray::from(vec![None; 32]); + let b = BooleanArray::from(vec![true; 32]); + let slice_a = a.slice(1, 12); + let slice_b = b.slice(1, 12); + assert_ne!(slice_a.data(), slice_b.data()); + } +} diff --git a/arrow/src/array/equal/decimal.rs b/arrow/src/array/equal/decimal.rs index e9879f3f281e..42a7d29e27d2 100644 --- a/arrow/src/array/equal/decimal.rs +++ b/arrow/src/array/equal/decimal.rs @@ -29,7 +29,8 @@ pub(super) fn decimal_equal( len: usize, ) -> bool { let size = match lhs.data_type() { - DataType::Decimal(_, _) => 16, + DataType::Decimal128(_, _) => 16, + DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; diff --git a/arrow/src/array/equal/list.rs b/arrow/src/array/equal/list.rs index 0feefa7aa11a..b3bca9a69228 100644 --- a/arrow/src/array/equal/list.rs +++ b/arrow/src/array/equal/list.rs @@ -160,22 +160,22 @@ mod tests { #[test] fn list_array_non_zero_nulls() { // Tests handling of list arrays with non-empty null ranges - let mut builder = ListBuilder::new(Int64Builder::new(10)); - builder.values().append_value(1).unwrap(); - builder.values().append_value(2).unwrap(); - builder.values().append_value(3).unwrap(); - builder.append(true).unwrap(); - builder.append(false).unwrap(); + let mut builder = ListBuilder::new(Int64Builder::with_capacity(10)); + builder.values().append_value(1); + builder.values().append_value(2); + builder.values().append_value(3); + builder.append(true); + builder.append(false); let array1 = builder.finish(); - let mut builder = ListBuilder::new(Int64Builder::new(10)); - builder.values().append_value(1).unwrap(); - builder.values().append_value(2).unwrap(); - builder.values().append_value(3).unwrap(); - builder.append(true).unwrap(); - builder.values().append_null().unwrap(); - builder.values().append_null().unwrap(); - builder.append(false).unwrap(); + let mut builder = ListBuilder::new(Int64Builder::with_capacity(10)); + builder.values().append_value(1); + builder.values().append_value(2); + builder.values().append_value(3); + builder.append(true); + builder.values().append_null(); + builder.values().append_null(); + builder.append(false); let array2 = builder.finish(); assert_eq!(array1, array2); diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index 74599c2ed6a4..34df0bda0b1f 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -20,9 +20,10 @@ //! depend on dynamic casting of `Array`. use super::{ - Array, ArrayData, BooleanArray, DecimalArray, DictionaryArray, FixedSizeBinaryArray, - FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericStringArray, - MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StructArray, + Array, ArrayData, BooleanArray, Decimal128Array, DictionaryArray, + FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, + GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, + StructArray, }; use crate::datatypes::{ArrowPrimitiveType, DataType, IntervalUnit}; use half::f16; @@ -109,7 +110,7 @@ impl PartialEq for FixedSizeBinaryArray { } } -impl PartialEq for DecimalArray { +impl PartialEq for Decimal128Array { fn eq(&self, other: &Self) -> bool { equal(self.data(), other.data()) } @@ -186,7 +187,9 @@ fn equal_values( DataType::FixedSizeBinary(_) => { fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len) } - DataType::Decimal(_, _) => decimal_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { + decimal_equal(lhs, rhs, lhs_start, rhs_start, len) + } DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::FixedSizeList(_, _) => { @@ -607,13 +610,13 @@ mod tests { } fn create_list_array, T: AsRef<[Option]>>(data: T) -> ArrayData { - let mut builder = ListBuilder::new(Int32Builder::new(10)); + let mut builder = ListBuilder::new(Int32Builder::with_capacity(10)); for d in data.as_ref() { if let Some(v) = d { - builder.values().append_slice(v.as_ref()).unwrap(); - builder.append(true).unwrap() + builder.values().append_slice(v.as_ref()); + builder.append(true); } else { - builder.append(false).unwrap() + builder.append(false); } } builder.finish().into_data() @@ -765,13 +768,13 @@ mod tests { fn create_fixed_size_binary_array, T: AsRef<[Option]>>( data: T, ) -> ArrayData { - let mut builder = FixedSizeBinaryBuilder::new(15, 5); + let mut builder = FixedSizeBinaryBuilder::with_capacity(data.as_ref().len(), 5); for d in data.as_ref() { if let Some(v) = d { builder.append_value(v.as_ref()).unwrap(); } else { - builder.append_null().unwrap(); + builder.append_null(); } } builder.finish().into_data() @@ -838,9 +841,9 @@ mod tests { test_equal(&a_slice, &b_slice, false); } - fn create_decimal_array(data: &[Option]) -> ArrayData { - data.iter() - .collect::() + fn create_decimal_array(data: Vec>) -> ArrayData { + data.into_iter() + .collect::() .with_precision_and_scale(23, 6) .unwrap() .into() @@ -848,32 +851,36 @@ mod tests { #[test] fn test_decimal_equal() { - let a = create_decimal_array(&[Some(8_887_000_000), Some(-8_887_000_000)]); - let b = create_decimal_array(&[Some(8_887_000_000), Some(-8_887_000_000)]); + let a = create_decimal_array(vec![Some(8_887_000_000), Some(-8_887_000_000)]); + let b = create_decimal_array(vec![Some(8_887_000_000), Some(-8_887_000_000)]); test_equal(&a, &b, true); - let b = create_decimal_array(&[Some(15_887_000_000), Some(-8_887_000_000)]); + let b = create_decimal_array(vec![Some(15_887_000_000), Some(-8_887_000_000)]); test_equal(&a, &b, false); } // Test the case where null_count > 0 #[test] fn test_decimal_null() { - let a = create_decimal_array(&[Some(8_887_000_000), None, Some(-8_887_000_000)]); - let b = create_decimal_array(&[Some(8_887_000_000), None, Some(-8_887_000_000)]); + let a = + create_decimal_array(vec![Some(8_887_000_000), None, Some(-8_887_000_000)]); + let b = + create_decimal_array(vec![Some(8_887_000_000), None, Some(-8_887_000_000)]); test_equal(&a, &b, true); - let b = create_decimal_array(&[Some(8_887_000_000), Some(-8_887_000_000), None]); + let b = + create_decimal_array(vec![Some(8_887_000_000), Some(-8_887_000_000), None]); test_equal(&a, &b, false); - let b = create_decimal_array(&[Some(15_887_000_000), None, Some(-8_887_000_000)]); + let b = + create_decimal_array(vec![Some(15_887_000_000), None, Some(-8_887_000_000)]); test_equal(&a, &b, false); } #[test] fn test_decimal_offsets() { // Test the case where offset != 0 - let a = create_decimal_array(&[ + let a = create_decimal_array(vec![ Some(8_887_000_000), None, None, @@ -881,7 +888,7 @@ mod tests { None, None, ]); - let b = create_decimal_array(&[ + let b = create_decimal_array(vec![ None, Some(8_887_000_000), None, @@ -911,7 +918,7 @@ mod tests { let b_slice = b.slice(2, 3); test_equal(&a_slice, &b_slice, false); - let b = create_decimal_array(&[ + let b = create_decimal_array(vec![ None, None, None, @@ -928,17 +935,17 @@ mod tests { fn create_fixed_size_list_array, T: AsRef<[Option]>>( data: T, ) -> ArrayData { - let mut builder = FixedSizeListBuilder::new(Int32Builder::new(10), 3); + let mut builder = FixedSizeListBuilder::new(Int32Builder::with_capacity(10), 3); for d in data.as_ref() { if let Some(v) = d { - builder.values().append_slice(v.as_ref()).unwrap(); - builder.append(true).unwrap() + builder.values().append_slice(v.as_ref()); + builder.append(true); } else { for _ in 0..builder.value_length() { - builder.values().append_null().unwrap(); + builder.values().append_null(); } - builder.append(false).unwrap() + builder.append(false); } } builder.finish().into_data() @@ -1239,7 +1246,7 @@ mod tests { fn create_dictionary_array(values: &[&str], keys: &[Option<&str>]) -> ArrayData { let values = StringArray::from(values.to_vec()); let mut builder = StringDictionaryBuilder::new_with_dictionary( - PrimitiveBuilder::::new(3), + PrimitiveBuilder::::with_capacity(3), &values, ) .unwrap(); @@ -1247,7 +1254,7 @@ mod tests { if let Some(v) = key { builder.append(v).unwrap(); } else { - builder.append_null().unwrap() + builder.append_null() } } builder.finish().into_data() @@ -1363,7 +1370,7 @@ mod tests { #[test] fn test_union_equal_dense() { - let mut builder = UnionBuilder::new_dense(7); + let mut builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); @@ -1373,7 +1380,7 @@ mod tests { builder.append::("b", 7).unwrap(); let union1 = builder.build().unwrap(); - builder = UnionBuilder::new_dense(7); + builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); @@ -1383,7 +1390,7 @@ mod tests { builder.append::("b", 7).unwrap(); let union2 = builder.build().unwrap(); - builder = UnionBuilder::new_dense(7); + builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); @@ -1393,7 +1400,7 @@ mod tests { builder.append::("b", 7).unwrap(); let union3 = builder.build().unwrap(); - builder = UnionBuilder::new_dense(7); + builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); @@ -1410,7 +1417,7 @@ mod tests { #[test] fn test_union_equal_sparse() { - let mut builder = UnionBuilder::new_sparse(7); + let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); @@ -1420,7 +1427,7 @@ mod tests { builder.append::("b", 7).unwrap(); let union1 = builder.build().unwrap(); - builder = UnionBuilder::new_sparse(7); + builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); @@ -1430,7 +1437,7 @@ mod tests { builder.append::("b", 7).unwrap(); let union2 = builder.build().unwrap(); - builder = UnionBuilder::new_sparse(7); + builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); @@ -1440,7 +1447,7 @@ mod tests { builder.append::("b", 7).unwrap(); let union3 = builder.build().unwrap(); - builder = UnionBuilder::new_sparse(7); + builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs index fed3933a0893..449055d366ec 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow/src/array/equal/utils.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData}; +use crate::array::data::contains_nulls; +use crate::array::ArrayData; use crate::datatypes::DataType; -use crate::util::bit_util; +use crate::util::bit_chunk_iterator::BitChunks; // whether bits along the positions are equal // `lhs_start`, `rhs_start` and `len` are _measured in bits_. @@ -29,10 +30,16 @@ pub(super) fn equal_bits( rhs_start: usize, len: usize, ) -> bool { - (0..len).all(|i| { - bit_util::get_bit(lhs_values, lhs_start + i) - == bit_util::get_bit(rhs_values, rhs_start + i) - }) + let lhs = BitChunks::new(lhs_values, lhs_start, len); + let rhs = BitChunks::new(rhs_values, rhs_start, len); + + for (a, b) in lhs.iter().zip(rhs.iter()) { + if a != b { + return false; + } + } + + lhs.remainder_bits() == rhs.remainder_bits() } #[inline] @@ -43,25 +50,16 @@ pub(super) fn equal_nulls( rhs_start: usize, len: usize, ) -> bool { - let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); - let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + let lhs_offset = lhs_start + lhs.offset(); + let rhs_offset = rhs_start + rhs.offset(); - if lhs_null_count != rhs_null_count { - return false; - } - - if lhs_null_count > 0 || rhs_null_count > 0 { - let lhs_values = lhs.null_buffer().unwrap().as_slice(); - let rhs_values = rhs.null_buffer().unwrap().as_slice(); - equal_bits( - lhs_values, - rhs_values, - lhs_start + lhs.offset(), - rhs_start + rhs.offset(), - len, - ) - } else { - true + match (lhs.null_buffer(), rhs.null_buffer()) { + (Some(lhs), Some(rhs)) => { + equal_bits(lhs.as_slice(), rhs.as_slice(), lhs_offset, rhs_offset, len) + } + (Some(lhs), None) => !contains_nulls(Some(lhs), lhs_offset, len), + (None, Some(rhs)) => !contains_nulls(Some(rhs), rhs_offset, len), + (None, None) => true, } } diff --git a/arrow/src/array/equal_json.rs b/arrow/src/array/equal_json.rs deleted file mode 100644 index 3fc84a7e3ab4..000000000000 --- a/arrow/src/array/equal_json.rs +++ /dev/null @@ -1,1160 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use super::*; -use crate::array::BasicDecimalArray; -use crate::datatypes::*; -use crate::util::decimal::BasicDecimal; -use array::Array; -use hex::FromHex; -use serde_json::value::Value::{Null as JNull, Object, String as JString}; -use serde_json::Value; - -/// Trait for comparing arrow array with json array -pub trait JsonEqual { - /// Checks whether arrow array equals to json array. - fn equals_json(&self, json: &[&Value]) -> bool; - - /// Checks whether arrow array equals to json array. - fn equals_json_values(&self, json: &[Value]) -> bool { - let refs = json.iter().collect::>(); - - self.equals_json(&refs) - } -} - -/// Implement array equals for numeric type -impl JsonEqual for PrimitiveArray { - fn equals_json(&self, json: &[&Value]) -> bool { - self.len() == json.len() - && (0..self.len()).all(|i| match json[i] { - Value::Null => self.is_null(i), - v => { - self.is_valid(i) - && Some(v) == self.value(i).into_json_value().as_ref() - } - }) - } -} - -/// Implement array equals for numeric type -impl JsonEqual for BooleanArray { - fn equals_json(&self, json: &[&Value]) -> bool { - self.len() == json.len() - && (0..self.len()).all(|i| match json[i] { - Value::Null => self.is_null(i), - v => { - self.is_valid(i) - && Some(v) == self.value(i).into_json_value().as_ref() - } - }) - } -} - -impl PartialEq for PrimitiveArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(array) => self.equals_json_values(array), - _ => false, - } - } -} - -impl PartialEq> for Value { - fn eq(&self, arrow: &PrimitiveArray) -> bool { - match self { - Value::Array(array) => arrow.equals_json_values(array), - _ => false, - } - } -} - -impl JsonEqual for GenericListArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - Value::Array(v) => self.is_valid(i) && self.value(i).equals_json_values(v), - Value::Null => self.is_null(i) || self.value_length(i).is_zero(), - _ => false, - }) - } -} - -impl PartialEq for GenericListArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq> for Value { - fn eq(&self, arrow: &GenericListArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for DictionaryArray { - fn equals_json(&self, json: &[&Value]) -> bool { - // todo: this is wrong: we must test the values also - self.keys().equals_json(json) - } -} - -impl PartialEq for DictionaryArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq> for Value { - fn eq(&self, arrow: &DictionaryArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for FixedSizeListArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - Value::Array(v) => self.is_valid(i) && self.value(i).equals_json_values(v), - Value::Null => self.is_null(i) || self.value_length() == 0, - _ => false, - }) - } -} - -impl PartialEq for FixedSizeListArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &FixedSizeListArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for StructArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - let all_object = json.iter().all(|v| matches!(v, Object(_) | JNull)); - - if !all_object { - return false; - } - - for column_name in self.column_names() { - let json_values = json - .iter() - .map(|obj| obj.get(column_name).unwrap_or(&Value::Null)) - .collect::>(); - - if !self - .column_by_name(column_name) - .map(|arr| arr.equals_json(&json_values)) - .unwrap_or(false) - { - return false; - } - } - - true - } -} - -impl PartialEq for StructArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &StructArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for MapArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - Value::Array(v) => self.is_valid(i) && self.value(i).equals_json_values(v), - Value::Null => self.is_null(i) || self.value_length(i).eq(&0), - _ => false, - }) - } -} - -impl PartialEq for MapArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &MapArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for GenericBinaryArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - JString(s) => { - // binary data is sometimes hex encoded, this checks if bytes are equal, - // and if not converting to hex is attempted - self.is_valid(i) - && (s.as_str().as_bytes() == self.value(i) - || Vec::from_hex(s.as_str()) == Ok(self.value(i).to_vec())) - } - JNull => self.is_null(i), - _ => false, - }) - } -} - -impl PartialEq for GenericBinaryArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq> for Value { - fn eq(&self, arrow: &GenericBinaryArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for GenericStringArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - JString(s) => self.is_valid(i) && s.as_str() == self.value(i), - JNull => self.is_null(i), - _ => false, - }) - } -} - -impl PartialEq for GenericStringArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq> for Value { - fn eq(&self, arrow: &GenericStringArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for FixedSizeBinaryArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - JString(s) => { - // binary data is sometimes hex encoded, this checks if bytes are equal, - // and if not converting to hex is attempted - self.is_valid(i) - && (s.as_str().as_bytes() == self.value(i) - || Vec::from_hex(s.as_str()) == Ok(self.value(i).to_vec())) - } - JNull => self.is_null(i), - _ => false, - }) - } -} - -impl PartialEq for FixedSizeBinaryArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &FixedSizeBinaryArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for DecimalArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - JString(s) => { - self.is_valid(i) - && (s - .parse::() - .map_or_else(|_| false, |v| v == self.value(i).as_i128())) - } - JNull => self.is_null(i), - _ => false, - }) - } -} - -impl JsonEqual for Decimal256Array { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - (0..self.len()).all(|i| match json[i] { - JString(s) => self.is_valid(i) && (s == &self.value(i).to_string()), - JNull => self.is_null(i), - _ => false, - }) - } -} - -impl PartialEq for DecimalArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &DecimalArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for UnionArray { - fn equals_json(&self, _json: &[&Value]) -> bool { - unimplemented!( - "Added to allow UnionArray to implement the Array trait: see ARROW-8547" - ) - } -} - -impl JsonEqual for NullArray { - fn equals_json(&self, json: &[&Value]) -> bool { - if self.len() != json.len() { - return false; - } - - // all JSON values must be nulls - json.iter().all(|&v| v == &JNull) - } -} - -impl PartialEq for Value { - fn eq(&self, arrow: &NullArray) -> bool { - match self { - Value::Array(json_array) => arrow.equals_json_values(json_array), - _ => false, - } - } -} - -impl PartialEq for NullArray { - fn eq(&self, json: &Value) -> bool { - match json { - Value::Array(json_array) => self.equals_json_values(json_array), - _ => false, - } - } -} - -impl JsonEqual for ArrayRef { - fn equals_json(&self, json: &[&Value]) -> bool { - self.as_ref().equals_json(json) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::error::Result; - use std::{convert::TryFrom, sync::Arc}; - - fn create_list_array, T: AsRef<[Option]>>( - builder: &mut ListBuilder, - data: T, - ) -> Result { - for d in data.as_ref() { - if let Some(v) = d { - builder.values().append_slice(v.as_ref())?; - builder.append(true)? - } else { - builder.append(false)? - } - } - Ok(builder.finish()) - } - - /// Create a fixed size list of 2 value lengths - fn create_fixed_size_list_array, T: AsRef<[Option]>>( - builder: &mut FixedSizeListBuilder, - data: T, - ) -> Result { - for d in data.as_ref() { - if let Some(v) = d { - builder.values().append_slice(v.as_ref())?; - builder.append(true)? - } else { - for _ in 0..builder.value_length() { - builder.values().append_null()?; - } - builder.append(false)? - } - } - Ok(builder.finish()) - } - - #[test] - fn test_primitive_json_equal() { - // Test equaled array - let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let json_array: Value = serde_json::from_str( - r#" - [ - 1, null, 2, 3 - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequaled array - let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let json_array: Value = serde_json::from_str( - r#" - [ - 1, 1, 2, 3 - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test unequal length case - let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let json_array: Value = serde_json::from_str( - r#" - [ - 1, 1 - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test not json array type case - let arrow_array = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_list_json_equal() { - // Test equal case - let arrow_array = create_list_array( - &mut ListBuilder::new(Int32Builder::new(10)), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - [1, 2, 3], - null, - [4, 5, 6] - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - let arrow_array = create_list_array( - &mut ListBuilder::new(Int32Builder::new(10)), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - [1, 2, 3], - [7, 8], - [4, 5, 6] - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let arrow_array = create_list_array( - &mut ListBuilder::new(Int32Builder::new(10)), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_fixed_size_list_json_equal() { - // Test equal case - let arrow_array = create_fixed_size_list_array( - &mut FixedSizeListBuilder::new(Int32Builder::new(10), 3), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - [1, 2, 3], - null, - [4, 5, 6] - ] - "#, - ) - .unwrap(); - println!("{:?}", arrow_array); - println!("{:?}", json_array); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - let arrow_array = create_fixed_size_list_array( - &mut FixedSizeListBuilder::new(Int32Builder::new(10), 3), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - [1, 2, 3], - [7, 8, 9], - [4, 5, 6] - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let arrow_array = create_fixed_size_list_array( - &mut FixedSizeListBuilder::new(Int32Builder::new(10), 3), - &[Some(&[1, 2, 3]), None, Some(&[4, 5, 6])], - ) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_string_json_equal() { - // Test the equal case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None, None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "world", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None, None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "arrow", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test unequal length case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "arrow", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect value type case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - 1, - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_binary_json_equal() { - // Test the equal case - let mut builder = BinaryBuilder::new(6); - builder.append_value(b"hello").unwrap(); - builder.append_null().unwrap(); - builder.append_null().unwrap(); - builder.append_value(b"world").unwrap(); - builder.append_null().unwrap(); - builder.append_null().unwrap(); - let arrow_array = builder.finish(); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "world", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None, None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "arrow", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test unequal length case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "arrow", - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect value type case - let arrow_array = - StringArray::from(vec![Some("hello"), None, None, Some("world"), None]); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - 1, - null, - null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_fixed_size_binary_json_equal() { - // Test the equal case - let mut builder = FixedSizeBinaryBuilder::new(15, 5); - builder.append_value(b"hello").unwrap(); - builder.append_null().unwrap(); - builder.append_value(b"world").unwrap(); - let arrow_array: FixedSizeBinaryArray = builder.finish(); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - "world" - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - builder.append_value(b"hello").unwrap(); - builder.append_null().unwrap(); - builder.append_value(b"world").unwrap(); - let arrow_array: FixedSizeBinaryArray = builder.finish(); - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - "arrow" - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test unequal length case - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - null, - "world" - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect value type case - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - 1 - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_decimal_json_equal() { - // Test the equal case - let arrow_array = [Some(1_000), None, Some(-250)] - .iter() - .collect::() - .with_precision_and_scale(23, 6) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - "1000", - null, - "-250" - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal case - let arrow_array = [Some(1_000), None, Some(55)] - .iter() - .collect::() - .with_precision_and_scale(23, 6) - .unwrap(); - let json_array: Value = serde_json::from_str( - r#" - [ - "1000", - null, - "-250" - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test unequal length case - let json_array: Value = serde_json::from_str( - r#" - [ - "1000", - null, - null, - "55" - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let json_array: Value = serde_json::from_str( - r#" - { - "a": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect value type case - let json_array: Value = serde_json::from_str( - r#" - [ - "hello", - null, - 1 - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_struct_json_equal() { - let strings: ArrayRef = Arc::new(StringArray::from(vec![ - Some("joe"), - None, - None, - Some("mark"), - Some("doe"), - ])); - let ints: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - Some(2), - None, - Some(4), - Some(5), - ])); - - let arrow_array = - StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) - .unwrap(); - - let json_array: Value = serde_json::from_str( - r#" - [ - { - "f1": "joe", - "f2": 1 - }, - { - "f2": 2 - }, - null, - { - "f1": "mark", - "f2": 4 - }, - { - "f1": "doe", - "f2": 5 - } - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequal length case - let json_array: Value = serde_json::from_str( - r#" - [ - { - "f1": "joe", - "f2": 1 - }, - { - "f2": 2 - }, - null, - { - "f1": "mark", - "f2": 4 - } - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test incorrect type case - let json_array: Value = serde_json::from_str( - r#" - { - "f1": "joe", - "f2": 1 - } - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - - // Test not all object case - let json_array: Value = serde_json::from_str( - r#" - [ - { - "f1": "joe", - "f2": 1 - }, - 2, - null, - { - "f1": "mark", - "f2": 4 - } - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } - - #[test] - fn test_null_json_equal() { - // Test equaled array - let arrow_array = NullArray::new(4); - let json_array: Value = serde_json::from_str( - r#" - [ - null, null, null, null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.eq(&json_array)); - assert!(json_array.eq(&arrow_array)); - - // Test unequaled array - let arrow_array = NullArray::new(2); - let json_array: Value = serde_json::from_str( - r#" - [ - null, null, null - ] - "#, - ) - .unwrap(); - assert!(arrow_array.ne(&json_array)); - assert!(json_array.ne(&arrow_array)); - } -} diff --git a/arrow/src/array/ffi.rs b/arrow/src/array/ffi.rs index 12d6f440b78d..72030f900a4e 100644 --- a/arrow/src/array/ffi.rs +++ b/arrow/src/array/ffi.rs @@ -25,7 +25,7 @@ use crate::{ ffi::ArrowArrayRef, }; -use super::ArrayData; +use super::{make_array, ArrayData, ArrayRef}; impl TryFrom for ArrayData { type Error = ArrowError; @@ -39,10 +39,46 @@ impl TryFrom for ffi::ArrowArray { type Error = ArrowError; fn try_from(value: ArrayData) -> Result { - unsafe { ffi::ArrowArray::try_new(value) } + ffi::ArrowArray::try_new(value) } } +/// Creates a new array from two FFI pointers. Used to import arrays from the C Data Interface +/// # Safety +/// Assumes that these pointers represent valid C Data Interfaces, both in memory +/// representation and lifetime via the `release` mechanism. +pub unsafe fn make_array_from_raw( + array: *const ffi::FFI_ArrowArray, + schema: *const ffi::FFI_ArrowSchema, +) -> Result { + let array = ffi::ArrowArray::try_from_raw(array, schema)?; + let data = ArrayData::try_from(array)?; + Ok(make_array(data)) +} + +/// Exports an array to raw pointers of the C Data Interface provided by the consumer. +/// # Safety +/// Assumes that these pointers represent valid C Data Interfaces, both in memory +/// representation and lifetime via the `release` mechanism. +/// +/// This function copies the content of two FFI structs [ffi::FFI_ArrowArray] and +/// [ffi::FFI_ArrowSchema] in the array to the location pointed by the raw pointers. +/// Usually the raw pointers are provided by the array data consumer. +pub unsafe fn export_array_into_raw( + src: ArrayRef, + out_array: *mut ffi::FFI_ArrowArray, + out_schema: *mut ffi::FFI_ArrowSchema, +) -> Result<()> { + let data = src.data(); + let array = ffi::FFI_ArrowArray::new(data); + let schema = ffi::FFI_ArrowSchema::try_from(data.data_type())?; + + std::ptr::write_unaligned(out_array, array); + std::ptr::write_unaligned(out_schema, schema); + + Ok(()) +} + #[cfg(test)] mod tests { use crate::array::{DictionaryArray, FixedSizeListArray, Int32Array, StringArray}; diff --git a/arrow/src/array/iterator.rs b/arrow/src/array/iterator.rs index 9ac2d0642d44..4269e99625b7 100644 --- a/arrow/src/array/iterator.rs +++ b/arrow/src/array/iterator.rs @@ -15,36 +15,38 @@ // specific language governing permissions and limitations // under the License. -use crate::array::BasicDecimalArray; -use crate::datatypes::ArrowPrimitiveType; +use crate::array::array::ArrayAccessor; +use crate::array::{DecimalArray, FixedSizeBinaryArray}; +use crate::datatypes::{Decimal128Type, Decimal256Type}; use super::{ - Array, ArrayRef, BooleanArray, DecimalArray, GenericBinaryArray, GenericListArray, - GenericStringArray, OffsetSizeTrait, PrimitiveArray, + BooleanArray, GenericBinaryArray, GenericListArray, GenericStringArray, + PrimitiveArray, }; -/// an iterator that returns Some(T) or None, that can be used on any PrimitiveArray +/// an iterator that returns Some(T) or None, that can be used on any [`ArrayAccessor`] // Note: This implementation is based on std's [Vec]s' [IntoIter]. #[derive(Debug)] -pub struct PrimitiveIter<'a, T: ArrowPrimitiveType> { - array: &'a PrimitiveArray, +pub struct ArrayIter { + array: T, current: usize, current_end: usize, } -impl<'a, T: ArrowPrimitiveType> PrimitiveIter<'a, T> { +impl ArrayIter { /// create a new iterator - pub fn new(array: &'a PrimitiveArray) -> Self { - PrimitiveIter:: { + pub fn new(array: T) -> Self { + let len = array.len(); + ArrayIter { array, current: 0, - current_end: array.len(), + current_end: len, } } } -impl<'a, T: ArrowPrimitiveType> std::iter::Iterator for PrimitiveIter<'a, T> { - type Item = Option; +impl Iterator for ArrayIter { + type Item = Option; #[inline] fn next(&mut self) -> Option { @@ -73,301 +75,7 @@ impl<'a, T: ArrowPrimitiveType> std::iter::Iterator for PrimitiveIter<'a, T> { } } -impl<'a, T: ArrowPrimitiveType> std::iter::DoubleEndedIterator for PrimitiveIter<'a, T> { - fn next_back(&mut self) -> Option { - if self.current_end == self.current { - None - } else { - self.current_end -= 1; - Some(if self.array.is_null(self.current_end) { - None - } else { - // Safety: - // we just checked bounds in `self.current_end == self.current` - // this is safe on the premise that this struct is initialized with - // current = array.len() - // and that current_end is ever only decremented - unsafe { Some(self.array.value_unchecked(self.current_end)) } - }) - } - } -} - -/// all arrays have known size. -impl<'a, T: ArrowPrimitiveType> std::iter::ExactSizeIterator for PrimitiveIter<'a, T> {} - -/// an iterator that returns Some(bool) or None. -// Note: This implementation is based on std's [Vec]s' [IntoIter]. -#[derive(Debug)] -pub struct BooleanIter<'a> { - array: &'a BooleanArray, - current: usize, - current_end: usize, -} - -impl<'a> BooleanIter<'a> { - /// create a new iterator - pub fn new(array: &'a BooleanArray) -> Self { - BooleanIter { - array, - current: 0, - current_end: array.len(), - } - } -} - -impl<'a> std::iter::Iterator for BooleanIter<'a> { - type Item = Option; - - fn next(&mut self) -> Option { - if self.current == self.current_end { - None - } else if self.array.is_null(self.current) { - self.current += 1; - Some(None) - } else { - let old = self.current; - self.current += 1; - // Safety: - // we just checked bounds in `self.current_end == self.current` - // this is safe on the premise that this struct is initialized with - // current = array.len() - // and that current_end is ever only decremented - unsafe { Some(Some(self.array.value_unchecked(old))) } - } - } - - fn size_hint(&self) -> (usize, Option) { - ( - self.array.len() - self.current, - Some(self.array.len() - self.current), - ) - } -} - -impl<'a> std::iter::DoubleEndedIterator for BooleanIter<'a> { - fn next_back(&mut self) -> Option { - if self.current_end == self.current { - None - } else { - self.current_end -= 1; - Some(if self.array.is_null(self.current_end) { - None - } else { - // Safety: - // we just checked bounds in `self.current_end == self.current` - // this is safe on the premise that this struct is initialized with - // current = array.len() - // and that current_end is ever only decremented - unsafe { Some(self.array.value_unchecked(self.current_end)) } - }) - } - } -} - -/// all arrays have known size. -impl<'a> std::iter::ExactSizeIterator for BooleanIter<'a> {} - -/// an iterator that returns `Some(&str)` or `None`, for string arrays -#[derive(Debug)] -pub struct GenericStringIter<'a, T> -where - T: OffsetSizeTrait, -{ - array: &'a GenericStringArray, - current: usize, - current_end: usize, -} - -impl<'a, T: OffsetSizeTrait> GenericStringIter<'a, T> { - /// create a new iterator - pub fn new(array: &'a GenericStringArray) -> Self { - GenericStringIter:: { - array, - current: 0, - current_end: array.len(), - } - } -} - -impl<'a, T: OffsetSizeTrait> std::iter::Iterator for GenericStringIter<'a, T> { - type Item = Option<&'a str>; - - fn next(&mut self) -> Option { - let i = self.current; - if i >= self.current_end { - None - } else if self.array.is_null(i) { - self.current += 1; - Some(None) - } else { - self.current += 1; - // Safety: - // we just checked bounds in `self.current_end == self.current` - // this is safe on the premise that this struct is initialized with - // current = array.len() - // and that current_end is ever only decremented - unsafe { Some(Some(self.array.value_unchecked(i))) } - } - } - - fn size_hint(&self) -> (usize, Option) { - ( - self.current_end - self.current, - Some(self.current_end - self.current), - ) - } -} - -impl<'a, T: OffsetSizeTrait> std::iter::DoubleEndedIterator for GenericStringIter<'a, T> { - fn next_back(&mut self) -> Option { - if self.current_end == self.current { - None - } else { - self.current_end -= 1; - Some(if self.array.is_null(self.current_end) { - None - } else { - // Safety: - // we just checked bounds in `self.current_end == self.current` - // this is safe on the premise that this struct is initialized with - // current = array.len() - // and that current_end is ever only decremented - unsafe { Some(self.array.value_unchecked(self.current_end)) } - }) - } - } -} - -/// all arrays have known size. -impl<'a, T: OffsetSizeTrait> std::iter::ExactSizeIterator for GenericStringIter<'a, T> {} - -/// an iterator that returns `Some(&[u8])` or `None`, for binary arrays -#[derive(Debug)] -pub struct GenericBinaryIter<'a, T> -where - T: OffsetSizeTrait, -{ - array: &'a GenericBinaryArray, - current: usize, - current_end: usize, -} - -impl<'a, T: OffsetSizeTrait> GenericBinaryIter<'a, T> { - /// create a new iterator - pub fn new(array: &'a GenericBinaryArray) -> Self { - GenericBinaryIter:: { - array, - current: 0, - current_end: array.len(), - } - } -} - -impl<'a, T: OffsetSizeTrait> std::iter::Iterator for GenericBinaryIter<'a, T> { - type Item = Option<&'a [u8]>; - - fn next(&mut self) -> Option { - let i = self.current; - if i >= self.current_end { - None - } else if self.array.is_null(i) { - self.current += 1; - Some(None) - } else { - self.current += 1; - // Safety: - // we just checked bounds in `self.current_end == self.current` - // this is safe on the premise that this struct is initialized with - // current = array.len() - // and that current_end is ever only decremented - unsafe { Some(Some(self.array.value_unchecked(i))) } - } - } - - fn size_hint(&self) -> (usize, Option) { - ( - self.current_end - self.current, - Some(self.current_end - self.current), - ) - } -} - -impl<'a, T: OffsetSizeTrait> std::iter::DoubleEndedIterator for GenericBinaryIter<'a, T> { - fn next_back(&mut self) -> Option { - if self.current_end == self.current { - None - } else { - self.current_end -= 1; - Some(if self.array.is_null(self.current_end) { - None - } else { - // Safety: - // we just checked bounds in `self.current_end == self.current` - // this is safe on the premise that this struct is initialized with - // current = array.len() - // and that current_end is ever only decremented - unsafe { Some(self.array.value_unchecked(self.current_end)) } - }) - } - } -} - -/// all arrays have known size. -impl<'a, T: OffsetSizeTrait> std::iter::ExactSizeIterator for GenericBinaryIter<'a, T> {} - -#[derive(Debug)] -pub struct GenericListArrayIter<'a, S> -where - S: OffsetSizeTrait, -{ - array: &'a GenericListArray, - current: usize, - current_end: usize, -} - -impl<'a, S: OffsetSizeTrait> GenericListArrayIter<'a, S> { - pub fn new(array: &'a GenericListArray) -> Self { - GenericListArrayIter:: { - array, - current: 0, - current_end: array.len(), - } - } -} - -impl<'a, S: OffsetSizeTrait> std::iter::Iterator for GenericListArrayIter<'a, S> { - type Item = Option; - - fn next(&mut self) -> Option { - let i = self.current; - if i >= self.current_end { - None - } else if self.array.is_null(i) { - self.current += 1; - Some(None) - } else { - self.current += 1; - // Safety: - // we just checked bounds in `self.current_end == self.current` - // this is safe on the premise that this struct is initialized with - // current = array.len() - // and that current_end is ever only decremented - unsafe { Some(Some(self.array.value_unchecked(i))) } - } - } - - fn size_hint(&self) -> (usize, Option) { - ( - self.current_end - self.current, - Some(self.current_end - self.current), - ) - } -} - -impl<'a, S: OffsetSizeTrait> std::iter::DoubleEndedIterator - for GenericListArrayIter<'a, S> -{ +impl DoubleEndedIterator for ArrayIter { fn next_back(&mut self) -> Option { if self.current_end == self.current { None @@ -388,58 +96,24 @@ impl<'a, S: OffsetSizeTrait> std::iter::DoubleEndedIterator } /// all arrays have known size. -impl<'a, S: OffsetSizeTrait> std::iter::ExactSizeIterator - for GenericListArrayIter<'a, S> -{ -} - -/// an iterator that returns `Some(i128)` or `None`, that can be used on a -/// [`DecimalArray`] -#[derive(Debug)] -pub struct DecimalIter<'a> { - array: &'a DecimalArray, - current: usize, - current_end: usize, -} +impl ExactSizeIterator for ArrayIter {} -impl<'a> DecimalIter<'a> { - pub fn new(array: &'a DecimalArray) -> Self { - Self { - array, - current: 0, - current_end: array.len(), - } - } -} - -impl<'a> std::iter::Iterator for DecimalIter<'a> { - type Item = Option; - - fn next(&mut self) -> Option { - if self.current == self.current_end { - None - } else { - let old = self.current; - self.current += 1; - // TODO: Improve performance by avoiding bounds check here - // (by using adding a `value_unchecked, for example) - if self.array.is_null(old) { - Some(None) - } else { - Some(Some(self.array.value(old).as_i128())) - } - } - } - - #[inline] - fn size_hint(&self) -> (usize, Option) { - let remain = self.current_end - self.current; - (remain, Some(remain)) - } -} - -/// iterator has known size. -impl<'a> std::iter::ExactSizeIterator for DecimalIter<'a> {} +/// an iterator that returns Some(T) or None, that can be used on any PrimitiveArray +pub type PrimitiveIter<'a, T> = ArrayIter<&'a PrimitiveArray>; +pub type BooleanIter<'a> = ArrayIter<&'a BooleanArray>; +pub type GenericStringIter<'a, T> = ArrayIter<&'a GenericStringArray>; +pub type GenericBinaryIter<'a, T> = ArrayIter<&'a GenericBinaryArray>; +pub type FixedSizeBinaryIter<'a> = ArrayIter<&'a FixedSizeBinaryArray>; +pub type GenericListArrayIter<'a, O> = ArrayIter<&'a GenericListArray>; + +pub type DecimalIter<'a, T> = ArrayIter<&'a DecimalArray>; +/// an iterator that returns `Some(Decimal128)` or `None`, that can be used on a +/// [`super::Decimal128Array`] +pub type Decimal128Iter<'a> = DecimalIter<'a, Decimal128Type>; + +/// an iterator that returns `Some(Decimal256)` or `None`, that can be used on a +/// [`super::Decimal256Array`] +pub type Decimal256Iter<'a> = DecimalIter<'a, Decimal256Type>; #[cfg(test)] mod tests { diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index 2f025f11c0f1..6ad2c26fee5d 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -24,9 +24,11 @@ //! Arrays are often passed around as a dynamically typed [`&dyn Array`] or [`ArrayRef`]. //! For example, [`RecordBatch`](`crate::record_batch::RecordBatch`) stores columns as [`ArrayRef`]. //! -//! Whilst these arrays can be passed directly to the [`compute`](crate::compute), -//! [`csv`](crate::csv), [`json`](crate::json), etc... APIs, it is often the case that you wish -//! to interact with the data directly. This requires downcasting to the concrete type of the array: +//! Whilst these arrays can be passed directly to the +//! [`compute`](crate::compute), [`csv`](crate::csv), +//! [`json`](crate::json), etc... APIs, it is often the case that you +//! wish to interact with the data directly. This requires downcasting +//! to the concrete type of the array: //! //! ``` //! # use arrow::array::{Array, Float32Array, Int32Array}; @@ -42,6 +44,19 @@ //! } //! ``` //! +//! Additionally, there are convenient functions to do this casting +//! such as [`as_primitive_array`] and [`as_string_array`]: +//! +//! ``` +//! # use arrow::array::*; +//! # use arrow::datatypes::*; +//! # +//! fn as_f32_slice(array: &dyn Array) -> &[f32] { +//! // use as_primtive_array +//! as_primitive_array::(array).values() +//! } +//! ``` + //! # Building an Array //! //! Most [`Array`] implementations can be constructed directly from iterators or [`Vec`] @@ -79,13 +94,13 @@ //! let mut builder = Int16Array::builder(100); //! //! // Append a single primitive value -//! builder.append_value(1).unwrap(); +//! builder.append_value(1); //! //! // Append a null value -//! builder.append_null().unwrap(); +//! builder.append_null(); //! //! // Append a slice of primitive values -//! builder.append_slice(&[2, 3, 4]).unwrap(); +//! builder.append_slice(&[2, 3, 4]); //! //! // Build the array //! let array = builder.finish(); @@ -148,6 +163,8 @@ mod array_binary; mod array_boolean; mod array_decimal; mod array_dictionary; +mod array_fixed_size_binary; +mod array_fixed_size_list; mod array_list; mod array_map; mod array_primitive; @@ -158,7 +175,7 @@ mod builder; mod cast; mod data; mod equal; -mod equal_json; +#[cfg(feature = "ffi")] mod ffi; mod iterator; mod null; @@ -171,22 +188,25 @@ use crate::datatypes::*; // --------------------- Array & ArrayData --------------------- pub use self::array::Array; +pub use self::array::ArrayAccessor; pub use self::array::ArrayRef; -pub(crate) use self::data::layout; pub use self::data::ArrayData; pub use self::data::ArrayDataBuilder; pub use self::data::ArrayDataRef; -pub(crate) use self::data::BufferSpec; + +#[cfg(feature = "ipc")] +pub(crate) use self::data::{layout, BufferSpec}; pub use self::array_binary::BinaryArray; -pub use self::array_binary::FixedSizeBinaryArray; pub use self::array_binary::LargeBinaryArray; pub use self::array_boolean::BooleanArray; -pub use self::array_decimal::BasicDecimalArray; +pub use self::array_decimal::Decimal128Array; pub use self::array_decimal::Decimal256Array; pub use self::array_decimal::DecimalArray; -pub use self::array_dictionary::DictionaryArray; -pub use self::array_list::FixedSizeListArray; +pub use self::array_fixed_size_binary::FixedSizeBinaryArray; +pub use self::array_fixed_size_list::FixedSizeListArray; + +pub use self::array_dictionary::{DictionaryArray, TypedDictionaryArray}; pub use self::array_list::LargeListArray; pub use self::array_list::ListArray; pub use self::array_map::MapArray; @@ -471,8 +491,12 @@ pub use self::builder::BinaryBuilder; pub use self::builder::BooleanBufferBuilder; pub use self::builder::BooleanBuilder; pub use self::builder::BufferBuilder; +pub use self::builder::Decimal128Builder; pub use self::builder::Decimal256Builder; -pub use self::builder::DecimalBuilder; + +#[deprecated(note = "Please use `Decimal128Builder` instead")] +pub type DecimalBuilder = Decimal128Builder; + pub use self::builder::FixedSizeBinaryBuilder; pub use self::builder::FixedSizeListBuilder; pub use self::builder::GenericListBuilder; @@ -570,10 +594,6 @@ pub use self::transform::{Capacities, MutableArrayData}; pub use self::iterator::*; -// --------------------- Array Equality --------------------- - -pub use self::equal_json::JsonEqual; - // --------------------- Array's values comparison --------------------- pub use self::ord::{build_compare, DynComparator}; @@ -589,7 +609,8 @@ pub use self::cast::{ // ------------------------------ C Data Interface --------------------------- -pub use self::array::{export_array_into_raw, make_array_from_raw}; +#[cfg(feature = "ffi")] +pub use self::ffi::{export_array_into_raw, make_array_from_raw}; #[cfg(test)] mod tests { diff --git a/arrow/src/array/ord.rs b/arrow/src/array/ord.rs index 019b1163b50a..dd6539589c13 100644 --- a/arrow/src/array/ord.rs +++ b/arrow/src/array/ord.rs @@ -19,7 +19,6 @@ use std::cmp::Ordering; -use crate::array::BasicDecimalArray; use crate::array::*; use crate::datatypes::TimeUnit; use crate::datatypes::*; @@ -226,9 +225,9 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { - let left: DecimalArray = DecimalArray::from(left.data().clone()); - let right: DecimalArray = DecimalArray::from(right.data().clone()); + (Decimal128(_, _), Decimal128(_, _)) => { + let left: Decimal128Array = Decimal128Array::from(left.data().clone()); + let right: Decimal128Array = Decimal128Array::from(right.data().clone()); Box::new(move |i, j| left.value(i).cmp(&right.value(j))) } (lhs, _) => { @@ -301,9 +300,9 @@ pub mod tests { #[test] fn test_decimal() -> Result<()> { - let array = vec![Some(5), Some(2), Some(3)] - .iter() - .collect::() + let array = vec![Some(5_i128), Some(2_i128), Some(3_i128)] + .into_iter() + .collect::() .with_precision_and_scale(23, 6) .unwrap(); diff --git a/arrow/src/array/transform/fixed_binary.rs b/arrow/src/array/transform/fixed_binary.rs index 36952d46a4d6..6d6262ca3c4e 100644 --- a/arrow/src/array/transform/fixed_binary.rs +++ b/arrow/src/array/transform/fixed_binary.rs @@ -22,6 +22,7 @@ use super::{Extend, _MutableArrayData}; pub(super) fn build_extend(array: &ArrayData) -> Extend { let size = match array.data_type() { DataType::FixedSizeBinary(i) => *i as usize, + DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; @@ -57,6 +58,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { let size = match mutable.data_type { DataType::FixedSizeBinary(i) => i as usize, + DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs index a103c35e5067..48859922a26e 100644 --- a/arrow/src/array/transform/mod.rs +++ b/arrow/src/array/transform/mod.rs @@ -205,7 +205,7 @@ fn build_extend_dictionary( fn build_extend(array: &ArrayData) -> Extend { use crate::datatypes::*; match array.data_type() { - DataType::Decimal(_, _) => primitive::build_extend::(array), + DataType::Decimal128(_, _) => primitive::build_extend::(array), DataType::Null => null::build_extend(array), DataType::Boolean => boolean::build_extend(array), DataType::UInt8 => primitive::build_extend::(array), @@ -241,7 +241,9 @@ fn build_extend(array: &ArrayData) -> Extend { DataType::LargeList(_) => list::build_extend::(array), DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), DataType::Struct(_) => structure::build_extend(array), - DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), + DataType::FixedSizeBinary(_) | DataType::Decimal256(_, _) => { + fixed_binary::build_extend(array) + } DataType::Float16 => primitive::build_extend::(array), DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array), DataType::Union(_, _, mode) => match mode { @@ -254,7 +256,7 @@ fn build_extend(array: &ArrayData) -> Extend { fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { use crate::datatypes::*; Box::new(match data_type { - DataType::Decimal(_, _) => primitive::extend_nulls::, + DataType::Decimal128(_, _) => primitive::extend_nulls::, DataType::Null => null::extend_nulls, DataType::Boolean => boolean::extend_nulls, DataType::UInt8 => primitive::extend_nulls::, @@ -292,7 +294,9 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { _ => unreachable!(), }, DataType::Struct(_) => structure::extend_nulls, - DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls, + DataType::FixedSizeBinary(_) | DataType::Decimal256(_, _) => { + fixed_binary::extend_nulls + } DataType::Float16 => primitive::extend_nulls::, DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls, DataType::Union(_, _, mode) => match mode { @@ -309,11 +313,7 @@ fn preallocate_offset_and_binary_buffer( // offsets let mut buffer = MutableBuffer::new((1 + capacity) * mem::size_of::()); // safety: `unsafe` code assumes that this buffer is initialized with one element - if Offset::IS_LARGE { - buffer.push(0i64); - } else { - buffer.push(0i32) - } + buffer.push(Offset::zero()); [ buffer, @@ -406,7 +406,8 @@ impl<'a> MutableArrayData<'a> { }; let child_data = match &data_type { - DataType::Decimal(_, _) + DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) | DataType::Null | DataType::Boolean | DataType::UInt8 @@ -669,8 +670,7 @@ mod tests { use std::{convert::TryFrom, sync::Arc}; use super::*; - - use crate::array::DecimalArray; + use crate::array::Decimal128Array; use crate::{ array::{ Array, ArrayData, ArrayRef, BooleanArray, DictionaryArray, @@ -687,13 +687,13 @@ mod tests { }; fn create_decimal_array( - array: &[Option], - precision: usize, - scale: usize, - ) -> DecimalArray { + array: Vec>, + precision: u8, + scale: u8, + ) -> Decimal128Array { array - .iter() - .collect::() + .into_iter() + .collect::() .with_precision_and_scale(precision, scale) .unwrap() } @@ -702,28 +702,28 @@ mod tests { #[cfg(not(feature = "force_validate"))] fn test_decimal() { let decimal_array = - create_decimal_array(&[Some(1), Some(2), None, Some(3)], 10, 3); - let arrays = vec![decimal_array.data()]; + create_decimal_array(vec![Some(1), Some(2), None, Some(3)], 10, 3); + let arrays = vec![Array::data(&decimal_array)]; let mut a = MutableArrayData::new(arrays, true, 3); a.extend(0, 0, 3); a.extend(0, 2, 3); let result = a.freeze(); - let array = DecimalArray::from(result); - let expected = create_decimal_array(&[Some(1), Some(2), None, None], 10, 3); + let array = Decimal128Array::from(result); + let expected = create_decimal_array(vec![Some(1), Some(2), None, None], 10, 3); assert_eq!(array, expected); } #[test] #[cfg(not(feature = "force_validate"))] fn test_decimal_offset() { let decimal_array = - create_decimal_array(&[Some(1), Some(2), None, Some(3)], 10, 3); + create_decimal_array(vec![Some(1), Some(2), None, Some(3)], 10, 3); let decimal_array = decimal_array.slice(1, 3); // 2, null, 3 let arrays = vec![decimal_array.data()]; let mut a = MutableArrayData::new(arrays, true, 2); a.extend(0, 0, 2); // 2, null let result = a.freeze(); - let array = DecimalArray::from(result); - let expected = create_decimal_array(&[Some(2), None], 10, 3); + let array = Decimal128Array::from(result); + let expected = create_decimal_array(vec![Some(2), None], 10, 3); assert_eq!(array, expected); } @@ -731,7 +731,7 @@ mod tests { #[cfg(not(feature = "force_validate"))] fn test_decimal_null_offset_nulls() { let decimal_array = - create_decimal_array(&[Some(1), Some(2), None, Some(3)], 10, 3); + create_decimal_array(vec![Some(1), Some(2), None, Some(3)], 10, 3); let decimal_array = decimal_array.slice(1, 3); // 2, null, 3 let arrays = vec![decimal_array.data()]; let mut a = MutableArrayData::new(arrays, true, 2); @@ -739,9 +739,9 @@ mod tests { a.extend_nulls(3); // 2, null, null, null, null a.extend(0, 1, 3); //2, null, null, null, null, null, 3 let result = a.freeze(); - let array = DecimalArray::from(result); + let array = Decimal128Array::from(result); let expected = create_decimal_array( - &[Some(2), None, None, None, None, None, Some(3)], + vec![Some(2), None, None, None, None, None, Some(3)], 10, 3, ); @@ -806,15 +806,15 @@ mod tests { } #[test] - fn test_list_null_offset() -> Result<()> { - let int_builder = Int64Builder::new(24); + fn test_list_null_offset() { + let int_builder = Int64Builder::with_capacity(24); let mut builder = ListBuilder::::new(int_builder); - builder.values().append_slice(&[1, 2, 3])?; - builder.append(true)?; - builder.values().append_slice(&[4, 5])?; - builder.append(true)?; - builder.values().append_slice(&[6, 7, 8])?; - builder.append(true)?; + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5]); + builder.append(true); + builder.values().append_slice(&[6, 7, 8]); + builder.append(true); let array = builder.finish(); let arrays = vec![array.data()]; @@ -824,15 +824,13 @@ mod tests { let result = mutable.freeze(); let array = ListArray::from(result); - let int_builder = Int64Builder::new(24); + let int_builder = Int64Builder::with_capacity(24); let mut builder = ListBuilder::::new(int_builder); - builder.values().append_slice(&[1, 2, 3])?; - builder.append(true)?; + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); let expected = builder.finish(); assert_eq!(array, expected); - - Ok(()) } /// tests extending from a variable-sized (strings and binary) array w/ offset with nulls @@ -966,7 +964,7 @@ mod tests { fn create_dictionary_array(values: &[&str], keys: &[Option<&str>]) -> ArrayData { let values = StringArray::from(values.to_vec()); let mut builder = StringDictionaryBuilder::new_with_dictionary( - PrimitiveBuilder::::new(3), + PrimitiveBuilder::::with_capacity(3), &values, ) .unwrap(); @@ -974,7 +972,7 @@ mod tests { if let Some(v) = key { builder.append(v).unwrap(); } else { - builder.append_null().unwrap() + builder.append_null() } } builder.finish().into_data() @@ -1176,24 +1174,25 @@ mod tests { } #[test] - fn test_list_append() -> Result<()> { - let mut builder = ListBuilder::::new(Int64Builder::new(24)); - builder.values().append_slice(&[1, 2, 3])?; - builder.append(true)?; - builder.values().append_slice(&[4, 5])?; - builder.append(true)?; - builder.values().append_slice(&[6, 7, 8])?; - builder.values().append_slice(&[9, 10, 11])?; - builder.append(true)?; + fn test_list_append() { + let mut builder = + ListBuilder::::new(Int64Builder::with_capacity(24)); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5]); + builder.append(true); + builder.values().append_slice(&[6, 7, 8]); + builder.values().append_slice(&[9, 10, 11]); + builder.append(true); let a = builder.finish(); - let a_builder = Int64Builder::new(24); + let a_builder = Int64Builder::with_capacity(24); let mut a_builder = ListBuilder::::new(a_builder); - a_builder.values().append_slice(&[12, 13])?; - a_builder.append(true)?; - a_builder.append(true)?; - a_builder.values().append_slice(&[14, 15])?; - a_builder.append(true)?; + a_builder.values().append_slice(&[12, 13]); + a_builder.append(true); + a_builder.append(true); + a_builder.values().append_slice(&[14, 15]); + a_builder.append(true); let b = a_builder.finish(); let c = b.slice(1, 2); @@ -1239,35 +1238,35 @@ mod tests { ) .unwrap(); assert_eq!(finished, expected_list_data); - - Ok(()) } #[test] fn test_list_nulls_append() -> Result<()> { - let mut builder = ListBuilder::::new(Int64Builder::new(32)); - builder.values().append_slice(&[1, 2, 3])?; - builder.append(true)?; - builder.values().append_slice(&[4, 5])?; - builder.append(true)?; - builder.append(false)?; - builder.values().append_slice(&[6, 7, 8])?; - builder.values().append_null()?; - builder.values().append_null()?; - builder.values().append_slice(&[9, 10, 11])?; - builder.append(true)?; + let mut builder = + ListBuilder::::new(Int64Builder::with_capacity(32)); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5]); + builder.append(true); + builder.append(false); + builder.values().append_slice(&[6, 7, 8]); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_slice(&[9, 10, 11]); + builder.append(true); let a = builder.finish(); let a = a.data(); - let mut builder = ListBuilder::::new(Int64Builder::new(32)); - builder.values().append_slice(&[12, 13])?; - builder.append(true)?; - builder.append(false)?; - builder.append(true)?; - builder.values().append_null()?; - builder.values().append_null()?; - builder.values().append_slice(&[14, 15])?; - builder.append(true)?; + let mut builder = + ListBuilder::::new(Int64Builder::with_capacity(32)); + builder.values().append_slice(&[12, 13]); + builder.append(true); + builder.append(false); + builder.append(true); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_slice(&[14, 15]); + builder.append(true); let b = builder.finish(); let b = b.data(); let c = b.slice(1, 2); @@ -1325,24 +1324,25 @@ mod tests { } #[test] - fn test_list_append_with_capacities() -> Result<()> { - let mut builder = ListBuilder::::new(Int64Builder::new(24)); - builder.values().append_slice(&[1, 2, 3])?; - builder.append(true)?; - builder.values().append_slice(&[4, 5])?; - builder.append(true)?; - builder.values().append_slice(&[6, 7, 8])?; - builder.values().append_slice(&[9, 10, 11])?; - builder.append(true)?; + fn test_list_append_with_capacities() { + let mut builder = + ListBuilder::::new(Int64Builder::with_capacity(24)); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5]); + builder.append(true); + builder.values().append_slice(&[6, 7, 8]); + builder.values().append_slice(&[9, 10, 11]); + builder.append(true); let a = builder.finish(); - let a_builder = Int64Builder::new(24); + let a_builder = Int64Builder::with_capacity(24); let mut a_builder = ListBuilder::::new(a_builder); - a_builder.values().append_slice(&[12, 13])?; - a_builder.append(true)?; - a_builder.append(true)?; - a_builder.values().append_slice(&[14, 15, 16, 17])?; - a_builder.append(true)?; + a_builder.values().append_slice(&[12, 13]); + a_builder.append(true); + a_builder.append(true); + a_builder.values().append_slice(&[14, 15, 16, 17]); + a_builder.append(true); let b = a_builder.finish(); let mutable = MutableArrayData::with_capacities( @@ -1354,52 +1354,48 @@ mod tests { // capacities are rounded up to multiples of 64 by MutableBuffer assert_eq!(mutable.data.buffer1.capacity(), 64); assert_eq!(mutable.data.child_data[0].data.buffer1.capacity(), 192); - - Ok(()) } #[test] fn test_map_nulls_append() -> Result<()> { let mut builder = MapBuilder::::new( None, - Int64Builder::new(32), - Int64Builder::new(32), + Int64Builder::with_capacity(32), + Int64Builder::with_capacity(32), ); - builder.keys().append_slice(&[1, 2, 3])?; - builder.values().append_slice(&[1, 2, 3])?; - builder.append(true)?; - builder.keys().append_slice(&[4, 5])?; - builder.values().append_slice(&[4, 5])?; - builder.append(true)?; - builder.append(false)?; - builder - .keys() - .append_slice(&[6, 7, 8, 100, 101, 9, 10, 11])?; - builder.values().append_slice(&[6, 7, 8])?; - builder.values().append_null()?; - builder.values().append_null()?; - builder.values().append_slice(&[9, 10, 11])?; - builder.append(true)?; + builder.keys().append_slice(&[1, 2, 3]); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true).unwrap(); + builder.keys().append_slice(&[4, 5]); + builder.values().append_slice(&[4, 5]); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.keys().append_slice(&[6, 7, 8, 100, 101, 9, 10, 11]); + builder.values().append_slice(&[6, 7, 8]); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_slice(&[9, 10, 11]); + builder.append(true).unwrap(); let a = builder.finish(); let a = a.data(); let mut builder = MapBuilder::::new( None, - Int64Builder::new(32), - Int64Builder::new(32), + Int64Builder::with_capacity(32), + Int64Builder::with_capacity(32), ); - builder.keys().append_slice(&[12, 13])?; - builder.values().append_slice(&[12, 13])?; - builder.append(true)?; - builder.append(false)?; - builder.append(true)?; - builder.keys().append_slice(&[100, 101, 14, 15])?; - builder.values().append_null()?; - builder.values().append_null()?; - builder.values().append_slice(&[14, 15])?; - builder.append(true)?; + builder.keys().append_slice(&[12, 13]); + builder.values().append_slice(&[12, 13]); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.append(true).unwrap(); + builder.keys().append_slice(&[100, 101, 14, 15]); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_slice(&[14, 15]); + builder.append(true).unwrap(); let b = builder.finish(); let b = b.data(); @@ -1511,24 +1507,24 @@ mod tests { #[test] fn test_list_of_strings_append() -> Result<()> { // [["alpha", "beta", None]] - let mut builder = ListBuilder::new(StringBuilder::new(32)); - builder.values().append_value("Hello")?; - builder.values().append_value("Arrow")?; - builder.values().append_null()?; - builder.append(true)?; + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.values().append_value("Hello"); + builder.values().append_value("Arrow"); + builder.values().append_null(); + builder.append(true); let a = builder.finish(); // [["alpha", "beta"], [None], ["gamma", "delta", None]] - let mut builder = ListBuilder::new(StringBuilder::new(32)); - builder.values().append_value("alpha")?; - builder.values().append_value("beta")?; - builder.append(true)?; - builder.values().append_null()?; - builder.append(true)?; - builder.values().append_value("gamma")?; - builder.values().append_value("delta")?; - builder.values().append_null()?; - builder.append(true)?; + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.values().append_value("alpha"); + builder.values().append_value("beta"); + builder.append(true); + builder.values().append_null(); + builder.append(true); + builder.values().append_value("gamma"); + builder.values().append_value("delta"); + builder.values().append_null(); + builder.append(true); let b = builder.finish(); let mut mutable = MutableArrayData::new(vec![a.data(), b.data()], false, 10); diff --git a/arrow/src/buffer/immutable.rs b/arrow/src/buffer/immutable.rs index cb686bd8441c..28042a3817be 100644 --- a/arrow/src/buffer/immutable.rs +++ b/arrow/src/buffer/immutable.rs @@ -22,7 +22,6 @@ use std::sync::Arc; use std::{convert::AsRef, usize}; use crate::alloc::{Allocation, Deallocation}; -use crate::ffi::FFI_ArrowArray; use crate::util::bit_chunk_iterator::{BitChunks, UnalignedBitChunk}; use crate::{bytes::Bytes, datatypes::ArrowNativeType}; @@ -38,15 +37,20 @@ pub struct Buffer { /// The offset into the buffer. offset: usize, + + /// Byte length of the buffer. + length: usize, } impl Buffer { /// Auxiliary method to create a new Buffer #[inline] pub fn from_bytes(bytes: Bytes) -> Self { + let length = bytes.len(); Buffer { data: Arc::new(bytes), offset: 0, + length, } } @@ -77,30 +81,6 @@ impl Buffer { Buffer::build_with_arguments(ptr, len, Deallocation::Arrow(capacity)) } - /// Creates a buffer from an existing memory region (must already be byte-aligned), this - /// `Buffer` **does not** free this piece of memory when dropped. - /// - /// # Arguments - /// - /// * `ptr` - Pointer to raw parts - /// * `len` - Length of raw parts in **bytes** - /// * `data` - An [crate::ffi::FFI_ArrowArray] with the data - /// - /// # Safety - /// - /// This function is unsafe as there is no guarantee that the given pointer is valid for `len` - /// bytes and that the foreign deallocator frees the region. - #[deprecated( - note = "use from_custom_allocation instead which makes it clearer that the allocation is in fact owned" - )] - pub unsafe fn from_unowned( - ptr: NonNull, - len: usize, - data: Arc, - ) -> Self { - Self::from_custom_allocation(ptr, len, data) - } - /// Creates a buffer from an existing memory region. Ownership of the memory is tracked via reference counting /// and the memory will be freed using the `drop` method of [crate::alloc::Allocation] when the reference count reaches zero. /// @@ -131,28 +111,32 @@ impl Buffer { Buffer { data: Arc::new(bytes), offset: 0, + length: len, } } /// Returns the number of bytes in the buffer + #[inline] pub fn len(&self) -> usize { - self.data.len() - self.offset + self.length } /// Returns the capacity of this buffer. /// For externally owned buffers, this returns zero + #[inline] pub fn capacity(&self) -> usize { self.data.capacity() } /// Returns whether the buffer is empty. + #[inline] pub fn is_empty(&self) -> bool { - self.data.len() - self.offset == 0 + self.length == 0 } /// Returns the byte slice stored in this buffer pub fn as_slice(&self) -> &[u8] { - &self.data[self.offset..] + &self.data[self.offset..(self.offset + self.length)] } /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`. @@ -167,6 +151,24 @@ impl Buffer { Self { data: self.data.clone(), offset: self.offset + offset, + length: self.length - offset, + } + } + + /// Returns a new [Buffer] that is a slice of this buffer starting at `offset`, + /// with `length` bytes. + /// Doing so allows the same memory region to be shared between buffers. + /// # Panics + /// Panics iff `(offset + length)` is larger than the existing length. + pub fn slice_with_length(&self, offset: usize, length: usize) -> Self { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + Self { + data: self.data.clone(), + offset: self.offset + offset, + length, } } @@ -344,10 +346,10 @@ mod tests { let buf2 = Buffer::from(&[0, 1, 2, 3, 4]); assert_eq!(buf1, buf2); - // slice with same offset should still preserve equality + // slice with same offset and same length should still preserve equality let buf3 = buf1.slice(2); assert_ne!(buf1, buf3); - let buf4 = buf2.slice(2); + let buf4 = buf2.slice_with_length(2, 3); assert_eq!(buf3, buf4); // Different capacities should still preserve equality @@ -401,7 +403,7 @@ mod tests { assert_eq!(3, buf2.len()); assert_eq!(unsafe { buf.as_ptr().offset(2) }, buf2.as_ptr()); - let buf3 = buf2.slice(1); + let buf3 = buf2.slice_with_length(1, 2); assert_eq!([8, 10], buf3.as_slice()); assert_eq!(2, buf3.len()); assert_eq!(unsafe { buf.as_ptr().offset(3) }, buf3.as_ptr()); @@ -411,7 +413,7 @@ mod tests { assert_eq!(empty_slice, buf4.as_slice()); assert_eq!(0, buf4.len()); assert!(buf4.is_empty()); - assert_eq!(buf2.slice(2).as_slice(), &[10]); + assert_eq!(buf2.slice_with_length(2, 1).as_slice(), &[10]); } #[test] @@ -482,7 +484,7 @@ mod tests { assert_eq!( 8, Buffer::from(&[0b11111111, 0b11111111]) - .slice(1) + .slice_with_length(1, 1) .count_set_bits() ); assert_eq!( @@ -494,7 +496,7 @@ mod tests { assert_eq!( 6, Buffer::from(&[0b11111111, 0b01001001, 0b01010010]) - .slice(1) + .slice_with_length(1, 2) .count_set_bits() ); assert_eq!( diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs index 11783b82da54..1c662ec23eef 100644 --- a/arrow/src/buffer/mutable.rs +++ b/arrow/src/buffer/mutable.rs @@ -377,7 +377,7 @@ impl MutableBuffer { /// # Safety /// `ptr` must be allocated for `old_capacity`. -#[inline] +#[cold] unsafe fn reallocate( ptr: NonNull, old_capacity: usize, diff --git a/arrow/src/buffer/ops.rs b/arrow/src/buffer/ops.rs index ea155c8d78e4..7000f39767cb 100644 --- a/arrow/src/buffer/ops.rs +++ b/arrow/src/buffer/ops.rs @@ -18,6 +18,53 @@ use super::{Buffer, MutableBuffer}; use crate::util::bit_util::ceil; +/// Apply a bitwise operation `op` to four inputs and return the result as a Buffer. +/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. +#[allow(clippy::too_many_arguments)] +pub(crate) fn bitwise_quaternary_op_helper( + first: &Buffer, + first_offset_in_bits: usize, + second: &Buffer, + second_offset_in_bits: usize, + third: &Buffer, + third_offset_in_bits: usize, + fourth: &Buffer, + fourth_offset_in_bits: usize, + len_in_bits: usize, + op: F, +) -> Buffer +where + F: Fn(u64, u64, u64, u64) -> u64, +{ + let first_chunks = first.bit_chunks(first_offset_in_bits, len_in_bits); + let second_chunks = second.bit_chunks(second_offset_in_bits, len_in_bits); + let third_chunks = third.bit_chunks(third_offset_in_bits, len_in_bits); + let fourth_chunks = fourth.bit_chunks(fourth_offset_in_bits, len_in_bits); + + let chunks = first_chunks + .iter() + .zip(second_chunks.iter()) + .zip(third_chunks.iter()) + .zip(fourth_chunks.iter()) + .map(|(((first, second), third), fourth)| op(first, second, third, fourth)); + // Soundness: `BitChunks` is a `BitChunks` iterator which + // correctly reports its upper bound + let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) }; + + let remainder_bytes = ceil(first_chunks.remainder_len(), 8); + let rem = op( + first_chunks.remainder_bits(), + second_chunks.remainder_bits(), + third_chunks.remainder_bits(), + fourth_chunks.remainder_bits(), + ); + // we are counting its starting from the least significant bit, to to_le_bytes should be correct + let rem = &rem.to_le_bytes()[0..remainder_bytes]; + buffer.extend_from_slice(rem); + + buffer.into() +} + /// Apply a bitwise operation `op` to two inputs and return the result as a Buffer. /// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. pub fn bitwise_bin_op_helper( diff --git a/arrow/src/compute/README.md b/arrow/src/compute/README.md index 761713a531b4..a5d15a83046f 100644 --- a/arrow/src/compute/README.md +++ b/arrow/src/compute/README.md @@ -33,16 +33,16 @@ We use the term "kernel" to refer to particular general operation that contains Types of functions -* Scalar functions: elementwise functions that perform scalar operations in a +- Scalar functions: elementwise functions that perform scalar operations in a vectorized manner. These functions are generally valid for SQL-like context. These are called "scalar" in that the functions executed consider each value in an array independently, and the output array or arrays have the same length as the input arrays. The result for each array cell is generally independent of its position in the array. -* Vector functions, which produce a result whose output is generally dependent +- Vector functions, which produce a result whose output is generally dependent on the entire contents of the input arrays. These functions **are generally not valid** for SQL-like processing because the output size may be different than the input size, and the result may change based on the order of the values in the array. This includes things like array subselection, sorting, hashing, and more. -* Scalar aggregate functions of which can be used in a SQL-like context \ No newline at end of file +- Scalar aggregate functions of which can be used in a SQL-like context diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index 12ead669f79d..d7726fbf92aa 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -21,14 +21,14 @@ use multiversion::multiversion; use std::ops::Add; use crate::array::{ - Array, BooleanArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, - PrimitiveArray, + as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray, + GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; -use crate::datatypes::{ArrowNativeType, ArrowNumericType}; +use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType}; /// Generic test for NaN, the optimizer should be able to remove this for integer types. #[inline] -fn is_nan(a: T) -> bool { +pub(crate) fn is_nan(a: T) -> bool { #[allow(clippy::eq_op)] !(a == a) } @@ -185,6 +185,99 @@ pub fn min_string(array: &GenericStringArray) -> Option<& } /// Returns the sum of values in the array. +pub fn sum_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: Add, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let iter = ArrayIter::new(array); + let sum = iter + .into_iter() + .fold(T::default_value(), |accumulator, value| { + if let Some(value) = value { + accumulator + value + } else { + accumulator + } + }); + + Some(sum) + } + _ => sum::(as_primitive_array(&array)), + } +} + +/// Returns the min of values in the array. +pub fn min_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeType, +{ + min_max_array_helper::( + array, + |a, b| (!is_nan(*a) & is_nan(*b)) || a < b, + min, + ) +} + +/// Returns the max of values in the array. +pub fn max_array>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: ArrowNativeType, +{ + min_max_array_helper::( + array, + |a, b| (is_nan(*a) & !is_nan(*b)) || a > b, + max, + ) +} + +fn min_max_array_helper, F, M>( + array: A, + cmp: F, + m: M, +) -> Option +where + T: ArrowNumericType, + F: Fn(&T::Native, &T::Native) -> bool, + M: Fn(&PrimitiveArray) -> Option, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let mut has_value = false; + let mut n = T::default_value(); + let iter = ArrayIter::new(array); + iter.into_iter().for_each(|value| { + if let Some(value) = value { + if !has_value || cmp(&value, &n) { + has_value = true; + n = value; + } + } + }); + + Some(n) + } + _ => m(as_primitive_array(&array)), + } +} + +/// Returns the sum of values in the primitive array. /// /// Returns `None` if the array is empty or only contains null values. #[cfg(not(feature = "simd"))] @@ -583,7 +676,7 @@ mod simd { } } -/// Returns the sum of values in the array. +/// Returns the sum of values in the primitive array. /// /// Returns `None` if the array is empty or only contains null values. #[cfg(feature = "simd")] @@ -625,6 +718,7 @@ mod tests { use super::*; use crate::array::*; use crate::compute::add; + use crate::datatypes::{Float32Type, Int32Type, Int8Type}; #[test] fn test_primitive_array_sum() { @@ -1003,4 +1097,71 @@ mod tests { assert_eq!(Some(true), min_boolean(&a)); assert_eq!(Some(true), max_boolean(&a)); } + + #[test] + fn test_sum_dyn() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(39, sum_array::(array).unwrap()); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(15, sum_array::(&a).unwrap()); + + let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(26, sum_array::(array).unwrap()); + + let keys = Int8Array::from(vec![None, None, None]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert!(sum_array::(array).is_none()); + } + + #[test] + fn test_max_min_dyn() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(14, max_array::(array).unwrap()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_array::(array).unwrap()); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(5, max_array::(&a).unwrap()); + assert_eq!(1, min_array::(&a).unwrap()); + + let keys = Int8Array::from(vec![Some(2_i8), None, Some(7)]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(17, max_array::(array).unwrap()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_array::(array).unwrap()); + + let keys = Int8Array::from(vec![None, None, None]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert!(max_array::(array).is_none()); + let array = dict_array.downcast_dict::().unwrap(); + assert!(min_array::(array).is_none()); + } + + #[test] + fn test_max_min_dyn_nan() { + let values = Float32Array::from(vec![5.0_f32, 2.0_f32, f32::NAN]); + let keys = Int8Array::from_iter_values([0_i8, 1, 2]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert!(max_array::(array).unwrap().is_nan()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(2.0_f32, min_array::(array).unwrap()); + } } diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 67fb61356d7c..fff687e18b3c 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -1229,8 +1229,8 @@ mod tests { #[test] fn test_primitive_array_add_dyn_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(5).unwrap(); builder.append(6).unwrap(); @@ -1239,13 +1239,13 @@ mod tests { builder.append(9).unwrap(); let a = builder.finish(); - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(6).unwrap(); builder.append(7).unwrap(); builder.append(8).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(10).unwrap(); let b = builder.finish(); @@ -1270,11 +1270,11 @@ mod tests { assert!(c.is_null(3)); assert_eq!(10, c.value(4)); - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(5).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(7).unwrap(); builder.append(8).unwrap(); builder.append(9).unwrap(); @@ -1313,8 +1313,8 @@ mod tests { #[test] fn test_primitive_array_subtract_dyn_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(15).unwrap(); builder.append(8).unwrap(); @@ -1323,13 +1323,13 @@ mod tests { builder.append(20).unwrap(); let a = builder.finish(); - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(6).unwrap(); builder.append(7).unwrap(); builder.append(8).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(10).unwrap(); let b = builder.finish(); @@ -1354,11 +1354,11 @@ mod tests { assert!(c.is_null(3)); assert_eq!(8, c.value(4)); - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(5).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(7).unwrap(); builder.append(8).unwrap(); builder.append(9).unwrap(); @@ -1397,8 +1397,8 @@ mod tests { #[test] fn test_primitive_array_multiply_dyn_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(5).unwrap(); builder.append(6).unwrap(); @@ -1407,13 +1407,13 @@ mod tests { builder.append(9).unwrap(); let a = builder.finish(); - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(6).unwrap(); builder.append(7).unwrap(); builder.append(8).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(10).unwrap(); let b = builder.finish(); @@ -1441,8 +1441,8 @@ mod tests { #[test] fn test_primitive_array_divide_dyn_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(15).unwrap(); builder.append(6).unwrap(); @@ -1451,13 +1451,13 @@ mod tests { builder.append(9).unwrap(); let a = builder.finish(); - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(5).unwrap(); builder.append(3).unwrap(); builder.append(1).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(3).unwrap(); let b = builder.finish(); @@ -1482,11 +1482,11 @@ mod tests { assert!(c.is_null(3)); assert_eq!(18, c.value(4)); - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(5).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(7).unwrap(); builder.append(8).unwrap(); builder.append(9).unwrap(); @@ -1668,11 +1668,11 @@ mod tests { assert!(c.is_null(3)); assert_eq!(4, c.value(4)); - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(5).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(7).unwrap(); builder.append(8).unwrap(); builder.append(9).unwrap(); @@ -1936,14 +1936,14 @@ mod tests { #[test] #[should_panic(expected = "DivideByZero")] fn test_primitive_array_divide_dyn_by_zero_dict() { - let key_builder = PrimitiveBuilder::::new(1); - let value_builder = PrimitiveBuilder::::new(1); + let key_builder = PrimitiveBuilder::::with_capacity(1); + let value_builder = PrimitiveBuilder::::with_capacity(1); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(15).unwrap(); let a = builder.finish(); - let key_builder = PrimitiveBuilder::::new(1); - let value_builder = PrimitiveBuilder::::new(1); + let key_builder = PrimitiveBuilder::::with_capacity(1); + let value_builder = PrimitiveBuilder::::with_capacity(1); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(0).unwrap(); let b = builder.finish(); diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 5135218168f7..be9d56ebb19b 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -228,25 +228,25 @@ mod tests { #[test] fn test_unary_dict_and_unary_dyn() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(5).unwrap(); builder.append(6).unwrap(); builder.append(7).unwrap(); builder.append(8).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(9).unwrap(); let dictionary_array = builder.finish(); - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(6).unwrap(); builder.append(7).unwrap(); builder.append(8).unwrap(); builder.append(9).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(10).unwrap(); let expected = builder.finish(); diff --git a/arrow/src/compute/kernels/boolean.rs b/arrow/src/compute/kernels/boolean.rs index 209edc48d195..c51953a7540c 100644 --- a/arrow/src/compute/kernels/boolean.rs +++ b/arrow/src/compute/kernels/boolean.rs @@ -26,162 +26,168 @@ use std::ops::Not; use crate::array::{Array, ArrayData, BooleanArray, PrimitiveArray}; use crate::buffer::{ - buffer_bin_and, buffer_bin_or, buffer_unary_not, Buffer, MutableBuffer, + bitwise_bin_op_helper, bitwise_quaternary_op_helper, buffer_bin_and, buffer_bin_or, + buffer_unary_not, Buffer, MutableBuffer, }; use crate::compute::util::combine_option_bitmap; use crate::datatypes::{ArrowNumericType, DataType}; use crate::error::{ArrowError, Result}; -use crate::util::bit_util::{ceil, round_upto_multiple_of_64}; -use core::iter; -use num::Zero; - -fn binary_boolean_kleene_kernel( - left: &BooleanArray, - right: &BooleanArray, - op: F, -) -> Result -where - F: Fn(u64, u64, u64, u64) -> (u64, u64), -{ - if left.len() != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform bitwise operation on arrays of different length".to_string(), - )); - } - - // length and offset of boolean array is measured in bits - let len = left.len(); +use crate::util::bit_util::ceil; + +/// Updates null buffer based on data buffer and null buffer of the operand at other side +/// in boolean AND kernel with Kleene logic. In short, because for AND kernel, null AND false +/// results false. So we cannot simply AND two null buffers. This function updates null buffer +/// of one side if other side is a false value. +pub(crate) fn build_null_buffer_for_and_kleene( + left_data: &ArrayData, + left_offset: usize, + right_data: &ArrayData, + right_offset: usize, + len_in_bits: usize, +) -> Option { + let left_buffer = &left_data.buffers()[0]; + let right_buffer = &right_data.buffers()[0]; - // result length measured in bytes (incl. remainder) - let mut result_len = round_upto_multiple_of_64(len) / 8; - // The iterator that applies the kleene_op closure always chains an additional iteration - // for the remainder chunk, even without a remainder. If the remainder is absent - // (length % 64 == 0), kleene_op would resize the result buffers (value_buffer and - // valid_buffer) to store 8 additional bytes, because result_len wouldn't include a remainder - // chunk. The resizing is unnecessary and expensive. We can prevent it by adding 8 bytes to - // result_len here. Nonetheless, all bits of these 8 bytes will be 0. - if len % 64 == 0 { - result_len += 8; + let left_null_buffer = left_data.null_buffer(); + let right_null_buffer = right_data.null_buffer(); + + match (left_null_buffer, right_null_buffer) { + (None, None) => None, + (Some(left_null_buffer), None) => { + // The right side has no null values. + // The final null bit is set only if: + // 1. left null bit is set, or + // 2. right data bit is false (because null AND false = false). + Some(bitwise_bin_op_helper( + left_null_buffer, + left_offset, + right_buffer, + right_offset, + len_in_bits, + |a, b| a | !b, + )) + } + (None, Some(right_null_buffer)) => { + // Same as above + Some(bitwise_bin_op_helper( + right_null_buffer, + right_offset, + left_buffer, + left_offset, + len_in_bits, + |a, b| a | !b, + )) + } + (Some(left_null_buffer), Some(right_null_buffer)) => { + // Follow the same logic above. Both sides have null values. + // Assume a is left null bits, b is left data bits, c is right null bits, + // d is right data bits. + // The final null bits are: + // (a | (c & !d)) & (c | (a & !b)) + Some(bitwise_quaternary_op_helper( + left_null_buffer, + left_offset, + left_buffer, + left_offset, + right_null_buffer, + right_offset, + right_buffer, + right_offset, + len_in_bits, + |a, b, c, d| (a | (c & !d)) & (c | (a & !b)), + )) + } } +} - let mut value_buffer = MutableBuffer::new(result_len); - let mut valid_buffer = MutableBuffer::new(result_len); - - let kleene_op = |((left_data, left_valid), (right_data, right_valid)): ( - (u64, u64), - (u64, u64), - )| { - let left_true = left_valid & left_data; - let left_false = left_valid & !left_data; - - let right_true = right_valid & right_data; - let right_false = right_valid & !right_data; - - let (value, valid) = op(left_true, left_false, right_true, right_false); - - value_buffer.extend_from_slice(&[value]); - valid_buffer.extend_from_slice(&[valid]); - }; +/// For AND/OR kernels, the result of null buffer is simply a bitwise `and` operation. +pub(crate) fn build_null_buffer_for_and_or( + left_data: &ArrayData, + _left_offset: usize, + right_data: &ArrayData, + _right_offset: usize, + len_in_bits: usize, +) -> Option { + // `arrays` are not empty, so safely do `unwrap` directly. + combine_option_bitmap(&[left_data, right_data], len_in_bits).unwrap() +} - let left_offset = left.offset(); - let right_offset = right.offset(); +/// Updates null buffer based on data buffer and null buffer of the operand at other side +/// in boolean OR kernel with Kleene logic. In short, because for OR kernel, null OR true +/// results true. So we cannot simply AND two null buffers. This function updates null +/// buffer of one side if other side is a true value. +pub(crate) fn build_null_buffer_for_or_kleene( + left_data: &ArrayData, + left_offset: usize, + right_data: &ArrayData, + right_offset: usize, + len_in_bits: usize, +) -> Option { + let left_buffer = &left_data.buffers()[0]; + let right_buffer = &right_data.buffers()[0]; - let left_buffer = left.values(); - let right_buffer = right.values(); - - let left_chunks = left_buffer.bit_chunks(left_offset, len); - let right_chunks = right_buffer.bit_chunks(right_offset, len); - - let left_rem = left_chunks.remainder_bits(); - let right_rem = right_chunks.remainder_bits(); - - let opt_left_valid_chunks_and_rem = left - .data_ref() - .null_buffer() - .map(|b| b.bit_chunks(left_offset, len)) - .map(|chunks| (chunks.iter(), chunks.remainder_bits())); - let opt_right_valid_chunks_and_rem = right - .data_ref() - .null_buffer() - .map(|b| b.bit_chunks(right_offset, len)) - .map(|chunks| (chunks.iter(), chunks.remainder_bits())); - - match ( - opt_left_valid_chunks_and_rem, - opt_right_valid_chunks_and_rem, - ) { - ( - Some((left_valid_chunks, left_valid_rem)), - Some((right_valid_chunks, right_valid_rem)), - ) => { - left_chunks - .iter() - .zip(left_valid_chunks) - .zip(right_chunks.iter().zip(right_valid_chunks)) - .chain(iter::once(( - (left_rem, left_valid_rem), - (right_rem, right_valid_rem), - ))) - .for_each(kleene_op); - } - (Some((left_valid_chunks, left_valid_rem)), None) => { - left_chunks - .iter() - .zip(left_valid_chunks) - .zip(right_chunks.iter().zip(iter::repeat(u64::MAX))) - .chain(iter::once(( - (left_rem, left_valid_rem), - (right_rem, u64::MAX), - ))) - .for_each(kleene_op); + let left_null_buffer = left_data.null_buffer(); + let right_null_buffer = right_data.null_buffer(); + + match (left_null_buffer, right_null_buffer) { + (None, None) => None, + (Some(left_null_buffer), None) => { + // The right side has no null values. + // The final null bit is set only if: + // 1. left null bit is set, or + // 2. right data bit is true (because null OR true = true). + Some(bitwise_bin_op_helper( + left_null_buffer, + left_offset, + right_buffer, + right_offset, + len_in_bits, + |a, b| a | b, + )) } - (None, Some((right_valid_chunks, right_valid_rem))) => { - left_chunks - .iter() - .zip(iter::repeat(u64::MAX)) - .zip(right_chunks.iter().zip(right_valid_chunks)) - .chain(iter::once(( - (left_rem, u64::MAX), - (right_rem, right_valid_rem), - ))) - .for_each(kleene_op); + (None, Some(right_null_buffer)) => { + // Same as above + Some(bitwise_bin_op_helper( + right_null_buffer, + right_offset, + left_buffer, + left_offset, + len_in_bits, + |a, b| a | b, + )) } - (None, None) => { - left_chunks - .iter() - .zip(iter::repeat(u64::MAX)) - .zip(right_chunks.iter().zip(iter::repeat(u64::MAX))) - .chain(iter::once(((left_rem, u64::MAX), (right_rem, u64::MAX)))) - .for_each(kleene_op); + (Some(left_null_buffer), Some(right_null_buffer)) => { + // Follow the same logic above. Both sides have null values. + // Assume a is left null bits, b is left data bits, c is right null bits, + // d is right data bits. + // The final null bits are: + // (a | (c & d)) & (c | (a & b)) + Some(bitwise_quaternary_op_helper( + left_null_buffer, + left_offset, + left_buffer, + left_offset, + right_null_buffer, + right_offset, + right_buffer, + right_offset, + len_in_bits, + |a, b, c, d| (a | (c & d)) & (c | (a & b)), + )) } - }; - - let bool_buffer: Buffer = value_buffer.into(); - let bool_valid_buffer: Buffer = valid_buffer.into(); - - let array_data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - None, - Some(bool_valid_buffer), - left_offset, - vec![bool_buffer], - vec![], - ) - }; - - Ok(BooleanArray::from(array_data)) + } } /// Helper function to implement binary kernels -pub(crate) fn binary_boolean_kernel( +pub(crate) fn binary_boolean_kernel( left: &BooleanArray, right: &BooleanArray, op: F, + null_op: U, ) -> Result where F: Fn(&Buffer, usize, &Buffer, usize, usize) -> Buffer, + U: Fn(&ArrayData, usize, &ArrayData, usize, usize) -> Option, { if left.len() != right.len() { return Err(ArrowError::ComputeError( @@ -193,13 +199,14 @@ where let left_data = left.data_ref(); let right_data = right.data_ref(); - let null_bit_buffer = combine_option_bitmap(&[left_data, right_data], len)?; let left_buffer = &left_data.buffers()[0]; let right_buffer = &right_data.buffers()[0]; let left_offset = left.offset(); let right_offset = right.offset(); + let null_bit_buffer = null_op(left_data, left_offset, right_data, right_offset, len); + let values = op(left_buffer, left_offset, right_buffer, right_offset, len); let data = unsafe { @@ -234,7 +241,7 @@ where /// # } /// ``` pub fn and(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_kernel(left, right, buffer_bin_and) + binary_boolean_kernel(left, right, buffer_bin_and, build_null_buffer_for_and_or) } /// Logical 'and' boolean values with Kleene logic @@ -272,18 +279,12 @@ pub fn and(left: &BooleanArray, right: &BooleanArray) -> Result { /// /// If the operands have different lengths pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { - if left.null_count().is_zero() && right.null_count().is_zero() { - return and(left, right); - } - - let op = |left_true, left_false, right_true, right_false| { - ( - left_true & right_true, - left_false | right_false | (left_true & right_true), - ) - }; - - binary_boolean_kleene_kernel(left, right, op) + binary_boolean_kernel( + left, + right, + buffer_bin_and, + build_null_buffer_for_and_kleene, + ) } /// Performs `OR` operation on two arrays. If either left or right value is null then the @@ -304,7 +305,7 @@ pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result Result { - binary_boolean_kernel(left, right, buffer_bin_or) + binary_boolean_kernel(left, right, buffer_bin_or, build_null_buffer_for_and_or) } /// Logical 'or' boolean values with Kleene logic @@ -342,18 +343,7 @@ pub fn or(left: &BooleanArray, right: &BooleanArray) -> Result { /// /// If the operands have different lengths pub fn or_kleene(left: &BooleanArray, right: &BooleanArray) -> Result { - if left.null_count().is_zero() && right.null_count().is_zero() { - return or(left, right); - } - - let op = |left_true, left_false, right_true, right_false| { - ( - left_true | right_true, - left_true | right_true | (left_false & right_false), - ) - }; - - binary_boolean_kleene_kernel(left, right, op) + binary_boolean_kernel(left, right, buffer_bin_or, build_null_buffer_for_or_kleene) } /// Performs unary `NOT` operation on an arrays. If value is null then the result is also @@ -644,33 +634,6 @@ mod tests { assert_eq!(c, expected); } - #[test] - fn test_binary_boolean_kleene_kernel() { - // the kleene kernel is based on chunking and we want to also create - // cases, where the number of values is not a multiple of 64 - for &value in [true, false].iter() { - for &is_valid in [true, false].iter() { - for &n in [0usize, 1, 63, 64, 65, 127, 128].iter() { - let a = BooleanArray::from(vec![Some(true); n]); - let b = BooleanArray::from(vec![None; n]); - - let result = binary_boolean_kleene_kernel(&a, &b, |_, _, _, _| { - let tmp_value = if value { u64::MAX } else { 0 }; - let tmp_is_valid = if is_valid { u64::MAX } else { 0 }; - (tmp_value, tmp_is_valid) - }) - .unwrap(); - - assert_eq!(result.len(), n); - (0..n).for_each(|idx| { - assert_eq!(value, result.value(idx)); - assert_eq!(is_valid, result.is_valid(idx)); - }); - } - } - } - } - #[test] fn test_boolean_array_kleene_no_remainder() { let n = 1024; diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index 3dd2ad69264e..6b4f224708da 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -35,19 +35,31 @@ //! assert_eq!(7.0, c.value(2)); //! ``` +use chrono::format::strftime::StrftimeItems; +use chrono::format::{parse, Parsed}; +use chrono::Timelike; +use std::ops::{Div, Mul}; use std::str; use std::sync::Arc; -use crate::array::BasicDecimalArray; use crate::buffer::MutableBuffer; +use crate::compute::divide_scalar; use crate::compute::kernels::arithmetic::{divide, multiply}; use crate::compute::kernels::arity::unary; use crate::compute::kernels::cast_utils::string_to_timestamp_nanos; +use crate::compute::kernels::temporal::extract_component_from_array; +use crate::compute::kernels::temporal::return_compute_error_with; +use crate::compute::using_chrono_tz_and_utc_naive_date_time; use crate::datatypes::*; use crate::error::{ArrowError, Result}; +use crate::temporal_conversions::{ + EPOCH_DAYS_FROM_CE, MICROSECONDS, MILLISECONDS, MILLISECONDS_IN_DAY, NANOSECONDS, + SECONDS_IN_DAY, +}; use crate::{array::*, compute::take}; use crate::{buffer::Buffer, util::serialization::lexical_to_string}; -use num::{NumCast, ToPrimitive}; +use num::cast::AsPrimitive; +use num::{BigInt, NumCast, ToPrimitive}; /// CastOptions provides a way to override the default cast behaviors #[derive(Debug)] @@ -71,11 +83,14 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { match (from_type, to_type) { // TODO UTF8/unsigned numeric to decimal // cast one decimal type to another decimal type - (Decimal(_, _), Decimal(_, _)) => true, + (Decimal128(_, _), Decimal128(_, _)) => true, + (Decimal256(_, _), Decimal256(_, _)) => true, + (Decimal128(_, _), Decimal256(_, _)) => true, + (Decimal256(_, _), Decimal128(_, _)) => true, // signed numeric to decimal - (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal(_, _)) | + (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal128(_, _)) | // decimal to signed numeric - (Decimal(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) + (Decimal128(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) | ( Null, Boolean @@ -108,8 +123,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { | Map(_, _) | Dictionary(_, _) ) => true, - (Decimal(_, _), _) => false, - (_, Decimal(_, _)) => false, + (Decimal128(_, _), _) => false, + (_, Decimal128(_, _)) => false, (Struct(_), _) => false, (_, Struct(_)) => false, (LargeList(list_from), LargeList(list_to)) => { @@ -135,9 +150,27 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Utf8, LargeUtf8) => true, (LargeUtf8, Utf8) => true, - (Utf8, Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, None)) => true, + (Utf8, + Binary + | Date32 + | Date64 + | Time32(TimeUnit::Second) + | Time32(TimeUnit::Millisecond) + | Time64(TimeUnit::Microsecond) + | Time64(TimeUnit::Nanosecond) + | Timestamp(TimeUnit::Nanosecond, None) + ) => true, (Utf8, _) => DataType::is_numeric(to_type), - (LargeUtf8, Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, None)) => true, + (LargeUtf8, + LargeBinary + | Date32 + | Date64 + | Time32(TimeUnit::Second) + | Time32(TimeUnit::Millisecond) + | Time64(TimeUnit::Microsecond) + | Time64(TimeUnit::Nanosecond) + | Timestamp(TimeUnit::Nanosecond, None) + ) => true, (LargeUtf8, _) => DataType::is_numeric(to_type), (Timestamp(_, _), Utf8) | (Timestamp(_, _), LargeUtf8) => true, (Date32, Utf8) | (Date32, LargeUtf8) => true, @@ -269,65 +302,80 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS) } -// cast the integer array to defined decimal data type array -macro_rules! cast_integer_to_decimal { - ($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - let mul: i128 = 10_i128.pow(*$SCALE as u32); - let decimal_array = array - .iter() - .map(|v| { - v.map(|v| { - let v = v as i128; - // with_precision_and_scale validates the - // value is within range for the output precision - mul * v - }) - }) - .collect::() - .with_precision_and_scale(*$PRECISION, *$SCALE)?; - Ok(Arc::new(decimal_array)) - }}; +/// Cast the primitive array to defined decimal data type array +fn cast_primitive_to_decimal( + array: T, + op: F, + precision: u8, + scale: u8, +) -> Result> +where + F: Fn(T::Item) -> i128, +{ + #[allow(clippy::redundant_closure)] + let decimal_array = ArrayIter::new(array) + .map(|v| v.map(|v| op(v))) + .collect::() + .with_precision_and_scale(precision, scale)?; + + Ok(Arc::new(decimal_array)) } -// cast the floating-point array to defined decimal data type array -macro_rules! cast_floating_point_to_decimal { - ($ARRAY: expr, $ARRAY_TYPE: ident, $PRECISION : ident, $SCALE : ident) => {{ - let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); - let mul = 10_f64.powi(*$SCALE as i32); - let decimal_array = array - .iter() - .map(|v| { - v.map(|v| { - // with_precision_and_scale validates the - // value is within range for the output precision - ((v as f64) * mul) as i128 - }) - }) - .collect::() - .with_precision_and_scale(*$PRECISION, *$SCALE)?; - Ok(Arc::new(decimal_array)) - }}; +fn cast_integer_to_decimal( + array: &PrimitiveArray, + precision: u8, + scale: u8, +) -> Result> +where + ::Native: AsPrimitive, +{ + let mul: i128 = 10_i128.pow(scale as u32); + + // with_precision_and_scale validates the + // value is within range for the output precision + cast_primitive_to_decimal(array, |v| v.as_() * mul, precision, scale) +} + +fn cast_floating_point_to_decimal( + array: &PrimitiveArray, + precision: u8, + scale: u8, +) -> Result> +where + ::Native: AsPrimitive, +{ + let mul = 10_f64.powi(scale as i32); + + cast_primitive_to_decimal( + array, + |v| { + // with_precision_and_scale validates the + // value is within range for the output precision + (v.as_() * mul) as i128 + }, + precision, + scale, + ) } // cast the decimal array to integer array macro_rules! cast_decimal_to_integer { ($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ident, $DATA_TYPE : expr) => {{ - let array = $ARRAY.as_any().downcast_ref::().unwrap(); - let mut value_builder = $VALUE_BUILDER::new(array.len()); + let array = $ARRAY.as_any().downcast_ref::().unwrap(); + let mut value_builder = $VALUE_BUILDER::with_capacity(array.len()); let div: i128 = 10_i128.pow(*$SCALE as u32); let min_bound = ($NATIVE_TYPE::MIN) as i128; let max_bound = ($NATIVE_TYPE::MAX) as i128; for i in 0..array.len() { if array.is_null(i) { - value_builder.append_null()?; + value_builder.append_null(); } else { let v = array.value(i).as_i128() / div; // check the overflow // For example: Decimal(128,10,0) as i8 // 128 is out of range i8 if v <= max_bound && v >= min_bound { - value_builder.append_value(v as $NATIVE_TYPE)?; + value_builder.append_value(v as $NATIVE_TYPE); } else { return Err(ArrowError::CastError(format!( "value of {} is out of range {}", @@ -343,17 +391,17 @@ macro_rules! cast_decimal_to_integer { // cast the decimal array to floating-point array macro_rules! cast_decimal_to_float { ($ARRAY:expr, $SCALE : ident, $VALUE_BUILDER: ident, $NATIVE_TYPE : ty) => {{ - let array = $ARRAY.as_any().downcast_ref::().unwrap(); + let array = $ARRAY.as_any().downcast_ref::().unwrap(); let div = 10_f64.powi(*$SCALE as i32); - let mut value_builder = $VALUE_BUILDER::new(array.len()); + let mut value_builder = $VALUE_BUILDER::with_capacity(array.len()); for i in 0..array.len() { if array.is_null(i) { - value_builder.append_null()?; + value_builder.append_null(); } else { // The range of f32 or f64 is larger than i128, we don't need to check overflow. // cast the i128 to f64 will lose precision, for example the `112345678901234568` will be as `112345678901234560`. let v = (array.value(i).as_i128() as f64 / div) as $NATIVE_TYPE; - value_builder.append_value(v)?; + value_builder.append_value(v); } } Ok(Arc::new(value_builder.finish())) @@ -394,8 +442,19 @@ pub fn cast_with_options( return Ok(array.clone()); } match (from_type, to_type) { - (Decimal(_, s1), Decimal(p2, s2)) => cast_decimal_to_decimal(array, s1, p2, s2), - (Decimal(_, scale), _) => { + (Decimal128(_, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal::<16, 16>(array, s1, p2, s2) + } + (Decimal256(_, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal::<32, 32>(array, s1, p2, s2) + } + (Decimal128(_, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal::<16, 32>(array, s1, p2, s2) + } + (Decimal256(_, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal::<32, 16>(array, s1, p2, s2) + } + (Decimal128(_, scale), _) => { // cast decimal to other type match to_type { Int8 => { @@ -423,28 +482,40 @@ pub fn cast_with_options( ))), } } - (_, Decimal(precision, scale)) => { + (_, Decimal128(precision, scale)) => { // cast data to decimal match from_type { // TODO now just support signed numeric to decimal, support decimal to numeric later - Int8 => { - cast_integer_to_decimal!(array, Int8Array, precision, scale) - } - Int16 => { - cast_integer_to_decimal!(array, Int16Array, precision, scale) - } - Int32 => { - cast_integer_to_decimal!(array, Int32Array, precision, scale) - } - Int64 => { - cast_integer_to_decimal!(array, Int64Array, precision, scale) - } - Float32 => { - cast_floating_point_to_decimal!(array, Float32Array, precision, scale) - } - Float64 => { - cast_floating_point_to_decimal!(array, Float64Array, precision, scale) - } + Int8 => cast_integer_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Int16 => cast_integer_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Int32 => cast_integer_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Int64 => cast_integer_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Float32 => cast_floating_point_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), + Float64 => cast_floating_point_to_decimal( + as_primitive_array::(array), + *precision, + *scale, + ), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( "Casting from {:?} to {:?} not supported", @@ -629,6 +700,19 @@ pub fn cast_with_options( Float64 => cast_string_to_numeric::(array, cast_options), Date32 => cast_string_to_date32::(&**array, cast_options), Date64 => cast_string_to_date64::(&**array, cast_options), + Binary => cast_string_to_binary(array), + Time32(TimeUnit::Second) => { + cast_string_to_time32second::(&**array, cast_options) + } + Time32(TimeUnit::Millisecond) => { + cast_string_to_time32millisecond::(&**array, cast_options) + } + Time64(TimeUnit::Microsecond) => { + cast_string_to_time64microsecond::(&**array, cast_options) + } + Time64(TimeUnit::Nanosecond) => { + cast_string_to_time64nanosecond::(&**array, cast_options) + } Timestamp(TimeUnit::Nanosecond, None) => { cast_string_to_timestamp_ns::(&**array, cast_options) } @@ -649,18 +733,18 @@ pub fn cast_with_options( Int64 => cast_numeric_to_string::(array), Float32 => cast_numeric_to_string::(array), Float64 => cast_numeric_to_string::(array), - Timestamp(unit, _) => match unit { + Timestamp(unit, tz) => match unit { TimeUnit::Nanosecond => { - cast_timestamp_to_string::(array) + cast_timestamp_to_string::(array, tz) } TimeUnit::Microsecond => { - cast_timestamp_to_string::(array) + cast_timestamp_to_string::(array, tz) } TimeUnit::Millisecond => { - cast_timestamp_to_string::(array) + cast_timestamp_to_string::(array, tz) } TimeUnit::Second => { - cast_timestamp_to_string::(array) + cast_timestamp_to_string::(array, tz) } }, Date32 => cast_date32_to_string::(array), @@ -705,18 +789,18 @@ pub fn cast_with_options( Int64 => cast_numeric_to_string::(array), Float32 => cast_numeric_to_string::(array), Float64 => cast_numeric_to_string::(array), - Timestamp(unit, _) => match unit { + Timestamp(unit, tz) => match unit { TimeUnit::Nanosecond => { - cast_timestamp_to_string::(array) + cast_timestamp_to_string::(array, tz) } TimeUnit::Microsecond => { - cast_timestamp_to_string::(array) + cast_timestamp_to_string::(array, tz) } TimeUnit::Millisecond => { - cast_timestamp_to_string::(array) + cast_timestamp_to_string::(array, tz) } TimeUnit::Second => { - cast_timestamp_to_string::(array) + cast_timestamp_to_string::(array, tz) } }, Date32 => cast_date32_to_string::(array), @@ -763,6 +847,19 @@ pub fn cast_with_options( Float64 => cast_string_to_numeric::(array, cast_options), Date32 => cast_string_to_date32::(&**array, cast_options), Date64 => cast_string_to_date64::(&**array, cast_options), + LargeBinary => cast_string_to_binary(array), + Time32(TimeUnit::Second) => { + cast_string_to_time32second::(&**array, cast_options) + } + Time32(TimeUnit::Millisecond) => { + cast_string_to_time32millisecond::(&**array, cast_options) + } + Time64(TimeUnit::Microsecond) => { + cast_string_to_time64microsecond::(&**array, cast_options) + } + Time64(TimeUnit::Nanosecond) => { + cast_string_to_time64nanosecond::(&**array, cast_options) + } Timestamp(TimeUnit::Nanosecond, None) => { cast_string_to_timestamp_ns::(&**array, cast_options) } @@ -1042,10 +1139,7 @@ pub fn cast_with_options( // we either divide or multiply, depending on size of each unit // units are never the same when the types are the same let converted = if from_size >= to_size { - divide( - &time_array, - &Int64Array::from(vec![from_size / to_size; array.len()]), - )? + divide_scalar(&time_array, from_size / to_size)? } else { multiply( &time_array, @@ -1075,12 +1169,15 @@ pub fn cast_with_options( (Timestamp(from_unit, _), Date32) => { let time_array = Int64Array::from(array.data().clone()); let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY; - let mut b = Date32Builder::new(array.len()); + + // Int32Array::from_iter(tim.iter) + let mut b = Date32Builder::with_capacity(array.len()); + for i in 0..array.len() { - if array.is_null(i) { - b.append_null()?; + if time_array.is_null(i) { + b.append_null(); } else { - b.append_value((time_array.value(i) / from_size) as i32)?; + b.append_value((time_array.value(i) / from_size) as i32); } } @@ -1166,6 +1263,41 @@ pub fn cast_with_options( } } +/// Cast to string array to binary array +fn cast_string_to_binary(array: &ArrayRef) -> Result { + let from_type = array.data_type(); + match *from_type { + DataType::Utf8 => { + let data = unsafe { + array + .data() + .clone() + .into_builder() + .data_type(DataType::Binary) + .build_unchecked() + }; + + Ok(Arc::new(BinaryArray::from(data)) as ArrayRef) + } + DataType::LargeUtf8 => { + let data = unsafe { + array + .data() + .clone() + .into_builder() + .data_type(DataType::LargeBinary) + .build_unchecked() + }; + + Ok(Arc::new(LargeBinaryArray::from(data)) as ArrayRef) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "{:?} cannot be converted to binary array", + from_type + ))), + } +} + /// Get the time unit as a multiple of a second const fn time_unit_multiple(unit: &TimeUnit) -> i64 { match unit { @@ -1176,48 +1308,124 @@ const fn time_unit_multiple(unit: &TimeUnit) -> i64 { } } -/// Number of seconds in a day -const SECONDS_IN_DAY: i64 = 86_400; -/// Number of milliseconds in a second -const MILLISECONDS: i64 = 1_000; -/// Number of microseconds in a second -const MICROSECONDS: i64 = 1_000_000; -/// Number of nanoseconds in a second -const NANOSECONDS: i64 = 1_000_000_000; -/// Number of milliseconds in a day -const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; -/// Number of days between 0001-01-01 and 1970-01-01 -const EPOCH_DAYS_FROM_CE: i32 = 719_163; - /// Cast one type of decimal array to another type of decimal array -fn cast_decimal_to_decimal( +fn cast_decimal_to_decimal( array: &ArrayRef, - input_scale: &usize, - output_precision: &usize, - output_scale: &usize, + input_scale: &u8, + output_precision: &u8, + output_scale: &u8, ) -> Result { - let array = array.as_any().downcast_ref::().unwrap(); - - let output_array = if input_scale > output_scale { + if input_scale > output_scale { // For example, input_scale is 4 and output_scale is 3; // Original value is 11234_i128, and will be cast to 1123_i128. let div = 10_i128.pow((input_scale - output_scale) as u32); - array - .iter() - .map(|v| v.map(|v| v / div)) - .collect::() + if BYTE_WIDTH1 == 16 { + let array = array.as_any().downcast_ref::().unwrap(); + let iter = array.iter().map(|v| v.map(|v| v.as_i128() / div)); + if BYTE_WIDTH2 == 16 { + let output_array = iter + .collect::() + .with_precision_and_scale(*output_precision, *output_scale)?; + + Ok(Arc::new(output_array)) + } else { + let output_array = iter + .map(|v| v.map(BigInt::from)) + .collect::() + .with_precision_and_scale(*output_precision, *output_scale)?; + + Ok(Arc::new(output_array)) + } + } else { + let array = array.as_any().downcast_ref::().unwrap(); + let iter = array.iter().map(|v| v.map(|v| v.to_big_int().div(div))); + if BYTE_WIDTH2 == 16 { + let values = iter + .map(|v| { + if v.is_none() { + Ok(None) + } else { + v.as_ref().and_then(|v| v.to_i128()) + .ok_or_else(|| { + ArrowError::InvalidArgumentError( + format!("{:?} cannot be casted to 128-bit integer for Decimal128", v), + ) + }) + .map(Some) + } + }) + .collect::>>()?; + + let output_array = values + .into_iter() + .collect::() + .with_precision_and_scale(*output_precision, *output_scale)?; + + Ok(Arc::new(output_array)) + } else { + let output_array = iter + .collect::() + .with_precision_and_scale(*output_precision, *output_scale)?; + + Ok(Arc::new(output_array)) + } + } } else { // For example, input_scale is 3 and output_scale is 4; // Original value is 1123_i128, and will be cast to 11230_i128. let mul = 10_i128.pow((output_scale - input_scale) as u32); - array - .iter() - .map(|v| v.map(|v| v * mul)) - .collect::() - } - .with_precision_and_scale(*output_precision, *output_scale)?; + if BYTE_WIDTH1 == 16 { + let array = array.as_any().downcast_ref::().unwrap(); + let iter = array.iter().map(|v| v.map(|v| v.as_i128() * mul)); + if BYTE_WIDTH2 == 16 { + let output_array = iter + .collect::() + .with_precision_and_scale(*output_precision, *output_scale)?; + + Ok(Arc::new(output_array)) + } else { + let output_array = iter + .map(|v| v.map(BigInt::from)) + .collect::() + .with_precision_and_scale(*output_precision, *output_scale)?; - Ok(Arc::new(output_array)) + Ok(Arc::new(output_array)) + } + } else { + let array = array.as_any().downcast_ref::().unwrap(); + let iter = array.iter().map(|v| v.map(|v| v.to_big_int().mul(mul))); + if BYTE_WIDTH2 == 16 { + let values = iter + .map(|v| { + if v.is_none() { + Ok(None) + } else { + v.as_ref().and_then(|v| v.to_i128()) + .ok_or_else(|| { + ArrowError::InvalidArgumentError( + format!("{:?} cannot be casted to 128-bit integer for Decimal128", v), + ) + }) + .map(Some) + } + }) + .collect::>>()?; + + let output_array = values + .into_iter() + .collect::() + .with_precision_and_scale(*output_precision, *output_scale)?; + + Ok(Arc::new(output_array)) + } else { + let output_array = iter + .collect::() + .with_precision_and_scale(*output_precision, *output_scale)?; + + Ok(Arc::new(output_array)) + } + } + } } /// Cast an array by changing its array_data type to the desired type @@ -1279,7 +1487,10 @@ where } /// Cast timestamp types to Utf8/LargeUtf8 -fn cast_timestamp_to_string(array: &ArrayRef) -> Result +fn cast_timestamp_to_string( + array: &ArrayRef, + tz: &Option, +) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: From<::Native>, @@ -1287,17 +1498,28 @@ where { let array = array.as_any().downcast_ref::>().unwrap(); - Ok(Arc::new( - (0..array.len()) - .map(|ix| { - if array.is_null(ix) { - None - } else { - array.value_as_datetime(ix).map(|v| v.to_string()) - } - }) - .collect::>(), - )) + let mut builder = GenericStringBuilder::::new(); + + if let Some(tz) = tz { + let mut scratch = Parsed::new(); + // The macro calls `value_as_datetime_with_tz` on timestamp values of the array. + // After applying timezone offset on the datatime, calling `to_string` to get + // the strings. + extract_component_from_array!( + array, + builder, + to_string, + value_as_datetime_with_tz, + tz, + scratch, + |h| h + ) + } else { + // No timezone available. Calling `to_string` on the datatime value simply. + extract_component_from_array!(array, builder, to_string, value_as_datetime, |h| h) + } + + Ok(Arc::new(builder.finish()) as ArrayRef) } /// Cast date32 types to Utf8/LargeUtf8 @@ -1392,35 +1614,28 @@ where ::Native: lexical_core::FromLexical, { if cast_options.safe { - let iter = (0..from.len()).map(|i| { - if from.is_null(i) { - None - } else { - lexical_core::parse(from.value(i).as_bytes()).ok() - } - }); + let iter = from + .iter() + .map(|v| v.and_then(|v| lexical_core::parse(v.as_bytes()).ok())); // Benefit: // 20% performance improvement // Soundness: // The iterator is trustedLen because it comes from an `StringArray`. Ok(unsafe { PrimitiveArray::::from_trusted_len_iter(iter) }) } else { - let vec = (0..from.len()) - .map(|i| { - if from.is_null(i) { - Ok(None) - } else { - let string = from.value(i); - let result = lexical_core::parse(string.as_bytes()); - Some(result.map_err(|_| { + let vec = from + .iter() + .map(|v| { + v.map(|v| { + lexical_core::parse(v.as_bytes()).map_err(|_| { ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {} type", - string, - std::any::type_name::() + "Cannot cast string '{}' to value of {:?} type", + v, + T::DATA_TYPE, )) - })) - .transpose() - } + }) + }) + .transpose() }) .collect::>>()?; // Benefit: @@ -1443,16 +1658,12 @@ fn cast_string_to_date32( .unwrap(); let array = if cast_options.safe { - let iter = (0..string_array.len()).map(|i| { - if string_array.is_null(i) { - None - } else { - string_array - .value(i) - .parse::() + let iter = string_array.iter().map(|v| { + v.and_then(|v| { + v.parse::() .map(|date| date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) .ok() - } + }) }); // Benefit: @@ -1461,25 +1672,21 @@ fn cast_string_to_date32( // The iterator is trustedLen because it comes from an `StringArray`. unsafe { Date32Array::from_trusted_len_iter(iter) } } else { - let vec = (0..string_array.len()) - .map(|i| { - if string_array.is_null(i) { - Ok(None) - } else { - let string = string_array - .value(i); - - let result = string - .parse::() - .map(|date| date.num_days_from_ce() - EPOCH_DAYS_FROM_CE); - - Some(result.map_err(|_| { - ArrowError::CastError( - format!("Cannot cast string '{}' to value of arrow::datatypes::types::Date32Type type", string), - ) - })) - .transpose() - } + let vec = string_array + .iter() + .map(|v| { + v.map(|v| { + v.parse::() + .map(|date| date.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + DataType::Date32 + )) + }) + }) + .transpose() }) .collect::>>>()?; @@ -1504,16 +1711,12 @@ fn cast_string_to_date64( .unwrap(); let array = if cast_options.safe { - let iter = (0..string_array.len()).map(|i| { - if string_array.is_null(i) { - None - } else { - string_array - .value(i) - .parse::() + let iter = string_array.iter().map(|v| { + v.and_then(|v| { + v.parse::() .map(|datetime| datetime.timestamp_millis()) .ok() - } + }) }); // Benefit: @@ -1522,25 +1725,21 @@ fn cast_string_to_date64( // The iterator is trustedLen because it comes from an `StringArray`. unsafe { Date64Array::from_trusted_len_iter(iter) } } else { - let vec = (0..string_array.len()) - .map(|i| { - if string_array.is_null(i) { - Ok(None) - } else { - let string = string_array - .value(i); - - let result = string - .parse::() - .map(|datetime| datetime.timestamp_millis()); - - Some(result.map_err(|_| { - ArrowError::CastError( - format!("Cannot cast string '{}' to value of arrow::datatypes::types::Date64Type type", string), - ) - })) - .transpose() - } + let vec = string_array + .iter() + .map(|v| { + v.map(|v| { + v.parse::() + .map(|datetime| datetime.timestamp_millis()) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + DataType::Date64 + )) + }) + }) + .transpose() }) .collect::>>>()?; @@ -1554,6 +1753,262 @@ fn cast_string_to_date64( Ok(Arc::new(array) as ArrayRef) } +/// Casts generic string arrays to `Time32SecondArray` +fn cast_string_to_time32second( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + /// The number of nanoseconds per millisecond. + const NANOS_PER_SEC: u32 = 1_000_000_000; + + let string_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let array = if cast_options.safe { + let iter = string_array.iter().map(|v| { + v.and_then(|v| { + v.parse::() + .map(|time| { + (time.num_seconds_from_midnight() + + time.nanosecond() / NANOS_PER_SEC) + as i32 + }) + .ok() + }) + }); + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { Time32SecondArray::from_trusted_len_iter(iter) } + } else { + let vec = string_array + .iter() + .map(|v| { + v.map(|v| { + v.parse::() + .map(|time| { + (time.num_seconds_from_midnight() + + time.nanosecond() / NANOS_PER_SEC) + as i32 + }) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + DataType::Time32(TimeUnit::Second) + )) + }) + }) + .transpose() + }) + .collect::>>>()?; + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { Time32SecondArray::from_trusted_len_iter(vec.iter()) } + }; + + Ok(Arc::new(array) as ArrayRef) +} + +/// Casts generic string arrays to `Time32MillisecondArray` +fn cast_string_to_time32millisecond( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + /// The number of nanoseconds per millisecond. + const NANOS_PER_MILLI: u32 = 1_000_000; + /// The number of milliseconds per second. + const MILLIS_PER_SEC: u32 = 1_000; + + let string_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let array = if cast_options.safe { + let iter = string_array.iter().map(|v| { + v.and_then(|v| { + v.parse::() + .map(|time| { + (time.num_seconds_from_midnight() * MILLIS_PER_SEC + + time.nanosecond() / NANOS_PER_MILLI) + as i32 + }) + .ok() + }) + }); + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { Time32MillisecondArray::from_trusted_len_iter(iter) } + } else { + let vec = string_array + .iter() + .map(|v| { + v.map(|v| { + v.parse::() + .map(|time| { + (time.num_seconds_from_midnight() * MILLIS_PER_SEC + + time.nanosecond() / NANOS_PER_MILLI) + as i32 + }) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + DataType::Time32(TimeUnit::Millisecond) + )) + }) + }) + .transpose() + }) + .collect::>>>()?; + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { Time32MillisecondArray::from_trusted_len_iter(vec.iter()) } + }; + + Ok(Arc::new(array) as ArrayRef) +} + +/// Casts generic string arrays to `Time64MicrosecondArray` +fn cast_string_to_time64microsecond( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + /// The number of nanoseconds per microsecond. + const NANOS_PER_MICRO: i64 = 1_000; + /// The number of microseconds per second. + const MICROS_PER_SEC: i64 = 1_000_000; + + let string_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let array = if cast_options.safe { + let iter = string_array.iter().map(|v| { + v.and_then(|v| { + v.parse::() + .map(|time| { + time.num_seconds_from_midnight() as i64 * MICROS_PER_SEC + + time.nanosecond() as i64 / NANOS_PER_MICRO + }) + .ok() + }) + }); + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { Time64MicrosecondArray::from_trusted_len_iter(iter) } + } else { + let vec = string_array + .iter() + .map(|v| { + v.map(|v| { + v.parse::() + .map(|time| { + time.num_seconds_from_midnight() as i64 * MICROS_PER_SEC + + time.nanosecond() as i64 / NANOS_PER_MICRO + }) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + DataType::Time64(TimeUnit::Microsecond) + )) + }) + }) + .transpose() + }) + .collect::>>>()?; + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { Time64MicrosecondArray::from_trusted_len_iter(vec.iter()) } + }; + + Ok(Arc::new(array) as ArrayRef) +} + +/// Casts generic string arrays to `Time64NanosecondArray` +fn cast_string_to_time64nanosecond( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + /// The number of nanoseconds per second. + const NANOS_PER_SEC: i64 = 1_000_000_000; + + let string_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + let array = if cast_options.safe { + let iter = string_array.iter().map(|v| { + v.and_then(|v| { + v.parse::() + .map(|time| { + time.num_seconds_from_midnight() as i64 * NANOS_PER_SEC + + time.nanosecond() as i64 + }) + .ok() + }) + }); + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { Time64NanosecondArray::from_trusted_len_iter(iter) } + } else { + let vec = string_array + .iter() + .map(|v| { + v.map(|v| { + v.parse::() + .map(|time| { + time.num_seconds_from_midnight() as i64 * NANOS_PER_SEC + + time.nanosecond() as i64 + }) + .map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to value of {:?} type", + v, + DataType::Time64(TimeUnit::Nanosecond) + )) + }) + }) + .transpose() + }) + .collect::>>>()?; + + // Benefit: + // 20% performance improvement + // Soundness: + // The iterator is trustedLen because it comes from an `StringArray`. + unsafe { Time64NanosecondArray::from_trusted_len_iter(vec.iter()) } + }; + + Ok(Arc::new(array) as ArrayRef) +} + /// Casts generic string arrays to TimeStampNanosecondArray fn cast_string_to_timestamp_ns( array: &dyn Array, @@ -1565,28 +2020,18 @@ fn cast_string_to_timestamp_ns( .unwrap(); let array = if cast_options.safe { - let iter = (0..string_array.len()).map(|i| { - if string_array.is_null(i) { - None - } else { - string_to_timestamp_nanos(string_array.value(i)).ok() - } - }); + let iter = string_array + .iter() + .map(|v| v.and_then(|v| string_to_timestamp_nanos(v).ok())); // Benefit: // 20% performance improvement // Soundness: // The iterator is trustedLen because it comes from an `StringArray`. unsafe { TimestampNanosecondArray::from_trusted_len_iter(iter) } } else { - let vec = (0..string_array.len()) - .map(|i| { - if string_array.is_null(i) { - Ok(None) - } else { - let result = string_to_timestamp_nanos(string_array.value(i)); - Some(result).transpose() - } - }) + let vec = string_array + .iter() + .map(|v| v.map(string_to_timestamp_nanos).transpose()) .collect::>>>()?; // Benefit: @@ -1646,15 +2091,15 @@ fn numeric_to_bool_cast(from: &PrimitiveArray) -> Result where T: ArrowPrimitiveType + ArrowNumericType, { - let mut b = BooleanBuilder::new(from.len()); + let mut b = BooleanBuilder::with_capacity(from.len()); for i in 0..from.len() { if from.is_null(i) { - b.append_null()?; + b.append_null(); } else if from.value(i) != T::default_value() { - b.append_value(true)?; + b.append_value(true); } else { - b.append_value(false)?; + b.append_value(false); } } @@ -1901,14 +2346,14 @@ where .downcast_ref::>() .unwrap(); - let keys_builder = PrimitiveBuilder::::new(values.len()); - let values_builder = PrimitiveBuilder::::new(values.len()); + let keys_builder = PrimitiveBuilder::::with_capacity(values.len()); + let values_builder = PrimitiveBuilder::::with_capacity(values.len()); let mut b = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); // copy each element one at a time for i in 0..values.len() { if values.is_null(i) { - b.append_null()?; + b.append_null(); } else { b.append(values.value(i))?; } @@ -1928,14 +2373,14 @@ where let cast_values = cast_with_options(array, &DataType::Utf8, cast_options)?; let values = cast_values.as_any().downcast_ref::().unwrap(); - let keys_builder = PrimitiveBuilder::::new(values.len()); - let values_builder = StringBuilder::new(values.len()); + let keys_builder = PrimitiveBuilder::::with_capacity(values.len()); + let values_builder = StringBuilder::with_capacity(1024, values.len()); let mut b = StringDictionaryBuilder::new(keys_builder, values_builder); // copy each element one at a time for i in 0..values.len() { if values.is_null(i) { - b.append_null()?; + b.append_null(); } else { b.append(values.value(i))?; } @@ -2135,8 +2580,8 @@ where #[cfg(test)] mod tests { use super::*; - use crate::array::BasicDecimalArray; - use crate::util::decimal::Decimal128; + use crate::datatypes::TimeUnit; + use crate::util::decimal::{Decimal128, Decimal256}; use crate::{buffer::Buffer, util::display::array_value_to_string}; macro_rules! generate_cast_test_case { @@ -2165,27 +2610,38 @@ mod tests { } fn create_decimal_array( - array: &[Option], - precision: usize, - scale: usize, - ) -> Result { + array: Vec>, + precision: u8, + scale: u8, + ) -> Result { array - .iter() - .collect::() + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + + fn create_decimal256_array( + array: Vec>, + precision: u8, + scale: u8, + ) -> Result { + array + .into_iter() + .collect::() .with_precision_and_scale(precision, scale) } #[test] - fn test_cast_decimal_to_decimal() { - let input_type = DataType::Decimal(20, 3); - let output_type = DataType::Decimal(20, 4); + fn test_cast_decimal128_to_decimal128() { + let input_type = DataType::Decimal128(20, 3); + let output_type = DataType::Decimal128(20, 4); assert!(can_cast_types(&input_type, &output_type)); let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; - let input_decimal_array = create_decimal_array(&array, 20, 3).unwrap(); + let input_decimal_array = create_decimal_array(array, 20, 3).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; generate_cast_test_case!( &array, - DecimalArray, + Decimal128Array, &output_type, vec![ Some(Decimal128::new_from_i128(20, 4, 11234560_i128)), @@ -2196,22 +2652,113 @@ mod tests { ); // negative test let array = vec![Some(123456), None]; - let input_decimal_array = create_decimal_array(&array, 10, 0).unwrap(); + let input_decimal_array = create_decimal_array(array, 10, 0).unwrap(); let array = Arc::new(input_decimal_array) as ArrayRef; - let result = cast(&array, &DataType::Decimal(2, 2)); + let result = cast(&array, &DataType::Decimal128(2, 2)); assert!(result.is_err()); - assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal of precision 2. Max is 99", + assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99", result.unwrap_err().to_string()); } + #[test] + fn test_cast_decimal128_to_decimal256() { + let input_type = DataType::Decimal128(20, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let input_decimal_array = create_decimal_array(array, 20, 3).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some( + Decimal256::from_big_int(&BigInt::from(11234560_i128), 20, 4) + .unwrap() + ), + Some( + Decimal256::from_big_int(&BigInt::from(21234560_i128), 20, 4) + .unwrap() + ), + Some( + Decimal256::from_big_int(&BigInt::from(31234560_i128), 20, 4) + .unwrap() + ), + None + ] + ); + } + + #[test] + fn test_cast_decimal256_to_decimal128() { + let input_type = DataType::Decimal256(20, 3); + let output_type = DataType::Decimal128(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![ + Some(BigInt::from(1123456)), + Some(BigInt::from(2123456)), + Some(BigInt::from(3123456)), + None, + ]; + let input_decimal_array = create_decimal256_array(array, 20, 3).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(Decimal128::new_from_i128(20, 4, 11234560_i128)), + Some(Decimal128::new_from_i128(20, 4, 21234560_i128)), + Some(Decimal128::new_from_i128(20, 4, 31234560_i128)), + None + ] + ); + } + + #[test] + fn test_cast_decimal256_to_decimal256() { + let input_type = DataType::Decimal256(20, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![ + Some(BigInt::from(1123456)), + Some(BigInt::from(2123456)), + Some(BigInt::from(3123456)), + None, + ]; + let input_decimal_array = create_decimal256_array(array, 20, 3).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some( + Decimal256::from_big_int(&BigInt::from(11234560_i128), 20, 4) + .unwrap() + ), + Some( + Decimal256::from_big_int(&BigInt::from(21234560_i128), 20, 4) + .unwrap() + ), + Some( + Decimal256::from_big_int(&BigInt::from(31234560_i128), 20, 4) + .unwrap() + ), + None + ] + ); + } + #[test] fn test_cast_decimal_to_numeric() { - let decimal_type = DataType::Decimal(38, 2); + let decimal_type = DataType::Decimal128(38, 2); // negative test assert!(!can_cast_types(&decimal_type, &DataType::UInt8)); let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; - let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let decimal_array = create_decimal_array(value_array, 38, 2).unwrap(); let array = Arc::new(decimal_array) as ArrayRef; // i8 generate_cast_test_case!( @@ -2258,7 +2805,7 @@ mod tests { // overflow test: out of range of max i8 let value_array: Vec> = vec![Some(24400)]; - let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let decimal_array = create_decimal_array(value_array, 38, 2).unwrap(); let array = Arc::new(decimal_array) as ArrayRef; let casted_array = cast(&array, &DataType::Int8); assert_eq!( @@ -2278,7 +2825,7 @@ mod tests { Some(112345678), Some(112345679), ]; - let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let decimal_array = create_decimal_array(value_array, 38, 2).unwrap(); let array = Arc::new(decimal_array) as ArrayRef; generate_cast_test_case!( &array, @@ -2306,7 +2853,7 @@ mod tests { Some(112345678901234568), Some(112345678901234560), ]; - let decimal_array = create_decimal_array(&value_array, 38, 2).unwrap(); + let decimal_array = create_decimal_array(value_array, 38, 2).unwrap(); let array = Arc::new(decimal_array) as ArrayRef; generate_cast_test_case!( &array, @@ -2327,7 +2874,7 @@ mod tests { #[test] fn test_cast_numeric_to_decimal() { // test negative cast type - let decimal_type = DataType::Decimal(38, 6); + let decimal_type = DataType::Decimal128(38, 6); assert!(!can_cast_types(&DataType::UInt64, &decimal_type)); // i8, i16, i32, i64 @@ -2364,7 +2911,7 @@ mod tests { for array in input_datas { generate_cast_test_case!( &array, - DecimalArray, + Decimal128Array, &decimal_type, vec![ Some(Decimal128::new_from_i128(38, 6, 1000000_i128)), @@ -2380,9 +2927,9 @@ mod tests { // the 100 will be converted to 1000_i128, but it is out of range for max value in the precision 3. let array = Int8Array::from(vec![1, 2, 3, 4, 100]); let array = Arc::new(array) as ArrayRef; - let casted_array = cast(&array, &DataType::Decimal(3, 1)); + let casted_array = cast(&array, &DataType::Decimal128(3, 1)); assert!(casted_array.is_err()); - assert_eq!("Invalid argument error: 1000 is too large to store in a Decimal of precision 3. Max is 999", casted_array.unwrap_err().to_string()); + assert_eq!("Invalid argument error: 1000 is too large to store in a Decimal128 of precision 3. Max is 999", casted_array.unwrap_err().to_string()); // test f32 to decimal type let array = Float32Array::from(vec![ @@ -2396,7 +2943,7 @@ mod tests { let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( &array, - DecimalArray, + Decimal128Array, &decimal_type, vec![ Some(Decimal128::new_from_i128(38, 6, 1100000_i128)), @@ -2421,7 +2968,7 @@ mod tests { let array = Arc::new(array) as ArrayRef; generate_cast_test_case!( &array, - DecimalArray, + Decimal128Array, &decimal_type, vec![ Some(Decimal128::new_from_i128(38, 6, 1100000_i128)), @@ -2595,9 +3142,13 @@ mod tests { match result { Ok(_) => panic!("expected error"), Err(e) => { - assert!(e.to_string().contains( - "Cast error: Cannot cast string 'seven' to value of arrow::datatypes::types::Int32Type type" - )) + assert!( + e.to_string().contains( + "Cast error: Cannot cast string 'seven' to value of Int32 type", + ), + "Error: {}", + e + ) } } } @@ -2791,8 +3342,8 @@ mod tests { None, ])) as ArrayRef; for array in &[a1, a2] { - let b = - cast(array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap(); + let to_type = DataType::Timestamp(TimeUnit::Nanosecond, None); + let b = cast(array, &to_type).unwrap(); let c = b .as_any() .downcast_ref::() @@ -2800,6 +3351,13 @@ mod tests { assert_eq!(1599566400000000000, c.value(0)); assert!(c.is_null(1)); assert!(c.is_null(2)); + + let options = CastOptions { safe: false }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!( + err.to_string(), + "Cast error: Error parsing 'Not a valid date' as timestamp" + ); } } @@ -2816,11 +3374,132 @@ mod tests { None, ])) as ArrayRef; for array in &[a1, a2] { - let b = cast(array, &DataType::Date32).unwrap(); + let to_type = DataType::Date32; + let b = cast(array, &to_type).unwrap(); let c = b.as_any().downcast_ref::().unwrap(); assert_eq!(17890, c.value(0)); assert!(c.is_null(1)); assert!(c.is_null(2)); + + let options = CastOptions { safe: false }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid date' to value of Date32 type"); + } + } + + #[test] + fn test_cast_string_to_time32second() { + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Time32(TimeUnit::Second); + let b = cast(array, &to_type).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(29315, c.value(0)); + assert_eq!(29340, c.value(1)); + assert!(c.is_null(2)); + assert!(c.is_null(3)); + assert!(c.is_null(4)); + + let options = CastOptions { safe: false }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Second) type"); + } + } + + #[test] + fn test_cast_string_to_time32millisecond() { + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("08:08:60.091323414"), // leap second + Some("08:08:61.091323414"), // not valid + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Time32(TimeUnit::Millisecond); + let b = cast(array, &to_type).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(29315091, c.value(0)); + assert_eq!(29340091, c.value(1)); + assert!(c.is_null(2)); + assert!(c.is_null(3)); + assert!(c.is_null(4)); + + let options = CastOptions { safe: false }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Millisecond) type"); + } + } + + #[test] + fn test_cast_string_to_time64microsecond() { + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Time64(TimeUnit::Microsecond); + let b = cast(array, &to_type).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(29315091323, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { safe: false }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Microsecond) type"); + } + } + + #[test] + fn test_cast_string_to_time64nanosecond() { + let a1 = Arc::new(StringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(vec![ + Some("08:08:35.091323414"), + Some("Not a valid time"), + None, + ])) as ArrayRef; + for array in &[a1, a2] { + let to_type = DataType::Time64(TimeUnit::Nanosecond); + let b = cast(array, &to_type).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(29315091323414, c.value(0)); + assert!(c.is_null(1)); + assert!(c.is_null(2)); + + let options = CastOptions { safe: false }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Nanosecond) type"); } } @@ -2837,14 +3516,47 @@ mod tests { None, ])) as ArrayRef; for array in &[a1, a2] { - let b = cast(array, &DataType::Date64).unwrap(); + let to_type = DataType::Date64; + let b = cast(array, &to_type).unwrap(); let c = b.as_any().downcast_ref::().unwrap(); assert_eq!(1599566400000, c.value(0)); assert!(c.is_null(1)); assert!(c.is_null(2)); + + let options = CastOptions { safe: false }; + let err = cast_with_options(array, &to_type, &options).unwrap_err(); + assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid date' to value of Date64 type"); } } + #[test] + fn test_cast_string_to_binary() { + let string_1 = "Hi"; + let string_2 = "Hello"; + + let bytes_1 = string_1.as_bytes(); + let bytes_2 = string_2.as_bytes(); + + let string_data = vec![Some(string_1), Some(string_2), None]; + let a1 = Arc::new(StringArray::from(string_data.clone())) as ArrayRef; + let a2 = Arc::new(LargeStringArray::from(string_data)) as ArrayRef; + + let mut array_ref = cast(&a1, &DataType::Binary).unwrap(); + let down_cast = array_ref.as_any().downcast_ref::().unwrap(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + + array_ref = cast(&a2, &DataType::LargeBinary).unwrap(); + let down_cast = array_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(bytes_1, down_cast.value(0)); + assert_eq!(bytes_2, down_cast.value(1)); + assert!(down_cast.is_null(2)); + } + #[test] fn test_cast_date32_to_int32() { let a = Date32Array::from(vec![10000, 17890]); @@ -2909,6 +3621,7 @@ mod tests { } #[test] + #[cfg(feature = "chrono-tz")] fn test_cast_timestamp_to_string() { let a = TimestampMillisecondArray::from_opt_vec( vec![Some(864000000005), Some(1545696000001), None], @@ -3062,15 +3775,15 @@ mod tests { #[test] fn test_cast_from_f64() { let f64_values: Vec = vec![ - std::i64::MIN as f64, - std::i32::MIN as f64, - std::i16::MIN as f64, - std::i8::MIN as f64, + i64::MIN as f64, + i32::MIN as f64, + i16::MIN as f64, + i8::MIN as f64, 0_f64, - std::u8::MAX as f64, - std::u16::MAX as f64, - std::u32::MAX as f64, - std::u64::MAX as f64, + u8::MAX as f64, + u16::MAX as f64, + u32::MAX as f64, + u64::MAX as f64, ]; let f64_array: ArrayRef = Arc::new(Float64Array::from(f64_values)); @@ -3212,15 +3925,15 @@ mod tests { #[test] fn test_cast_from_f32() { let f32_values: Vec = vec![ - std::i32::MIN as f32, - std::i32::MIN as f32, - std::i16::MIN as f32, - std::i8::MIN as f32, + i32::MIN as f32, + i32::MIN as f32, + i16::MIN as f32, + i8::MIN as f32, 0_f32, - std::u8::MAX as f32, - std::u16::MAX as f32, - std::u32::MAX as f32, - std::u32::MAX as f32, + u8::MAX as f32, + u16::MAX as f32, + u32::MAX as f32, + u32::MAX as f32, ]; let f32_array: ArrayRef = Arc::new(Float32Array::from(f32_values)); @@ -3349,10 +4062,10 @@ mod tests { fn test_cast_from_uint64() { let u64_values: Vec = vec![ 0, - std::u8::MAX as u64, - std::u16::MAX as u64, - std::u32::MAX as u64, - std::u64::MAX, + u8::MAX as u64, + u16::MAX as u64, + u32::MAX as u64, + u64::MAX, ]; let u64_array: ArrayRef = Arc::new(UInt64Array::from(u64_values)); @@ -3428,12 +4141,8 @@ mod tests { #[test] fn test_cast_from_uint32() { - let u32_values: Vec = vec![ - 0, - std::u8::MAX as u32, - std::u16::MAX as u32, - std::u32::MAX as u32, - ]; + let u32_values: Vec = + vec![0, u8::MAX as u32, u16::MAX as u32, u32::MAX as u32]; let u32_array: ArrayRef = Arc::new(UInt32Array::from(u32_values)); let f64_expected = vec!["0.0", "255.0", "65535.0", "4294967295.0"]; @@ -3499,7 +4208,7 @@ mod tests { #[test] fn test_cast_from_uint16() { - let u16_values: Vec = vec![0, std::u8::MAX as u16, std::u16::MAX as u16]; + let u16_values: Vec = vec![0, u8::MAX as u16, u16::MAX as u16]; let u16_array: ArrayRef = Arc::new(UInt16Array::from(u16_values)); let f64_expected = vec!["0.0", "255.0", "65535.0"]; @@ -3565,7 +4274,7 @@ mod tests { #[test] fn test_cast_from_uint8() { - let u8_values: Vec = vec![0, std::u8::MAX]; + let u8_values: Vec = vec![0, u8::MAX]; let u8_array: ArrayRef = Arc::new(UInt8Array::from(u8_values)); let f64_expected = vec!["0.0", "255.0"]; @@ -3632,15 +4341,15 @@ mod tests { #[test] fn test_cast_from_int64() { let i64_values: Vec = vec![ - std::i64::MIN, - std::i32::MIN as i64, - std::i16::MIN as i64, - std::i8::MIN as i64, + i64::MIN, + i32::MIN as i64, + i16::MIN as i64, + i8::MIN as i64, 0, - std::i8::MAX as i64, - std::i16::MAX as i64, - std::i32::MAX as i64, - std::i64::MAX, + i8::MAX as i64, + i16::MAX as i64, + i32::MAX as i64, + i64::MAX, ]; let i64_array: ArrayRef = Arc::new(Int64Array::from(i64_values)); @@ -3787,13 +4496,13 @@ mod tests { #[test] fn test_cast_from_int32() { let i32_values: Vec = vec![ - std::i32::MIN as i32, - std::i16::MIN as i32, - std::i8::MIN as i32, + i32::MIN as i32, + i16::MIN as i32, + i8::MIN as i32, 0, - std::i8::MAX as i32, - std::i16::MAX as i32, - std::i32::MAX as i32, + i8::MAX as i32, + i16::MAX as i32, + i32::MAX as i32, ]; let i32_array: ArrayRef = Arc::new(Int32Array::from(i32_values)); @@ -3881,13 +4590,8 @@ mod tests { #[test] fn test_cast_from_int16() { - let i16_values: Vec = vec![ - std::i16::MIN, - std::i8::MIN as i16, - 0, - std::i8::MAX as i16, - std::i16::MAX, - ]; + let i16_values: Vec = + vec![i16::MIN, i8::MIN as i16, 0, i8::MAX as i16, i16::MAX]; let i16_array: ArrayRef = Arc::new(Int16Array::from(i16_values)); let f64_expected = vec!["-32768.0", "-128.0", "0.0", "127.0", "32767.0"]; @@ -3954,13 +4658,13 @@ mod tests { #[test] fn test_cast_from_date32() { let i32_values: Vec = vec![ - std::i32::MIN as i32, - std::i16::MIN as i32, - std::i8::MIN as i32, + i32::MIN as i32, + i16::MIN as i32, + i8::MIN as i32, 0, - std::i8::MAX as i32, - std::i16::MAX as i32, - std::i32::MAX as i32, + i8::MAX as i32, + i16::MAX as i32, + i32::MAX as i32, ]; let date32_array: ArrayRef = Arc::new(Date32Array::from(i32_values)); @@ -3981,7 +4685,7 @@ mod tests { #[test] fn test_cast_from_int8() { - let i8_values: Vec = vec![std::i8::MIN, 0, std::i8::MAX]; + let i8_values: Vec = vec![i8::MIN, 0, i8::MAX]; let i8_array: ArrayRef = Arc::new(Int8Array::from(i8_values)); let f64_expected = vec!["-128.0", "0.0", "127.0"]; @@ -4068,11 +4772,11 @@ mod tests { // FROM a dictionary with of Utf8 values use DataType::*; - let keys_builder = PrimitiveBuilder::::new(10); - let values_builder = StringBuilder::new(10); + let keys_builder = PrimitiveBuilder::::new(); + let values_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); builder.append("one").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("three").unwrap(); let array: ArrayRef = Arc::new(builder.finish()); @@ -4129,8 +4833,8 @@ mod tests { // that are out of bounds for a particular other kind of // index. - let keys_builder = PrimitiveBuilder::::new(10); - let values_builder = PrimitiveBuilder::::new(10); + let keys_builder = PrimitiveBuilder::::new(); + let values_builder = PrimitiveBuilder::::new(); let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); // add 200 distinct values (which can be stored by a @@ -4160,8 +4864,8 @@ mod tests { // Same test as test_cast_dict_to_dict_bad_index_value but use // string values (and encode the expected behavior here); - let keys_builder = PrimitiveBuilder::::new(10); - let values_builder = StringBuilder::new(10); + let keys_builder = PrimitiveBuilder::::new(); + let values_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); // add 200 distinct values (which can be stored by a @@ -4191,11 +4895,11 @@ mod tests { // FROM a dictionary with of INT32 values use DataType::*; - let keys_builder = PrimitiveBuilder::::new(10); - let values_builder = PrimitiveBuilder::::new(10); + let keys_builder = PrimitiveBuilder::::new(); + let values_builder = PrimitiveBuilder::::new(); let mut builder = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); builder.append(1).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(3).unwrap(); let array: ArrayRef = Arc::new(builder.finish()); @@ -4215,10 +4919,10 @@ mod tests { fn test_cast_primitive_array_to_dict() { use DataType::*; - let mut builder = PrimitiveBuilder::::new(10); - builder.append_value(1).unwrap(); - builder.append_null().unwrap(); - builder.append_value(3).unwrap(); + let mut builder = PrimitiveBuilder::::new(); + builder.append_value(1); + builder.append_null(); + builder.append_value(3); let array: ArrayRef = Arc::new(builder.finish()); let expected = vec!["1", "null", "3"]; @@ -4254,7 +4958,7 @@ mod tests { #[test] fn test_cast_null_array_to_from_decimal_array() { - let data_type = DataType::Decimal(12, 4); + let data_type = DataType::Decimal128(12, 4); let array = new_null_array(&DataType::Null, 4); assert_eq!(array.data_type(), &DataType::Null); let cast_array = cast(&array, &data_type).expect("cast failed"); @@ -4443,6 +5147,7 @@ mod tests { #[test] #[cfg_attr(miri, ignore)] // running forever + #[cfg(feature = "chrono-tz")] fn test_can_cast_types() { // this function attempts to ensure that can_cast_types stays // in sync with cast. It simply tries all combinations of @@ -4510,6 +5215,7 @@ mod tests { } /// Create instances of arrays with varying types for cast tests + #[cfg(feature = "chrono-tz")] fn get_arrays_of_all_types() -> Vec { let tz_name = String::from("America/New_York"); let binary_data: Vec<&[u8]> = vec![b"foo", b"bar"]; @@ -4596,7 +5302,8 @@ mod tests { Arc::new(DurationMicrosecondArray::from(vec![1000, 2000])), Arc::new(DurationNanosecondArray::from(vec![1000, 2000])), Arc::new( - create_decimal_array(&[Some(1), Some(2), Some(3), None], 38, 0).unwrap(), + create_decimal_array(vec![Some(1), Some(2), Some(3), None], 38, 0) + .unwrap(), ), ] } @@ -4649,6 +5356,7 @@ mod tests { LargeListArray::from(list_data) } + #[cfg(feature = "chrono-tz")] fn make_fixed_size_list_array() -> FixedSizeListArray { // Construct a value array let value_data = ArrayData::builder(DataType::Int32) @@ -4670,6 +5378,7 @@ mod tests { FixedSizeListArray::from(list_data) } + #[cfg(feature = "chrono-tz")] fn make_fixed_size_binary_array() -> FixedSizeBinaryArray { let values: [u8; 15] = *b"hellotherearrow"; @@ -4681,18 +5390,20 @@ mod tests { FixedSizeBinaryArray::from(array_data) } + #[cfg(feature = "chrono-tz")] fn make_union_array() -> UnionArray { - let mut builder = UnionBuilder::new_dense(7); + let mut builder = UnionBuilder::with_capacity_dense(7); builder.append::("a", 1).unwrap(); builder.append::("b", 2).unwrap(); builder.build().unwrap() } /// Creates a dictionary with primitive dictionary values, and keys of type K + #[cfg(feature = "chrono-tz")] fn make_dictionary_primitive() -> ArrayRef { - let keys_builder = PrimitiveBuilder::::new(2); + let keys_builder = PrimitiveBuilder::::new(); // Pick Int32 arbitrarily for dictionary values - let values_builder = PrimitiveBuilder::::new(2); + let values_builder = PrimitiveBuilder::::new(); let mut b = PrimitiveDictionaryBuilder::new(keys_builder, values_builder); b.append(1).unwrap(); b.append(2).unwrap(); @@ -4700,10 +5411,11 @@ mod tests { } /// Creates a dictionary with utf8 values, and keys of type K + #[cfg(feature = "chrono-tz")] fn make_dictionary_utf8() -> ArrayRef { - let keys_builder = PrimitiveBuilder::::new(2); + let keys_builder = PrimitiveBuilder::::new(); // Pick Int32 arbitrarily for dictionary values - let values_builder = StringBuilder::new(2); + let values_builder = StringBuilder::new(); let mut b = StringDictionaryBuilder::new(keys_builder, values_builder); b.append("foo").unwrap(); b.append("bar").unwrap(); @@ -4711,6 +5423,7 @@ mod tests { } // Get a selection of datatypes to try and cast to + #[cfg(feature = "chrono-tz")] fn get_all_types() -> Vec { use DataType::*; let tz_name = String::from("America/New_York"); @@ -4776,7 +5489,7 @@ mod tests { Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), Dictionary(Box::new(DataType::Int16), Box::new(DataType::Utf8)), Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), - Decimal(38, 0), + Decimal128(38, 0), ] } @@ -4805,4 +5518,39 @@ mod tests { assert_eq!(&out1, &out2.slice(1, 2)) } + + #[test] + #[cfg(feature = "chrono-tz")] + fn test_timestamp_cast_utf8() { + let array: PrimitiveArray = + vec![Some(37800000000), None, Some(86339000000)].into(); + let out = cast(&(Arc::new(array) as ArrayRef), &DataType::Utf8).unwrap(); + + let expected = StringArray::from(vec![ + Some("1970-01-01 10:30:00"), + None, + Some("1970-01-01 23:58:59"), + ]); + + assert_eq!( + out.as_any().downcast_ref::().unwrap(), + &expected + ); + + let array: PrimitiveArray = + vec![Some(37800000000), None, Some(86339000000)].into(); + let array = array.with_timezone("Australia/Sydney".to_string()); + let out = cast(&(Arc::new(array) as ArrayRef), &DataType::Utf8).unwrap(); + + let expected = StringArray::from(vec![ + Some("1970-01-01 20:30:00"), + None, + Some("1970-01-02 09:58:59"), + ]); + + assert_eq!( + out.as_any().downcast_ref::().unwrap(), + &expected + ); + } } diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 0a6d60cea470..dd9d4fc5d492 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -24,184 +24,91 @@ //! use crate::array::*; -use crate::buffer::{bitwise_bin_op_helper, buffer_unary_not, Buffer, MutableBuffer}; -use crate::compute::binary_boolean_kernel; +use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer}; use crate::compute::util::combine_option_bitmap; use crate::datatypes::{ ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, + IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -use regex::{escape, Regex}; -use std::any::type_name; +use regex::Regex; use std::collections::HashMap; -/// Helper function to perform boolean lambda function on values from two arrays, this +/// Helper function to perform boolean lambda function on values from two array accessors, this /// version does not attempt to use SIMD. -macro_rules! compare_op { - ($left: expr, $right:expr, $op:expr) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let null_bit_buffer = - combine_option_bitmap(&[$left.data_ref(), $right.data_ref()], $left.len())?; - - // Safety: - // `i < $left.len()` and $left.len() == $right.len() - let comparison = (0..$left.len()) - .map(|i| unsafe { $op($left.value_unchecked(i), $right.value_unchecked(i)) }); - // same size as $left.len() and $right.len() - let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - $left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(buffer)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) - }}; -} +fn compare_op( + left: T, + right: S, + op: F, +) -> Result +where + F: Fn(T::Item, S::Item) -> bool, +{ + if left.len() != right.len() { + return Err(ArrowError::ComputeError( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } -macro_rules! compare_op_primitive { - ($left: expr, $right:expr, $op:expr) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } + let null_bit_buffer = + combine_option_bitmap(&[left.data_ref(), right.data_ref()], left.len())?; - let null_bit_buffer = - combine_option_bitmap(&[$left.data_ref(), $right.data_ref()], $left.len())?; - - let mut values = MutableBuffer::from_len_zeroed(($left.len() + 7) / 8); - let lhs_chunks_iter = $left.values().chunks_exact(8); - let lhs_remainder = lhs_chunks_iter.remainder(); - let rhs_chunks_iter = $right.values().chunks_exact(8); - let rhs_remainder = rhs_chunks_iter.remainder(); - let chunks = $left.len() / 8; - - values[..chunks] - .iter_mut() - .zip(lhs_chunks_iter) - .zip(rhs_chunks_iter) - .for_each(|((byte, lhs), rhs)| { - lhs.iter() - .zip(rhs.iter()) - .enumerate() - .for_each(|(i, (&lhs, &rhs))| { - *byte |= if $op(lhs, rhs) { 1 << i } else { 0 }; - }); - }); + // Safety: + // `i < $left.len()` and $left.len() == $right.len() + let comparison = (0..left.len()) + .map(|i| unsafe { op(left.value_unchecked(i), right.value_unchecked(i)) }); + // same size as $left.len() and $right.len() + let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; - if !lhs_remainder.is_empty() { - let last = &mut values[chunks]; - lhs_remainder - .iter() - .zip(rhs_remainder.iter()) - .enumerate() - .for_each(|(i, (&lhs, &rhs))| { - *last |= if $op(lhs, rhs) { 1 << i } else { 0 }; - }); - }; - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - $left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(values)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) - }}; + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + left.len(), + None, + null_bit_buffer, + 0, + vec![Buffer::from(buffer)], + vec![], + ) + }; + Ok(BooleanArray::from(data)) } -macro_rules! compare_op_scalar { - ($left:expr, $op:expr) => {{ - let null_bit_buffer = $left - .data() - .null_buffer() - .map(|b| b.bit_slice($left.offset(), $left.len())); - - // Safety: - // `i < $left.len()` - let comparison = - (0..$left.len()).map(|i| unsafe { $op($left.value_unchecked(i)) }); - // same as $left.len() - let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - $left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(buffer)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) - }}; -} +/// Helper function to perform boolean lambda function on values from array accessor, this +/// version does not attempt to use SIMD. +fn compare_op_scalar(left: T, op: F) -> Result +where + F: Fn(T::Item) -> bool, +{ + let null_bit_buffer = left + .data() + .null_buffer() + .map(|b| b.bit_slice(left.offset(), left.len())); -macro_rules! compare_op_scalar_primitive { - ($left: expr, $right:expr, $op:expr) => {{ - let null_bit_buffer = $left - .data() - .null_buffer() - .map(|b| b.bit_slice($left.offset(), $left.len())); - - let mut values = MutableBuffer::from_len_zeroed(($left.len() + 7) / 8); - let lhs_chunks_iter = $left.values().chunks_exact(8); - let lhs_remainder = lhs_chunks_iter.remainder(); - let chunks = $left.len() / 8; - - values[..chunks] - .iter_mut() - .zip(lhs_chunks_iter) - .for_each(|(byte, chunk)| { - chunk.iter().enumerate().for_each(|(i, &c_i)| { - *byte |= if $op(c_i, $right) { 1 << i } else { 0 }; - }); - }); - if !lhs_remainder.is_empty() { - let last = &mut values[chunks]; - lhs_remainder.iter().enumerate().for_each(|(i, &lhs)| { - *last |= if $op(lhs, $right) { 1 << i } else { 0 }; - }); - }; + // Safety: + // `i < $left.len()` + let comparison = (0..left.len()).map(|i| unsafe { op(left.value_unchecked(i)) }); + // same as $left.len() + let buffer = unsafe { MutableBuffer::from_trusted_len_iter_bool(comparison) }; - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - $left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(values)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) - }}; + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + left.len(), + None, + null_bit_buffer, + 0, + vec![Buffer::from(buffer)], + vec![], + ) + }; + Ok(BooleanArray::from(data)) } /// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified @@ -215,7 +122,7 @@ where T: ArrowNumericType, F: Fn(T::Native, T::Native) -> bool, { - compare_op_primitive!(left, right, op) + compare_op(left, right, op) } /// Evaluate `op(left, right)` for [`PrimitiveArray`] and scalar using @@ -229,7 +136,7 @@ where T: ArrowNumericType, F: Fn(T::Native, T::Native) -> bool, { - compare_op_scalar_primitive!(left, right, op) + compare_op_scalar(left, |l| op(l, right)) } fn is_like_pattern(c: char) -> bool { @@ -267,7 +174,7 @@ where let re = if let Some(ref regex) = map.get(pat) { regex } else { - let re_pattern = escape(pat).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(pat)?; let re = op(&re_pattern)?; map.insert(pat, re); map.get(pat).unwrap() @@ -326,12 +233,9 @@ pub fn like_utf8( }) } -/// Perform SQL `left LIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn like_utf8_scalar( - left: &GenericStringArray, +#[inline] +fn like_scalar<'a, L: ArrayAccessor>( + left: L, right: &str, ) -> Result { let null_bit_buffer = left.data().null_buffer().cloned(); @@ -342,29 +246,51 @@ pub fn like_utf8_scalar( if !right.contains(is_like_pattern) { // fast path, can use equals for i in 0..left.len() { - if left.value(i) == right { - bit_util::set_bit(bool_slice, i); + unsafe { + if left.value_unchecked(i) == right { + bit_util::set_bit(bool_slice, i); + } } } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use starts_with let starts_with = &right[..right.len() - 1]; for i in 0..left.len() { - if left.value(i).starts_with(starts_with) { - bit_util::set_bit(bool_slice, i); + unsafe { + if left.value_unchecked(i).starts_with(starts_with) { + bit_util::set_bit(bool_slice, i); + } } } } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { // fast path, can use ends_with let ends_with = &right[1..]; + for i in 0..left.len() { - if left.value(i).ends_with(ends_with) { - bit_util::set_bit(bool_slice, i); + unsafe { + if left.value_unchecked(i).ends_with(ends_with) { + bit_util::set_bit(bool_slice, i); + } + } + } + } else if right.starts_with('%') + && right.ends_with('%') + && !right[1..right.len() - 1].contains(is_like_pattern) + { + // fast path, can use contains + let contains = &right[1..right.len() - 1]; + for i in 0..left.len() { + unsafe { + if left.value_unchecked(i).contains(contains) { + bit_util::set_bit(bool_slice, i); + } } } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from LIKE pattern: {}", @@ -373,7 +299,7 @@ pub fn like_utf8_scalar( })?; for i in 0..left.len() { - let haystack = left.value(i); + let haystack = unsafe { left.value_unchecked(i) }; if re.is_match(haystack) { bit_util::set_bit(bool_slice, i); } @@ -394,6 +320,79 @@ pub fn like_utf8_scalar( Ok(BooleanArray::from(data)) } +/// Perform SQL `left LIKE right` operation on [`StringArray`] / +/// [`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn like_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { + like_scalar(left, right) +} + +/// Perform SQL `left LIKE right` operation on [`DictionaryArray`] with values +/// [`StringArray`]/[`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn like_dict_scalar( + left: &DictionaryArray, + right: &str, +) -> Result { + match left.value_type() { + DataType::Utf8 => { + let left = left.downcast_dict::>().unwrap(); + like_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = left.downcast_dict::>().unwrap(); + like_scalar(left, right) + } + _ => { + Err(ArrowError::ComputeError( + "like_dict_scalar only supports DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )) + } + } +} + +/// Transforms a like `pattern` to a regex compatible pattern. To achieve that, it does: +/// +/// 1. Replace like wildcards for regex expressions as the pattern will be evaluated using regex match: `%` => `.*` and `_` => `.` +/// 2. Escape regex meta characters to match them and not be evaluated as regex special chars. For example: `.` => `\\.` +/// 3. Replace escaped like wildcards removing the escape characters to be able to match it as a regex. For example: `\\%` => `%` +fn replace_like_wildcards(pattern: &str) -> Result { + let mut result = String::new(); + let pattern = String::from(pattern); + let mut chars_iter = pattern.chars().peekable(); + while let Some(c) = chars_iter.next() { + if c == '\\' { + let next = chars_iter.peek(); + match next { + Some(next) if is_like_pattern(*next) => { + result.push(*next); + // Skipping the next char as it is already appended + chars_iter.next(); + } + _ => { + result.push('\\'); + result.push('\\'); + } + } + } else if regex_syntax::is_meta_character(c) { + result.push('\\'); + result.push(c); + } else if c == '%' { + result.push_str(".*"); + } else if c == '_' { + result.push('.'); + } else { + result.push(c); + } + } + Ok(result) +} + /// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / /// [`LargeStringArray`]. /// @@ -412,46 +411,78 @@ pub fn nlike_utf8( }) } -/// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn nlike_utf8_scalar( - left: &GenericStringArray, +#[inline] +fn nlike_scalar<'a, L: ArrayAccessor>( + left: L, right: &str, ) -> Result { let null_bit_buffer = left.data().null_buffer().cloned(); - let mut result = BooleanBufferBuilder::new(left.len()); + let bytes = bit_util::ceil(left.len(), 8); + let mut bool_buf = MutableBuffer::from_len_zeroed(bytes); + let bool_slice = bool_buf.as_slice_mut(); if !right.contains(is_like_pattern) { // fast path, can use equals for i in 0..left.len() { - result.append(left.value(i) != right); + unsafe { + if left.value_unchecked(i) != right { + bit_util::set_bit(bool_slice, i); + } + } } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { - // fast path, can use ends_with + // fast path, can use starts_with + let starts_with = &right[..right.len() - 1]; for i in 0..left.len() { - result.append(!left.value(i).starts_with(&right[..right.len() - 1])); + unsafe { + if !(left.value_unchecked(i).starts_with(starts_with)) { + bit_util::set_bit(bool_slice, i); + } + } } } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { - // fast path, can use starts_with + // fast path, can use ends_with + let ends_with = &right[1..]; + + for i in 0..left.len() { + unsafe { + if !(left.value_unchecked(i).ends_with(ends_with)) { + bit_util::set_bit(bool_slice, i); + } + } + } + } else if right.starts_with('%') + && right.ends_with('%') + && !right[1..right.len() - 1].contains(is_like_pattern) + { + // fast path, can use contains + let contains = &right[1..right.len() - 1]; for i in 0..left.len() { - result.append(!left.value(i).ends_with(&right[1..])); + unsafe { + if !(left.value_unchecked(i).contains(contains)) { + bit_util::set_bit(bool_slice, i); + } + } } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from LIKE pattern: {}", e )) })?; + for i in 0..left.len() { - let haystack = left.value(i); - result.append(!re.is_match(haystack)); + let haystack = unsafe { left.value_unchecked(i) }; + if !re.is_match(haystack) { + bit_util::set_bit(bool_slice, i); + } } - } + }; let data = unsafe { ArrayData::new_unchecked( @@ -460,13 +491,49 @@ pub fn nlike_utf8_scalar( None, null_bit_buffer, 0, - vec![result.finish()], + vec![bool_buf.into()], vec![], ) }; Ok(BooleanArray::from(data)) } +/// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / +/// [`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn nlike_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { + nlike_scalar(left, right) +} + +/// Perform SQL `left NOT LIKE right` operation on [`DictionaryArray`] with values +/// [`StringArray`]/[`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn nlike_dict_scalar( + left: &DictionaryArray, + right: &str, +) -> Result { + match left.value_type() { + DataType::Utf8 => { + let left = left.downcast_dict::>().unwrap(); + nlike_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = left.downcast_dict::>().unwrap(); + nlike_scalar(left, right) + } + _ => { + Err(ArrowError::ComputeError( + "nlike_dict_scalar only supports DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )) + } + } +} + /// Perform SQL `left ILIKE right` operation on [`StringArray`] / /// [`LargeStringArray`]. /// @@ -485,54 +552,83 @@ pub fn ilike_utf8( }) } -/// Perform SQL `left ILIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn ilike_utf8_scalar( - left: &GenericStringArray, +#[inline] +fn ilike_scalar<'a, L: ArrayAccessor>( + left: L, right: &str, ) -> Result { let null_bit_buffer = left.data().null_buffer().cloned(); - let mut result = BooleanBufferBuilder::new(left.len()); + let bytes = bit_util::ceil(left.len(), 8); + let mut bool_buf = MutableBuffer::from_len_zeroed(bytes); + let bool_slice = bool_buf.as_slice_mut(); if !right.contains(is_like_pattern) { // fast path, can use equals + let right_uppercase = right.to_uppercase(); for i in 0..left.len() { - result.append(left.value(i) == right); + unsafe { + if left.value_unchecked(i).to_uppercase() == right_uppercase { + bit_util::set_bit(bool_slice, i); + } + } } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { - // fast path, can use ends_with + // fast path, can use starts_with + let start_str = &right[..right.len() - 1].to_uppercase(); for i in 0..left.len() { - result.append( - left.value(i) + unsafe { + if left + .value_unchecked(i) .to_uppercase() - .starts_with(&right[..right.len() - 1].to_uppercase()), - ); + .starts_with(start_str) + { + bit_util::set_bit(bool_slice, i); + } + } } } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { - // fast path, can use starts_with + // fast path, can use ends_with + let ends_str = &right[1..].to_uppercase(); + for i in 0..left.len() { - result.append( - left.value(i) - .to_uppercase() - .ends_with(&right[1..].to_uppercase()), - ); + unsafe { + if left.value_unchecked(i).to_uppercase().ends_with(ends_str) { + bit_util::set_bit(bool_slice, i); + } + } + } + } else if right.starts_with('%') + && right.ends_with('%') + && !right[1..right.len() - 1].contains(is_like_pattern) + { + // fast path, can use contains + let contains = &right[1..right.len() - 1].to_uppercase(); + for i in 0..left.len() { + unsafe { + if left.value_unchecked(i).to_uppercase().contains(contains) { + bit_util::set_bit(bool_slice, i); + } + } } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from ILIKE pattern: {}", e )) })?; + for i in 0..left.len() { - let haystack = left.value(i); - result.append(re.is_match(haystack)); + let haystack = unsafe { left.value_unchecked(i) }; + if re.is_match(haystack) { + bit_util::set_bit(bool_slice, i); + } } - } + }; let data = unsafe { ArrayData::new_unchecked( @@ -541,13 +637,49 @@ pub fn ilike_utf8_scalar( None, null_bit_buffer, 0, - vec![result.finish()], + vec![bool_buf.into()], vec![], ) }; Ok(BooleanArray::from(data)) } +/// Perform SQL `left ILIKE right` operation on [`StringArray`] / +/// [`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn ilike_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { + ilike_scalar(left, right) +} + +/// Perform SQL `left ILIKE right` operation on [`DictionaryArray`] with values +/// [`StringArray`]/[`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn ilike_dict_scalar( + left: &DictionaryArray, + right: &str, +) -> Result { + match left.value_type() { + DataType::Utf8 => { + let left = left.downcast_dict::>().unwrap(); + ilike_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = left.downcast_dict::>().unwrap(); + ilike_scalar(left, right) + } + _ => { + Err(ArrowError::ComputeError( + "ilike_dict_scalar only supports DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )) + } + } +} + /// Perform SQL `left NOT ILIKE right` operation on [`StringArray`] / /// [`LargeStringArray`]. /// @@ -566,56 +698,83 @@ pub fn nilike_utf8( }) } -/// Perform SQL `left NOT ILIKE right` operation on [`StringArray`] / -/// [`LargeStringArray`] and a scalar. -/// -/// See the documentation on [`like_utf8`] for more details. -pub fn nilike_utf8_scalar( - left: &GenericStringArray, +#[inline] +fn nilike_scalar<'a, L: ArrayAccessor>( + left: L, right: &str, ) -> Result { let null_bit_buffer = left.data().null_buffer().cloned(); - let mut result = BooleanBufferBuilder::new(left.len()); + let bytes = bit_util::ceil(left.len(), 8); + let mut bool_buf = MutableBuffer::from_len_zeroed(bytes); + let bool_slice = bool_buf.as_slice_mut(); if !right.contains(is_like_pattern) { // fast path, can use equals + let right_uppercase = right.to_uppercase(); for i in 0..left.len() { - result.append(left.value(i) != right); + unsafe { + if left.value_unchecked(i).to_uppercase() != right_uppercase { + bit_util::set_bit(bool_slice, i); + } + } } - } else if right.ends_with('%') && !right[..right.len() - 1].contains(is_like_pattern) + } else if right.ends_with('%') + && !right.ends_with("\\%") + && !right[..right.len() - 1].contains(is_like_pattern) { - // fast path, can use ends_with + // fast path, can use starts_with + let start_str = &right[..right.len() - 1].to_uppercase(); for i in 0..left.len() { - result.append( - !left - .value(i) + unsafe { + if !(left + .value_unchecked(i) .to_uppercase() - .starts_with(&right[..right.len() - 1].to_uppercase()), - ); + .starts_with(start_str)) + { + bit_util::set_bit(bool_slice, i); + } + } } } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { - // fast path, can use starts_with + // fast path, can use ends_with + let ends_str = &right[1..].to_uppercase(); + for i in 0..left.len() { - result.append( - !left - .value(i) - .to_uppercase() - .ends_with(&right[1..].to_uppercase()), - ); + unsafe { + if !(left.value_unchecked(i).to_uppercase().ends_with(ends_str)) { + bit_util::set_bit(bool_slice, i); + } + } + } + } else if right.starts_with('%') + && right.ends_with('%') + && !right[1..right.len() - 1].contains(is_like_pattern) + { + // fast path, can use contains + let contains = &right[1..right.len() - 1].to_uppercase(); + for i in 0..left.len() { + unsafe { + if !(left.value_unchecked(i).to_uppercase().contains(contains)) { + bit_util::set_bit(bool_slice, i); + } + } } } else { - let re_pattern = escape(right).replace('%', ".*").replace('_', "."); + let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("(?i)^{}$", re_pattern)).map_err(|e| { ArrowError::ComputeError(format!( "Unable to build regex from ILIKE pattern: {}", e )) })?; + for i in 0..left.len() { - let haystack = left.value(i); - result.append(!re.is_match(haystack)); + let haystack = unsafe { left.value_unchecked(i) }; + if !re.is_match(haystack) { + bit_util::set_bit(bool_slice, i); + } } - } + }; let data = unsafe { ArrayData::new_unchecked( @@ -624,13 +783,49 @@ pub fn nilike_utf8_scalar( None, null_bit_buffer, 0, - vec![result.finish()], + vec![bool_buf.into()], vec![], ) }; Ok(BooleanArray::from(data)) } +/// Perform SQL `left NOT ILIKE right` operation on [`StringArray`] / +/// [`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn nilike_utf8_scalar( + left: &GenericStringArray, + right: &str, +) -> Result { + nilike_scalar(left, right) +} + +/// Perform SQL `left NOT ILIKE right` operation on [`DictionaryArray`] with values +/// [`StringArray`]/[`LargeStringArray`] and a scalar. +/// +/// See the documentation on [`like_utf8`] for more details. +pub fn nilike_dict_scalar( + left: &DictionaryArray, + right: &str, +) -> Result { + match left.value_type() { + DataType::Utf8 => { + let left = left.downcast_dict::>().unwrap(); + nilike_scalar(left, right) + } + DataType::LargeUtf8 => { + let left = left.downcast_dict::>().unwrap(); + nilike_scalar(left, right) + } + _ => { + Err(ArrowError::ComputeError( + "nilike_dict_scalar only supports DictionaryArray with Utf8 or LargeUtf8 values".to_string(), + )) + } + } +} + /// Perform SQL `array ~ regex_array` operation on [`StringArray`] / [`LargeStringArray`]. /// If `regex_array` element has an empty value, the corresponding result value is always true. /// @@ -769,7 +964,7 @@ pub fn eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a == b) + compare_op(left, right, |a, b| a == b) } /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -777,66 +972,37 @@ pub fn eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a == right) -} - -#[inline] -fn binary_boolean_op( - left: &BooleanArray, - right: &BooleanArray, - op: F, -) -> Result -where - F: Copy + Fn(u64, u64) -> u64, -{ - binary_boolean_kernel( - left, - right, - |left: &Buffer, - left_offset_in_bits: usize, - right: &Buffer, - right_offset_in_bits: usize, - len_in_bits: usize| { - bitwise_bin_op_helper( - left, - left_offset_in_bits, - right, - right_offset_in_bits, - len_in_bits, - op, - ) - }, - ) + compare_op_scalar(left, |a| a == right) } /// Perform `left == right` operation on [`BooleanArray`] pub fn eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| !(a ^ b)) + compare_op(left, right, |a, b| !(a ^ b)) } /// Perform `left != right` operation on [`BooleanArray`] pub fn neq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| (a ^ b)) + compare_op(left, right, |a, b| (a ^ b)) } /// Perform `left < right` operation on [`BooleanArray`] pub fn lt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| ((!a) & b)) + compare_op(left, right, |a, b| ((!a) & b)) } /// Perform `left <= right` operation on [`BooleanArray`] pub fn lt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| !(a & (!b))) + compare_op(left, right, |a, b| !(a & (!b))) } /// Perform `left > right` operation on [`BooleanArray`] pub fn gt_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| (a & (!b))) + compare_op(left, right, |a, b| (a & (!b))) } /// Perform `left >= right` operation on [`BooleanArray`] pub fn gt_eq_bool(left: &BooleanArray, right: &BooleanArray) -> Result { - binary_boolean_op(left, right, |a, b| !((!a) & b)) + compare_op(left, right, |a, b| !((!a) & b)) } /// Perform `left == right` operation on [`BooleanArray`] and a scalar @@ -870,22 +1036,22 @@ pub fn eq_bool_scalar(left: &BooleanArray, right: bool) -> Result /// Perform `left < right` operation on [`BooleanArray`] and a scalar pub fn lt_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, |a: bool| !a & right) + compare_op_scalar(left, |a: bool| !a & right) } /// Perform `left <= right` operation on [`BooleanArray`] and a scalar pub fn lt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, |a| a <= right) + compare_op_scalar(left, |a| a <= right) } /// Perform `left > right` operation on [`BooleanArray`] and a scalar pub fn gt_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, |a: bool| a & !right) + compare_op_scalar(left, |a: bool| a & !right) } /// Perform `left >= right` operation on [`BooleanArray`] and a scalar pub fn gt_eq_bool_scalar(left: &BooleanArray, right: bool) -> Result { - compare_op_scalar!(left, |a| a >= right) + compare_op_scalar(left, |a| a >= right) } /// Perform `left != right` operation on [`BooleanArray`] and a scalar @@ -898,7 +1064,7 @@ pub fn eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a == b) + compare_op(left, right, |a, b| a == b) } /// Perform `left == right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar @@ -906,7 +1072,7 @@ pub fn eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a == right) + compare_op_scalar(left, |a| a == right) } /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -914,7 +1080,7 @@ pub fn neq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a != b) + compare_op(left, right, |a, b| a != b) } /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -922,7 +1088,7 @@ pub fn neq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a != right) + compare_op_scalar(left, |a| a != right) } /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -930,7 +1096,7 @@ pub fn lt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a < b) + compare_op(left, right, |a, b| a < b) } /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -938,7 +1104,7 @@ pub fn lt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a < right) + compare_op_scalar(left, |a| a < right) } /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -946,7 +1112,7 @@ pub fn lt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a <= b) + compare_op(left, right, |a, b| a <= b) } /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -954,7 +1120,7 @@ pub fn lt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a <= right) + compare_op_scalar(left, |a| a <= right) } /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -962,7 +1128,7 @@ pub fn gt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a > b) + compare_op(left, right, |a, b| a > b) } /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -970,7 +1136,7 @@ pub fn gt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a > right) + compare_op_scalar(left, |a| a > right) } /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. @@ -978,7 +1144,7 @@ pub fn gt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op!(left, right, |a, b| a >= b) + compare_op(left, right, |a, b| a >= b) } /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. @@ -986,7 +1152,7 @@ pub fn gt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar!(left, |a| a >= right) + compare_op_scalar(left, |a| a >= right) } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -994,7 +1160,7 @@ pub fn neq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a != b) + compare_op(left, right, |a, b| a != b) } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1002,7 +1168,7 @@ pub fn neq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a != right) + compare_op_scalar(left, |a| a != right) } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1010,7 +1176,7 @@ pub fn lt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a < b) + compare_op(left, right, |a, b| a < b) } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1018,7 +1184,7 @@ pub fn lt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a < right) + compare_op_scalar(left, |a| a < right) } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1026,7 +1192,7 @@ pub fn lt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a <= b) + compare_op(left, right, |a, b| a <= b) } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1034,7 +1200,7 @@ pub fn lt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a <= right) + compare_op_scalar(left, |a| a <= right) } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1042,7 +1208,7 @@ pub fn gt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a > b) + compare_op(left, right, |a, b| a > b) } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1050,7 +1216,7 @@ pub fn gt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a > right) + compare_op_scalar(left, |a| a > right) } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`]. @@ -1058,7 +1224,7 @@ pub fn gt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op!(left, right, |a, b| a >= b) + compare_op(left, right, |a, b| a >= b) } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. @@ -1066,21 +1232,22 @@ pub fn gt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar!(left, |a| a >= right) + compare_op_scalar(left, |a| a >= right) +} + +// Avoids creating a closure for each combination of `$RIGHT` and `$TY` +fn try_to_type_result(value: Option, right: &str, ty: &str) -> Result { + value.ok_or_else(|| { + ArrowError::ComputeError(format!("Could not convert {} with {}", right, ty,)) + }) } /// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message. /// Type of expression is `Result<.., ArrowError>` macro_rules! try_to_type { - ($RIGHT: expr, $TY: ident) => {{ - $RIGHT.$TY().ok_or_else(|| { - ArrowError::ComputeError(format!( - "Could not convert {} with {}", - stringify!($RIGHT), - stringify!($TY) - )) - }) - }}; + ($RIGHT: expr, $TY: ident) => { + try_to_type_result($RIGHT.$TY(), stringify!($RIGHT), stringify!($TYPE)) + }; } macro_rules! dyn_compare_scalar { @@ -1150,59 +1317,35 @@ macro_rules! dyn_compare_scalar { match $KT.as_ref() { DataType::UInt8 => { let left = as_dictionary_array::($LEFT); - unpack_dict_comparison( - left, - dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, - ) + unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) } DataType::UInt16 => { let left = as_dictionary_array::($LEFT); - unpack_dict_comparison( - left, - dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, - ) + unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) } DataType::UInt32 => { let left = as_dictionary_array::($LEFT); - unpack_dict_comparison( - left, - dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, - ) + unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) } DataType::UInt64 => { let left = as_dictionary_array::($LEFT); - unpack_dict_comparison( - left, - dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, - ) + unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) } DataType::Int8 => { let left = as_dictionary_array::($LEFT); - unpack_dict_comparison( - left, - dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, - ) + unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) } DataType::Int16 => { let left = as_dictionary_array::($LEFT); - unpack_dict_comparison( - left, - dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, - ) + unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) } DataType::Int32 => { let left = as_dictionary_array::($LEFT); - unpack_dict_comparison( - left, - dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, - ) + unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) } DataType::Int64 => { let left = as_dictionary_array::($LEFT); - unpack_dict_comparison( - left, - dyn_compare_scalar!(left.values(), $RIGHT, $OP)?, - ) + unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) } _ => Err(ArrowError::ComputeError(format!( "Unsupported dictionary key type {:?}", @@ -1268,7 +1411,7 @@ where { match left.data_type() { DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, eq_scalar) + dyn_compare_scalar!(left, right, key_type, eq_dyn_scalar) } _ => dyn_compare_scalar!(left, right, eq_scalar), } @@ -1282,7 +1425,7 @@ where { match left.data_type() { DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, lt_scalar) + dyn_compare_scalar!(left, right, key_type, lt_dyn_scalar) } _ => dyn_compare_scalar!(left, right, lt_scalar), } @@ -1296,7 +1439,7 @@ where { match left.data_type() { DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, lt_eq_scalar) + dyn_compare_scalar!(left, right, key_type, lt_eq_dyn_scalar) } _ => dyn_compare_scalar!(left, right, lt_eq_scalar), } @@ -1310,7 +1453,7 @@ where { match left.data_type() { DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, gt_scalar) + dyn_compare_scalar!(left, right, key_type, gt_dyn_scalar) } _ => dyn_compare_scalar!(left, right, gt_scalar), } @@ -1324,7 +1467,7 @@ where { match left.data_type() { DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, gt_eq_scalar) + dyn_compare_scalar!(left, right, key_type, gt_eq_dyn_scalar) } _ => dyn_compare_scalar!(left, right, gt_eq_scalar), } @@ -1338,7 +1481,7 @@ where { match left.data_type() { DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, neq_scalar) + dyn_compare_scalar!(left, right, key_type, neq_dyn_scalar) } _ => dyn_compare_scalar!(left, right, neq_scalar), } @@ -1931,177 +2074,324 @@ where Ok(BooleanArray::from(data)) } -macro_rules! typed_cmp { - ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident) => {{ - let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| { - ArrowError::CastError(format!( - "Left array cannot be cast to {}", - type_name::<$T>() - )) - })?; - let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| { - ArrowError::CastError(format!( - "Right array cannot be cast to {}", - type_name::<$T>(), - )) - })?; - $OP(left, right) +fn cmp_primitive_array( + left: &dyn Array, + right: &dyn Array, + op: F, +) -> Result +where + F: Fn(T::Native, T::Native) -> bool, +{ + let left_array = as_primitive_array::(left); + let right_array = as_primitive_array::(right); + compare_op(left_array, right_array, op) +} + +#[cfg(feature = "dyn_cmp_dict")] +macro_rules! typed_dict_non_dict_cmp { + ($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP_BOOL: expr, $OP: expr) => {{ + match $LEFT_KEY_TYPE { + DataType::Int8 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::Int16 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::Int32 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::Int64 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt8 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt16 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt32 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt64 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + t => Err(ArrowError::NotYetImplemented(format!( + "Cannot compare dictionary array of key type {}", + t + ))), + } }}; - ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{ - let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| { - ArrowError::CastError(format!( - "Left array cannot be cast to {}", - type_name::<$T>() - )) - })?; - let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| { - ArrowError::CastError(format!( - "Right array cannot be cast to {}", - type_name::<$T>(), - )) - })?; - $OP::<$TT>(left, right) +} + +#[cfg(feature = "dyn_cmp_dict")] +macro_rules! typed_dict_string_array_cmp { + ($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP: expr) => {{ + match $LEFT_KEY_TYPE { + DataType::Int8 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::Int16 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::Int32 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::Int64 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt8 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt16 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt32 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + DataType::UInt64 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) + } + t => Err(ArrowError::NotYetImplemented(format!( + "Cannot compare dictionary array of key type {}", + t + ))), + } + }}; +} + +#[cfg(feature = "dyn_cmp_dict")] +macro_rules! typed_dict_boolean_array_cmp { + ($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $OP: expr) => {{ + match $LEFT_KEY_TYPE { + DataType::Int8 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) + } + DataType::Int16 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) + } + DataType::Int32 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) + } + DataType::Int64 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) + } + DataType::UInt8 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) + } + DataType::UInt16 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) + } + DataType::UInt32 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) + } + DataType::UInt64 => { + let left = as_dictionary_array::($LEFT); + cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) + } + t => Err(ArrowError::NotYetImplemented(format!( + "Cannot compare dictionary array of key type {}", + t + ))), + } + }}; +} + +#[cfg(feature = "dyn_cmp_dict")] +macro_rules! typed_cmp_dict_non_dict { + ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ + match ($LEFT.data_type(), $RIGHT.data_type()) { + (DataType::Dictionary(left_key_type, left_value_type), right_type) => { + match (left_value_type.as_ref(), right_type) { + (DataType::Boolean, DataType::Boolean) => { + typed_dict_boolean_array_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), $OP_BOOL) + } + (DataType::Int8, DataType::Int8) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int8Type, $OP_BOOL, $OP) + } + (DataType::Int16, DataType::Int16) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int16Type, $OP_BOOL, $OP) + } + (DataType::Int32, DataType::Int32) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int32Type, $OP_BOOL, $OP) + } + (DataType::Int64, DataType::Int64) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int64Type, $OP_BOOL, $OP) + } + (DataType::UInt8, DataType::UInt8) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt8Type, $OP_BOOL, $OP) + } + (DataType::UInt16, DataType::UInt16) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt16Type, $OP_BOOL, $OP) + } + (DataType::UInt32, DataType::UInt32) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt32Type, $OP_BOOL, $OP) + } + (DataType::UInt64, DataType::UInt64) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt64Type, $OP_BOOL, $OP) + } + (DataType::Float32, DataType::Float32) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float32Type, $OP_BOOL, $OP_FLOAT) + } + (DataType::Float64, DataType::Float64) => { + typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float64Type, $OP_BOOL, $OP_FLOAT) + } + (DataType::Utf8, DataType::Utf8) => { + typed_dict_string_array_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), i32, $OP) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + typed_dict_string_array_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), i64, $OP) + } + (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( + "Comparing dictionary array of type {} with array of type {} is not yet implemented", + t1, t2 + ))), + (t1, t2) => Err(ArrowError::CastError(format!( + "Cannot compare dictionary array with array of different value types ({} and {})", + t1, t2 + ))), + } + } + _ => unreachable!("Should not reach this branch"), + } }}; } +#[cfg(not(feature = "dyn_cmp_dict"))] +macro_rules! typed_cmp_dict_non_dict { + ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ + Err(ArrowError::CastError(format!( + "Comparing dictionary array of type {} with array of type {} requires \"dyn_cmp_dict\" feature", + $LEFT.data_type(), $RIGHT.data_type() + ))) + }} +} + macro_rules! typed_compares { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident, $OP_BINARY: ident) => {{ + ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { (DataType::Boolean, DataType::Boolean) => { - typed_cmp!($LEFT, $RIGHT, BooleanArray, $OP_BOOL) + compare_op(as_boolean_array($LEFT), as_boolean_array($RIGHT), $OP_BOOL) } (DataType::Int8, DataType::Int8) => { - typed_cmp!($LEFT, $RIGHT, Int8Array, $OP_PRIM, Int8Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Int16, DataType::Int16) => { - typed_cmp!($LEFT, $RIGHT, Int16Array, $OP_PRIM, Int16Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Int32, DataType::Int32) => { - typed_cmp!($LEFT, $RIGHT, Int32Array, $OP_PRIM, Int32Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Int64, DataType::Int64) => { - typed_cmp!($LEFT, $RIGHT, Int64Array, $OP_PRIM, Int64Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::UInt8, DataType::UInt8) => { - typed_cmp!($LEFT, $RIGHT, UInt8Array, $OP_PRIM, UInt8Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::UInt16, DataType::UInt16) => { - typed_cmp!($LEFT, $RIGHT, UInt16Array, $OP_PRIM, UInt16Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::UInt32, DataType::UInt32) => { - typed_cmp!($LEFT, $RIGHT, UInt32Array, $OP_PRIM, UInt32Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::UInt64, DataType::UInt64) => { - typed_cmp!($LEFT, $RIGHT, UInt64Array, $OP_PRIM, UInt64Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Float32, DataType::Float32) => { - typed_cmp!($LEFT, $RIGHT, Float32Array, $OP_PRIM, Float32Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP_FLOAT) } (DataType::Float64, DataType::Float64) => { - typed_cmp!($LEFT, $RIGHT, Float64Array, $OP_PRIM, Float64Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP_FLOAT) } (DataType::Utf8, DataType::Utf8) => { - typed_cmp!($LEFT, $RIGHT, StringArray, $OP_STR, i32) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - typed_cmp!($LEFT, $RIGHT, LargeStringArray, $OP_STR, i64) - } - (DataType::Binary, DataType::Binary) => { - typed_cmp!($LEFT, $RIGHT, BinaryArray, $OP_BINARY, i32) - } - (DataType::LargeBinary, DataType::LargeBinary) => { - typed_cmp!($LEFT, $RIGHT, LargeBinaryArray, $OP_BINARY, i64) + compare_op(as_string_array($LEFT), as_string_array($RIGHT), $OP) } + (DataType::LargeUtf8, DataType::LargeUtf8) => compare_op( + as_largestring_array($LEFT), + as_largestring_array($RIGHT), + $OP, + ), + (DataType::Binary, DataType::Binary) => compare_op( + as_generic_binary_array::($LEFT), + as_generic_binary_array::($RIGHT), + $OP, + ), + (DataType::LargeBinary, DataType::LargeBinary) => compare_op( + as_generic_binary_array::($LEFT), + as_generic_binary_array::($RIGHT), + $OP, + ), ( DataType::Timestamp(TimeUnit::Nanosecond, _), DataType::Timestamp(TimeUnit::Nanosecond, _), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - TimestampNanosecondArray, - $OP_PRIM, - TimestampNanosecondType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Timestamp(TimeUnit::Microsecond, _), DataType::Timestamp(TimeUnit::Microsecond, _), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - TimestampMicrosecondArray, - $OP_PRIM, - TimestampMicrosecondType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Timestamp(TimeUnit::Millisecond, _), DataType::Timestamp(TimeUnit::Millisecond, _), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - TimestampMillisecondArray, - $OP_PRIM, - TimestampMillisecondType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Timestamp(TimeUnit::Second, _), DataType::Timestamp(TimeUnit::Second, _), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - TimestampSecondArray, - $OP_PRIM, - TimestampSecondType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), (DataType::Date32, DataType::Date32) => { - typed_cmp!($LEFT, $RIGHT, Date32Array, $OP_PRIM, Date32Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) } (DataType::Date64, DataType::Date64) => { - typed_cmp!($LEFT, $RIGHT, Date64Array, $OP_PRIM, Date64Type) + cmp_primitive_array::($LEFT, $RIGHT, $OP) + } + (DataType::Time32(TimeUnit::Second), DataType::Time32(TimeUnit::Second)) => { + cmp_primitive_array::($LEFT, $RIGHT, $OP) } + ( + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Millisecond), + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), + ( + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Microsecond), + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), + ( + DataType::Time64(TimeUnit::Nanosecond), + DataType::Time64(TimeUnit::Nanosecond), + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Interval(IntervalUnit::YearMonth), DataType::Interval(IntervalUnit::YearMonth), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - IntervalYearMonthArray, - $OP_PRIM, - IntervalYearMonthType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Interval(IntervalUnit::DayTime), DataType::Interval(IntervalUnit::DayTime), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - IntervalDayTimeArray, - $OP_PRIM, - IntervalDayTimeType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), ( DataType::Interval(IntervalUnit::MonthDayNano), DataType::Interval(IntervalUnit::MonthDayNano), - ) => { - typed_cmp!( - $LEFT, - $RIGHT, - IntervalMonthDayNanoArray, - $OP_PRIM, - IntervalMonthDayNanoType - ) - } + ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing arrays of type {} is not yet implemented", t1 @@ -2115,8 +2405,9 @@ macro_rules! typed_compares { } /// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT +#[cfg(feature = "dyn_cmp_dict")] macro_rules! typed_dict_cmp { - ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr, $KT: tt) => {{ + ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr, $KT: tt) => {{ match ($LEFT.value_type(), $RIGHT.value_type()) { (DataType::Boolean, DataType::Boolean) => { cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP_BOOL) @@ -2146,10 +2437,10 @@ macro_rules! typed_dict_cmp { cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP) } (DataType::Float32, DataType::Float32) => { - cmp_dict::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP) + cmp_dict::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP_FLOAT) } (DataType::Float64, DataType::Float64) => { - cmp_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP) + cmp_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP_FLOAT) } (DataType::Utf8, DataType::Utf8) => { cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP) @@ -2193,6 +2484,30 @@ macro_rules! typed_dict_cmp { (DataType::Date64, DataType::Date64) => { cmp_dict::<$KT, Date64Type, _>($LEFT, $RIGHT, $OP) } + ( + DataType::Time32(TimeUnit::Second), + DataType::Time32(TimeUnit::Second), + ) => { + cmp_dict::<$KT, Time32SecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Millisecond), + ) => { + cmp_dict::<$KT, Time32MillisecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Microsecond), + ) => { + cmp_dict::<$KT, Time64MicrosecondType, _>($LEFT, $RIGHT, $OP) + } + ( + DataType::Time64(TimeUnit::Nanosecond), + DataType::Time64(TimeUnit::Nanosecond), + ) => { + cmp_dict::<$KT, Time64NanosecondType, _>($LEFT, $RIGHT, $OP) + } ( DataType::Interval(IntervalUnit::YearMonth), DataType::Interval(IntervalUnit::YearMonth), @@ -2223,51 +2538,52 @@ macro_rules! typed_dict_cmp { }}; } +#[cfg(feature = "dyn_cmp_dict")] macro_rules! typed_dict_compares { // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` - ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_BOOL: expr) => {{ + ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{ match ($LEFT.data_type(), $RIGHT.data_type()) { (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { match (left_key_type.as_ref(), right_key_type.as_ref()) { (DataType::Int8, DataType::Int8) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int8Type) + typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int8Type) } (DataType::Int16, DataType::Int16) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int16Type) + typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int16Type) } (DataType::Int32, DataType::Int32) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int32Type) + typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int32Type) } (DataType::Int64, DataType::Int64) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_BOOL, Int64Type) + typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int64Type) } (DataType::UInt8, DataType::UInt8) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt8Type) + typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt8Type) } (DataType::UInt16, DataType::UInt16) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt16Type) + typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt16Type) } (DataType::UInt32, DataType::UInt32) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt32Type) + typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt32Type) } (DataType::UInt64, DataType::UInt64) => { let left = as_dictionary_array::($LEFT); let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_BOOL, UInt64Type) + typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt64Type) } (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( "Comparing dictionary arrays of type {} is not yet implemented", @@ -2287,70 +2603,102 @@ macro_rules! typed_dict_compares { }}; } -/// Helper function to perform boolean lambda function on values from two dictionary arrays, this -/// version does not attempt to use SIMD explicitly (though the compiler may auto vectorize) -macro_rules! compare_dict_op { - ($left: expr, $right:expr, $op:expr, $value_ty:ty) => {{ - if $left.len() != $right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - // Safety justification: Since the inputs are valid Arrow arrays, all values are - // valid indexes into the dictionary (which is verified during construction) - - let left_iter = unsafe { - $left - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($left.keys_iter()) - }; - - let right_iter = unsafe { - $right - .values() - .as_any() - .downcast_ref::<$value_ty>() - .unwrap() - .take_iter_unchecked($right.keys_iter()) - }; - - let result = left_iter - .zip(right_iter) - .map(|(left_value, right_value)| { - if let (Some(left), Some(right)) = (left_value, right_value) { - Some($op(left, right)) - } else { - None - } - }) - .collect(); - - Ok(result) - }}; +#[cfg(not(feature = "dyn_cmp_dict"))] +macro_rules! typed_dict_compares { + ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{ + Err(ArrowError::CastError(format!( + "Comparing array of type {} with array of type {} requires \"dyn_cmp_dict\" feature", + $LEFT.data_type(), $RIGHT.data_type() + ))) + }} } -/// Perform given operation on two `DictionaryArray`s. -/// Returns an error if the two arrays have different value type -pub fn cmp_dict( +/// Perform given operation on `DictionaryArray` and `PrimitiveArray`. The value +/// type of `DictionaryArray` is same as `PrimitiveArray`'s type. +#[cfg(feature = "dyn_cmp_dict")] +fn cmp_dict_primitive( left: &DictionaryArray, - right: &DictionaryArray, + right: &dyn Array, op: F, ) -> Result where K: ArrowNumericType, - T: ArrowNumericType, + T: ArrowNumericType + Sync + Send, F: Fn(T::Native, T::Native) -> bool, { - compare_dict_op!(left, right, op, PrimitiveArray) + compare_op( + left.downcast_dict::>().unwrap(), + as_primitive_array::(right), + op, + ) } -/// Perform the given operation on two `DictionaryArray`s which value type is -/// `DataType::Boolean`. +/// Perform given operation on `DictionaryArray` and `GenericStringArray`. The value +/// type of `DictionaryArray` is same as `GenericStringArray`'s type. +#[cfg(feature = "dyn_cmp_dict")] +fn cmp_dict_string_array( + left: &DictionaryArray, + right: &dyn Array, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(&str, &str) -> bool, +{ + compare_op( + left.downcast_dict::>() + .unwrap(), + right + .as_any() + .downcast_ref::>() + .unwrap(), + op, + ) +} + +/// Perform given operation on `DictionaryArray` and `BooleanArray`. The value +/// type of `DictionaryArray` is same as `BooleanArray`'s type. +#[cfg(feature = "dyn_cmp_dict")] +fn cmp_dict_boolean_array( + left: &DictionaryArray, + right: &dyn Array, + op: F, +) -> Result +where + K: ArrowNumericType, + F: Fn(bool, bool) -> bool, +{ + compare_op( + left.downcast_dict::().unwrap(), + right.as_any().downcast_ref::().unwrap(), + op, + ) +} + +/// Perform given operation on two `DictionaryArray`s which value type is +/// primitive type. Returns an error if the two arrays have different value +/// type +#[cfg(feature = "dyn_cmp_dict")] +pub fn cmp_dict( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + T: ArrowNumericType + Sync + Send, + F: Fn(T::Native, T::Native) -> bool, +{ + compare_op( + left.downcast_dict::>().unwrap(), + right.downcast_dict::>().unwrap(), + op, + ) +} + +/// Perform the given operation on two `DictionaryArray`s which value type is +/// `DataType::Boolean`. +#[cfg(feature = "dyn_cmp_dict")] pub fn cmp_dict_bool( left: &DictionaryArray, right: &DictionaryArray, @@ -2360,11 +2708,16 @@ where K: ArrowNumericType, F: Fn(bool, bool) -> bool, { - compare_dict_op!(left, right, op, BooleanArray) + compare_op( + left.downcast_dict::().unwrap(), + right.downcast_dict::().unwrap(), + op, + ) } /// Perform the given operation on two `DictionaryArray`s which value type is /// `DataType::Utf8` or `DataType::LargeUtf8`. +#[cfg(feature = "dyn_cmp_dict")] pub fn cmp_dict_utf8( left: &DictionaryArray, right: &DictionaryArray, @@ -2374,11 +2727,19 @@ where K: ArrowNumericType, F: Fn(&str, &str) -> bool, { - compare_dict_op!(left, right, op, GenericStringArray) + compare_op( + left.downcast_dict::>() + .unwrap(), + right + .downcast_dict::>() + .unwrap(), + op, + ) } /// Perform the given operation on two `DictionaryArray`s which value type is /// `DataType::Binary` or `DataType::LargeBinary`. +#[cfg(feature = "dyn_cmp_dict")] pub fn cmp_dict_binary( left: &DictionaryArray, right: &DictionaryArray, @@ -2388,7 +2749,14 @@ where K: ArrowNumericType, F: Fn(&[u8], &[u8]) -> bool, { - compare_dict_op!(left, right, op, GenericBinaryArray) + compare_op( + left.downcast_dict::>() + .unwrap(), + right + .downcast_dict::>() + .unwrap(), + op, + ) } /// Perform `left == right` operation on two (dynamic) [`Array`]s. @@ -2396,6 +2764,10 @@ where /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. /// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Please refer to `f32::total_cmp` and `f64::total_cmp`. +/// /// # Example /// ``` /// use arrow::array::{StringArray, BooleanArray}; @@ -2407,10 +2779,34 @@ where /// ``` pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { - DataType::Dictionary(_, _) => { - typed_dict_compares!(left, right, |a, b| a == b, |a, b| a == b) + DataType::Dictionary(_, _) + if matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_dict_compares!( + left, + right, + |a, b| a == b, + |a, b| a.total_cmp(&b).is_eq(), + |a, b| a == b + ) + } + DataType::Dictionary(_, _) + if !matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_cmp_dict_non_dict!(left, right, |a, b| a == b, |a, b| a == b, |a, b| a + .total_cmp(&b) + .is_eq()) + } + _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { + typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b, |a, b| a + .total_cmp(&b) + .is_eq()) + } + _ => { + typed_compares!(left, right, |a, b| !(a ^ b), |a, b| a == b, |a, b| a + .total_cmp(&b) + .is_eq()) } - _ => typed_compares!(left, right, eq_bool, eq, eq_utf8, eq_binary), } } @@ -2419,6 +2815,10 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. /// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Please refer to `f32::total_cmp` and `f64::total_cmp`. +/// /// # Example /// ``` /// use arrow::array::{BinaryArray, BooleanArray}; @@ -2432,10 +2832,34 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// ``` pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { - DataType::Dictionary(_, _) => { - typed_dict_compares!(left, right, |a, b| a != b, |a, b| a != b) + DataType::Dictionary(_, _) + if matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_dict_compares!( + left, + right, + |a, b| a != b, + |a, b| a.total_cmp(&b).is_ne(), + |a, b| a != b + ) + } + DataType::Dictionary(_, _) + if !matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_cmp_dict_non_dict!(left, right, |a, b| a != b, |a, b| a != b, |a, b| a + .total_cmp(&b) + .is_ne()) + } + _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { + typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b, |a, b| a + .total_cmp(&b) + .is_ne()) + } + _ => { + typed_compares!(left, right, |a, b| (a ^ b), |a, b| a != b, |a, b| a + .total_cmp(&b) + .is_ne()) } - _ => typed_compares!(left, right, neq_bool, neq, neq_utf8, neq_binary), } } @@ -2444,6 +2868,10 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. /// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Please refer to `f32::total_cmp` and `f64::total_cmp`. +/// /// # Example /// ``` /// use arrow::array::{PrimitiveArray, BooleanArray}; @@ -2457,10 +2885,34 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result { #[allow(clippy::bool_comparison)] pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { - DataType::Dictionary(_, _) => { - typed_dict_compares!(left, right, |a, b| a < b, |a, b| a < b) + DataType::Dictionary(_, _) + if matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_dict_compares!( + left, + right, + |a, b| a < b, + |a, b| a.total_cmp(&b).is_lt(), + |a, b| a < b + ) + } + DataType::Dictionary(_, _) + if !matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_cmp_dict_non_dict!(left, right, |a, b| a < b, |a, b| a < b, |a, b| a + .total_cmp(&b) + .is_lt()) + } + _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { + typed_cmp_dict_non_dict!(right, left, |a, b| a > b, |a, b| a > b, |a, b| b + .total_cmp(&a) + .is_lt()) + } + _ => { + typed_compares!(left, right, |a, b| ((!a) & b), |a, b| a < b, |a, b| a + .total_cmp(&b) + .is_lt()) } - _ => typed_compares!(left, right, lt_bool, lt, lt_utf8, lt_binary), } } @@ -2469,6 +2921,10 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. /// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Please refer to `f32::total_cmp` and `f64::total_cmp`. +/// /// # Example /// ``` /// use arrow::array::{PrimitiveArray, BooleanArray}; @@ -2481,10 +2937,34 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// ``` pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { - DataType::Dictionary(_, _) => { - typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a <= b) + DataType::Dictionary(_, _) + if matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_dict_compares!( + left, + right, + |a, b| a <= b, + |a, b| a.total_cmp(&b).is_le(), + |a, b| a <= b + ) + } + DataType::Dictionary(_, _) + if !matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_cmp_dict_non_dict!(left, right, |a, b| a <= b, |a, b| a <= b, |a, b| a + .total_cmp(&b) + .is_le()) + } + _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { + typed_cmp_dict_non_dict!(right, left, |a, b| a >= b, |a, b| a >= b, |a, b| b + .total_cmp(&a) + .is_le()) + } + _ => { + typed_compares!(left, right, |a, b| !(a & (!b)), |a, b| a <= b, |a, b| a + .total_cmp(&b) + .is_le()) } - _ => typed_compares!(left, right, lt_eq_bool, lt_eq, lt_eq_utf8, lt_eq_binary), } } @@ -2493,6 +2973,10 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. /// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Please refer to `f32::total_cmp` and `f64::total_cmp`. +/// /// # Example /// ``` /// use arrow::array::BooleanArray; @@ -2505,10 +2989,34 @@ pub fn lt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { #[allow(clippy::bool_comparison)] pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { - DataType::Dictionary(_, _) => { - typed_dict_compares!(left, right, |a, b| a > b, |a, b| a > b) + DataType::Dictionary(_, _) + if matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_dict_compares!( + left, + right, + |a, b| a > b, + |a, b| a.total_cmp(&b).is_gt(), + |a, b| a > b + ) + } + DataType::Dictionary(_, _) + if !matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_cmp_dict_non_dict!(left, right, |a, b| a > b, |a, b| a > b, |a, b| a + .total_cmp(&b) + .is_gt()) + } + _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { + typed_cmp_dict_non_dict!(right, left, |a, b| a < b, |a, b| a < b, |a, b| b + .total_cmp(&a) + .is_gt()) + } + _ => { + typed_compares!(left, right, |a, b| (a & (!b)), |a, b| a > b, |a, b| a + .total_cmp(&b) + .is_gt()) } - _ => typed_compares!(left, right, gt_bool, gt, gt_utf8, gt_binary), } } @@ -2517,6 +3025,10 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// Only when two arrays are of the same type the comparison will happen otherwise it will err /// with a casting error. /// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Please refer to `f32::total_cmp` and `f64::total_cmp`. +/// /// # Example /// ``` /// use arrow::array::{BooleanArray, StringArray}; @@ -2528,10 +3040,34 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { /// ``` pub fn gt_eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { match left.data_type() { - DataType::Dictionary(_, _) => { - typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a >= b) + DataType::Dictionary(_, _) + if matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_dict_compares!( + left, + right, + |a, b| a >= b, + |a, b| a.total_cmp(&b).is_ge(), + |a, b| a >= b + ) + } + DataType::Dictionary(_, _) + if !matches!(right.data_type(), DataType::Dictionary(_, _)) => + { + typed_cmp_dict_non_dict!(left, right, |a, b| a >= b, |a, b| a >= b, |a, b| a + .total_cmp(&b) + .is_ge()) + } + _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { + typed_cmp_dict_non_dict!(right, left, |a, b| a <= b, |a, b| a <= b, |a, b| b + .total_cmp(&a) + .is_ge()) + } + _ => { + typed_compares!(left, right, |a, b| !((!a) & b), |a, b| a >= b, |a, b| a + .total_cmp(&b) + .is_ge()) } - _ => typed_compares!(left, right, gt_eq_bool, gt_eq, gt_eq_utf8, gt_eq_binary), } } @@ -2543,7 +3079,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::eq, |a, b| a == b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a == b); + return compare_op(left, right, |a, b| a == b); } /// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2554,7 +3090,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a == right); + return compare_op_scalar(left, |a| a == right); } /// Applies an unary and infallible comparison function to a primitive array. @@ -2563,7 +3099,7 @@ where T: ArrowNumericType, F: Fn(T::Native) -> bool, { - return compare_op_scalar!(left, op); + compare_op_scalar(left, op) } /// Perform `left != right` operation on two [`PrimitiveArray`]s. @@ -2574,7 +3110,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::ne, |a, b| a != b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a != b); + return compare_op(left, right, |a, b| a != b); } /// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2585,7 +3121,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a != right); + return compare_op_scalar(left, |a| a != right); } /// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2597,7 +3133,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::lt, |a, b| a < b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a < b); + return compare_op(left, right, |a, b| a < b); } /// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2609,7 +3145,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a < right); + return compare_op_scalar(left, |a| a < right); } /// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2624,7 +3160,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::le, |a, b| a <= b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a <= b); + return compare_op(left, right, |a, b| a <= b); } /// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2636,7 +3172,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a <= right); + return compare_op_scalar(left, |a| a <= right); } /// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2648,7 +3184,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::gt, |a, b| a > b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a > b); + return compare_op(left, right, |a, b| a > b); } /// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2660,7 +3196,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a > right); + return compare_op_scalar(left, |a| a > right); } /// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2675,7 +3211,7 @@ where #[cfg(feature = "simd")] return simd_compare_op(left, right, T::ge, |a, b| a >= b); #[cfg(not(feature = "simd"))] - return compare_op!(left, right, |a, b| a >= b); + return compare_op(left, right, |a, b| a >= b); } /// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2687,7 +3223,7 @@ where #[cfg(feature = "simd")] return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b); #[cfg(not(feature = "simd"))] - return compare_op_scalar!(left, |a| a >= right); + return compare_op_scalar(left, |a| a >= right); } /// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`] @@ -2913,6 +3449,42 @@ mod tests { vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], vec![false, false, true, false, false, false, false, true, false, false] ); + + cmp_vec!( + eq, + eq_dyn, + Time32SecondArray, + vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, true, false, false, false, false, true, false, false] + ); + + cmp_vec!( + eq, + eq_dyn, + Time32MillisecondArray, + vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, true, false, false, false, false, true, false, false] + ); + + cmp_vec!( + eq, + eq_dyn, + Time64MicrosecondArray, + vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, true, false, false, false, false, true, false, false] + ); + + cmp_vec!( + eq, + eq_dyn, + Time64NanosecondArray, + vec![8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + vec![6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, true, false, false, false, false, true, false, false] + ); } #[test] @@ -3677,20 +4249,20 @@ mod tests { // contains(null, null) = false #[test] fn test_contains_utf8() { - let values_builder = StringBuilder::new(10); + let values_builder = StringBuilder::new(); let mut builder = ListBuilder::new(values_builder); - builder.values().append_value("Lorem").unwrap(); - builder.values().append_value("ipsum").unwrap(); - builder.values().append_null().unwrap(); - builder.append(true).unwrap(); - builder.values().append_value("sit").unwrap(); - builder.values().append_value("amet").unwrap(); - builder.values().append_value("Lorem").unwrap(); - builder.append(true).unwrap(); - builder.append(false).unwrap(); - builder.values().append_value("ipsum").unwrap(); - builder.append(true).unwrap(); + builder.values().append_value("Lorem"); + builder.values().append_value("ipsum"); + builder.values().append_null(); + builder.append(true); + builder.values().append_value("sit"); + builder.values().append_value("amet"); + builder.values().append_value("Lorem"); + builder.append(true); + builder.append(false); + builder.values().append_value("ipsum"); + builder.append(true); // [["Lorem", "ipsum", null], ["sit", "amet", "Lorem"], null, ["ipsum"]] // value_offsets = [0, 3, 6, 6] @@ -3937,6 +4509,50 @@ mod tests { vec![false, true, false, false] ); + test_utf8_scalar!( + test_utf8_scalar_like_escape, + vec!["a%", "a\\x"], + "a\\%", + like_utf8_scalar, + vec![true, false] + ); + + test_utf8!( + test_utf8_scalar_ilike_regex, + vec!["%%%"], + vec![r#"\%_\%"#], + ilike_utf8, + vec![true] + ); + + #[test] + fn test_replace_like_wildcards() { + let a_eq = "_%"; + let expected = "..*"; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + + #[test] + fn test_replace_like_wildcards_leave_like_meta_chars() { + let a_eq = "\\%\\_"; + let expected = "%_"; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + + #[test] + fn test_replace_like_wildcards_with_multiple_escape_chars() { + let a_eq = "\\\\%"; + let expected = "\\\\%"; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + + #[test] + fn test_replace_like_wildcards_escape_regex_meta_char() { + let a_eq = "."; + let expected = "\\."; + assert_eq!(replace_like_wildcards(a_eq).unwrap(), expected); + } + test_utf8!( test_utf8_array_eq, vec!["arrow", "arrow", "arrow", "arrow"], @@ -4063,7 +4679,7 @@ mod tests { test_utf8_scalar!( test_utf8_array_ilike_scalar_equals, vec!["arrow", "parrow", "arrows", "arr"], - "arrow", + "Arrow", ilike_utf8_scalar, vec![true, false, false, false] ); @@ -4116,8 +4732,8 @@ mod tests { test_utf8_scalar!( test_utf8_array_nilike_scalar_equals, - vec!["arrow", "parrow", "arrows", "arr"], - "arrow", + vec!["arRow", "parrow", "arrows", "arr"], + "Arrow", nilike_utf8_scalar, vec![false, true, true, true] ); @@ -4257,11 +4873,11 @@ mod tests { #[test] fn test_eq_dyn_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(123).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(23).unwrap(); let array = builder.finish(); let a_eq = eq_dyn_scalar(&array, 123).unwrap(); @@ -4301,11 +4917,11 @@ mod tests { #[test] fn test_lt_dyn_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(123).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(23).unwrap(); let array = builder.finish(); let a_eq = lt_dyn_scalar(&array, 123).unwrap(); @@ -4344,11 +4960,11 @@ mod tests { } #[test] fn test_lt_eq_dyn_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::new(); + let value_builder = PrimitiveBuilder::::new(); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(123).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(23).unwrap(); let array = builder.finish(); let a_eq = lt_eq_dyn_scalar(&array, 23).unwrap(); @@ -4388,11 +5004,11 @@ mod tests { #[test] fn test_gt_dyn_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(123).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(23).unwrap(); let array = builder.finish(); let a_eq = gt_dyn_scalar(&array, 23).unwrap(); @@ -4432,11 +5048,11 @@ mod tests { #[test] fn test_gt_eq_dyn_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::new(); + let value_builder = PrimitiveBuilder::::new(); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(22).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(23).unwrap(); let array = builder.finish(); let a_eq = gt_eq_dyn_scalar(&array, 23).unwrap(); @@ -4476,11 +5092,11 @@ mod tests { #[test] fn test_neq_dyn_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::new(); + let value_builder = PrimitiveBuilder::::new(); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(22).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(23).unwrap(); let array = builder.finish(); let a_eq = neq_dyn_scalar(&array, 23).unwrap(); @@ -4620,11 +5236,11 @@ mod tests { #[test] fn test_eq_dyn_utf8_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = StringBuilder::new(100); + let key_builder = PrimitiveBuilder::::new(); + let value_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); builder.append("abc").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("def").unwrap(); builder.append("def").unwrap(); builder.append("abc").unwrap(); @@ -4648,11 +5264,11 @@ mod tests { } #[test] fn test_lt_dyn_utf8_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = StringBuilder::new(100); + let key_builder = PrimitiveBuilder::::new(); + let value_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); builder.append("abc").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("def").unwrap(); builder.append("def").unwrap(); builder.append("abc").unwrap(); @@ -4677,11 +5293,11 @@ mod tests { } #[test] fn test_lt_eq_dyn_utf8_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = StringBuilder::new(100); + let key_builder = PrimitiveBuilder::::new(); + let value_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); builder.append("abc").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("def").unwrap(); builder.append("def").unwrap(); builder.append("xyz").unwrap(); @@ -4706,11 +5322,11 @@ mod tests { } #[test] fn test_gt_eq_dyn_utf8_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = StringBuilder::new(100); + let key_builder = PrimitiveBuilder::::new(); + let value_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); builder.append("abc").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("def").unwrap(); builder.append("def").unwrap(); builder.append("xyz").unwrap(); @@ -4736,11 +5352,11 @@ mod tests { #[test] fn test_gt_dyn_utf8_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = StringBuilder::new(100); + let key_builder = PrimitiveBuilder::::new(); + let value_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); builder.append("abc").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("def").unwrap(); builder.append("def").unwrap(); builder.append("xyz").unwrap(); @@ -4765,11 +5381,11 @@ mod tests { } #[test] fn test_neq_dyn_utf8_scalar_with_dict() { - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = StringBuilder::new(100); + let key_builder = PrimitiveBuilder::::new(); + let value_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(key_builder, value_builder); builder.append("abc").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append("def").unwrap(); builder.append("def").unwrap(); builder.append("abc").unwrap(); @@ -4844,6 +5460,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_i8_array() { // Construct a value array let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); @@ -4864,6 +5481,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_u64_array() { let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]); @@ -4885,6 +5503,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_utf8_array() { let test1 = vec!["a", "a", "b", "c"]; let test2 = vec!["a", "b", "b", "c"]; @@ -4912,6 +5531,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_binary_array() { let values: BinaryArray = ["hello", "", "parquet"] .into_iter() @@ -4936,6 +5556,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_interval_array() { let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]); @@ -4957,6 +5578,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_date_array() { let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]); @@ -4978,6 +5600,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_bool_array() { let values = BooleanArray::from(vec![true, false]); @@ -4999,6 +5622,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_gt_dyn_dictionary_i8_array() { // Construct a value array let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); @@ -5028,6 +5652,7 @@ mod tests { } #[test] + #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_gt_dyn_dictionary_bool_array() { let values = BooleanArray::from(vec![true, false]); @@ -5068,4 +5693,809 @@ mod tests { BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]) ); } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_eq_dyn_neq_dyn_dictionary_i8_i8_array() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + + let array = Int8Array::from_iter([Some(12_i8), None, Some(14)]); + + let result = eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + + let result = eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + + let result = neq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, Some(false)]) + ); + + let result = neq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, Some(false)]) + ); + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_i8_i8_array() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + + let array = Int8Array::from_iter([Some(12_i8), None, Some(11)]); + + let result = lt_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, Some(false)]) + ); + + let result = lt_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, Some(true)]) + ); + + let result = lt_eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, Some(false)]) + ); + + let result = lt_eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + + let result = gt_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, Some(true)]) + ); + + let result = gt_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, Some(false)]) + ); + + let result = gt_eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + + let result = gt_eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, Some(false)]) + ); + } + + #[test] + fn test_eq_dyn_neq_dyn_float_nan() { + let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let array2: Float32Array = vec![f32::NAN, f32::NAN, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(true), Some(true)], + ); + assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(false), Some(false)], + ); + assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); + + let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let array2: Float64Array = vec![f64::NAN, f64::NAN, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(true), Some(true)], + ); + assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(false), Some(false)], + ); + assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); + } + + #[test] + fn test_lt_dyn_lt_eq_dyn_float_nan() { + let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] + .into_iter() + .map(Some) + .collect(); + let array2: Float32Array = vec![f32::NAN, f32::NAN, 8.0, 9.0, 10.0, 1.0] + .into_iter() + .map(Some) + .collect(); + + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); + + let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] + .into_iter() + .map(Some) + .collect(); + let array2: Float64Array = vec![f64::NAN, f64::NAN, 8.0, 9.0, 10.0, 1.0] + .into_iter() + .map(Some) + .collect(); + + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); + } + + #[test] + fn test_gt_dyn_gt_eq_dyn_float_nan() { + let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] + .into_iter() + .map(Some) + .collect(); + let array2: Float32Array = vec![f32::NAN, f32::NAN, 8.0, 9.0, 10.0, 1.0] + .into_iter() + .map(Some) + .collect(); + + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); + + let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] + .into_iter() + .map(Some) + .collect(); + let array2: Float64Array = vec![f64::NAN, f64::NAN, 8.0, 9.0, 10.0, 1.0] + .into_iter() + .map(Some) + .collect(); + + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); + } + + #[test] + fn test_eq_dyn_scalar_neq_dyn_scalar_float_nan() { + let array: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(eq_dyn_scalar(&array, f32::NAN).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(true), Some(true)], + ); + assert_eq!(neq_dyn_scalar(&array, f32::NAN).unwrap(), expected); + + let array: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(eq_dyn_scalar(&array, f64::NAN).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(true), Some(true)], + ); + assert_eq!(neq_dyn_scalar(&array, f64::NAN).unwrap(), expected); + } + + #[test] + fn test_lt_dyn_scalar_lt_eq_dyn_scalar_float_nan() { + let array: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(lt_dyn_scalar(&array, f32::NAN).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn_scalar(&array, f32::NAN).unwrap(), expected); + + let array: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(lt_dyn_scalar(&array, f64::NAN).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn_scalar(&array, f64::NAN).unwrap(), expected); + } + + #[test] + fn test_gt_dyn_scalar_gt_eq_dyn_scalar_float_nan() { + let array: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(gt_dyn_scalar(&array, f32::NAN).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(gt_eq_dyn_scalar(&array, f32::NAN).unwrap(), expected); + + let array: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(gt_dyn_scalar(&array, f64::NAN).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(false)], + ); + assert_eq!(gt_eq_dyn_scalar(&array, f64::NAN).unwrap(), expected); + } + + #[test] + fn test_dict_like_kernels() { + let data = + vec![Some("Earth"), Some("Fire"), Some("Water"), Some("Air"), None, Some("Air")]; + + let dict_array: DictionaryArray = data.into_iter().collect(); + + assert_eq!( + like_dict_scalar(&dict_array, "Air").unwrap(), + BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(true), None, Some(true)] + ), + ); + + assert_eq!( + like_dict_scalar(&dict_array, "Wa%").unwrap(), + BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(false), None, Some(false)] + ), + ); + + assert_eq!( + like_dict_scalar(&dict_array, "%r").unwrap(), + BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), None, Some(true)] + ), + ); + + assert_eq!( + like_dict_scalar(&dict_array, "%i%").unwrap(), + BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(true), None, Some(true)] + ), + ); + + assert_eq!( + like_dict_scalar(&dict_array, "%a%r%").unwrap(), + BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(false), None, Some(false)] + ), + ); + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_eq_dyn_neq_dyn_dictionary_to_utf8_array() { + let test1 = vec!["a", "a", "b", "c"]; + let test2 = vec!["a", "b", "b", "d"]; + + let dict_array: DictionaryArray = test1 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + + let array: StringArray = test2 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + + let result = eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(false)]) + ); + + let result = eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(false)]) + ); + + let result = neq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(true)]) + ); + + let result = neq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(true)]) + ); + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_to_utf8_array() { + let test1 = vec!["abc", "abc", "b", "cde"]; + let test2 = vec!["abc", "b", "b", "def"]; + + let dict_array: DictionaryArray = test1 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + + let array: StringArray = test2 + .iter() + .map(|&x| if x == "b" { None } else { Some(x) }) + .collect(); + + let result = lt_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(true)]) + ); + + let result = lt_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(false)]) + ); + + let result = lt_eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(true)]) + ); + + let result = lt_eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(false)]) + ); + + let result = gt_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(false)]) + ); + + let result = gt_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(true)]) + ); + + let result = gt_eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(false)]) + ); + + let result = gt_eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(true)]) + ); + } + + #[test] + fn test_dict_nlike_kernels() { + let data = + vec![Some("Earth"), Some("Fire"), Some("Water"), Some("Air"), None, Some("Air")]; + + let dict_array: DictionaryArray = data.into_iter().collect(); + + assert_eq!( + nlike_dict_scalar(&dict_array, "Air").unwrap(), + BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(false), None, Some(false)] + ), + ); + + assert_eq!( + nlike_dict_scalar(&dict_array, "Wa%").unwrap(), + BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(true), None, Some(true)] + ), + ); + + assert_eq!( + nlike_dict_scalar(&dict_array, "%r").unwrap(), + BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), None, Some(false)] + ), + ); + + assert_eq!( + nlike_dict_scalar(&dict_array, "%i%").unwrap(), + BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(false), None, Some(false)] + ), + ); + + assert_eq!( + nlike_dict_scalar(&dict_array, "%a%r%").unwrap(), + BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(true), None, Some(true)] + ), + ); + } + + #[test] + fn test_dict_ilike_kernels() { + let data = + vec![Some("Earth"), Some("Fire"), Some("Water"), Some("Air"), None, Some("Air")]; + + let dict_array: DictionaryArray = data.into_iter().collect(); + + assert_eq!( + ilike_dict_scalar(&dict_array, "air").unwrap(), + BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(true), None, Some(true)] + ), + ); + + assert_eq!( + ilike_dict_scalar(&dict_array, "wa%").unwrap(), + BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(false), None, Some(false)] + ), + ); + + assert_eq!( + ilike_dict_scalar(&dict_array, "%R").unwrap(), + BooleanArray::from( + vec![Some(false), Some(false), Some(true), Some(true), None, Some(true)] + ), + ); + + assert_eq!( + ilike_dict_scalar(&dict_array, "%I%").unwrap(), + BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(true), None, Some(true)] + ), + ); + + assert_eq!( + ilike_dict_scalar(&dict_array, "%A%r%").unwrap(), + BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(true), None, Some(true)] + ), + ); + } + + #[test] + fn test_dict_nilike_kernels() { + let data = + vec![Some("Earth"), Some("Fire"), Some("Water"), Some("Air"), None, Some("Air")]; + + let dict_array: DictionaryArray = data.into_iter().collect(); + + assert_eq!( + nilike_dict_scalar(&dict_array, "air").unwrap(), + BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(false), None, Some(false)] + ), + ); + + assert_eq!( + nilike_dict_scalar(&dict_array, "wa%").unwrap(), + BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(true), None, Some(true)] + ), + ); + + assert_eq!( + nilike_dict_scalar(&dict_array, "%R").unwrap(), + BooleanArray::from( + vec![Some(true), Some(true), Some(false), Some(false), None, Some(false)] + ), + ); + + assert_eq!( + nilike_dict_scalar(&dict_array, "%I%").unwrap(), + BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(false), None, Some(false)] + ), + ); + + assert_eq!( + nilike_dict_scalar(&dict_array, "%A%r%").unwrap(), + BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(false), None, Some(false)] + ), + ); + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_eq_dyn_neq_dyn_dict_non_dict_float_nan() { + let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let values = Float32Array::from(vec![f32::NAN, 8.0, 10.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 1, 2]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(true), Some(true)], + ); + assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(false), Some(false)], + ); + assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); + + let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] + .into_iter() + .map(Some) + .collect(); + let values = Float64Array::from(vec![f64::NAN, 8.0, 10.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 1, 2]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(true), Some(true)], + ); + assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(false), Some(false)], + ); + assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_lt_dyn_lt_eq_dyn_dict_non_dict_float_nan() { + let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] + .into_iter() + .map(Some) + .collect(); + let values = Float32Array::from(vec![f32::NAN, 8.0, 9.0, 10.0, 1.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); + + let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] + .into_iter() + .map(Some) + .collect(); + let values = Float64Array::from(vec![f64::NAN, 8.0, 9.0, 10.0, 1.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + let expected = BooleanArray::from( + vec![Some(false), Some(true), Some(false), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(true), Some(true), Some(true), Some(false), Some(false)], + ); + assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_gt_dyn_gt_eq_dyn_dict_non_dict_float_nan() { + let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] + .into_iter() + .map(Some) + .collect(); + let values = Float32Array::from(vec![f32::NAN, 8.0, 9.0, 10.0, 1.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); + + let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] + .into_iter() + .map(Some) + .collect(); + let values = Float64Array::from(vec![f64::NAN, 8.0, 9.0, 10.0, 1.0]); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]); + let array2 = DictionaryArray::try_new(&keys, &values).unwrap(); + + let expected = BooleanArray::from( + vec![Some(false), Some(false), Some(false), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); + + let expected = BooleanArray::from( + vec![Some(true), Some(false), Some(true), Some(false), Some(true), Some(true)], + ); + assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_eq_dyn_neq_dyn_dictionary_to_boolean_array() { + let test1 = vec![Some(true), None, Some(false)]; + let test2 = vec![Some(true), None, None, Some(true)]; + + let values = BooleanArray::from(test1); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + + let array: BooleanArray = test2.iter().collect(); + + let result = eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(false)]) + ); + + let result = eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(false)]) + ); + + let result = neq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(true)]) + ); + + let result = neq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(true)]) + ); + } + + #[test] + #[cfg(feature = "dyn_cmp_dict")] + fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_to_boolean_array() { + let test1 = vec![Some(true), None, Some(false)]; + let test2 = vec![Some(true), None, None, Some(true)]; + + let values = BooleanArray::from(test1); + let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + + let array: BooleanArray = test2.iter().collect(); + + let result = lt_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(true)]) + ); + + let result = lt_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(false)]) + ); + + let result = lt_eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(true)]) + ); + + let result = lt_eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(false)]) + ); + + let result = gt_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(false)]) + ); + + let result = gt_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(false), None, None, Some(true)]) + ); + + let result = gt_eq_dyn(&dict_array, &array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(false)]) + ); + + let result = gt_eq_dyn(&array, &dict_array); + assert_eq!( + result.unwrap(), + BooleanArray::from(vec![Some(true), None, None, Some(true)]) + ); + } } diff --git a/arrow/src/compute/kernels/concat_elements.rs b/arrow/src/compute/kernels/concat_elements.rs index 7d460b21cb0d..ac365a0968ec 100644 --- a/arrow/src/compute/kernels/concat_elements.rs +++ b/arrow/src/compute/kernels/concat_elements.rs @@ -75,7 +75,7 @@ pub fn concat_elements_utf8( output_offsets.append(Offset::from_usize(output_values.len()).unwrap()); } - let builder = ArrayDataBuilder::new(GenericStringArray::::get_data_type()) + let builder = ArrayDataBuilder::new(GenericStringArray::::DATA_TYPE) .len(left.len()) .add_buffer(output_offsets.finish()) .add_buffer(output_values.finish()) @@ -155,7 +155,7 @@ pub fn concat_elements_utf8_many( output_offsets.append(Offset::from_usize(output_values.len()).unwrap()); } - let builder = ArrayDataBuilder::new(GenericStringArray::::get_data_type()) + let builder = ArrayDataBuilder::new(GenericStringArray::::DATA_TYPE) .len(size) .add_buffer(output_offsets.finish()) .add_buffer(output_values.finish()) diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 7b88de7b8e82..52664a175447 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -22,8 +22,6 @@ use std::sync::Arc; use num::Zero; -use TimeUnit::*; - use crate::array::*; use crate::buffer::{buffer_bin_and, Buffer, MutableBuffer}; use crate::datatypes::*; @@ -31,6 +29,7 @@ use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; use crate::util::bit_iterator::{BitIndexIterator, BitSliceIterator}; use crate::util::bit_util; +use crate::{downcast_dictionary_array, downcast_primitive_array}; /// If the filter selects more than this fraction of rows, use /// [`SlicesIterator`] to copy ranges of values. Otherwise iterate @@ -40,27 +39,6 @@ use crate::util::bit_util; /// const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; -macro_rules! downcast_filter { - ($type: ty, $values: expr, $filter: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a primitive array"); - - Ok(Arc::new(filter_primitive::<$type>(&values, $filter))) - }}; -} - -macro_rules! downcast_dict_filter { - ($type: ty, $values: expr, $filter: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a dictionary array"); - Ok(Arc::new(filter_dict::<$type>(values, $filter))) - }}; -} - /// An iterator of `(usize, usize)` each representing an interval /// `[start, end)` whose slots of a [BooleanArray] are true. Each /// interval corresponds to a contiguous region of memory to be @@ -358,92 +336,12 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result Ok(new_empty_array(values.data_type())), IterationStrategy::All => Ok(make_array(values.data().slice(0, predicate.count))), // actually filter - _ => match values.data_type() { + _ => downcast_primitive_array! { + values => Ok(Arc::new(filter_primitive(values, predicate))), DataType::Boolean => { let values = values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(filter_boolean(values, predicate))) } - DataType::Int8 => { - downcast_filter!(Int8Type, values, predicate) - } - DataType::Int16 => { - downcast_filter!(Int16Type, values, predicate) - } - DataType::Int32 => { - downcast_filter!(Int32Type, values, predicate) - } - DataType::Int64 => { - downcast_filter!(Int64Type, values, predicate) - } - DataType::UInt8 => { - downcast_filter!(UInt8Type, values, predicate) - } - DataType::UInt16 => { - downcast_filter!(UInt16Type, values, predicate) - } - DataType::UInt32 => { - downcast_filter!(UInt32Type, values, predicate) - } - DataType::UInt64 => { - downcast_filter!(UInt64Type, values, predicate) - } - DataType::Float32 => { - downcast_filter!(Float32Type, values, predicate) - } - DataType::Float64 => { - downcast_filter!(Float64Type, values, predicate) - } - DataType::Date32 => { - downcast_filter!(Date32Type, values, predicate) - } - DataType::Date64 => { - downcast_filter!(Date64Type, values, predicate) - } - DataType::Time32(Second) => { - downcast_filter!(Time32SecondType, values, predicate) - } - DataType::Time32(Millisecond) => { - downcast_filter!(Time32MillisecondType, values, predicate) - } - DataType::Time64(Microsecond) => { - downcast_filter!(Time64MicrosecondType, values, predicate) - } - DataType::Time64(Nanosecond) => { - downcast_filter!(Time64NanosecondType, values, predicate) - } - DataType::Timestamp(Second, _) => { - downcast_filter!(TimestampSecondType, values, predicate) - } - DataType::Timestamp(Millisecond, _) => { - downcast_filter!(TimestampMillisecondType, values, predicate) - } - DataType::Timestamp(Microsecond, _) => { - downcast_filter!(TimestampMicrosecondType, values, predicate) - } - DataType::Timestamp(Nanosecond, _) => { - downcast_filter!(TimestampNanosecondType, values, predicate) - } - DataType::Interval(IntervalUnit::YearMonth) => { - downcast_filter!(IntervalYearMonthType, values, predicate) - } - DataType::Interval(IntervalUnit::DayTime) => { - downcast_filter!(IntervalDayTimeType, values, predicate) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - downcast_filter!(IntervalMonthDayNanoType, values, predicate) - } - DataType::Duration(TimeUnit::Second) => { - downcast_filter!(DurationSecondType, values, predicate) - } - DataType::Duration(TimeUnit::Millisecond) => { - downcast_filter!(DurationMillisecondType, values, predicate) - } - DataType::Duration(TimeUnit::Microsecond) => { - downcast_filter!(DurationMicrosecondType, values, predicate) - } - DataType::Duration(TimeUnit::Nanosecond) => { - downcast_filter!(DurationNanosecondType, values, predicate) - } DataType::Utf8 => { let values = values .as_any() @@ -458,19 +356,10 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result(values, predicate))) } - DataType::Dictionary(key_type, _) => match key_type.as_ref() { - DataType::Int8 => downcast_dict_filter!(Int8Type, values, predicate), - DataType::Int16 => downcast_dict_filter!(Int16Type, values, predicate), - DataType::Int32 => downcast_dict_filter!(Int32Type, values, predicate), - DataType::Int64 => downcast_dict_filter!(Int64Type, values, predicate), - DataType::UInt8 => downcast_dict_filter!(UInt8Type, values, predicate), - DataType::UInt16 => downcast_dict_filter!(UInt16Type, values, predicate), - DataType::UInt32 => downcast_dict_filter!(UInt32Type, values, predicate), - DataType::UInt64 => downcast_dict_filter!(UInt64Type, values, predicate), - t => { - unimplemented!("Filter not supported for dictionary key type {:?}", t) - } - }, + DataType::Dictionary(_, _) => downcast_dictionary_array! { + values => Ok(Arc::new(filter_dict(values, predicate))), + t => unimplemented!("Filter not supported for dictionary type {:?}", t) + } _ => { // fallback to using MutableArrayData let mut mutable = MutableArrayData::new( @@ -1039,11 +928,11 @@ mod tests { #[test] fn test_filter_string_array_with_negated_boolean_array() { let a = StringArray::from(vec!["hello", " ", "world", "!"]); - let mut bb = BooleanBuilder::new(2); - bb.append_value(false).unwrap(); - bb.append_value(true).unwrap(); - bb.append_value(false).unwrap(); - bb.append_value(true).unwrap(); + let mut bb = BooleanBuilder::with_capacity(2); + bb.append_value(false); + bb.append_value(true); + bb.append_value(false); + bb.append_value(true); let b = bb.finish(); let b = crate::compute::not(&b).unwrap(); @@ -1416,19 +1305,19 @@ mod tests { #[test] fn test_filter_map() { let mut builder = - MapBuilder::new(None, StringBuilder::new(16), Int64Builder::new(4)); + MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4)); // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1} - builder.keys().append_value("key1").unwrap(); - builder.values().append_value(1).unwrap(); + builder.keys().append_value("key1"); + builder.values().append_value(1); builder.append(true).unwrap(); - builder.keys().append_value("key2").unwrap(); - builder.keys().append_value("key3").unwrap(); - builder.values().append_value(2).unwrap(); - builder.values().append_value(3).unwrap(); + builder.keys().append_value("key2"); + builder.keys().append_value("key3"); + builder.values().append_value(2); + builder.values().append_value(3); builder.append(true).unwrap(); builder.append(false).unwrap(); - builder.keys().append_value("key1").unwrap(); - builder.values().append_value(1).unwrap(); + builder.keys().append_value("key1"); + builder.values().append_value(1); builder.append(true).unwrap(); let maparray = Arc::new(builder.finish()) as ArrayRef; @@ -1438,12 +1327,12 @@ mod tests { let got = filter(&maparray, &indices).unwrap(); let mut builder = - MapBuilder::new(None, StringBuilder::new(8), Int64Builder::new(2)); - builder.keys().append_value("key1").unwrap(); - builder.values().append_value(1).unwrap(); + MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2)); + builder.keys().append_value("key1"); + builder.values().append_value(1); builder.append(true).unwrap(); - builder.keys().append_value("key1").unwrap(); - builder.values().append_value(1).unwrap(); + builder.keys().append_value("key1"); + builder.values().append_value(1); builder.append(true).unwrap(); let expected = Arc::new(builder.finish()) as ArrayRef; @@ -1553,7 +1442,7 @@ mod tests { let c = filter(&array, &filter_array).unwrap(); let filtered = c.as_any().downcast_ref::().unwrap(); - let mut builder = UnionBuilder::new_dense(1); + let mut builder = UnionBuilder::new_dense(); builder.append::("A", 1).unwrap(); let expected_array = builder.build().unwrap(); @@ -1563,7 +1452,7 @@ mod tests { let c = filter(&array, &filter_array).unwrap(); let filtered = c.as_any().downcast_ref::().unwrap(); - let mut builder = UnionBuilder::new_dense(2); + let mut builder = UnionBuilder::new_dense(); builder.append::("A", 1).unwrap(); builder.append::("A", 34).unwrap(); let expected_array = builder.build().unwrap(); @@ -1574,7 +1463,7 @@ mod tests { let c = filter(&array, &filter_array).unwrap(); let filtered = c.as_any().downcast_ref::().unwrap(); - let mut builder = UnionBuilder::new_dense(2); + let mut builder = UnionBuilder::new_dense(); builder.append::("A", 1).unwrap(); builder.append::("B", 3.2).unwrap(); let expected_array = builder.build().unwrap(); @@ -1584,7 +1473,7 @@ mod tests { #[test] fn test_filter_union_array_dense() { - let mut builder = UnionBuilder::new_dense(3); + let mut builder = UnionBuilder::new_dense(); builder.append::("A", 1).unwrap(); builder.append::("B", 3.2).unwrap(); builder.append::("A", 34).unwrap(); @@ -1595,7 +1484,7 @@ mod tests { #[test] fn test_filter_run_union_array_dense() { - let mut builder = UnionBuilder::new_dense(3); + let mut builder = UnionBuilder::new_dense(); builder.append::("A", 1).unwrap(); builder.append::("A", 3).unwrap(); builder.append::("A", 34).unwrap(); @@ -1605,7 +1494,7 @@ mod tests { let c = filter(&array, &filter_array).unwrap(); let filtered = c.as_any().downcast_ref::().unwrap(); - let mut builder = UnionBuilder::new_dense(3); + let mut builder = UnionBuilder::new_dense(); builder.append::("A", 1).unwrap(); builder.append::("A", 3).unwrap(); let expected = builder.build().unwrap(); @@ -1615,7 +1504,7 @@ mod tests { #[test] fn test_filter_union_array_dense_with_nulls() { - let mut builder = UnionBuilder::new_dense(4); + let mut builder = UnionBuilder::new_dense(); builder.append::("A", 1).unwrap(); builder.append::("B", 3.2).unwrap(); builder.append_null::("B").unwrap(); @@ -1626,7 +1515,7 @@ mod tests { let c = filter(&array, &filter_array).unwrap(); let filtered = c.as_any().downcast_ref::().unwrap(); - let mut builder = UnionBuilder::new_dense(2); + let mut builder = UnionBuilder::new_dense(); builder.append::("A", 1).unwrap(); builder.append::("B", 3.2).unwrap(); let expected_array = builder.build().unwrap(); @@ -1637,7 +1526,7 @@ mod tests { let c = filter(&array, &filter_array).unwrap(); let filtered = c.as_any().downcast_ref::().unwrap(); - let mut builder = UnionBuilder::new_dense(2); + let mut builder = UnionBuilder::new_dense(); builder.append::("A", 1).unwrap(); builder.append_null::("B").unwrap(); let expected_array = builder.build().unwrap(); @@ -1647,7 +1536,7 @@ mod tests { #[test] fn test_filter_union_array_sparse() { - let mut builder = UnionBuilder::new_sparse(3); + let mut builder = UnionBuilder::new_sparse(); builder.append::("A", 1).unwrap(); builder.append::("B", 3.2).unwrap(); builder.append::("A", 34).unwrap(); @@ -1658,7 +1547,7 @@ mod tests { #[test] fn test_filter_union_array_sparse_with_nulls() { - let mut builder = UnionBuilder::new_sparse(4); + let mut builder = UnionBuilder::new_sparse(); builder.append::("A", 1).unwrap(); builder.append::("B", 3.2).unwrap(); builder.append_null::("B").unwrap(); @@ -1669,7 +1558,7 @@ mod tests { let c = filter(&array, &filter_array).unwrap(); let filtered = c.as_any().downcast_ref::().unwrap(); - let mut builder = UnionBuilder::new_sparse(2); + let mut builder = UnionBuilder::new_sparse(); builder.append::("A", 1).unwrap(); builder.append_null::("B").unwrap(); let expected_array = builder.build().unwrap(); diff --git a/arrow/src/compute/kernels/regexp.rs b/arrow/src/compute/kernels/regexp.rs index 081a6e193bda..1c5fa1927756 100644 --- a/arrow/src/compute/kernels/regexp.rs +++ b/arrow/src/compute/kernels/regexp.rs @@ -35,7 +35,8 @@ pub fn regexp_match( flags_array: Option<&GenericStringArray>, ) -> Result { let mut patterns: HashMap = HashMap::new(); - let builder: GenericStringBuilder = GenericStringBuilder::new(0); + let builder: GenericStringBuilder = + GenericStringBuilder::with_capacity(0, 0); let mut list_builder = ListBuilder::new(builder); let complete_pattern = match flags_array { @@ -61,8 +62,8 @@ pub fn regexp_match( // Required for Postgres compatibility: // SELECT regexp_match('foobarbequebaz', ''); = {""} (Some(_), Some(pattern)) if pattern == *"" => { - list_builder.values().append_value("")?; - list_builder.append(true)?; + list_builder.values().append_value(""); + list_builder.append(true); } (Some(value), Some(pattern)) => { let existing_pattern = patterns.get(&pattern); @@ -82,14 +83,14 @@ pub fn regexp_match( match re.captures(value) { Some(caps) => { for m in caps.iter().skip(1).flatten() { - list_builder.values().append_value(m.as_str())?; + list_builder.values().append_value(m.as_str()); } - list_builder.append(true)? + list_builder.append(true); } - None => list_builder.append(false)?, + None => list_builder.append(false), } } - _ => list_builder.append(false)?, + _ => list_builder.append(false), } Ok(()) }) @@ -103,7 +104,7 @@ mod tests { use crate::array::{ListArray, StringArray}; #[test] - fn match_single_group() -> Result<()> { + fn match_single_group() { let values = vec![ Some("abc-005-def"), Some("X-7-5"), @@ -117,41 +118,40 @@ mod tests { pattern_values.push(r"(bar)(bequ1e)"); pattern_values.push(""); let pattern = StringArray::from(pattern_values); - let actual = regexp_match(&array, &pattern, None)?; - let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); + let actual = regexp_match(&array, &pattern, None).unwrap(); + let elem_builder: GenericStringBuilder = GenericStringBuilder::new(); let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.values().append_value("005")?; - expected_builder.append(true)?; - expected_builder.values().append_value("7")?; - expected_builder.append(true)?; - expected_builder.append(false)?; - expected_builder.append(false)?; - expected_builder.append(false)?; - expected_builder.values().append_value("")?; - expected_builder.append(true)?; + expected_builder.values().append_value("005"); + expected_builder.append(true); + expected_builder.values().append_value("7"); + expected_builder.append(true); + expected_builder.append(false); + expected_builder.append(false); + expected_builder.append(false); + expected_builder.values().append_value(""); + expected_builder.append(true); let expected = expected_builder.finish(); let result = actual.as_any().downcast_ref::().unwrap(); assert_eq!(&expected, result); - Ok(()) } #[test] - fn match_single_group_with_flags() -> Result<()> { + fn match_single_group_with_flags() { let values = vec![Some("abc-005-def"), Some("X-7-5"), Some("X545"), None]; let array = StringArray::from(values); let pattern = StringArray::from(vec![r"x.*-(\d*)-.*"; 4]); let flags = StringArray::from(vec!["i"; 4]); - let actual = regexp_match(&array, &pattern, Some(&flags))?; - let elem_builder: GenericStringBuilder = GenericStringBuilder::new(0); + let actual = regexp_match(&array, &pattern, Some(&flags)).unwrap(); + let elem_builder: GenericStringBuilder = + GenericStringBuilder::with_capacity(0, 0); let mut expected_builder = ListBuilder::new(elem_builder); - expected_builder.append(false)?; - expected_builder.values().append_value("7")?; - expected_builder.append(true)?; - expected_builder.append(false)?; - expected_builder.append(false)?; + expected_builder.append(false); + expected_builder.values().append_value("7"); + expected_builder.append(true); + expected_builder.append(false); + expected_builder.append(false); let expected = expected_builder.finish(); let result = actual.as_any().downcast_ref::().unwrap(); assert_eq!(&expected, result); - Ok(()) } } diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 8e0831c6140e..0e2273e92525 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -17,7 +17,6 @@ //! Defines sort kernel for `ArrayRef` -use crate::array::BasicDecimalArray; use crate::array::*; use crate::buffer::MutableBuffer; use crate::compute::take; @@ -113,32 +112,6 @@ where } } -// implements comparison using IEEE 754 total ordering for f32 -// Original implementation from https://doc.rust-lang.org/std/primitive.f64.html#method.total_cmp -// TODO to change to use std when it becomes stable -fn total_cmp_32(l: f32, r: f32) -> std::cmp::Ordering { - let mut left = l.to_bits() as i32; - let mut right = r.to_bits() as i32; - - left ^= (((left >> 31) as u32) >> 1) as i32; - right ^= (((right >> 31) as u32) >> 1) as i32; - - left.cmp(&right) -} - -// implements comparison using IEEE 754 total ordering for f64 -// Original implementation from https://doc.rust-lang.org/std/primitive.f64.html#method.total_cmp -// TODO to change to use std when it becomes stable -fn total_cmp_64(l: f64, r: f64) -> std::cmp::Ordering { - let mut left = l.to_bits() as i64; - let mut right = r.to_bits() as i64; - - left ^= (((left >> 63) as u64) >> 1) as i64; - right ^= (((right >> 63) as u64) >> 1) as i64; - - left.cmp(&right) -} - fn cmp(l: T, r: T) -> std::cmp::Ordering where T: Ord, @@ -171,7 +144,7 @@ pub fn sort_to_indices( let (v, n) = partition_validity(values); Ok(match values.data_type() { - DataType::Decimal(_, _) => sort_decimal(values, v, n, cmp, &options, limit), + DataType::Decimal128(_, _) => sort_decimal(values, v, n, cmp, &options, limit), DataType::Boolean => sort_boolean(values, v, n, &options, limit), DataType::Int8 => { sort_primitive::(values, v, n, cmp, &options, limit) @@ -197,12 +170,22 @@ pub fn sort_to_indices( DataType::UInt64 => { sort_primitive::(values, v, n, cmp, &options, limit) } - DataType::Float32 => { - sort_primitive::(values, v, n, total_cmp_32, &options, limit) - } - DataType::Float64 => { - sort_primitive::(values, v, n, total_cmp_64, &options, limit) - } + DataType::Float32 => sort_primitive::( + values, + v, + n, + |x, y| x.total_cmp(&y), + &options, + limit, + ), + DataType::Float64 => sort_primitive::( + values, + v, + n, + |x, y| x.total_cmp(&y), + &options, + limit, + ), DataType::Date32 => { sort_primitive::(values, v, n, cmp, &options, limit) } @@ -500,7 +483,7 @@ where // downcast to decimal array let decimal_array = decimal_values .as_any() - .downcast_ref::() + .downcast_ref::() .expect("Unable to downcast to decimal array"); let valids = value_indices .into_iter() @@ -1079,9 +1062,9 @@ mod tests { use std::convert::TryFrom; use std::sync::Arc; - fn create_decimal_array(data: &[Option]) -> DecimalArray { - data.iter() - .collect::() + fn create_decimal_array(data: Vec>) -> Decimal128Array { + data.into_iter() + .collect::() .with_precision_and_scale(23, 6) .unwrap() } @@ -1092,7 +1075,7 @@ mod tests { limit: Option, expected_data: Vec, ) { - let output = create_decimal_array(&data); + let output = create_decimal_array(data); let expected = UInt32Array::from(expected_data); let output = sort_to_indices(&(Arc::new(output) as ArrayRef), options, limit).unwrap(); @@ -1105,8 +1088,8 @@ mod tests { limit: Option, expected_data: Vec>, ) { - let output = create_decimal_array(&data); - let expected = Arc::new(create_decimal_array(&expected_data)) as ArrayRef; + let output = create_decimal_array(data); + let expected = Arc::new(create_decimal_array(expected_data)) as ArrayRef; let output = match limit { Some(_) => { sort_limit(&(Arc::new(output) as ArrayRef), options, limit).unwrap() diff --git a/arrow/src/compute/kernels/substring.rs b/arrow/src/compute/kernels/substring.rs index 024f5633fef4..5190d0bf0b67 100644 --- a/arrow/src/compute/kernels/substring.rs +++ b/arrow/src/compute/kernels/substring.rs @@ -205,7 +205,7 @@ pub fn substring_by_char( }); let data = unsafe { ArrayData::new_unchecked( - GenericStringArray::::get_data_type(), + GenericStringArray::::DATA_TYPE, array.len(), None, array @@ -294,7 +294,7 @@ fn binary_substring( let data = unsafe { ArrayData::new_unchecked( - GenericBinaryArray::::get_data_type(), + GenericBinaryArray::::DATA_TYPE, array.len(), None, array @@ -425,7 +425,7 @@ fn utf8_substring( let data = unsafe { ArrayData::new_unchecked( - GenericStringArray::::get_data_type(), + GenericStringArray::::DATA_TYPE, array.len(), None, array @@ -587,7 +587,7 @@ mod tests { // set the first and third element to be valid let bitmap = [0b101_u8]; - let data = ArrayData::builder(GenericBinaryArray::::get_data_type()) + let data = ArrayData::builder(GenericBinaryArray::::DATA_TYPE) .len(2) .add_buffer(Buffer::from_slice_ref(offsets)) .add_buffer(Buffer::from_iter(values)) @@ -814,7 +814,7 @@ mod tests { // set the first and third element to be valid let bitmap = [0b101_u8]; - let data = ArrayData::builder(GenericStringArray::::get_data_type()) + let data = ArrayData::builder(GenericStringArray::::DATA_TYPE) .len(2) .add_buffer(Buffer::from_slice_ref(offsets)) .add_buffer(Buffer::from(values)) @@ -939,7 +939,7 @@ mod tests { // set the first and third element to be valid let bitmap = [0b101_u8]; - let data = ArrayData::builder(GenericStringArray::::get_data_type()) + let data = ArrayData::builder(GenericStringArray::::DATA_TYPE) .len(2) .add_buffer(Buffer::from_slice_ref(offsets)) .add_buffer(Buffer::from(values)) diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index fa907656ae8c..19eb1b17ca21 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -19,8 +19,6 @@ use std::{ops::AddAssign, sync::Arc}; -use crate::array::BasicDecimalArray; - use crate::buffer::{Buffer, MutableBuffer}; use crate::compute::util::{ take_value_indices_from_fixed_size_list, take_value_indices_from_list, @@ -28,30 +26,11 @@ use crate::compute::util::{ use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -use crate::{array::*, buffer::buffer_bin_and}; +use crate::{ + array::*, buffer::buffer_bin_and, downcast_dictionary_array, downcast_primitive_array, +}; use num::{ToPrimitive, Zero}; -use TimeUnit::*; - -macro_rules! downcast_take { - ($type: ty, $values: expr, $indices: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a primitive array"); - Ok(Arc::new(take_primitive::<$type, _>(&values, $indices)?)) - }}; -} - -macro_rules! downcast_dict_take { - ($type: ty, $values: expr, $indices: expr) => {{ - let values = $values - .as_any() - .downcast_ref::>() - .expect("Unable to downcast to a dictionary array"); - Ok(Arc::new(take_dict::<$type, _>(values, $indices)?)) - }}; -} /// Take elements by index from [Array], creating a new [Array] from those indexes. /// @@ -143,70 +122,18 @@ where })? } } - match values.data_type() { + + downcast_primitive_array! { + values => Ok(Arc::new(take_primitive(values, indices)?)), DataType::Boolean => { let values = values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(take_boolean(values, indices)?)) } - DataType::Decimal(_, _) => { - let decimal_values = values.as_any().downcast_ref::().unwrap(); + DataType::Decimal128(_, _) => { + let decimal_values = + values.as_any().downcast_ref::().unwrap(); Ok(Arc::new(take_decimal128(decimal_values, indices)?)) } - DataType::Int8 => downcast_take!(Int8Type, values, indices), - DataType::Int16 => downcast_take!(Int16Type, values, indices), - DataType::Int32 => downcast_take!(Int32Type, values, indices), - DataType::Int64 => downcast_take!(Int64Type, values, indices), - DataType::UInt8 => downcast_take!(UInt8Type, values, indices), - DataType::UInt16 => downcast_take!(UInt16Type, values, indices), - DataType::UInt32 => downcast_take!(UInt32Type, values, indices), - DataType::UInt64 => downcast_take!(UInt64Type, values, indices), - DataType::Float32 => downcast_take!(Float32Type, values, indices), - DataType::Float64 => downcast_take!(Float64Type, values, indices), - DataType::Date32 => downcast_take!(Date32Type, values, indices), - DataType::Date64 => downcast_take!(Date64Type, values, indices), - DataType::Time32(Second) => downcast_take!(Time32SecondType, values, indices), - DataType::Time32(Millisecond) => { - downcast_take!(Time32MillisecondType, values, indices) - } - DataType::Time64(Microsecond) => { - downcast_take!(Time64MicrosecondType, values, indices) - } - DataType::Time64(Nanosecond) => { - downcast_take!(Time64NanosecondType, values, indices) - } - DataType::Timestamp(Second, _) => { - downcast_take!(TimestampSecondType, values, indices) - } - DataType::Timestamp(Millisecond, _) => { - downcast_take!(TimestampMillisecondType, values, indices) - } - DataType::Timestamp(Microsecond, _) => { - downcast_take!(TimestampMicrosecondType, values, indices) - } - DataType::Timestamp(Nanosecond, _) => { - downcast_take!(TimestampNanosecondType, values, indices) - } - DataType::Interval(IntervalUnit::YearMonth) => { - downcast_take!(IntervalYearMonthType, values, indices) - } - DataType::Interval(IntervalUnit::DayTime) => { - downcast_take!(IntervalDayTimeType, values, indices) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - downcast_take!(IntervalMonthDayNanoType, values, indices) - } - DataType::Duration(TimeUnit::Second) => { - downcast_take!(DurationSecondType, values, indices) - } - DataType::Duration(TimeUnit::Millisecond) => { - downcast_take!(DurationMillisecondType, values, indices) - } - DataType::Duration(TimeUnit::Microsecond) => { - downcast_take!(DurationMicrosecondType, values, indices) - } - DataType::Duration(TimeUnit::Nanosecond) => { - downcast_take!(DurationNanosecondType, values, indices) - } DataType::Utf8 => { let values = values .as_any() @@ -272,17 +199,10 @@ where Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef) } - DataType::Dictionary(key_type, _) => match key_type.as_ref() { - DataType::Int8 => downcast_dict_take!(Int8Type, values, indices), - DataType::Int16 => downcast_dict_take!(Int16Type, values, indices), - DataType::Int32 => downcast_dict_take!(Int32Type, values, indices), - DataType::Int64 => downcast_dict_take!(Int64Type, values, indices), - DataType::UInt8 => downcast_dict_take!(UInt8Type, values, indices), - DataType::UInt16 => downcast_dict_take!(UInt16Type, values, indices), - DataType::UInt32 => downcast_dict_take!(UInt32Type, values, indices), - DataType::UInt64 => downcast_dict_take!(UInt64Type, values, indices), - t => unimplemented!("Take not supported for dictionary key type {:?}", t), - }, + DataType::Dictionary(_, _) => downcast_dictionary_array! { + values => Ok(Arc::new(take_dict(values, indices)?)), + t => unimplemented!("Take not supported for dictionary type {:?}", t) + } DataType::Binary => { let values = values .as_any() @@ -315,7 +235,7 @@ where Ok(new_null_array(&DataType::Null, indices.len())) } } - t => unimplemented!("Take not supported for data type {:?}", t), + t => unimplemented!("Take not supported for data type {:?}", t) } } @@ -506,9 +426,9 @@ where /// `take` implementation for decimal arrays fn take_decimal128( - decimal_values: &DecimalArray, + decimal_values: &Decimal128Array, indices: &PrimitiveArray, -) -> Result +) -> Result where IndexType: ArrowNumericType, IndexType::Native: ToPrimitive, @@ -533,9 +453,9 @@ where let t: Result> = t.map(|t| t.flatten()); t }) - .collect::>()? + .collect::>()? // PERF: we could avoid re-validating that the data in - // DecimalArray was in range as we know it came from a valid DecimalArray + // Decimal128Array was in range as we know it came from a valid Decimal128Array .with_precision_and_scale(decimal_values.precision(), decimal_values.scale()) } @@ -599,66 +519,76 @@ where Ok(PrimitiveArray::::from(data)) } -/// `take` implementation for boolean arrays -fn take_boolean( - values: &BooleanArray, +fn take_bits( + values: &Buffer, + values_offset: usize, indices: &PrimitiveArray, -) -> Result +) -> Result where IndexType: ArrowNumericType, IndexType::Native: ToPrimitive, { - let data_len = indices.len(); + let len = indices.len(); + let values_slice = values.as_slice(); + let mut output_buffer = MutableBuffer::new_null(len); + let output_slice = output_buffer.as_slice_mut(); - let num_byte = bit_util::ceil(data_len, 8); - let mut val_buf = MutableBuffer::from_len_zeroed(num_byte); + let indices_has_nulls = indices.null_count() > 0; - let val_slice = val_buf.as_slice_mut(); + if indices_has_nulls { + indices + .iter() + .enumerate() + .try_for_each::<_, Result<()>>(|(i, index)| { + if let Some(index) = index { + let index = ToPrimitive::to_usize(&index).ok_or_else(|| { + ArrowError::ComputeError("Cast to usize failed".to_string()) + })?; - let null_count = values.null_count(); + if bit_util::get_bit(values_slice, values_offset + index) { + bit_util::set_bit(output_slice, i); + } + } - let nulls = if null_count == 0 { - (0..data_len).try_for_each::<_, Result<()>>(|i| { - let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) + Ok(()) })?; - - if values.value(index) { - bit_util::set_bit(val_slice, i); - } - - Ok(()) - })?; - - indices.data_ref().null_buffer().cloned() } else { - let mut null_buf = MutableBuffer::new(num_byte).with_bitset(num_byte, true); - let null_slice = null_buf.as_slice_mut(); + indices + .values() + .iter() + .enumerate() + .try_for_each::<_, Result<()>>(|(i, index)| { + let index = ToPrimitive::to_usize(index).ok_or_else(|| { + ArrowError::ComputeError("Cast to usize failed".to_string()) + })?; - (0..data_len).try_for_each::<_, Result<()>>(|i| { - let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| { - ArrowError::ComputeError("Cast to usize failed".to_string()) + if bit_util::get_bit(values_slice, values_offset + index) { + bit_util::set_bit(output_slice, i); + } + Ok(()) })?; + } + Ok(output_buffer.into()) +} - if values.is_null(index) { - bit_util::unset_bit(null_slice, i); - } else if values.value(index) { - bit_util::set_bit(val_slice, i); - } - - Ok(()) - })?; - - match indices.data_ref().null_buffer() { - Some(buffer) => Some(buffer_bin_and( - buffer, - indices.offset(), - &null_buf.into(), - 0, - indices.len(), - )), - None => Some(null_buf.into()), +/// `take` implementation for boolean arrays +fn take_boolean( + values: &BooleanArray, + indices: &PrimitiveArray, +) -> Result +where + IndexType: ArrowNumericType, + IndexType::Native: ToPrimitive, +{ + let val_buf = take_bits(values.values(), values.offset(), indices)?; + let null_buf = match values.data().null_buffer() { + Some(buf) if values.null_count() > 0 => { + Some(take_bits(buf, values.offset(), indices)?) } + _ => indices + .data() + .null_buffer() + .map(|b| b.bit_slice(indices.offset(), indices.len())), }; let data = unsafe { @@ -666,9 +596,9 @@ where DataType::Boolean, indices.len(), None, - nulls, + null_buf, 0, - vec![val_buf.into()], + vec![val_buf], vec![], ) }; @@ -778,12 +708,11 @@ where }; } - let array_data = - ArrayData::builder(GenericStringArray::::get_data_type()) - .len(data_len) - .add_buffer(offsets_buffer.into()) - .add_buffer(values.into()) - .null_bit_buffer(nulls); + let array_data = ArrayData::builder(GenericStringArray::::DATA_TYPE) + .len(data_len) + .add_buffer(offsets_buffer.into()) + .add_buffer(values.into()) + .null_bit_buffer(nulls); let array_data = unsafe { array_data.build_unchecked() }; @@ -979,18 +908,18 @@ mod tests { index: &UInt32Array, options: Option, expected_data: Vec>, - precision: &usize, - scale: &usize, + precision: &u8, + scale: &u8, ) -> Result<()> { let output = data .into_iter() - .collect::() + .collect::() .with_precision_and_scale(*precision, *scale) .unwrap(); let expected = expected_data .into_iter() - .collect::() + .collect::() .with_precision_and_scale(*precision, *scale) .unwrap(); @@ -1075,8 +1004,8 @@ mod tests { Field::new("b", DataType::Int32, true), ], vec![ - Box::new(BooleanBuilder::new(values.len())), - Box::new(Int32Builder::new(values.len())), + Box::new(BooleanBuilder::with_capacity(values.len())), + Box::new(Int32Builder::with_capacity(values.len())), ], ); @@ -1084,14 +1013,12 @@ mod tests { struct_builder .field_builder::(0) .unwrap() - .append_option(value.and_then(|v| v.0)) - .unwrap(); + .append_option(value.and_then(|v| v.0)); struct_builder .field_builder::(1) .unwrap() - .append_option(value.and_then(|v| v.1)) - .unwrap(); - struct_builder.append(value.is_some()).unwrap(); + .append_option(value.and_then(|v| v.1)); + struct_builder.append(value.is_some()); } struct_builder.finish() } @@ -1099,8 +1026,8 @@ mod tests { #[test] fn test_take_decimal128_non_null_indices() { let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]); - let precision: usize = 10; - let scale: usize = 5; + let precision: u8 = 10; + let scale: u8 = 5; test_take_decimal_arrays( vec![None, Some(3), Some(5), Some(2), Some(3), None], &index, @@ -1115,8 +1042,8 @@ mod tests { #[test] fn test_take_decimal128() { let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); - let precision: usize = 10; - let scale: usize = 5; + let precision: u8 = 10; + let scale: u8 = 5; test_take_decimal_arrays( vec![Some(0), Some(1), Some(2), Some(3), Some(4)], &index, @@ -1467,6 +1394,52 @@ mod tests { ); } + #[test] + fn test_take_bool_nullable_index() { + // indices where the masked invalid elements would be out of bounds + let index_data = ArrayData::try_new( + DataType::Int32, + 6, + Some(Buffer::from_iter(vec![ + false, true, false, true, false, true, + ])), + 0, + vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])], + vec![], + ) + .unwrap(); + let index = UInt32Array::from(index_data); + test_take_boolean_arrays( + vec![Some(true), None, Some(false)], + &index, + None, + vec![None, Some(true), None, None, None, Some(false)], + ); + } + + #[test] + fn test_take_bool_nullable_index_nonnull_values() { + // indices where the masked invalid elements would be out of bounds + let index_data = ArrayData::try_new( + DataType::Int32, + 6, + Some(Buffer::from_iter(vec![ + false, true, false, true, false, true, + ])), + 0, + vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])], + vec![], + ) + .unwrap(); + let index = UInt32Array::from(index_data); + test_take_boolean_arrays( + vec![Some(true), Some(true), Some(false)], + &index, + None, + vec![None, Some(true), None, Some(true), None, Some(false)], + ); + } + #[test] fn test_take_bool_with_offset() { let index = @@ -1987,15 +1960,15 @@ mod tests { #[test] fn test_take_dict() { - let keys_builder = Int16Builder::new(8); - let values_builder = StringBuilder::new(4); + let keys_builder = Int16Builder::new(); + let values_builder = StringBuilder::new(); let mut dict_builder = StringDictionaryBuilder::new(keys_builder, values_builder); dict_builder.append("foo").unwrap(); dict_builder.append("bar").unwrap(); dict_builder.append("").unwrap(); - dict_builder.append_null().unwrap(); + dict_builder.append_null(); dict_builder.append("foo").unwrap(); dict_builder.append("bar").unwrap(); dict_builder.append("bar").unwrap(); diff --git a/arrow/src/compute/kernels/temporal.rs b/arrow/src/compute/kernels/temporal.rs index efb828430629..1bec1d84f681 100644 --- a/arrow/src/compute/kernels/temporal.rs +++ b/arrow/src/compute/kernels/temporal.rs @@ -28,33 +28,33 @@ use chrono::format::{parse, Parsed}; use chrono::FixedOffset; macro_rules! extract_component_from_array { - ($array:ident, $builder:ident, $extract_fn:ident, $using:ident) => { + ($array:ident, $builder:ident, $extract_fn:ident, $using:ident, $convert:expr) => { for i in 0..$array.len() { if $array.is_null(i) { - $builder.append_null()?; + $builder.append_null(); } else { match $array.$using(i) { - Some(dt) => $builder.append_value(dt.$extract_fn() as i32)?, - None => $builder.append_null()?, + Some(dt) => $builder.append_value($convert(dt.$extract_fn())), + None => $builder.append_null(), } } } }; - ($array:ident, $builder:ident, $extract_fn1:ident, $extract_fn2:ident, $using:ident) => { + ($array:ident, $builder:ident, $extract_fn1:ident, $extract_fn2:ident, $using:ident, $convert:expr) => { for i in 0..$array.len() { if $array.is_null(i) { - $builder.append_null()?; + $builder.append_null(); } else { match $array.$using(i) { Some(dt) => { - $builder.append_value(dt.$extract_fn1().$extract_fn2() as i32)? + $builder.append_value($convert(dt.$extract_fn1().$extract_fn2())); } - None => $builder.append_null()?, + None => $builder.append_null(), } } } }; - ($array:ident, $builder:ident, $extract_fn:ident, $using:ident, $tz:ident, $parsed:ident) => { + ($array:ident, $builder:ident, $extract_fn:ident, $using:ident, $tz:ident, $parsed:ident, $convert:expr) => { if ($tz.starts_with('+') || $tz.starts_with('-')) && !$tz.contains(':') { return_compute_error_with!( "Invalid timezone", @@ -72,7 +72,7 @@ macro_rules! extract_component_from_array { for i in 0..$array.len() { if $array.is_null(i) { - $builder.append_null()?; + $builder.append_null(); } else { match $array.value_as_datetime(i) { Some(utc) => { @@ -90,9 +90,9 @@ macro_rules! extract_component_from_array { }; match $array.$using(i, fixed_offset) { Some(dt) => { - $builder.append_value(dt.$extract_fn() as i32)? + $builder.append_value($convert(dt.$extract_fn())); } - None => $builder.append_null()?, + None => $builder.append_null(), } } err => return_compute_error_with!( @@ -112,6 +112,9 @@ macro_rules! return_compute_error_with { }; } +pub(crate) use extract_component_from_array; +pub(crate) use return_compute_error_with; + // Internal trait, which is used for mapping values from DateLike structures trait ChronoDateExt { /// Returns a value in range `1..=4` indicating the quarter this date falls into @@ -174,13 +177,13 @@ where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Time32(_) | &DataType::Time64(_) => { - extract_component_from_array!(array, b, hour, value_as_time) + extract_component_from_array!(array, b, hour, value_as_time, |h| h as i32) } &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { - extract_component_from_array!(array, b, hour, value_as_datetime) + extract_component_from_array!(array, b, hour, value_as_datetime, |h| h as i32) } &DataType::Timestamp(_, Some(ref tz)) => { let mut scratch = Parsed::new(); @@ -190,7 +193,8 @@ where hour, value_as_datetime_with_tz, tz, - scratch + scratch, + |h| h as i32 ) } dt => return_compute_error_with!("hour does not support", dt), @@ -205,10 +209,10 @@ where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, _) => { - extract_component_from_array!(array, b, year, value_as_datetime) + extract_component_from_array!(array, b, year, value_as_datetime, |h| h as i32) } dt => return_compute_error_with!("year does not support", dt), } @@ -222,10 +226,11 @@ where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { - extract_component_from_array!(array, b, quarter, value_as_datetime) + extract_component_from_array!(array, b, quarter, value_as_datetime, |h| h + as i32) } &DataType::Timestamp(_, Some(ref tz)) => { let mut scratch = Parsed::new(); @@ -235,7 +240,8 @@ where quarter, value_as_datetime_with_tz, tz, - scratch + scratch, + |h| h as i32 ) } dt => return_compute_error_with!("quarter does not support", dt), @@ -250,10 +256,11 @@ where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { - extract_component_from_array!(array, b, month, value_as_datetime) + extract_component_from_array!(array, b, month, value_as_datetime, |h| h + as i32) } &DataType::Timestamp(_, Some(ref tz)) => { let mut scratch = Parsed::new(); @@ -263,7 +270,8 @@ where month, value_as_datetime_with_tz, tz, - scratch + scratch, + |h| h as i32 ) } dt => return_compute_error_with!("month does not support", dt), @@ -283,14 +291,15 @@ where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { extract_component_from_array!( array, b, num_days_from_monday, - value_as_datetime + value_as_datetime, + |h| h as i32 ) } &DataType::Timestamp(_, Some(ref tz)) => { @@ -301,7 +310,8 @@ where num_days_from_monday, value_as_datetime_with_tz, tz, - scratch + scratch, + |h| h as i32 ) } dt => return_compute_error_with!("weekday does not support", dt), @@ -321,14 +331,15 @@ where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { extract_component_from_array!( array, b, num_days_from_sunday, - value_as_datetime + value_as_datetime, + |h| h as i32 ) } &DataType::Timestamp(_, Some(ref tz)) => { @@ -339,10 +350,11 @@ where num_days_from_sunday, value_as_datetime_with_tz, tz, - scratch + scratch, + |h| h as i32 ) } - dt => return_compute_error_with!("weekday does not support", dt), + dt => return_compute_error_with!("num_days_from_sunday does not support", dt), } Ok(b.finish()) @@ -354,10 +366,10 @@ where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { - extract_component_from_array!(array, b, day, value_as_datetime) + extract_component_from_array!(array, b, day, value_as_datetime, |h| h as i32) } &DataType::Timestamp(_, Some(ref tz)) => { let mut scratch = Parsed::new(); @@ -367,7 +379,8 @@ where day, value_as_datetime_with_tz, tz, - scratch + scratch, + |h| h as i32 ) } dt => return_compute_error_with!("day does not support", dt), @@ -376,16 +389,48 @@ where Ok(b.finish()) } +/// Extracts the day of year of a given temporal array as an array of integers +/// The day of year that ranges from 1 to 366 +pub fn doy(array: &PrimitiveArray) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: std::convert::From, +{ + let mut b = Int32Builder::with_capacity(array.len()); + match array.data_type() { + &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { + extract_component_from_array!(array, b, ordinal, value_as_datetime, |h| h + as i32) + } + &DataType::Timestamp(_, Some(ref tz)) => { + let mut scratch = Parsed::new(); + extract_component_from_array!( + array, + b, + ordinal, + value_as_datetime_with_tz, + tz, + scratch, + |h| h as i32 + ) + } + dt => return_compute_error_with!("doy does not support", dt), + } + + Ok(b.finish()) +} + /// Extracts the minutes of a given temporal array as an array of integers pub fn minute(array: &PrimitiveArray) -> Result where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Date64 | &DataType::Timestamp(_, None) => { - extract_component_from_array!(array, b, minute, value_as_datetime) + extract_component_from_array!(array, b, minute, value_as_datetime, |h| h + as i32) } &DataType::Timestamp(_, Some(ref tz)) => { let mut scratch = Parsed::new(); @@ -395,7 +440,8 @@ where minute, value_as_datetime_with_tz, tz, - scratch + scratch, + |h| h as i32 ) } dt => return_compute_error_with!("minute does not support", dt), @@ -410,11 +456,18 @@ where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Date32 | &DataType::Date64 | &DataType::Timestamp(_, None) => { - extract_component_from_array!(array, b, iso_week, week, value_as_datetime) + extract_component_from_array!( + array, + b, + iso_week, + week, + value_as_datetime, + |h| h as i32 + ) } dt => return_compute_error_with!("week does not support", dt), } @@ -428,10 +481,11 @@ where T: ArrowTemporalType + ArrowNumericType, i64: std::convert::From, { - let mut b = Int32Builder::new(array.len()); + let mut b = Int32Builder::with_capacity(array.len()); match array.data_type() { &DataType::Date64 | &DataType::Timestamp(_, None) => { - extract_component_from_array!(array, b, second, value_as_datetime) + extract_component_from_array!(array, b, second, value_as_datetime, |h| h + as i32) } &DataType::Timestamp(_, Some(ref tz)) => { let mut scratch = Parsed::new(); @@ -441,7 +495,8 @@ where second, value_as_datetime_with_tz, tz, - scratch + scratch, + |h| h as i32 ) } dt => return_compute_error_with!("second does not support", dt), @@ -685,6 +740,26 @@ mod tests { assert_eq!(1, b.value(2)); } + #[test] + fn test_temporal_array_date64_doy() { + //1483228800000 -> 2017-01-01 (Sunday) + //1514764800000 -> 2018-01-01 + //1550636625000 -> 2019-02-20 + let a: PrimitiveArray = vec![ + Some(1483228800000), + Some(1514764800000), + None, + Some(1550636625000), + ] + .into(); + + let b = doy(&a).unwrap(); + assert_eq!(1, b.value(0)); + assert_eq!(1, b.value(1)); + assert!(!b.is_valid(2)); + assert_eq!(51, b.value(3)); + } + #[test] fn test_temporal_array_timestamp_micro_year() { let a: TimestampMicrosecondArray = diff --git a/arrow/src/compute/kernels/zip.rs b/arrow/src/compute/kernels/zip.rs index 0ee8e47bede0..c28529cf6762 100644 --- a/arrow/src/compute/kernels/zip.rs +++ b/arrow/src/compute/kernels/zip.rs @@ -44,7 +44,7 @@ pub fn zip( let falsy = falsy.data(); let truthy = truthy.data(); - let mut mutable = MutableArrayData::new(vec![&*truthy, &*falsy], false, truthy.len()); + let mut mutable = MutableArrayData::new(vec![truthy, falsy], false, truthy.len()); // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to // fill with falsy values diff --git a/arrow/src/compute/util.rs b/arrow/src/compute/util.rs index 29a90b65c237..974af9593e36 100644 --- a/arrow/src/compute/util.rs +++ b/arrow/src/compute/util.rs @@ -351,9 +351,7 @@ pub(super) mod tests { T: ArrowPrimitiveType, PrimitiveArray: From>>, { - use std::any::TypeId; - - let mut offset = vec![0]; + let mut offset = vec![S::zero()]; let mut values = vec![]; let list_len = data.len(); @@ -367,34 +365,18 @@ pub(super) mod tests { list_null_count += 1; bit_util::unset_bit(list_bitmap.as_slice_mut(), idx); } - offset.push(values.len() as i64); + offset.push(S::from_usize(values.len()).unwrap()); } let value_data = PrimitiveArray::::from(values).into_data(); - let (list_data_type, value_offsets) = if TypeId::of::() == TypeId::of::() - { - ( - DataType::List(Box::new(Field::new( - "item", - T::DATA_TYPE, - list_null_count == 0, - ))), - Buffer::from_slice_ref( - &offset.into_iter().map(|x| x as i32).collect::>(), - ), - ) - } else if TypeId::of::() == TypeId::of::() { - ( - DataType::LargeList(Box::new(Field::new( - "item", - T::DATA_TYPE, - list_null_count == 0, - ))), - Buffer::from_slice_ref(&offset), - ) - } else { - unreachable!() - }; + let (list_data_type, value_offsets) = ( + GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new(Field::new( + "item", + T::DATA_TYPE, + list_null_count == 0, + ))), + Buffer::from_slice_ref(&offset), + ); let list_data = ArrayData::builder(list_data_type) .len(list_len) diff --git a/arrow/src/csv/reader.rs b/arrow/src/csv/reader.rs index 7250f943e48a..d164d35c3c8c 100644 --- a/arrow/src/csv/reader.rs +++ b/arrow/src/csv/reader.rs @@ -50,11 +50,12 @@ use std::io::{Read, Seek, SeekFrom}; use std::sync::Arc; use crate::array::{ - ArrayRef, BooleanArray, DecimalBuilder, DictionaryArray, PrimitiveArray, StringArray, + ArrayRef, BooleanArray, Decimal128Builder, DictionaryArray, PrimitiveArray, + StringArray, }; use crate::datatypes::*; use crate::error::{ArrowError, Result}; -use crate::record_batch::RecordBatch; +use crate::record_batch::{RecordBatch, RecordBatchOptions}; use crate::util::reader_parser::Parser; use csv_crate::{ByteRecord, StringRecord}; @@ -543,7 +544,7 @@ fn parse( let field = &fields[i]; match field.data_type() { DataType::Boolean => build_boolean_array(line_number, rows, i), - DataType::Decimal(precision, scale) => { + DataType::Decimal128(precision, scale) => { build_decimal_array(line_number, rows, i, *precision, *scale) } DataType::Int8 => { @@ -670,7 +671,16 @@ fn parse( Some(metadata) => Schema::new_with_metadata(projected_fields, metadata), }); - arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr)) + arrays.and_then(|arr| { + RecordBatch::try_new_with_options( + projected_schema, + arr, + &RecordBatchOptions { + match_field_names: true, + row_count: Some(rows.len()), + }, + ) + }) } fn parse_item(string: &str) -> Option { T::parse(string) @@ -695,21 +705,22 @@ fn build_decimal_array( _line_number: usize, rows: &[StringRecord], col_idx: usize, - precision: usize, - scale: usize, + precision: u8, + scale: u8, ) -> Result { - let mut decimal_builder = DecimalBuilder::new(rows.len(), precision, scale); + let mut decimal_builder = + Decimal128Builder::with_capacity(rows.len(), precision, scale); for row in rows { let col_s = row.get(col_idx); match col_s { None => { // No data for this row - decimal_builder.append_null()?; + decimal_builder.append_null(); } Some(s) => { if s.is_empty() { // append null - decimal_builder.append_null()?; + decimal_builder.append_null(); } else { let decimal_value: Result = parse_decimal_with_parameter(s, precision, scale); @@ -730,11 +741,12 @@ fn build_decimal_array( // Parse the string format decimal value to i128 format and checking the precision and scale. // The result i128 value can't be out of bounds. -fn parse_decimal_with_parameter(s: &str, precision: usize, scale: usize) -> Result { +fn parse_decimal_with_parameter(s: &str, precision: u8, scale: u8) -> Result { if PARSE_DECIMAL_RE.is_match(s) { let mut offset = s.len(); let len = s.len(); let mut base = 1; + let scale_usize = usize::from(scale); // handle the value after the '.' and meet the scale let delimiter_position = s.find('.'); @@ -745,12 +757,12 @@ fn parse_decimal_with_parameter(s: &str, precision: usize, scale: usize) -> Resu } Some(mid) => { // there is the '.' - if len - mid >= scale + 1 { + if len - mid >= scale_usize + 1 { // If the string value is "123.12345" and the scale is 2, we should just remain '.12' and drop the '345' value. - offset -= len - mid - 1 - scale; + offset -= len - mid - 1 - scale_usize; } else { // If the string value is "123.12" and the scale is 4, we should append '00' to the tail. - base = 10_i128.pow((scale + 1 + mid - len) as u32); + base = 10_i128.pow((scale_usize + 1 + mid - len) as u32); } } }; @@ -775,8 +787,14 @@ fn parse_decimal_with_parameter(s: &str, precision: usize, scale: usize) -> Resu if negative { result = result.neg(); } - validate_decimal_precision(result, precision) - .map_err(|e| ArrowError::ParseError(format!("parse decimal overflow: {}", e))) + + match validate_decimal_precision(result, precision) { + Ok(_) => Ok(result), + Err(e) => Err(ArrowError::ParseError(format!( + "parse decimal overflow: {}", + e + ))), + } } else { Err(ArrowError::ParseError(format!( "can't parse the string value {} to decimal", @@ -1115,7 +1133,6 @@ mod tests { use std::io::{Cursor, Write}; use tempfile::NamedTempFile; - use crate::array::BasicDecimalArray; use crate::array::*; use crate::compute::cast; use crate::datatypes::Field; @@ -1205,8 +1222,8 @@ mod tests { fn test_csv_reader_with_decimal() { let schema = Schema::new(vec![ Field::new("city", DataType::Utf8, false), - Field::new("lat", DataType::Decimal(38, 6), false), - Field::new("lng", DataType::Decimal(38, 6), false), + Field::new("lat", DataType::Decimal128(38, 6), false), + Field::new("lng", DataType::Decimal128(38, 6), false), ]); let file = File::open("test/data/decimal_test.csv").unwrap(); @@ -1218,7 +1235,7 @@ mod tests { let lat = batch .column(1) .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); assert_eq!("57.653484", lat.value_as_string(0)); @@ -1861,6 +1878,38 @@ mod tests { assert!(csv.next().is_none()); } + #[test] + fn test_empty_projection() { + let schema = Schema::new(vec![Field::new("int", DataType::UInt32, false)]); + let data = vec![vec!["0"], vec!["1"]]; + + let data = data + .iter() + .map(|x| x.join(",")) + .collect::>() + .join("\n"); + let data = data.as_bytes(); + + let reader = std::io::Cursor::new(data); + + let mut csv = Reader::new( + reader, + Arc::new(schema), + false, + None, + 2, + None, + Some(vec![]), + None, + ); + + let batch = csv.next().unwrap().unwrap(); + assert_eq!(batch.columns().len(), 0); + assert_eq!(batch.num_rows(), 2); + + assert!(csv.next().is_none()); + } + #[test] fn test_parsing_bool() { // Encode the expected behavior of boolean parsing diff --git a/arrow/src/csv/writer.rs b/arrow/src/csv/writer.rs index 6735d9668560..7097706ba5f3 100644 --- a/arrow/src/csv/writer.rs +++ b/arrow/src/csv/writer.rs @@ -27,8 +27,6 @@ //! use arrow::csv; //! use arrow::datatypes::*; //! use arrow::record_batch::RecordBatch; -//! use arrow::util::test_util::get_temp_file; -//! use std::fs::File; //! use std::sync::Arc; //! //! let schema = Schema::new(vec![ @@ -56,9 +54,9 @@ //! ) //! .unwrap(); //! -//! let file = get_temp_file("out.csv", &[]); +//! let mut output = Vec::with_capacity(1024); //! -//! let mut writer = csv::Writer::new(file); +//! let mut writer = csv::Writer::new(&mut output); //! let batches = vec![&batch, &batch]; //! for batch in batches { //! writer.write(batch).unwrap(); @@ -223,7 +221,7 @@ impl Writer { DataType::Timestamp(time_unit, time_zone) => { self.handle_timestamp(time_unit, time_zone.as_ref(), row_index, col)? } - DataType::Decimal(..) => make_string_from_decimal(col, row_index)?, + DataType::Decimal128(..) => make_string_from_decimal(col, row_index)?, t => { // List and Struct arrays not supported by the writer, any // other type needs to be implemented diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs index d65915bd7ad9..b65bfd7725ac 100644 --- a/arrow/src/datatypes/datatype.rs +++ b/arrow/src/datatypes/datatype.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. +use num::BigInt; +use std::cmp::Ordering; use std::fmt; -use serde_derive::{Deserialize, Serialize}; -use serde_json::{json, Value, Value::String as VString}; - use crate::error::{ArrowError, Result}; +use crate::util::decimal::singed_cmp_le_bytes; use super::Field; @@ -39,7 +39,8 @@ use super::Field; /// Nested types can themselves be nested within other arrays. /// For more information on these types please see /// [the physical memory layout of Apache Arrow](https://arrow.apache.org/docs/format/Columnar.html#physical-memory-layout). -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum DataType { /// Null type Null, @@ -188,13 +189,20 @@ pub enum DataType { /// This type mostly used to represent low cardinality string /// arrays or a limited set of primitive types as integers. Dictionary(Box, Box), - /// Exact decimal value with precision and scale + /// Exact 128-bit width decimal value with precision and scale /// /// * precision is the total number of digits /// * scale is the number of digits past the decimal /// /// For example the number 123.45 has precision 5 and scale 2. - Decimal(usize, usize), + Decimal128(u8, u8), + /// Exact 256-bit width decimal value with precision and scale + /// + /// * precision is the total number of digits + /// * scale is the number of digits past the decimal + /// + /// For example the number 123.45 has precision 5 and scale 2. + Decimal256(u8, u8), /// A Map is a logical nested type that is represented as /// /// `List>` @@ -212,7 +220,8 @@ pub enum DataType { } /// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum TimeUnit { /// Time in seconds. Second, @@ -225,7 +234,8 @@ pub enum TimeUnit { } /// YEAR_MONTH, DAY_TIME, MONTH_DAY_NANO interval in SQL style. -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum IntervalUnit { /// Indicates the number of elapsed whole months, stored as 4-byte integers. YearMonth, @@ -243,7 +253,8 @@ pub enum IntervalUnit { } // Sparse or Dense union layouts -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum UnionMode { Sparse, Dense, @@ -255,8 +266,628 @@ impl fmt::Display for DataType { } } +// MAX decimal256 value of little-endian format for each precision. +// Each element is the max value of signed 256-bit integer for the specified precision which +// is encoded to the 32-byte width format of little-endian. +pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [[u8; 32]; 76] = [ + [ + 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ], + [ + 99, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ], + [ + 231, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ], + [ + 15, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ], + [ + 159, 134, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ], + [ + 63, 66, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ], + [ + 127, 150, 152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 224, 245, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 201, 154, 59, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 227, 11, 84, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 231, 118, 72, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 15, 165, 212, 232, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 159, 114, 78, 24, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 63, 122, 16, 243, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 127, 198, 164, 126, 141, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 192, 111, 242, 134, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 137, 93, 120, 69, 99, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 99, 167, 179, 182, 224, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 231, 137, 4, 35, 199, 138, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 15, 99, 45, 94, 199, 107, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 159, 222, 197, 173, 201, 53, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 63, 178, 186, 201, 224, 25, 30, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 127, 246, 74, 225, 199, 2, 45, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 160, 237, 204, 206, 27, 194, 211, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 73, 72, 1, 20, 22, 149, 69, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 227, 210, 12, 200, 220, 210, 183, 82, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 231, 60, 128, 208, 159, 60, 46, 59, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 15, 97, 2, 37, 62, 94, 206, 79, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 159, 202, 23, 114, 109, 174, 15, 30, 67, 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 63, 234, 237, 116, 70, 208, 156, 44, 159, 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 127, 38, 75, 145, 192, 34, 32, 190, 55, 126, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 128, 239, 172, 133, 91, 65, 109, 45, 238, 4, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 9, 91, 193, 56, 147, 141, 68, 198, 77, 49, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 99, 142, 141, 55, 192, 135, 173, 190, 9, 237, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 231, 143, 135, 43, 130, 77, 199, 114, 97, 66, 19, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 15, 159, 75, 179, 21, 7, 201, 123, 206, 151, 192, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 159, 54, 244, 0, 217, 70, 218, 213, 16, 238, 133, 7, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 63, 34, 138, 9, 122, 196, 134, 90, 168, 76, 59, 75, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 127, 86, 101, 95, 196, 172, 67, 137, 147, 254, 80, 240, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 96, 245, 185, 171, 191, 164, 92, 195, 241, 41, 99, 29, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 201, 149, 67, 181, 124, 111, 158, 161, 113, 163, 223, + 37, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 227, 217, 163, 20, 223, 90, 48, 80, 112, 98, 188, 122, + 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 231, 130, 102, 206, 182, 140, 227, 33, 99, 216, 91, 203, + 114, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 15, 29, 1, 16, 36, 127, 227, 82, 223, 115, 150, 241, + 123, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 159, 34, 11, 160, 104, 247, 226, 60, 185, 134, 224, 111, + 215, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 63, 90, 111, 64, 22, 170, 221, 96, 60, 67, 197, 94, 106, + 192, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 127, 134, 89, 132, 222, 164, 168, 200, 91, 160, 180, + 179, 39, 132, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 64, 127, 43, 177, 112, 150, 214, 149, 67, 14, 5, + 141, 41, 175, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 137, 248, 178, 235, 102, 224, 97, 218, 163, 142, + 50, 130, 159, 215, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 99, 181, 253, 52, 5, 196, 210, 135, 102, 146, 249, + 21, 59, 108, 68, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 231, 21, 233, 17, 52, 168, 59, 78, 1, 184, 191, + 219, 78, 58, 172, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 15, 219, 26, 179, 8, 146, 84, 14, 13, 48, 125, 149, + 20, 71, 186, 26, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 159, 142, 12, 255, 86, 180, 77, 143, 130, 224, 227, + 214, 205, 198, 70, 11, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 63, 146, 125, 246, 101, 11, 9, 153, 25, 197, 230, + 100, 10, 196, 195, 112, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 127, 182, 231, 160, 251, 113, 90, 250, 255, 178, 3, + 241, 103, 168, 165, 103, 104, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 32, 13, 73, 212, 115, 136, 199, 255, 253, 36, + 106, 15, 148, 120, 12, 20, 4, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 73, 131, 218, 74, 134, 84, 203, 253, 235, 113, + 37, 154, 200, 181, 124, 200, 40, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 227, 32, 137, 236, 62, 77, 241, 233, 55, 115, + 118, 5, 214, 25, 223, 212, 151, 1, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 231, 72, 91, 61, 117, 4, 109, 35, 47, 128, + 160, 54, 92, 2, 183, 80, 238, 15, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 15, 217, 144, 101, 148, 44, 66, 98, 215, 1, + 69, 34, 154, 23, 38, 39, 79, 159, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 159, 122, 168, 247, 203, 189, 149, 214, 105, + 18, 178, 86, 5, 236, 124, 135, 23, 57, 6, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 63, 202, 148, 172, 247, 105, 217, 97, 34, 184, + 244, 98, 53, 56, 225, 74, 235, 58, 62, 0, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 127, 230, 207, 189, 172, 35, 126, 210, 87, 49, + 143, 221, 21, 50, 204, 236, 48, 77, 110, 2, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 0, 31, 106, 191, 100, 237, 56, 110, 237, + 151, 167, 218, 244, 249, 63, 233, 3, 79, 24, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 9, 54, 37, 122, 239, 69, 57, 78, 70, 239, + 139, 138, 144, 195, 127, 28, 39, 22, 243, 0, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 99, 28, 116, 197, 90, 187, 60, 14, 191, + 88, 119, 105, 165, 163, 253, 28, 135, 221, 126, 9, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 231, 27, 137, 182, 139, 81, 95, 142, 118, + 119, 169, 30, 118, 100, 232, 33, 71, 167, 244, 94, 0, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 15, 23, 91, 33, 117, 47, 185, 143, 161, + 170, 158, 50, 157, 236, 19, 83, 199, 136, 142, 181, 3, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 159, 230, 142, 77, 147, 218, 59, 157, 79, + 170, 50, 250, 35, 62, 199, 62, 201, 87, 145, 23, 37, 0, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 63, 2, 149, 7, 193, 137, 86, 36, 28, 167, + 250, 197, 103, 109, 200, 115, 220, 109, 173, 235, 114, 1, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 127, 22, 210, 75, 138, 97, 97, 107, 25, + 135, 202, 187, 13, 70, 212, 133, 156, 74, 198, 52, 125, 14, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 224, 52, 246, 102, 207, 205, 49, + 254, 70, 233, 85, 137, 188, 74, 58, 29, 234, 190, 15, 228, 144, 0, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 201, 16, 158, 5, 26, 10, 242, 237, + 197, 28, 91, 93, 93, 235, 70, 36, 37, 117, 157, 232, 168, 5, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 227, 167, 44, 56, 4, 101, 116, 75, + 187, 31, 143, 165, 165, 49, 197, 106, 115, 147, 38, 22, 153, 56, 0, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 231, 142, 190, 49, 42, 242, 139, + 242, 80, 61, 151, 119, 120, 240, 179, 43, 130, 194, 129, 221, 250, 53, 2, + ], + [ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 15, 149, 113, 241, 165, 117, 119, + 121, 41, 101, 232, 171, 180, 100, 7, 181, 21, 153, 17, 167, 204, 27, 22, + ], +]; + +// MIN decimal256 value of little-endian format for each precision. +// Each element is the min value of signed 256-bit integer for the specified precision which +// is encoded to the 76-byte width format of little-endian. +pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [[u8; 32]; 76] = [ + [ + 247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 157, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 25, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 241, 216, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 97, 121, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 193, 189, 240, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 129, 105, 103, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 31, 10, 250, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 54, 101, 196, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 28, 244, 171, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 24, 137, 183, 232, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 240, 90, 43, 23, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 96, 141, 177, 231, 246, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 192, 133, 239, 12, 165, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 128, 57, 91, 129, 114, 252, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 63, 144, 13, 121, 220, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 118, 162, 135, 186, 156, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 156, 88, 76, 73, 31, 242, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 24, 118, 251, 220, 56, 117, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 240, 156, 210, 161, 56, 148, 250, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 96, 33, 58, 82, 54, 202, 201, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 192, 77, 69, 54, 31, 230, 225, 253, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 128, 9, 181, 30, 56, 253, 210, 234, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 95, 18, 51, 49, 228, 61, 44, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 182, 183, 254, 235, 233, 106, 186, 247, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 28, 45, 243, 55, 35, 45, 72, 173, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 24, 195, 127, 47, 96, 195, 209, 196, 252, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 240, 158, 253, 218, 193, 161, 49, 176, 223, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 96, 53, 232, 141, 146, 81, 240, 225, 188, 254, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 192, 21, 18, 139, 185, 47, 99, 211, 96, 243, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 128, 217, 180, 110, 63, 221, 223, 65, 200, 129, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 127, 16, 83, 122, 164, 190, 146, 210, 17, 251, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 246, 164, 62, 199, 108, 114, 187, 57, 178, 206, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 156, 113, 114, 200, 63, 120, 82, 65, 246, 18, 254, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 24, 112, 120, 212, 125, 178, 56, 141, 158, 189, 236, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 240, 96, 180, 76, 234, 248, 54, 132, 49, 104, 63, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 96, 201, 11, 255, 38, 185, 37, 42, 239, 17, 122, 248, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 192, 221, 117, 246, 133, 59, 121, 165, 87, 179, 196, 180, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 128, 169, 154, 160, 59, 83, 188, 118, 108, 1, 175, 15, 253, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 159, 10, 70, 84, 64, 91, 163, 60, 14, 214, 156, 226, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 54, 106, 188, 74, 131, 144, 97, 94, 142, 92, 32, 218, 254, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 28, 38, 92, 235, 32, 165, 207, 175, 143, 157, 67, 133, 244, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 24, 125, 153, 49, 73, 115, 28, 222, 156, 39, 164, 52, 141, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 240, 226, 254, 239, 219, 128, 28, 173, 32, 140, 105, 14, 132, 251, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 96, 221, 244, 95, 151, 8, 29, 195, 70, 121, 31, 144, 40, 211, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 192, 165, 144, 191, 233, 85, 34, 159, 195, 188, 58, 161, 149, 63, + 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 128, 121, 166, 123, 33, 91, 87, 55, 164, 95, 75, 76, 216, 123, + 238, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 191, 128, 212, 78, 143, 105, 41, 106, 188, 241, 250, 114, 214, + 80, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 118, 7, 77, 20, 153, 31, 158, 37, 92, 113, 205, 125, 96, 40, + 249, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 156, 74, 2, 203, 250, 59, 45, 120, 153, 109, 6, 234, 196, 147, + 187, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 24, 234, 22, 238, 203, 87, 196, 177, 254, 71, 64, 36, 177, 197, + 83, 253, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 240, 36, 229, 76, 247, 109, 171, 241, 242, 207, 130, 106, 235, + 184, 69, 229, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 96, 113, 243, 0, 169, 75, 178, 112, 125, 31, 28, 41, 50, 57, + 185, 244, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 192, 109, 130, 9, 154, 244, 246, 102, 230, 58, 25, 155, 245, + 59, 60, 143, 245, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 128, 73, 24, 95, 4, 142, 165, 5, 0, 77, 252, 14, 152, 87, 90, + 152, 151, 255, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 223, 242, 182, 43, 140, 119, 56, 0, 2, 219, 149, 240, 107, + 135, 243, 235, 251, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 182, 124, 37, 181, 121, 171, 52, 2, 20, 142, 218, 101, 55, + 74, 131, 55, 215, 255, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 28, 223, 118, 19, 193, 178, 14, 22, 200, 140, 137, 250, 41, + 230, 32, 43, 104, 254, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 24, 183, 164, 194, 138, 251, 146, 220, 208, 127, 95, 201, + 163, 253, 72, 175, 17, 240, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 240, 38, 111, 154, 107, 211, 189, 157, 40, 254, 186, 221, + 101, 232, 217, 216, 176, 96, 255, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 96, 133, 87, 8, 52, 66, 106, 41, 150, 237, 77, 169, 250, 19, + 131, 120, 232, 198, 249, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 192, 53, 107, 83, 8, 150, 38, 158, 221, 71, 11, 157, 202, + 199, 30, 181, 20, 197, 193, 255, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 128, 25, 48, 66, 83, 220, 129, 45, 168, 206, 112, 34, 234, + 205, 51, 19, 207, 178, 145, 253, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 255, 224, 149, 64, 155, 18, 199, 145, 18, 104, 88, 37, + 11, 6, 192, 22, 252, 176, 231, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 246, 201, 218, 133, 16, 186, 198, 177, 185, 16, 116, 117, + 111, 60, 128, 227, 216, 233, 12, 255, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 156, 227, 139, 58, 165, 68, 195, 241, 64, 167, 136, 150, + 90, 92, 2, 227, 120, 34, 129, 246, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 24, 228, 118, 73, 116, 174, 160, 113, 137, 136, 86, 225, + 137, 155, 23, 222, 184, 88, 11, 161, 255, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 240, 232, 164, 222, 138, 208, 70, 112, 94, 85, 97, 205, + 98, 19, 236, 172, 56, 119, 113, 74, 252, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 96, 25, 113, 178, 108, 37, 196, 98, 176, 85, 205, 5, 220, + 193, 56, 193, 54, 168, 110, 232, 218, 255, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 192, 253, 106, 248, 62, 118, 169, 219, 227, 88, 5, 58, + 152, 146, 55, 140, 35, 146, 82, 20, 141, 254, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 128, 233, 45, 180, 117, 158, 158, 148, 230, 120, 53, 68, + 242, 185, 43, 122, 99, 181, 57, 203, 130, 241, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 31, 203, 9, 153, 48, 50, 206, 1, 185, 22, 170, 118, + 67, 181, 197, 226, 21, 65, 240, 27, 111, 255, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 54, 239, 97, 250, 229, 245, 13, 18, 58, 227, 164, 162, + 162, 20, 185, 219, 218, 138, 98, 23, 87, 250, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 28, 88, 211, 199, 251, 154, 139, 180, 68, 224, 112, + 90, 90, 206, 58, 149, 140, 108, 217, 233, 102, 199, 255, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 24, 113, 65, 206, 213, 13, 116, 13, 175, 194, 104, + 136, 135, 15, 76, 212, 125, 61, 126, 34, 5, 202, 253, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 240, 106, 142, 14, 90, 138, 136, 134, 214, 154, 23, + 84, 75, 155, 248, 74, 234, 102, 238, 88, 51, 228, 233, + ], +]; + /// `MAX_DECIMAL_FOR_EACH_PRECISION[p]` holds the maximum `i128` value -/// that can be stored in [DataType::Decimal] value of precision `p` +/// that can be stored in [DataType::Decimal128] value of precision `p` pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ 9, 99, @@ -299,7 +930,7 @@ pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ ]; /// `MIN_DECIMAL_FOR_EACH_PRECISION[p]` holds the minimum `i128` value -/// that can be stored in a [DataType::Decimal] value of precision `p` +/// that can be stored in a [DataType::Decimal128] value of precision `p` pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ -9, -99, @@ -341,45 +972,90 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ -99999999999999999999999999999999999999, ]; -/// The maximum precision for [DataType::Decimal] values -pub const DECIMAL_MAX_PRECISION: usize = 38; +/// The maximum precision for [DataType::Decimal128] values +pub const DECIMAL128_MAX_PRECISION: u8 = 38; + +/// The maximum scale for [DataType::Decimal128] values +pub const DECIMAL128_MAX_SCALE: u8 = 38; -/// The maximum scale for [DataType::Decimal] values -pub const DECIMAL_MAX_SCALE: usize = 38; +/// The maximum precision for [DataType::Decimal256] values +pub const DECIMAL256_MAX_PRECISION: u8 = 76; -/// The default scale for [DataType::Decimal] values -pub const DECIMAL_DEFAULT_SCALE: usize = 10; +/// The maximum scale for [DataType::Decimal256] values +pub const DECIMAL256_MAX_SCALE: u8 = 76; + +/// The default scale for [DataType::Decimal128] and [DataType::Decimal256] values +pub const DECIMAL_DEFAULT_SCALE: u8 = 10; /// Validates that the specified `i128` value can be properly /// interpreted as a Decimal number with precision `precision` #[inline] -pub(crate) fn validate_decimal_precision(value: i128, precision: usize) -> Result { - // TODO: add validation logic for precision > 38 - if precision > 38 { - return Ok(value); +pub(crate) fn validate_decimal_precision(value: i128, precision: u8) -> Result<()> { + if precision > DECIMAL128_MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "Max precision of a Decimal128 is {}, but got {}", + DECIMAL128_MAX_PRECISION, precision, + ))); } - let max = MAX_DECIMAL_FOR_EACH_PRECISION[precision - 1]; - let min = MIN_DECIMAL_FOR_EACH_PRECISION[precision - 1]; + let max = MAX_DECIMAL_FOR_EACH_PRECISION[usize::from(precision) - 1]; + let min = MIN_DECIMAL_FOR_EACH_PRECISION[usize::from(precision) - 1]; if value > max { Err(ArrowError::InvalidArgumentError(format!( - "{} is too large to store in a Decimal of precision {}. Max is {}", + "{} is too large to store in a Decimal128 of precision {}. Max is {}", value, precision, max ))) } else if value < min { Err(ArrowError::InvalidArgumentError(format!( - "{} is too small to store in a Decimal of precision {}. Min is {}", + "{} is too small to store in a Decimal128 of precision {}. Min is {}", value, precision, min ))) } else { - Ok(value) + Ok(()) + } +} + +/// Validates that the specified `byte_array` of little-endian format +/// value can be properly interpreted as a Decimal256 number with precision `precision` +#[inline] +pub(crate) fn validate_decimal256_precision_with_lt_bytes( + lt_value: &[u8], + precision: u8, +) -> Result<()> { + if precision > DECIMAL256_MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "Max precision of a Decimal256 is {}, but got {}", + DECIMAL256_MAX_PRECISION, precision, + ))); + } + let max = MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[usize::from(precision) - 1]; + let min = MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[usize::from(precision) - 1]; + + if singed_cmp_le_bytes(lt_value, &max) == Ordering::Greater { + Err(ArrowError::InvalidArgumentError(format!( + "{:?} is too large to store in a Decimal256 of precision {}. Max is {:?}", + BigInt::from_signed_bytes_le(lt_value), + precision, + BigInt::from_signed_bytes_le(&max) + ))) + } else if singed_cmp_le_bytes(lt_value, &min) == Ordering::Less { + Err(ArrowError::InvalidArgumentError(format!( + "{:?} is too small to store in a Decimal256 of precision {}. Min is {:?}", + BigInt::from_signed_bytes_le(lt_value), + precision, + BigInt::from_signed_bytes_le(&min) + ))) + } else { + Ok(()) } } impl DataType { /// Parse a data type from a JSON representation. - pub(crate) fn from(json: &Value) -> Result { + #[cfg(feature = "json")] + pub(crate) fn from(json: &serde_json::Value) -> Result { + use serde_json::Value; let default_field = Field::new("", DataType::Boolean, true); match *json { Value::Object(ref map) => match map.get("name") { @@ -402,19 +1078,31 @@ impl DataType { Some(s) if s == "decimal" => { // return a list with any type as its child isn't defined in the map let precision = match map.get("precision") { - Some(p) => Ok(p.as_u64().unwrap() as usize), + Some(p) => Ok(p.as_u64().unwrap().try_into().unwrap()), None => Err(ArrowError::ParseError( "Expecting a precision for decimal".to_string(), )), - }; + }?; let scale = match map.get("scale") { - Some(s) => Ok(s.as_u64().unwrap() as usize), + Some(s) => Ok(s.as_u64().unwrap().try_into().unwrap()), _ => Err(ArrowError::ParseError( "Expecting a scale for decimal".to_string(), )), + }?; + let bit_width: usize = match map.get("bitWidth") { + Some(b) => b.as_u64().unwrap() as usize, + _ => 128, // Default bit width }; - Ok(DataType::Decimal(precision?, scale?)) + if bit_width == 128 { + Ok(DataType::Decimal128(precision, scale)) + } else if bit_width == 256 { + Ok(DataType::Decimal256(precision, scale)) + } else { + Err(ArrowError::ParseError( + "Decimal bit_width invalid".to_string(), + )) + } } Some(s) if s == "floatingpoint" => match map.get("precision") { Some(p) if p == "HALF" => Ok(DataType::Float16), @@ -436,7 +1124,7 @@ impl DataType { }; let tz = match map.get("timezone") { None => Ok(None), - Some(VString(tz)) => Ok(Some(tz.clone())), + Some(serde_json::Value::String(tz)) => Ok(Some(tz.clone())), _ => Err(ArrowError::ParseError( "timezone must be a string".to_string(), )), @@ -615,7 +1303,9 @@ impl DataType { } /// Generate a JSON representation of the data type. - pub fn to_json(&self) -> Value { + #[cfg(feature = "json")] + pub fn to_json(&self) -> serde_json::Value { + use serde_json::json; match self { DataType::Null => json!({"name": "null"}), DataType::Boolean => json!({"name": "bool"}), @@ -694,8 +1384,11 @@ impl DataType { TimeUnit::Nanosecond => "NANOSECOND", }}), DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), - DataType::Decimal(precision, scale) => { - json!({"name": "decimal", "precision": precision, "scale": scale}) + DataType::Decimal128(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128}) + } + DataType::Decimal256(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256}) } DataType::Map(_, keys_sorted) => { json!({"name": "map", "keysSorted": keys_sorted}) @@ -703,7 +1396,7 @@ impl DataType { } } - /// Returns true if this type is numeric: (UInt*, Unit*, or Float*). + /// Returns true if this type is numeric: (UInt*, Int*, or Float*). pub fn is_numeric(t: &DataType) -> bool { use DataType::*; matches!( @@ -775,3 +1468,32 @@ impl DataType { } } } + +#[cfg(test)] +mod test { + use crate::datatypes::datatype::{ + MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION, + MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION, + }; + use crate::util::decimal::Decimal256; + use num::{BigInt, Num}; + + #[test] + fn test_decimal256_min_max_for_precision() { + // The precision from 1 to 76 + let mut max_value = "9".to_string(); + let mut min_value = "-9".to_string(); + for i in 1..77 { + let max_decimal = + Decimal256::from(BigInt::from_str_radix(max_value.as_str(), 10).unwrap()); + let min_decimal = + Decimal256::from(BigInt::from_str_radix(min_value.as_str(), 10).unwrap()); + let max_bytes = MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[i - 1]; + let min_bytes = MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[i - 1]; + max_value += "9"; + min_value += "9"; + assert_eq!(max_decimal.raw_value(), &max_bytes); + assert_eq!(min_decimal.raw_value(), &min_bytes); + } + } +} diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 7ad468b5ed9e..ef303dfdd1ff 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -98,33 +98,33 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { ["d", extra] => { match extra.splitn(3, ',').collect::>().as_slice() { [precision, scale] => { - let parsed_precision = precision.parse::().map_err(|_| { + let parsed_precision = precision.parse::().map_err(|_| { ArrowError::CDataInterface( "The decimal type requires an integer precision".to_string(), ) })?; - let parsed_scale = scale.parse::().map_err(|_| { + let parsed_scale = scale.parse::().map_err(|_| { ArrowError::CDataInterface( "The decimal type requires an integer scale".to_string(), ) })?; - DataType::Decimal(parsed_precision, parsed_scale) + DataType::Decimal128(parsed_precision, parsed_scale) }, [precision, scale, bits] => { if *bits != "128" { return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string())); } - let parsed_precision = precision.parse::().map_err(|_| { + let parsed_precision = precision.parse::().map_err(|_| { ArrowError::CDataInterface( "The decimal type requires an integer precision".to_string(), ) })?; - let parsed_scale = scale.parse::().map_err(|_| { + let parsed_scale = scale.parse::().map_err(|_| { ArrowError::CDataInterface( "The decimal type requires an integer scale".to_string(), ) })?; - DataType::Decimal(parsed_precision, parsed_scale) + DataType::Decimal128(parsed_precision, parsed_scale) } _ => { return Err(ArrowError::CDataInterface(format!( @@ -253,7 +253,9 @@ fn get_format_string(dtype: &DataType) -> Result { DataType::LargeUtf8 => Ok("U".to_string()), DataType::FixedSizeBinary(num_bytes) => Ok(format!("w:{}", num_bytes)), DataType::FixedSizeList(_, num_elems) => Ok(format!("+w:{}", num_elems)), - DataType::Decimal(precision, scale) => Ok(format!("d:{},{}", precision, scale)), + DataType::Decimal128(precision, scale) => { + Ok(format!("d:{},{}", precision, scale)) + } DataType::Date32 => Ok("tdD".to_string()), DataType::Date64 => Ok("tdm".to_string()), DataType::Time32(TimeUnit::Second) => Ok("tts".to_string()), diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs index ade48d93dab1..ac966cafe34f 100644 --- a/arrow/src/datatypes/field.rs +++ b/arrow/src/datatypes/field.rs @@ -15,22 +15,19 @@ // specific language governing permissions and limitations // under the License. +use crate::error::{ArrowError, Result}; use std::cmp::Ordering; use std::collections::BTreeMap; use std::hash::{Hash, Hasher}; -use serde_derive::{Deserialize, Serialize}; -use serde_json::{json, Value}; - -use crate::error::{ArrowError, Result}; - use super::DataType; /// Describes a single column in a [`Schema`](super::Schema). /// /// A [`Schema`](super::Schema) is an ordered collection of /// [`Field`] objects. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Field { name: String, data_type: DataType, @@ -38,7 +35,7 @@ pub struct Field { dict_id: i64, dict_is_ordered: bool, /// A map of key-value pairs containing additional custom meta data. - #[serde(skip_serializing_if = "Option::is_none")] + #[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))] metadata: Option>, } @@ -209,23 +206,17 @@ impl Field { } fn _fields<'a>(&'a self, dt: &'a DataType) -> Vec<&Field> { - let mut collected_fields = vec![]; - match dt { DataType::Struct(fields) | DataType::Union(fields, _, _) => { - collected_fields.extend(fields.iter().flat_map(|f| f.fields())) + fields.iter().flat_map(|f| f.fields()).collect() } DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) - | DataType::Map(field, _) => collected_fields.extend(field.fields()), - DataType::Dictionary(_, value_field) => { - collected_fields.append(&mut self._fields(value_field.as_ref())) - } - _ => (), + | DataType::Map(field, _) => field.fields(), + DataType::Dictionary(_, value_field) => self._fields(value_field.as_ref()), + _ => vec![], } - - collected_fields } /// Returns a vector containing all (potentially nested) `Field` instances selected by the @@ -260,7 +251,9 @@ impl Field { } /// Parse a `Field` definition from a JSON representation. - pub fn from(json: &Value) -> Result { + #[cfg(feature = "json")] + pub fn from(json: &serde_json::Value) -> Result { + use serde_json::Value; match *json { Value::Object(ref map) => { let name = match map.get("name") { @@ -503,19 +496,18 @@ impl Field { } /// Generate a JSON representation of the `Field`. - pub fn to_json(&self) -> Value { - let children: Vec = match self.data_type() { + #[cfg(feature = "json")] + pub fn to_json(&self) -> serde_json::Value { + let children: Vec = match self.data_type() { DataType::Struct(fields) => fields.iter().map(|f| f.to_json()).collect(), - DataType::List(field) => vec![field.to_json()], - DataType::LargeList(field) => vec![field.to_json()], - DataType::FixedSizeList(field, _) => vec![field.to_json()], - DataType::Map(field, _) => { - vec![field.to_json()] - } + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) + | DataType::Map(field, _) => vec![field.to_json()], _ => vec![], }; match self.data_type() { - DataType::Dictionary(ref index_type, ref value_type) => json!({ + DataType::Dictionary(ref index_type, ref value_type) => serde_json::json!({ "name": self.name, "nullable": self.nullable, "type": value_type.to_json(), @@ -526,7 +518,7 @@ impl Field { "isOrdered": self.dict_is_ordered } }), - _ => json!({ + _ => serde_json::json!({ "name": self.name, "nullable": self.nullable, "type": self.data_type.to_json(), @@ -550,6 +542,17 @@ impl Field { /// assert!(field.is_nullable()); /// ``` pub fn try_merge(&mut self, from: &Field) -> Result<()> { + if from.dict_id != self.dict_id { + return Err(ArrowError::SchemaError( + "Fail to merge schema Field due to conflicting dict_id".to_string(), + )); + } + if from.dict_is_ordered != self.dict_is_ordered { + return Err(ArrowError::SchemaError( + "Fail to merge schema Field due to conflicting dict_is_ordered" + .to_string(), + )); + } // merge metadata match (self.metadata(), from.metadata()) { (Some(self_metadata), Some(from_metadata)) => { @@ -572,31 +575,16 @@ impl Field { } _ => {} } - if from.dict_id != self.dict_id { - return Err(ArrowError::SchemaError( - "Fail to merge schema Field due to conflicting dict_id".to_string(), - )); - } - if from.dict_is_ordered != self.dict_is_ordered { - return Err(ArrowError::SchemaError( - "Fail to merge schema Field due to conflicting dict_is_ordered" - .to_string(), - )); - } match &mut self.data_type { DataType::Struct(nested_fields) => match &from.data_type { DataType::Struct(from_nested_fields) => { for from_field in from_nested_fields { - let mut is_new_field = true; - for self_field in nested_fields.iter_mut() { - if self_field.name != from_field.name { - continue; - } - is_new_field = false; - self_field.try_merge(from_field)?; - } - if is_new_field { - nested_fields.push(from_field.clone()); + match nested_fields + .iter_mut() + .find(|self_field| self_field.name == from_field.name) + { + Some(self_field) => self_field.try_merge(from_field)?, + None => nested_fields.push(from_field.clone()), } } } @@ -675,7 +663,8 @@ impl Field { | DataType::FixedSizeBinary(_) | DataType::Utf8 | DataType::LargeUtf8 - | DataType::Decimal(_, _) => { + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => { if self.data_type != from.data_type { return Err(ArrowError::SchemaError( "Fail to merge schema Field due to conflicting datatype" @@ -684,9 +673,7 @@ impl Field { } } } - if from.nullable { - self.nullable = from.nullable; - } + self.nullable |= from.nullable; Ok(()) } @@ -697,41 +684,25 @@ impl Field { /// * self.metadata is a superset of other.metadata /// * all other fields are equal pub fn contains(&self, other: &Field) -> bool { - if self.name != other.name - || self.data_type != other.data_type - || self.dict_id != other.dict_id - || self.dict_is_ordered != other.dict_is_ordered - { - return false; - } - - if self.nullable != other.nullable && !self.nullable { - return false; - } - + self.name == other.name + && self.data_type == other.data_type + && self.dict_id == other.dict_id + && self.dict_is_ordered == other.dict_is_ordered + // self need to be nullable or both of them are not nullable + && (self.nullable || !other.nullable) // make sure self.metadata is a superset of other.metadata - match (&self.metadata, &other.metadata) { - (None, Some(_)) => { - return false; - } + && match (&self.metadata, &other.metadata) { + (_, None) => true, + (None, Some(_)) => false, (Some(self_meta), Some(other_meta)) => { - for (k, v) in other_meta.iter() { + other_meta.iter().all(|(k, v)| { match self_meta.get(k) { - Some(s) => { - if s != v { - return false; - } - } - None => { - return false; - } + Some(s) => s == v, + None => false } - } + }) } - _ => {} } - - true } } @@ -744,7 +715,7 @@ impl std::fmt::Display for Field { #[cfg(test)] mod test { - use super::{DataType, Field}; + use super::*; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; @@ -839,4 +810,72 @@ mod test { assert_ne!(dict1, dict2); assert_ne!(get_field_hash(&dict1), get_field_hash(&dict2)); } + + #[test] + fn test_contains_reflexivity() { + let mut field = Field::new("field1", DataType::Float16, false); + field.set_metadata(Some(BTreeMap::from([ + (String::from("k0"), String::from("v0")), + (String::from("k1"), String::from("v1")), + ]))); + assert!(field.contains(&field)) + } + + #[test] + fn test_contains_transitivity() { + let child_field = Field::new("child1", DataType::Float16, false); + + let mut field1 = Field::new("field1", DataType::Struct(vec![child_field]), false); + field1.set_metadata(Some(BTreeMap::from([( + String::from("k1"), + String::from("v1"), + )]))); + + let mut field2 = Field::new("field1", DataType::Struct(vec![]), true); + field2.set_metadata(Some(BTreeMap::from([( + String::from("k2"), + String::from("v2"), + )]))); + field2.try_merge(&field1).unwrap(); + + let mut field3 = Field::new("field1", DataType::Struct(vec![]), false); + field3.set_metadata(Some(BTreeMap::from([( + String::from("k3"), + String::from("v3"), + )]))); + field3.try_merge(&field2).unwrap(); + + assert!(field2.contains(&field1)); + assert!(field3.contains(&field2)); + assert!(field3.contains(&field1)); + + assert!(!field1.contains(&field2)); + assert!(!field1.contains(&field3)); + assert!(!field2.contains(&field3)); + } + + #[test] + fn test_contains_nullable() { + let field1 = Field::new("field1", DataType::Boolean, true); + let field2 = Field::new("field1", DataType::Boolean, false); + assert!(field1.contains(&field2)); + assert!(!field2.contains(&field1)); + } + + #[test] + fn test_contains_must_have_same_fields() { + let child_field1 = Field::new("child1", DataType::Float16, false); + let child_field2 = Field::new("child2", DataType::Float16, false); + + let field1 = + Field::new("field1", DataType::Struct(vec![child_field1.clone()]), true); + let field2 = Field::new( + "field1", + DataType::Struct(vec![child_field1, child_field2]), + true, + ); + + assert!(!field1.contains(&field2)); + assert!(!field2.contains(&field1)); + } } diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs index c082bc64c660..38b6c7bf9744 100644 --- a/arrow/src/datatypes/mod.rs +++ b/arrow/src/datatypes/mod.rs @@ -37,8 +37,10 @@ pub use types::*; mod datatype; pub use datatype::*; mod delta; -mod ffi; +#[cfg(feature = "ffi")] +mod ffi; +#[cfg(feature = "ffi")] pub use ffi::*; /// A reference-counted reference to a [`Schema`](crate::datatypes::Schema). @@ -48,11 +50,15 @@ pub type SchemaRef = Arc; mod tests { use super::*; use crate::error::Result; - use serde_json::Value::{Bool, Number as VNumber, String as VString}; - use serde_json::{Number, Value}; - use std::{ - collections::{BTreeMap, HashMap}, - f32::NAN, + use std::collections::{BTreeMap, HashMap}; + + #[cfg(feature = "json")] + use crate::json::JsonSerializable; + + #[cfg(feature = "json")] + use serde_json::{ + Number, Value, + Value::{Bool, Number as VNumber, String as VString}, }; #[test] @@ -104,6 +110,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn create_struct_type() { let _person = DataType::Struct(vec![ Field::new("first_name", DataType::Utf8, false), @@ -120,6 +127,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn serde_struct_type() { let kv_array = [("k".to_string(), "v".to_string())]; let field_metadata: BTreeMap = kv_array.iter().cloned().collect(); @@ -167,6 +175,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn struct_field_to_json() { let f = Field::new( "address", @@ -210,6 +219,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn map_field_to_json() { let f = Field::new( "my_map", @@ -270,6 +280,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn primitive_field_to_json() { let f = Field::new("first_name", DataType::Utf8, false); let value: Value = serde_json::from_str( @@ -286,6 +297,7 @@ mod tests { assert_eq!(value, f.to_json()); } #[test] + #[cfg(feature = "json")] fn parse_struct_from_json() { let json = r#" { @@ -332,6 +344,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn parse_map_from_json() { let json = r#" { @@ -395,6 +408,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn parse_union_from_json() { let json = r#" { @@ -450,6 +464,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn parse_utf8_from_json() { let json = "{\"name\":\"utf8\"}"; let value: Value = serde_json::from_str(json).unwrap(); @@ -458,6 +473,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn parse_int32_from_json() { let json = "{\"name\": \"int\", \"isSigned\": true, \"bitWidth\": 32}"; let value: Value = serde_json::from_str(json).unwrap(); @@ -466,6 +482,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn schema_json() { // Add some custom metadata let metadata: HashMap = @@ -1226,6 +1243,7 @@ mod tests { } #[test] + #[cfg(feature = "json")] fn test_arrow_native_type_to_json() { assert_eq!(Some(Bool(true)), true.into_json_value()); assert_eq!(Some(VNumber(Number::from(1))), 1i8.into_json_value()); @@ -1245,7 +1263,7 @@ mod tests { Some(VNumber(Number::from_f64(0.01f64).unwrap())), 0.01f64.into_json_value() ); - assert_eq!(None, NAN.into_json_value()); + assert_eq!(None, f32::NAN.into_json_value()); } fn person_schema() -> Schema { @@ -1481,23 +1499,31 @@ mod tests { .is_err()); // incompatible metadata should throw error - assert!(Schema::try_merge(vec![ + let res = Schema::try_merge(vec![ Schema::new_with_metadata( vec![Field::new("first_name", DataType::Utf8, false)], - [("foo".to_string(), "bar".to_string()),] + [("foo".to_string(), "bar".to_string())] .iter() .cloned() - .collect::>() + .collect::>(), ), Schema::new_with_metadata( vec![Field::new("last_name", DataType::Utf8, false)], - [("foo".to_string(), "baz".to_string()),] + [("foo".to_string(), "baz".to_string())] .iter() .cloned() - .collect::>() - ) + .collect::>(), + ), ]) - .is_err()); + .unwrap_err(); + + let expected = "Fail to merge schema due to conflicting metadata. Key 'foo' has different values 'bar' and 'baz'"; + assert!( + res.to_string().contains(expected), + "Could not find expected string '{}' in '{}'", + expected, + res + ); Ok(()) } diff --git a/arrow/src/datatypes/native.rs b/arrow/src/datatypes/native.rs index d9a3f667d8e4..207e8cb40330 100644 --- a/arrow/src/datatypes/native.rs +++ b/arrow/src/datatypes/native.rs @@ -17,17 +17,11 @@ use super::DataType; use half::f16; -use serde_json::{Number, Value}; mod private { pub trait Sealed {} } -/// Trait declaring any type that is serializable to JSON. This includes all primitive types (bool, i32, etc.). -pub trait JsonSerializable: 'static { - fn into_json_value(self) -> Option; -} - /// Trait expressing a Rust type that has the same in-memory representation /// as Arrow. This includes `i16`, `f32`, but excludes `bool` (which in arrow is represented in bits). /// @@ -58,8 +52,8 @@ pub trait ArrowNativeType: + PartialOrd + std::str::FromStr + Default - + JsonSerializable + private::Sealed + + 'static { /// Convert native type from usize. #[inline] @@ -120,18 +114,6 @@ pub trait ArrowPrimitiveType: 'static { } } -impl JsonSerializable for bool { - fn into_json_value(self) -> Option { - Some(self.into()) - } -} - -impl JsonSerializable for i8 { - fn into_json_value(self) -> Option { - Some(self.into()) - } -} - impl private::Sealed for i8 {} impl ArrowNativeType for i8 { #[inline] @@ -150,12 +132,6 @@ impl ArrowNativeType for i8 { } } -impl JsonSerializable for i16 { - fn into_json_value(self) -> Option { - Some(self.into()) - } -} - impl private::Sealed for i16 {} impl ArrowNativeType for i16 { #[inline] @@ -174,12 +150,6 @@ impl ArrowNativeType for i16 { } } -impl JsonSerializable for i32 { - fn into_json_value(self) -> Option { - Some(self.into()) - } -} - impl private::Sealed for i32 {} impl ArrowNativeType for i32 { #[inline] @@ -204,12 +174,6 @@ impl ArrowNativeType for i32 { } } -impl JsonSerializable for i64 { - fn into_json_value(self) -> Option { - Some(Value::Number(Number::from(self))) - } -} - impl private::Sealed for i64 {} impl ArrowNativeType for i64 { #[inline] @@ -234,16 +198,6 @@ impl ArrowNativeType for i64 { } } -impl JsonSerializable for i128 { - fn into_json_value(self) -> Option { - // Serialize as string to avoid issues with arbitrary_precision serde_json feature - // - https://github.com/serde-rs/json/issues/559 - // - https://github.com/serde-rs/json/issues/845 - // - https://github.com/serde-rs/json/issues/846 - Some(self.to_string().into()) - } -} - impl private::Sealed for i128 {} impl ArrowNativeType for i128 { #[inline] @@ -268,12 +222,6 @@ impl ArrowNativeType for i128 { } } -impl JsonSerializable for u8 { - fn into_json_value(self) -> Option { - Some(self.into()) - } -} - impl private::Sealed for u8 {} impl ArrowNativeType for u8 { #[inline] @@ -292,12 +240,6 @@ impl ArrowNativeType for u8 { } } -impl JsonSerializable for u16 { - fn into_json_value(self) -> Option { - Some(self.into()) - } -} - impl private::Sealed for u16 {} impl ArrowNativeType for u16 { #[inline] @@ -316,12 +258,6 @@ impl ArrowNativeType for u16 { } } -impl JsonSerializable for u32 { - fn into_json_value(self) -> Option { - Some(self.into()) - } -} - impl private::Sealed for u32 {} impl ArrowNativeType for u32 { #[inline] @@ -340,12 +276,6 @@ impl ArrowNativeType for u32 { } } -impl JsonSerializable for u64 { - fn into_json_value(self) -> Option { - Some(self.into()) - } -} - impl private::Sealed for u64 {} impl ArrowNativeType for u64 { #[inline] @@ -364,24 +294,6 @@ impl ArrowNativeType for u64 { } } -impl JsonSerializable for f16 { - fn into_json_value(self) -> Option { - Number::from_f64(f64::round(f64::from(self) * 1000.0) / 1000.0).map(Value::Number) - } -} - -impl JsonSerializable for f32 { - fn into_json_value(self) -> Option { - Number::from_f64(f64::round(self as f64 * 1000.0) / 1000.0).map(Value::Number) - } -} - -impl JsonSerializable for f64 { - fn into_json_value(self) -> Option { - Number::from_f64(self).map(Value::Number) - } -} - impl ArrowNativeType for f16 {} impl private::Sealed for f16 {} impl ArrowNativeType for f32 {} diff --git a/arrow/src/datatypes/schema.rs b/arrow/src/datatypes/schema.rs index 5a7336624f5b..efde4edefa66 100644 --- a/arrow/src/datatypes/schema.rs +++ b/arrow/src/datatypes/schema.rs @@ -16,11 +16,8 @@ // under the License. use std::collections::HashMap; -use std::default::Default; use std::fmt; - -use serde_derive::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use std::hash::Hash; use crate::error::{ArrowError, Result}; @@ -30,13 +27,16 @@ use super::Field; /// /// Note that this information is only part of the meta-data and not part of the physical /// memory layout. -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Schema { - pub(crate) fields: Vec, + pub fields: Vec, /// A map of key-value pairs containing additional meta data. - #[serde(skip_serializing_if = "HashMap::is_empty")] - #[serde(default)] - pub(crate) metadata: HashMap, + #[cfg_attr( + feature = "serde", + serde(skip_serializing_if = "HashMap::is_empty", default) + )] + pub metadata: HashMap, } impl Schema { @@ -48,7 +48,7 @@ impl Schema { } } - /// Creates a new `Schema` from a sequence of `Field` values. + /// Creates a new [`Schema`] from a sequence of [`Field`] values. /// /// # Example /// @@ -63,7 +63,7 @@ impl Schema { Self::new_with_metadata(fields, HashMap::new()) } - /// Creates a new `Schema` from a sequence of `Field` values + /// Creates a new [`Schema`] from a sequence of [`Field`] values /// and adds additional metadata in form of key value pairs. /// /// # Example @@ -148,27 +148,23 @@ impl Schema { // merge metadata if let Some(old_val) = merged.metadata.get(&key) { if old_val != &value { - return Err(ArrowError::SchemaError( - "Fail to merge schema due to conflicting metadata." - .to_string(), - )); + return Err(ArrowError::SchemaError(format!( + "Fail to merge schema due to conflicting metadata. \ + Key '{}' has different values '{}' and '{}'", + key, old_val, value + ))); } } merged.metadata.insert(key, value); } // merge fields for field in fields.into_iter() { - let mut new_field = true; - for merged_field in &mut merged.fields { - if field.name() != merged_field.name() { - continue; - } - new_field = false; - merged_field.try_merge(&field)? - } - // found a new field, add to field list - if new_field { - merged.fields.push(field); + let merged_field = + merged.fields.iter_mut().find(|f| f.name() == field.name()); + match merged_field { + Some(merged_field) => merged_field.try_merge(&field)?, + // found a new field, add to field list + None => merged.fields.push(field), } } Ok(merged) @@ -183,22 +179,23 @@ impl Schema { /// Returns a vector with references to all fields (including nested fields) #[inline] + #[cfg(feature = "ipc")] pub(crate) fn all_fields(&self) -> Vec<&Field> { self.fields.iter().flat_map(|f| f.fields()).collect() } - /// Returns an immutable reference of a specific `Field` instance selected using an + /// Returns an immutable reference of a specific [`Field`] instance selected using an /// offset within the internal `fields` vector. pub fn field(&self, i: usize) -> &Field { &self.fields[i] } - /// Returns an immutable reference of a specific `Field` instance selected by name. + /// Returns an immutable reference of a specific [`Field`] instance selected by name. pub fn field_with_name(&self, name: &str) -> Result<&Field> { Ok(&self.fields[self.index_of(name)?]) } - /// Returns a vector of immutable references to all `Field` instances selected by + /// Returns a vector of immutable references to all [`Field`] instances selected by /// the dictionary ID they use. pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { self.fields @@ -209,17 +206,16 @@ impl Schema { /// Find the index of the column with the given name. pub fn index_of(&self, name: &str) -> Result { - for i in 0..self.fields.len() { - if self.fields[i].name() == name { - return Ok(i); - } - } - let valid_fields: Vec = - self.fields.iter().map(|f| f.name().clone()).collect(); - Err(ArrowError::InvalidArgumentError(format!( - "Unable to get field named \"{}\". Valid fields: {:?}", - name, valid_fields - ))) + (0..self.fields.len()) + .find(|idx| self.fields[*idx].name() == name) + .ok_or_else(|| { + let valid_fields: Vec = + self.fields.iter().map(|f| f.name().clone()).collect(); + ArrowError::InvalidArgumentError(format!( + "Unable to get field named \"{}\". Valid fields: {:?}", + name, valid_fields + )) + }) } /// Returns an immutable reference to the Map of custom metadata key-value pairs. @@ -238,15 +234,18 @@ impl Schema { } /// Generate a JSON representation of the `Schema`. - pub fn to_json(&self) -> Value { - json!({ - "fields": self.fields.iter().map(|field| field.to_json()).collect::>(), + #[cfg(feature = "json")] + pub fn to_json(&self) -> serde_json::Value { + serde_json::json!({ + "fields": self.fields.iter().map(|field| field.to_json()).collect::>(), "metadata": serde_json::to_value(&self.metadata).unwrap() }) } /// Parse a `Schema` definition from a JSON representation. - pub fn from(json: &Value) -> Result { + #[cfg(feature = "json")] + pub fn from(json: &serde_json::Value) -> Result { + use serde_json::Value; match *json { Value::Object(ref schema) => { let fields = if let Some(Value::Array(fields)) = schema.get("fields") { @@ -273,7 +272,9 @@ impl Schema { /// Parse a `metadata` definition from a JSON representation. /// The JSON can either be an Object or an Array of Objects. - fn from_metadata(json: &Value) -> Result> { + #[cfg(feature = "json")] + fn from_metadata(json: &serde_json::Value) -> Result> { + use serde_json::Value; match json { Value::Array(_) => { let mut hashmap = HashMap::new(); @@ -315,31 +316,13 @@ impl Schema { /// /// In other words, any record conforms to `other` should also conform to `self`. pub fn contains(&self, other: &Schema) -> bool { - if self.fields.len() != other.fields.len() { - return false; - } - - for (i, field) in other.fields.iter().enumerate() { - if !self.fields[i].contains(field) { - return false; - } - } - + self.fields.len() == other.fields.len() + && self.fields.iter().zip(other.fields.iter()).all(|(f1, f2)| f1.contains(f2)) // make sure self.metadata is a superset of other.metadata - for (k, v) in &other.metadata { - match self.metadata.get(k) { - Some(s) => { - if s != v { - return false; - } - } - None => { - return false; - } - } - } - - true + && other.metadata.iter().all(|(k, v1)| match self.metadata.get(k) { + Some(v2) => v1 == v2, + _ => false, + }) } } @@ -356,7 +339,24 @@ impl fmt::Display for Schema { } } -#[derive(Deserialize)] +// need to implement `Hash` manually because `HashMap` implement Eq but no `Hash` +#[allow(clippy::derive_hash_xor_eq)] +impl Hash for Schema { + fn hash(&self, state: &mut H) { + self.fields.hash(state); + + // ensure deterministic key order + let mut keys: Vec<&String> = self.metadata.keys().collect(); + keys.sort(); + for k in keys { + k.hash(state); + self.metadata.get(k).expect("key valid").hash(state); + } + } +} + +#[cfg(feature = "json")] +#[derive(serde::Deserialize)] struct MetadataKeyValue { key: String, value: String, @@ -369,6 +369,7 @@ mod tests { use super::*; #[test] + #[cfg(feature = "json")] fn test_ser_de_metadata() { // ser/de with empty metadata let schema = Schema::new(vec![ @@ -433,4 +434,34 @@ mod tests { ) } } + + #[test] + fn test_schema_contains() { + let mut metadata1 = HashMap::new(); + metadata1.insert("meta".to_string(), "data".to_string()); + + let schema1 = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]) + .with_metadata(metadata1.clone()); + + let mut metadata2 = HashMap::new(); + metadata2.insert("meta".to_string(), "data".to_string()); + metadata2.insert("meta2".to_string(), "data".to_string()); + let schema2 = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]) + .with_metadata(metadata2); + + // reflexivity + assert!(schema1.contains(&schema1)); + assert!(schema2.contains(&schema2)); + + assert!(!schema1.contains(&schema2)); + assert!(schema2.contains(&schema1)); + } } diff --git a/arrow/src/datatypes/types.rs b/arrow/src/datatypes/types.rs index 223f969285ec..1b7d0675bb43 100644 --- a/arrow/src/datatypes/types.rs +++ b/arrow/src/datatypes/types.rs @@ -17,6 +17,10 @@ use super::{ArrowPrimitiveType, DataType, IntervalUnit, TimeUnit}; use crate::datatypes::delta::shift_months; +use crate::datatypes::{ + DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE, +}; use chrono::{Duration, NaiveDate}; use half::f16; use std::ops::{Add, Sub}; @@ -232,6 +236,18 @@ impl IntervalDayTimeType { days: i32, millis: i32, ) -> ::Native { + /* + https://github.com/apache/arrow/blob/02c8598d264c839a5b5cf3109bfd406f3b8a6ba5/cpp/src/arrow/type.h#L1433 + struct DayMilliseconds { + int32_t days = 0; + int32_t milliseconds = 0; + ... + } + 64 56 48 40 32 24 16 8 0 + +-------+-------+-------+-------+-------+-------+-------+-------+ + | days | milliseconds | + +-------+-------+-------+-------+-------+-------+-------+-------+ + */ let m = millis as u64 & u32::MAX as u64; let d = (days as u64 & u32::MAX as u64) << 32; (m | d) as ::Native @@ -264,9 +280,21 @@ impl IntervalMonthDayNanoType { days: i32, nanos: i64, ) -> ::Native { - let m = months as u128 & u32::MAX as u128; - let d = (days as u128 & u32::MAX as u128) << 32; - let n = (nanos as u128) << 64; + /* + https://github.com/apache/arrow/blob/02c8598d264c839a5b5cf3109bfd406f3b8a6ba5/cpp/src/arrow/type.h#L1475 + struct MonthDayNanos { + int32_t months; + int32_t days; + int64_t nanoseconds; + } + 128 112 96 80 64 48 32 16 0 + +-------+-------+-------+-------+-------+-------+-------+-------+ + | months | days | nanos | + +-------+-------+-------+-------+-------+-------+-------+-------+ + */ + let m = (months as u128 & u32::MAX as u128) << 96; + let d = (days as u128 & u32::MAX as u128) << 64; + let n = nanos as u128 & u64::MAX as u128; (m | d | n) as ::Native } @@ -278,9 +306,9 @@ impl IntervalMonthDayNanoType { pub fn to_parts( i: ::Native, ) -> (i32, i32, i64) { - let nanos = (i >> 64) as i64; - let days = (i >> 32) as i32; - let months = i as i32; + let months = (i >> 96) as i32; + let days = (i >> 64) as i32; + let nanos = i as i64; (months, days, nanos) } } @@ -430,3 +458,112 @@ impl Date64Type { Date64Type::from_naive_date(res) } } + +mod private { + use super::*; + + pub trait DecimalTypeSealed {} + impl DecimalTypeSealed for Decimal128Type {} + impl DecimalTypeSealed for Decimal256Type {} +} + +/// Trait representing the in-memory layout of a decimal type +pub trait NativeDecimalType: Send + Sync + Copy + AsRef<[u8]> { + fn from_slice(slice: &[u8]) -> Self; +} + +impl NativeDecimalType for [u8; N] { + fn from_slice(slice: &[u8]) -> Self { + slice.try_into().unwrap() + } +} + +/// A trait over the decimal types, used by [`DecimalArray`] to provide a generic +/// implementation across the various decimal types +/// +/// Implemented by [`Decimal128Type`] and [`Decimal256Type`] for [`Decimal128Array`] +/// and [`Decimal256Array`] respectively +/// +/// [`DecimalArray`]: [crate::array::DecimalArray] +/// [`Decimal128Array`]: [crate::array::Decimal128Array] +/// [`Decimal256Array`]: [crate::array::Decimal256Array] +pub trait DecimalType: 'static + Send + Sync + private::DecimalTypeSealed { + type Native: NativeDecimalType; + + const BYTE_LENGTH: usize; + const MAX_PRECISION: u8; + const MAX_SCALE: u8; + const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType; + const DEFAULT_TYPE: DataType; +} + +/// The decimal type for a Decimal128Array +#[derive(Debug)] +pub struct Decimal128Type {} + +impl DecimalType for Decimal128Type { + type Native = [u8; 16]; + + const BYTE_LENGTH: usize = 16; + const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION; + const MAX_SCALE: u8 = DECIMAL128_MAX_SCALE; + const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal128; + const DEFAULT_TYPE: DataType = + DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); +} + +/// The decimal type for a Decimal256Array +#[derive(Debug)] +pub struct Decimal256Type {} + +impl DecimalType for Decimal256Type { + type Native = [u8; 32]; + + const BYTE_LENGTH: usize = 32; + const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION; + const MAX_SCALE: u8 = DECIMAL256_MAX_SCALE; + const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = DataType::Decimal256; + const DEFAULT_TYPE: DataType = + DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn month_day_nano_should_roundtrip() { + let value = IntervalMonthDayNanoType::make_value(1, 2, 3); + assert_eq!(IntervalMonthDayNanoType::to_parts(value), (1, 2, 3)); + } + + #[test] + fn month_day_nano_should_roundtrip_neg() { + let value = IntervalMonthDayNanoType::make_value(-1, -2, -3); + assert_eq!(IntervalMonthDayNanoType::to_parts(value), (-1, -2, -3)); + } + + #[test] + fn day_time_should_roundtrip() { + let value = IntervalDayTimeType::make_value(1, 2); + assert_eq!(IntervalDayTimeType::to_parts(value), (1, 2)); + } + + #[test] + fn day_time_should_roundtrip_neg() { + let value = IntervalDayTimeType::make_value(-1, -2); + assert_eq!(IntervalDayTimeType::to_parts(value), (-1, -2)); + } + + #[test] + fn year_month_should_roundtrip() { + let value = IntervalYearMonthType::make_value(1, 2); + assert_eq!(IntervalYearMonthType::to_months(value), 14); + } + + #[test] + fn year_month_should_roundtrip_neg() { + let value = IntervalYearMonthType::make_value(-1, -2); + assert_eq!(IntervalYearMonthType::to_months(value), -14); + } +} diff --git a/arrow/src/error.rs b/arrow/src/error.rs index ef7abbbddef9..5d92fb930170 100644 --- a/arrow/src/error.rs +++ b/arrow/src/error.rs @@ -85,6 +85,7 @@ impl From<::std::string::FromUtf8Error> for ArrowError { } } +#[cfg(feature = "json")] impl From for ArrowError { fn from(error: serde_json::Error) -> Self { ArrowError::JsonError(error.to_string()) diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 97cbe76c84c0..528f3adc2d84 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -29,14 +29,16 @@ //! # use arrow::array::{Int32Array, Array, ArrayData, export_array_into_raw, make_array, make_array_from_raw}; //! # use arrow::error::{Result, ArrowError}; //! # use arrow::compute::kernels::arithmetic; -//! # use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; +//! # use arrow::ffi::{ArrowArray, FFI_ArrowArray, FFI_ArrowSchema}; //! # use std::convert::TryFrom; //! # fn main() -> Result<()> { //! // create an array natively //! let array = Int32Array::from(vec![Some(1), None, Some(3)]); //! //! // export it -//! let (array_ptr, schema_ptr) = array.to_raw()?; +//! +//! let ffi_array = ArrowArray::try_new(array.data().clone())?; +//! let (array_ptr, schema_ptr) = ArrowArray::into_raw(ffi_array); //! //! // consumed and used by something else... //! @@ -322,7 +324,7 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Int64, 1) | (DataType::Date64, 1) | (DataType::Time64(_), 1) => size_of::() * 8, (DataType::Float32, 1) => size_of::() * 8, (DataType::Float64, 1) => size_of::() * 8, - (DataType::Decimal(..), 1) => size_of::() * 8, + (DataType::Decimal128(..), 1) => size_of::() * 8, (DataType::Timestamp(..), 1) => size_of::() * 8, (DataType::Duration(..), 1) => size_of::() * 8, // primitive types have a single buffer @@ -337,7 +339,7 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Int64, _) | (DataType::Date64, _) | (DataType::Time64(_), _) | (DataType::Float32, _) | (DataType::Float64, _) | - (DataType::Decimal(..), _) | + (DataType::Decimal128(..), _) | (DataType::Timestamp(..), _) | (DataType::Duration(..), _) => { return Err(ArrowError::CDataInterface(format!( @@ -456,7 +458,7 @@ struct ArrayPrivateData { impl FFI_ArrowArray { /// creates a new `FFI_ArrowArray` from existing data. - /// # Safety + /// # Memory Leaks /// This method releases `buffers`. Consumers of this struct *must* call `release` before /// releasing this struct, or contents in `buffers` leak. pub fn new(data: &ArrayData) -> Self { @@ -836,10 +838,11 @@ impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { impl ArrowArray { /// creates a new `ArrowArray`. This is used to export to the C Data Interface. - /// # Safety - /// See safety of [ArrowArray] - #[allow(clippy::too_many_arguments)] - pub unsafe fn try_new(data: ArrayData) -> Result { + /// + /// # Memory Leaks + /// This method releases `buffers`. Consumers of this struct *must* call `release` before + /// releasing this struct, or contents in `buffers` leak. + pub fn try_new(data: ArrayData) -> Result { let array = Arc::new(FFI_ArrowArray::new(&data)); let schema = Arc::new(FFI_ArrowSchema::try_from(data.data_type())?); Ok(ArrowArray { array, schema }) @@ -909,11 +912,11 @@ impl<'a> ArrowArrayChild<'a> { mod tests { use super::*; use crate::array::{ - export_array_into_raw, make_array, Array, ArrayData, BooleanArray, DecimalArray, - DictionaryArray, DurationSecondArray, FixedSizeBinaryArray, FixedSizeListArray, - GenericBinaryArray, GenericListArray, GenericStringArray, Int32Array, MapArray, - NullArray, OffsetSizeTrait, Time32MillisecondArray, TimestampMillisecondArray, - UInt32Array, + export_array_into_raw, make_array, Array, ArrayData, BooleanArray, + Decimal128Array, DictionaryArray, DurationSecondArray, FixedSizeBinaryArray, + FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericStringArray, + Int32Array, MapArray, NullArray, OffsetSizeTrait, Time32MillisecondArray, + TimestampMillisecondArray, UInt32Array, }; use crate::compute::kernels; use crate::datatypes::{Field, Int8Type}; @@ -948,19 +951,19 @@ mod tests { // create an array natively let original_array = [Some(12345_i128), Some(-12345_i128), None] .into_iter() - .collect::() + .collect::() .with_precision_and_scale(6, 2) .unwrap(); // export it - let array = ArrowArray::try_from(original_array.data().clone())?; + let array = ArrowArray::try_from(Array::data(&original_array).clone())?; // (simulate consumer) import it let data = ArrayData::try_from(array)?; let array = make_array(data); // perform some operation - let array = array.as_any().downcast_ref::().unwrap(); + let array = array.as_any().downcast_ref::().unwrap(); // verify assert_eq!(array, &original_array); @@ -1030,12 +1033,9 @@ mod tests { .collect::(); // Construct a list array from the above two - let list_data_type = match std::mem::size_of::() { - 4 => DataType::List(Box::new(Field::new("item", DataType::Int32, false))), - _ => { - DataType::LargeList(Box::new(Field::new("item", DataType::Int32, false))) - } - }; + let list_data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( + Field::new("item", DataType::Int32, false), + )); let list_data = ArrayData::builder(list_data_type) .len(3) diff --git a/arrow/src/ipc/compression/codec.rs b/arrow/src/ipc/compression/codec.rs new file mode 100644 index 000000000000..58ba8cb86585 --- /dev/null +++ b/arrow/src/ipc/compression/codec.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::buffer::Buffer; +use crate::error::{ArrowError, Result}; +use crate::ipc::CompressionType; +use std::io::{Read, Write}; + +const LENGTH_NO_COMPRESSED_DATA: i64 = -1; +const LENGTH_OF_PREFIX_DATA: i64 = 8; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Represents compressing a ipc stream using a particular compression algorithm +pub enum CompressionCodec { + Lz4Frame, + Zstd, +} + +impl TryFrom for CompressionCodec { + type Error = ArrowError; + + fn try_from(compression_type: CompressionType) -> Result { + match compression_type { + CompressionType::ZSTD => Ok(CompressionCodec::Zstd), + CompressionType::LZ4_FRAME => Ok(CompressionCodec::Lz4Frame), + other_type => Err(ArrowError::NotYetImplemented(format!( + "compression type {:?} not supported ", + other_type + ))), + } + } +} + +impl CompressionCodec { + /// Compresses the data in `input` to `output` and appends the + /// data using the specified compression mechanism. + /// + /// returns the number of bytes written to the stream + /// + /// Writes this format to output: + /// ```text + /// [8 bytes]: uncompressed length + /// [remaining bytes]: compressed data stream + /// ``` + pub(crate) fn compress_to_vec( + &self, + input: &[u8], + output: &mut Vec, + ) -> Result { + let uncompressed_data_len = input.len(); + let original_output_len = output.len(); + + if input.is_empty() { + // empty input, nothing to do + } else { + // write compressed data directly into the output buffer + output.extend_from_slice(&uncompressed_data_len.to_le_bytes()); + self.compress(input, output)?; + + let compression_len = output.len(); + if compression_len > uncompressed_data_len { + // length of compressed data was larger than + // uncompressed data, use the uncompressed data with + // length -1 to indicate that we don't compress the + // data + output.truncate(original_output_len); + output.extend_from_slice(&LENGTH_NO_COMPRESSED_DATA.to_le_bytes()); + output.extend_from_slice(input); + } + } + Ok(output.len() - original_output_len) + } + + /// Decompresses the input into a [`Buffer`] + /// + /// The input should look like: + /// ```text + /// [8 bytes]: uncompressed length + /// [remaining bytes]: compressed data stream + /// ``` + pub(crate) fn decompress_to_buffer(&self, input: &Buffer) -> Result { + // read the first 8 bytes to determine if the data is + // compressed + let decompressed_length = read_uncompressed_size(input); + let buffer = if decompressed_length == 0 { + // emtpy + Buffer::from([]) + } else if decompressed_length == LENGTH_NO_COMPRESSED_DATA { + // no compression + input.slice(LENGTH_OF_PREFIX_DATA as usize) + } else { + // decompress data using the codec + let mut uncompressed_buffer = + Vec::with_capacity(decompressed_length as usize); + let input_data = &input[(LENGTH_OF_PREFIX_DATA as usize)..]; + self.decompress(input_data, &mut uncompressed_buffer)?; + Buffer::from(uncompressed_buffer) + }; + Ok(buffer) + } + + /// Compress the data in input buffer and write to output buffer + /// using the specified compression + fn compress(&self, input: &[u8], output: &mut Vec) -> Result<()> { + match self { + CompressionCodec::Lz4Frame => { + let mut encoder = lz4::EncoderBuilder::new().build(output)?; + encoder.write_all(input)?; + match encoder.finish().1 { + Ok(_) => Ok(()), + Err(e) => Err(e.into()), + } + } + CompressionCodec::Zstd => { + let mut encoder = zstd::Encoder::new(output, 0)?; + encoder.write_all(input)?; + match encoder.finish() { + Ok(_) => Ok(()), + Err(e) => Err(e.into()), + } + } + } + } + + /// Decompress the data in input buffer and write to output buffer + /// using the specified compression + fn decompress(&self, input: &[u8], output: &mut Vec) -> Result { + let result: Result = match self { + CompressionCodec::Lz4Frame => { + let mut decoder = lz4::Decoder::new(input)?; + match decoder.read_to_end(output) { + Ok(size) => Ok(size), + Err(e) => Err(e.into()), + } + } + CompressionCodec::Zstd => { + let mut decoder = zstd::Decoder::new(input)?; + match decoder.read_to_end(output) { + Ok(size) => Ok(size), + Err(e) => Err(e.into()), + } + } + }; + result + } +} + +/// Get the uncompressed length +/// Notes: +/// LENGTH_NO_COMPRESSED_DATA: indicate that the data that follows is not compressed +/// 0: indicate that there is no data +/// positive number: indicate the uncompressed length for the following data +#[inline] +fn read_uncompressed_size(buffer: &[u8]) -> i64 { + let len_buffer = &buffer[0..8]; + // 64-bit little-endian signed integer + i64::from_le_bytes(len_buffer.try_into().unwrap()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_lz4_compression() { + let input_bytes = "hello lz4".as_bytes(); + let codec: CompressionCodec = CompressionCodec::Lz4Frame; + let mut output_bytes: Vec = Vec::new(); + codec.compress(input_bytes, &mut output_bytes).unwrap(); + let mut result_output_bytes: Vec = Vec::new(); + codec + .decompress(output_bytes.as_slice(), &mut result_output_bytes) + .unwrap(); + assert_eq!(input_bytes, result_output_bytes.as_slice()); + } + + #[test] + fn test_zstd_compression() { + let input_bytes = "hello zstd".as_bytes(); + let codec: CompressionCodec = CompressionCodec::Zstd; + let mut output_bytes: Vec = Vec::new(); + codec.compress(input_bytes, &mut output_bytes).unwrap(); + let mut result_output_bytes: Vec = Vec::new(); + codec + .decompress(output_bytes.as_slice(), &mut result_output_bytes) + .unwrap(); + assert_eq!(input_bytes, result_output_bytes.as_slice()); + } +} diff --git a/arrow/src/ipc/compression/mod.rs b/arrow/src/ipc/compression/mod.rs new file mode 100644 index 000000000000..666fa6d86a27 --- /dev/null +++ b/arrow/src/ipc/compression/mod.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(feature = "ipc_compression")] +mod codec; +#[cfg(feature = "ipc_compression")] +pub(crate) use codec::CompressionCodec; + +#[cfg(not(feature = "ipc_compression"))] +mod stub; +#[cfg(not(feature = "ipc_compression"))] +pub(crate) use stub::CompressionCodec; diff --git a/arrow/src/ipc/compression/stub.rs b/arrow/src/ipc/compression/stub.rs new file mode 100644 index 000000000000..6240f084be3f --- /dev/null +++ b/arrow/src/ipc/compression/stub.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stubs that implement the same interface as the ipc_compression +//! codec module, but always errors. + +use crate::buffer::Buffer; +use crate::error::{ArrowError, Result}; +use crate::ipc::CompressionType; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CompressionCodec {} + +impl TryFrom for CompressionType { + type Error = ArrowError; + fn try_from(codec: CompressionCodec) -> Result { + Err(ArrowError::InvalidArgumentError( + format!("codec type {:?} not supported because arrow was not compiled with the ipc_compression feature", codec))) + } +} + +impl TryFrom for CompressionCodec { + type Error = ArrowError; + + fn try_from(compression_type: CompressionType) -> Result { + Err(ArrowError::InvalidArgumentError( + format!("compression type {:?} not supported because arrow was not compiled with the ipc_compression feature", compression_type)) + ) + } +} + +impl CompressionCodec { + #[allow(clippy::ptr_arg)] + pub(crate) fn compress_to_vec( + &self, + _input: &[u8], + _output: &mut Vec, + ) -> Result { + Err(ArrowError::InvalidArgumentError( + "compression not supported because arrow was not compiled with the ipc_compression feature".to_string() + )) + } + + pub(crate) fn decompress_to_buffer(&self, _input: &[u8]) -> Result { + Err(ArrowError::InvalidArgumentError( + "decompression not supported because arrow was not compiled with the ipc_compression feature".to_string() + )) + } +} diff --git a/arrow/src/ipc/convert.rs b/arrow/src/ipc/convert.rs index c81ea8278c4f..00503d50e338 100644 --- a/arrow/src/ipc/convert.rs +++ b/arrow/src/ipc/convert.rs @@ -320,7 +320,20 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT } ipc::Type::Decimal => { let fsb = field.type_as_decimal().unwrap(); - DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize) + let bit_width = fsb.bitWidth(); + if bit_width == 128 { + DataType::Decimal128( + fsb.precision().try_into().unwrap(), + fsb.scale().try_into().unwrap(), + ) + } else if bit_width == 256 { + DataType::Decimal256( + fsb.precision().try_into().unwrap(), + fsb.scale().try_into().unwrap(), + ) + } else { + panic!("Unexpected decimal bit width {}", bit_width) + } } ipc::Type::Union => { let union = field.type_as_union().unwrap(); @@ -660,7 +673,7 @@ pub(crate) fn get_fb_field_type<'a>( // type in the DictionaryEncoding metadata in the parent field get_fb_field_type(value_type, is_nullable, fbb) } - Decimal(precision, scale) => { + Decimal128(precision, scale) => { let mut builder = ipc::DecimalBuilder::new(fbb); builder.add_precision(*precision as i32); builder.add_scale(*scale as i32); @@ -671,6 +684,17 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&empty_fields[..])), } } + Decimal256(precision, scale) => { + let mut builder = ipc::DecimalBuilder::new(fbb); + builder.add_precision(*precision as i32); + builder.add_scale(*scale as i32); + builder.add_bitWidth(256); + FBFieldType { + type_type: ipc::Type::Decimal, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } Union(fields, type_ids, mode) => { let mut children = vec![]; for field in fields { @@ -947,7 +971,7 @@ mod tests { 123, true, ), - Field::new("decimal", DataType::Decimal(10, 6), false), + Field::new("decimal", DataType::Decimal128(10, 6), false), ], md, ); diff --git a/arrow/src/ipc/mod.rs b/arrow/src/ipc/mod.rs index d5455b454e7f..2b30e72206c3 100644 --- a/arrow/src/ipc/mod.rs +++ b/arrow/src/ipc/mod.rs @@ -22,6 +22,8 @@ pub mod convert; pub mod reader; pub mod writer; +mod compression; + #[allow(clippy::redundant_closure)] #[allow(clippy::needless_lifetimes)] #[allow(clippy::extra_unused_lifetimes)] diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index e8abd3a63269..969c8c43f026 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -21,26 +21,43 @@ //! however the `FileReader` expects a reader that supports `Seek`ing use std::collections::HashMap; +use std::fmt; use std::io::{BufReader, Read, Seek, SeekFrom}; use std::sync::Arc; use crate::array::*; -use crate::buffer::Buffer; +use crate::buffer::{Buffer, MutableBuffer}; use crate::compute::cast; use crate::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef, UnionMode}; use crate::error::{ArrowError, Result}; use crate::ipc; use crate::record_batch::{RecordBatch, RecordBatchOptions, RecordBatchReader}; +use crate::ipc::compression::CompressionCodec; use ipc::CONTINUATION_MARKER; use DataType::*; /// Read a buffer based on offset and length -fn read_buffer(buf: &ipc::Buffer, a_data: &[u8]) -> Buffer { +/// From +/// Each constituent buffer is first compressed with the indicated +/// compressor, and then written with the uncompressed length in the first 8 +/// bytes as a 64-bit little-endian signed integer followed by the compressed +/// buffer bytes (and then padding as required by the protocol). The +/// uncompressed length may be set to -1 to indicate that the data that +/// follows is not compressed, which can be useful for cases where +/// compression does not yield appreciable savings. +fn read_buffer( + buf: &ipc::Buffer, + a_data: &Buffer, + compression_codec: &Option, +) -> Result { let start_offset = buf.offset() as usize; - let end_offset = start_offset + buf.length() as usize; - let buf_data = &a_data[start_offset..end_offset]; - Buffer::from(&buf_data) + let buf_data = a_data.slice_with_length(start_offset, buf.length() as usize); + // corner case: empty buffer + match (buf_data.is_empty(), compression_codec) { + (true, _) | (_, None) => Ok(buf_data), + (false, Some(decompressor)) => decompressor.decompress_to_buffer(&buf_data), + } } /// Coordinates reading arrays based on data types. @@ -56,24 +73,24 @@ fn read_buffer(buf: &ipc::Buffer, a_data: &[u8]) -> Buffer { fn create_array( nodes: &[ipc::FieldNode], field: &Field, - data: &[u8], + data: &Buffer, buffers: &[ipc::Buffer], dictionaries_by_id: &HashMap, mut node_index: usize, mut buffer_index: usize, + compression_codec: &Option, metadata: &ipc::MetadataVersion, ) -> Result<(ArrayRef, usize, usize)> { - use DataType::*; let data_type = field.data_type(); let array = match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { let array = create_primitive_array( &nodes[node_index], data_type, - buffers[buffer_index..buffer_index + 3] + &buffers[buffer_index..buffer_index + 3] .iter() - .map(|buf| read_buffer(buf, data)) - .collect(), + .map(|buf| read_buffer(buf, data, compression_codec)) + .collect::>>()?, ); node_index += 1; buffer_index += 3; @@ -83,10 +100,10 @@ fn create_array( let array = create_primitive_array( &nodes[node_index], data_type, - buffers[buffer_index..buffer_index + 2] + &buffers[buffer_index..buffer_index + 2] .iter() - .map(|buf| read_buffer(buf, data)) - .collect(), + .map(|buf| read_buffer(buf, data, compression_codec)) + .collect::>>()?, ); node_index += 1; buffer_index += 2; @@ -96,8 +113,8 @@ fn create_array( let list_node = &nodes[node_index]; let list_buffers: Vec = buffers[buffer_index..buffer_index + 2] .iter() - .map(|buf| read_buffer(buf, data)) - .collect(); + .map(|buf| read_buffer(buf, data, compression_codec)) + .collect::>()?; node_index += 1; buffer_index += 2; let triple = create_array( @@ -108,6 +125,7 @@ fn create_array( dictionaries_by_id, node_index, buffer_index, + compression_codec, metadata, )?; node_index = triple.1; @@ -119,8 +137,8 @@ fn create_array( let list_node = &nodes[node_index]; let list_buffers: Vec = buffers[buffer_index..=buffer_index] .iter() - .map(|buf| read_buffer(buf, data)) - .collect(); + .map(|buf| read_buffer(buf, data, compression_codec)) + .collect::>()?; node_index += 1; buffer_index += 1; let triple = create_array( @@ -131,6 +149,7 @@ fn create_array( dictionaries_by_id, node_index, buffer_index, + compression_codec, metadata, )?; node_index = triple.1; @@ -140,7 +159,8 @@ fn create_array( } Struct(struct_fields) => { let struct_node = &nodes[node_index]; - let null_buffer: Buffer = read_buffer(&buffers[buffer_index], data); + let null_buffer: Buffer = + read_buffer(&buffers[buffer_index], data, compression_codec)?; node_index += 1; buffer_index += 1; @@ -157,6 +177,7 @@ fn create_array( dictionaries_by_id, node_index, buffer_index, + compression_codec, metadata, )?; node_index = triple.1; @@ -177,8 +198,8 @@ fn create_array( let index_node = &nodes[node_index]; let index_buffers: Vec = buffers[buffer_index..buffer_index + 2] .iter() - .map(|buf| read_buffer(buf, data)) - .collect(); + .map(|buf| read_buffer(buf, data, compression_codec)) + .collect::>()?; let dict_id = field.dict_id().ok_or_else(|| { ArrowError::IoError(format!("Field {} does not have dict id", field)) @@ -209,18 +230,20 @@ fn create_array( // In V4, union types has validity bitmap // In V5 and later, union types have no validity bitmap if metadata < &ipc::MetadataVersion::V5 { - read_buffer(&buffers[buffer_index], data); + read_buffer(&buffers[buffer_index], data, compression_codec)?; buffer_index += 1; } let type_ids: Buffer = - read_buffer(&buffers[buffer_index], data)[..len].into(); + read_buffer(&buffers[buffer_index], data, compression_codec)?[..len] + .into(); buffer_index += 1; let value_offsets = match mode { UnionMode::Dense => { - let buffer = read_buffer(&buffers[buffer_index], data); + let buffer = + read_buffer(&buffers[buffer_index], data, compression_codec)?; buffer_index += 1; Some(buffer[..len * 4].into()) } @@ -238,6 +261,7 @@ fn create_array( dictionaries_by_id, node_index, buffer_index, + compression_codec, metadata, )?; @@ -275,10 +299,10 @@ fn create_array( let array = create_primitive_array( &nodes[node_index], data_type, - buffers[buffer_index..buffer_index + 2] + &buffers[buffer_index..buffer_index + 2] .iter() - .map(|buf| read_buffer(buf, data)) - .collect(), + .map(|buf| read_buffer(buf, data, compression_codec)) + .collect::>>()?, ); node_index += 1; buffer_index += 2; @@ -292,16 +316,10 @@ fn create_array( /// This function should be called when doing projection in fn `read_record_batch`. /// The advancement logic references fn `create_array`. fn skip_field( - nodes: &[ipc::FieldNode], - field: &Field, - data: &[u8], - buffers: &[ipc::Buffer], - dictionaries_by_id: &HashMap, + data_type: &DataType, mut node_index: usize, mut buffer_index: usize, ) -> Result<(usize, usize)> { - use DataType::*; - let data_type = field.data_type(); match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { node_index += 1; @@ -314,30 +332,14 @@ fn skip_field( List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { node_index += 1; buffer_index += 2; - let tuple = skip_field( - nodes, - list_field, - data, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = skip_field(list_field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; } FixedSizeList(ref list_field, _) => { node_index += 1; buffer_index += 1; - let tuple = skip_field( - nodes, - list_field, - data, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = skip_field(list_field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; } @@ -347,15 +349,8 @@ fn skip_field( // skip for each field for struct_field in struct_fields { - let tuple = skip_field( - nodes, - struct_field, - data, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = + skip_field(struct_field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; } @@ -376,15 +371,7 @@ fn skip_field( }; for field in fields { - let tuple = skip_field( - nodes, - field, - data, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = skip_field(field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; @@ -407,30 +394,28 @@ fn skip_field( fn create_primitive_array( field_node: &ipc::FieldNode, data_type: &DataType, - buffers: Vec, + buffers: &[Buffer], ) -> ArrayRef { let length = field_node.length() as usize; - let null_count = field_node.null_count() as usize; + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); let array_data = match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { - // read 3 buffers + // read 3 buffers: null buffer (optional), offsets buffer and data buffer ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..3].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())) + .null_bit_buffer(null_buffer) .build() .unwrap() } FixedSizeBinary(_) => { - // read 3 buffers - let builder = ArrayData::builder(data_type.clone()) + // read 2 buffers: null buffer (optional) and data buffer + ArrayData::builder(data_type.clone()) .len(length) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - unsafe { builder.build_unchecked() } + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer) + .build() + .unwrap() } Int8 | Int16 @@ -443,49 +428,45 @@ fn create_primitive_array( | Interval(IntervalUnit::YearMonth) => { if buffers[1].len() / 8 == length && length != 1 { // interpret as a signed i64, and cast appropriately - let builder = ArrayData::builder(DataType::Int64) + let data = ArrayData::builder(DataType::Int64) .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - let data = unsafe { builder.build_unchecked() }; + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer) + .build() + .unwrap(); let values = Arc::new(Int64Array::from(data)) as ArrayRef; // this cast is infallible, the unwrap is safe let casted = cast(&values, data_type).unwrap(); casted.into_data() } else { - let builder = ArrayData::builder(data_type.clone()) + ArrayData::builder(data_type.clone()) .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - unsafe { builder.build_unchecked() } + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer) + .build() + .unwrap() } } Float32 => { if buffers[1].len() / 8 == length && length != 1 { // interpret as a f64, and cast appropriately - let builder = ArrayData::builder(DataType::Float64) + let data = ArrayData::builder(DataType::Float64) .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - let data = unsafe { builder.build_unchecked() }; + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer) + .build() + .unwrap(); let values = Arc::new(Float64Array::from(data)) as ArrayRef; // this cast is infallible, the unwrap is safe let casted = cast(&values, data_type).unwrap(); casted.into_data() } else { - let builder = ArrayData::builder(data_type.clone()) + ArrayData::builder(data_type.clone()) .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - unsafe { builder.build_unchecked() } + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer) + .build() + .unwrap() } } Boolean @@ -497,26 +478,26 @@ fn create_primitive_array( | Date64 | Duration(_) | Interval(IntervalUnit::DayTime) - | Interval(IntervalUnit::MonthDayNano) => { - let builder = ArrayData::builder(data_type.clone()) - .len(length) - .buffers(buffers[1..].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - unsafe { builder.build_unchecked() } - } - Decimal(_, _) => { - // read 3 buffers + | Interval(IntervalUnit::MonthDayNano) => ArrayData::builder(data_type.clone()) + .len(length) + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer) + .build() + .unwrap(), + Decimal128(_, _) | Decimal256(_, _) => { + // read 2 buffers: null buffer (optional) and data buffer let builder = ArrayData::builder(data_type.clone()) .len(length) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .null_bit_buffer(null_buffer); + // Don't validate the decimal array so far, + // becasue validating decimal is some what complicated + // and there is no conclusion on whether we should do it. + // For more infomation, please look at https://github.com/apache/arrow-rs/issues/2387 unsafe { builder.build_unchecked() } } - t => panic!("Data type {:?} either unsupported or not primitive", t), + t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; make_array(array_data) @@ -530,39 +511,24 @@ fn create_list_array( buffers: &[Buffer], child_array: ArrayRef, ) -> ArrayRef { - if let DataType::List(_) | DataType::LargeList(_) = *data_type { - let null_count = field_node.null_count() as usize; - let builder = ArrayData::builder(data_type.clone()) - .len(field_node.length() as usize) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .child_data(vec![child_array.into_data()]) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - make_array(unsafe { builder.build_unchecked() }) - } else if let DataType::FixedSizeList(_, _) = *data_type { - let null_count = field_node.null_count() as usize; - let builder = ArrayData::builder(data_type.clone()) - .len(field_node.length() as usize) - .buffers(buffers[1..1].to_vec()) - .offset(0) - .child_data(vec![child_array.into_data()]) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - make_array(unsafe { builder.build_unchecked() }) - } else if let DataType::Map(_, _) = *data_type { - let null_count = field_node.null_count() as usize; - let builder = ArrayData::builder(data_type.clone()) - .len(field_node.length() as usize) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .child_data(vec![child_array.into_data()]) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); - - make_array(unsafe { builder.build_unchecked() }) - } else { - panic!("Cannot create list or map array from {:?}", data_type) - } + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); + let length = field_node.length() as usize; + let child_data = child_array.into_data(); + let builder = match data_type { + List(_) | LargeList(_) | Map(_, _) => ArrayData::builder(data_type.clone()) + .len(length) + .add_buffer(buffers[1].clone()) + .add_child_data(child_data) + .null_bit_buffer(null_buffer), + + FixedSizeList(_, _) => ArrayData::builder(data_type.clone()) + .len(length) + .add_child_data(child_data) + .null_bit_buffer(null_buffer), + + _ => unreachable!("Cannot create list or map array from {:?}", data_type), + }; + make_array(builder.build().unwrap()) } /// Reads the correct number of buffers based on list type and null_count, and creates a @@ -573,14 +539,13 @@ fn create_dictionary_array( buffers: &[Buffer], value_array: ArrayRef, ) -> ArrayRef { - if let DataType::Dictionary(_, _) = *data_type { - let null_count = field_node.null_count() as usize; + if let Dictionary(_, _) = *data_type { + let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); let builder = ArrayData::builder(data_type.clone()) .len(field_node.length() as usize) - .buffers(buffers[1..2].to_vec()) - .offset(0) - .child_data(vec![value_array.into_data()]) - .null_bit_buffer((null_count > 0).then(|| buffers[0].clone())); + .add_buffer(buffers[1].clone()) + .add_child_data(value_array.into_data()) + .null_bit_buffer(null_buffer); make_array(unsafe { builder.build_unchecked() }) } else { @@ -590,7 +555,7 @@ fn create_dictionary_array( /// Creates a record batch from binary data using the `ipc::RecordBatch` indexes and the `Schema` pub fn read_record_batch( - buf: &[u8], + buf: &Buffer, batch: ipc::RecordBatch, schema: SchemaRef, dictionaries_by_id: &HashMap, @@ -603,6 +568,11 @@ pub fn read_record_batch( let field_nodes = batch.nodes().ok_or_else(|| { ArrowError::IoError("Unable to get field nodes from IPC RecordBatch".to_string()) })?; + let batch_compression = batch.compression(); + let compression_codec: Option = batch_compression + .map(|batch_compression| batch_compression.codec().try_into()) + .transpose()?; + // keep track of buffer and node index, the functions that create arrays mutate these let mut buffer_index = 0; let mut node_index = 0; @@ -626,6 +596,7 @@ pub fn read_record_batch( dictionaries_by_id, node_index, buffer_index, + &compression_codec, metadata, )?; node_index = triple.1; @@ -634,15 +605,7 @@ pub fn read_record_batch( } else { // Skip field. // This must be called to advance `node_index` and `buffer_index`. - let tuple = skip_field( - field_nodes, - field, - buf, - buffers, - dictionaries_by_id, - node_index, - buffer_index, - )?; + let tuple = skip_field(field.data_type(), node_index, buffer_index)?; node_index = tuple.0; buffer_index = tuple.1; } @@ -664,6 +627,7 @@ pub fn read_record_batch( dictionaries_by_id, node_index, buffer_index, + &compression_codec, metadata, )?; node_index = triple.1; @@ -677,7 +641,7 @@ pub fn read_record_batch( /// Read the dictionary from the buffer and provided metadata, /// updating the `dictionaries_by_id` with the resulting dictionary pub fn read_dictionary( - buf: &[u8], + buf: &Buffer, batch: ipc::DictionaryBatch, schema: &Schema, dictionaries_by_id: &mut HashMap, @@ -761,6 +725,21 @@ pub struct FileReader { projection: Option<(Vec, Schema)>, } +impl fmt::Debug for FileReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> { + f.debug_struct("FileReader") + .field("reader", &"BufReader<..>") + .field("schema", &self.schema) + .field("blocks", &self.blocks) + .field("current_block", &self.current_block) + .field("total_blocks", &self.total_blocks) + .field("dictionaries_by_id", &self.dictionaries_by_id) + .field("metadata_version", &self.metadata_version) + .field("projection", &self.projection) + .finish() + } +} + impl FileReader { /// Try to create a new file reader /// @@ -837,14 +816,15 @@ impl FileReader { let batch = message.header_as_dictionary_batch().unwrap(); // read the block that makes up the dictionary batch into a buffer - let mut buf = vec![0; block.bodyLength() as usize]; + let mut buf = + MutableBuffer::from_len_zeroed(message.bodyLength() as usize); reader.seek(SeekFrom::Start( block.offset() as u64 + block.metaDataLength() as u64, ))?; reader.read_exact(&mut buf)?; read_dictionary( - &buf, + &buf.into(), batch, &schema, &mut dictionaries_by_id, @@ -921,7 +901,6 @@ impl FileReader { let mut block_data = vec![0; meta_len as usize]; self.reader.read_exact(&mut block_data)?; - let message = ipc::root_as_message(&block_data[..]).map_err(|err| { ArrowError::IoError(format!("Unable to get root as footer: {:?}", err)) })?; @@ -946,14 +925,14 @@ impl FileReader { ) })?; // read the block that makes up the record batch into a buffer - let mut buf = vec![0; block.bodyLength() as usize]; + let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); self.reader.seek(SeekFrom::Start( block.offset() as u64 + block.metaDataLength() as u64, ))?; self.reader.read_exact(&mut buf)?; read_record_batch( - &buf, + &buf.into(), batch, self.schema(), &self.dictionaries_by_id, @@ -1013,6 +992,18 @@ pub struct StreamReader { projection: Option<(Vec, Schema)>, } +impl fmt::Debug for StreamReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> { + f.debug_struct("StreamReader") + .field("reader", &"BufReader<..>") + .field("schema", &self.schema) + .field("dictionaries_by_id", &self.dictionaries_by_id) + .field("finished", &self.finished) + .field("projection", &self.projection) + .finish() + } +} + impl StreamReader { /// Try to create a new stream reader /// @@ -1130,10 +1121,10 @@ impl StreamReader { ) })?; // read the block that makes up the record batch into a buffer - let mut buf = vec![0; message.bodyLength() as usize]; + let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); self.reader.read_exact(&mut buf)?; - read_record_batch(&buf, batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version()).map(Some) + read_record_batch(&buf.into(), batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version()).map(Some) } ipc::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().ok_or_else(|| { @@ -1142,11 +1133,11 @@ impl StreamReader { ) })?; // read the block that makes up the dictionary batch into a buffer - let mut buf = vec![0; message.bodyLength() as usize]; + let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); self.reader.read_exact(&mut buf)?; read_dictionary( - &buf, batch, &self.schema, &mut self.dictionaries_by_id, &message.version() + &buf.into(), batch, &self.schema, &mut self.dictionaries_by_id, &message.version() )?; // read the next message until we encounter a RecordBatch @@ -1182,237 +1173,8 @@ mod tests { use std::fs::File; - use flate2::read::GzDecoder; - + use crate::datatypes; use crate::datatypes::{ArrowNativeType, Float64Type, Int32Type, Int8Type}; - use crate::{datatypes, util::integration_util::*}; - - #[test] - #[cfg(not(feature = "force_validate"))] - fn read_generated_files_014() { - let testdata = crate::util::test_util::arrow_test_data(); - let version = "0.14.1"; - // the test is repetitive, thus we can read all supported files at once - let paths = vec![ - "generated_interval", - "generated_datetime", - "generated_dictionary", - "generated_map", - "generated_nested", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - "generated_decimal", - ]; - paths.iter().for_each(|path| { - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", - testdata, version, path - )) - .unwrap(); - - let mut reader = FileReader::try_new(file, None).unwrap(); - - // read expected JSON output - let arrow_json = read_gzip_json(version, path); - assert!(arrow_json.equals_reader(&mut reader)); - }); - } - - #[test] - #[should_panic(expected = "Big Endian is not supported for Decimal!")] - fn read_decimal_be_file_should_panic() { - let testdata = crate::util::test_util::arrow_test_data(); - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/1.0.0-bigendian/generated_decimal.arrow_file", - testdata - )) - .unwrap(); - FileReader::try_new(file, None).unwrap(); - } - - #[test] - #[should_panic( - expected = "Last offset 687865856 of Utf8 is larger than values length 41" - )] - fn read_dictionary_be_not_implemented() { - // The offsets are not translated for big-endian files - // https://github.com/apache/arrow-rs/issues/859 - let testdata = crate::util::test_util::arrow_test_data(); - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/1.0.0-bigendian/generated_dictionary.arrow_file", - testdata - )) - .unwrap(); - FileReader::try_new(file, None).unwrap(); - } - - #[test] - fn read_generated_be_files_should_work() { - // complementary to the previous test - let testdata = crate::util::test_util::arrow_test_data(); - let paths = vec![ - "generated_interval", - "generated_datetime", - "generated_map", - "generated_nested", - "generated_null_trivial", - "generated_null", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - ]; - paths.iter().for_each(|path| { - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/1.0.0-bigendian/{}.arrow_file", - testdata, path - )) - .unwrap(); - - FileReader::try_new(file, None).unwrap(); - }); - } - - #[test] - fn projection_should_work() { - // complementary to the previous test - let testdata = crate::util::test_util::arrow_test_data(); - let paths = vec![ - "generated_interval", - "generated_datetime", - "generated_map", - "generated_nested", - "generated_null_trivial", - "generated_null", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - ]; - paths.iter().for_each(|path| { - // We must use littleendian files here. - // The offsets are not translated for big-endian files - // https://github.com/apache/arrow-rs/issues/859 - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/1.0.0-littleendian/{}.arrow_file", - testdata, path - )) - .unwrap(); - - let reader = FileReader::try_new(file, Some(vec![0])).unwrap(); - let datatype_0 = reader.schema().fields()[0].data_type().clone(); - reader.for_each(|batch| { - let batch = batch.unwrap(); - assert_eq!(batch.columns().len(), 1); - assert_eq!(datatype_0, batch.schema().fields()[0].data_type().clone()); - }); - }); - } - - #[test] - #[cfg(not(feature = "force_validate"))] - fn read_generated_streams_014() { - let testdata = crate::util::test_util::arrow_test_data(); - let version = "0.14.1"; - // the test is repetitive, thus we can read all supported files at once - let paths = vec![ - "generated_interval", - "generated_datetime", - "generated_dictionary", - "generated_map", - "generated_nested", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - "generated_decimal", - ]; - paths.iter().for_each(|path| { - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.stream", - testdata, version, path - )) - .unwrap(); - - let mut reader = StreamReader::try_new(file, None).unwrap(); - - // read expected JSON output - let arrow_json = read_gzip_json(version, path); - assert!(arrow_json.equals_reader(&mut reader)); - // the next batch must be empty - assert!(reader.next().is_none()); - // the stream must indicate that it's finished - assert!(reader.is_finished()); - }); - } - - #[test] - fn read_generated_files_100() { - let testdata = crate::util::test_util::arrow_test_data(); - let version = "1.0.0-littleendian"; - // the test is repetitive, thus we can read all supported files at once - let paths = vec![ - "generated_interval", - "generated_datetime", - "generated_dictionary", - "generated_map", - // "generated_map_non_canonical", - "generated_nested", - "generated_null_trivial", - "generated_null", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - ]; - paths.iter().for_each(|path| { - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", - testdata, version, path - )) - .unwrap(); - - let mut reader = FileReader::try_new(file, None).unwrap(); - - // read expected JSON output - let arrow_json = read_gzip_json(version, path); - assert!(arrow_json.equals_reader(&mut reader)); - }); - } - - #[test] - fn read_generated_streams_100() { - let testdata = crate::util::test_util::arrow_test_data(); - let version = "1.0.0-littleendian"; - // the test is repetitive, thus we can read all supported files at once - let paths = vec![ - "generated_interval", - "generated_datetime", - "generated_dictionary", - "generated_map", - // "generated_map_non_canonical", - "generated_nested", - "generated_null_trivial", - "generated_null", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - ]; - paths.iter().for_each(|path| { - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.stream", - testdata, version, path - )) - .unwrap(); - - let mut reader = StreamReader::try_new(file, None).unwrap(); - - // read expected JSON output - let arrow_json = read_gzip_json(version, path); - assert!(arrow_json.equals_reader(&mut reader)); - // the next batch must be empty - assert!(reader.next().is_none()); - // the stream must indicate that it's finished - assert!(reader.is_finished()); - }); - } fn create_test_projection_schema() -> Schema { // define field types @@ -1469,7 +1231,7 @@ mod tests { let array1 = StringArray::from(vec!["foo", "bar", "baz"]); let array2 = BooleanArray::from(vec![true, false, true]); - let mut union_builder = UnionBuilder::new_dense(3); + let mut union_builder = UnionBuilder::new_dense(); union_builder.append::("a", 1).unwrap(); union_builder.append::("b", 10.1).unwrap(); union_builder.append_null::("b").unwrap(); @@ -1718,28 +1480,12 @@ mod tests { #[test] fn test_roundtrip_dense_union() { - check_union_with_builder(UnionBuilder::new_dense(6)); + check_union_with_builder(UnionBuilder::new_dense()); } #[test] fn test_roundtrip_sparse_union() { - check_union_with_builder(UnionBuilder::new_sparse(6)); - } - - /// Read gzipped JSON file - fn read_gzip_json(version: &str, path: &str) -> ArrowJson { - let testdata = crate::util::test_util::arrow_test_data(); - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.json.gz", - testdata, version, path - )) - .unwrap(); - let mut gz = GzDecoder::new(&file); - let mut s = String::new(); - gz.read_to_string(&mut s).unwrap(); - // convert to Arrow JSON - let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); - arrow_json + check_union_with_builder(UnionBuilder::new_sparse()); } #[test] diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs index 3847661db5ec..63f1520a5e9c 100644 --- a/arrow/src/ipc/writer.rs +++ b/arrow/src/ipc/writer.rs @@ -39,6 +39,7 @@ use crate::ipc; use crate::record_batch::RecordBatch; use crate::util::bit_util; +use crate::ipc::compression::CompressionCodec; use ipc::CONTINUATION_MARKER; /// IPC write options used to control the behaviour of the writer @@ -58,9 +59,30 @@ pub struct IpcWriteOptions { /// version 2.0.0: V4, with legacy format enabled /// version 4.0.0: V5 metadata_version: ipc::MetadataVersion, + /// Compression, if desired. Only supported when `ipc_compression` + /// feature is enabled + batch_compression_type: Option, } impl IpcWriteOptions { + /// Configures compression when writing IPC files. Requires the + /// `ipc_compression` feature of the crate to be activated. + #[cfg(feature = "ipc_compression")] + pub fn try_with_compression( + mut self, + batch_compression_type: Option, + ) -> Result { + self.batch_compression_type = batch_compression_type; + + if self.batch_compression_type.is_some() + && self.metadata_version < ipc::MetadataVersion::V5 + { + return Err(ArrowError::InvalidArgumentError( + "Compression only supported in metadata v5 and above".to_string(), + )); + } + Ok(self) + } /// Try create IpcWriteOptions, checking for incompatible settings pub fn try_new( alignment: usize, @@ -82,6 +104,7 @@ impl IpcWriteOptions { alignment, write_legacy_ipc_format, metadata_version, + batch_compression_type: None, }), ipc::MetadataVersion::V5 => { if write_legacy_ipc_format { @@ -94,10 +117,14 @@ impl IpcWriteOptions { alignment, write_legacy_ipc_format, metadata_version, + batch_compression_type: None, }) } } - z => panic!("Unsupported ipc::MetadataVersion {:?}", z), + z => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported ipc::MetadataVersion {:?}", + z + ))), } } } @@ -108,6 +135,7 @@ impl Default for IpcWriteOptions { alignment: 8, write_legacy_ipc_format: false, metadata_version: ipc::MetadataVersion::V5, + batch_compression_type: None, } } } @@ -226,7 +254,7 @@ impl IpcDataGenerator { } DataType::Union(fields, _, _) => { let union = as_union_array(column); - for (field, ref column) in fields + for (field, column) in fields .iter() .enumerate() .map(|(n, f)| (f, union.child(n as i8))) @@ -278,7 +306,7 @@ impl IpcDataGenerator { dict_id, dict_values, write_options, - )); + )?); } } _ => self._encode_dictionaries( @@ -312,7 +340,7 @@ impl IpcDataGenerator { )?; } - let encoded_message = self.record_batch_to_bytes(batch, write_options); + let encoded_message = self.record_batch_to_bytes(batch, write_options)?; Ok((encoded_dictionaries, encoded_message)) } @@ -322,13 +350,27 @@ impl IpcDataGenerator { &self, batch: &RecordBatch, write_options: &IpcWriteOptions, - ) -> EncodedData { + ) -> Result { let mut fbb = FlatBufferBuilder::new(); let mut nodes: Vec = vec![]; let mut buffers: Vec = vec![]; let mut arrow_data: Vec = vec![]; let mut offset = 0; + + // get the type of compression + let batch_compression_type = write_options.batch_compression_type; + + let compression = batch_compression_type.map(|batch_compression_type| { + let mut c = ipc::BodyCompressionBuilder::new(&mut fbb); + c.add_method(ipc::BodyCompressionMethod::BUFFER); + c.add_codec(batch_compression_type); + c.finish() + }); + + let compression_codec: Option = + batch_compression_type.map(TryInto::try_into).transpose()?; + for array in batch.columns() { let array_data = array.data(); offset = write_array_data( @@ -339,19 +381,26 @@ impl IpcDataGenerator { offset, array.len(), array.null_count(), + &compression_codec, write_options, - ); + )?; } + // pad the tail of body data + let len = arrow_data.len(); + let pad_len = pad_to_8(len as u32); + arrow_data.extend_from_slice(&vec![0u8; pad_len][..]); // write data let buffers = fbb.create_vector(&buffers); let nodes = fbb.create_vector(&nodes); - let root = { let mut batch_builder = ipc::RecordBatchBuilder::new(&mut fbb); batch_builder.add_length(batch.num_rows() as i64); batch_builder.add_nodes(nodes); batch_builder.add_buffers(buffers); + if let Some(c) = compression { + batch_builder.add_compression(c); + } let b = batch_builder.finish(); b.as_union_value() }; @@ -365,10 +414,10 @@ impl IpcDataGenerator { fbb.finish(root, None); let finished_data = fbb.finished_data(); - EncodedData { + Ok(EncodedData { ipc_message: finished_data.to_vec(), arrow_data, - } + }) } /// Write dictionary values into two sets of bytes, one for the header (ipc::Message) and the @@ -378,13 +427,27 @@ impl IpcDataGenerator { dict_id: i64, array_data: &ArrayData, write_options: &IpcWriteOptions, - ) -> EncodedData { + ) -> Result { let mut fbb = FlatBufferBuilder::new(); let mut nodes: Vec = vec![]; let mut buffers: Vec = vec![]; let mut arrow_data: Vec = vec![]; + // get the type of compression + let batch_compression_type = write_options.batch_compression_type; + + let compression = batch_compression_type.map(|batch_compression_type| { + let mut c = ipc::BodyCompressionBuilder::new(&mut fbb); + c.add_method(ipc::BodyCompressionMethod::BUFFER); + c.add_codec(batch_compression_type); + c.finish() + }); + + let compression_codec: Option = batch_compression_type + .map(|batch_compression_type| batch_compression_type.try_into()) + .transpose()?; + write_array_data( array_data, &mut buffers, @@ -393,8 +456,14 @@ impl IpcDataGenerator { 0, array_data.len(), array_data.null_count(), + &compression_codec, write_options, - ); + )?; + + // pad the tail of body data + let len = arrow_data.len(); + let pad_len = pad_to_8(len as u32); + arrow_data.extend_from_slice(&vec![0u8; pad_len][..]); // write data let buffers = fbb.create_vector(&buffers); @@ -405,6 +474,9 @@ impl IpcDataGenerator { batch_builder.add_length(array_data.len() as i64); batch_builder.add_nodes(nodes); batch_builder.add_buffers(buffers); + if let Some(c) = compression { + batch_builder.add_compression(c); + } batch_builder.finish() }; @@ -427,10 +499,10 @@ impl IpcDataGenerator { fbb.finish(root, None); let finished_data = fbb.finished_data(); - EncodedData { + Ok(EncodedData { ipc_message: finished_data.to_vec(), arrow_data, - } + }) } } @@ -519,9 +591,10 @@ impl FileWriter { ) -> Result { let data_gen = IpcDataGenerator::default(); let mut writer = BufWriter::new(writer); - // write magic to header + // write magic to header aligned on 8 byte boundary + let header_size = super::ARROW_MAGIC.len() + 2; + assert_eq!(header_size, 8); writer.write_all(&super::ARROW_MAGIC[..])?; - // create an 8-byte boundary after the header writer.write_all(&[0, 0])?; // write the schema, set the written bytes to the schema + header let encoded_message = data_gen.schema_to_bytes(schema, &write_options); @@ -530,7 +603,7 @@ impl FileWriter { writer, write_options, schema: schema.clone(), - block_offsets: meta + data + 8, + block_offsets: meta + data + header_size, dictionary_blocks: vec![], record_blocks: vec![], finished: false, @@ -884,6 +957,16 @@ fn get_buffer_element_width(spec: &BufferSpec) -> usize { } } +/// Returns byte width for binary value_offset buffer spec. +#[inline] +fn get_value_offset_byte_width(data_type: &DataType) -> usize { + match data_type { + DataType::Binary | DataType::Utf8 => 4, + DataType::LargeBinary | DataType::LargeUtf8 => 8, + _ => unreachable!(), + } +} + /// Returns the number of total bytes in base binary arrays. fn get_binary_buffer_len(array_data: &ArrayData) -> usize { if array_data.is_empty() { @@ -974,8 +1057,9 @@ fn write_array_data( offset: i64, num_rows: usize, null_count: usize, + compression_codec: &Option, write_options: &IpcWriteOptions, -) -> i64 { +) -> Result { let mut offset = offset; if !matches!(array_data.data_type(), DataType::Null) { nodes.push(ipc::FieldNode::new(num_rows as i64, null_count as i64)); @@ -997,7 +1081,13 @@ fn write_array_data( Some(buffer) => buffer.bit_slice(array_data.offset(), array_data.len()), }; - offset = write_buffer(null_buffer.as_slice(), buffers, arrow_data, offset); + offset = write_buffer( + null_buffer.as_slice(), + buffers, + arrow_data, + offset, + compression_codec, + )?; } let data_type = array_data.data_type(); @@ -1005,13 +1095,16 @@ fn write_array_data( data_type, DataType::Binary | DataType::LargeBinary | DataType::Utf8 | DataType::LargeUtf8 ) { - let total_bytes = get_binary_buffer_len(array_data); - let value_buffer = &array_data.buffers()[1]; + let offset_buffer = &array_data.buffers()[0]; + let value_offset_byte_width = get_value_offset_byte_width(data_type); + let min_length = (array_data.len() + 1) * value_offset_byte_width; if buffer_need_truncate( array_data.offset(), - value_buffer, - &BufferSpec::VariableWidth, - total_bytes, + offset_buffer, + &BufferSpec::FixedWidth { + byte_width: value_offset_byte_width, + }, + min_length, ) { // Rebase offsets and truncate values let (new_offsets, byte_offset) = @@ -1027,16 +1120,36 @@ fn write_array_data( ) }; - offset = write_buffer(new_offsets.as_slice(), buffers, arrow_data, offset); + offset = write_buffer( + new_offsets.as_slice(), + buffers, + arrow_data, + offset, + compression_codec, + )?; + let total_bytes = get_binary_buffer_len(array_data); + let value_buffer = &array_data.buffers()[1]; let buffer_length = min(total_bytes, value_buffer.len() - byte_offset); let buffer_slice = &value_buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]; - offset = write_buffer(buffer_slice, buffers, arrow_data, offset); + offset = write_buffer( + buffer_slice, + buffers, + arrow_data, + offset, + compression_codec, + )?; } else { - array_data.buffers().iter().for_each(|buffer| { - offset = write_buffer(buffer.as_slice(), buffers, arrow_data, offset); - }); + for buffer in array_data.buffers() { + offset = write_buffer( + buffer.as_slice(), + buffers, + arrow_data, + offset, + compression_codec, + )?; + } } } else if DataType::is_numeric(data_type) || DataType::is_temporal(data_type) @@ -1059,19 +1172,32 @@ fn write_array_data( let buffer_length = min(min_length, buffer.len() - byte_offset); let buffer_slice = &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]; - offset = write_buffer(buffer_slice, buffers, arrow_data, offset); + offset = write_buffer( + buffer_slice, + buffers, + arrow_data, + offset, + compression_codec, + )?; } else { - offset = write_buffer(buffer.as_slice(), buffers, arrow_data, offset); + offset = write_buffer( + buffer.as_slice(), + buffers, + arrow_data, + offset, + compression_codec, + )?; } } else { - array_data.buffers().iter().for_each(|buffer| { - offset = write_buffer(buffer, buffers, arrow_data, offset); - }); + for buffer in array_data.buffers() { + offset = + write_buffer(buffer, buffers, arrow_data, offset, compression_codec)?; + } } if !matches!(array_data.data_type(), DataType::Dictionary(_, _)) { // recursively write out nested structures - array_data.child_data().iter().for_each(|data_ref| { + for data_ref in array_data.child_data() { // write the nested data (e.g list data) offset = write_array_data( data_ref, @@ -1081,29 +1207,56 @@ fn write_array_data( offset, data_ref.len(), data_ref.null_count(), + compression_codec, write_options, - ); - }); + )?; + } } - offset + Ok(offset) } -/// Write a buffer to a vector of bytes, and add its ipc::Buffer to a vector +/// Write a buffer into `arrow_data`, a vector of bytes, and adds its +/// [`ipc::Buffer`] to `buffers`. Returns the new offset in `arrow_data` +/// +/// +/// From +/// Each constituent buffer is first compressed with the indicated +/// compressor, and then written with the uncompressed length in the first 8 +/// bytes as a 64-bit little-endian signed integer followed by the compressed +/// buffer bytes (and then padding as required by the protocol). The +/// uncompressed length may be set to -1 to indicate that the data that +/// follows is not compressed, which can be useful for cases where +/// compression does not yield appreciable savings. fn write_buffer( - buffer: &[u8], - buffers: &mut Vec, - arrow_data: &mut Vec, - offset: i64, -) -> i64 { - let len = buffer.len(); - let pad_len = pad_to_8(len as u32); - let total_len: i64 = (len + pad_len) as i64; - // assert_eq!(len % 8, 0, "Buffer width not a multiple of 8 bytes"); - buffers.push(ipc::Buffer::new(offset, total_len)); - arrow_data.extend_from_slice(buffer); - arrow_data.extend_from_slice(&vec![0u8; pad_len][..]); - offset + total_len + buffer: &[u8], // input + buffers: &mut Vec, // output buffer descriptors + arrow_data: &mut Vec, // output stream + offset: i64, // current output stream offset + compression_codec: &Option, +) -> Result { + let len: i64 = match compression_codec { + Some(compressor) => compressor.compress_to_vec(buffer, arrow_data)?, + None => { + arrow_data.extend_from_slice(buffer); + buffer.len() + } + } + .try_into() + .map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Could not convert compressed size to i64: {}", + e + )) + })?; + + // make new index entry + buffers.push(ipc::Buffer::new(offset, len)); + // padding and make offset 8 bytes aligned + let pad_len = pad_to_8(len as u32) as i64; + arrow_data.extend_from_slice(&vec![0u8; pad_len as usize][..]); + + Ok(offset + len + pad_len) } /// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes @@ -1117,16 +1270,170 @@ mod tests { use super::*; use std::fs::File; - use std::io::Read; + use std::io::Seek; use std::sync::Arc; - use flate2::read::GzDecoder; use ipc::MetadataVersion; use crate::array::*; use crate::datatypes::Field; use crate::ipc::reader::*; - use crate::util::integration_util::*; + + #[test] + #[cfg(feature = "ipc_compression")] + fn test_write_empty_record_batch_lz4_compression() { + let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]); + let values: Vec> = vec![]; + let array = Int32Array::from(values); + let record_batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]) + .unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + { + let write_option = + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + .unwrap() + .try_with_compression(Some(ipc::CompressionType::LZ4_FRAME)) + .unwrap(); + + let mut writer = + FileWriter::try_new_with_options(&mut file, &schema, write_option) + .unwrap(); + writer.write(&record_batch).unwrap(); + writer.finish().unwrap(); + } + file.rewind().unwrap(); + { + // read file + let mut reader = FileReader::try_new(file, None).unwrap(); + loop { + match reader.next() { + Some(Ok(read_batch)) => { + read_batch + .columns() + .iter() + .zip(record_batch.columns()) + .for_each(|(a, b)| { + assert_eq!(a.data_type(), b.data_type()); + assert_eq!(a.len(), b.len()); + assert_eq!(a.null_count(), b.null_count()); + }); + } + Some(Err(e)) => { + panic!("{}", e); + } + None => { + break; + } + } + } + } + } + + #[test] + #[cfg(feature = "ipc_compression")] + fn test_write_file_with_lz4_compression() { + let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]); + let values: Vec> = vec![Some(12), Some(1)]; + let array = Int32Array::from(values); + let record_batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]) + .unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + { + let write_option = + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + .unwrap() + .try_with_compression(Some(ipc::CompressionType::LZ4_FRAME)) + .unwrap(); + + let mut writer = + FileWriter::try_new_with_options(&mut file, &schema, write_option) + .unwrap(); + writer.write(&record_batch).unwrap(); + writer.finish().unwrap(); + } + file.rewind().unwrap(); + { + // read file + let mut reader = FileReader::try_new(file, None).unwrap(); + loop { + match reader.next() { + Some(Ok(read_batch)) => { + read_batch + .columns() + .iter() + .zip(record_batch.columns()) + .for_each(|(a, b)| { + assert_eq!(a.data_type(), b.data_type()); + assert_eq!(a.len(), b.len()); + assert_eq!(a.null_count(), b.null_count()); + }); + } + Some(Err(e)) => { + panic!("{}", e); + } + None => { + break; + } + } + } + } + } + + #[test] + #[cfg(feature = "ipc_compression")] + fn test_write_file_with_zstd_compression() { + let schema = Schema::new(vec![Field::new("field1", DataType::Int32, true)]); + let values: Vec> = vec![Some(12), Some(1)]; + let array = Int32Array::from(values); + let record_batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)]) + .unwrap(); + let mut file = tempfile::tempfile().unwrap(); + { + let write_option = + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + .unwrap() + .try_with_compression(Some(ipc::CompressionType::ZSTD)) + .unwrap(); + + let mut writer = + FileWriter::try_new_with_options(&mut file, &schema, write_option) + .unwrap(); + writer.write(&record_batch).unwrap(); + writer.finish().unwrap(); + } + file.rewind().unwrap(); + { + // read file + let mut reader = FileReader::try_new(file, None).unwrap(); + loop { + match reader.next() { + Some(Ok(read_batch)) => { + read_batch + .columns() + .iter() + .zip(record_batch.columns()) + .for_each(|(a, b)| { + assert_eq!(a.data_type(), b.data_type()); + assert_eq!(a.len(), b.len()); + assert_eq!(a.null_count(), b.null_count()); + }); + } + Some(Err(e)) => { + panic!("{}", e); + } + None => { + break; + } + } + } + } + } #[test] fn test_write_file() { @@ -1148,18 +1455,16 @@ mod tests { vec![Arc::new(array1) as ArrayRef], ) .unwrap(); + let mut file = tempfile::tempfile().unwrap(); { - let file = File::create("target/debug/testdata/arrow.arrow_file").unwrap(); - let mut writer = FileWriter::try_new(file, &schema).unwrap(); + let mut writer = FileWriter::try_new(&mut file, &schema).unwrap(); writer.write(&batch).unwrap(); writer.finish().unwrap(); } + file.rewind().unwrap(); { - let file = - File::open(format!("target/debug/testdata/{}.arrow_file", "arrow")) - .unwrap(); let mut reader = FileReader::try_new(file, None).unwrap(); while let Some(Ok(read_batch)) = reader.next() { read_batch @@ -1255,251 +1560,6 @@ mod tests { ); } - #[test] - #[cfg(not(feature = "force_validate"))] - fn read_and_rewrite_generated_files_014() { - let testdata = crate::util::test_util::arrow_test_data(); - let version = "0.14.1"; - // the test is repetitive, thus we can read all supported files at once - let paths = vec![ - "generated_interval", - "generated_datetime", - "generated_dictionary", - "generated_map", - "generated_nested", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - "generated_decimal", - ]; - paths.iter().for_each(|path| { - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", - testdata, version, path - )) - .unwrap(); - - let mut reader = FileReader::try_new(file, None).unwrap(); - - // read and rewrite the file to a temp location - { - let file = File::create(format!( - "target/debug/testdata/{}-{}.arrow_file", - version, path - )) - .unwrap(); - let mut writer = FileWriter::try_new(file, &reader.schema()).unwrap(); - while let Some(Ok(batch)) = reader.next() { - writer.write(&batch).unwrap(); - } - writer.finish().unwrap(); - } - - let file = File::open(format!( - "target/debug/testdata/{}-{}.arrow_file", - version, path - )) - .unwrap(); - let mut reader = FileReader::try_new(file, None).unwrap(); - - // read expected JSON output - let arrow_json = read_gzip_json(version, path); - assert!(arrow_json.equals_reader(&mut reader)); - }); - } - - #[test] - #[cfg(not(feature = "force_validate"))] - fn read_and_rewrite_generated_streams_014() { - let testdata = crate::util::test_util::arrow_test_data(); - let version = "0.14.1"; - // the test is repetitive, thus we can read all supported files at once - let paths = vec![ - "generated_interval", - "generated_datetime", - "generated_dictionary", - "generated_map", - "generated_nested", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - "generated_decimal", - ]; - paths.iter().for_each(|path| { - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.stream", - testdata, version, path - )) - .unwrap(); - - let reader = StreamReader::try_new(file, None).unwrap(); - - // read and rewrite the stream to a temp location - { - let file = File::create(format!( - "target/debug/testdata/{}-{}.stream", - version, path - )) - .unwrap(); - let mut writer = StreamWriter::try_new(file, &reader.schema()).unwrap(); - reader.for_each(|batch| { - writer.write(&batch.unwrap()).unwrap(); - }); - writer.finish().unwrap(); - } - - let file = - File::open(format!("target/debug/testdata/{}-{}.stream", version, path)) - .unwrap(); - let mut reader = StreamReader::try_new(file, None).unwrap(); - - // read expected JSON output - let arrow_json = read_gzip_json(version, path); - assert!(arrow_json.equals_reader(&mut reader)); - }); - } - - #[test] - fn read_and_rewrite_generated_files_100() { - let testdata = crate::util::test_util::arrow_test_data(); - let version = "1.0.0-littleendian"; - // the test is repetitive, thus we can read all supported files at once - let paths = vec![ - "generated_custom_metadata", - "generated_datetime", - "generated_dictionary_unsigned", - "generated_dictionary", - // "generated_duplicate_fieldnames", - "generated_interval", - "generated_map", - "generated_nested", - // "generated_nested_large_offsets", - "generated_null_trivial", - "generated_null", - "generated_primitive_large_offsets", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - // "generated_recursive_nested", - ]; - paths.iter().for_each(|path| { - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", - testdata, version, path - )) - .unwrap(); - - let mut reader = FileReader::try_new(file, None).unwrap(); - - // read and rewrite the file to a temp location - { - let file = File::create(format!( - "target/debug/testdata/{}-{}.arrow_file", - version, path - )) - .unwrap(); - // write IPC version 5 - let options = - IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5).unwrap(); - let mut writer = - FileWriter::try_new_with_options(file, &reader.schema(), options) - .unwrap(); - while let Some(Ok(batch)) = reader.next() { - writer.write(&batch).unwrap(); - } - writer.finish().unwrap(); - } - - let file = File::open(format!( - "target/debug/testdata/{}-{}.arrow_file", - version, path - )) - .unwrap(); - let mut reader = FileReader::try_new(file, None).unwrap(); - - // read expected JSON output - let arrow_json = read_gzip_json(version, path); - assert!(arrow_json.equals_reader(&mut reader)); - }); - } - - #[test] - fn read_and_rewrite_generated_streams_100() { - let testdata = crate::util::test_util::arrow_test_data(); - let version = "1.0.0-littleendian"; - // the test is repetitive, thus we can read all supported files at once - let paths = vec![ - "generated_custom_metadata", - "generated_datetime", - "generated_dictionary_unsigned", - "generated_dictionary", - // "generated_duplicate_fieldnames", - "generated_interval", - "generated_map", - "generated_nested", - // "generated_nested_large_offsets", - "generated_null_trivial", - "generated_null", - "generated_primitive_large_offsets", - "generated_primitive_no_batches", - "generated_primitive_zerolength", - "generated_primitive", - // "generated_recursive_nested", - ]; - paths.iter().for_each(|path| { - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.stream", - testdata, version, path - )) - .unwrap(); - - let reader = StreamReader::try_new(file, None).unwrap(); - - // read and rewrite the stream to a temp location - { - let file = File::create(format!( - "target/debug/testdata/{}-{}.stream", - version, path - )) - .unwrap(); - let options = - IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5).unwrap(); - let mut writer = - StreamWriter::try_new_with_options(file, &reader.schema(), options) - .unwrap(); - reader.for_each(|batch| { - writer.write(&batch.unwrap()).unwrap(); - }); - writer.finish().unwrap(); - } - - let file = - File::open(format!("target/debug/testdata/{}-{}.stream", version, path)) - .unwrap(); - let mut reader = StreamReader::try_new(file, None).unwrap(); - - // read expected JSON output - let arrow_json = read_gzip_json(version, path); - assert!(arrow_json.equals_reader(&mut reader)); - }); - } - - /// Read gzipped JSON file - fn read_gzip_json(version: &str, path: &str) -> ArrowJson { - let testdata = crate::util::test_util::arrow_test_data(); - let file = File::open(format!( - "{}/arrow-ipc-stream/integration/{}/{}.json.gz", - testdata, version, path - )) - .unwrap(); - let mut gz = GzDecoder::new(&file); - let mut s = String::new(); - gz.read_to_string(&mut s).unwrap(); - // convert to Arrow JSON - let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); - arrow_json - } - #[test] fn track_union_nested_dict() { let inner: DictionaryArray = vec!["a", "b", "a"].into_iter().collect(); @@ -1567,7 +1627,6 @@ mod tests { #[test] fn read_union_017() { let testdata = crate::util::test_util::arrow_test_data(); - let version = "0.17.1"; let data_file = File::open(format!( "{}/arrow-ipc-stream/integration/0.17.1/generated_union.stream", testdata, @@ -1576,26 +1635,18 @@ mod tests { let reader = StreamReader::try_new(data_file, None).unwrap(); + let mut file = tempfile::tempfile().unwrap(); // read and rewrite the stream to a temp location { - let file = File::create(format!( - "target/debug/testdata/{}-generated_union.stream", - version - )) - .unwrap(); - let mut writer = StreamWriter::try_new(file, &reader.schema()).unwrap(); + let mut writer = StreamWriter::try_new(&mut file, &reader.schema()).unwrap(); reader.for_each(|batch| { writer.write(&batch.unwrap()).unwrap(); }); writer.finish().unwrap(); } + file.rewind().unwrap(); // Compare original file and rewrote file - let file = File::open(format!( - "target/debug/testdata/{}-generated_union.stream", - version - )) - .unwrap(); let rewrite_reader = StreamReader::try_new(file, None).unwrap(); let data_file = File::open(format!( @@ -1625,7 +1676,7 @@ mod tests { ), true, )]); - let mut builder = UnionBuilder::new_sparse(5); + let mut builder = UnionBuilder::with_capacity_sparse(5); builder.append::("a", 1).unwrap(); builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); @@ -1638,18 +1689,18 @@ mod tests { vec![Arc::new(union) as ArrayRef], ) .unwrap(); - let file_name = "target/debug/testdata/union.arrow_file"; + + let mut file = tempfile::tempfile().unwrap(); { - let file = File::create(&file_name).unwrap(); let mut writer = - FileWriter::try_new_with_options(file, &schema, options).unwrap(); + FileWriter::try_new_with_options(&mut file, &schema, options).unwrap(); writer.write(&batch).unwrap(); writer.finish().unwrap(); } + file.rewind().unwrap(); { - let file = File::open(&file_name).unwrap(); let reader = FileReader::try_new(file, None).unwrap(); reader.for_each(|maybe_batch| { maybe_batch @@ -1832,4 +1883,21 @@ mod tests { assert!(structs.column(1).is_null(1)); assert_eq!(record_batch_slice, deserialized_batch); } + + #[test] + fn truncate_ipc_string_array_with_all_empty_string() { + fn create_batch() -> RecordBatch { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let a = + StringArray::from(vec![Some(""), Some(""), Some(""), Some(""), Some("")]); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap() + } + + let record_batch = create_batch(); + let record_batch_slice = record_batch.slice(0, 1); + let deserialized_batch = deserialize(serialize(&record_batch_slice)); + + assert!(serialize(&record_batch).len() > serialize(&record_batch_slice).len()); + assert_eq!(record_batch_slice, deserialized_batch); + } } diff --git a/arrow/src/json/mod.rs b/arrow/src/json/mod.rs index 6b3df188a476..836145bb08e4 100644 --- a/arrow/src/json/mod.rs +++ b/arrow/src/json/mod.rs @@ -25,3 +25,58 @@ pub mod writer; pub use self::reader::Reader; pub use self::reader::ReaderBuilder; pub use self::writer::{ArrayWriter, LineDelimitedWriter, Writer}; +use half::f16; +use serde_json::{Number, Value}; + +/// Trait declaring any type that is serializable to JSON. This includes all primitive types (bool, i32, etc.). +pub trait JsonSerializable: 'static { + fn into_json_value(self) -> Option; +} + +macro_rules! json_serializable { + ($t:ty) => { + impl JsonSerializable for $t { + fn into_json_value(self) -> Option { + Some(self.into()) + } + } + }; +} + +json_serializable!(bool); +json_serializable!(u8); +json_serializable!(u16); +json_serializable!(u32); +json_serializable!(u64); +json_serializable!(i8); +json_serializable!(i16); +json_serializable!(i32); +json_serializable!(i64); + +impl JsonSerializable for i128 { + fn into_json_value(self) -> Option { + // Serialize as string to avoid issues with arbitrary_precision serde_json feature + // - https://github.com/serde-rs/json/issues/559 + // - https://github.com/serde-rs/json/issues/845 + // - https://github.com/serde-rs/json/issues/846 + Some(self.to_string().into()) + } +} + +impl JsonSerializable for f16 { + fn into_json_value(self) -> Option { + Number::from_f64(f64::round(f64::from(self) * 1000.0) / 1000.0).map(Value::Number) + } +} + +impl JsonSerializable for f32 { + fn into_json_value(self) -> Option { + Number::from_f64(f64::round(self as f64 * 1000.0) / 1000.0).map(Value::Number) + } +} + +impl JsonSerializable for f64 { + fn into_json_value(self) -> Option { + Number::from_f64(self).map(Value::Number) + } +} diff --git a/arrow/src/json/reader.rs b/arrow/src/json/reader.rs index 260d185dad18..fb8f6cfab477 100644 --- a/arrow/src/json/reader.rs +++ b/arrow/src/json/reader.rs @@ -58,7 +58,7 @@ use serde_json::{map::Map as JsonMap, Value}; use crate::buffer::MutableBuffer; use crate::datatypes::*; use crate::error::{ArrowError, Result}; -use crate::record_batch::RecordBatch; +use crate::record_batch::{RecordBatch, RecordBatchOptions}; use crate::util::bit_util; use crate::util::reader_parser::Parser; use crate::{array::*, buffer::Buffer}; @@ -590,7 +590,7 @@ pub struct Decoder { options: DecoderOptions, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] /// Options for JSON decoding pub struct DecoderOptions { /// Batch size (number of records to load each time), defaults to 1024 records @@ -698,22 +698,34 @@ impl Decoder { } let rows = &rows[..]; - let projection = self.options.projection.clone().unwrap_or_default(); - let arrays = self.build_struct_array(rows, self.schema.fields(), &projection); - let projected_fields: Vec = if projection.is_empty() { - self.schema.fields().to_vec() - } else { + let arrays = + self.build_struct_array(rows, self.schema.fields(), &self.options.projection); + + let projected_fields = if let Some(projection) = self.options.projection.as_ref() + { projection .iter() .filter_map(|name| self.schema.column_with_name(name)) .map(|(_, field)| field.clone()) .collect() + } else { + self.schema.fields().to_vec() }; let projected_schema = Arc::new(Schema::new(projected_fields)); - arrays.and_then(|arr| RecordBatch::try_new(projected_schema, arr).map(Some)) + arrays.and_then(|arr| { + RecordBatch::try_new_with_options( + projected_schema, + arr, + &RecordBatchOptions { + match_field_names: true, + row_count: Some(rows.len()), + }, + ) + .map(Some) + }) } fn build_wrapped_list_array( @@ -798,12 +810,13 @@ impl Decoder { { let mut builder: Box = match data_type { DataType::Utf8 => { - let values_builder = StringBuilder::new(rows.len() * 5); + let values_builder = + StringBuilder::with_capacity(rows.len(), rows.len() * 5); Box::new(ListBuilder::new(values_builder)) } DataType::Dictionary(_, _) => { let values_builder = - self.build_string_dictionary_builder::
(rows.len() * 5)?; + self.build_string_dictionary_builder::
(rows.len() * 5); Box::new(ListBuilder::new(values_builder)) } e => { @@ -855,14 +868,14 @@ impl Decoder { ))?; for val in vals { if let Some(v) = val { - builder.values().append_value(&v)? + builder.values().append_value(&v); } else { - builder.values().append_null()? + builder.values().append_null(); }; } // Append to the list - builder.append(true)?; + builder.append(true); } DataType::Dictionary(_, _) => { let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||ArrowError::JsonError( @@ -870,14 +883,14 @@ impl Decoder { ))?; for val in vals { if let Some(v) = val { - let _ = builder.values().append(&v)?; + let _ = builder.values().append(&v); } else { - builder.values().append_null()? + builder.values().append_null(); }; } // Append to the list - builder.append(true)?; + builder.append(true); } e => { return Err(ArrowError::JsonError(format!( @@ -897,13 +910,13 @@ impl Decoder { fn build_string_dictionary_builder( &self, row_len: usize, - ) -> Result> + ) -> StringDictionaryBuilder where T: ArrowPrimitiveType + ArrowDictionaryKeyType, { - let key_builder = PrimitiveBuilder::::new(row_len); - let values_builder = StringBuilder::new(row_len * 5); - Ok(StringDictionaryBuilder::new(key_builder, values_builder)) + let key_builder = PrimitiveBuilder::::with_capacity(row_len); + let values_builder = StringBuilder::with_capacity(row_len, row_len * 5); + StringDictionaryBuilder::new(key_builder, values_builder) } #[inline(always)] @@ -950,16 +963,16 @@ impl Decoder { } fn build_boolean_array(&self, rows: &[Value], col_name: &str) -> Result { - let mut builder = BooleanBuilder::new(rows.len()); + let mut builder = BooleanBuilder::with_capacity(rows.len()); for row in rows { if let Some(value) = row.get(&col_name) { if let Some(boolean) = value.as_bool() { - builder.append_value(boolean)? + builder.append_value(boolean); } else { - builder.append_null()?; + builder.append_null(); } } else { - builder.append_null()?; + builder.append_null(); } } Ok(Arc::new(builder.finish())) @@ -1137,7 +1150,7 @@ impl Decoder { }) .collect(); let arrays = - self.build_struct_array(rows.as_slice(), fields.as_slice(), &[])?; + self.build_struct_array(rows.as_slice(), fields.as_slice(), &None)?; let data_type = DataType::Struct(fields.clone()); let buf = null_buffer.into(); unsafe { @@ -1170,18 +1183,23 @@ impl Decoder { /// /// *Note*: The function is recursive, and will read nested structs. /// - /// If `projection` is not empty, then all values are returned. The first level of projection + /// If `projection` is &None, then all values are returned. The first level of projection /// occurs at the `RecordBatch` level. No further projection currently occurs, but would be /// useful if plucking values from a struct, e.g. getting `a.b.c.e` from `a.b.c.{d, e}`. fn build_struct_array( &self, rows: &[Value], struct_fields: &[Field], - projection: &[String], + projection: &Option>, ) -> Result> { let arrays: Result> = struct_fields .iter() - .filter(|field| projection.is_empty() || projection.contains(field.name())) + .filter(|field| { + projection + .as_ref() + .map(|p| p.contains(field.name())) + .unwrap_or(true) + }) .map(|field| { match field.data_type() { DataType::Null => { @@ -1344,7 +1362,7 @@ impl Decoder { }) .collect::>(); let arrays = - self.build_struct_array(&struct_rows, fields, &[])?; + self.build_struct_array(&struct_rows, fields, &None)?; // construct a struct array's data in order to set null buffer let data_type = DataType::Struct(fields.clone()); let data = ArrayDataBuilder::new(data_type) @@ -1441,7 +1459,7 @@ impl Decoder { let struct_children = self.build_struct_array( struct_rows.as_slice(), &[key_field.clone(), value_field.clone()], - &[], + &None, )?; unsafe { @@ -1479,16 +1497,16 @@ impl Decoder { T: ArrowPrimitiveType + ArrowDictionaryKeyType, { let mut builder: StringDictionaryBuilder = - self.build_string_dictionary_builder(rows.len())?; + self.build_string_dictionary_builder(rows.len()); for row in rows { if let Some(value) = row.get(&col_name) { if let Some(str_v) = value.as_str() { builder.append(str_v).map(drop)? } else { - builder.append_null()? + builder.append_null(); } } else { - builder.append_null()? + builder.append_null(); } } Ok(Arc::new(builder.finish()) as ArrayRef) @@ -1805,6 +1823,21 @@ mod tests { assert_eq!("text", dd.value(8)); } + #[test] + fn test_json_empty_projection() { + let builder = ReaderBuilder::new() + .infer_schema(None) + .with_batch_size(64) + .with_projection(vec![]); + let mut reader: Reader = builder + .build::(File::open("test/data/basic.json").unwrap()) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + + assert_eq!(0, batch.num_columns()); + assert_eq!(12, batch.num_rows()); + } + #[test] fn test_json_basic_with_nulls() { let builder = ReaderBuilder::new().infer_schema(None).with_batch_size(64); @@ -2623,7 +2656,7 @@ mod tests { let re = builder.build(Cursor::new(json_content)); assert_eq!( re.err().unwrap().to_string(), - r#"Json error: Expected JSON record to be an object, found Array([Number(1), String("hello")])"#, + r#"Json error: Expected JSON record to be an object, found Array [Number(1), String("hello")]"#, ); } diff --git a/arrow/src/json/writer.rs b/arrow/src/json/writer.rs index 0755a5758e4e..bf40b31b494e 100644 --- a/arrow/src/json/writer.rs +++ b/arrow/src/json/writer.rs @@ -111,11 +111,14 @@ use serde_json::Value; use crate::array::*; use crate::datatypes::*; use crate::error::{ArrowError, Result}; +use crate::json::JsonSerializable; use crate::record_batch::RecordBatch; -fn primitive_array_to_json( - array: &ArrayRef, -) -> Result> { +fn primitive_array_to_json(array: &ArrayRef) -> Result> +where + T: ArrowPrimitiveType, + T::Native: JsonSerializable, +{ Ok(as_primitive_array::(array) .iter() .map(|maybe_value| match maybe_value { @@ -239,12 +242,15 @@ macro_rules! set_temporal_column_by_array_type { }; } -fn set_column_by_primitive_type( +fn set_column_by_primitive_type( rows: &mut [JsonMap], row_count: usize, array: &ArrayRef, col_name: &str, -) { +) where + T: ArrowPrimitiveType, + T::Native: JsonSerializable, +{ let primitive_arr = as_primitive_array::(array); rows.iter_mut() @@ -745,6 +751,21 @@ mod tests { use super::*; + /// Asserts that the NDJSON `input` is semantically identical to `expected` + fn assert_json_eq(input: &[u8], expected: &str) { + let expected: Vec> = expected + .split('\n') + .map(|s| (!s.is_empty()).then(|| serde_json::from_str(s).unwrap())) + .collect(); + + let actual: Vec> = input + .split(|b| *b == b'\n') + .map(|s| (!s.is_empty()).then(|| serde_json::from_slice(s).unwrap())) + .collect(); + + assert_eq!(expected, actual); + } + #[test] fn write_simple_rows() { let schema = Schema::new(vec![ @@ -765,14 +786,14 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"c1":1,"c2":"a"} {"c1":2,"c2":"b"} {"c1":3,"c2":"c"} {"c2":"d"} {"c1":5} -"# +"#, ); } @@ -796,14 +817,14 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"c1":"a","c2":"a"} {"c2":"b"} {"c1":"c"} {"c1":"d","c2":"d"} {} -"# +"#, ); } @@ -846,14 +867,14 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"c1":"cupcakes","c2":"sdsd"} {"c1":"foo","c2":"sdsd"} {"c1":"foo"} {"c2":"sd"} {"c1":"cupcakes","c2":"sdsd"} -"# +"#, ); } @@ -905,11 +926,11 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"nanos":"2018-11-13 17:11:10.011375885","micros":"2018-11-13 17:11:10.011375","millis":"2018-11-13 17:11:10.011","secs":"2018-11-13 17:11:10","name":"a"} {"name":"b"} -"# +"#, ); } @@ -951,11 +972,11 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"date32":"2018-11-13","date64":"2018-11-13","name":"a"} {"name":"b"} -"# +"#, ); } @@ -994,11 +1015,11 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"time32sec":"00:02:00","time32msec":"00:00:00.120","time64usec":"00:00:00.000120","time64nsec":"00:00:00.000000120","name":"a"} {"name":"b"} -"# +"#, ); } @@ -1037,11 +1058,11 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"duration_sec":"PT120S","duration_msec":"PT0.120S","duration_usec":"PT0.000120S","duration_nsec":"PT0.000000120S","name":"a"} {"name":"b"} -"# +"#, ); } @@ -1093,12 +1114,12 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"c1":{"c11":1,"c12":{"c121":"e"}},"c2":"a"} {"c1":{"c12":{"c121":"f"}},"c2":"b"} {"c1":{"c11":5,"c12":{"c121":"g"}},"c2":"c"} -"# +"#, ); } @@ -1136,14 +1157,14 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"c1":["a","a1"],"c2":1} {"c1":["b"],"c2":2} {"c1":["c"],"c2":3} {"c1":["d"],"c2":4} {"c1":["e"],"c2":5} -"# +"#, ); } @@ -1196,12 +1217,12 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"c1":[[1,2],[3]],"c2":"foo"} {"c1":[],"c2":"bar"} {"c1":[[4,5,6]]} -"# +"#, ); } @@ -1271,12 +1292,12 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"c1":[{"c11":1,"c12":{"c121":"e"}},{"c12":{"c121":"f"}}],"c2":1} {"c2":2} {"c1":[{"c11":5,"c12":{"c121":"g"}}],"c2":3} -"# +"#, ); } @@ -1396,15 +1417,15 @@ mod tests { // that implementations differ on the treatment of a null struct. // It would be more accurate to return a null struct, so this can be done // as a follow up. - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"list":[{"ints":1}]} {"list":[{}]} {"list":[]} {} {"list":[{}]} {"list":[{}]} -"# +"#, ); } @@ -1455,15 +1476,15 @@ mod tests { writer.write_batches(&[batch]).unwrap(); } - assert_eq!( - String::from_utf8(buf).unwrap(), + assert_json_eq( + &buf, r#"{"map":{"foo":10}} {"map":null} {"map":{}} {"map":{"bar":20,"baz":30,"qux":40}} {"map":{"quux":50}} {"map":{}} -"# +"#, ); } diff --git a/arrow/src/lib.rs b/arrow/src/lib.rs index 95c69ca0be6d..d1fb0cae0da2 100644 --- a/arrow/src/lib.rs +++ b/arrow/src/lib.rs @@ -18,6 +18,9 @@ //! A complete, safe, native Rust implementation of [Apache Arrow](https://arrow.apache.org), a cross-language //! development platform for in-memory data. //! +//! Please see the [arrow crates.io](https://crates.io/crates/arrow) +//! page for feature flags and tips to improve performance. +//! //! # Columnar Format //! //! The [`array`] module provides statically typed implementations of all the array @@ -57,6 +60,23 @@ //! assert_eq!(sum(&TimestampNanosecondArray::from(vec![1, 2, 3])), 6); //! ``` //! +//! And the following is generic over all arrays with comparable values +//! +//! ```rust +//! # use arrow::array::{ArrayAccessor, ArrayIter, Int32Array, StringArray}; +//! # use arrow::datatypes::ArrowPrimitiveType; +//! # +//! fn min(array: T) -> Option +//! where +//! T::Item: Ord +//! { +//! ArrayIter::new(array).filter_map(|v| v).min() +//! } +//! +//! assert_eq!(min(&Int32Array::from(vec![4, 2, 1, 6])), Some(1)); +//! assert_eq!(min(&StringArray::from(vec!["b", "a", "c"])), Some("a")); +//! ``` +//! //! For more examples, consult the [`array`] docs. //! //! # Type Erasure / Trait Objects @@ -238,10 +258,13 @@ pub mod compute; pub mod csv; pub mod datatypes; pub mod error; +#[cfg(feature = "ffi")] pub mod ffi; +#[cfg(feature = "ffi")] pub mod ffi_stream; #[cfg(feature = "ipc")] pub mod ipc; +#[cfg(feature = "serde_json")] pub mod json; #[cfg(feature = "pyarrow")] pub mod pyarrow; diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 3ae5b3b9987f..89463e4c8fd3 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! This library demonstrates a minimal usage of Rust's C data interface to pass +//! This module demonstrates a minimal usage of Rust's C data interface to pass //! arrays from and to Python. use std::convert::{From, TryFrom}; diff --git a/arrow/src/temporal_conversions.rs b/arrow/src/temporal_conversions.rs index 2d6d6776f59e..14fa82f6e7dc 100644 --- a/arrow/src/temporal_conversions.rs +++ b/arrow/src/temporal_conversions.rs @@ -20,13 +20,18 @@ use chrono::{Duration, NaiveDateTime, NaiveTime}; /// Number of seconds in a day -const SECONDS_IN_DAY: i64 = 86_400; +pub(crate) const SECONDS_IN_DAY: i64 = 86_400; /// Number of milliseconds in a second -const MILLISECONDS: i64 = 1_000; +pub(crate) const MILLISECONDS: i64 = 1_000; /// Number of microseconds in a second -const MICROSECONDS: i64 = 1_000_000; +pub(crate) const MICROSECONDS: i64 = 1_000_000; /// Number of nanoseconds in a second -const NANOSECONDS: i64 = 1_000_000_000; +pub(crate) const NANOSECONDS: i64 = 1_000_000_000; + +/// Number of milliseconds in a day +pub(crate) const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; +/// Number of days between 0001-01-01 and 1970-01-01 +pub(crate) const EPOCH_DAYS_FROM_CE: i32 = 719_163; /// converts a `i32` representing a `date32` to [`NaiveDateTime`] #[inline] @@ -37,11 +42,13 @@ pub fn date32_to_datetime(v: i32) -> NaiveDateTime { /// converts a `i64` representing a `date64` to [`NaiveDateTime`] #[inline] pub fn date64_to_datetime(v: i64) -> NaiveDateTime { + let (sec, milli_sec) = split_second(v, MILLISECONDS); + NaiveDateTime::from_timestamp( // extract seconds from milliseconds - v / MILLISECONDS, + sec, // discard extracted seconds and convert milliseconds to nanoseconds - (v % MILLISECONDS * MICROSECONDS) as u32, + milli_sec * MICROSECONDS as u32, ) } @@ -96,36 +103,46 @@ pub fn timestamp_s_to_datetime(v: i64) -> NaiveDateTime { /// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] #[inline] pub fn timestamp_ms_to_datetime(v: i64) -> NaiveDateTime { + let (sec, milli_sec) = split_second(v, MILLISECONDS); + NaiveDateTime::from_timestamp( // extract seconds from milliseconds - v / MILLISECONDS, + sec, // discard extracted seconds and convert milliseconds to nanoseconds - (v % MILLISECONDS * MICROSECONDS) as u32, + milli_sec * MICROSECONDS as u32, ) } /// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] #[inline] pub fn timestamp_us_to_datetime(v: i64) -> NaiveDateTime { + let (sec, micro_sec) = split_second(v, MICROSECONDS); + NaiveDateTime::from_timestamp( // extract seconds from microseconds - v / MICROSECONDS, + sec, // discard extracted seconds and convert microseconds to nanoseconds - (v % MICROSECONDS * MILLISECONDS) as u32, + micro_sec * MILLISECONDS as u32, ) } /// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] #[inline] pub fn timestamp_ns_to_datetime(v: i64) -> NaiveDateTime { + let (sec, nano_sec) = split_second(v, NANOSECONDS); + NaiveDateTime::from_timestamp( // extract seconds from nanoseconds - v / NANOSECONDS, - // discard extracted seconds - (v % NANOSECONDS) as u32, + sec, // discard extracted seconds + nano_sec, ) } +#[inline] +pub(crate) fn split_second(v: i64, base: i64) -> (i64, u32) { + (v.div_euclid(base), v.rem_euclid(base) as u32) +} + /// converts a `i64` representing a `duration(s)` to [`Duration`] #[inline] pub fn duration_s_to_duration(v: i64) -> Duration { @@ -149,3 +166,83 @@ pub fn duration_us_to_duration(v: i64) -> Duration { pub fn duration_ns_to_duration(v: i64) -> Duration { Duration::nanoseconds(v) } + +#[cfg(test)] +mod tests { + use crate::temporal_conversions::{ + date64_to_datetime, split_second, timestamp_ms_to_datetime, + timestamp_ns_to_datetime, timestamp_us_to_datetime, NANOSECONDS, + }; + use chrono::NaiveDateTime; + + #[test] + fn negative_input_timestamp_ns_to_datetime() { + assert_eq!( + timestamp_ns_to_datetime(-1), + NaiveDateTime::from_timestamp(-1, 999_999_999) + ); + + assert_eq!( + timestamp_ns_to_datetime(-1_000_000_001), + NaiveDateTime::from_timestamp(-2, 999_999_999) + ); + } + + #[test] + fn negative_input_timestamp_us_to_datetime() { + assert_eq!( + timestamp_us_to_datetime(-1), + NaiveDateTime::from_timestamp(-1, 999_999_000) + ); + + assert_eq!( + timestamp_us_to_datetime(-1_000_001), + NaiveDateTime::from_timestamp(-2, 999_999_000) + ); + } + + #[test] + fn negative_input_timestamp_ms_to_datetime() { + assert_eq!( + timestamp_ms_to_datetime(-1), + NaiveDateTime::from_timestamp(-1, 999_000_000) + ); + + assert_eq!( + timestamp_ms_to_datetime(-1_001), + NaiveDateTime::from_timestamp(-2, 999_000_000) + ); + } + + #[test] + fn negative_input_date64_to_datetime() { + assert_eq!( + date64_to_datetime(-1), + NaiveDateTime::from_timestamp(-1, 999_000_000) + ); + + assert_eq!( + date64_to_datetime(-1_001), + NaiveDateTime::from_timestamp(-2, 999_000_000) + ); + } + + #[test] + fn test_split_seconds() { + let (sec, nano_sec) = split_second(100, NANOSECONDS); + assert_eq!(sec, 0); + assert_eq!(nano_sec, 100); + + let (sec, nano_sec) = split_second(123_000_000_456, NANOSECONDS); + assert_eq!(sec, 123); + assert_eq!(nano_sec, 456); + + let (sec, nano_sec) = split_second(-1, NANOSECONDS); + assert_eq!(sec, -1); + assert_eq!(nano_sec, 999_999_999); + + let (sec, nano_sec) = split_second(-123_000_000_001, NANOSECONDS); + assert_eq!(sec, -124); + assert_eq!(nano_sec, 999_999_999); + } +} diff --git a/arrow/src/util/data_gen.rs b/arrow/src/util/data_gen.rs index 21b8ee8c9fd1..4d974409a0ee 100644 --- a/arrow/src/util/data_gen.rs +++ b/arrow/src/util/data_gen.rs @@ -143,6 +143,17 @@ pub fn create_random_array( }) .collect::>>()?, )?), + d @ Dictionary(_, value_type) + if crate::compute::can_cast_types(value_type, d) => + { + let f = Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ); + let v = create_random_array(&f, size, null_density, true_density)?; + crate::compute::cast(&v, d)? + } other => { return Err(ArrowError::NotYetImplemented(format!( "Generating random arrays not yet implemented for {:?}", diff --git a/arrow/src/util/decimal.rs b/arrow/src/util/decimal.rs index 4d67245647d6..421942df5c1b 100644 --- a/arrow/src/util/decimal.rs +++ b/arrow/src/util/decimal.rs @@ -17,25 +17,84 @@ //! Decimal related utils +use crate::datatypes::{ + DataType, Decimal128Type, Decimal256Type, DecimalType, DECIMAL256_MAX_PRECISION, + DECIMAL_DEFAULT_SCALE, +}; use crate::error::{ArrowError, Result}; use num::bigint::BigInt; +use num::Signed; use std::cmp::{min, Ordering}; -pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { - /// The bit-width of the internal representation. - const BIT_WIDTH: usize; +/// [`Decimal`] is the generic representation of a single decimal value +/// +/// See [`Decimal128`] and [`Decimal256`] for the value types of [`Decimal128Array`] +/// and [`Decimal256Array`] respectively +/// +/// [`Decimal128Array`]: [crate::array::Decimal128Array] +/// [`Decimal256Array`]: [crate::array::Decimal256Array] +pub struct Decimal { + precision: u8, + scale: u8, + value: T::Native, +} + +/// Manually implement to avoid `T: Debug` bound +impl std::fmt::Debug for Decimal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Decimal") + .field("scale", &self.precision) + .field("precision", &self.precision) + // TODO: Could format this better + .field("value", &self.value.as_ref()) + .finish() + } +} + +/// Manually implement to avoid `T: Debug` bound +impl Clone for Decimal { + fn clone(&self) -> Self { + Self { + precision: self.precision, + scale: self.scale, + value: self.value, + } + } +} + +impl Copy for Decimal {} + +impl Decimal { + pub const MAX_PRECISION: u8 = T::MAX_PRECISION; + pub const MAX_SCALE: u8 = T::MAX_SCALE; + pub const TYPE_CONSTRUCTOR: fn(u8, u8) -> DataType = T::TYPE_CONSTRUCTOR; + pub const DEFAULT_TYPE: DataType = T::DEFAULT_TYPE; /// Tries to create a decimal value from precision, scale and bytes. - /// If the length of bytes isn't same as the bit width of this decimal, - /// returning an error. The bytes should be stored in little-endian order. + /// The bytes should be stored in little-endian order. /// /// Safety: /// This method doesn't validate if the decimal value represented by the bytes /// can be fitted into the specified precision. - fn try_new_from_bytes(precision: usize, scale: usize, bytes: &[u8]) -> Result + pub fn try_new_from_bytes(precision: u8, scale: u8, bytes: &T::Native) -> Result where Self: Sized, { + if precision > Self::MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "precision {} is greater than max {}", + precision, + Self::MAX_PRECISION + ))); + } + if scale > Self::MAX_SCALE { + return Err(ArrowError::InvalidArgumentError(format!( + "scale {} is greater than max {}", + scale, + Self::MAX_SCALE + ))); + } + if precision < scale { return Err(ArrowError::InvalidArgumentError(format!( "Precision {} is less than scale {}", @@ -43,72 +102,109 @@ pub trait BasicDecimal: PartialOrd + Ord + PartialEq + Eq { ))); } - if bytes.len() == Self::BIT_WIDTH / 8 { - Ok(Self::new(precision, scale, bytes)) - } else { - Err(ArrowError::InvalidArgumentError(format!( - "Input to Decimal{} must be {} bytes", - Self::BIT_WIDTH, - Self::BIT_WIDTH / 8 - ))) - } + Ok(Self::new(precision, scale, bytes)) } /// Creates a decimal value from precision, scale, and bytes. /// /// Safety: - /// This method doesn't check if the length of bytes is compatible with this decimal. + /// This method doesn't check if the precision and scale are valid. /// Use `try_new_from_bytes` for safe constructor. - fn new(precision: usize, scale: usize, bytes: &[u8]) -> Self; - + pub fn new(precision: u8, scale: u8, bytes: &T::Native) -> Self { + Self { + precision, + scale, + value: *bytes, + } + } /// Returns the raw bytes of the integer representation of the decimal. - fn raw_value(&self) -> &[u8]; + pub fn raw_value(&self) -> &T::Native { + &self.value + } /// Returns the precision of the decimal. - fn precision(&self) -> usize; + pub fn precision(&self) -> u8 { + self.precision + } /// Returns the scale of the decimal. - fn scale(&self) -> usize; + pub fn scale(&self) -> u8 { + self.scale + } /// Returns the string representation of the decimal. /// If the string representation cannot be fitted with the precision of the decimal, /// the string will be truncated. - fn to_string(&self) -> String { + #[allow(clippy::inherent_to_string)] + pub fn to_string(&self) -> String { let raw_bytes = self.raw_value(); - let integer = BigInt::from_signed_bytes_le(raw_bytes); + let integer = BigInt::from_signed_bytes_le(raw_bytes.as_ref()); let value_str = integer.to_string(); let (sign, rest) = value_str.split_at(if integer >= BigInt::from(0) { 0 } else { 1 }); - let bound = min(self.precision(), rest.len()) + sign.len(); + let bound = min(usize::from(self.precision()), rest.len()) + sign.len(); let value_str = &value_str[0..bound]; + let scale_usize = usize::from(self.scale()); if self.scale() == 0 { value_str.to_string() - } else if rest.len() > self.scale() { + } else if rest.len() > scale_usize { // Decimal separator is in the middle of the string - let (whole, decimal) = value_str.split_at(value_str.len() - self.scale()); + let (whole, decimal) = value_str.split_at(value_str.len() - scale_usize); format!("{}.{}", whole, decimal) } else { // String has to be padded - format!("{}0.{:0>width$}", sign, rest, width = self.scale()) + format!("{}0.{:0>width$}", sign, rest, width = scale_usize) } } } +impl PartialOrd for Decimal { + fn partial_cmp(&self, other: &Self) -> Option { + assert_eq!( + self.scale, other.scale, + "Cannot compare two Decimals with different scale: {}, {}", + self.scale, other.scale + ); + Some(singed_cmp_le_bytes( + self.value.as_ref(), + other.value.as_ref(), + )) + } +} + +impl Ord for Decimal { + fn cmp(&self, other: &Self) -> Ordering { + assert_eq!( + self.scale, other.scale, + "Cannot compare two Decimals with different scale: {}, {}", + self.scale, other.scale + ); + singed_cmp_le_bytes(self.value.as_ref(), other.value.as_ref()) + } +} + +impl PartialEq for Decimal { + fn eq(&self, other: &Self) -> bool { + assert_eq!( + self.scale, other.scale, + "Cannot compare two Decimals with different scale: {}, {}", + self.scale, other.scale + ); + self.value.as_ref().eq(other.value.as_ref()) + } +} + +impl Eq for Decimal {} + /// Represents a decimal value with precision and scale. /// The decimal value could represented by a signed 128-bit integer. -#[derive(Debug)] -pub struct Decimal128 { - #[allow(dead_code)] - precision: usize, - scale: usize, - value: [u8; 16], -} +pub type Decimal128 = Decimal; impl Decimal128 { /// Creates `Decimal128` from an `i128` value. #[allow(dead_code)] - pub(crate) fn new_from_i128(precision: usize, scale: usize, value: i128) -> Self { + pub(crate) fn new_from_i128(precision: u8, scale: u8, value: i128) -> Self { Decimal128 { precision, scale, @@ -130,83 +226,81 @@ impl From for i128 { /// Represents a decimal value with precision and scale. /// The decimal value could be represented by a signed 256-bit integer. -#[derive(Debug)] -pub struct Decimal256 { - #[allow(dead_code)] - precision: usize, - scale: usize, - value: [u8; 32], -} +pub type Decimal256 = Decimal; -macro_rules! def_decimal { - ($ty:ident, $bit:expr) => { - impl BasicDecimal for $ty { - const BIT_WIDTH: usize = $bit; - - fn new(precision: usize, scale: usize, bytes: &[u8]) -> Self { - $ty { - precision, - scale, - value: bytes.try_into().unwrap(), - } - } - - fn raw_value(&self) -> &[u8] { - &self.value - } +impl Decimal256 { + /// Constructs a `Decimal256` value from a `BigInt`. + pub fn from_big_int(num: &BigInt, precision: u8, scale: u8) -> Result { + let mut bytes = if num.is_negative() { + [255_u8; 32] + } else { + [0; 32] + }; + let num_bytes = &num.to_signed_bytes_le(); + bytes[0..num_bytes.len()].clone_from_slice(num_bytes); + Decimal256::try_new_from_bytes(precision, scale, &bytes) + } - fn precision(&self) -> usize { - self.precision - } + /// Constructs a `BigInt` from this `Decimal256` value. + pub(crate) fn to_big_int(self) -> BigInt { + BigInt::from_signed_bytes_le(&self.value) + } +} - fn scale(&self) -> usize { - self.scale - } - } +impl From for Decimal256 { + fn from(bigint: BigInt) -> Self { + Decimal256::from_big_int(&bigint, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE) + .unwrap() + } +} - impl PartialOrd for $ty { - fn partial_cmp(&self, other: &Self) -> Option { - assert_eq!( - self.scale, other.scale, - "Cannot compare two Decimals with different scale: {}, {}", - self.scale, other.scale - ); - self.value.partial_cmp(&other.value) +// compare two signed integer which are encoded with little endian. +// left bytes and right bytes must have the same length. +#[inline] +pub(crate) fn singed_cmp_le_bytes(left: &[u8], right: &[u8]) -> Ordering { + assert_eq!( + left.len(), + right.len(), + "Can't compare bytes array with different len: {}, {}", + left.len(), + right.len() + ); + assert_ne!(left.len(), 0, "Can't compare bytes array of length 0"); + let len = left.len(); + // the sign bit is 1, the value is negative + let left_negative = left[len - 1] >= 0x80_u8; + let right_negative = right[len - 1] >= 0x80_u8; + if left_negative != right_negative { + return match left_negative { + true => { + // left is negative value + // right is positive value + Ordering::Less } - } - - impl Ord for $ty { - fn cmp(&self, other: &Self) -> Ordering { - assert_eq!( - self.scale, other.scale, - "Cannot compare two Decimals with different scale: {}, {}", - self.scale, other.scale - ); - self.value.cmp(&other.value) + false => Ordering::Greater, + }; + } + for i in 0..len { + let l_byte = left[len - 1 - i]; + let r_byte = right[len - 1 - i]; + match l_byte.cmp(&r_byte) { + Ordering::Less => { + return Ordering::Less; } - } - - impl PartialEq for $ty { - fn eq(&self, other: &Self) -> bool { - assert_eq!( - self.scale, other.scale, - "Cannot compare two Decimals with different scale: {}, {}", - self.scale, other.scale - ); - self.value.eq(&other.value) + Ordering::Greater => { + return Ordering::Greater; } + Ordering::Equal => {} } - - impl Eq for $ty {} - }; + } + Ordering::Equal } -def_decimal!(Decimal128, 128); -def_decimal!(Decimal256, 256); - #[cfg(test)] mod tests { - use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256}; + use super::*; + use num::{BigInt, Num}; + use rand::random; #[test] fn decimal_128_to_string() { @@ -257,9 +351,9 @@ mod tests { #[test] fn decimal_256_from_bytes() { - let mut bytes = vec![0; 32]; + let mut bytes = [0_u8; 32]; bytes[0..16].clone_from_slice(&100_i128.to_le_bytes()); - let value = Decimal256::try_new_from_bytes(5, 2, bytes.as_slice()).unwrap(); + let value = Decimal256::try_new_from_bytes(5, 2, &bytes).unwrap(); assert_eq!(value.to_string(), "1.00"); bytes[0..16].clone_from_slice(&i128::MAX.to_le_bytes()); @@ -279,15 +373,15 @@ mod tests { ); // smaller than i128 minimum - bytes = vec![255; 32]; + bytes = [255; 32]; bytes[31] = 128; - let value = Decimal256::try_new_from_bytes(79, 4, &bytes).unwrap(); + let value = Decimal256::try_new_from_bytes(76, 4, &bytes).unwrap(); assert_eq!( value.to_string(), - "-5744373177007483132341216834415376678658315645522012356644966081642565415.7313" + "-574437317700748313234121683441537667865831564552201235664496608164256541.5731" ); - bytes = vec![255; 32]; + bytes = [255; 32]; let value = Decimal256::try_new_from_bytes(5, 2, &bytes).unwrap(); assert_eq!(value.to_string(), "-0.01"); } @@ -302,4 +396,79 @@ mod tests { let integer = i128_func(value); assert_eq!(integer, 100); } + + #[test] + fn bigint_to_decimal256() { + let num = BigInt::from_str_radix("123456789", 10).unwrap(); + let value = Decimal256::from_big_int(&num, 30, 2).unwrap(); + assert_eq!(value.to_string(), "1234567.89"); + + let num = BigInt::from_str_radix("-5744373177007483132341216834415376678658315645522012356644966081642565415731", 10).unwrap(); + let value = Decimal256::from_big_int(&num, 76, 4).unwrap(); + assert_eq!(value.to_string(), "-574437317700748313234121683441537667865831564552201235664496608164256541.5731"); + } + + #[test] + fn test_lt_cmp_byte() { + for _i in 0..100 { + let left = random::(); + let right = random::(); + let result = singed_cmp_le_bytes( + left.to_le_bytes().as_slice(), + right.to_le_bytes().as_slice(), + ); + assert_eq!(left.cmp(&right), result); + } + for _i in 0..100 { + let left = random::(); + let right = random::(); + let result = singed_cmp_le_bytes( + left.to_le_bytes().as_slice(), + right.to_le_bytes().as_slice(), + ); + assert_eq!(left.cmp(&right), result); + } + } + + #[test] + fn compare_decimal128() { + let v1 = -100_i128; + let v2 = 10000_i128; + let right = Decimal128::new_from_i128(20, 3, v2); + for v in v1..v2 { + let left = Decimal128::new_from_i128(20, 3, v); + assert!(left < right); + } + + for _i in 0..100 { + let left = random::(); + let right = random::(); + let left_decimal = Decimal128::new_from_i128(38, 2, left); + let right_decimal = Decimal128::new_from_i128(38, 2, right); + assert_eq!(left < right, left_decimal < right_decimal); + assert_eq!(left == right, left_decimal == right_decimal) + } + } + + #[test] + fn compare_decimal256() { + let v1 = -100_i128; + let v2 = 10000_i128; + let right = Decimal256::from_big_int(&BigInt::from(v2), 75, 2).unwrap(); + for v in v1..v2 { + let left = Decimal256::from_big_int(&BigInt::from(v), 75, 2).unwrap(); + assert!(left < right); + } + + for _i in 0..100 { + let left = random::(); + let right = random::(); + let left_decimal = + Decimal256::from_big_int(&BigInt::from(left), 75, 2).unwrap(); + let right_decimal = + Decimal256::from_big_int(&BigInt::from(right), 75, 2).unwrap(); + assert_eq!(left < right, left_decimal < right_decimal); + assert_eq!(left == right, left_decimal == right_decimal) + } + } } diff --git a/arrow/src/util/display.rs b/arrow/src/util/display.rs index 7a7da8ccb0f5..aa4fd4200870 100644 --- a/arrow/src/util/display.rs +++ b/arrow/src/util/display.rs @@ -23,7 +23,6 @@ use std::fmt::Write; use std::sync::Arc; use crate::array::Array; -use crate::array::BasicDecimalArray; use crate::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Field, Int16Type, Int32Type, Int64Type, Int8Type, TimeUnit, UInt16Type, UInt32Type, UInt64Type, UInt8Type, @@ -256,7 +255,7 @@ macro_rules! make_string_from_fixed_size_list { pub fn make_string_from_decimal(column: &Arc, row: usize) -> Result { let array = column .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); let formatted_decimal = array.value_as_string(row); @@ -319,7 +318,7 @@ pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result make_string!(array::Float16Array, column, row), DataType::Float32 => make_string!(array::Float32Array, column, row), DataType::Float64 => make_string!(array::Float64Array, column, row), - DataType::Decimal(..) => make_string_from_decimal(column, row), + DataType::Decimal128(..) => make_string_from_decimal(column, row), DataType::Timestamp(unit, _) if *unit == TimeUnit::Second => { make_string_datetime!(array::TimestampSecondArray, column, row) } @@ -434,7 +433,7 @@ fn union_to_string( let name = fields.get(field_idx).unwrap().name(); let value = array_value_to_string( - &list.child(type_id), + list.child(type_id), match mode { UnionMode::Dense => list.value_offset(row) as usize, UnionMode::Sparse => row, diff --git a/arrow/src/util/integration_util.rs b/arrow/src/util/integration_util.rs deleted file mode 100644 index ee32f0c39902..000000000000 --- a/arrow/src/util/integration_util.rs +++ /dev/null @@ -1,1010 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Utils for JSON integration testing -//! -//! These utilities define structs that read the integration JSON format for integration testing purposes. - -use serde_derive::{Deserialize, Serialize}; -use serde_json::{Map as SJMap, Number as VNumber, Value}; - -use crate::array::*; -use crate::datatypes::*; -use crate::error::Result; -use crate::record_batch::{RecordBatch, RecordBatchReader}; - -/// A struct that represents an Arrow file with a schema and record batches -#[derive(Deserialize, Serialize, Debug)] -pub struct ArrowJson { - pub schema: ArrowJsonSchema, - pub batches: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub dictionaries: Option>, -} - -/// A struct that partially reads the Arrow JSON schema. -/// -/// Fields are left as JSON `Value` as they vary by `DataType` -#[derive(Deserialize, Serialize, Debug)] -pub struct ArrowJsonSchema { - pub fields: Vec, -} - -/// Fields are left as JSON `Value` as they vary by `DataType` -#[derive(Deserialize, Serialize, Debug)] -pub struct ArrowJsonField { - pub name: String, - #[serde(rename = "type")] - pub field_type: Value, - pub nullable: bool, - pub children: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub dictionary: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option, -} - -impl From<&Field> for ArrowJsonField { - fn from(field: &Field) -> Self { - let metadata_value = match field.metadata() { - Some(kv_list) => { - let mut array = Vec::new(); - for (k, v) in kv_list { - let mut kv_map = SJMap::new(); - kv_map.insert(k.clone(), Value::String(v.clone())); - array.push(Value::Object(kv_map)); - } - if !array.is_empty() { - Some(Value::Array(array)) - } else { - None - } - } - _ => None, - }; - - Self { - name: field.name().to_string(), - field_type: field.data_type().to_json(), - nullable: field.is_nullable(), - children: vec![], - dictionary: None, // TODO: not enough info - metadata: metadata_value, - } - } -} - -#[derive(Deserialize, Serialize, Debug)] -pub struct ArrowJsonFieldDictionary { - pub id: i64, - #[serde(rename = "indexType")] - pub index_type: DictionaryIndexType, - #[serde(rename = "isOrdered")] - pub is_ordered: bool, -} - -#[derive(Deserialize, Serialize, Debug)] -pub struct DictionaryIndexType { - pub name: String, - #[serde(rename = "isSigned")] - pub is_signed: bool, - #[serde(rename = "bitWidth")] - pub bit_width: i64, -} - -/// A struct that partially reads the Arrow JSON record batch -#[derive(Deserialize, Serialize, Debug)] -pub struct ArrowJsonBatch { - count: usize, - pub columns: Vec, -} - -/// A struct that partially reads the Arrow JSON dictionary batch -#[derive(Deserialize, Serialize, Debug)] -#[allow(non_snake_case)] -pub struct ArrowJsonDictionaryBatch { - pub id: i64, - pub data: ArrowJsonBatch, -} - -/// A struct that partially reads the Arrow JSON column/array -#[derive(Deserialize, Serialize, Clone, Debug)] -pub struct ArrowJsonColumn { - name: String, - pub count: usize, - #[serde(rename = "VALIDITY")] - pub validity: Option>, - #[serde(rename = "DATA")] - pub data: Option>, - #[serde(rename = "OFFSET")] - pub offset: Option>, // leaving as Value as 64-bit offsets are strings - #[serde(rename = "TYPE_ID")] - pub type_id: Option>, - pub children: Option>, -} - -impl ArrowJson { - /// Compare the Arrow JSON with a record batch reader - pub fn equals_reader(&self, reader: &mut dyn RecordBatchReader) -> bool { - if !self.schema.equals_schema(&reader.schema()) { - return false; - } - self.batches.iter().all(|col| { - let batch = reader.next(); - match batch { - Some(Ok(batch)) => col.equals_batch(&batch), - _ => false, - } - }) - } -} - -impl ArrowJsonSchema { - /// Compare the Arrow JSON schema with the Arrow `Schema` - fn equals_schema(&self, schema: &Schema) -> bool { - let field_len = self.fields.len(); - if field_len != schema.fields().len() { - return false; - } - for i in 0..field_len { - let json_field = &self.fields[i]; - let field = schema.field(i); - if !json_field.equals_field(field) { - return false; - } - } - true - } -} - -impl ArrowJsonField { - /// Compare the Arrow JSON field with the Arrow `Field` - fn equals_field(&self, field: &Field) -> bool { - // convert to a field - match self.to_arrow_field() { - Ok(self_field) => { - assert_eq!(&self_field, field, "Arrow fields not the same"); - true - } - Err(e) => { - eprintln!( - "Encountered error while converting JSON field to Arrow field: {:?}", - e - ); - false - } - } - } - - /// Convert to an Arrow Field - /// TODO: convert to use an Into - fn to_arrow_field(&self) -> Result { - // a bit regressive, but we have to convert the field to JSON in order to convert it - let field = serde_json::to_value(self)?; - Field::from(&field) - } -} - -impl ArrowJsonBatch { - /// Compare the Arrow JSON record batch with a `RecordBatch` - fn equals_batch(&self, batch: &RecordBatch) -> bool { - if self.count != batch.num_rows() { - return false; - } - let num_columns = self.columns.len(); - if num_columns != batch.num_columns() { - return false; - } - let schema = batch.schema(); - self.columns - .iter() - .zip(batch.columns()) - .zip(schema.fields()) - .all(|((col, arr), field)| { - // compare each column based on its type - if &col.name != field.name() { - return false; - } - let json_array: Vec = json_from_col(col, field.data_type()); - match field.data_type() { - DataType::Null => { - let arr: &NullArray = - arr.as_any().downcast_ref::().unwrap(); - // NullArrays should have the same length, json_array is empty - arr.len() == col.count - } - DataType::Boolean => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Int8 => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Int16 => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Int32 | DataType::Date32 | DataType::Time32(_) => { - let arr = Int32Array::from(arr.data().clone()); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Int64 - | DataType::Date64 - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) => { - let arr = Int64Array::from(arr.data().clone()); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Interval(IntervalUnit::YearMonth) => { - let arr = IntervalYearMonthArray::from(arr.data().clone()); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Interval(IntervalUnit::DayTime) => { - let arr = IntervalDayTimeArray::from(arr.data().clone()); - let x = json_array - .iter() - .map(|v| { - match v { - Value::Null => Value::Null, - Value::Object(v) => { - // interval has days and milliseconds - let days: i32 = - v.get("days").unwrap().as_i64().unwrap() - as i32; - let milliseconds: i32 = v - .get("milliseconds") - .unwrap() - .as_i64() - .unwrap() - as i32; - let value: i64 = unsafe { - std::mem::transmute::<[i32; 2], i64>([ - days, - milliseconds, - ]) - }; - Value::Number(VNumber::from(value)) - } - // return null if Value is not an object - _ => Value::Null, - } - }) - .collect::>(); - arr.equals_json(&x.iter().collect::>()[..]) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let arr = IntervalMonthDayNanoArray::from(arr.data().clone()); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::UInt8 => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::UInt16 => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::UInt32 => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::UInt64 => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Float32 => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Float64 => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Binary => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::LargeBinary => { - let arr = - arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::FixedSizeBinary(_) => { - let arr = - arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Utf8 => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::LargeUtf8 => { - let arr = - arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::List(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::LargeList(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::FixedSizeList(_, _) => { - let arr = - arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Map(_, _) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Decimal(_, _) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - arr.equals_json(&json_array.iter().collect::>()[..]) - } - DataType::Dictionary(ref key_type, _) => match key_type.as_ref() { - DataType::Int8 => { - let arr = arr - .as_any() - .downcast_ref::() - .unwrap(); - arr.equals_json( - &json_array.iter().collect::>()[..], - ) - } - DataType::Int16 => { - let arr = arr - .as_any() - .downcast_ref::() - .unwrap(); - arr.equals_json( - &json_array.iter().collect::>()[..], - ) - } - DataType::Int32 => { - let arr = arr - .as_any() - .downcast_ref::() - .unwrap(); - arr.equals_json( - &json_array.iter().collect::>()[..], - ) - } - DataType::Int64 => { - let arr = arr - .as_any() - .downcast_ref::() - .unwrap(); - arr.equals_json( - &json_array.iter().collect::>()[..], - ) - } - DataType::UInt8 => { - let arr = arr - .as_any() - .downcast_ref::() - .unwrap(); - arr.equals_json( - &json_array.iter().collect::>()[..], - ) - } - DataType::UInt16 => { - let arr = arr - .as_any() - .downcast_ref::() - .unwrap(); - arr.equals_json( - &json_array.iter().collect::>()[..], - ) - } - DataType::UInt32 => { - let arr = arr - .as_any() - .downcast_ref::() - .unwrap(); - arr.equals_json( - &json_array.iter().collect::>()[..], - ) - } - DataType::UInt64 => { - let arr = arr - .as_any() - .downcast_ref::() - .unwrap(); - arr.equals_json( - &json_array.iter().collect::>()[..], - ) - } - t => panic!("Unsupported dictionary comparison for {:?}", t), - }, - t => panic!("Unsupported comparison for {:?}", t), - } - }) - } - - pub fn from_batch(batch: &RecordBatch) -> ArrowJsonBatch { - let mut json_batch = ArrowJsonBatch { - count: batch.num_rows(), - columns: Vec::with_capacity(batch.num_columns()), - }; - - for (col, field) in batch.columns().iter().zip(batch.schema().fields.iter()) { - let json_col = match field.data_type() { - DataType::Int8 => { - let col = col.as_any().downcast_ref::().unwrap(); - - let mut validity: Vec = Vec::with_capacity(col.len()); - let mut data: Vec = Vec::with_capacity(col.len()); - - for i in 0..col.len() { - if col.is_null(i) { - validity.push(1); - data.push(0i8.into()); - } else { - validity.push(0); - data.push(col.value(i).into()); - } - } - - ArrowJsonColumn { - name: field.name().clone(), - count: col.len(), - validity: Some(validity), - data: Some(data), - offset: None, - type_id: None, - children: None, - } - } - _ => ArrowJsonColumn { - name: field.name().clone(), - count: col.len(), - validity: None, - data: None, - offset: None, - type_id: None, - children: None, - }, - }; - - json_batch.columns.push(json_col); - } - - json_batch - } -} - -/// Convert an Arrow JSON column/array into a vector of `Value` -fn json_from_col(col: &ArrowJsonColumn, data_type: &DataType) -> Vec { - match data_type { - DataType::List(field) => json_from_list_col(col, field.data_type()), - DataType::FixedSizeList(field, list_size) => { - json_from_fixed_size_list_col(col, field.data_type(), *list_size as usize) - } - DataType::Struct(fields) => json_from_struct_col(col, fields), - DataType::Map(field, keys_sorted) => json_from_map_col(col, field, *keys_sorted), - DataType::Int64 - | DataType::UInt64 - | DataType::Date64 - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) => { - // convert int64 data from strings to numbers - let converted_col: Vec = col - .data - .clone() - .unwrap() - .iter() - .map(|v| { - Value::Number(match v { - Value::Number(number) => number.clone(), - Value::String(string) => VNumber::from( - string - .parse::() - .expect("Unable to parse string as i64"), - ), - t => panic!("Cannot convert {} to number", t), - }) - }) - .collect(); - merge_json_array( - col.validity.as_ref().unwrap().as_slice(), - converted_col.as_slice(), - ) - } - DataType::Null => vec![], - _ => merge_json_array( - col.validity.as_ref().unwrap().as_slice(), - &col.data.clone().unwrap(), - ), - } -} - -/// Merge VALIDITY and DATA vectors from a primitive data type into a `Value` vector with nulls -fn merge_json_array(validity: &[u8], data: &[Value]) -> Vec { - validity - .iter() - .zip(data) - .map(|(v, d)| match v { - 0 => Value::Null, - 1 => d.clone(), - _ => panic!("Validity data should be 0 or 1"), - }) - .collect() -} - -/// Convert an Arrow JSON column/array of a `DataType::Struct` into a vector of `Value` -fn json_from_struct_col(col: &ArrowJsonColumn, fields: &[Field]) -> Vec { - let mut values = Vec::with_capacity(col.count); - - let children: Vec> = col - .children - .clone() - .unwrap() - .iter() - .zip(fields) - .map(|(child, field)| json_from_col(child, field.data_type())) - .collect(); - - // create a struct from children - for j in 0..col.count { - let mut map = serde_json::map::Map::new(); - for i in 0..children.len() { - map.insert(fields[i].name().to_string(), children[i][j].clone()); - } - values.push(Value::Object(map)); - } - - values -} - -/// Convert an Arrow JSON column/array of a `DataType::List` into a vector of `Value` -fn json_from_list_col(col: &ArrowJsonColumn, data_type: &DataType) -> Vec { - let mut values = Vec::with_capacity(col.count); - - // get the inner array - let child = &col.children.clone().expect("list type must have children")[0]; - let offsets: Vec = col - .offset - .clone() - .unwrap() - .iter() - .map(|o| match o { - Value::String(s) => s.parse::().unwrap(), - Value::Number(n) => n.as_u64().unwrap() as usize, - _ => panic!( - "Offsets should be numbers or strings that are convertible to numbers" - ), - }) - .collect(); - let inner = match data_type { - DataType::List(ref field) => json_from_col(child, field.data_type()), - DataType::Struct(fields) => json_from_struct_col(col, fields), - _ => merge_json_array( - child.validity.as_ref().unwrap().as_slice(), - &child.data.clone().unwrap(), - ), - }; - - for i in 0..col.count { - match &col.validity { - Some(validity) => match &validity[i] { - 0 => values.push(Value::Null), - 1 => { - values.push(Value::Array(inner[offsets[i]..offsets[i + 1]].to_vec())) - } - _ => panic!("Validity data should be 0 or 1"), - }, - None => { - // Null type does not have a validity vector - } - } - } - - values -} - -/// Convert an Arrow JSON column/array of a `DataType::List` into a vector of `Value` -fn json_from_fixed_size_list_col( - col: &ArrowJsonColumn, - data_type: &DataType, - list_size: usize, -) -> Vec { - let mut values = Vec::with_capacity(col.count); - - // get the inner array - let child = &col.children.clone().expect("list type must have children")[0]; - let inner = match data_type { - DataType::List(ref field) => json_from_col(child, field.data_type()), - DataType::FixedSizeList(ref field, _) => json_from_col(child, field.data_type()), - DataType::Struct(fields) => json_from_struct_col(col, fields), - _ => merge_json_array( - child.validity.as_ref().unwrap().as_slice(), - &child.data.clone().unwrap(), - ), - }; - - for i in 0..col.count { - match &col.validity { - Some(validity) => match &validity[i] { - 0 => values.push(Value::Null), - 1 => values.push(Value::Array( - inner[(list_size * i)..(list_size * (i + 1))].to_vec(), - )), - _ => panic!("Validity data should be 0 or 1"), - }, - None => {} - } - } - - values -} - -fn json_from_map_col( - col: &ArrowJsonColumn, - field: &Field, - _keys_sorted: bool, -) -> Vec { - let mut values = Vec::with_capacity(col.count); - - // get the inner array - let child = &col.children.clone().expect("list type must have children")[0]; - let offsets: Vec = col - .offset - .clone() - .unwrap() - .iter() - .map(|o| match o { - Value::String(s) => s.parse::().unwrap(), - Value::Number(n) => n.as_u64().unwrap() as usize, - _ => panic!( - "Offsets should be numbers or strings that are convertible to numbers" - ), - }) - .collect(); - - let inner = match field.data_type() { - DataType::Struct(fields) => json_from_struct_col(child, fields), - _ => panic!("Map child must be Struct"), - }; - - for i in 0..col.count { - match &col.validity { - Some(validity) => match &validity[i] { - 0 => values.push(Value::Null), - 1 => { - values.push(Value::Array(inner[offsets[i]..offsets[i + 1]].to_vec())) - } - _ => panic!("Validity data should be 0 or 1"), - }, - None => { - // Null type does not have a validity vector - } - } - } - - values -} -#[cfg(test)] -mod tests { - use super::*; - - use std::fs::File; - use std::io::Read; - use std::sync::Arc; - - use crate::buffer::Buffer; - - #[test] - fn test_schema_equality() { - let json = r#" - { - "fields": [ - { - "name": "c1", - "type": {"name": "int", "isSigned": true, "bitWidth": 32}, - "nullable": true, - "children": [] - }, - { - "name": "c2", - "type": {"name": "floatingpoint", "precision": "DOUBLE"}, - "nullable": true, - "children": [] - }, - { - "name": "c3", - "type": {"name": "utf8"}, - "nullable": true, - "children": [] - }, - { - "name": "c4", - "type": { - "name": "list" - }, - "nullable": true, - "children": [ - { - "name": "custom_item", - "type": { - "name": "int", - "isSigned": true, - "bitWidth": 32 - }, - "nullable": false, - "children": [] - } - ] - } - ] - }"#; - let json_schema: ArrowJsonSchema = serde_json::from_str(json).unwrap(); - let schema = Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Float64, true), - Field::new("c3", DataType::Utf8, true), - Field::new( - "c4", - DataType::List(Box::new(Field::new( - "custom_item", - DataType::Int32, - false, - ))), - true, - ), - ]); - assert!(json_schema.equals_schema(&schema)); - } - - #[test] - #[cfg_attr(miri, ignore)] // running forever - fn test_arrow_data_equality() { - let secs_tz = Some("Europe/Budapest".to_string()); - let millis_tz = Some("America/New_York".to_string()); - let micros_tz = Some("UTC".to_string()); - let nanos_tz = Some("Africa/Johannesburg".to_string()); - - let schema = - Schema::new(vec![ - Field::new("bools-with-metadata-map", DataType::Boolean, true) - .with_metadata(Some( - [("k".to_string(), "v".to_string())] - .iter() - .cloned() - .collect(), - )), - Field::new("bools-with-metadata-vec", DataType::Boolean, true) - .with_metadata(Some( - [("k2".to_string(), "v2".to_string())] - .iter() - .cloned() - .collect(), - )), - Field::new("bools", DataType::Boolean, true), - Field::new("int8s", DataType::Int8, true), - Field::new("int16s", DataType::Int16, true), - Field::new("int32s", DataType::Int32, true), - Field::new("int64s", DataType::Int64, true), - Field::new("uint8s", DataType::UInt8, true), - Field::new("uint16s", DataType::UInt16, true), - Field::new("uint32s", DataType::UInt32, true), - Field::new("uint64s", DataType::UInt64, true), - Field::new("float32s", DataType::Float32, true), - Field::new("float64s", DataType::Float64, true), - Field::new("date_days", DataType::Date32, true), - Field::new("date_millis", DataType::Date64, true), - Field::new("time_secs", DataType::Time32(TimeUnit::Second), true), - Field::new("time_millis", DataType::Time32(TimeUnit::Millisecond), true), - Field::new("time_micros", DataType::Time64(TimeUnit::Microsecond), true), - Field::new("time_nanos", DataType::Time64(TimeUnit::Nanosecond), true), - Field::new("ts_secs", DataType::Timestamp(TimeUnit::Second, None), true), - Field::new( - "ts_millis", - DataType::Timestamp(TimeUnit::Millisecond, None), - true, - ), - Field::new( - "ts_micros", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - Field::new( - "ts_nanos", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - Field::new( - "ts_secs_tz", - DataType::Timestamp(TimeUnit::Second, secs_tz.clone()), - true, - ), - Field::new( - "ts_millis_tz", - DataType::Timestamp(TimeUnit::Millisecond, millis_tz.clone()), - true, - ), - Field::new( - "ts_micros_tz", - DataType::Timestamp(TimeUnit::Microsecond, micros_tz.clone()), - true, - ), - Field::new( - "ts_nanos_tz", - DataType::Timestamp(TimeUnit::Nanosecond, nanos_tz.clone()), - true, - ), - Field::new("utf8s", DataType::Utf8, true), - Field::new( - "lists", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - ), - Field::new( - "structs", - DataType::Struct(vec![ - Field::new("int32s", DataType::Int32, true), - Field::new("utf8s", DataType::Utf8, true), - ]), - true, - ), - ]); - - let bools_with_metadata_map = - BooleanArray::from(vec![Some(true), None, Some(false)]); - let bools_with_metadata_vec = - BooleanArray::from(vec![Some(true), None, Some(false)]); - let bools = BooleanArray::from(vec![Some(true), None, Some(false)]); - let int8s = Int8Array::from(vec![Some(1), None, Some(3)]); - let int16s = Int16Array::from(vec![Some(1), None, Some(3)]); - let int32s = Int32Array::from(vec![Some(1), None, Some(3)]); - let int64s = Int64Array::from(vec![Some(1), None, Some(3)]); - let uint8s = UInt8Array::from(vec![Some(1), None, Some(3)]); - let uint16s = UInt16Array::from(vec![Some(1), None, Some(3)]); - let uint32s = UInt32Array::from(vec![Some(1), None, Some(3)]); - let uint64s = UInt64Array::from(vec![Some(1), None, Some(3)]); - let float32s = Float32Array::from(vec![Some(1.0), None, Some(3.0)]); - let float64s = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); - let date_days = Date32Array::from(vec![Some(1196848), None, None]); - let date_millis = Date64Array::from(vec![ - Some(167903550396207), - Some(29923997007884), - Some(30612271819236), - ]); - let time_secs = - Time32SecondArray::from(vec![Some(27974), Some(78592), Some(43207)]); - let time_millis = Time32MillisecondArray::from(vec![ - Some(6613125), - Some(74667230), - Some(52260079), - ]); - let time_micros = - Time64MicrosecondArray::from(vec![Some(62522958593), None, None]); - let time_nanos = Time64NanosecondArray::from(vec![ - Some(73380123595985), - None, - Some(16584393546415), - ]); - let ts_secs = TimestampSecondArray::from_opt_vec( - vec![None, Some(193438817552), None], - None, - ); - let ts_millis = TimestampMillisecondArray::from_opt_vec( - vec![None, Some(38606916383008), Some(58113709376587)], - None, - ); - let ts_micros = - TimestampMicrosecondArray::from_opt_vec(vec![None, None, None], None); - let ts_nanos = TimestampNanosecondArray::from_opt_vec( - vec![None, None, Some(-6473623571954960143)], - None, - ); - let ts_secs_tz = TimestampSecondArray::from_opt_vec( - vec![None, Some(193438817552), None], - secs_tz, - ); - let ts_millis_tz = TimestampMillisecondArray::from_opt_vec( - vec![None, Some(38606916383008), Some(58113709376587)], - millis_tz, - ); - let ts_micros_tz = - TimestampMicrosecondArray::from_opt_vec(vec![None, None, None], micros_tz); - let ts_nanos_tz = TimestampNanosecondArray::from_opt_vec( - vec![None, None, Some(-6473623571954960143)], - nanos_tz, - ); - let utf8s = StringArray::from(vec![Some("aa"), None, Some("bbb")]); - - let value_data = Int32Array::from(vec![None, Some(2), None, None]); - let value_offsets = Buffer::from_slice_ref(&[0, 3, 4, 4]); - let list_data_type = - DataType::List(Box::new(Field::new("item", DataType::Int32, true))); - let list_data = ArrayData::builder(list_data_type) - .len(3) - .add_buffer(value_offsets) - .add_child_data(value_data.into_data()) - .build() - .unwrap(); - let lists = ListArray::from(list_data); - - let structs_int32s = Int32Array::from(vec![None, Some(-2), None]); - let structs_utf8s = StringArray::from(vec![None, None, Some("aaaaaa")]); - let structs = StructArray::from(vec![ - ( - Field::new("int32s", DataType::Int32, true), - Arc::new(structs_int32s) as ArrayRef, - ), - ( - Field::new("utf8s", DataType::Utf8, true), - Arc::new(structs_utf8s) as ArrayRef, - ), - ]); - - let record_batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![ - Arc::new(bools_with_metadata_map), - Arc::new(bools_with_metadata_vec), - Arc::new(bools), - Arc::new(int8s), - Arc::new(int16s), - Arc::new(int32s), - Arc::new(int64s), - Arc::new(uint8s), - Arc::new(uint16s), - Arc::new(uint32s), - Arc::new(uint64s), - Arc::new(float32s), - Arc::new(float64s), - Arc::new(date_days), - Arc::new(date_millis), - Arc::new(time_secs), - Arc::new(time_millis), - Arc::new(time_micros), - Arc::new(time_nanos), - Arc::new(ts_secs), - Arc::new(ts_millis), - Arc::new(ts_micros), - Arc::new(ts_nanos), - Arc::new(ts_secs_tz), - Arc::new(ts_millis_tz), - Arc::new(ts_micros_tz), - Arc::new(ts_nanos_tz), - Arc::new(utf8s), - Arc::new(lists), - Arc::new(structs), - ], - ) - .unwrap(); - let mut file = File::open("test/data/integration.json").unwrap(); - let mut json = String::new(); - file.read_to_string(&mut json).unwrap(); - let arrow_json: ArrowJson = serde_json::from_str(&json).unwrap(); - // test schemas - assert!(arrow_json.schema.equals_schema(&schema)); - // test record batch - assert!(arrow_json.batches[0].equals_batch(&record_batch)); - } -} diff --git a/arrow/src/util/mod.rs b/arrow/src/util/mod.rs index 86253da8d777..6f68398e7703 100644 --- a/arrow/src/util/mod.rs +++ b/arrow/src/util/mod.rs @@ -24,13 +24,11 @@ pub mod bit_util; #[cfg(feature = "test_utils")] pub mod data_gen; pub mod display; -#[cfg(feature = "test_utils")] -pub mod integration_util; #[cfg(feature = "prettyprint")] pub mod pretty; pub(crate) mod serialization; pub mod string_writer; -#[cfg(feature = "test_utils")] +#[cfg(any(test, feature = "test_utils"))] pub mod test_util; mod trusted_len; diff --git a/arrow/src/util/pretty.rs b/arrow/src/util/pretty.rs index 124de6127ddd..b0013619b50c 100644 --- a/arrow/src/util/pretty.rs +++ b/arrow/src/util/pretty.rs @@ -19,9 +19,8 @@ //! available unless `feature = "prettyprint"` is enabled. use crate::{array::ArrayRef, record_batch::RecordBatch}; -use std::fmt::Display; - use comfy_table::{Cell, Table}; +use std::fmt::Display; use crate::error::Result; @@ -120,7 +119,7 @@ mod tests { }; use super::*; - use crate::array::{DecimalArray, FixedSizeListBuilder}; + use crate::array::{Decimal128Array, FixedSizeListBuilder}; use std::fmt::Write; use std::sync::Arc; @@ -242,12 +241,12 @@ mod tests { DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); - let keys_builder = PrimitiveBuilder::::new(10); - let values_builder = StringBuilder::new(10); + let keys_builder = PrimitiveBuilder::::with_capacity(10); + let values_builder = StringBuilder::new(); let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); builder.append("one")?; - builder.append_null()?; + builder.append_null(); builder.append("three")?; let array = Arc::new(builder.finish()); @@ -284,12 +283,12 @@ mod tests { let keys_builder = Int32Array::builder(3); let mut builder = FixedSizeListBuilder::new(keys_builder, 3); - builder.values().append_slice(&[1, 2, 3]).unwrap(); - builder.append(true).unwrap(); - builder.values().append_slice(&[4, 5, 6]).unwrap(); - builder.append(false).unwrap(); - builder.values().append_slice(&[7, 8, 9]).unwrap(); - builder.append(true).unwrap(); + builder.values().append_slice(&[1, 2, 3]); + builder.append(true); + builder.values().append_slice(&[4, 5, 6]); + builder.append(false); + builder.values().append_slice(&[7, 8, 9]); + builder.append(true); let array = Arc::new(builder.finish()); @@ -318,10 +317,10 @@ mod tests { let field_type = DataType::FixedSizeBinary(3); let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); - let mut builder = FixedSizeBinaryBuilder::new(3, 3); + let mut builder = FixedSizeBinaryBuilder::with_capacity(3, 3); builder.append_value(&[1, 2, 3]).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append_value(&[7, 8, 9]).unwrap(); let array = Arc::new(builder.finish()); @@ -351,8 +350,8 @@ mod tests { macro_rules! check_datetime { ($ARRAYTYPE:ident, $VALUE:expr, $EXPECTED_RESULT:expr) => { let mut builder = $ARRAYTYPE::builder(10); - builder.append_value($VALUE).unwrap(); - builder.append_null().unwrap(); + builder.append_value($VALUE); + builder.append_null(); let array = builder.finish(); let schema = Arc::new(Schema::new(vec![Field::new( @@ -523,7 +522,7 @@ mod tests { let array = [Some(101), None, Some(200), Some(3040)] .into_iter() - .collect::() + .collect::() .with_precision_and_scale(precision, scale) .unwrap(); @@ -563,7 +562,7 @@ mod tests { let array = [Some(101), None, Some(200), Some(3040)] .into_iter() - .collect::() + .collect::() .with_precision_and_scale(precision, scale) .unwrap(); @@ -650,7 +649,7 @@ mod tests { #[test] fn test_pretty_format_dense_union() -> Result<()> { - let mut builder = UnionBuilder::new_dense(4); + let mut builder = UnionBuilder::new_dense(); builder.append::("a", 1).unwrap(); builder.append::("b", 3.2234).unwrap(); builder.append_null::("b").unwrap(); @@ -691,7 +690,7 @@ mod tests { #[test] fn test_pretty_format_sparse_union() -> Result<()> { - let mut builder = UnionBuilder::new_sparse(4); + let mut builder = UnionBuilder::new_sparse(); builder.append::("a", 1).unwrap(); builder.append::("b", 3.2234).unwrap(); builder.append_null::("b").unwrap(); @@ -733,7 +732,7 @@ mod tests { #[test] fn test_pretty_format_nested_union() -> Result<()> { //Inner UnionArray - let mut builder = UnionBuilder::new_dense(5); + let mut builder = UnionBuilder::new_dense(); builder.append::("b", 1).unwrap(); builder.append::("c", 3.2234).unwrap(); builder.append_null::("c").unwrap(); diff --git a/arrow/tests/schema.rs b/arrow/tests/schema.rs new file mode 100644 index 000000000000..ff544b68937b --- /dev/null +++ b/arrow/tests/schema.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use std::collections::HashMap; +/// The tests in this file ensure a `Schema` can be manipulated +/// outside of the arrow crate + +#[test] +fn schema_destructure() { + let meta = [("foo".to_string(), "baz".to_string())] + .into_iter() + .collect::>(); + + let field = Field::new("c1", DataType::Utf8, false); + let schema = Schema::new(vec![field]).with_metadata(meta); + + // Destructuring a Schema allows rewriting fields and metadata + // without copying + // + // Model this usecase below: + + let Schema { + mut fields, + metadata, + } = schema; + fields.push(Field::new("c2", DataType::Utf8, false)); + + let new_schema = Schema::new(fields).with_metadata(metadata); + + assert_eq!(new_schema.fields().len(), 2); +} diff --git a/dev/release/README.md b/dev/release/README.md index a6315dc92615..3783301e9bed 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -21,10 +21,23 @@ ## Overview -We try to release a new version of Arrow every two weeks. This cadence balances getting new features into arrow without overwhelming downstream projects with too frequent changes. +This file documents the release process for: + +1. The "Rust Arrow Crates": `arrow`, `arrow-flight`, `parquet`, and `parquet-derive`. +2. The `object_store` crate. + +### The Rust Arrow Crates + +The Rust Arrow Crates are interconnected (e.g. `parquet` has an optional dependency on `arrow`) so we increment and release all of them together. We try to release a new version of "Rust Arrow Crates" every two weeks. This cadence balances getting new features into the community without overwhelming downstream projects with too frequent changes or overly burdening maintainers. If any code has been merged to master that has a breaking API change, as defined in [Rust RFC 1105](https://github.com/rust-lang/rfcs/blob/master/text/1105-api-evolution.md), the major version number incremented changed (e.g. `9.0.2` to `9.0.2`). Otherwise the new minor version incremented (e.g. `9.0.2` to `7.1.0`). +### `object_store` crate + +At the time of writing, we release a new version of `object_store` on demand rather than on a regular schedule. + +As we are still in an early phase, we use the 0.x version scheme. If any code has been merged to master that has a breaking API change, as defined in [Rust RFC 1105](https://github.com/rust-lang/rfcs/blob/master/text/1105-api-evolution.md), the minor version number incremented changed (e.g. `0.3.0` to `0.4.0`). Otherwise the patch version is incremented (e.g. `0.3.0` to `0.3.1`). + # Release Mechanics ## Process Overview @@ -47,13 +60,17 @@ labels associated with them. Now prepare a PR to update `CHANGELOG.md` and versions on `master` to reflect the planned release. -See [#1141](https://github.com/apache/arrow-rs/pull/1141) for an example. +For the Rust Arrow crates, do this in the root of this repository. For example [#2323](https://github.com/apache/arrow-rs/pull/2323) + +For `object_store` the same process is done in the `object_store` directory. Examples TBD ```bash git checkout master git pull git checkout -b make-release +# Copy the content of CHANGELOG.md to the beginning of CHANGELOG-old.md + # manully edit ./dev/release/update_change_log.sh to reflect the release version # create the changelog CHANGELOG_GITHUB_TOKEN= ./dev/release/update_change_log.sh @@ -61,7 +78,7 @@ CHANGELOG_GITHUB_TOKEN= ./dev/release/update_change_log.sh git commit -a -m 'Create changelog' # update versions -sed -i '' -e 's/14.0.0/18.0.0/g' `find . -name 'Cargo.toml' -or -name '*.md' | grep -v CHANGELOG.md` +sed -i '' -e 's/14.0.0/22.0.0/g' `find . -name 'Cargo.toml' -or -name '*.md' | grep -v CHANGELOG.md` git commit -a -m 'Update version' ``` @@ -82,7 +99,11 @@ distribution servers. While the official release artifact is a signed tarball, we also tag the commit it was created for convenience and code archaeology. -Using a string such as `4.0.1` as the ``, create and push the tag thusly: +For a Rust Arrow Crates release, use a string such as `4.0.1` as the ``. + +For `object_store` releases, use a string such as `object_store_0.4.0` as the ``. + +Create and push the tag thusly: ```shell git fetch apache @@ -97,12 +118,20 @@ Pick numbers in sequential order, with `1` for `rc1`, `2` for `rc2`, etc. ### Create, sign, and upload tarball -Run `create-tarball.sh` with the `` tag and `` and you found in previous steps: +Run `create-tarball.sh` with the `` tag and `` and you found in previous steps. + +Rust Arrow Crates: ```shell ./dev/release/create-tarball.sh 4.1.0 2 ``` +`object_store`: + +```shell +./object_store/dev/release/create-tarball.sh 4.1.0 2 +``` + The `create-tarball.sh` script 1. creates and uploads a release candidate tarball to the [arrow @@ -114,7 +143,7 @@ The `create-tarball.sh` script ### Vote on Release Candidate tarball -Send the email output from the script to dev@arrow.apache.org. The email should look like +Send an email, based on the output from the script to dev@arrow.apache.org. The email should look like ``` To: dev@arrow.apache.org @@ -144,11 +173,11 @@ The vote will be open for at least 72 hours. [3]: https://github.com/apache/arrow-rs/blob/a5dd428f57e62db20a945e8b1895de91405958c4/CHANGELOG.md ``` -For the release to become "official" it needs at least three PMC members to vote +1 on it. +For the release to become "official" it needs at least three Apache Arrow PMC members to vote +1 on it. ## Verifying release candidates -The `dev/release/verify-release-candidate.sh` is a script in this repository that can assist in the verification process. Run it like: +The `dev/release/verify-release-candidate.sh` or `object_store/dev/release/verify-release-candidate.sh` are scripts in this repository that can assist in the verification process. Run it like: ``` ./dev/release/verify-release-candidate.sh 4.1.0 2 @@ -162,10 +191,18 @@ If the release is not approved, fix whatever the problem is and try again with t Move tarball to the release location in SVN, e.g. https://dist.apache.org/repos/dist/release/arrow/arrow-4.1.0/, using the `release-tarball.sh` script: +Rust Arrow Crates: + ```shell ./dev/release/release-tarball.sh 4.1.0 2 ``` +`object_store` + +```shell +./object_store/dev/release/release-tarball.sh 4.1.0 2 +``` + Congratulations! The release is now offical! ### Publish on Crates.io @@ -188,9 +225,17 @@ Verify that the Cargo.toml in the tarball contains the correct version (e.g. `version = "0.11.0"`) and then publish the crate with the following commands +Rust Arrow Crates: + ```shell (cd arrow && cargo publish) (cd arrow-flight && cargo publish) (cd parquet && cargo publish) (cd parquet_derive && cargo publish) ``` + +`object_store` + +```shell +cargo publish +``` diff --git a/dev/release/create-tarball.sh b/dev/release/create-tarball.sh index 06320171bb23..0463f89f77ae 100755 --- a/dev/release/create-tarball.sh +++ b/dev/release/create-tarball.sh @@ -53,6 +53,17 @@ fi tag=$1 rc=$2 + +# mac tar doesn't have --delete, so use gnutar +# e.g. brew install gtar +if command -v gtar &> /dev/null +then + echo "using gtar (gnu)tar" + tar=gtar +else + tar=tar +fi + release_hash=$(cd "${SOURCE_TOP_DIR}" && git rev-list --max-count=1 ${tag}) release=apache-arrow-rs-${tag} @@ -103,10 +114,18 @@ MAIL echo "---------------------------------------------------------" + # create containing the files in git at $release_hash # the files in the tarball are prefixed with {tag} (e.g. 4.0.1) +# use --delete to filter out: +# 1. `object_store` files +# 2. Workspace `Cargo.toml` file (which refers to object_store) mkdir -p ${distdir} -(cd "${SOURCE_TOP_DIR}" && git archive ${release_hash} --prefix ${release}/ | gzip > ${tarball}) +(cd "${SOURCE_TOP_DIR}" && \ + git archive ${release_hash} --prefix ${release}/ \ + | $tar --delete ${release}/'object_store' \ + | $tar --delete ${release}/'Cargo.toml' \ + | gzip > ${tarball}) echo "Running rat license checker on ${tarball}" ${SOURCE_DIR}/run-rat.sh ${tarball} diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 466f6fa45267..609a5851cad3 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -4,6 +4,7 @@ target/* dev/release/rat_exclude_files.txt arrow/test/data/* arrow/test/dependency/* +integration-testing/data/* parquet_derive/test/dependency/* .gitattributes **.gitignore diff --git a/dev/release/update_change_log.sh b/dev/release/update_change_log.sh index 93d674e9aff0..252cd285d92b 100755 --- a/dev/release/update_change_log.sh +++ b/dev/release/update_change_log.sh @@ -29,8 +29,8 @@ set -e -SINCE_TAG="17.0.0" -FUTURE_RELEASE="18.0.0" +SINCE_TAG="21.0.0" +FUTURE_RELEASE="22.0.0" SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" @@ -40,6 +40,8 @@ OUTPUT_PATH="${SOURCE_TOP_DIR}/CHANGELOG.md" # remove license header so github-changelog-generator has a clean base to append sed -i.bak '1,18d' "${OUTPUT_PATH}" +# use exclude-tags-regex to filter out tags used for object_store +# crates and only only look at tags that DO NOT begin with `object_store_` pushd "${SOURCE_TOP_DIR}" docker run -it --rm -e CHANGELOG_GITHUB_TOKEN="$CHANGELOG_GITHUB_TOKEN" -v "$(pwd)":/usr/local/src/your-app githubchangeloggenerator/github-changelog-generator \ --user apache \ @@ -48,6 +50,7 @@ docker run -it --rm -e CHANGELOG_GITHUB_TOKEN="$CHANGELOG_GITHUB_TOKEN" -v "$(pw --cache-log=.githubchangeloggenerator.cache.log \ --http-cache \ --max-issues=300 \ + --exclude-tags-regex "^object_store_\d+\.\d+\.\d+$" \ --since-tag ${SINCE_TAG} \ --future-release ${FUTURE_RELEASE} diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index a5ed04c6f8b8..cf8050c1c9f2 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -72,24 +72,6 @@ fetch_archive() { ${sha512_verify} ${dist_name}.tar.gz.sha512 } -verify_dir_artifact_signatures() { - # verify the signature and the checksums of each artifact - find $1 -name '*.asc' | while read sigfile; do - artifact=${sigfile/.asc/} - gpg --verify $sigfile $artifact || exit 1 - - # go into the directory because the checksum files contain only the - # basename of the artifact - pushd $(dirname $artifact) - base_artifact=$(basename $artifact) - if [ -f $base_artifact.sha256 ]; then - ${sha256_verify} $base_artifact.sha256 || exit 1 - fi - ${sha512_verify} $base_artifact.sha512 || exit 1 - popd - done -} - setup_tempdir() { cleanup() { if [ "${TEST_SUCCESS}" = "yes" ]; then @@ -123,7 +105,10 @@ test_source_distribution() { # raises on any formatting errors rustup component add rustfmt --toolchain stable - cargo fmt --all -- --check + (cd arrow && cargo fmt --check) + (cd arrow-flight && cargo fmt --check) + (cd parquet && cargo fmt --check) + (cd parquet_derive && cargo fmt --check) # Clone testing repositories if not cloned already git clone https://github.com/apache/arrow-testing.git arrow-testing-data @@ -139,8 +124,10 @@ test_source_distribution() { -e 's/^parquet = "([^"]*)"/parquet = { version = "\1", path = "..\/parquet" }/g' \ */Cargo.toml - cargo build - cargo test --all + (cd arrow && cargo build && cargo test) + (cd arrow-flight && cargo build && cargo test) + (cd parquet && cargo build && cargo test) + (cd parquet_derive && cargo build && cargo test) # verify that the crates can be published to crates.io pushd arrow diff --git a/integration-testing/Cargo.toml b/integration-testing/Cargo.toml index 4cff73aa7011..b9f6cf81855e 100644 --- a/integration-testing/Cargo.toml +++ b/integration-testing/Cargo.toml @@ -18,29 +18,33 @@ [package] name = "arrow-integration-testing" description = "Binaries used in the Arrow integration tests" -version = "18.0.0" +version = "22.0.0" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] license = "Apache-2.0" edition = "2021" publish = false -rust-version = "1.57" +rust-version = "1.62" [features] logging = ["tracing-subscriber"] [dependencies] -arrow = { path = "../arrow", default-features = false, features = [ "test_utils" ] } +arrow = { path = "../arrow", default-features = false, features = ["test_utils", "ipc", "ipc_compression", "json"] } arrow-flight = { path = "../arrow-flight", default-features = false } async-trait = { version = "0.1.41", default-features = false } clap = { version = "3", default-features = false, features = ["std", "derive"] } futures = { version = "0.3", default-features = false } -hex = { version = "0.4", default-features = false } -prost = { version = "0.10", default-features = false } -serde = { version = "1.0", default-features = false, features = ["rc"] } -serde_derive = { version = "1.0", default-features = false } -serde_json = { version = "1.0", default-features = false, features = ["preserve_order"] } +hex = { version = "0.4", default-features = false, features = ["std"] } +prost = { version = "0.11", default-features = false } +serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } +serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.0", default-features = false } -tonic = { version = "0.7", default-features = false } +tonic = { version = "0.8", default-features = false } tracing-subscriber = { version = "0.3.1", default-features = false, features = ["fmt"], optional = true } +num = { version = "0.4", default-features = false, features = ["std"] } +flate2 = { version = "1", default-features = false, features = ["rust_backend"] } + +[dev-dependencies] +tempfile = { version = "3", default-features = false } diff --git a/arrow/test/data/integration.json b/integration-testing/data/integration.json similarity index 100% rename from arrow/test/data/integration.json rename to integration-testing/data/integration.json diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs b/integration-testing/src/bin/arrow-json-integration-test.rs index 69b73b19f222..a7d7cf6ee7cb 100644 --- a/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/integration-testing/src/bin/arrow-json-integration-test.rs @@ -20,8 +20,7 @@ use arrow::datatypes::{DataType, Field}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; -use arrow::util::integration_util::*; -use arrow_integration_testing::read_json_file; +use arrow_integration_testing::{read_json_file, util::*}; use clap::Parser; use std::fs::File; @@ -91,7 +90,10 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> for f in reader.schema().fields() { fields.push(ArrowJsonField::from(f)); } - let schema = ArrowJsonSchema { fields }; + let schema = ArrowJsonSchema { + fields, + metadata: None, + }; let batches = reader .map(|batch| Ok(ArrowJsonBatch::from_batch(&batch?))) diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs index 62fe2b85d262..c01baa09a1f7 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -20,6 +20,7 @@ use std::collections::HashMap; use arrow::{ array::ArrayRef, + buffer::Buffer, datatypes::SchemaRef, ipc::{self, reader, writer}, record_batch::RecordBatch, @@ -264,7 +265,7 @@ async fn receive_batch_flight_data( while message.header_type() == ipc::MessageHeader::DictionaryBatch { reader::read_dictionary( - &data.data_body, + &Buffer::from(&data.data_body), message .header_as_dictionary_batch() .expect("Error parsing dictionary"), diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index 7ad3d18eb5ba..dee2fda3be3d 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use arrow::{ array::ArrayRef, + buffer::Buffer, datatypes::Schema, datatypes::SchemaRef, ipc::{self, reader}, @@ -282,7 +283,7 @@ async fn send_app_metadata( async fn record_batch_from_message( message: ipc::Message<'_>, - data_body: &[u8], + data_body: &Buffer, schema_ref: SchemaRef, dictionaries_by_id: &HashMap, ) -> Result { @@ -306,7 +307,7 @@ async fn record_batch_from_message( async fn dictionary_from_message( message: ipc::Message<'_>, - data_body: &[u8], + data_body: &Buffer, schema_ref: SchemaRef, dictionaries_by_id: &mut HashMap, ) -> Result<(), Status> { @@ -354,7 +355,7 @@ async fn save_uploaded_chunks( let batch = record_batch_from_message( message, - &data.data_body, + &Buffer::from(data.data_body), schema_ref.clone(), &dictionaries_by_id, ) @@ -365,7 +366,7 @@ async fn save_uploaded_chunks( ipc::MessageHeader::DictionaryBatch => { dictionary_from_message( message, - &data.data_body, + &Buffer::from(data.data_body), schema_ref.clone(), &mut dictionaries_by_id, ) diff --git a/integration-testing/src/lib.rs b/integration-testing/src/lib.rs index e4cc872ffd17..ffe112af72cd 100644 --- a/integration-testing/src/lib.rs +++ b/integration-testing/src/lib.rs @@ -17,26 +17,17 @@ //! Common code used in the integration test binaries -use hex::decode; use serde_json::Value; -use arrow::util::integration_util::ArrowJsonBatch; +use util::*; -use arrow::array::*; -use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; -use arrow::error::{ArrowError, Result}; +use arrow::datatypes::Schema; +use arrow::error::Result; use arrow::record_batch::RecordBatch; -use arrow::{ - buffer::Buffer, - buffer::MutableBuffer, - datatypes::ToByteSlice, - util::{bit_util, integration_util::*}, -}; - +use arrow::util::test_util::arrow_test_data; use std::collections::HashMap; use std::fs::File; use std::io::BufReader; -use std::sync::Arc; /// The expected username for the basic auth integration test. pub const AUTH_USERNAME: &str = "arrow"; @@ -45,6 +36,7 @@ pub const AUTH_PASSWORD: &str = "flight"; pub mod flight_client_scenarios; pub mod flight_server_scenarios; +pub mod util; pub struct ArrowFile { pub schema: Schema, @@ -86,679 +78,21 @@ pub fn read_json_file(json_name: &str) -> Result { }) } -fn record_batch_from_json( - schema: &Schema, - json_batch: ArrowJsonBatch, - json_dictionaries: Option<&HashMap>, -) -> Result { - let mut columns = vec![]; - - for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { - let col = array_from_json(field, json_col, json_dictionaries)?; - columns.push(col); - } - - RecordBatch::try_new(Arc::new(schema.clone()), columns) -} - -/// Construct an Arrow array from a partially typed JSON column -fn array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dictionaries: Option<&HashMap>, -) -> Result { - match field.data_type() { - DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), - DataType::Boolean => { - let mut b = BooleanBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_bool().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int8 => { - let mut b = Int8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to get {:?} as int64", - value - )) - })? as i8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int16 => { - let mut b = Int16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - let mut b = Int32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i32), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::Int64 - | DataType::Date64 - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) - | DataType::Interval(IntervalUnit::DayTime) => { - let mut b = Int64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(match value { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => { - s.parse().expect("Unable to parse string as i64") - } - Value::Object(ref map) - if map.contains_key("days") - && map.contains_key("milliseconds") => - { - match field.data_type() { - DataType::Interval(IntervalUnit::DayTime) => { - let days = map.get("days").unwrap(); - let milliseconds = map.get("milliseconds").unwrap(); - - match (days, milliseconds) { - (Value::Number(d), Value::Number(m)) => { - let mut bytes = [0_u8; 8]; - let m = (m.as_i64().unwrap() as i32) - .to_le_bytes(); - let d = (d.as_i64().unwrap() as i32) - .to_le_bytes(); - - let c = [d, m].concat(); - bytes.copy_from_slice(c.as_slice()); - i64::from_le_bytes(bytes) - } - _ => panic!( - "Unable to parse {:?} as interval daytime", - value - ), - } - } - _ => panic!( - "Unable to parse {:?} as interval daytime", - value - ), - } - } - _ => panic!("Unable to parse {:?} as number", value), - }), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::UInt8 => { - let mut b = UInt8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt16 => { - let mut b = UInt16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt32 => { - let mut b = UInt32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt64 => { - let mut b = UInt64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value( - value - .as_str() - .unwrap() - .parse() - .expect("Unable to parse string as u64"), - ), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let mut b = IntervalMonthDayNanoBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(match value { - Value::Object(v) => { - let months = v.get("months").unwrap(); - let days = v.get("days").unwrap(); - let nanoseconds = v.get("nanoseconds").unwrap(); - match (months, days, nanoseconds) { - ( - Value::Number(months), - Value::Number(days), - Value::Number(nanoseconds), - ) => { - let months = months.as_i64().unwrap() as i32; - let days = days.as_i64().unwrap() as i32; - let nanoseconds = nanoseconds.as_i64().unwrap(); - let months_days_ns: i128 = ((nanoseconds as i128) - & 0xFFFFFFFFFFFFFFFF) - << 64 - | ((days as i128) & 0xFFFFFFFF) << 32 - | ((months as i128) & 0xFFFFFFFF); - months_days_ns - } - (_, _, _) => { - panic!("Unable to parse {:?} as MonthDayNano", v) - } - } - } - _ => panic!("Unable to parse {:?} as MonthDayNano", value), - }), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float32 => { - let mut b = Float32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap() as f32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float64 => { - let mut b = Float64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Binary => { - let mut b = BinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeBinary => { - let mut b = LargeBinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Utf8 => { - let mut b = StringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeUtf8 => { - let mut b = LargeStringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::FixedSizeBinary(len) => { - let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = hex::decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::List(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| v.as_i64().unwrap() as i32) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.into_data()) - .null_bit_buffer(Some(null_buf)) - .build() - .unwrap(); - Ok(Arc::new(ListArray::from(list_data))) - } - DataType::LargeList(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| match v { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => s.parse::().unwrap(), - _ => panic!("64-bit offset must be either string or number"), - }) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.into_data()) - .null_bit_buffer(Some(null_buf)) - .build() - .unwrap(); - Ok(Arc::new(LargeListArray::from(list_data))) - } - DataType::FixedSizeList(child_field, _) => { - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let null_buf = create_null_buf(&json_col); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .add_child_data(child_array.into_data()) - .null_bit_buffer(Some(null_buf)) - .build() - .unwrap(); - Ok(Arc::new(FixedSizeListArray::from(list_data))) - } - DataType::Struct(fields) => { - // construct struct with null data - let null_buf = create_null_buf(&json_col); - let mut array_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .null_bit_buffer(Some(null_buf)); - - for (field, col) in fields.iter().zip(json_col.children.unwrap()) { - let array = array_from_json(field, col, dictionaries)?; - array_data = array_data.add_child_data(array.into_data()); - } - - let array = StructArray::from(array_data.build().unwrap()); - Ok(Arc::new(array)) - } - DataType::Dictionary(key_type, value_type) => { - let dict_id = field.dict_id().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find dict_id for field {:?}", - field - )) - })?; - // find dictionary - let dictionary = dictionaries - .ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find any dictionaries for field {:?}", - field - )) - })? - .get(&dict_id); - match dictionary { - Some(dictionary) => dictionary_array_from_json( - field, - json_col, - key_type, - value_type, - dictionary, - dictionaries, - ), - None => Err(ArrowError::JsonError(format!( - "Unable to find dictionary for field {:?}", - field - ))), - } - } - DataType::Decimal(precision, scale) => { - let mut b = DecimalBuilder::new(json_col.count, *precision, *scale); - // C++ interop tests involve incompatible decimal values - unsafe { - b.disable_value_validation(); - } - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap().parse::().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Map(child_field, _) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| v.as_i64().unwrap() as i32) - .collect(); - let array_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.into_data()) - .null_bit_buffer(Some(null_buf)) - .build() - .unwrap(); - - let array = MapArray::from(array_data); - Ok(Arc::new(array)) - } - DataType::Union(fields, field_type_ids, _) => { - let type_ids = if let Some(type_id) = json_col.type_id { - type_id - } else { - return Err(ArrowError::JsonError( - "Cannot find expected type_id in json column".to_string(), - )); - }; - - let offset: Option = json_col.offset.map(|offsets| { - let offsets: Vec = - offsets.iter().map(|v| v.as_i64().unwrap() as i32).collect(); - Buffer::from(&offsets.to_byte_slice()) - }); - - let mut children: Vec<(Field, Arc)> = vec![]; - for (field, col) in fields.iter().zip(json_col.children.unwrap()) { - let array = array_from_json(field, col, dictionaries)?; - children.push((field.clone(), array)); - } - - let array = UnionArray::try_new( - field_type_ids, - Buffer::from(&type_ids.to_byte_slice()), - offset, - children, - ) - .unwrap(); - Ok(Arc::new(array)) - } - t => Err(ArrowError::JsonError(format!( - "data type {:?} not supported", - t - ))), - } -} - -fn dictionary_array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dict_key: &DataType, - dict_value: &DataType, - dictionary: &ArrowJsonDictionaryBatch, - dictionaries: Option<&HashMap>, -) -> Result { - match dict_key { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => { - let null_buf = create_null_buf(&json_col); - - // build the key data into a buffer, then construct values separately - let key_field = Field::new_dict( - "key", - dict_key.clone(), - field.is_nullable(), - field - .dict_id() - .expect("Dictionary fields must have a dict_id value"), - field - .dict_is_ordered() - .expect("Dictionary fields must have a dict_is_ordered value"), - ); - let keys = array_from_json(&key_field, json_col, None)?; - // note: not enough info on nullability of dictionary - let value_field = Field::new("value", dict_value.clone(), true); - let values = array_from_json( - &value_field, - dictionary.data.columns[0].clone(), - dictionaries, - )?; - - // convert key and value to dictionary data - let dict_data = ArrayData::builder(field.data_type().clone()) - .len(keys.len()) - .add_buffer(keys.data().buffers()[0].clone()) - .null_bit_buffer(Some(null_buf)) - .add_child_data(values.into_data()) - .build() - .unwrap(); - - let array = match dict_key { - DataType::Int8 => { - Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef - } - DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), - DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), - DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), - DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), - DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), - DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), - DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), - _ => unreachable!(), - }; - Ok(array) - } - _ => Err(ArrowError::JsonError(format!( - "Dictionary key type {:?} not supported", - dict_key - ))), - } -} - -/// A helper to create a null buffer from a Vec -fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { - let num_bytes = bit_util::ceil(json_col.count, 8); - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); - json_col - .validity - .clone() - .unwrap() - .iter() - .enumerate() - .for_each(|(i, v)| { - let null_slice = null_buf.as_slice_mut(); - if *v != 0 { - bit_util::set_bit(null_slice, i); - } - }); - null_buf.into() +/// Read gzipped JSON test file +pub fn read_gzip_json(version: &str, path: &str) -> ArrowJson { + use flate2::read::GzDecoder; + use std::io::Read; + + let testdata = arrow_test_data(); + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.json.gz", + testdata, version, path + )) + .unwrap(); + let mut gz = GzDecoder::new(&file); + let mut s = String::new(); + gz.read_to_string(&mut s).unwrap(); + // convert to Arrow JSON + let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); + arrow_json } diff --git a/integration-testing/src/util.rs b/integration-testing/src/util.rs new file mode 100644 index 000000000000..e098c4e1491a --- /dev/null +++ b/integration-testing/src/util.rs @@ -0,0 +1,1344 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utils for JSON integration testing +//! +//! These utilities define structs that read the integration JSON format for integration testing purposes. + +use hex::decode; +use num::BigInt; +use num::Signed; +use serde::{Deserialize, Serialize}; +use serde_json::{Map as SJMap, Value}; +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::*; +use arrow::buffer::{Buffer, MutableBuffer}; +use arrow::compute; +use arrow::datatypes::*; +use arrow::error::{ArrowError, Result}; +use arrow::record_batch::{RecordBatch, RecordBatchReader}; +use arrow::util::bit_util; +use arrow::util::decimal::Decimal256; + +/// A struct that represents an Arrow file with a schema and record batches +#[derive(Deserialize, Serialize, Debug)] +pub struct ArrowJson { + pub schema: ArrowJsonSchema, + pub batches: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub dictionaries: Option>, +} + +/// A struct that partially reads the Arrow JSON schema. +/// +/// Fields are left as JSON `Value` as they vary by `DataType` +#[derive(Deserialize, Serialize, Debug)] +pub struct ArrowJsonSchema { + pub fields: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>>, +} + +/// Fields are left as JSON `Value` as they vary by `DataType` +#[derive(Deserialize, Serialize, Debug)] +pub struct ArrowJsonField { + pub name: String, + #[serde(rename = "type")] + pub field_type: Value, + pub nullable: bool, + pub children: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub dictionary: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +impl From<&Field> for ArrowJsonField { + fn from(field: &Field) -> Self { + let metadata_value = match field.metadata() { + Some(kv_list) => { + let mut array = Vec::new(); + for (k, v) in kv_list { + let mut kv_map = SJMap::new(); + kv_map.insert(k.clone(), Value::String(v.clone())); + array.push(Value::Object(kv_map)); + } + if !array.is_empty() { + Some(Value::Array(array)) + } else { + None + } + } + _ => None, + }; + + Self { + name: field.name().to_string(), + field_type: field.data_type().to_json(), + nullable: field.is_nullable(), + children: vec![], + dictionary: None, // TODO: not enough info + metadata: metadata_value, + } + } +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct ArrowJsonFieldDictionary { + pub id: i64, + #[serde(rename = "indexType")] + pub index_type: DictionaryIndexType, + #[serde(rename = "isOrdered")] + pub is_ordered: bool, +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct DictionaryIndexType { + pub name: String, + #[serde(rename = "isSigned")] + pub is_signed: bool, + #[serde(rename = "bitWidth")] + pub bit_width: i64, +} + +/// A struct that partially reads the Arrow JSON record batch +#[derive(Deserialize, Serialize, Debug, Clone)] +pub struct ArrowJsonBatch { + count: usize, + pub columns: Vec, +} + +/// A struct that partially reads the Arrow JSON dictionary batch +#[derive(Deserialize, Serialize, Debug, Clone)] +#[allow(non_snake_case)] +pub struct ArrowJsonDictionaryBatch { + pub id: i64, + pub data: ArrowJsonBatch, +} + +/// A struct that partially reads the Arrow JSON column/array +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct ArrowJsonColumn { + name: String, + pub count: usize, + #[serde(rename = "VALIDITY")] + pub validity: Option>, + #[serde(rename = "DATA")] + pub data: Option>, + #[serde(rename = "OFFSET")] + pub offset: Option>, // leaving as Value as 64-bit offsets are strings + #[serde(rename = "TYPE_ID")] + pub type_id: Option>, + pub children: Option>, +} + +impl ArrowJson { + /// Compare the Arrow JSON with a record batch reader + pub fn equals_reader(&self, reader: &mut dyn RecordBatchReader) -> Result { + if !self.schema.equals_schema(&reader.schema()) { + return Ok(false); + } + + for json_batch in self.get_record_batches()?.into_iter() { + let batch = reader.next(); + match batch { + Some(Ok(batch)) => { + if json_batch != batch { + println!("json: {:?}", json_batch); + println!("batch: {:?}", batch); + return Ok(false); + } + } + _ => return Ok(false), + } + } + + Ok(true) + } + + pub fn get_record_batches(&self) -> Result> { + let schema = self.schema.to_arrow_schema()?; + + let mut dictionaries = HashMap::new(); + self.dictionaries.iter().for_each(|dict_batches| { + dict_batches.iter().for_each(|d| { + dictionaries.insert(d.id, d.clone()); + }); + }); + + let batches: Result> = self + .batches + .iter() + .map(|col| record_batch_from_json(&schema, col.clone(), Some(&dictionaries))) + .collect(); + + batches + } +} + +impl ArrowJsonSchema { + /// Compare the Arrow JSON schema with the Arrow `Schema` + fn equals_schema(&self, schema: &Schema) -> bool { + let field_len = self.fields.len(); + if field_len != schema.fields().len() { + return false; + } + for i in 0..field_len { + let json_field = &self.fields[i]; + let field = schema.field(i); + if !json_field.equals_field(field) { + return false; + } + } + true + } + + fn to_arrow_schema(&self) -> Result { + let arrow_fields: Result> = self + .fields + .iter() + .map(|field| field.to_arrow_field()) + .collect(); + + if let Some(metadatas) = &self.metadata { + let mut metadata: HashMap = HashMap::new(); + + metadatas.iter().for_each(|pair| { + let key = pair.get("key").unwrap(); + let value = pair.get("value").unwrap(); + metadata.insert(key.clone(), value.clone()); + }); + + Ok(Schema::new_with_metadata(arrow_fields?, metadata)) + } else { + Ok(Schema::new(arrow_fields?)) + } + } +} + +impl ArrowJsonField { + /// Compare the Arrow JSON field with the Arrow `Field` + fn equals_field(&self, field: &Field) -> bool { + // convert to a field + match self.to_arrow_field() { + Ok(self_field) => { + assert_eq!(&self_field, field, "Arrow fields not the same"); + true + } + Err(e) => { + eprintln!( + "Encountered error while converting JSON field to Arrow field: {:?}", + e + ); + false + } + } + } + + /// Convert to an Arrow Field + /// TODO: convert to use an Into + fn to_arrow_field(&self) -> Result { + // a bit regressive, but we have to convert the field to JSON in order to convert it + let field = serde_json::to_value(self)?; + Field::from(&field) + } +} + +pub fn record_batch_from_json( + schema: &Schema, + json_batch: ArrowJsonBatch, + json_dictionaries: Option<&HashMap>, +) -> Result { + let mut columns = vec![]; + + for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { + let col = array_from_json(field, json_col, json_dictionaries)?; + columns.push(col); + } + + RecordBatch::try_new(Arc::new(schema.clone()), columns) +} + +/// Construct an Arrow array from a partially typed JSON column +pub fn array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dictionaries: Option<&HashMap>, +) -> Result { + match field.data_type() { + DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), + DataType::Boolean => { + let mut b = BooleanBuilder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_bool().unwrap()), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::Int8 => { + let mut b = Int8Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to get {:?} as int64", + value + )) + })? as i8), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::Int16 => { + let mut b = Int16Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i16), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + let mut b = Int32Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i32), + _ => b.append_null(), + }; + } + let array = Arc::new(b.finish()) as ArrayRef; + compute::cast(&array, field.data_type()) + } + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => { + let mut b = Int64Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(match value { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => { + s.parse().expect("Unable to parse string as i64") + } + Value::Object(ref map) + if map.contains_key("days") + && map.contains_key("milliseconds") => + { + match field.data_type() { + DataType::Interval(IntervalUnit::DayTime) => { + let days = map.get("days").unwrap(); + let milliseconds = map.get("milliseconds").unwrap(); + + match (days, milliseconds) { + (Value::Number(d), Value::Number(m)) => { + let mut bytes = [0_u8; 8]; + let m = (m.as_i64().unwrap() as i32) + .to_le_bytes(); + let d = (d.as_i64().unwrap() as i32) + .to_le_bytes(); + + let c = [d, m].concat(); + bytes.copy_from_slice(c.as_slice()); + i64::from_le_bytes(bytes) + } + _ => panic!( + "Unable to parse {:?} as interval daytime", + value + ), + } + } + _ => panic!( + "Unable to parse {:?} as interval daytime", + value + ), + } + } + _ => panic!("Unable to parse {:?} as number", value), + }), + _ => b.append_null(), + }; + } + let array = Arc::new(b.finish()) as ArrayRef; + compute::cast(&array, field.data_type()) + } + DataType::UInt8 => { + let mut b = UInt8Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u8), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt16 => { + let mut b = UInt16Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u16), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt32 => { + let mut b = UInt32Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u32), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt64 => { + let mut b = UInt64Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + if value.is_string() { + b.append_value( + value + .as_str() + .unwrap() + .parse() + .expect("Unable to parse string as u64"), + ) + } else if value.is_number() { + b.append_value( + value.as_u64().expect("Unable to read number as u64"), + ) + } else { + panic!("Unable to parse value {:?} as u64", value) + } + } + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let mut b = IntervalMonthDayNanoBuilder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(match value { + Value::Object(v) => { + let months = v.get("months").unwrap(); + let days = v.get("days").unwrap(); + let nanoseconds = v.get("nanoseconds").unwrap(); + match (months, days, nanoseconds) { + ( + Value::Number(months), + Value::Number(days), + Value::Number(nanoseconds), + ) => { + let months = months.as_i64().unwrap() as i32; + let days = days.as_i64().unwrap() as i32; + let nanoseconds = nanoseconds.as_i64().unwrap(); + let months_days_ns: i128 = ((nanoseconds as i128) + & 0xFFFFFFFFFFFFFFFF) + << 64 + | ((days as i128) & 0xFFFFFFFF) << 32 + | ((months as i128) & 0xFFFFFFFF); + months_days_ns + } + (_, _, _) => { + panic!("Unable to parse {:?} as MonthDayNano", v) + } + } + } + _ => panic!("Unable to parse {:?} as MonthDayNano", value), + }), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::Float32 => { + let mut b = Float32Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap() as f32), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::Float64 => { + let mut b = Float64Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap()), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::Binary => { + let mut b = BinaryBuilder::with_capacity(json_col.count, 1024); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeBinary => { + let mut b = LargeBinaryBuilder::with_capacity(json_col.count, 1024); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::Utf8 => { + let mut b = StringBuilder::with_capacity(json_col.count, 1024); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeUtf8 => { + let mut b = LargeStringBuilder::with_capacity(json_col.count, 1024); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::FixedSizeBinary(len) => { + let mut b = FixedSizeBinaryBuilder::with_capacity(json_col.count, *len); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = hex::decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v)? + } + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::List(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| v.as_i64().unwrap() as i32) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.into_data()) + .null_bit_buffer(Some(null_buf)) + .build() + .unwrap(); + Ok(Arc::new(ListArray::from(list_data))) + } + DataType::LargeList(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| match v { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => s.parse::().unwrap(), + _ => panic!("64-bit offset must be either string or number"), + }) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.into_data()) + .null_bit_buffer(Some(null_buf)) + .build() + .unwrap(); + Ok(Arc::new(LargeListArray::from(list_data))) + } + DataType::FixedSizeList(child_field, _) => { + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let null_buf = create_null_buf(&json_col); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .add_child_data(child_array.into_data()) + .null_bit_buffer(Some(null_buf)) + .build() + .unwrap(); + Ok(Arc::new(FixedSizeListArray::from(list_data))) + } + DataType::Struct(fields) => { + // construct struct with null data + let null_buf = create_null_buf(&json_col); + let mut array_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .null_bit_buffer(Some(null_buf)); + + for (field, col) in fields.iter().zip(json_col.children.unwrap()) { + let array = array_from_json(field, col, dictionaries)?; + array_data = array_data.add_child_data(array.into_data()); + } + + let array = StructArray::from(array_data.build().unwrap()); + Ok(Arc::new(array)) + } + DataType::Dictionary(key_type, value_type) => { + let dict_id = field.dict_id().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find dict_id for field {:?}", + field + )) + })?; + // find dictionary + let dictionary = dictionaries + .ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find any dictionaries for field {:?}", + field + )) + })? + .get(&dict_id); + match dictionary { + Some(dictionary) => dictionary_array_from_json( + field, + json_col, + key_type, + value_type, + dictionary, + dictionaries, + ), + None => Err(ArrowError::JsonError(format!( + "Unable to find dictionary for field {:?}", + field + ))), + } + } + DataType::Decimal128(precision, scale) => { + let mut b = + Decimal128Builder::with_capacity(json_col.count, *precision, *scale); + // C++ interop tests involve incompatible decimal values + unsafe { + b.disable_value_validation(); + } + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + b.append_value(value.as_str().unwrap().parse::().unwrap())? + } + _ => b.append_null(), + }; + } + Ok(Arc::new(b.finish())) + } + DataType::Decimal256(precision, scale) => { + let mut b = + Decimal256Builder::with_capacity(json_col.count, *precision, *scale); + // C++ interop tests involve incompatible decimal values + unsafe { + b.disable_value_validation(); + } + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let str = value.as_str().unwrap(); + let integer = BigInt::parse_bytes(str.as_bytes(), 10).unwrap(); + let integer_bytes = integer.to_signed_bytes_le(); + let mut bytes = if integer.is_positive() { + [0_u8; 32] + } else { + [255_u8; 32] + }; + bytes[0..integer_bytes.len()] + .copy_from_slice(integer_bytes.as_slice()); + let decimal = + Decimal256::try_new_from_bytes(*precision, *scale, &bytes) + .unwrap(); + b.append_value(&decimal)?; + } + _ => b.append_null(), + } + } + Ok(Arc::new(b.finish())) + } + DataType::Map(child_field, _) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| v.as_i64().unwrap() as i32) + .collect(); + let array_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.into_data()) + .null_bit_buffer(Some(null_buf)) + .build() + .unwrap(); + + let array = MapArray::from(array_data); + Ok(Arc::new(array)) + } + DataType::Union(fields, field_type_ids, _) => { + let type_ids = if let Some(type_id) = json_col.type_id { + type_id + } else { + return Err(ArrowError::JsonError( + "Cannot find expected type_id in json column".to_string(), + )); + }; + + let offset: Option = json_col.offset.map(|offsets| { + let offsets: Vec = + offsets.iter().map(|v| v.as_i64().unwrap() as i32).collect(); + Buffer::from(&offsets.to_byte_slice()) + }); + + let mut children: Vec<(Field, Arc)> = vec![]; + for (field, col) in fields.iter().zip(json_col.children.unwrap()) { + let array = array_from_json(field, col, dictionaries)?; + children.push((field.clone(), array)); + } + + let array = UnionArray::try_new( + field_type_ids, + Buffer::from(&type_ids.to_byte_slice()), + offset, + children, + ) + .unwrap(); + Ok(Arc::new(array)) + } + t => Err(ArrowError::JsonError(format!( + "data type {:?} not supported", + t + ))), + } +} + +pub fn dictionary_array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dict_key: &DataType, + dict_value: &DataType, + dictionary: &ArrowJsonDictionaryBatch, + dictionaries: Option<&HashMap>, +) -> Result { + match dict_key { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + let null_buf = create_null_buf(&json_col); + + // build the key data into a buffer, then construct values separately + let key_field = Field::new_dict( + "key", + dict_key.clone(), + field.is_nullable(), + field + .dict_id() + .expect("Dictionary fields must have a dict_id value"), + field + .dict_is_ordered() + .expect("Dictionary fields must have a dict_is_ordered value"), + ); + let keys = array_from_json(&key_field, json_col, None)?; + // note: not enough info on nullability of dictionary + let value_field = Field::new("value", dict_value.clone(), true); + let values = array_from_json( + &value_field, + dictionary.data.columns[0].clone(), + dictionaries, + )?; + + // convert key and value to dictionary data + let dict_data = ArrayData::builder(field.data_type().clone()) + .len(keys.len()) + .add_buffer(keys.data().buffers()[0].clone()) + .null_bit_buffer(Some(null_buf)) + .add_child_data(values.into_data()) + .build() + .unwrap(); + + let array = match dict_key { + DataType::Int8 => { + Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef + } + DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), + DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), + DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), + DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), + DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), + DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), + DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), + _ => unreachable!(), + }; + Ok(array) + } + _ => Err(ArrowError::JsonError(format!( + "Dictionary key type {:?} not supported", + dict_key + ))), + } +} + +/// A helper to create a null buffer from a Vec +fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { + let num_bytes = bit_util::ceil(json_col.count, 8); + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + json_col + .validity + .clone() + .unwrap() + .iter() + .enumerate() + .for_each(|(i, v)| { + let null_slice = null_buf.as_slice_mut(); + if *v != 0 { + bit_util::set_bit(null_slice, i); + } + }); + null_buf.into() +} + +impl ArrowJsonBatch { + pub fn from_batch(batch: &RecordBatch) -> ArrowJsonBatch { + let mut json_batch = ArrowJsonBatch { + count: batch.num_rows(), + columns: Vec::with_capacity(batch.num_columns()), + }; + + for (col, field) in batch.columns().iter().zip(batch.schema().fields.iter()) { + let json_col = match field.data_type() { + DataType::Int8 => { + let col = col.as_any().downcast_ref::().unwrap(); + + let mut validity: Vec = Vec::with_capacity(col.len()); + let mut data: Vec = Vec::with_capacity(col.len()); + + for i in 0..col.len() { + if col.is_null(i) { + validity.push(1); + data.push(0i8.into()); + } else { + validity.push(0); + data.push(col.value(i).into()); + } + } + + ArrowJsonColumn { + name: field.name().clone(), + count: col.len(), + validity: Some(validity), + data: Some(data), + offset: None, + type_id: None, + children: None, + } + } + _ => ArrowJsonColumn { + name: field.name().clone(), + count: col.len(), + validity: None, + data: None, + offset: None, + type_id: None, + children: None, + }, + }; + + json_batch.columns.push(json_col); + } + + json_batch + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::fs::File; + use std::io::Read; + use std::sync::Arc; + + use arrow::buffer::Buffer; + + #[test] + fn test_schema_equality() { + let json = r#" + { + "fields": [ + { + "name": "c1", + "type": {"name": "int", "isSigned": true, "bitWidth": 32}, + "nullable": true, + "children": [] + }, + { + "name": "c2", + "type": {"name": "floatingpoint", "precision": "DOUBLE"}, + "nullable": true, + "children": [] + }, + { + "name": "c3", + "type": {"name": "utf8"}, + "nullable": true, + "children": [] + }, + { + "name": "c4", + "type": { + "name": "list" + }, + "nullable": true, + "children": [ + { + "name": "custom_item", + "type": { + "name": "int", + "isSigned": true, + "bitWidth": 32 + }, + "nullable": false, + "children": [] + } + ] + } + ] + }"#; + let json_schema: ArrowJsonSchema = serde_json::from_str(json).unwrap(); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::Utf8, true), + Field::new( + "c4", + DataType::List(Box::new(Field::new( + "custom_item", + DataType::Int32, + false, + ))), + true, + ), + ]); + assert!(json_schema.equals_schema(&schema)); + } + + #[test] + fn test_arrow_data_equality() { + let secs_tz = Some("Europe/Budapest".to_string()); + let millis_tz = Some("America/New_York".to_string()); + let micros_tz = Some("UTC".to_string()); + let nanos_tz = Some("Africa/Johannesburg".to_string()); + + let schema = + Schema::new(vec![ + Field::new("bools-with-metadata-map", DataType::Boolean, true) + .with_metadata(Some( + [("k".to_string(), "v".to_string())] + .iter() + .cloned() + .collect(), + )), + Field::new("bools-with-metadata-vec", DataType::Boolean, true) + .with_metadata(Some( + [("k2".to_string(), "v2".to_string())] + .iter() + .cloned() + .collect(), + )), + Field::new("bools", DataType::Boolean, true), + Field::new("int8s", DataType::Int8, true), + Field::new("int16s", DataType::Int16, true), + Field::new("int32s", DataType::Int32, true), + Field::new("int64s", DataType::Int64, true), + Field::new("uint8s", DataType::UInt8, true), + Field::new("uint16s", DataType::UInt16, true), + Field::new("uint32s", DataType::UInt32, true), + Field::new("uint64s", DataType::UInt64, true), + Field::new("float32s", DataType::Float32, true), + Field::new("float64s", DataType::Float64, true), + Field::new("date_days", DataType::Date32, true), + Field::new("date_millis", DataType::Date64, true), + Field::new("time_secs", DataType::Time32(TimeUnit::Second), true), + Field::new("time_millis", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_micros", DataType::Time64(TimeUnit::Microsecond), true), + Field::new("time_nanos", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("ts_secs", DataType::Timestamp(TimeUnit::Second, None), true), + Field::new( + "ts_millis", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micros", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "ts_nanos", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new( + "ts_secs_tz", + DataType::Timestamp(TimeUnit::Second, secs_tz.clone()), + true, + ), + Field::new( + "ts_millis_tz", + DataType::Timestamp(TimeUnit::Millisecond, millis_tz.clone()), + true, + ), + Field::new( + "ts_micros_tz", + DataType::Timestamp(TimeUnit::Microsecond, micros_tz.clone()), + true, + ), + Field::new( + "ts_nanos_tz", + DataType::Timestamp(TimeUnit::Nanosecond, nanos_tz.clone()), + true, + ), + Field::new("utf8s", DataType::Utf8, true), + Field::new( + "lists", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "structs", + DataType::Struct(vec![ + Field::new("int32s", DataType::Int32, true), + Field::new("utf8s", DataType::Utf8, true), + ]), + true, + ), + ]); + + let bools_with_metadata_map = + BooleanArray::from(vec![Some(true), None, Some(false)]); + let bools_with_metadata_vec = + BooleanArray::from(vec![Some(true), None, Some(false)]); + let bools = BooleanArray::from(vec![Some(true), None, Some(false)]); + let int8s = Int8Array::from(vec![Some(1), None, Some(3)]); + let int16s = Int16Array::from(vec![Some(1), None, Some(3)]); + let int32s = Int32Array::from(vec![Some(1), None, Some(3)]); + let int64s = Int64Array::from(vec![Some(1), None, Some(3)]); + let uint8s = UInt8Array::from(vec![Some(1), None, Some(3)]); + let uint16s = UInt16Array::from(vec![Some(1), None, Some(3)]); + let uint32s = UInt32Array::from(vec![Some(1), None, Some(3)]); + let uint64s = UInt64Array::from(vec![Some(1), None, Some(3)]); + let float32s = Float32Array::from(vec![Some(1.0), None, Some(3.0)]); + let float64s = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); + let date_days = Date32Array::from(vec![Some(1196848), None, None]); + let date_millis = Date64Array::from(vec![ + Some(167903550396207), + Some(29923997007884), + Some(30612271819236), + ]); + let time_secs = + Time32SecondArray::from(vec![Some(27974), Some(78592), Some(43207)]); + let time_millis = Time32MillisecondArray::from(vec![ + Some(6613125), + Some(74667230), + Some(52260079), + ]); + let time_micros = + Time64MicrosecondArray::from(vec![Some(62522958593), None, None]); + let time_nanos = Time64NanosecondArray::from(vec![ + Some(73380123595985), + None, + Some(16584393546415), + ]); + let ts_secs = TimestampSecondArray::from_opt_vec( + vec![None, Some(193438817552), None], + None, + ); + let ts_millis = TimestampMillisecondArray::from_opt_vec( + vec![None, Some(38606916383008), Some(58113709376587)], + None, + ); + let ts_micros = + TimestampMicrosecondArray::from_opt_vec(vec![None, None, None], None); + let ts_nanos = TimestampNanosecondArray::from_opt_vec( + vec![None, None, Some(-6473623571954960143)], + None, + ); + let ts_secs_tz = TimestampSecondArray::from_opt_vec( + vec![None, Some(193438817552), None], + secs_tz, + ); + let ts_millis_tz = TimestampMillisecondArray::from_opt_vec( + vec![None, Some(38606916383008), Some(58113709376587)], + millis_tz, + ); + let ts_micros_tz = + TimestampMicrosecondArray::from_opt_vec(vec![None, None, None], micros_tz); + let ts_nanos_tz = TimestampNanosecondArray::from_opt_vec( + vec![None, None, Some(-6473623571954960143)], + nanos_tz, + ); + let utf8s = StringArray::from(vec![Some("aa"), None, Some("bbb")]); + + let value_data = Int32Array::from(vec![None, Some(2), None, None]); + let value_offsets = Buffer::from_slice_ref(&[0, 3, 4, 4]); + let list_data_type = + DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + let list_data = ArrayData::builder(list_data_type) + .len(3) + .add_buffer(value_offsets) + .add_child_data(value_data.into_data()) + .null_bit_buffer(Some(Buffer::from([0b00000011]))) + .build() + .unwrap(); + let lists = ListArray::from(list_data); + + let structs_int32s = Int32Array::from(vec![None, Some(-2), None]); + let structs_utf8s = StringArray::from(vec![None, None, Some("aaaaaa")]); + let struct_data_type = DataType::Struct(vec![ + Field::new("int32s", DataType::Int32, true), + Field::new("utf8s", DataType::Utf8, true), + ]); + let struct_data = ArrayData::builder(struct_data_type) + .len(3) + .add_child_data(structs_int32s.data().clone()) + .add_child_data(structs_utf8s.data().clone()) + .null_bit_buffer(Some(Buffer::from([0b00000011]))) + .build() + .unwrap(); + let structs = StructArray::from(struct_data); + + let record_batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(bools_with_metadata_map), + Arc::new(bools_with_metadata_vec), + Arc::new(bools), + Arc::new(int8s), + Arc::new(int16s), + Arc::new(int32s), + Arc::new(int64s), + Arc::new(uint8s), + Arc::new(uint16s), + Arc::new(uint32s), + Arc::new(uint64s), + Arc::new(float32s), + Arc::new(float64s), + Arc::new(date_days), + Arc::new(date_millis), + Arc::new(time_secs), + Arc::new(time_millis), + Arc::new(time_micros), + Arc::new(time_nanos), + Arc::new(ts_secs), + Arc::new(ts_millis), + Arc::new(ts_micros), + Arc::new(ts_nanos), + Arc::new(ts_secs_tz), + Arc::new(ts_millis_tz), + Arc::new(ts_micros_tz), + Arc::new(ts_nanos_tz), + Arc::new(utf8s), + Arc::new(lists), + Arc::new(structs), + ], + ) + .unwrap(); + let mut file = File::open("data/integration.json").unwrap(); + let mut json = String::new(); + file.read_to_string(&mut json).unwrap(); + let arrow_json: ArrowJson = serde_json::from_str(&json).unwrap(); + // test schemas + assert!(arrow_json.schema.equals_schema(&schema)); + // test record batch + assert_eq!(arrow_json.get_record_batches().unwrap()[0], record_batch); + } +} diff --git a/integration-testing/tests/ipc_reader.rs b/integration-testing/tests/ipc_reader.rs new file mode 100644 index 000000000000..778d1ee77d3f --- /dev/null +++ b/integration-testing/tests/ipc_reader.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::ipc::reader::{FileReader, StreamReader}; +use arrow::util::test_util::arrow_test_data; +use arrow_integration_testing::read_gzip_json; +use std::fs::File; + +#[test] +fn read_generated_files_014() { + let testdata = arrow_test_data(); + let version = "0.14.1"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec![ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + "generated_nested", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + "generated_decimal", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", + testdata, version, path + )) + .unwrap(); + + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + }); +} + +#[test] +#[should_panic(expected = "Big Endian is not supported for Decimal!")] +fn read_decimal_be_file_should_panic() { + let testdata = arrow_test_data(); + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/1.0.0-bigendian/generated_decimal.arrow_file", + testdata + )) + .unwrap(); + FileReader::try_new(file, None).unwrap(); +} + +#[test] +#[should_panic( + expected = "Last offset 687865856 of Utf8 is larger than values length 41" +)] +fn read_dictionary_be_not_implemented() { + // The offsets are not translated for big-endian files + // https://github.com/apache/arrow-rs/issues/859 + let testdata = arrow_test_data(); + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/1.0.0-bigendian/generated_dictionary.arrow_file", + testdata + )) + .unwrap(); + FileReader::try_new(file, None).unwrap(); +} + +#[test] +fn read_generated_be_files_should_work() { + // complementary to the previous test + let testdata = arrow_test_data(); + let paths = vec![ + "generated_interval", + "generated_datetime", + "generated_map", + "generated_nested", + "generated_null_trivial", + "generated_null", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/1.0.0-bigendian/{}.arrow_file", + testdata, path + )) + .unwrap(); + + FileReader::try_new(file, None).unwrap(); + }); +} + +#[test] +fn projection_should_work() { + // complementary to the previous test + let testdata = arrow_test_data(); + let paths = vec![ + "generated_interval", + "generated_datetime", + "generated_map", + "generated_nested", + "generated_null_trivial", + "generated_null", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + ]; + paths.iter().for_each(|path| { + // We must use littleendian files here. + // The offsets are not translated for big-endian files + // https://github.com/apache/arrow-rs/issues/859 + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/1.0.0-littleendian/{}.arrow_file", + testdata, path + )) + .unwrap(); + + let reader = FileReader::try_new(file, Some(vec![0])).unwrap(); + let datatype_0 = reader.schema().fields()[0].data_type().clone(); + reader.for_each(|batch| { + let batch = batch.unwrap(); + assert_eq!(batch.columns().len(), 1); + assert_eq!(datatype_0, batch.schema().fields()[0].data_type().clone()); + }); + }); +} + +#[test] +fn read_generated_streams_014() { + let testdata = arrow_test_data(); + let version = "0.14.1"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec![ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + "generated_nested", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + "generated_decimal", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.stream", + testdata, version, path + )) + .unwrap(); + + let mut reader = StreamReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + // the next batch must be empty + assert!(reader.next().is_none()); + // the stream must indicate that it's finished + assert!(reader.is_finished()); + }); +} + +#[test] +fn read_generated_files_100() { + let testdata = arrow_test_data(); + let version = "1.0.0-littleendian"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec![ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + // "generated_map_non_canonical", + "generated_nested", + "generated_null_trivial", + "generated_null", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", + testdata, version, path + )) + .unwrap(); + + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + }); +} + +#[test] +fn read_generated_streams_100() { + let testdata = arrow_test_data(); + let version = "1.0.0-littleendian"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec![ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + // "generated_map_non_canonical", + "generated_nested", + "generated_null_trivial", + "generated_null", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.stream", + testdata, version, path + )) + .unwrap(); + + let mut reader = StreamReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + // the next batch must be empty + assert!(reader.next().is_none()); + // the stream must indicate that it's finished + assert!(reader.is_finished()); + }); +} + +#[test] +fn read_generated_streams_200() { + let testdata = arrow_test_data(); + let version = "2.0.0-compression"; + + // the test is repetitive, thus we can read all supported files at once + let paths = vec!["generated_lz4", "generated_zstd"]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.stream", + testdata, version, path + )) + .unwrap(); + + let mut reader = StreamReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + // the next batch must be empty + assert!(reader.next().is_none()); + // the stream must indicate that it's finished + assert!(reader.is_finished()); + }); +} + +#[test] +fn read_generated_files_200() { + let testdata = arrow_test_data(); + let version = "2.0.0-compression"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec!["generated_lz4", "generated_zstd"]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", + testdata, version, path + )) + .unwrap(); + + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + }); +} diff --git a/integration-testing/tests/ipc_writer.rs b/integration-testing/tests/ipc_writer.rs new file mode 100644 index 000000000000..0aa17cd05c35 --- /dev/null +++ b/integration-testing/tests/ipc_writer.rs @@ -0,0 +1,314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::ipc; +use arrow::ipc::reader::{FileReader, StreamReader}; +use arrow::ipc::writer::{FileWriter, IpcWriteOptions, StreamWriter}; +use arrow::util::test_util::arrow_test_data; +use arrow_integration_testing::read_gzip_json; +use std::fs::File; +use std::io::Seek; + +#[test] +fn read_and_rewrite_generated_files_014() { + let testdata = arrow_test_data(); + let version = "0.14.1"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec![ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + "generated_nested", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + "generated_decimal", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", + testdata, version, path + )) + .unwrap(); + + let mut reader = FileReader::try_new(file, None).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + // read and rewrite the file to a temp location + { + let mut writer = FileWriter::try_new(&mut file, &reader.schema()).unwrap(); + while let Some(Ok(batch)) = reader.next() { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + } + file.rewind().unwrap(); + + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + }); +} + +#[test] +fn read_and_rewrite_generated_streams_014() { + let testdata = arrow_test_data(); + let version = "0.14.1"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec![ + "generated_interval", + "generated_datetime", + "generated_dictionary", + "generated_map", + "generated_nested", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + "generated_decimal", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.stream", + testdata, version, path + )) + .unwrap(); + + let reader = StreamReader::try_new(file, None).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + // read and rewrite the stream to a temp location + { + let mut writer = StreamWriter::try_new(&mut file, &reader.schema()).unwrap(); + reader.for_each(|batch| { + writer.write(&batch.unwrap()).unwrap(); + }); + writer.finish().unwrap(); + } + + file.rewind().unwrap(); + let mut reader = StreamReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + }); +} + +#[test] +fn read_and_rewrite_generated_files_100() { + let testdata = arrow_test_data(); + let version = "1.0.0-littleendian"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec![ + "generated_custom_metadata", + "generated_datetime", + "generated_dictionary_unsigned", + "generated_dictionary", + // "generated_duplicate_fieldnames", + "generated_interval", + "generated_map", + "generated_nested", + // "generated_nested_large_offsets", + "generated_null_trivial", + "generated_null", + "generated_primitive_large_offsets", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + // "generated_recursive_nested", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", + testdata, version, path + )) + .unwrap(); + + let mut reader = FileReader::try_new(file, None).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + // read and rewrite the file to a temp location + { + // write IPC version 5 + let options = + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5).unwrap(); + let mut writer = + FileWriter::try_new_with_options(&mut file, &reader.schema(), options) + .unwrap(); + while let Some(Ok(batch)) = reader.next() { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + } + + file.rewind().unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + }); +} + +#[test] +fn read_and_rewrite_generated_streams_100() { + let testdata = arrow_test_data(); + let version = "1.0.0-littleendian"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec![ + "generated_custom_metadata", + "generated_datetime", + "generated_dictionary_unsigned", + "generated_dictionary", + // "generated_duplicate_fieldnames", + "generated_interval", + "generated_map", + "generated_nested", + // "generated_nested_large_offsets", + "generated_null_trivial", + "generated_null", + "generated_primitive_large_offsets", + "generated_primitive_no_batches", + "generated_primitive_zerolength", + "generated_primitive", + // "generated_recursive_nested", + ]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.stream", + testdata, version, path + )) + .unwrap(); + + let reader = StreamReader::try_new(file, None).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + // read and rewrite the stream to a temp location + { + let options = + IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5).unwrap(); + let mut writer = + StreamWriter::try_new_with_options(&mut file, &reader.schema(), options) + .unwrap(); + reader.for_each(|batch| { + writer.write(&batch.unwrap()).unwrap(); + }); + writer.finish().unwrap(); + } + + file.rewind().unwrap(); + + let mut reader = StreamReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + }); +} + +#[test] +fn read_and_rewrite_compression_files_200() { + let testdata = arrow_test_data(); + let version = "2.0.0-compression"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec!["generated_lz4", "generated_zstd"]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", + testdata, version, path + )) + .unwrap(); + + let mut reader = FileReader::try_new(file, None).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + // read and rewrite the file to a temp location + { + // write IPC version 5 + let options = IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + .unwrap() + .try_with_compression(Some(ipc::CompressionType::LZ4_FRAME)) + .unwrap(); + + let mut writer = + FileWriter::try_new_with_options(&mut file, &reader.schema(), options) + .unwrap(); + while let Some(Ok(batch)) = reader.next() { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + } + + file.rewind().unwrap(); + let mut reader = FileReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + }); +} + +#[test] +fn read_and_rewrite_compression_stream_200() { + let testdata = arrow_test_data(); + let version = "2.0.0-compression"; + // the test is repetitive, thus we can read all supported files at once + let paths = vec!["generated_lz4", "generated_zstd"]; + paths.iter().for_each(|path| { + let file = File::open(format!( + "{}/arrow-ipc-stream/integration/{}/{}.stream", + testdata, version, path + )) + .unwrap(); + + let reader = StreamReader::try_new(file, None).unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + // read and rewrite the stream to a temp location + { + let options = IpcWriteOptions::try_new(8, false, ipc::MetadataVersion::V5) + .unwrap() + .try_with_compression(Some(ipc::CompressionType::ZSTD)) + .unwrap(); + + let mut writer = + StreamWriter::try_new_with_options(&mut file, &reader.schema(), options) + .unwrap(); + reader.for_each(|batch| { + writer.write(&batch.unwrap()).unwrap(); + }); + writer.finish().unwrap(); + } + + file.rewind().unwrap(); + + let mut reader = StreamReader::try_new(file, None).unwrap(); + + // read expected JSON output + let arrow_json = read_gzip_json(version, path); + assert!(arrow_json.equals_reader(&mut reader).unwrap()); + }); +} diff --git a/object_store/.github_changelog_generator b/object_store/.github_changelog_generator new file mode 100644 index 000000000000..cbd8aa0c4b48 --- /dev/null +++ b/object_store/.github_changelog_generator @@ -0,0 +1,27 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Add special sections for documentation, security and performance +add-sections={"documentation":{"prefix":"**Documentation updates:**","labels":["documentation"]},"security":{"prefix":"**Security updates:**","labels":["security"]},"performance":{"prefix":"**Performance improvements:**","labels":["performance"]}} +# so that the component is shown associated with the issue +issue-line-labels=object-store +# skip non object_store issues +exclude-labels=development-process,invalid,arrow,parquet,arrow-flight +breaking_labels=api-change diff --git a/object_store/CHANGELOG.md b/object_store/CHANGELOG.md new file mode 100644 index 000000000000..93faa678ffa8 --- /dev/null +++ b/object_store/CHANGELOG.md @@ -0,0 +1,70 @@ + + +# Changelog + +## [object_store_0.4.0](https://github.com/apache/arrow-rs/tree/object_store_0.4.0) (2022-08-10) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/object_store_0.3.0...object_store_0.4.0) + +**Implemented enhancements:** + +- Relax Path Validation to Allow Any Percent-Encoded Sequence [\#2355](https://github.com/apache/arrow-rs/issues/2355) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Support get\_multi\_ranges in ObjectStore [\#2293](https://github.com/apache/arrow-rs/issues/2293) +- object\_store: Create explicit test for symlinks [\#2206](https://github.com/apache/arrow-rs/issues/2206) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Make builder style configuration for object stores [\#2203](https://github.com/apache/arrow-rs/issues/2203) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- object\_store: Add example in the main documentation readme [\#2202](https://github.com/apache/arrow-rs/issues/2202) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Fixed bugs:** + +- Azure/S3 Storage Fails to Copy Blob with URL-encoded Path [\#2353](https://github.com/apache/arrow-rs/issues/2353) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] +- Accessing a file with a percent-encoded name on the filesystem with ObjectStore LocalFileSystem [\#2349](https://github.com/apache/arrow-rs/issues/2349) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] + +**Documentation updates:** + +- Improve `object_store crate` documentation [\#2260](https://github.com/apache/arrow-rs/pull/2260) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) + +**Merged pull requests:** + +- Canonicalize filesystem paths in user-facing APIs \(\#2370\) [\#2371](https://github.com/apache/arrow-rs/pull/2371) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix object\_store lint [\#2367](https://github.com/apache/arrow-rs/pull/2367) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Relax path validation \(\#2355\) [\#2356](https://github.com/apache/arrow-rs/pull/2356) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Fix Copy from percent-encoded path \(\#2353\) [\#2354](https://github.com/apache/arrow-rs/pull/2354) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add ObjectStore::get\_ranges \(\#2293\) [\#2336](https://github.com/apache/arrow-rs/pull/2336) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Remove vestigal ` object_store/.circleci/` [\#2337](https://github.com/apache/arrow-rs/pull/2337) ([alamb](https://github.com/alamb)) +- Handle symlinks in LocalFileSystem \(\#2206\) [\#2269](https://github.com/apache/arrow-rs/pull/2269) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Retry GCP requests on server error [\#2243](https://github.com/apache/arrow-rs/pull/2243) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Add LimitStore \(\#2175\) [\#2242](https://github.com/apache/arrow-rs/pull/2242) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([tustvold](https://github.com/tustvold)) +- Only trigger `arrow` CI on changes to arrow [\#2227](https://github.com/apache/arrow-rs/pull/2227) ([alamb](https://github.com/alamb)) +- Update instructions on how to join the Slack channel [\#2219](https://github.com/apache/arrow-rs/pull/2219) ([HaoYang670](https://github.com/HaoYang670)) +- Add Builder style config objects for object\_store [\#2204](https://github.com/apache/arrow-rs/pull/2204) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Ignore broken symlinks for LocalFileSystem object store [\#2195](https://github.com/apache/arrow-rs/pull/2195) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([jccampagne](https://github.com/jccampagne)) +- Change CI names to match crate names [\#2189](https://github.com/apache/arrow-rs/pull/2189) ([alamb](https://github.com/alamb)) +- Split most arrow specific CI checks into their own workflows \(reduce common CI time to 21 minutes\) [\#2168](https://github.com/apache/arrow-rs/pull/2168) ([alamb](https://github.com/alamb)) +- Remove another attempt to cache target directory in action.yaml [\#2167](https://github.com/apache/arrow-rs/pull/2167) ([alamb](https://github.com/alamb)) +- Run actions on push to master, pull requests [\#2166](https://github.com/apache/arrow-rs/pull/2166) ([alamb](https://github.com/alamb)) +- Break parquet\_derive and arrow\_flight tests into their own workflows [\#2165](https://github.com/apache/arrow-rs/pull/2165) ([alamb](https://github.com/alamb)) +- Only run integration tests when `arrow` changes [\#2152](https://github.com/apache/arrow-rs/pull/2152) ([alamb](https://github.com/alamb)) +- Break out docs CI job to its own github action [\#2151](https://github.com/apache/arrow-rs/pull/2151) ([alamb](https://github.com/alamb)) +- Do not pretend to cache rust build artifacts, speed up CI by ~20% [\#2150](https://github.com/apache/arrow-rs/pull/2150) ([alamb](https://github.com/alamb)) +- Port `object_store` integration tests, use github actions [\#2148](https://github.com/apache/arrow-rs/pull/2148) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Port Add stream upload \(multi-part upload\) [\#2147](https://github.com/apache/arrow-rs/pull/2147) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([alamb](https://github.com/alamb)) +- Increase upper wait time to reduce flakyness of object store test [\#2142](https://github.com/apache/arrow-rs/pull/2142) [[object-store](https://github.com/apache/arrow-rs/labels/object-store)] ([viirya](https://github.com/viirya)) + +\* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* diff --git a/object_store/CONTRIBUTING.md b/object_store/CONTRIBUTING.md new file mode 100644 index 000000000000..7c2832cf7ef1 --- /dev/null +++ b/object_store/CONTRIBUTING.md @@ -0,0 +1,113 @@ + + +# Development instructions + +## Running Tests + +Tests can be run using `cargo` + +```shell +cargo test +``` + +## Running Integration Tests + +By default, integration tests are not run. To run them you will need to set `TEST_INTEGRATION=1` and then provide the +necessary configuration for that object store + +### AWS + +To test the S3 integration against [localstack](https://localstack.cloud/) + +First start up a container running localstack + +``` +$ podman run --rm -it -p 4566:4566 -p 4510-4559:4510-4559 localstack/localstack +``` + +Setup environment + +``` +export TEST_INTEGRATION=1 +export AWS_DEFAULT_REGION=us-east-1 +export AWS_ACCESS_KEY_ID=test +export AWS_SECRET_ACCESS_KEY=test +export AWS_ENDPOINT=http://127.0.0.1:4566 +export OBJECT_STORE_BUCKET=test-bucket +``` + +Create a bucket using the AWS CLI + +``` +podman run --net=host --env-host amazon/aws-cli --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket +``` + +Run tests + +``` +$ cargo test --features aws +``` + +### Azure + +To test the Azure integration +against [azurite](https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio) + +Startup azurite + +``` +$ podman run -p 10000:10000 -p 10001:10001 -p 10002:10002 mcr.microsoft.com/azure-storage/azurite +``` + +Create a bucket + +``` +$ podman run --net=host mcr.microsoft.com/azure-cli az storage container create -n test-bucket --connection-string 'DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;QueueEndpoint=http://127.0.0.1:10001/devstoreaccount1;' +``` + +Run tests + +``` +$ cargo test --features azure +``` + +### GCP + +To test the GCS integration, we use [Fake GCS Server](https://github.com/fsouza/fake-gcs-server) + +Startup the fake server: + +```shell +docker run -p 4443:4443 fsouza/fake-gcs-server +``` + +Configure the account: +```shell +curl --insecure -v -X POST --data-binary '{"name":"test-bucket"}' -H "Content-Type: application/json" "https://localhost:4443/storage/v1/b" +echo '{"gcs_base_url": "https://localhost:4443", "disable_oauth": true, "client_email": "", "private_key": ""}' > /tmp/gcs.json +``` + +Now run the tests: +```shell +TEST_INTEGRATION=1 \ +OBJECT_STORE_BUCKET=test-bucket \ +GOOGLE_SERVICE_ACCOUNT=/tmp/gcs.json \ +cargo test -p object_store --features=gcp +``` diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml new file mode 100644 index 000000000000..b0201e2af983 --- /dev/null +++ b/object_store/Cargo.toml @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "object_store" +version = "0.4.0" +edition = "2021" +license = "MIT/Apache-2.0" +readme = "README.md" +description = "A generic object store interface for uniformly interacting with AWS S3, Google Cloud Storage, Azure Blob Storage and local files." +keywords = ["object", "storage", "cloud"] +repository = "https://github.com/apache/arrow-rs" + +[package.metadata.docs.rs] +all-features = true + +[dependencies] # In alphabetical order +async-trait = "0.1.53" +bytes = "1.0" +chrono = { version = "0.4", default-features = false, features = ["clock"] } +futures = "0.3" +itertools = "0.10.1" +parking_lot = { version = "0.12" } +percent-encoding = "2.1" +snafu = "0.7" +tokio = { version = "1.18", features = ["sync", "macros", "parking_lot", "rt-multi-thread", "time", "io-util"] } +tracing = { version = "0.1" } +url = "2.2" +walkdir = "2" + +# Cloud storage support +base64 = { version = "0.13", default-features = false, optional = true } +quick-xml = { version = "0.24.0", features = ["serialize"], optional = true } +serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } +serde_json = { version = "1.0", default-features = false, optional = true } +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } +reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"], optional = true } +ring = { version = "0.16", default-features = false, features = ["std"], optional = true } +rustls-pemfile = { version = "1.0", default-features = false, optional = true } + +[features] +cloud = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "base64", "rand", "ring"] +azure = ["cloud"] +gcp = ["cloud", "rustls-pemfile"] +aws = ["cloud"] + +[dev-dependencies] # In alphabetical order +dotenv = "0.15.0" +tempfile = "3.1.0" +futures-test = "0.3" +rand = "0.8" +hyper = { version = "0.14", features = ["server"] } diff --git a/object_store/README.md b/object_store/README.md new file mode 100644 index 000000000000..fd10414a9285 --- /dev/null +++ b/object_store/README.md @@ -0,0 +1,39 @@ + + +# Rust Object Store + +A focused, easy to use, idiomatic, high performance, `async` object +store library interacting with object stores. + +Using this crate, the same binary and code can easily run in multiple +clouds and local test environments, via a simple runtime configuration +change. Supported object stores include: + +* [AWS S3](https://aws.amazon.com/s3/) +* [Azure Blob Storage](https://azure.microsoft.com/en-us/services/storage/blobs/) +* [Google Cloud Storage](https://cloud.google.com/storage) +* Local files +* Memory +* Custom implementations + + +Originally developed for [InfluxDB IOx](https://github.com/influxdata/influxdb_iox/) and later split out and donated to [Apache Arrow](https://arrow.apache.org/). + +See [docs.rs](https://docs.rs/object_store) for usage instructions diff --git a/object_store/deny.toml b/object_store/deny.toml new file mode 100644 index 000000000000..bfd060a0b94d --- /dev/null +++ b/object_store/deny.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Configuration documentation: +#  https://embarkstudios.github.io/cargo-deny/index.html + +[advisories] +vulnerability = "deny" +yanked = "deny" +unmaintained = "warn" +notice = "warn" +ignore = [ +] +git-fetch-with-cli = true + +[licenses] +default = "allow" +unlicensed = "allow" +copyleft = "allow" + +[bans] +multiple-versions = "warn" +deny = [ + # We are using rustls as the TLS implementation, so we shouldn't be linking + # in OpenSSL too. + # + # If you're hitting this, you might want to take a look at what new + # dependencies you have introduced and check if there's a way to depend on + # rustls instead of OpenSSL (tip: check the crate's feature flags). + { name = "openssl-sys" } +] diff --git a/object_store/dev/release/README.md b/object_store/dev/release/README.md new file mode 100644 index 000000000000..89f6e579b23d --- /dev/null +++ b/object_store/dev/release/README.md @@ -0,0 +1,20 @@ + + +See instructions in [`/dev/release/README.md`](../../../dev/release/README.md) diff --git a/object_store/dev/release/create-tarball.sh b/object_store/dev/release/create-tarball.sh new file mode 100755 index 000000000000..bbffde89b043 --- /dev/null +++ b/object_store/dev/release/create-tarball.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This script creates a signed tarball in +# dev/dist/apache-arrow-object-store-rs--.tar.gz and uploads it to +# the "dev" area of the dist.apache.arrow repository and prepares an +# email for sending to the dev@arrow.apache.org list for a formal +# vote. +# +# Note the tags are expected to be `object_sore_` +# +# See release/README.md for full release instructions +# +# Requirements: +# +# 1. gpg setup for signing and have uploaded your public +# signature to https://pgp.mit.edu/ +# +# 2. Logged into the apache svn server with the appropriate +# credentials +# +# +# Based in part on 02-source.sh from apache/arrow +# + +set -e + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo "ex. $0 0.4.0 1" + exit +fi + +object_store_version=$1 +rc=$2 + +tag=object_store_${object_store_version} + +release=apache-arrow-object-store-rs-${object_store_version} +distdir=${SOURCE_TOP_DIR}/dev/dist/${release}-rc${rc} +tarname=${release}.tar.gz +tarball=${distdir}/${tarname} +url="https://dist.apache.org/repos/dist/dev/arrow/${release}-rc${rc}" + +echo "Attempting to create ${tarball} from tag ${tag}" + +release_hash=$(cd "${SOURCE_TOP_DIR}" && git rev-list --max-count=1 ${tag}) + +if [ -z "$release_hash" ]; then + echo "Cannot continue: unknown git tag: $tag" +fi + +echo "Draft email for dev@arrow.apache.org mailing list" +echo "" +echo "---------------------------------------------------------" +cat < containing the files in git at $release_hash +# the files in the tarball are prefixed with {object_store_version=} (e.g. 0.4.0) +mkdir -p ${distdir} +(cd "${SOURCE_TOP_DIR}" && git archive ${release_hash} --prefix ${release}/ | gzip > ${tarball}) + +echo "Running rat license checker on ${tarball}" +${SOURCE_DIR}/../../../dev/release/run-rat.sh ${tarball} + +echo "Signing tarball and creating checksums" +gpg --armor --output ${tarball}.asc --detach-sig ${tarball} +# create signing with relative path of tarball +# so that they can be verified with a command such as +# shasum --check apache-arrow-rs-4.1.0-rc2.tar.gz.sha512 +(cd ${distdir} && shasum -a 256 ${tarname}) > ${tarball}.sha256 +(cd ${distdir} && shasum -a 512 ${tarname}) > ${tarball}.sha512 + +echo "Uploading to apache dist/dev to ${url}" +svn co --depth=empty https://dist.apache.org/repos/dist/dev/arrow ${SOURCE_TOP_DIR}/dev/dist +svn add ${distdir} +svn ci -m "Apache Arrow Rust ${object_store_version=} ${rc}" ${distdir} diff --git a/object_store/dev/release/release-tarball.sh b/object_store/dev/release/release-tarball.sh new file mode 100755 index 000000000000..75ff886c6b1e --- /dev/null +++ b/object_store/dev/release/release-tarball.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This script copies a tarball from the "dev" area of the +# dist.apache.arrow repository to the "release" area +# +# This script should only be run after the release has been approved +# by the arrow PMC committee. +# +# See release/README.md for full release instructions +# +# Based in part on post-01-upload.sh from apache/arrow + + +set -e +set -u + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo "ex. $0 0.4.0 1" + exit +fi + +version=$1 +rc=$2 + +tmp_dir=tmp-apache-arrow-dist + +echo "Recreate temporary directory: ${tmp_dir}" +rm -rf ${tmp_dir} +mkdir -p ${tmp_dir} + +echo "Clone dev dist repository" +svn \ + co \ + https://dist.apache.org/repos/dist/dev/arrow/apache-arrow-object-store-rs-${version}-rc${rc} \ + ${tmp_dir}/dev + +echo "Clone release dist repository" +svn co https://dist.apache.org/repos/dist/release/arrow ${tmp_dir}/release + +echo "Copy ${version}-rc${rc} to release working copy" +release_version=arrow-object-store-rs-${version} +mkdir -p ${tmp_dir}/release/${release_version} +cp -r ${tmp_dir}/dev/* ${tmp_dir}/release/${release_version}/ +svn add ${tmp_dir}/release/${release_version} + +echo "Commit release" +svn ci -m "Apache Arrow Rust Object Store ${version}" ${tmp_dir}/release + +echo "Clean up" +rm -rf ${tmp_dir} + +echo "Success!" +echo "The release is available here:" +echo " https://dist.apache.org/repos/dist/release/arrow/${release_version}" diff --git a/object_store/dev/release/update_change_log.sh b/object_store/dev/release/update_change_log.sh new file mode 100755 index 000000000000..ebd50df7ffc0 --- /dev/null +++ b/object_store/dev/release/update_change_log.sh @@ -0,0 +1,79 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# invokes the changelog generator from +# https://github.com/github-changelog-generator/github-changelog-generator +# +# With the config located in +# arrow-rs/object_store/.github_changelog_generator +# +# Usage: +# CHANGELOG_GITHUB_TOKEN= ./update_change_log.sh + +set -e + +SINCE_TAG="object_store_0.3.0" +FUTURE_RELEASE="object_store_0.4.0" + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" + +OUTPUT_PATH="${SOURCE_TOP_DIR}/CHANGELOG.md" + +# remove license header so github-changelog-generator has a clean base to append +sed -i.bak '1,18d' "${OUTPUT_PATH}" + +# use exclude-tags-regex to filter out tags used for arrow +# crates and only look at tags that begin with `object_store_` +pushd "${SOURCE_TOP_DIR}" +docker run -it --rm -e CHANGELOG_GITHUB_TOKEN="$CHANGELOG_GITHUB_TOKEN" -v "$(pwd)":/usr/local/src/your-app githubchangeloggenerator/github-changelog-generator \ + --user apache \ + --project arrow-rs \ + --cache-file=.githubchangeloggenerator.cache \ + --cache-log=.githubchangeloggenerator.cache.log \ + --http-cache \ + --max-issues=300 \ + --exclude-tags-regex "^\d+\.\d+\.\d+$" \ + --since-tag ${SINCE_TAG} \ + --future-release ${FUTURE_RELEASE} + +sed -i.bak "s/\\\n/\n\n/" "${OUTPUT_PATH}" + +# Put license header back on +echo ' +' | cat - "${OUTPUT_PATH}" > "${OUTPUT_PATH}".tmp +mv "${OUTPUT_PATH}".tmp "${OUTPUT_PATH}" diff --git a/object_store/dev/release/verify-release-candidate.sh b/object_store/dev/release/verify-release-candidate.sh new file mode 100755 index 000000000000..06a5d8bcb838 --- /dev/null +++ b/object_store/dev/release/verify-release-candidate.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +case $# in + 2) VERSION="$1" + RC_NUMBER="$2" + ;; + *) echo "Usage: $0 X.Y.Z RC_NUMBER" + exit 1 + ;; +esac + +set -e +set -x +set -o pipefail + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" +ARROW_DIR="$(dirname $(dirname ${SOURCE_DIR}))" +ARROW_DIST_URL='https://dist.apache.org/repos/dist/dev/arrow' + +download_dist_file() { + curl \ + --silent \ + --show-error \ + --fail \ + --location \ + --remote-name $ARROW_DIST_URL/$1 +} + +download_rc_file() { + download_dist_file apache-arrow-object-store-rs-${VERSION}-rc${RC_NUMBER}/$1 +} + +import_gpg_keys() { + download_dist_file KEYS + gpg --import KEYS +} + +if type shasum >/dev/null 2>&1; then + sha256_verify="shasum -a 256 -c" + sha512_verify="shasum -a 512 -c" +else + sha256_verify="sha256sum -c" + sha512_verify="sha512sum -c" +fi + +fetch_archive() { + local dist_name=$1 + download_rc_file ${dist_name}.tar.gz + download_rc_file ${dist_name}.tar.gz.asc + download_rc_file ${dist_name}.tar.gz.sha256 + download_rc_file ${dist_name}.tar.gz.sha512 + gpg --verify ${dist_name}.tar.gz.asc ${dist_name}.tar.gz + ${sha256_verify} ${dist_name}.tar.gz.sha256 + ${sha512_verify} ${dist_name}.tar.gz.sha512 +} + +setup_tempdir() { + cleanup() { + if [ "${TEST_SUCCESS}" = "yes" ]; then + rm -fr "${ARROW_TMPDIR}" + else + echo "Failed to verify release candidate. See ${ARROW_TMPDIR} for details." + fi + } + + if [ -z "${ARROW_TMPDIR}" ]; then + # clean up automatically if ARROW_TMPDIR is not defined + ARROW_TMPDIR=$(mktemp -d -t "$1.XXXXX") + trap cleanup EXIT + else + # don't clean up automatically + mkdir -p "${ARROW_TMPDIR}" + fi +} + +test_source_distribution() { + # install rust toolchain in a similar fashion like test-miniconda + export RUSTUP_HOME=$PWD/test-rustup + export CARGO_HOME=$PWD/test-rustup + + curl https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path + + export PATH=$RUSTUP_HOME/bin:$PATH + source $RUSTUP_HOME/env + + # build and test rust + cargo build + cargo test --all + + # verify that the crate can be published to crates.io + cargo publish --dry-run +} + +TEST_SUCCESS=no + +setup_tempdir "arrow-${VERSION}" +echo "Working in sandbox ${ARROW_TMPDIR}" +cd ${ARROW_TMPDIR} + +dist_name="apache-arrow-object-store-rs-${VERSION}" +import_gpg_keys +fetch_archive ${dist_name} +tar xf ${dist_name}.tar.gz +pushd ${dist_name} +test_source_distribution +popd + +TEST_SUCCESS=yes +echo 'Release candidate looks good!' +exit 0 diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs new file mode 100644 index 000000000000..d8ab3bba8f20 --- /dev/null +++ b/object_store/src/aws/client.rs @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider}; +use crate::client::pagination::stream_paginated; +use crate::client::retry::RetryExt; +use crate::multipart::UploadPart; +use crate::path::DELIMITER; +use crate::util::{format_http_range, format_prefix}; +use crate::{ + BoxStream, ListResult, MultipartId, ObjectMeta, Path, Result, RetryConfig, StreamExt, +}; +use bytes::{Buf, Bytes}; +use chrono::{DateTime, Utc}; +use percent_encoding::{utf8_percent_encode, AsciiSet, PercentEncode, NON_ALPHANUMERIC}; +use reqwest::{Client as ReqwestClient, Method, Response, StatusCode}; +use serde::{Deserialize, Serialize}; +use snafu::{ResultExt, Snafu}; +use std::ops::Range; +use std::sync::Arc; + +// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html +// +// Do not URI-encode any of the unreserved characters that RFC 3986 defines: +// A-Z, a-z, 0-9, hyphen ( - ), underscore ( _ ), period ( . ), and tilde ( ~ ). +const STRICT_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC + .remove(b'-') + .remove(b'.') + .remove(b'_') + .remove(b'~'); + +/// This struct is used to maintain the URI path encoding +const STRICT_PATH_ENCODE_SET: AsciiSet = STRICT_ENCODE_SET.remove(b'/'); + +/// A specialized `Error` for object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub(crate) enum Error { + #[snafu(display("Error performing get request {}: {}", path, source))] + GetRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error fetching get response body {}: {}", path, source))] + GetResponseBody { + source: reqwest::Error, + path: String, + }, + + #[snafu(display("Error performing put request {}: {}", path, source))] + PutRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing delete request {}: {}", path, source))] + DeleteRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing copy request {}: {}", path, source))] + CopyRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing list request: {}", source))] + ListRequest { source: crate::client::retry::Error }, + + #[snafu(display("Error getting list response body: {}", source))] + ListResponseBody { source: reqwest::Error }, + + #[snafu(display("Error performing create multipart request: {}", source))] + CreateMultipartRequest { source: crate::client::retry::Error }, + + #[snafu(display("Error getting create multipart response body: {}", source))] + CreateMultipartResponseBody { source: reqwest::Error }, + + #[snafu(display("Error performing complete multipart request: {}", source))] + CompleteMultipartRequest { source: crate::client::retry::Error }, + + #[snafu(display("Got invalid list response: {}", source))] + InvalidListResponse { source: quick_xml::de::DeError }, + + #[snafu(display("Got invalid multipart response: {}", source))] + InvalidMultipartResponse { source: quick_xml::de::DeError }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::GetRequest { source, path } + | Error::DeleteRequest { source, path } + | Error::CopyRequest { source, path } + | Error::PutRequest { source, path } + if matches!(source.status(), Some(StatusCode::NOT_FOUND)) => + { + Self::NotFound { + path, + source: Box::new(source), + } + } + _ => Self::Generic { + store: "S3", + source: Box::new(err), + }, + } + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct ListResponse { + #[serde(default)] + pub contents: Vec, + #[serde(default)] + pub common_prefixes: Vec, + #[serde(default)] + pub next_continuation_token: Option, +} + +impl TryFrom for ListResult { + type Error = crate::Error; + + fn try_from(value: ListResponse) -> Result { + let common_prefixes = value + .common_prefixes + .into_iter() + .map(|x| Ok(Path::parse(&x.prefix)?)) + .collect::>()?; + + let objects = value + .contents + .into_iter() + .map(TryFrom::try_from) + .collect::>()?; + + Ok(Self { + common_prefixes, + objects, + }) + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct ListPrefix { + pub prefix: String, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct ListContents { + pub key: String, + pub size: usize, + pub last_modified: DateTime, +} + +impl TryFrom for ObjectMeta { + type Error = crate::Error; + + fn try_from(value: ListContents) -> Result { + Ok(Self { + location: Path::parse(value.key)?, + last_modified: value.last_modified, + size: value.size, + }) + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct InitiateMultipart { + upload_id: String, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "PascalCase", rename = "CompleteMultipartUpload")] +struct CompleteMultipart { + part: Vec, +} + +#[derive(Debug, Serialize)] +struct MultipartPart { + #[serde(rename = "$unflatten=ETag")] + e_tag: String, + #[serde(rename = "$unflatten=PartNumber")] + part_number: usize, +} + +#[derive(Debug)] +pub struct S3Config { + pub region: String, + pub endpoint: String, + pub bucket: String, + pub credentials: CredentialProvider, + pub retry_config: RetryConfig, + pub allow_http: bool, +} + +impl S3Config { + fn path_url(&self, path: &Path) -> String { + format!("{}/{}/{}", self.endpoint, self.bucket, encode_path(path)) + } +} + +#[derive(Debug)] +pub(crate) struct S3Client { + config: S3Config, + client: ReqwestClient, +} + +impl S3Client { + pub fn new(config: S3Config) -> Self { + let client = reqwest::ClientBuilder::new() + .https_only(!config.allow_http) + .build() + .unwrap(); + + Self { config, client } + } + + /// Returns the config + pub fn config(&self) -> &S3Config { + &self.config + } + + async fn get_credential(&self) -> Result> { + self.config.credentials.get_credential().await + } + + /// Make an S3 GET request + pub async fn get_request( + &self, + path: &Path, + range: Option>, + head: bool, + ) -> Result { + use reqwest::header::RANGE; + + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + let method = match head { + true => Method::HEAD, + false => Method::GET, + }; + + let mut builder = self.client.request(method, url); + + if let Some(range) = range { + builder = builder.header(RANGE, format_http_range(range)); + } + + let response = builder + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(GetRequestSnafu { + path: path.as_ref(), + })?; + + Ok(response) + } + + /// Make an S3 PUT request + pub async fn put_request( + &self, + path: &Path, + bytes: Option, + query: &T, + ) -> Result { + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + + let mut builder = self.client.request(Method::PUT, url); + if let Some(bytes) = bytes { + builder = builder.body(bytes) + } + + let response = builder + .query(query) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(PutRequestSnafu { + path: path.as_ref(), + })?; + + Ok(response) + } + + /// Make an S3 Delete request + pub async fn delete_request( + &self, + path: &Path, + query: &T, + ) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + + self.client + .request(Method::DELETE, url) + .query(query) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(DeleteRequestSnafu { + path: path.as_ref(), + })?; + + Ok(()) + } + + /// Make an S3 Copy request + pub async fn copy_request(&self, from: &Path, to: &Path) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.config.path_url(to); + let source = format!("{}/{}", self.config.bucket, encode_path(from)); + + self.client + .request(Method::PUT, url) + .header("x-amz-copy-source", source) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(CopyRequestSnafu { + path: from.as_ref(), + })?; + + Ok(()) + } + + /// Make an S3 List request + async fn list_request( + &self, + prefix: Option<&str>, + delimiter: bool, + token: Option<&str>, + ) -> Result<(ListResult, Option)> { + let credential = self.get_credential().await?; + let url = format!("{}/{}", self.config.endpoint, self.config.bucket); + + let mut query = Vec::with_capacity(4); + + // Note: the order of these matters to ensure the generated URL is canonical + if let Some(token) = token { + query.push(("continuation-token", token)) + } + + if delimiter { + query.push(("delimiter", DELIMITER)) + } + + query.push(("list-type", "2")); + + if let Some(prefix) = prefix { + query.push(("prefix", prefix)) + } + + let response = self + .client + .request(Method::GET, &url) + .query(&query) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(ListRequestSnafu)? + .bytes() + .await + .context(ListResponseBodySnafu)?; + + let mut response: ListResponse = quick_xml::de::from_reader(response.reader()) + .context(InvalidListResponseSnafu)?; + let token = response.next_continuation_token.take(); + + Ok((response.try_into()?, token)) + } + + /// Perform a list operation automatically handling pagination + pub fn list_paginated( + &self, + prefix: Option<&Path>, + delimiter: bool, + ) -> BoxStream<'_, Result> { + let prefix = format_prefix(prefix); + stream_paginated(prefix, move |prefix, token| async move { + let (r, next_token) = self + .list_request(prefix.as_deref(), delimiter, token.as_deref()) + .await?; + Ok((r, prefix, next_token)) + }) + .boxed() + } + + pub async fn create_multipart(&self, location: &Path) -> Result { + let credential = self.get_credential().await?; + let url = format!( + "{}/{}/{}?uploads", + self.config.endpoint, + self.config.bucket, + encode_path(location) + ); + + let response = self + .client + .request(Method::POST, url) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(CreateMultipartRequestSnafu)? + .bytes() + .await + .context(CreateMultipartResponseBodySnafu)?; + + let response: InitiateMultipart = quick_xml::de::from_reader(response.reader()) + .context(InvalidMultipartResponseSnafu)?; + + Ok(response.upload_id) + } + + pub async fn complete_multipart( + &self, + location: &Path, + upload_id: &str, + parts: Vec, + ) -> Result<()> { + let parts = parts + .into_iter() + .enumerate() + .map(|(part_idx, part)| MultipartPart { + e_tag: part.content_id, + part_number: part_idx + 1, + }) + .collect(); + + let request = CompleteMultipart { part: parts }; + let body = quick_xml::se::to_string(&request).unwrap(); + + let credential = self.get_credential().await?; + let url = self.config.path_url(location); + + self.client + .request(Method::POST, url) + .query(&[("uploadId", upload_id)]) + .body(body) + .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3") + .send_retry(&self.config.retry_config) + .await + .context(CompleteMultipartRequestSnafu)?; + + Ok(()) + } +} + +fn encode_path(path: &Path) -> PercentEncode<'_> { + utf8_percent_encode(path.as_ref(), &STRICT_PATH_ENCODE_SET) +} diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs new file mode 100644 index 000000000000..1abf42be9103 --- /dev/null +++ b/object_store/src/aws/credential.rs @@ -0,0 +1,699 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::retry::RetryExt; +use crate::client::token::{TemporaryToken, TokenCache}; +use crate::util::hmac_sha256; +use crate::{Result, RetryConfig}; +use bytes::Buf; +use chrono::{DateTime, Utc}; +use futures::TryFutureExt; +use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest::{Client, Method, Request, RequestBuilder, StatusCode}; +use serde::Deserialize; +use std::collections::BTreeMap; +use std::sync::Arc; +use std::time::Instant; +use tracing::warn; + +type StdError = Box; + +/// SHA256 hash of empty string +static EMPTY_SHA256_HASH: &str = + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + +#[derive(Debug)] +pub struct AwsCredential { + pub key_id: String, + pub secret_key: String, + pub token: Option, +} + +impl AwsCredential { + /// Signs a string + /// + /// + fn sign( + &self, + to_sign: &str, + date: DateTime, + region: &str, + service: &str, + ) -> String { + let date_string = date.format("%Y%m%d").to_string(); + let date_hmac = hmac_sha256(format!("AWS4{}", self.secret_key), date_string); + let region_hmac = hmac_sha256(date_hmac, region); + let service_hmac = hmac_sha256(region_hmac, service); + let signing_hmac = hmac_sha256(service_hmac, b"aws4_request"); + hex_encode(hmac_sha256(signing_hmac, to_sign).as_ref()) + } +} + +struct RequestSigner<'a> { + date: DateTime, + credential: &'a AwsCredential, + service: &'a str, + region: &'a str, +} + +const DATE_HEADER: &str = "x-amz-date"; +const HASH_HEADER: &str = "x-amz-content-sha256"; +const TOKEN_HEADER: &str = "x-amz-security-token"; +const AUTH_HEADER: &str = "authorization"; + +const ALL_HEADERS: &[&str; 4] = &[DATE_HEADER, HASH_HEADER, TOKEN_HEADER, AUTH_HEADER]; + +impl<'a> RequestSigner<'a> { + fn sign(&self, request: &mut Request) { + if let Some(ref token) = self.credential.token { + let token_val = HeaderValue::from_str(token).unwrap(); + request.headers_mut().insert(TOKEN_HEADER, token_val); + } + + let host_val = HeaderValue::from_str( + &request.url()[url::Position::BeforeHost..url::Position::AfterPort], + ) + .unwrap(); + request.headers_mut().insert("host", host_val); + + let date_str = self.date.format("%Y%m%dT%H%M%SZ").to_string(); + let date_val = HeaderValue::from_str(&date_str).unwrap(); + request.headers_mut().insert(DATE_HEADER, date_val); + + let digest = match request.body() { + None => EMPTY_SHA256_HASH.to_string(), + Some(body) => hex_digest(body.as_bytes().unwrap()), + }; + + let header_digest = HeaderValue::from_str(&digest).unwrap(); + request.headers_mut().insert(HASH_HEADER, header_digest); + + let (signed_headers, canonical_headers) = canonicalize_headers(request.headers()); + + // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + let canonical_request = format!( + "{}\n{}\n{}\n{}\n{}\n{}", + request.method().as_str(), + request.url().path(), // S3 doesn't percent encode this like other services + request.url().query().unwrap_or(""), // This assumes the query pairs are in order + canonical_headers, + signed_headers, + digest + ); + + let hashed_canonical_request = hex_digest(canonical_request.as_bytes()); + let scope = format!( + "{}/{}/{}/aws4_request", + self.date.format("%Y%m%d"), + self.region, + self.service + ); + + let string_to_sign = format!( + "AWS4-HMAC-SHA256\n{}\n{}\n{}", + self.date.format("%Y%m%dT%H%M%SZ"), + scope, + hashed_canonical_request + ); + + // sign the string + let signature = + self.credential + .sign(&string_to_sign, self.date, self.region, self.service); + + // build the actual auth header + let authorisation = format!( + "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + self.credential.key_id, scope, signed_headers, signature + ); + + let authorization_val = HeaderValue::from_str(&authorisation).unwrap(); + request.headers_mut().insert(AUTH_HEADER, authorization_val); + } +} + +pub trait CredentialExt { + /// Sign a request + fn with_aws_sigv4( + self, + credential: &AwsCredential, + region: &str, + service: &str, + ) -> Self; +} + +impl CredentialExt for RequestBuilder { + fn with_aws_sigv4( + mut self, + credential: &AwsCredential, + region: &str, + service: &str, + ) -> Self { + // Hack around lack of access to underlying request + // https://github.com/seanmonstar/reqwest/issues/1212 + let mut request = self + .try_clone() + .expect("not stream") + .build() + .expect("request valid"); + + let date = Utc::now(); + let signer = RequestSigner { + date, + credential, + service, + region, + }; + + signer.sign(&mut request); + + for header in ALL_HEADERS { + if let Some(val) = request.headers_mut().remove(*header) { + self = self.header(*header, val) + } + } + self + } +} + +/// Computes the SHA256 digest of `body` returned as a hex encoded string +fn hex_digest(bytes: &[u8]) -> String { + let digest = ring::digest::digest(&ring::digest::SHA256, bytes); + hex_encode(digest.as_ref()) +} + +/// Returns `bytes` as a lower-case hex encoded string +fn hex_encode(bytes: &[u8]) -> String { + use std::fmt::Write; + let mut out = String::with_capacity(bytes.len() * 2); + for byte in bytes { + // String writing is infallible + let _ = write!(out, "{:02x}", byte); + } + out +} + +/// Canonicalizes headers into the AWS Canonical Form. +/// +/// +fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) { + let mut headers = BTreeMap::<&str, Vec<&str>>::new(); + let mut value_count = 0; + let mut value_bytes = 0; + let mut key_bytes = 0; + + for (key, value) in header_map { + let key = key.as_str(); + if ["authorization", "content-length", "user-agent"].contains(&key) { + continue; + } + + let value = std::str::from_utf8(value.as_bytes()).unwrap(); + key_bytes += key.len(); + value_bytes += value.len(); + value_count += 1; + headers.entry(key).or_default().push(value); + } + + let mut signed_headers = String::with_capacity(key_bytes + headers.len()); + let mut canonical_headers = + String::with_capacity(key_bytes + value_bytes + headers.len() + value_count); + + for (header_idx, (name, values)) in headers.into_iter().enumerate() { + if header_idx != 0 { + signed_headers.push(';'); + } + + signed_headers.push_str(name); + canonical_headers.push_str(name); + canonical_headers.push(':'); + for (value_idx, value) in values.into_iter().enumerate() { + if value_idx != 0 { + canonical_headers.push(','); + } + canonical_headers.push_str(value.trim()); + } + canonical_headers.push('\n'); + } + + (signed_headers, canonical_headers) +} + +/// Provides credentials for use when signing requests +#[derive(Debug)] +pub enum CredentialProvider { + Static(StaticCredentialProvider), + Instance(InstanceCredentialProvider), + WebIdentity(WebIdentityProvider), +} + +impl CredentialProvider { + pub async fn get_credential(&self) -> Result> { + match self { + Self::Static(s) => Ok(Arc::clone(&s.credential)), + Self::Instance(c) => c.get_credential().await, + Self::WebIdentity(c) => c.get_credential().await, + } + } +} + +/// A static set of credentials +#[derive(Debug)] +pub struct StaticCredentialProvider { + pub credential: Arc, +} + +/// Credentials sourced from the instance metadata service +/// +/// +#[derive(Debug)] +pub struct InstanceCredentialProvider { + pub cache: TokenCache>, + pub client: Client, + pub retry_config: RetryConfig, + pub imdsv1_fallback: bool, +} + +impl InstanceCredentialProvider { + async fn get_credential(&self) -> Result> { + self.cache + .get_or_insert_with(|| { + const METADATA_ENDPOINT: &str = "http://169.254.169.254"; + instance_creds( + &self.client, + &self.retry_config, + METADATA_ENDPOINT, + self.imdsv1_fallback, + ) + .map_err(|source| crate::Error::Generic { + store: "S3", + source, + }) + }) + .await + } +} + +/// Credentials sourced using AssumeRoleWithWebIdentity +/// +/// +#[derive(Debug)] +pub struct WebIdentityProvider { + pub cache: TokenCache>, + pub token: String, + pub role_arn: String, + pub session_name: String, + pub endpoint: String, + pub client: Client, + pub retry_config: RetryConfig, +} + +impl WebIdentityProvider { + async fn get_credential(&self) -> Result> { + self.cache + .get_or_insert_with(|| { + web_identity( + &self.client, + &self.retry_config, + &self.token, + &self.role_arn, + &self.session_name, + &self.endpoint, + ) + .map_err(|source| crate::Error::Generic { + store: "S3", + source, + }) + }) + .await + } +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct InstanceCredentials { + access_key_id: String, + secret_access_key: String, + token: String, + expiration: DateTime, +} + +impl From for AwsCredential { + fn from(s: InstanceCredentials) -> Self { + Self { + key_id: s.access_key_id, + secret_key: s.secret_access_key, + token: Some(s.token), + } + } +} + +/// +async fn instance_creds( + client: &Client, + retry_config: &RetryConfig, + endpoint: &str, + imdsv1_fallback: bool, +) -> Result>, StdError> { + const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials"; + const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token"; + + let token_url = format!("{}/latest/api/token", endpoint); + + let token_result = client + .request(Method::PUT, token_url) + .header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL + .send_retry(retry_config) + .await; + + let token = match token_result { + Ok(t) => Some(t.text().await?), + Err(e) + if imdsv1_fallback && matches!(e.status(), Some(StatusCode::FORBIDDEN)) => + { + warn!("received 403 from metadata endpoint, falling back to IMDSv1"); + None + } + Err(e) => return Err(e.into()), + }; + + let role_url = format!("{}/{}/", endpoint, CREDENTIALS_PATH); + let mut role_request = client.request(Method::GET, role_url); + + if let Some(token) = &token { + role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token); + } + + let role = role_request.send_retry(retry_config).await?.text().await?; + + let creds_url = format!("{}/{}/{}", endpoint, CREDENTIALS_PATH, role); + let mut creds_request = client.request(Method::GET, creds_url); + if let Some(token) = &token { + creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token); + } + + let creds: InstanceCredentials = + creds_request.send_retry(retry_config).await?.json().await?; + + let now = Utc::now(); + let ttl = (creds.expiration - now).to_std().unwrap_or_default(); + Ok(TemporaryToken { + token: Arc::new(creds.into()), + expiry: Instant::now() + ttl, + }) +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct AssumeRoleResponse { + assume_role_with_web_identity_result: AssumeRoleResult, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct AssumeRoleResult { + credentials: AssumeRoleCredentials, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct AssumeRoleCredentials { + session_token: String, + secret_access_key: String, + access_key_id: String, + expiration: DateTime, +} + +impl From for AwsCredential { + fn from(s: AssumeRoleCredentials) -> Self { + Self { + key_id: s.access_key_id, + secret_key: s.secret_access_key, + token: Some(s.session_token), + } + } +} + +/// +async fn web_identity( + client: &Client, + retry_config: &RetryConfig, + token: &str, + role_arn: &str, + session_name: &str, + endpoint: &str, +) -> Result>, StdError> { + let bytes = client + .request(Method::POST, endpoint) + .query(&[ + ("Action", "AssumeRoleWithWebIdentity"), + ("DurationSeconds", "3600"), + ("RoleArn", role_arn), + ("RoleSessionName", session_name), + ("Version", "2011-06-15"), + ("WebIdentityToken", token), + ]) + .send_retry(retry_config) + .await? + .bytes() + .await?; + + let resp: AssumeRoleResponse = quick_xml::de::from_reader(bytes.reader()) + .map_err(|e| format!("Invalid AssumeRoleWithWebIdentity response: {}", e))?; + + let creds = resp.assume_role_with_web_identity_result.credentials; + let now = Utc::now(); + let ttl = (creds.expiration - now).to_std().unwrap_or_default(); + + Ok(TemporaryToken { + token: Arc::new(creds.into()), + expiry: Instant::now() + ttl, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::mock_server::MockServer; + use hyper::{Body, Response}; + use reqwest::{Client, Method}; + use std::env; + + // Test generated using https://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html + #[test] + fn test_sign() { + let client = Client::new(); + + // Test credentials from https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + // method = 'GET' + // service = 'ec2' + // host = 'ec2.amazonaws.com' + // region = 'us-east-1' + // endpoint = 'https://ec2.amazonaws.com' + // request_parameters = '' + let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z") + .unwrap() + .with_timezone(&Utc); + + let mut request = client + .request(Method::GET, "https://ec2.amazon.com/") + .build() + .unwrap(); + + let signer = RequestSigner { + date, + credential: &credential, + service: "ec2", + region: "us-east-1", + }; + + signer.sign(&mut request); + assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4") + } + + #[test] + fn test_sign_port() { + let client = Client::new(); + + let credential = AwsCredential { + key_id: "H20ABqCkLZID4rLe".to_string(), + secret_key: "jMqRDgxSsBqqznfmddGdu1TmmZOJQxdM".to_string(), + token: None, + }; + + let date = DateTime::parse_from_rfc3339("2022-08-09T13:05:25Z") + .unwrap() + .with_timezone(&Utc); + + let mut request = client + .request(Method::GET, "http://localhost:9000/tsm-schemas") + .query(&[ + ("delimiter", "/"), + ("encoding-type", "url"), + ("list-type", "2"), + ("prefix", ""), + ]) + .build() + .unwrap(); + + let signer = RequestSigner { + date, + credential: &credential, + service: "s3", + region: "us-east-1", + }; + + signer.sign(&mut request); + assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d") + } + + #[tokio::test] + async fn test_instance_metadata() { + if env::var("TEST_INTEGRATION").is_err() { + eprintln!("skipping AWS integration test"); + } + + // For example https://github.com/aws/amazon-ec2-metadata-mock + let endpoint = env::var("EC2_METADATA_ENDPOINT").unwrap(); + let client = Client::new(); + let retry_config = RetryConfig::default(); + + // Verify only allows IMDSv2 + let resp = client + .request(Method::GET, format!("{}/latest/meta-data/ami-id", endpoint)) + .send() + .await + .unwrap(); + + assert_eq!( + resp.status(), + StatusCode::UNAUTHORIZED, + "Ensure metadata endpoint is set to only allow IMDSv2" + ); + + let creds = instance_creds(&client, &retry_config, &endpoint, false) + .await + .unwrap(); + + let id = &creds.token.key_id; + let secret = &creds.token.secret_key; + let token = creds.token.token.as_ref().unwrap(); + + assert!(!id.is_empty()); + assert!(!secret.is_empty()); + assert!(!token.is_empty()) + } + + #[tokio::test] + async fn test_mock() { + let server = MockServer::new(); + + const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token"; + + let secret_access_key = "SECRET"; + let access_key_id = "KEYID"; + let token = "TOKEN"; + + let endpoint = server.url(); + let client = Client::new(); + let retry_config = RetryConfig::default(); + + // Test IMDSv2 + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/api/token"); + assert_eq!(req.method(), &Method::PUT); + Response::new(Body::from("cupcakes")) + }); + server.push_fn(|req| { + assert_eq!( + req.uri().path(), + "/latest/meta-data/iam/security-credentials/" + ); + assert_eq!(req.method(), &Method::GET); + let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap(); + assert_eq!(t, "cupcakes"); + Response::new(Body::from("myrole")) + }); + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole"); + assert_eq!(req.method(), &Method::GET); + let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap(); + assert_eq!(t, "cupcakes"); + Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#)) + }); + + let creds = instance_creds(&client, &retry_config, endpoint, true) + .await + .unwrap(); + + assert_eq!(creds.token.token.as_deref().unwrap(), token); + assert_eq!(&creds.token.key_id, access_key_id); + assert_eq!(&creds.token.secret_key, secret_access_key); + + // Test IMDSv1 fallback + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/api/token"); + assert_eq!(req.method(), &Method::PUT); + Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Body::empty()) + .unwrap() + }); + server.push_fn(|req| { + assert_eq!( + req.uri().path(), + "/latest/meta-data/iam/security-credentials/" + ); + assert_eq!(req.method(), &Method::GET); + assert!(req.headers().get(IMDSV2_HEADER).is_none()); + Response::new(Body::from("myrole")) + }); + server.push_fn(|req| { + assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole"); + assert_eq!(req.method(), &Method::GET); + assert!(req.headers().get(IMDSV2_HEADER).is_none()); + Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#)) + }); + + let creds = instance_creds(&client, &retry_config, endpoint, true) + .await + .unwrap(); + + assert_eq!(creds.token.token.as_deref().unwrap(), token); + assert_eq!(&creds.token.key_id, access_key_id); + assert_eq!(&creds.token.secret_key, secret_access_key); + + // Test IMDSv1 fallback disabled + server.push( + Response::builder() + .status(StatusCode::FORBIDDEN) + .body(Body::empty()) + .unwrap(), + ); + + // Should fail + instance_creds(&client, &retry_config, endpoint, false) + .await + .unwrap_err(); + } +} diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs new file mode 100644 index 000000000000..d1d0a12cdaf9 --- /dev/null +++ b/object_store/src/aws/mod.rs @@ -0,0 +1,746 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for S3 +//! +//! ## Multi-part uploads +//! +//! Multi-part uploads can be initiated with the [ObjectStore::put_multipart] method. +//! Data passed to the writer is automatically buffered to meet the minimum size +//! requirements for a part. Multiple parts are uploaded concurrently. +//! +//! If the writer fails for any reason, you may have parts uploaded to AWS but not +//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method +//! to abort the upload and drop those unneeded parts. In addition, you may wish to +//! consider implementing [automatic cleanup] of unused parts that are older than one +//! week. +//! +//! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/ + +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{DateTime, Utc}; +use futures::stream::BoxStream; +use futures::TryStreamExt; +use reqwest::Client; +use snafu::{OptionExt, ResultExt, Snafu}; +use std::collections::BTreeSet; +use std::ops::Range; +use std::sync::Arc; +use tokio::io::AsyncWrite; +use tracing::info; + +use crate::aws::client::{S3Client, S3Config}; +use crate::aws::credential::{ + AwsCredential, CredentialProvider, InstanceCredentialProvider, + StaticCredentialProvider, WebIdentityProvider, +}; +use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}; +use crate::{ + GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, Result, + RetryConfig, StreamExt, +}; + +mod client; +mod credential; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +enum Error { + #[snafu(display("Last-Modified Header missing from response"))] + MissingLastModified, + + #[snafu(display("Content-Length Header missing from response"))] + MissingContentLength, + + #[snafu(display("Invalid last modified '{}': {}", last_modified, source))] + InvalidLastModified { + last_modified: String, + source: chrono::ParseError, + }, + + #[snafu(display("Invalid content length '{}': {}", content_length, source))] + InvalidContentLength { + content_length: String, + source: std::num::ParseIntError, + }, + + #[snafu(display("Missing region"))] + MissingRegion, + + #[snafu(display("Missing bucket name"))] + MissingBucketName, + + #[snafu(display("Missing AccessKeyId"))] + MissingAccessKeyId, + + #[snafu(display("Missing SecretAccessKey"))] + MissingSecretAccessKey, + + #[snafu(display("ETag Header missing from response"))] + MissingEtag, + + #[snafu(display("Received header containing non-ASCII data"))] + BadHeader { source: reqwest::header::ToStrError }, + + #[snafu(display("Error reading token file: {}", source))] + ReadTokenFile { source: std::io::Error }, +} + +impl From for super::Error { + fn from(err: Error) -> Self { + Self::Generic { + store: "S3", + source: Box::new(err), + } + } +} + +/// Interface for [Amazon S3](https://aws.amazon.com/s3/). +#[derive(Debug)] +pub struct AmazonS3 { + client: Arc, +} + +impl std::fmt::Display for AmazonS3 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "AmazonS3({})", self.client.config().bucket) + } +} + +#[async_trait] +impl ObjectStore for AmazonS3 { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + self.client.put_request(location, Some(bytes), &()).await?; + Ok(()) + } + + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let id = self.client.create_multipart(location).await?; + + let upload = S3MultiPartUpload { + location: location.clone(), + upload_id: id.clone(), + client: Arc::clone(&self.client), + }; + + Ok((id, Box::new(CloudMultiPartUpload::new(upload, 8)))) + } + + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + self.client + .delete_request(location, &[("uploadId", multipart_id)]) + .await + } + + async fn get(&self, location: &Path) -> Result { + let response = self.client.get_request(location, None, false).await?; + let stream = response + .bytes_stream() + .map_err(|source| crate::Error::Generic { + store: "S3", + source: Box::new(source), + }) + .boxed(); + + Ok(GetResult::Stream(stream)) + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let bytes = self + .client + .get_request(location, Some(range), false) + .await? + .bytes() + .await + .map_err(|source| client::Error::GetResponseBody { + source, + path: location.to_string(), + })?; + Ok(bytes) + } + + async fn head(&self, location: &Path) -> Result { + use reqwest::header::{CONTENT_LENGTH, LAST_MODIFIED}; + + // Extract meta from headers + // https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadObject.html#API_HeadObject_ResponseSyntax + let response = self.client.get_request(location, None, true).await?; + let headers = response.headers(); + + let last_modified = headers + .get(LAST_MODIFIED) + .context(MissingLastModifiedSnafu)?; + + let content_length = headers + .get(CONTENT_LENGTH) + .context(MissingContentLengthSnafu)?; + + let last_modified = last_modified.to_str().context(BadHeaderSnafu)?; + let last_modified = DateTime::parse_from_rfc2822(last_modified) + .context(InvalidLastModifiedSnafu { last_modified })? + .with_timezone(&Utc); + + let content_length = content_length.to_str().context(BadHeaderSnafu)?; + let content_length = content_length + .parse() + .context(InvalidContentLengthSnafu { content_length })?; + Ok(ObjectMeta { + location: location.clone(), + last_modified, + size: content_length, + }) + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.client.delete_request(location, &()).await + } + + async fn list( + &self, + prefix: Option<&Path>, + ) -> Result>> { + let stream = self + .client + .list_paginated(prefix, false) + .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok))) + .try_flatten() + .boxed(); + + Ok(stream) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let mut stream = self.client.list_paginated(prefix, true); + + let mut common_prefixes = BTreeSet::new(); + let mut objects = Vec::new(); + + while let Some(result) = stream.next().await { + let response = result?; + common_prefixes.extend(response.common_prefixes.into_iter()); + objects.extend(response.objects.into_iter()); + } + + Ok(ListResult { + common_prefixes: common_prefixes.into_iter().collect(), + objects, + }) + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to).await + } + + async fn copy_if_not_exists(&self, _source: &Path, _dest: &Path) -> Result<()> { + // Will need dynamodb_lock + Err(crate::Error::NotImplemented) + } +} + +struct S3MultiPartUpload { + location: Path, + upload_id: String, + client: Arc, +} + +#[async_trait] +impl CloudMultiPartUploadImpl for S3MultiPartUpload { + async fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> Result { + use reqwest::header::ETAG; + let part = (part_idx + 1).to_string(); + + let response = self + .client + .put_request( + &self.location, + Some(buf.into()), + &[("partNumber", &part), ("uploadId", &self.upload_id)], + ) + .await?; + + let etag = response + .headers() + .get(ETAG) + .context(MissingEtagSnafu) + .map_err(crate::Error::from)?; + + let etag = etag + .to_str() + .context(BadHeaderSnafu) + .map_err(crate::Error::from)?; + + Ok(UploadPart { + content_id: etag.to_string(), + }) + } + + async fn complete( + &self, + completed_parts: Vec, + ) -> Result<(), std::io::Error> { + self.client + .complete_multipart(&self.location, &self.upload_id, completed_parts) + .await?; + Ok(()) + } +} + +/// Configure a connection to Amazon S3 using the specified credentials in +/// the specified Amazon region and bucket. +/// +/// # Example +/// ``` +/// # let REGION = "foo"; +/// # let BUCKET_NAME = "foo"; +/// # let ACCESS_KEY_ID = "foo"; +/// # let SECRET_KEY = "foo"; +/// # use object_store::aws::AmazonS3Builder; +/// let s3 = AmazonS3Builder::new() +/// .with_region(REGION) +/// .with_bucket_name(BUCKET_NAME) +/// .with_access_key_id(ACCESS_KEY_ID) +/// .with_secret_access_key(SECRET_KEY) +/// .build(); +/// ``` +#[derive(Debug, Default)] +pub struct AmazonS3Builder { + access_key_id: Option, + secret_access_key: Option, + region: Option, + bucket_name: Option, + endpoint: Option, + token: Option, + retry_config: RetryConfig, + allow_http: bool, + imdsv1_fallback: bool, +} + +impl AmazonS3Builder { + /// Create a new [`AmazonS3Builder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Fill the [`AmazonS3Builder`] with regular AWS environment variables + /// + /// Variables extracted from environment: + /// * AWS_ACCESS_KEY_ID -> access_key_id + /// * AWS_SECRET_ACCESS_KEY -> secret_access_key + /// * AWS_DEFAULT_REGION -> region + /// * AWS_ENDPOINT -> endpoint + /// * AWS_SESSION_TOKEN -> token + /// # Example + /// ``` + /// use object_store::aws::AmazonS3Builder; + /// + /// let s3 = AmazonS3Builder::from_env() + /// .with_bucket_name("foo") + /// .build(); + /// ``` + pub fn from_env() -> Self { + let mut builder: Self = Default::default(); + + if let Ok(access_key_id) = std::env::var("AWS_ACCESS_KEY_ID") { + builder.access_key_id = Some(access_key_id); + } + + if let Ok(secret_access_key) = std::env::var("AWS_SECRET_ACCESS_KEY") { + builder.secret_access_key = Some(secret_access_key); + } + + if let Ok(secret) = std::env::var("AWS_DEFAULT_REGION") { + builder.region = Some(secret); + } + + if let Ok(endpoint) = std::env::var("AWS_ENDPOINT") { + builder.endpoint = Some(endpoint); + } + + if let Ok(token) = std::env::var("AWS_SESSION_TOKEN") { + builder.token = Some(token); + } + + builder + } + + /// Set the AWS Access Key (required) + pub fn with_access_key_id(mut self, access_key_id: impl Into) -> Self { + self.access_key_id = Some(access_key_id.into()); + self + } + + /// Set the AWS Secret Access Key (required) + pub fn with_secret_access_key( + mut self, + secret_access_key: impl Into, + ) -> Self { + self.secret_access_key = Some(secret_access_key.into()); + self + } + + /// Set the region (e.g. `us-east-1`) (required) + pub fn with_region(mut self, region: impl Into) -> Self { + self.region = Some(region.into()); + self + } + + /// Set the bucket_name (required) + pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { + self.bucket_name = Some(bucket_name.into()); + self + } + + /// Sets the endpoint for communicating with AWS S3. Default value + /// is based on region. + /// + /// For example, this might be set to `"http://localhost:4566:` + /// for testing against a localstack instance. + pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { + self.endpoint = Some(endpoint.into()); + self + } + + /// Set the token to use for requests (passed to underlying provider) + pub fn with_token(mut self, token: impl Into) -> Self { + self.token = Some(token.into()); + self + } + + /// Sets what protocol is allowed. If `allow_http` is : + /// * false (default): Only HTTPS are allowed + /// * true: HTTP and HTTPS are allowed + pub fn with_allow_http(mut self, allow_http: bool) -> Self { + self.allow_http = allow_http; + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// By default instance credentials will only be fetched over [IMDSv2], as AWS recommends + /// against having IMDSv1 enabled on EC2 instances as it is vulnerable to [SSRF attack] + /// + /// However, certain deployment environments, such as those running old versions of kube2iam, + /// may not support IMDSv2. This option will enable automatic fallback to using IMDSv1 + /// if the token endpoint returns a 403 error indicating that IMDSv2 is not supported. + /// + /// This option has no effect if not using instance credentials + /// + /// [IMDSv2]: [https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html] + /// [SSRF attack]: [https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/] + /// + pub fn with_imdsv1_fallback(mut self) -> Self { + self.imdsv1_fallback = true; + self + } + + /// Create a [`AmazonS3`] instance from the provided values, + /// consuming `self`. + pub fn build(self) -> Result { + let bucket = self.bucket_name.context(MissingBucketNameSnafu)?; + let region = self.region.context(MissingRegionSnafu)?; + + let credentials = match (self.access_key_id, self.secret_access_key, self.token) { + (Some(key_id), Some(secret_key), token) => { + info!("Using Static credential provider"); + CredentialProvider::Static(StaticCredentialProvider { + credential: Arc::new(AwsCredential { + key_id, + secret_key, + token, + }), + }) + } + (None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()), + (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()), + // TODO: Replace with `AmazonS3Builder::credentials_from_env` + _ => match ( + std::env::var_os("AWS_WEB_IDENTITY_TOKEN_FILE"), + std::env::var("AWS_ROLE_ARN"), + ) { + (Some(token_file), Ok(role_arn)) => { + info!("Using WebIdentity credential provider"); + let token = std::fs::read_to_string(token_file) + .context(ReadTokenFileSnafu)?; + + let session_name = std::env::var("AWS_ROLE_SESSION_NAME") + .unwrap_or_else(|_| "WebIdentitySession".to_string()); + + let endpoint = format!("https://sts.{}.amazonaws.com", region); + + // Disallow non-HTTPs requests + let client = Client::builder().https_only(true).build().unwrap(); + + CredentialProvider::WebIdentity(WebIdentityProvider { + cache: Default::default(), + token, + session_name, + role_arn, + endpoint, + client, + retry_config: self.retry_config.clone(), + }) + } + _ => { + info!("Using Instance credential provider"); + + // The instance metadata endpoint is access over HTTP + let client = Client::builder().https_only(false).build().unwrap(); + + CredentialProvider::Instance(InstanceCredentialProvider { + cache: Default::default(), + client, + retry_config: self.retry_config.clone(), + imdsv1_fallback: self.imdsv1_fallback, + }) + } + }, + }; + + let endpoint = self + .endpoint + .unwrap_or_else(|| format!("https://s3.{}.amazonaws.com", region)); + + let config = S3Config { + region, + endpoint, + bucket, + credentials, + retry_config: self.retry_config, + allow_http: self.allow_http, + }; + + let client = Arc::new(S3Client::new(config)); + + Ok(AmazonS3 { client }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{ + get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter, + put_get_delete_list, rename_and_copy, stream_get, + }; + use bytes::Bytes; + use std::env; + + const NON_EXISTENT_NAME: &str = "nonexistentname"; + + // Helper macro to skip tests if TEST_INTEGRATION and the AWS + // environment variables are not set. Returns a configured + // AmazonS3Builder + macro_rules! maybe_skip_integration { + () => {{ + dotenv::dotenv().ok(); + + let required_vars = [ + "OBJECT_STORE_AWS_DEFAULT_REGION", + "OBJECT_STORE_BUCKET", + "OBJECT_STORE_AWS_ACCESS_KEY_ID", + "OBJECT_STORE_AWS_SECRET_ACCESS_KEY", + ]; + let unset_vars: Vec<_> = required_vars + .iter() + .filter_map(|&name| match env::var(name) { + Ok(_) => None, + Err(_) => Some(name), + }) + .collect(); + let unset_var_names = unset_vars.join(", "); + + let force = env::var("TEST_INTEGRATION"); + + if force.is_ok() && !unset_var_names.is_empty() { + panic!( + "TEST_INTEGRATION is set, \ + but variable(s) {} need to be set", + unset_var_names + ); + } else if force.is_err() { + eprintln!( + "skipping AWS integration test - set {}TEST_INTEGRATION to run", + if unset_var_names.is_empty() { + String::new() + } else { + format!("{} and ", unset_var_names) + } + ); + return; + } else { + let config = AmazonS3Builder::new() + .with_access_key_id( + env::var("OBJECT_STORE_AWS_ACCESS_KEY_ID") + .expect("already checked OBJECT_STORE_AWS_ACCESS_KEY_ID"), + ) + .with_secret_access_key( + env::var("OBJECT_STORE_AWS_SECRET_ACCESS_KEY") + .expect("already checked OBJECT_STORE_AWS_SECRET_ACCESS_KEY"), + ) + .with_region( + env::var("OBJECT_STORE_AWS_DEFAULT_REGION") + .expect("already checked OBJECT_STORE_AWS_DEFAULT_REGION"), + ) + .with_bucket_name( + env::var("OBJECT_STORE_BUCKET") + .expect("already checked OBJECT_STORE_BUCKET"), + ) + .with_allow_http(true); + + let config = + if let Some(endpoint) = env::var("OBJECT_STORE_AWS_ENDPOINT").ok() { + config.with_endpoint(endpoint) + } else { + config + }; + + let config = if let Some(token) = + env::var("OBJECT_STORE_AWS_SESSION_TOKEN").ok() + { + config.with_token(token) + } else { + config + }; + + config + } + }}; + } + + #[test] + fn s3_test_config_from_env() { + let aws_access_key_id = env::var("AWS_ACCESS_KEY_ID") + .unwrap_or_else(|_| "object_store:fake_access_key_id".into()); + let aws_secret_access_key = env::var("AWS_SECRET_ACCESS_KEY") + .unwrap_or_else(|_| "object_store:fake_secret_key".into()); + + let aws_default_region = env::var("AWS_DEFAULT_REGION") + .unwrap_or_else(|_| "object_store:fake_default_region".into()); + + let aws_endpoint = env::var("AWS_ENDPOINT") + .unwrap_or_else(|_| "object_store:fake_endpoint".into()); + let aws_session_token = env::var("AWS_SESSION_TOKEN") + .unwrap_or_else(|_| "object_store:fake_session_token".into()); + + // required + env::set_var("AWS_ACCESS_KEY_ID", &aws_access_key_id); + env::set_var("AWS_SECRET_ACCESS_KEY", &aws_secret_access_key); + env::set_var("AWS_DEFAULT_REGION", &aws_default_region); + + // optional + env::set_var("AWS_ENDPOINT", &aws_endpoint); + env::set_var("AWS_SESSION_TOKEN", &aws_session_token); + + let builder = AmazonS3Builder::from_env(); + assert_eq!(builder.access_key_id.unwrap(), aws_access_key_id.as_str()); + assert_eq!( + builder.secret_access_key.unwrap(), + aws_secret_access_key.as_str() + ); + assert_eq!(builder.region.unwrap(), aws_default_region); + + assert_eq!(builder.endpoint.unwrap(), aws_endpoint); + assert_eq!(builder.token.unwrap(), aws_session_token); + } + + #[tokio::test] + async fn s3_test() { + let config = maybe_skip_integration!(); + let integration = config.build().unwrap(); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + stream_get(&integration).await; + } + + #[tokio::test] + async fn s3_test_get_nonexistent_location() { + let config = maybe_skip_integration!(); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = get_nonexistent_object(&integration, Some(location)) + .await + .unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + async fn s3_test_get_nonexistent_bucket() { + let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.get(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + async fn s3_test_put_nonexistent_bucket() { + let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + let data = Bytes::from("arbitrary data"); + + let err = integration.put(&location, data).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } + + #[tokio::test] + async fn s3_test_delete_nonexistent_location() { + let config = maybe_skip_integration!(); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + integration.delete(&location).await.unwrap(); + } + + #[tokio::test] + async fn s3_test_delete_nonexistent_bucket() { + let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + let integration = config.build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.delete(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + } +} diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs new file mode 100644 index 000000000000..ece07853a1b6 --- /dev/null +++ b/object_store/src/azure/client.rs @@ -0,0 +1,725 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::credential::{AzureCredential, CredentialProvider}; +use crate::azure::credential::*; +use crate::client::pagination::stream_paginated; +use crate::client::retry::RetryExt; +use crate::path::DELIMITER; +use crate::util::{format_http_range, format_prefix}; +use crate::{BoxStream, ListResult, ObjectMeta, Path, Result, RetryConfig, StreamExt}; +use bytes::{Buf, Bytes}; +use chrono::{DateTime, TimeZone, Utc}; +use itertools::Itertools; +use reqwest::{ + header::{HeaderValue, CONTENT_LENGTH, IF_NONE_MATCH, RANGE}, + Client as ReqwestClient, Method, Response, StatusCode, +}; +use serde::{Deserialize, Deserializer, Serialize}; +use snafu::{ResultExt, Snafu}; +use std::collections::HashMap; +use std::ops::Range; +use url::Url; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub(crate) enum Error { + #[snafu(display("Error performing get request {}: {}", path, source))] + GetRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error getting get response body {}: {}", path, source))] + GetResponseBody { + source: reqwest::Error, + path: String, + }, + + #[snafu(display("Error performing put request {}: {}", path, source))] + PutRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing delete request {}: {}", path, source))] + DeleteRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing copy request {}: {}", path, source))] + CopyRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing list request: {}", source))] + ListRequest { source: crate::client::retry::Error }, + + #[snafu(display("Error getting list response body: {}", source))] + ListResponseBody { source: reqwest::Error }, + + #[snafu(display("Got invalid list response: {}", source))] + InvalidListResponse { source: quick_xml::de::DeError }, + + #[snafu(display("Error authorizing request: {}", source))] + Authorization { + source: crate::azure::credential::Error, + }, +} + +impl From for crate::Error { + fn from(err: Error) -> Self { + match err { + Error::GetRequest { source, path } + | Error::DeleteRequest { source, path } + | Error::CopyRequest { source, path } + | Error::PutRequest { source, path } + if matches!(source.status(), Some(StatusCode::NOT_FOUND)) => + { + Self::NotFound { + path, + source: Box::new(source), + } + } + Error::CopyRequest { source, path } + if matches!(source.status(), Some(StatusCode::CONFLICT)) => + { + Self::AlreadyExists { + path, + source: Box::new(source), + } + } + _ => Self::Generic { + store: "MicrosoftAzure", + source: Box::new(err), + }, + } + } +} + +/// Configuration for [AzureClient] +#[derive(Debug)] +pub struct AzureConfig { + pub account: String, + pub container: String, + pub credentials: CredentialProvider, + pub retry_config: RetryConfig, + pub allow_http: bool, + pub service: Url, + pub is_emulator: bool, +} + +impl AzureConfig { + fn path_url(&self, path: &Path) -> Url { + let mut url = self.service.clone(); + { + let mut path_mut = url.path_segments_mut().unwrap(); + if self.is_emulator { + path_mut.push(&self.account); + } + path_mut.push(&self.container).extend(path.parts()); + } + url + } +} + +#[derive(Debug)] +pub(crate) struct AzureClient { + config: AzureConfig, + client: ReqwestClient, +} + +impl AzureClient { + /// create a new instance of [AzureClient] + pub fn new(config: AzureConfig) -> Self { + let client = reqwest::ClientBuilder::new() + .https_only(!config.allow_http) + .build() + .unwrap(); + + Self { config, client } + } + + /// Returns the config + pub fn config(&self) -> &AzureConfig { + &self.config + } + + async fn get_credential(&self) -> Result { + match &self.config.credentials { + CredentialProvider::AccessKey(key) => { + Ok(AzureCredential::AccessKey(key.to_owned())) + } + CredentialProvider::ClientSecret(cred) => { + let token = cred + .fetch_token(&self.client, &self.config.retry_config) + .await + .context(AuthorizationSnafu)?; + Ok(AzureCredential::AuthorizationToken( + // we do the conversion to a HeaderValue here, since it is fallible + // and we wna to use it in an infallible function + HeaderValue::from_str(&format!("Bearer {}", token)).map_err( + |err| crate::Error::Generic { + store: "MicrosoftAzure", + source: Box::new(err), + }, + )?, + )) + } + CredentialProvider::SASToken(sas) => { + Ok(AzureCredential::SASToken(sas.clone())) + } + } + } + + /// Make an Azure PUT request + pub async fn put_request( + &self, + path: &Path, + bytes: Option, + is_block_op: bool, + query: &T, + ) -> Result { + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + + let mut builder = self.client.request(Method::PUT, url); + + if !is_block_op { + builder = builder.header(&BLOB_TYPE, "BlockBlob").query(query); + } else { + builder = builder.query(query); + } + + if let Some(bytes) = bytes { + builder = builder + .header(CONTENT_LENGTH, HeaderValue::from(bytes.len())) + .body(bytes) + } else { + builder = builder.header(CONTENT_LENGTH, HeaderValue::from_static("0")); + } + + let response = builder + .with_azure_authorization(&credential, &self.config.account) + .send_retry(&self.config.retry_config) + .await + .context(PutRequestSnafu { + path: path.as_ref(), + })?; + + Ok(response) + } + + /// Make an Azure GET request + /// + /// + pub async fn get_request( + &self, + path: &Path, + range: Option>, + head: bool, + ) -> Result { + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + let method = match head { + true => Method::HEAD, + false => Method::GET, + }; + + let mut builder = self + .client + .request(method, url) + .header(CONTENT_LENGTH, HeaderValue::from_static("0")) + .body(Bytes::new()); + + if let Some(range) = range { + builder = builder.header(RANGE, format_http_range(range)); + } + + let response = builder + .with_azure_authorization(&credential, &self.config.account) + .send_retry(&self.config.retry_config) + .await + .context(GetRequestSnafu { + path: path.as_ref(), + })?; + + Ok(response) + } + + /// Make an Azure Delete request + pub async fn delete_request( + &self, + path: &Path, + query: &T, + ) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.config.path_url(path); + + self.client + .request(Method::DELETE, url) + .query(query) + .header(&DELETE_SNAPSHOTS, "include") + .with_azure_authorization(&credential, &self.config.account) + .send_retry(&self.config.retry_config) + .await + .context(DeleteRequestSnafu { + path: path.as_ref(), + })?; + + Ok(()) + } + + /// Make an Azure Copy request + pub async fn copy_request( + &self, + from: &Path, + to: &Path, + overwrite: bool, + ) -> Result<()> { + let credential = self.get_credential().await?; + let url = self.config.path_url(to); + let mut source = self.config.path_url(from); + + // If using SAS authorization must include the headers in the URL + // + if let AzureCredential::SASToken(pairs) = &credential { + source.query_pairs_mut().extend_pairs(pairs); + } + + let mut builder = self + .client + .request(Method::PUT, url) + .header(©_SOURCE, source.to_string()) + .header(CONTENT_LENGTH, HeaderValue::from_static("0")); + + if !overwrite { + builder = builder.header(IF_NONE_MATCH, "*"); + } + + builder + .with_azure_authorization(&credential, &self.config.account) + .send_retry(&self.config.retry_config) + .await + .context(CopyRequestSnafu { + path: from.as_ref(), + })?; + + Ok(()) + } + + /// Make an Azure List request + async fn list_request( + &self, + prefix: Option<&str>, + delimiter: bool, + token: Option<&str>, + ) -> Result<(ListResult, Option)> { + let credential = self.get_credential().await?; + let url = self.config.path_url(&Path::default()); + + let mut query = Vec::with_capacity(5); + query.push(("restype", "container")); + query.push(("comp", "list")); + + if let Some(prefix) = prefix { + query.push(("prefix", prefix)) + } + + if delimiter { + query.push(("delimiter", DELIMITER)) + } + + if let Some(token) = token { + query.push(("marker", token)) + } + + let response = self + .client + .request(Method::GET, url) + .query(&query) + .with_azure_authorization(&credential, &self.config.account) + .send_retry(&self.config.retry_config) + .await + .context(ListRequestSnafu)? + .bytes() + .await + .context(ListResponseBodySnafu)?; + + let mut response: ListResultInternal = + quick_xml::de::from_reader(response.reader()) + .context(InvalidListResponseSnafu)?; + let token = response.next_marker.take(); + + Ok((response.try_into()?, token)) + } + + /// Perform a list operation automatically handling pagination + pub fn list_paginated( + &self, + prefix: Option<&Path>, + delimiter: bool, + ) -> BoxStream<'_, Result> { + let prefix = format_prefix(prefix); + stream_paginated(prefix, move |prefix, token| async move { + let (r, next_token) = self + .list_request(prefix.as_deref(), delimiter, token.as_deref()) + .await?; + Ok((r, prefix, next_token)) + }) + .boxed() + } +} + +/// Raw / internal response from list requests +#[derive(Debug, Clone, PartialEq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct ListResultInternal { + pub prefix: Option, + pub max_results: Option, + pub delimiter: Option, + pub next_marker: Option, + pub blobs: Blobs, +} + +impl TryFrom for ListResult { + type Error = crate::Error; + + fn try_from(value: ListResultInternal) -> Result { + let common_prefixes = value + .blobs + .blob_prefix + .unwrap_or_default() + .into_iter() + .map(|x| Ok(Path::parse(&x.name)?)) + .collect::>()?; + + let objects = value + .blobs + .blobs + .into_iter() + .map(ObjectMeta::try_from) + // Note: workaround for gen2 accounts with hierarchical namespaces. These accounts also + // return path segments as "directories". When we cant directories, its always via + // the BlobPrefix mechanics. + .filter_map_ok(|obj| if obj.size > 0 { Some(obj) } else { None }) + .collect::>()?; + + Ok(Self { + common_prefixes, + objects, + }) + } +} + +/// Collection of blobs and potentially shared prefixes returned from list requests. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct Blobs { + pub blob_prefix: Option>, + #[serde(rename = "Blob", default)] + pub blobs: Vec, +} + +/// Common prefix in list blobs response +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct BlobPrefix { + pub name: String, +} + +/// Details for a specific blob +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct Blob { + pub name: String, + pub version_id: Option, + pub is_current_version: Option, + pub deleted: Option, + pub properties: BlobProperties, + pub metadata: Option>, +} + +impl TryFrom for ObjectMeta { + type Error = crate::Error; + + fn try_from(value: Blob) -> Result { + Ok(Self { + location: Path::parse(value.name)?, + last_modified: value.properties.last_modified, + size: value.properties.content_length as usize, + }) + } +} + +/// Properties associated with individual blobs. The actual list +/// of returned properties is much more exhaustive, but we limit +/// the parsed fields to the ones relevant in this crate. +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "PascalCase")] +struct BlobProperties { + #[serde(deserialize_with = "deserialize_http_date", rename = "Last-Modified")] + pub last_modified: DateTime, + pub etag: String, + #[serde(rename = "Content-Length")] + pub content_length: u64, + #[serde(rename = "Content-Type")] + pub content_type: String, + #[serde(rename = "Content-Encoding")] + pub content_encoding: Option, + #[serde(rename = "Content-Language")] + pub content_language: Option, +} + +// deserialize dates used in Azure payloads according to rfc1123 +fn deserialize_http_date<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + Utc.datetime_from_str(&s, RFC1123_FMT) + .map_err(serde::de::Error::custom) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct BlockId(Bytes); + +impl BlockId { + pub fn new(block_id: impl Into) -> Self { + Self(block_id.into()) + } +} + +impl From for BlockId +where + B: Into, +{ + fn from(v: B) -> Self { + Self::new(v) + } +} + +impl AsRef<[u8]> for BlockId { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub(crate) struct BlockList { + pub blocks: Vec, +} + +impl BlockList { + pub fn to_xml(&self) -> String { + let mut s = String::new(); + s.push_str("\n\n"); + for block_id in &self.blocks { + let node = format!( + "\t{}\n", + base64::encode(block_id) + ); + s.push_str(&node); + } + + s.push_str(""); + s + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::*; + + #[test] + fn deserde_azure() { + const S: &str = " + + + + blob0.txt + + Thu, 01 Jul 2021 10:44:59 GMT + Thu, 01 Jul 2021 10:44:59 GMT + Thu, 07 Jul 2022 14:38:48 GMT + 0x8D93C7D4629C227 + 8 + text/plain + + + + rvr3UC1SmUw7AZV2NqPN0g== + + + BlockBlob + Hot + true + unlocked + available + true + + uservalue + + + + blob1.txt + + Thu, 01 Jul 2021 10:44:59 GMT + Thu, 01 Jul 2021 10:44:59 GMT + 0x8D93C7D463004D6 + 8 + text/plain + + + + rvr3UC1SmUw7AZV2NqPN0g== + + + BlockBlob + Hot + true + unlocked + available + true + + + + + blob2.txt + + Thu, 01 Jul 2021 10:44:59 GMT + Thu, 01 Jul 2021 10:44:59 GMT + 0x8D93C7D4636478A + 8 + text/plain + + + + rvr3UC1SmUw7AZV2NqPN0g== + + + BlockBlob + Hot + true + unlocked + available + true + + + + + +"; + + let mut _list_blobs_response_internal: ListResultInternal = + quick_xml::de::from_str(S).unwrap(); + } + + #[test] + fn deserde_azurite() { + const S: &str = " + + + + 5000 + + + + blob0.txt + + Thu, 01 Jul 2021 10:45:02 GMT + Thu, 01 Jul 2021 10:45:02 GMT + 0x228281B5D517B20 + 8 + text/plain + rvr3UC1SmUw7AZV2NqPN0g== + BlockBlob + unlocked + available + true + Hot + true + Thu, 01 Jul 2021 10:45:02 GMT + + + + blob1.txt + + Thu, 01 Jul 2021 10:45:02 GMT + Thu, 01 Jul 2021 10:45:02 GMT + 0x1DD959381A8A860 + 8 + text/plain + rvr3UC1SmUw7AZV2NqPN0g== + BlockBlob + unlocked + available + true + Hot + true + Thu, 01 Jul 2021 10:45:02 GMT + + + + blob2.txt + + Thu, 01 Jul 2021 10:45:02 GMT + Thu, 01 Jul 2021 10:45:02 GMT + 0x1FBE9C9B0C7B650 + 8 + text/plain + rvr3UC1SmUw7AZV2NqPN0g== + BlockBlob + unlocked + available + true + Hot + true + Thu, 01 Jul 2021 10:45:02 GMT + + + + +"; + + let mut _list_blobs_response_internal: ListResultInternal = + quick_xml::de::from_str(S).unwrap(); + } + + #[test] + fn to_xml() { + const S: &str = " + +\tbnVtZXJvMQ== +\tbnVtZXJvMg== +\tbnVtZXJvMw== +"; + let mut blocks = BlockList { blocks: Vec::new() }; + blocks.blocks.push(Bytes::from_static(b"numero1").into()); + blocks.blocks.push("numero2".into()); + blocks.blocks.push("numero3".into()); + + let res: &str = &blocks.to_xml(); + + assert_eq!(res, S) + } +} diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs new file mode 100644 index 000000000000..721fcaea46f0 --- /dev/null +++ b/object_store/src/azure/credential.rs @@ -0,0 +1,350 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::retry::RetryExt; +use crate::client::token::{TemporaryToken, TokenCache}; +use crate::util::hmac_sha256; +use crate::RetryConfig; +use chrono::Utc; +use reqwest::header::ACCEPT; +use reqwest::{ + header::{ + HeaderMap, HeaderName, HeaderValue, AUTHORIZATION, CONTENT_ENCODING, + CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, + IF_MODIFIED_SINCE, IF_NONE_MATCH, IF_UNMODIFIED_SINCE, RANGE, + }, + Client, Method, RequestBuilder, +}; +use snafu::{ResultExt, Snafu}; +use std::borrow::Cow; +use std::str; +use std::time::{Duration, Instant}; +use url::Url; + +static AZURE_VERSION: HeaderValue = HeaderValue::from_static("2021-08-06"); +static VERSION: HeaderName = HeaderName::from_static("x-ms-version"); +pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-type"); +pub(crate) static DELETE_SNAPSHOTS: HeaderName = + HeaderName::from_static("x-ms-delete-snapshots"); +pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source"); +static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5"); +pub(crate) static RFC1123_FMT: &str = "%a, %d %h %Y %T GMT"; +const CONTENT_TYPE_JSON: &str = "application/json"; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Error performing token request: {}", source))] + TokenRequest { source: crate::client::retry::Error }, + + #[snafu(display("Error getting token response body: {}", source))] + TokenResponseBody { source: reqwest::Error }, +} + +pub type Result = std::result::Result; + +/// Provides credentials for use when signing requests +#[derive(Debug)] +pub enum CredentialProvider { + AccessKey(String), + SASToken(Vec<(String, String)>), + ClientSecret(ClientSecretOAuthProvider), +} + +pub(crate) enum AzureCredential { + AccessKey(String), + SASToken(Vec<(String, String)>), + AuthorizationToken(HeaderValue), +} + +/// A list of known Azure authority hosts +pub mod authority_hosts { + /// China-based Azure Authority Host + pub const AZURE_CHINA: &str = "https://login.chinacloudapi.cn"; + /// Germany-based Azure Authority Host + pub const AZURE_GERMANY: &str = "https://login.microsoftonline.de"; + /// US Government Azure Authority Host + pub const AZURE_GOVERNMENT: &str = "https://login.microsoftonline.us"; + /// Public Cloud Azure Authority Host + pub const AZURE_PUBLIC_CLOUD: &str = "https://login.microsoftonline.com"; +} + +pub(crate) trait CredentialExt { + /// Apply authorization to requests against azure storage accounts + /// + fn with_azure_authorization( + self, + credential: &AzureCredential, + account: &str, + ) -> Self; +} + +impl CredentialExt for RequestBuilder { + fn with_azure_authorization( + mut self, + credential: &AzureCredential, + account: &str, + ) -> Self { + // rfc2822 string should never contain illegal characters + let date = Utc::now(); + let date_str = date.format(RFC1123_FMT).to_string(); + // we formatted the data string ourselves, so unwrapping should be fine + let date_val = HeaderValue::from_str(&date_str).unwrap(); + self = self + .header(DATE, &date_val) + .header(&VERSION, &AZURE_VERSION); + + // Hack around lack of access to underlying request + // https://github.com/seanmonstar/reqwest/issues/1212 + let request = self + .try_clone() + .expect("not stream") + .build() + .expect("request valid"); + + match credential { + AzureCredential::AccessKey(key) => { + let signature = generate_authorization( + request.headers(), + request.url(), + request.method(), + account, + key.as_str(), + ); + self = self + // "signature" is a base 64 encoded string so it should never contain illegal characters. + .header( + AUTHORIZATION, + HeaderValue::from_str(signature.as_str()).unwrap(), + ); + } + AzureCredential::AuthorizationToken(token) => { + self = self.header(AUTHORIZATION, token); + } + AzureCredential::SASToken(query_pairs) => { + self = self.query(&query_pairs); + } + }; + + self + } +} + +/// Generate signed key for authorization via access keys +/// +fn generate_authorization( + h: &HeaderMap, + u: &Url, + method: &Method, + account: &str, + key: &str, +) -> String { + let str_to_sign = string_to_sign(h, u, method, account); + let auth = hmac_sha256(base64::decode(key).unwrap(), &str_to_sign); + format!("SharedKey {}:{}", account, base64::encode(auth)) +} + +fn add_if_exists<'a>(h: &'a HeaderMap, key: &HeaderName) -> &'a str { + h.get(key) + .map(|s| s.to_str()) + .transpose() + .ok() + .flatten() + .unwrap_or_default() +} + +/// +fn string_to_sign(h: &HeaderMap, u: &Url, method: &Method, account: &str) -> String { + // content length must only be specified if != 0 + // this is valid from 2015-02-21 + let content_length = h + .get(&CONTENT_LENGTH) + .map(|s| s.to_str()) + .transpose() + .ok() + .flatten() + .filter(|&v| v != "0") + .unwrap_or_default(); + format!( + "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}", + method.as_ref(), + add_if_exists(h, &CONTENT_ENCODING), + add_if_exists(h, &CONTENT_LANGUAGE), + content_length, + add_if_exists(h, &CONTENT_MD5), + add_if_exists(h, &CONTENT_TYPE), + add_if_exists(h, &DATE), + add_if_exists(h, &IF_MODIFIED_SINCE), + add_if_exists(h, &IF_MATCH), + add_if_exists(h, &IF_NONE_MATCH), + add_if_exists(h, &IF_UNMODIFIED_SINCE), + add_if_exists(h, &RANGE), + canonicalize_header(h), + canonicalized_resource(account, u) + ) +} + +/// +fn canonicalize_header(headers: &HeaderMap) -> String { + let mut names = headers + .iter() + .filter_map(|(k, _)| { + (k.as_str().starts_with("x-ms")) + // TODO remove unwraps + .then(|| (k.as_str(), headers.get(k).unwrap().to_str().unwrap())) + }) + .collect::>(); + names.sort_unstable(); + + let mut result = String::new(); + for (name, value) in names { + result.push_str(name); + result.push(':'); + result.push_str(value); + result.push('\n'); + } + result +} + +/// +fn canonicalized_resource(account: &str, uri: &Url) -> String { + let mut can_res: String = String::new(); + can_res.push('/'); + can_res.push_str(account); + can_res.push_str(uri.path().to_string().as_str()); + can_res.push('\n'); + + // query parameters + let query_pairs = uri.query_pairs(); + { + let mut qps: Vec = Vec::new(); + for (q, _) in query_pairs { + if !(qps.iter().any(|x| x == &*q)) { + qps.push(q.into_owned()); + } + } + + qps.sort(); + + for qparam in qps { + // find correct parameter + let ret = lexy_sort(query_pairs, &qparam); + + can_res = can_res + &qparam.to_lowercase() + ":"; + + for (i, item) in ret.iter().enumerate() { + if i > 0 { + can_res.push(','); + } + can_res.push_str(item); + } + + can_res.push('\n'); + } + }; + + can_res[0..can_res.len() - 1].to_owned() +} + +fn lexy_sort<'a>( + vec: impl Iterator, Cow<'a, str>)> + 'a, + query_param: &str, +) -> Vec> { + let mut values = vec + .filter(|(k, _)| *k == query_param) + .map(|(_, v)| v) + .collect::>(); + values.sort_unstable(); + values +} + +#[derive(serde::Deserialize, Debug)] +struct TokenResponse { + access_token: String, + expires_in: u64, +} + +/// Encapsulates the logic to perform an OAuth token challenge +#[derive(Debug)] +pub struct ClientSecretOAuthProvider { + scope: String, + token_url: String, + client_id: String, + client_secret: String, + cache: TokenCache, +} + +impl ClientSecretOAuthProvider { + /// Create a new [`ClientSecretOAuthProvider`] for an azure backed store + pub fn new( + client_id: String, + client_secret: String, + tenant_id: String, + authority_host: Option, + ) -> Self { + let authority_host = authority_host + .unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned()); + + Self { + scope: "https://storage.azure.com/.default".to_owned(), + token_url: format!("{}/{}/oauth2/v2.0/token", authority_host, tenant_id), + client_id, + client_secret, + cache: TokenCache::default(), + } + } + + /// Fetch a token + pub async fn fetch_token( + &self, + client: &Client, + retry: &RetryConfig, + ) -> Result { + self.cache + .get_or_insert_with(|| self.fetch_token_inner(client, retry)) + .await + } + + /// Fetch a fresh token + async fn fetch_token_inner( + &self, + client: &Client, + retry: &RetryConfig, + ) -> Result> { + let response: TokenResponse = client + .request(Method::POST, &self.token_url) + .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)) + .form(&[ + ("client_id", self.client_id.as_str()), + ("client_secret", self.client_secret.as_str()), + ("scope", self.scope.as_str()), + ("grant_type", "client_credentials"), + ]) + .send_retry(retry) + .await + .context(TokenRequestSnafu)? + .json() + .await + .context(TokenResponseBodySnafu)?; + + let token = TemporaryToken { + token: response.access_token, + expiry: Instant::now() + Duration::from_secs(response.expires_in), + }; + + Ok(token) + } +} diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs new file mode 100644 index 000000000000..dd1cde9c7a2a --- /dev/null +++ b/object_store/src/azure/mod.rs @@ -0,0 +1,707 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for Azure blob storage +//! +//! ## Streaming uploads +//! +//! [ObjectStore::put_multipart] will upload data in blocks and write a blob from those +//! blocks. Data is buffered internally to make blocks of at least 5MB and blocks +//! are uploaded concurrently. +//! +//! [ObjectStore::abort_multipart] is a no-op, since Azure Blob Store doesn't provide +//! a way to drop old blocks. Instead unused blocks are automatically cleaned up +//! after 7 days. +use self::client::{BlockId, BlockList}; +use crate::{ + multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, + path::Path, + GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, RetryConfig, +}; +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{TimeZone, Utc}; +use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use snafu::{ResultExt, Snafu}; +use std::collections::BTreeSet; +use std::fmt::{Debug, Formatter}; +use std::io; +use std::ops::Range; +use std::sync::Arc; +use tokio::io::AsyncWrite; +use url::Url; + +pub use credential::authority_hosts; + +mod client; +mod credential; + +/// The well-known account used by Azurite and the legacy Azure Storage Emulator. +/// +const EMULATOR_ACCOUNT: &str = "devstoreaccount1"; + +/// The well-known account key used by Azurite and the legacy Azure Storage Emulator. +/// +const EMULATOR_ACCOUNT_KEY: &str = + "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="; + +/// A specialized `Error` for Azure object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +enum Error { + #[snafu(display("Last-Modified Header missing from response"))] + MissingLastModified, + + #[snafu(display("Content-Length Header missing from response"))] + MissingContentLength, + + #[snafu(display("Invalid last modified '{}': {}", last_modified, source))] + InvalidLastModified { + last_modified: String, + source: chrono::ParseError, + }, + + #[snafu(display("Invalid content length '{}': {}", content_length, source))] + InvalidContentLength { + content_length: String, + source: std::num::ParseIntError, + }, + + #[snafu(display("Received header containing non-ASCII data"))] + BadHeader { source: reqwest::header::ToStrError }, + + #[snafu(display("Unable parse source url. Url: {}, Error: {}", url, source))] + UnableToParseUrl { + source: url::ParseError, + url: String, + }, + + #[snafu(display( + "Unable parse emulator url {}={}, Error: {}", + env_name, + env_value, + source + ))] + UnableToParseEmulatorUrl { + env_name: String, + env_value: String, + source: url::ParseError, + }, + + #[snafu(display("Account must be specified"))] + MissingAccount {}, + + #[snafu(display("Container name must be specified"))] + MissingContainerName {}, + + #[snafu(display("At least one authorization option must be specified"))] + MissingCredentials {}, + + #[snafu(display("Azure credential error: {}", source), context(false))] + Credential { source: credential::Error }, +} + +impl From for super::Error { + fn from(source: Error) -> Self { + Self::Generic { + store: "MicrosoftAzure", + source: Box::new(source), + } + } +} + +/// Interface for [Microsoft Azure Blob Storage](https://azure.microsoft.com/en-us/services/storage/blobs/). +#[derive(Debug)] +pub struct MicrosoftAzure { + client: Arc, +} + +impl std::fmt::Display for MicrosoftAzure { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "MicrosoftAzure {{ account: {}, container: {} }}", + self.client.config().account, + self.client.config().container + ) + } +} + +#[async_trait] +impl ObjectStore for MicrosoftAzure { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + self.client + .put_request(location, Some(bytes), false, &()) + .await?; + Ok(()) + } + + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let inner = AzureMultiPartUpload { + client: Arc::clone(&self.client), + location: location.to_owned(), + }; + Ok((String::new(), Box::new(CloudMultiPartUpload::new(inner, 8)))) + } + + async fn abort_multipart( + &self, + _location: &Path, + _multipart_id: &MultipartId, + ) -> Result<()> { + // There is no way to drop blocks that have been uploaded. Instead, they simply + // expire in 7 days. + Ok(()) + } + + async fn get(&self, location: &Path) -> Result { + let response = self.client.get_request(location, None, false).await?; + let stream = response + .bytes_stream() + .map_err(|source| crate::Error::Generic { + store: "MicrosoftAzure", + source: Box::new(source), + }) + .boxed(); + + Ok(GetResult::Stream(stream)) + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let bytes = self + .client + .get_request(location, Some(range), false) + .await? + .bytes() + .await + .map_err(|source| client::Error::GetResponseBody { + source, + path: location.to_string(), + })?; + Ok(bytes) + } + + async fn head(&self, location: &Path) -> Result { + use reqwest::header::{CONTENT_LENGTH, LAST_MODIFIED}; + + // Extract meta from headers + // https://docs.microsoft.com/en-us/rest/api/storageservices/get-blob-properties + let response = self.client.get_request(location, None, true).await?; + let headers = response.headers(); + + let last_modified = headers + .get(LAST_MODIFIED) + .ok_or(Error::MissingLastModified)? + .to_str() + .context(BadHeaderSnafu)?; + let last_modified = Utc + .datetime_from_str(last_modified, credential::RFC1123_FMT) + .context(InvalidLastModifiedSnafu { last_modified })?; + + let content_length = headers + .get(CONTENT_LENGTH) + .ok_or(Error::MissingContentLength)? + .to_str() + .context(BadHeaderSnafu)?; + let content_length = content_length + .parse() + .context(InvalidContentLengthSnafu { content_length })?; + + Ok(ObjectMeta { + location: location.clone(), + last_modified, + size: content_length, + }) + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.client.delete_request(location, &()).await + } + + async fn list( + &self, + prefix: Option<&Path>, + ) -> Result>> { + let stream = self + .client + .list_paginated(prefix, false) + .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok))) + .try_flatten() + .boxed(); + + Ok(stream) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let mut stream = self.client.list_paginated(prefix, true); + + let mut common_prefixes = BTreeSet::new(); + let mut objects = Vec::new(); + + while let Some(result) = stream.next().await { + let response = result?; + common_prefixes.extend(response.common_prefixes.into_iter()); + objects.extend(response.objects.into_iter()); + } + + Ok(ListResult { + common_prefixes: common_prefixes.into_iter().collect(), + objects, + }) + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to, true).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to, false).await + } +} + +/// Relevant docs: +/// In Azure Blob Store, parts are "blocks" +/// put_multipart_part -> PUT block +/// complete -> PUT block list +/// abort -> No equivalent; blocks are simply dropped after 7 days +#[derive(Debug, Clone)] +struct AzureMultiPartUpload { + client: Arc, + location: Path, +} + +#[async_trait] +impl CloudMultiPartUploadImpl for AzureMultiPartUpload { + async fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> Result { + let content_id = format!("{:20}", part_idx); + let block_id: BlockId = content_id.clone().into(); + + self.client + .put_request( + &self.location, + Some(buf.into()), + true, + &[("comp", "block"), ("blockid", &base64::encode(block_id))], + ) + .await?; + + Ok(UploadPart { content_id }) + } + + async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error> { + let blocks = completed_parts + .into_iter() + .map(|part| BlockId::from(part.content_id)) + .collect(); + + let block_list = BlockList { blocks }; + let block_xml = block_list.to_xml(); + + self.client + .put_request( + &self.location, + Some(block_xml.into()), + true, + &[("comp", "blocklist")], + ) + .await?; + + Ok(()) + } +} + +/// Configure a connection to Microsoft Azure Blob Storage container using +/// the specified credentials. +/// +/// # Example +/// ``` +/// # let ACCOUNT = "foo"; +/// # let BUCKET_NAME = "foo"; +/// # let ACCESS_KEY = "foo"; +/// # use object_store::azure::MicrosoftAzureBuilder; +/// let azure = MicrosoftAzureBuilder::new() +/// .with_account(ACCOUNT) +/// .with_access_key(ACCESS_KEY) +/// .with_container_name(BUCKET_NAME) +/// .build(); +/// ``` +#[derive(Default)] +pub struct MicrosoftAzureBuilder { + account_name: Option, + access_key: Option, + container_name: Option, + bearer_token: Option, + client_id: Option, + client_secret: Option, + tenant_id: Option, + sas_query_pairs: Option>, + authority_host: Option, + use_emulator: bool, + retry_config: RetryConfig, + allow_http: bool, +} + +impl Debug for MicrosoftAzureBuilder { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "MicrosoftAzureBuilder {{ account: {:?}, container_name: {:?} }}", + self.account_name, self.container_name + ) + } +} + +impl MicrosoftAzureBuilder { + /// Create a new [`MicrosoftAzureBuilder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Create an instance of [MicrosoftAzureBuilder] with values pre-populated from environment variables. + /// + /// Variables extracted from environment: + /// * AZURE_STORAGE_ACCOUNT_NAME: storage account name + /// * AZURE_STORAGE_ACCOUNT_KEY: storage account master key + /// * AZURE_STORAGE_ACCESS_KEY: alias for AZURE_STORAGE_ACCOUNT_KEY + /// * AZURE_STORAGE_CLIENT_ID -> client id for service principal authorization + /// * AZURE_STORAGE_CLIENT_SECRET -> client secret for service principal authorization + /// * AZURE_STORAGE_TENANT_ID -> tenant id used in oauth flows + /// # Example + /// ``` + /// use object_store::azure::MicrosoftAzureBuilder; + /// + /// let azure = MicrosoftAzureBuilder::from_env() + /// .with_container_name("foo") + /// .build(); + /// ``` + pub fn from_env() -> Self { + let mut builder = Self::default(); + + if let Ok(account_name) = std::env::var("AZURE_STORAGE_ACCOUNT_NAME") { + builder.account_name = Some(account_name); + } + + if let Ok(access_key) = std::env::var("AZURE_STORAGE_ACCOUNT_KEY") { + builder.access_key = Some(access_key); + } else if let Ok(access_key) = std::env::var("AZURE_STORAGE_ACCESS_KEY") { + builder.access_key = Some(access_key); + } + + if let Ok(client_id) = std::env::var("AZURE_STORAGE_CLIENT_ID") { + builder.client_id = Some(client_id); + } + + if let Ok(client_secret) = std::env::var("AZURE_STORAGE_CLIENT_SECRET") { + builder.client_secret = Some(client_secret); + } + + if let Ok(tenant_id) = std::env::var("AZURE_STORAGE_TENANT_ID") { + builder.tenant_id = Some(tenant_id); + } + + builder + } + + /// Set the Azure Account (required) + pub fn with_account(mut self, account: impl Into) -> Self { + self.account_name = Some(account.into()); + self + } + + /// Set the Azure Container Name (required) + pub fn with_container_name(mut self, container_name: impl Into) -> Self { + self.container_name = Some(container_name.into()); + self + } + + /// Set the Azure Access Key (required - one of access key, bearer token, or client credentials) + pub fn with_access_key(mut self, access_key: impl Into) -> Self { + self.access_key = Some(access_key.into()); + self + } + + /// Set a static bearer token to be used for authorizing requests + pub fn with_bearer_token_authorization( + mut self, + bearer_token: impl Into, + ) -> Self { + self.bearer_token = Some(bearer_token.into()); + self + } + + /// Set a client secret used for client secret authorization + pub fn with_client_secret_authorization( + mut self, + client_id: impl Into, + client_secret: impl Into, + tenant_id: impl Into, + ) -> Self { + self.client_id = Some(client_id.into()); + self.client_secret = Some(client_secret.into()); + self.tenant_id = Some(tenant_id.into()); + self + } + + /// Set query pairs appended to the url for shared access signature authorization + pub fn with_sas_authorization( + mut self, + query_pairs: impl Into>, + ) -> Self { + self.sas_query_pairs = Some(query_pairs.into()); + self + } + + /// Set if the Azure emulator should be used (defaults to false) + pub fn with_use_emulator(mut self, use_emulator: bool) -> Self { + self.use_emulator = use_emulator; + self + } + + /// Sets what protocol is allowed. If `allow_http` is : + /// * false (default): Only HTTPS is allowed + /// * true: HTTP and HTTPS are allowed + pub fn with_allow_http(mut self, allow_http: bool) -> Self { + self.allow_http = allow_http; + self + } + + /// Sets an alternative authority host for OAuth based authorization + /// common hosts for azure clouds are defined in [authority_hosts]. + /// Defaults to + pub fn with_authority_host(mut self, authority_host: String) -> Self { + self.authority_host = Some(authority_host); + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// Configure a connection to container with given name on Microsoft Azure + /// Blob store. + pub fn build(self) -> Result { + let Self { + account_name, + access_key, + container_name, + bearer_token, + client_id, + client_secret, + tenant_id, + sas_query_pairs, + use_emulator, + retry_config, + allow_http, + authority_host, + } = self; + + let container = container_name.ok_or(Error::MissingContainerName {})?; + + let (is_emulator, allow_http, storage_url, auth, account) = if use_emulator { + let account_name = + account_name.unwrap_or_else(|| EMULATOR_ACCOUNT.to_string()); + // Allow overriding defaults. Values taken from + // from https://docs.rs/azure_storage/0.2.0/src/azure_storage/core/clients/storage_account_client.rs.html#129-141 + let url = url_from_env("AZURITE_BLOB_STORAGE_URL", "http://127.0.0.1:10000")?; + let account_key = + access_key.unwrap_or_else(|| EMULATOR_ACCOUNT_KEY.to_string()); + let credential = credential::CredentialProvider::AccessKey(account_key); + (true, true, url, credential, account_name) + } else { + let account_name = account_name.ok_or(Error::MissingAccount {})?; + let account_url = format!("https://{}.blob.core.windows.net", &account_name); + let url = Url::parse(&account_url) + .context(UnableToParseUrlSnafu { url: account_url })?; + let credential = if let Some(bearer_token) = bearer_token { + Ok(credential::CredentialProvider::AccessKey(bearer_token)) + } else if let Some(access_key) = access_key { + Ok(credential::CredentialProvider::AccessKey(access_key)) + } else if let (Some(client_id), Some(client_secret), Some(tenant_id)) = + (client_id, client_secret, tenant_id) + { + let client_credential = credential::ClientSecretOAuthProvider::new( + client_id, + client_secret, + tenant_id, + authority_host, + ); + Ok(credential::CredentialProvider::ClientSecret( + client_credential, + )) + } else if let Some(query_pairs) = sas_query_pairs { + Ok(credential::CredentialProvider::SASToken(query_pairs)) + } else { + Err(Error::MissingCredentials {}) + }?; + (false, allow_http, url, credential, account_name) + }; + + let config = client::AzureConfig { + account, + allow_http, + retry_config, + service: storage_url, + container, + credentials: auth, + is_emulator, + }; + + let client = Arc::new(client::AzureClient::new(config)); + + Ok(MicrosoftAzure { client }) + } +} + +/// Parses the contents of the environment variable `env_name` as a URL +/// if present, otherwise falls back to default_url +fn url_from_env(env_name: &str, default_url: &str) -> Result { + let url = match std::env::var(env_name) { + Ok(env_value) => { + Url::parse(&env_value).context(UnableToParseEmulatorUrlSnafu { + env_name, + env_value, + })? + } + Err(_) => Url::parse(default_url).expect("Failed to parse default URL"), + }; + Ok(url) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{ + copy_if_not_exists, list_uses_directories_correctly, list_with_delimiter, + put_get_delete_list, rename_and_copy, stream_get, + }; + use std::env; + + // Helper macro to skip tests if TEST_INTEGRATION and the Azure environment + // variables are not set. + macro_rules! maybe_skip_integration { + () => {{ + dotenv::dotenv().ok(); + + let use_emulator = std::env::var("AZURE_USE_EMULATOR").is_ok(); + + let mut required_vars = vec!["OBJECT_STORE_BUCKET"]; + if !use_emulator { + required_vars.push("AZURE_STORAGE_ACCOUNT"); + required_vars.push("AZURE_STORAGE_ACCESS_KEY"); + } + let unset_vars: Vec<_> = required_vars + .iter() + .filter_map(|&name| match env::var(name) { + Ok(_) => None, + Err(_) => Some(name), + }) + .collect(); + let unset_var_names = unset_vars.join(", "); + + let force = std::env::var("TEST_INTEGRATION"); + + if force.is_ok() && !unset_var_names.is_empty() { + panic!( + "TEST_INTEGRATION is set, \ + but variable(s) {} need to be set", + unset_var_names + ) + } else if force.is_err() { + eprintln!( + "skipping Azure integration test - set {}TEST_INTEGRATION to run", + if unset_var_names.is_empty() { + String::new() + } else { + format!("{} and ", unset_var_names) + } + ); + return; + } else { + let builder = MicrosoftAzureBuilder::new() + .with_container_name( + env::var("OBJECT_STORE_BUCKET") + .expect("already checked OBJECT_STORE_BUCKET"), + ) + .with_use_emulator(use_emulator); + if !use_emulator { + builder + .with_account( + env::var("AZURE_STORAGE_ACCOUNT").unwrap_or_default(), + ) + .with_access_key( + env::var("AZURE_STORAGE_ACCESS_KEY").unwrap_or_default(), + ) + } else { + builder + } + } + }}; + } + + #[tokio::test] + async fn azure_blob_test() { + let integration = maybe_skip_integration!().build().unwrap(); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } + + // test for running integration test against actual blob service with service principal + // credentials. To run make sure all environment variables are set and remove the ignore + #[tokio::test] + #[ignore] + async fn azure_blob_test_sp() { + dotenv::dotenv().ok(); + let builder = MicrosoftAzureBuilder::new() + .with_account( + env::var("AZURE_STORAGE_ACCOUNT") + .expect("must be set AZURE_STORAGE_ACCOUNT"), + ) + .with_container_name( + env::var("OBJECT_STORE_BUCKET").expect("must be set OBJECT_STORE_BUCKET"), + ) + .with_client_secret_authorization( + env::var("AZURE_STORAGE_CLIENT_ID") + .expect("must be set AZURE_STORAGE_CLIENT_ID"), + env::var("AZURE_STORAGE_CLIENT_SECRET") + .expect("must be set AZURE_STORAGE_CLIENT_SECRET"), + env::var("AZURE_STORAGE_TENANT_ID") + .expect("must be set AZURE_STORAGE_TENANT_ID"), + ); + let integration = builder.build().unwrap(); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } +} diff --git a/object_store/src/client/backoff.rs b/object_store/src/client/backoff.rs new file mode 100644 index 000000000000..5a6126cc45c6 --- /dev/null +++ b/object_store/src/client/backoff.rs @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use rand::prelude::*; +use std::time::Duration; + +/// Exponential backoff with jitter +/// +/// See +#[allow(missing_copy_implementations)] +#[derive(Debug, Clone)] +pub struct BackoffConfig { + /// The initial backoff duration + pub init_backoff: Duration, + /// The maximum backoff duration + pub max_backoff: Duration, + /// The base of the exponential to use + pub base: f64, +} + +impl Default for BackoffConfig { + fn default() -> Self { + Self { + init_backoff: Duration::from_millis(100), + max_backoff: Duration::from_secs(15), + base: 2., + } + } +} + +/// [`Backoff`] can be created from a [`BackoffConfig`] +/// +/// Consecutive calls to [`Backoff::next`] will return the next backoff interval +/// +pub struct Backoff { + init_backoff: f64, + next_backoff_secs: f64, + max_backoff_secs: f64, + base: f64, + rng: Option>, +} + +impl std::fmt::Debug for Backoff { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Backoff") + .field("init_backoff", &self.init_backoff) + .field("next_backoff_secs", &self.next_backoff_secs) + .field("max_backoff_secs", &self.max_backoff_secs) + .field("base", &self.base) + .finish() + } +} + +impl Backoff { + /// Create a new [`Backoff`] from the provided [`BackoffConfig`] + pub fn new(config: &BackoffConfig) -> Self { + Self::new_with_rng(config, None) + } + + /// Creates a new `Backoff` with the optional `rng` + /// + /// Used [`rand::thread_rng()`] if no rng provided + pub fn new_with_rng( + config: &BackoffConfig, + rng: Option>, + ) -> Self { + let init_backoff = config.init_backoff.as_secs_f64(); + Self { + init_backoff, + next_backoff_secs: init_backoff, + max_backoff_secs: config.max_backoff.as_secs_f64(), + base: config.base, + rng, + } + } + + /// Returns the next backoff duration to wait for + pub fn next(&mut self) -> Duration { + let range = self.init_backoff..(self.next_backoff_secs * self.base); + + let rand_backoff = match self.rng.as_mut() { + Some(rng) => rng.gen_range(range), + None => thread_rng().gen_range(range), + }; + + let next_backoff = self.max_backoff_secs.min(rand_backoff); + Duration::from_secs_f64(std::mem::replace( + &mut self.next_backoff_secs, + next_backoff, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::mock::StepRng; + + #[test] + fn test_backoff() { + let init_backoff_secs = 1.; + let max_backoff_secs = 500.; + let base = 3.; + + let config = BackoffConfig { + init_backoff: Duration::from_secs_f64(init_backoff_secs), + max_backoff: Duration::from_secs_f64(max_backoff_secs), + base, + }; + + let assert_fuzzy_eq = + |a: f64, b: f64| assert!((b - a).abs() < 0.0001, "{} != {}", a, b); + + // Create a static rng that takes the minimum of the range + let rng = Box::new(StepRng::new(0, 0)); + let mut backoff = Backoff::new_with_rng(&config, Some(rng)); + + for _ in 0..20 { + assert_eq!(backoff.next().as_secs_f64(), init_backoff_secs); + } + + // Create a static rng that takes the maximum of the range + let rng = Box::new(StepRng::new(u64::MAX, 0)); + let mut backoff = Backoff::new_with_rng(&config, Some(rng)); + + for i in 0..20 { + let value = (base.powi(i) * init_backoff_secs).min(max_backoff_secs); + assert_fuzzy_eq(backoff.next().as_secs_f64(), value); + } + + // Create a static rng that takes the mid point of the range + let rng = Box::new(StepRng::new(u64::MAX / 2, 0)); + let mut backoff = Backoff::new_with_rng(&config, Some(rng)); + + let mut value = init_backoff_secs; + for _ in 0..20 { + assert_fuzzy_eq(backoff.next().as_secs_f64(), value); + value = (init_backoff_secs + (value * base - init_backoff_secs) / 2.) + .min(max_backoff_secs); + } + } +} diff --git a/object_store/src/client/mock_server.rs b/object_store/src/client/mock_server.rs new file mode 100644 index 000000000000..adb7e0fff779 --- /dev/null +++ b/object_store/src/client/mock_server.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Request, Response, Server}; +use parking_lot::Mutex; +use std::collections::VecDeque; +use std::convert::Infallible; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; + +pub type ResponseFn = Box) -> Response + Send>; + +/// A mock server +pub struct MockServer { + responses: Arc>>, + shutdown: oneshot::Sender<()>, + handle: JoinHandle<()>, + url: String, +} + +impl MockServer { + pub fn new() -> Self { + let responses: Arc>> = + Arc::new(Mutex::new(VecDeque::with_capacity(10))); + + let r = Arc::clone(&responses); + let make_service = make_service_fn(move |_conn| { + let r = Arc::clone(&r); + async move { + Ok::<_, Infallible>(service_fn(move |req| { + let r = Arc::clone(&r); + async move { + Ok::<_, Infallible>(match r.lock().pop_front() { + Some(r) => r(req), + None => Response::new(Body::from("Hello World")), + }) + } + })) + } + }); + + let (shutdown, rx) = oneshot::channel::<()>(); + let server = + Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service); + + let url = format!("http://{}", server.local_addr()); + + let handle = tokio::spawn(async move { + server + .with_graceful_shutdown(async { + rx.await.ok(); + }) + .await + .unwrap() + }); + + Self { + responses, + shutdown, + handle, + url, + } + } + + /// The url of the mock server + pub fn url(&self) -> &str { + &self.url + } + + /// Add a response + pub fn push(&self, response: Response) { + self.push_fn(|_| response) + } + + /// Add a response function + pub fn push_fn(&self, f: F) + where + F: FnOnce(Request) -> Response + Send + 'static, + { + self.responses.lock().push_back(Box::new(f)) + } + + /// Shutdown the mock server + pub async fn shutdown(self) { + let _ = self.shutdown.send(()); + self.handle.await.unwrap() + } +} diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs new file mode 100644 index 000000000000..c93c68a1faa4 --- /dev/null +++ b/object_store/src/client/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generic utilities reqwest based ObjectStore implementations + +pub mod backoff; +#[cfg(test)] +pub mod mock_server; +pub mod pagination; +pub mod retry; +pub mod token; diff --git a/object_store/src/client/pagination.rs b/object_store/src/client/pagination.rs new file mode 100644 index 000000000000..1febe3ae0a90 --- /dev/null +++ b/object_store/src/client/pagination.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use futures::Stream; +use std::future::Future; + +/// Takes a paginated operation `op` that when called with: +/// +/// - A state `S` +/// - An optional next token `Option` +/// +/// Returns +/// +/// - A response value `T` +/// - The next state `S` +/// - The next continuation token `Option` +/// +/// And converts it into a `Stream>` which will first call `op(state, None)`, and yield +/// the returned response `T`. If the returned continuation token was `None` the stream will then +/// finish, otherwise it will continue to call `op(state, token)` with the values returned by the +/// previous call to `op`, until a continuation token of `None` is returned +/// +pub fn stream_paginated(state: S, op: F) -> impl Stream> +where + F: Fn(S, Option) -> Fut + Copy, + Fut: Future)>>, +{ + enum PaginationState { + Start(T), + HasMore(T, String), + Done, + } + + futures::stream::unfold(PaginationState::Start(state), move |state| async move { + let (s, page_token) = match state { + PaginationState::Start(s) => (s, None), + PaginationState::HasMore(s, page_token) if !page_token.is_empty() => { + (s, Some(page_token)) + } + _ => { + return None; + } + }; + + let (resp, s, continuation) = match op(s, page_token).await { + Ok(resp) => resp, + Err(e) => return Some((Err(e), PaginationState::Done)), + }; + + let next_state = match continuation { + Some(token) => PaginationState::HasMore(s, token), + None => PaginationState::Done, + }; + + Some((Ok(resp), next_state)) + }) +} diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs new file mode 100644 index 000000000000..d66628aec458 --- /dev/null +++ b/object_store/src/client/retry.rs @@ -0,0 +1,286 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A shared HTTP client implementation incorporating retries + +use crate::client::backoff::{Backoff, BackoffConfig}; +use futures::future::BoxFuture; +use futures::FutureExt; +use reqwest::{Response, StatusCode}; +use snafu::Snafu; +use std::time::{Duration, Instant}; +use tracing::info; + +/// Retry request error +#[derive(Debug, Snafu)] +#[snafu(display( + "response error \"{}\", after {} retries: {}", + message, + retries, + source +))] +pub struct Error { + retries: usize, + message: String, + source: reqwest::Error, +} + +impl Error { + /// Returns the status code associated with this error if any + pub fn status(&self) -> Option { + self.source.status() + } +} + +impl From for std::io::Error { + fn from(err: Error) -> Self { + use std::io::ErrorKind; + if err.source.is_builder() || err.source.is_request() { + Self::new(ErrorKind::InvalidInput, err) + } else if let Some(s) = err.source.status() { + match s { + StatusCode::NOT_FOUND => Self::new(ErrorKind::NotFound, err), + StatusCode::BAD_REQUEST => Self::new(ErrorKind::InvalidInput, err), + _ => Self::new(ErrorKind::Other, err), + } + } else if err.source.is_timeout() { + Self::new(ErrorKind::TimedOut, err) + } else if err.source.is_connect() { + Self::new(ErrorKind::NotConnected, err) + } else { + Self::new(ErrorKind::Other, err) + } + } +} + +pub type Result = std::result::Result; + +/// Contains the configuration for how to respond to server errors +/// +/// By default they will be retried up to some limit, using exponential +/// backoff with jitter. See [`BackoffConfig`] for more information +/// +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// The backoff configuration + pub backoff: BackoffConfig, + + /// The maximum number of times to retry a request + /// + /// Set to 0 to disable retries + pub max_retries: usize, + + /// The maximum length of time from the initial request + /// after which no further retries will be attempted + /// + /// This not only bounds the length of time before a server + /// error will be surfaced to the application, but also bounds + /// the length of time a request's credentials must remain valid. + /// + /// As requests are retried without renewing credentials or + /// regenerating request payloads, this number should be kept + /// below 5 minutes to avoid errors due to expired credentials + /// and/or request payloads + pub retry_timeout: Duration, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + backoff: Default::default(), + max_retries: 10, + retry_timeout: Duration::from_secs(3 * 60), + } + } +} + +pub trait RetryExt { + /// Dispatch a request with the given retry configuration + /// + /// # Panic + /// + /// This will panic if the request body is a stream + fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result>; +} + +impl RetryExt for reqwest::RequestBuilder { + fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result> { + let mut backoff = Backoff::new(&config.backoff); + let max_retries = config.max_retries; + let retry_timeout = config.retry_timeout; + + async move { + let mut retries = 0; + let now = Instant::now(); + + loop { + let s = self.try_clone().expect("request body must be cloneable"); + match s.send().await { + Ok(r) => match r.error_for_status_ref() { + Ok(_) => return Ok(r), + Err(e) => { + let status = r.status(); + + if retries == max_retries + || now.elapsed() > retry_timeout + || !status.is_server_error() { + + // Get the response message if returned a client error + let message = match status.is_client_error() { + true => match r.text().await { + Ok(message) if !message.is_empty() => message, + Ok(_) => "No Body".to_string(), + Err(e) => format!("error getting response body: {}", e) + } + false => status.to_string(), + }; + + return Err(Error{ + message, + retries, + source: e, + }) + + } + + let sleep = backoff.next(); + retries += 1; + info!("Encountered server error, backing off for {} seconds, retry {} of {}", sleep.as_secs_f32(), retries, max_retries); + tokio::time::sleep(sleep).await; + } + }, + Err(e) => + { + return Err(Error{ + retries, + message: "request error".to_string(), + source: e + }) + } + } + } + } + .boxed() + } +} + +#[cfg(test)] +mod tests { + use crate::client::mock_server::MockServer; + use crate::client::retry::RetryExt; + use crate::RetryConfig; + use hyper::header::LOCATION; + use hyper::{Body, Response}; + use reqwest::{Client, Method, StatusCode}; + use std::time::Duration; + + #[tokio::test] + async fn test_retry() { + let mock = MockServer::new(); + + let retry = RetryConfig { + backoff: Default::default(), + max_retries: 2, + retry_timeout: Duration::from_secs(1000), + }; + + let client = Client::new(); + let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry); + + // Simple request should work + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + + // Returns client errors immediately with status message + mock.push( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("cupcakes")) + .unwrap(), + ); + + let e = do_request().await.unwrap_err(); + assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); + assert_eq!(e.retries, 0); + assert_eq!(&e.message, "cupcakes"); + + // Handles client errors with no payload + mock.push( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::empty()) + .unwrap(), + ); + + let e = do_request().await.unwrap_err(); + assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); + assert_eq!(e.retries, 0); + assert_eq!(&e.message, "No Body"); + + // Should retry server error request + mock.push( + Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(Body::empty()) + .unwrap(), + ); + + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + + // Accepts 204 status code + mock.push( + Response::builder() + .status(StatusCode::NO_CONTENT) + .body(Body::empty()) + .unwrap(), + ); + + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::NO_CONTENT); + + // Follows redirects + mock.push( + Response::builder() + .status(StatusCode::FOUND) + .header(LOCATION, "/foo") + .body(Body::empty()) + .unwrap(), + ); + + let r = do_request().await.unwrap(); + assert_eq!(r.status(), StatusCode::OK); + assert_eq!(r.url().path(), "/foo"); + + // Gives up after the retrying the specified number of times + for _ in 0..=retry.max_retries { + mock.push( + Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(Body::from("ignored")) + .unwrap(), + ); + } + + let e = do_request().await.unwrap_err(); + assert_eq!(e.retries, retry.max_retries); + assert_eq!(e.message, "502 Bad Gateway"); + + // Shutdown + mock.shutdown().await + } +} diff --git a/object_store/src/client/token.rs b/object_store/src/client/token.rs new file mode 100644 index 000000000000..2ff28616e608 --- /dev/null +++ b/object_store/src/client/token.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::Future; +use std::time::Instant; +use tokio::sync::Mutex; + +/// A temporary authentication token with an associated expiry +#[derive(Debug, Clone)] +pub struct TemporaryToken { + /// The temporary credential + pub token: T, + /// The instant at which this credential is no longer valid + pub expiry: Instant, +} + +/// Provides [`TokenCache::get_or_insert_with`] which can be used to cache a +/// [`TemporaryToken`] based on its expiry +#[derive(Debug)] +pub struct TokenCache { + cache: Mutex>>, +} + +impl Default for TokenCache { + fn default() -> Self { + Self { + cache: Default::default(), + } + } +} + +impl TokenCache { + pub async fn get_or_insert_with(&self, f: F) -> Result + where + F: FnOnce() -> Fut + Send, + Fut: Future, E>> + Send, + { + let now = Instant::now(); + let mut locked = self.cache.lock().await; + + if let Some(cached) = locked.as_ref() { + let delta = cached + .expiry + .checked_duration_since(now) + .unwrap_or_default(); + + if delta.as_secs() > 300 { + return Ok(cached.token.clone()); + } + } + + let cached = f().await?; + let token = cached.token.clone(); + *locked = Some(cached); + + Ok(token) + } +} diff --git a/object_store/src/gcp/credential.rs b/object_store/src/gcp/credential.rs new file mode 100644 index 000000000000..5b8cdb8480b4 --- /dev/null +++ b/object_store/src/gcp/credential.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::client::retry::RetryExt; +use crate::client::token::TemporaryToken; +use crate::RetryConfig; +use reqwest::{Client, Method}; +use ring::signature::RsaKeyPair; +use snafu::{ResultExt, Snafu}; +use std::time::{Duration, Instant}; + +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("No RSA key found in pem file"))] + MissingKey, + + #[snafu(display("Invalid RSA key: {}", source), context(false))] + InvalidKey { source: ring::error::KeyRejected }, + + #[snafu(display("Error signing jwt: {}", source))] + Sign { source: ring::error::Unspecified }, + + #[snafu(display("Error encoding jwt payload: {}", source))] + Encode { source: serde_json::Error }, + + #[snafu(display("Unsupported key encoding: {}", encoding))] + UnsupportedKey { encoding: String }, + + #[snafu(display("Error performing token request: {}", source))] + TokenRequest { source: crate::client::retry::Error }, + + #[snafu(display("Error getting token response body: {}", source))] + TokenResponseBody { source: reqwest::Error }, +} + +pub type Result = std::result::Result; + +#[derive(Debug, Default, serde::Serialize)] +pub struct JwtHeader { + /// The type of JWS: it can only be "JWT" here + /// + /// Defined in [RFC7515#4.1.9](https://tools.ietf.org/html/rfc7515#section-4.1.9). + #[serde(skip_serializing_if = "Option::is_none")] + pub typ: Option, + /// The algorithm used + /// + /// Defined in [RFC7515#4.1.1](https://tools.ietf.org/html/rfc7515#section-4.1.1). + pub alg: String, + /// Content type + /// + /// Defined in [RFC7519#5.2](https://tools.ietf.org/html/rfc7519#section-5.2). + #[serde(skip_serializing_if = "Option::is_none")] + pub cty: Option, + /// JSON Key URL + /// + /// Defined in [RFC7515#4.1.2](https://tools.ietf.org/html/rfc7515#section-4.1.2). + #[serde(skip_serializing_if = "Option::is_none")] + pub jku: Option, + /// Key ID + /// + /// Defined in [RFC7515#4.1.4](https://tools.ietf.org/html/rfc7515#section-4.1.4). + #[serde(skip_serializing_if = "Option::is_none")] + pub kid: Option, + /// X.509 URL + /// + /// Defined in [RFC7515#4.1.5](https://tools.ietf.org/html/rfc7515#section-4.1.5). + #[serde(skip_serializing_if = "Option::is_none")] + pub x5u: Option, + /// X.509 certificate thumbprint + /// + /// Defined in [RFC7515#4.1.7](https://tools.ietf.org/html/rfc7515#section-4.1.7). + #[serde(skip_serializing_if = "Option::is_none")] + pub x5t: Option, +} + +#[derive(serde::Serialize)] +struct TokenClaims<'a> { + iss: &'a str, + scope: &'a str, + aud: &'a str, + exp: u64, + iat: u64, +} + +#[derive(serde::Deserialize, Debug)] +struct TokenResponse { + access_token: String, + expires_in: u64, +} + +/// Encapsulates the logic to perform an OAuth token challenge +#[derive(Debug)] +pub struct OAuthProvider { + issuer: String, + scope: String, + audience: String, + key_pair: RsaKeyPair, + jwt_header: String, + random: ring::rand::SystemRandom, +} + +impl OAuthProvider { + /// Create a new [`OAuthProvider`] + pub fn new( + issuer: String, + private_key_pem: String, + scope: String, + audience: String, + ) -> Result { + let key_pair = decode_first_rsa_key(private_key_pem)?; + let jwt_header = b64_encode_obj(&JwtHeader { + alg: "RS256".to_string(), + ..Default::default() + })?; + + Ok(Self { + issuer, + key_pair, + scope, + audience, + jwt_header, + random: ring::rand::SystemRandom::new(), + }) + } + + /// Fetch a fresh token + pub async fn fetch_token( + &self, + client: &Client, + retry: &RetryConfig, + ) -> Result> { + let now = seconds_since_epoch(); + let exp = now + 3600; + + let claims = TokenClaims { + iss: &self.issuer, + scope: &self.scope, + aud: &self.audience, + exp, + iat: now, + }; + + let claim_str = b64_encode_obj(&claims)?; + let message = [self.jwt_header.as_ref(), claim_str.as_ref()].join("."); + let mut sig_bytes = vec![0; self.key_pair.public_modulus_len()]; + self.key_pair + .sign( + &ring::signature::RSA_PKCS1_SHA256, + &self.random, + message.as_bytes(), + &mut sig_bytes, + ) + .context(SignSnafu)?; + + let signature = base64::encode_config(&sig_bytes, base64::URL_SAFE_NO_PAD); + let jwt = [message, signature].join("."); + + let body = [ + ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), + ("assertion", &jwt), + ]; + + let response: TokenResponse = client + .request(Method::POST, &self.audience) + .form(&body) + .send_retry(retry) + .await + .context(TokenRequestSnafu)? + .json() + .await + .context(TokenResponseBodySnafu)?; + + let token = TemporaryToken { + token: response.access_token, + expiry: Instant::now() + Duration::from_secs(response.expires_in), + }; + + Ok(token) + } +} + +/// Returns the number of seconds since unix epoch +fn seconds_since_epoch() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs() +} + +fn decode_first_rsa_key(private_key_pem: String) -> Result { + use rustls_pemfile::Item; + use std::io::{BufReader, Cursor}; + + let mut cursor = Cursor::new(private_key_pem); + let mut reader = BufReader::new(&mut cursor); + + // Reading from string is infallible + match rustls_pemfile::read_one(&mut reader).unwrap() { + Some(Item::PKCS8Key(key)) => Ok(RsaKeyPair::from_pkcs8(&key)?), + Some(Item::RSAKey(key)) => Ok(RsaKeyPair::from_der(&key)?), + _ => Err(Error::MissingKey), + } +} + +fn b64_encode_obj(obj: &T) -> Result { + let string = serde_json::to_string(obj).context(EncodeSnafu)?; + Ok(base64::encode_config(string, base64::URL_SAFE_NO_PAD)) +} diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs new file mode 100644 index 000000000000..0ef4d3564b64 --- /dev/null +++ b/object_store/src/gcp/mod.rs @@ -0,0 +1,1018 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for Google Cloud Storage +//! +//! ## Multi-part uploads +//! +//! [Multi-part uploads](https://cloud.google.com/storage/docs/multipart-uploads) +//! can be initiated with the [ObjectStore::put_multipart] method. +//! Data passed to the writer is automatically buffered to meet the minimum size +//! requirements for a part. Multiple parts are uploaded concurrently. +//! +//! If the writer fails for any reason, you may have parts uploaded to GCS but not +//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method +//! to abort the upload and drop those unneeded parts. In addition, you may wish to +//! consider implementing automatic clean up of unused parts that are older than one +//! week. +use std::collections::BTreeSet; +use std::fs::File; +use std::io::{self, BufReader}; +use std::ops::Range; +use std::sync::Arc; + +use async_trait::async_trait; +use bytes::{Buf, Bytes}; +use chrono::{DateTime, Utc}; +use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use percent_encoding::{percent_encode, NON_ALPHANUMERIC}; +use reqwest::header::RANGE; +use reqwest::{header, Client, Method, Response, StatusCode}; +use snafu::{ResultExt, Snafu}; +use tokio::io::AsyncWrite; + +use crate::client::pagination::stream_paginated; +use crate::client::retry::RetryExt; +use crate::{ + client::token::TokenCache, + multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart}, + path::{Path, DELIMITER}, + util::{format_http_range, format_prefix}, + GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, RetryConfig, +}; + +use credential::OAuthProvider; + +mod credential; + +#[derive(Debug, Snafu)] +enum Error { + #[snafu(display("Unable to open service account file: {}", source))] + OpenCredentials { source: std::io::Error }, + + #[snafu(display("Unable to decode service account file: {}", source))] + DecodeCredentials { source: serde_json::Error }, + + #[snafu(display("Got invalid XML response for {} {}: {}", method, url, source))] + InvalidXMLResponse { + source: quick_xml::de::DeError, + method: String, + url: String, + data: Bytes, + }, + + #[snafu(display("Error performing list request: {}", source))] + ListRequest { source: crate::client::retry::Error }, + + #[snafu(display("Error getting list response body: {}", source))] + ListResponseBody { source: reqwest::Error }, + + #[snafu(display("Error performing get request {}: {}", path, source))] + GetRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error getting get response body {}: {}", path, source))] + GetResponseBody { + source: reqwest::Error, + path: String, + }, + + #[snafu(display("Error performing delete request {}: {}", path, source))] + DeleteRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing copy request {}: {}", path, source))] + CopyRequest { + source: crate::client::retry::Error, + path: String, + }, + + #[snafu(display("Error performing put request: {}", source))] + PutRequest { source: crate::client::retry::Error }, + + #[snafu(display("Error getting put response body: {}", source))] + PutResponseBody { source: reqwest::Error }, + + #[snafu(display("Error decoding object size: {}", source))] + InvalidSize { source: std::num::ParseIntError }, + + #[snafu(display("Missing bucket name"))] + MissingBucketName {}, + + #[snafu(display("Missing service account path"))] + MissingServiceAccountPath, + + #[snafu(display("GCP credential error: {}", source))] + Credential { source: credential::Error }, +} + +impl From for super::Error { + fn from(err: Error) -> Self { + match err { + Error::GetRequest { source, path } + | Error::DeleteRequest { source, path } + | Error::CopyRequest { source, path } + if matches!(source.status(), Some(StatusCode::NOT_FOUND)) => + { + Self::NotFound { + path, + source: Box::new(source), + } + } + _ => Self::Generic { + store: "GCS", + source: Box::new(err), + }, + } + } +} + +/// A deserialized `service-account-********.json`-file. +#[derive(serde::Deserialize, Debug)] +struct ServiceAccountCredentials { + /// The private key in RSA format. + pub private_key: String, + + /// The email address associated with the service account. + pub client_email: String, + + /// Base URL for GCS + #[serde(default = "default_gcs_base_url")] + pub gcs_base_url: String, + + /// Disable oauth and use empty tokens. + #[serde(default = "default_disable_oauth")] + pub disable_oauth: bool, +} + +fn default_gcs_base_url() -> String { + "https://storage.googleapis.com".to_owned() +} + +fn default_disable_oauth() -> bool { + false +} + +#[derive(serde::Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +struct ListResponse { + next_page_token: Option, + #[serde(default)] + prefixes: Vec, + #[serde(default)] + items: Vec, +} + +#[derive(serde::Deserialize, Debug)] +struct Object { + name: String, + size: String, + updated: DateTime, +} + +#[derive(serde::Deserialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct InitiateMultipartUploadResult { + upload_id: String, +} + +#[derive(serde::Serialize, Debug)] +#[serde(rename_all = "PascalCase", rename(serialize = "Part"))] +struct MultipartPart { + #[serde(rename = "$unflatten=PartNumber")] + part_number: usize, + #[serde(rename = "$unflatten=ETag")] + e_tag: String, +} + +#[derive(serde::Serialize, Debug)] +#[serde(rename_all = "PascalCase")] +struct CompleteMultipartUpload { + #[serde(rename = "Part", default)] + parts: Vec, +} + +/// Interface for [Google Cloud Storage](https://cloud.google.com/storage/). +#[derive(Debug)] +pub struct GoogleCloudStorage { + client: Arc, +} + +impl std::fmt::Display for GoogleCloudStorage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "GoogleCloudStorage({})", self.client.bucket_name) + } +} + +#[derive(Debug)] +struct GoogleCloudStorageClient { + client: Client, + base_url: String, + + oauth_provider: Option, + token_cache: TokenCache, + + bucket_name: String, + bucket_name_encoded: String, + + retry_config: RetryConfig, + + // TODO: Hook this up in tests + max_list_results: Option, +} + +impl GoogleCloudStorageClient { + async fn get_token(&self) -> Result { + if let Some(oauth_provider) = &self.oauth_provider { + Ok(self + .token_cache + .get_or_insert_with(|| { + oauth_provider.fetch_token(&self.client, &self.retry_config) + }) + .await + .context(CredentialSnafu)?) + } else { + Ok("".to_owned()) + } + } + + fn object_url(&self, path: &Path) -> String { + let encoded = + percent_encoding::utf8_percent_encode(path.as_ref(), NON_ALPHANUMERIC); + format!( + "{}/storage/v1/b/{}/o/{}", + self.base_url, self.bucket_name_encoded, encoded + ) + } + + /// Perform a get request + async fn get_request( + &self, + path: &Path, + range: Option>, + head: bool, + ) -> Result { + let token = self.get_token().await?; + let url = self.object_url(path); + + let mut builder = self.client.request(Method::GET, url); + + if let Some(range) = range { + builder = builder.header(RANGE, format_http_range(range)); + } + + let alt = match head { + true => "json", + false => "media", + }; + + let response = builder + .bearer_auth(token) + .query(&[("alt", alt)]) + .send_retry(&self.retry_config) + .await + .context(GetRequestSnafu { + path: path.as_ref(), + })?; + + Ok(response) + } + + /// Perform a put request + async fn put_request(&self, path: &Path, payload: Bytes) -> Result<()> { + let token = self.get_token().await?; + let url = format!( + "{}/upload/storage/v1/b/{}/o", + self.base_url, self.bucket_name_encoded + ); + + self.client + .request(Method::POST, url) + .bearer_auth(token) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, payload.len()) + .query(&[("uploadType", "media"), ("name", path.as_ref())]) + .body(payload) + .send_retry(&self.retry_config) + .await + .context(PutRequestSnafu)?; + + Ok(()) + } + + /// Initiate a multi-part upload + async fn multipart_initiate(&self, path: &Path) -> Result { + let token = self.get_token().await?; + let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path); + + let response = self + .client + .request(Method::POST, &url) + .bearer_auth(token) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, "0") + .query(&[("uploads", "")]) + .send_retry(&self.retry_config) + .await + .context(PutRequestSnafu)?; + + let data = response.bytes().await.context(PutResponseBodySnafu)?; + let result: InitiateMultipartUploadResult = quick_xml::de::from_reader( + data.as_ref().reader(), + ) + .context(InvalidXMLResponseSnafu { + method: "POST".to_string(), + url, + data, + })?; + + Ok(result.upload_id) + } + + /// Cleanup unused parts + async fn multipart_cleanup( + &self, + path: &str, + multipart_id: &MultipartId, + ) -> Result<()> { + let token = self.get_token().await?; + let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path); + + self.client + .request(Method::DELETE, &url) + .bearer_auth(token) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, "0") + .query(&[("uploadId", multipart_id)]) + .send_retry(&self.retry_config) + .await + .context(PutRequestSnafu)?; + + Ok(()) + } + + /// Perform a delete request + async fn delete_request(&self, path: &Path) -> Result<()> { + let token = self.get_token().await?; + let url = self.object_url(path); + + let builder = self.client.request(Method::DELETE, url); + builder + .bearer_auth(token) + .send_retry(&self.retry_config) + .await + .context(DeleteRequestSnafu { + path: path.as_ref(), + })?; + + Ok(()) + } + + /// Perform a copy request + async fn copy_request( + &self, + from: &Path, + to: &Path, + if_not_exists: bool, + ) -> Result<()> { + let token = self.get_token().await?; + + let source = + percent_encoding::utf8_percent_encode(from.as_ref(), NON_ALPHANUMERIC); + let destination = + percent_encoding::utf8_percent_encode(to.as_ref(), NON_ALPHANUMERIC); + let url = format!( + "{}/storage/v1/b/{}/o/{}/copyTo/b/{}/o/{}", + self.base_url, + self.bucket_name_encoded, + source, + self.bucket_name_encoded, + destination + ); + + let mut builder = self.client.request(Method::POST, url); + + if if_not_exists { + builder = builder.query(&[("ifGenerationMatch", "0")]); + } + + builder + .bearer_auth(token) + .send_retry(&self.retry_config) + .await + .context(CopyRequestSnafu { + path: from.as_ref(), + })?; + + Ok(()) + } + + /// Perform a list request + async fn list_request( + &self, + prefix: Option<&str>, + delimiter: bool, + page_token: Option<&str>, + ) -> Result { + let token = self.get_token().await?; + + let url = format!( + "{}/storage/v1/b/{}/o", + self.base_url, self.bucket_name_encoded + ); + + let mut query = Vec::with_capacity(4); + if delimiter { + query.push(("delimiter", DELIMITER)) + } + + if let Some(prefix) = &prefix { + query.push(("prefix", prefix)) + } + + if let Some(page_token) = page_token { + query.push(("pageToken", page_token)) + } + + if let Some(max_results) = &self.max_list_results { + query.push(("maxResults", max_results)) + } + + let response: ListResponse = self + .client + .request(Method::GET, url) + .query(&query) + .bearer_auth(token) + .send_retry(&self.retry_config) + .await + .context(ListRequestSnafu)? + .json() + .await + .context(ListResponseBodySnafu)?; + + Ok(response) + } + + /// Perform a list operation automatically handling pagination + fn list_paginated( + &self, + prefix: Option<&Path>, + delimiter: bool, + ) -> BoxStream<'_, Result> { + let prefix = format_prefix(prefix); + stream_paginated(prefix, move |prefix, token| async move { + let mut r = self + .list_request(prefix.as_deref(), delimiter, token.as_deref()) + .await?; + let next_token = r.next_page_token.take(); + Ok((r, prefix, next_token)) + }) + .boxed() + } +} + +struct GCSMultipartUpload { + client: Arc, + encoded_path: String, + multipart_id: MultipartId, +} + +#[async_trait] +impl CloudMultiPartUploadImpl for GCSMultipartUpload { + /// Upload an object part + async fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> Result { + let upload_id = self.multipart_id.clone(); + let url = format!( + "{}/{}/{}", + self.client.base_url, self.client.bucket_name_encoded, self.encoded_path + ); + + let token = self + .client + .get_token() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + let response = self + .client + .client + .request(Method::PUT, &url) + .bearer_auth(token) + .query(&[ + ("partNumber", format!("{}", part_idx + 1)), + ("uploadId", upload_id), + ]) + .header(header::CONTENT_TYPE, "application/octet-stream") + .header(header::CONTENT_LENGTH, format!("{}", buf.len())) + .body(buf) + .send_retry(&self.client.retry_config) + .await?; + + let content_id = response + .headers() + .get("ETag") + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "response headers missing ETag", + ) + })? + .to_str() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? + .to_string(); + + Ok(UploadPart { content_id }) + } + + /// Complete a multipart upload + async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error> { + let upload_id = self.multipart_id.clone(); + let url = format!( + "{}/{}/{}", + self.client.base_url, self.client.bucket_name_encoded, self.encoded_path + ); + + let parts = completed_parts + .into_iter() + .enumerate() + .map(|(part_number, part)| MultipartPart { + e_tag: part.content_id, + part_number: part_number + 1, + }) + .collect(); + + let token = self + .client + .get_token() + .await + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?; + + let upload_info = CompleteMultipartUpload { parts }; + + let data = quick_xml::se::to_string(&upload_info) + .map_err(|err| io::Error::new(io::ErrorKind::Other, err))? + // We cannot disable the escaping that transforms "/" to ""e;" :( + // https://github.com/tafia/quick-xml/issues/362 + // https://github.com/tafia/quick-xml/issues/350 + .replace(""", "\""); + + self.client + .client + .request(Method::POST, &url) + .bearer_auth(token) + .query(&[("uploadId", upload_id)]) + .body(data) + .send_retry(&self.client.retry_config) + .await?; + + Ok(()) + } +} + +#[async_trait] +impl ObjectStore for GoogleCloudStorage { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + self.client.put_request(location, bytes).await + } + + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let upload_id = self.client.multipart_initiate(location).await?; + + let encoded_path = + percent_encode(location.to_string().as_bytes(), NON_ALPHANUMERIC).to_string(); + + let inner = GCSMultipartUpload { + client: Arc::clone(&self.client), + encoded_path, + multipart_id: upload_id.clone(), + }; + + Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8)))) + } + + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + self.client + .multipart_cleanup(location.as_ref(), multipart_id) + .await?; + + Ok(()) + } + + async fn get(&self, location: &Path) -> Result { + let response = self.client.get_request(location, None, false).await?; + let stream = response + .bytes_stream() + .map_err(|source| crate::Error::Generic { + store: "GCS", + source: Box::new(source), + }) + .boxed(); + + Ok(GetResult::Stream(stream)) + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let response = self + .client + .get_request(location, Some(range), false) + .await?; + Ok(response.bytes().await.context(GetResponseBodySnafu { + path: location.as_ref(), + })?) + } + + async fn head(&self, location: &Path) -> Result { + let response = self.client.get_request(location, None, true).await?; + let object = response.json().await.context(GetResponseBodySnafu { + path: location.as_ref(), + })?; + convert_object_meta(&object) + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.client.delete_request(location).await + } + + async fn list( + &self, + prefix: Option<&Path>, + ) -> Result>> { + let stream = self + .client + .list_paginated(prefix, false) + .map_ok(|r| { + futures::stream::iter( + r.items.into_iter().map(|x| convert_object_meta(&x)), + ) + }) + .try_flatten() + .boxed(); + + Ok(stream) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let mut stream = self.client.list_paginated(prefix, true); + + let mut common_prefixes = BTreeSet::new(); + let mut objects = Vec::new(); + + while let Some(result) = stream.next().await { + let response = result?; + + for p in response.prefixes { + common_prefixes.insert(Path::parse(p)?); + } + + objects.reserve(response.items.len()); + for object in &response.items { + objects.push(convert_object_meta(object)?); + } + } + + Ok(ListResult { + common_prefixes: common_prefixes.into_iter().collect(), + objects, + }) + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to, false).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.client.copy_request(from, to, true).await + } +} + +fn reader_credentials_file( + service_account_path: impl AsRef, +) -> Result { + let file = File::open(service_account_path).context(OpenCredentialsSnafu)?; + let reader = BufReader::new(file); + Ok(serde_json::from_reader(reader).context(DecodeCredentialsSnafu)?) +} + +/// Configure a connection to Google Cloud Storage using the specified +/// credentials. +/// +/// # Example +/// ``` +/// # let BUCKET_NAME = "foo"; +/// # let SERVICE_ACCOUNT_PATH = "/tmp/foo.json"; +/// # use object_store::gcp::GoogleCloudStorageBuilder; +/// let gcs = GoogleCloudStorageBuilder::new() +/// .with_service_account_path(SERVICE_ACCOUNT_PATH) +/// .with_bucket_name(BUCKET_NAME) +/// .build(); +/// ``` +#[derive(Debug, Default)] +pub struct GoogleCloudStorageBuilder { + bucket_name: Option, + service_account_path: Option, + client: Option, + retry_config: RetryConfig, +} + +impl GoogleCloudStorageBuilder { + /// Create a new [`GoogleCloudStorageBuilder`] with default values. + pub fn new() -> Self { + Default::default() + } + + /// Set the bucket name (required) + pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { + self.bucket_name = Some(bucket_name.into()); + self + } + + /// Set the path to the service account file (required). Example + /// `"/tmp/gcs.json"` + /// + /// Example contents of `gcs.json`: + /// + /// ```json + /// { + /// "gcs_base_url": "https://localhost:4443", + /// "disable_oauth": true, + /// "client_email": "", + /// "private_key": "" + /// } + /// ``` + pub fn with_service_account_path( + mut self, + service_account_path: impl Into, + ) -> Self { + self.service_account_path = Some(service_account_path.into()); + self + } + + /// Set the retry configuration + pub fn with_retry(mut self, retry_config: RetryConfig) -> Self { + self.retry_config = retry_config; + self + } + + /// Configure a connection to Google Cloud Storage, returning a + /// new [`GoogleCloudStorage`] and consuming `self` + pub fn build(self) -> Result { + let Self { + bucket_name, + service_account_path, + client, + retry_config, + } = self; + + let bucket_name = bucket_name.ok_or(Error::MissingBucketName {})?; + let service_account_path = + service_account_path.ok_or(Error::MissingServiceAccountPath)?; + let client = client.unwrap_or_else(Client::new); + + let credentials = reader_credentials_file(service_account_path)?; + + // TODO: https://cloud.google.com/storage/docs/authentication#oauth-scopes + let scope = "https://www.googleapis.com/auth/devstorage.full_control"; + let audience = "https://www.googleapis.com/oauth2/v4/token".to_string(); + + let oauth_provider = (!credentials.disable_oauth) + .then(|| { + OAuthProvider::new( + credentials.client_email, + credentials.private_key, + scope.to_string(), + audience, + ) + }) + .transpose() + .context(CredentialSnafu)?; + + let encoded_bucket_name = + percent_encode(bucket_name.as_bytes(), NON_ALPHANUMERIC).to_string(); + + // The cloud storage crate currently only supports authentication via + // environment variables. Set the environment variable explicitly so + // that we can optionally accept command line arguments instead. + Ok(GoogleCloudStorage { + client: Arc::new(GoogleCloudStorageClient { + client, + base_url: credentials.gcs_base_url, + oauth_provider, + token_cache: Default::default(), + bucket_name, + bucket_name_encoded: encoded_bucket_name, + retry_config, + max_list_results: None, + }), + }) + } +} + +fn convert_object_meta(object: &Object) -> Result { + let location = Path::parse(&object.name)?; + let last_modified = object.updated; + let size = object.size.parse().context(InvalidSizeSnafu)?; + + Ok(ObjectMeta { + location, + last_modified, + size, + }) +} + +#[cfg(test)] +mod test { + use std::env; + + use bytes::Bytes; + + use crate::{ + tests::{ + get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter, + put_get_delete_list, rename_and_copy, stream_get, + }, + Error as ObjectStoreError, ObjectStore, + }; + + use super::*; + + const NON_EXISTENT_NAME: &str = "nonexistentname"; + + // Helper macro to skip tests if TEST_INTEGRATION and the GCP environment variables are not set. + macro_rules! maybe_skip_integration { + () => {{ + dotenv::dotenv().ok(); + + let required_vars = ["OBJECT_STORE_BUCKET", "GOOGLE_SERVICE_ACCOUNT"]; + let unset_vars: Vec<_> = required_vars + .iter() + .filter_map(|&name| match env::var(name) { + Ok(_) => None, + Err(_) => Some(name), + }) + .collect(); + let unset_var_names = unset_vars.join(", "); + + let force = std::env::var("TEST_INTEGRATION"); + + if force.is_ok() && !unset_var_names.is_empty() { + panic!( + "TEST_INTEGRATION is set, \ + but variable(s) {} need to be set", + unset_var_names + ) + } else if force.is_err() { + eprintln!( + "skipping Google Cloud integration test - set {}TEST_INTEGRATION to run", + if unset_var_names.is_empty() { + String::new() + } else { + format!("{} and ", unset_var_names) + } + ); + return; + } else { + GoogleCloudStorageBuilder::new() + .with_bucket_name( + env::var("OBJECT_STORE_BUCKET") + .expect("already checked OBJECT_STORE_BUCKET") + ) + .with_service_account_path( + env::var("GOOGLE_SERVICE_ACCOUNT") + .expect("already checked GOOGLE_SERVICE_ACCOUNT") + ) + } + }}; + } + + #[tokio::test] + async fn gcs_test() { + let integration = maybe_skip_integration!().build().unwrap(); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + if integration.client.base_url == default_gcs_base_url() { + // Fake GCS server does not yet implement XML Multipart uploads + // https://github.com/fsouza/fake-gcs-server/issues/852 + stream_get(&integration).await; + } + } + + #[tokio::test] + async fn gcs_test_get_nonexistent_location() { + let integration = maybe_skip_integration!().build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.get(&location).await.unwrap_err(); + + assert!( + matches!(err, ObjectStoreError::NotFound { .. }), + "unexpected error type: {}", + err + ); + } + + #[tokio::test] + async fn gcs_test_get_nonexistent_bucket() { + let integration = maybe_skip_integration!() + .with_bucket_name(NON_EXISTENT_NAME) + .build() + .unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = get_nonexistent_object(&integration, Some(location)) + .await + .unwrap_err(); + + assert!( + matches!(err, ObjectStoreError::NotFound { .. }), + "unexpected error type: {}", + err + ); + } + + #[tokio::test] + async fn gcs_test_delete_nonexistent_location() { + let integration = maybe_skip_integration!().build().unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.delete(&location).await.unwrap_err(); + assert!( + matches!(err, ObjectStoreError::NotFound { .. }), + "unexpected error type: {}", + err + ); + } + + #[tokio::test] + async fn gcs_test_delete_nonexistent_bucket() { + let integration = maybe_skip_integration!() + .with_bucket_name(NON_EXISTENT_NAME) + .build() + .unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + + let err = integration.delete(&location).await.unwrap_err(); + assert!( + matches!(err, ObjectStoreError::NotFound { .. }), + "unexpected error type: {}", + err + ); + } + + #[tokio::test] + async fn gcs_test_put_nonexistent_bucket() { + let integration = maybe_skip_integration!() + .with_bucket_name(NON_EXISTENT_NAME) + .build() + .unwrap(); + + let location = Path::from_iter([NON_EXISTENT_NAME]); + let data = Bytes::from("arbitrary data"); + + let err = integration + .put(&location, data) + .await + .unwrap_err() + .to_string(); + assert!( + err.contains("HTTP status client error (404 Not Found)"), + "{}", + err + ) + } +} diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs new file mode 100644 index 000000000000..9ed9db9e928c --- /dev/null +++ b/object_store/src/lib.rs @@ -0,0 +1,973 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![deny(rustdoc::broken_intra_doc_links, rustdoc::bare_urls, rust_2018_idioms)] +#![warn( + missing_copy_implementations, + missing_debug_implementations, + missing_docs, + clippy::explicit_iter_loop, + clippy::future_not_send, + clippy::use_self, + clippy::clone_on_ref_ptr +)] + +//! # object_store +//! +//! This crate provides a uniform API for interacting with object storage services and +//! local files via the the [`ObjectStore`] trait. +//! +//! # Create an [`ObjectStore`] implementation: +//! +//! * [Google Cloud Storage](https://cloud.google.com/storage/): [`GoogleCloudStorageBuilder`](gcp::GoogleCloudStorageBuilder) +//! * [Amazon S3](https://aws.amazon.com/s3/): [`AmazonS3Builder`](aws::AmazonS3Builder) +//! * [Azure Blob Storage](https://azure.microsoft.com/en-gb/services/storage/blobs/):: [`MicrosoftAzureBuilder`](azure::MicrosoftAzureBuilder) +//! * In Memory: [`InMemory`](memory::InMemory) +//! * Local filesystem: [`LocalFileSystem`](local::LocalFileSystem) +//! +//! # Adapters +//! +//! [`ObjectStore`] instances can be composed with various adapters +//! which add additional functionality: +//! +//! * Rate Throttling: [`ThrottleConfig`](throttle::ThrottleConfig) +//! * Concurrent Request Limit: [`LimitStore`](limit::LimitStore) +//! +//! +//! # Listing objects: +//! +//! Use the [`ObjectStore::list`] method to iterate over objects in +//! remote storage or files in the local filesystem: +//! +//! ``` +//! # use object_store::local::LocalFileSystem; +//! # // use LocalFileSystem for example +//! # fn get_object_store() -> LocalFileSystem { +//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # } +//! +//! # async fn example() { +//! use std::sync::Arc; +//! use object_store::{path::Path, ObjectStore}; +//! use futures::stream::StreamExt; +//! +//! // create an ObjectStore +//! let object_store: Arc = Arc::new(get_object_store()); +//! +//! // Recursively list all files below the 'data' path. +//! // 1. On AWS S3 this would be the 'data/' prefix +//! // 2. On a local filesystem, this would be the 'data' directory +//! let prefix: Path = "data".try_into().unwrap(); +//! +//! // Get an `async` stream of Metadata objects: +//! let list_stream = object_store +//! .list(Some(&prefix)) +//! .await +//! .expect("Error listing files"); +//! +//! // Print a line about each object based on its metadata +//! // using for_each from `StreamExt` trait. +//! list_stream +//! .for_each(move |meta| { +//! async { +//! let meta = meta.expect("Error listing"); +//! println!("Name: {}, size: {}", meta.location, meta.size); +//! } +//! }) +//! .await; +//! # } +//! ``` +//! +//! Which will print out something like the following: +//! +//! ```text +//! Name: data/file01.parquet, size: 112832 +//! Name: data/file02.parquet, size: 143119 +//! Name: data/child/file03.parquet, size: 100 +//! ... +//! ``` +//! +//! # Fetching objects +//! +//! Use the [`ObjectStore::get`] method to fetch the data bytes +//! from remote storage or files in the local filesystem as a stream. +//! +//! ``` +//! # use object_store::local::LocalFileSystem; +//! # // use LocalFileSystem for example +//! # fn get_object_store() -> LocalFileSystem { +//! # LocalFileSystem::new_with_prefix("/tmp").unwrap() +//! # } +//! +//! # async fn example() { +//! use std::sync::Arc; +//! use object_store::{path::Path, ObjectStore}; +//! use futures::stream::StreamExt; +//! +//! // create an ObjectStore +//! let object_store: Arc = Arc::new(get_object_store()); +//! +//! // Retrieve a specific file +//! let path: Path = "data/file01.parquet".try_into().unwrap(); +//! +//! // fetch the bytes from object store +//! let stream = object_store +//! .get(&path) +//! .await +//! .unwrap() +//! .into_stream(); +//! +//! // Count the '0's using `map` from `StreamExt` trait +//! let num_zeros = stream +//! .map(|bytes| { +//! let bytes = bytes.unwrap(); +//! bytes.iter().filter(|b| **b == 0).count() +//! }) +//! .collect::>() +//! .await +//! .into_iter() +//! .sum::(); +//! +//! println!("Num zeros in {} is {}", path, num_zeros); +//! # } +//! ``` +//! +//! Which will print out something like the following: +//! +//! ```text +//! Num zeros in data/file01.parquet is 657 +//! ``` +//! + +#[cfg(feature = "aws")] +pub mod aws; +#[cfg(feature = "azure")] +pub mod azure; +#[cfg(feature = "gcp")] +pub mod gcp; +pub mod limit; +pub mod local; +pub mod memory; +pub mod path; +pub mod throttle; + +#[cfg(any(feature = "gcp", feature = "aws", feature = "azure"))] +mod client; + +#[cfg(any(feature = "gcp", feature = "aws", feature = "azure"))] +pub use client::{backoff::BackoffConfig, retry::RetryConfig}; + +#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))] +mod multipart; +mod util; + +use crate::path::Path; +use crate::util::{ + coalesce_ranges, collect_bytes, maybe_spawn_blocking, OBJECT_STORE_COALESCE_DEFAULT, +}; +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{DateTime, Utc}; +use futures::{stream::BoxStream, StreamExt}; +use snafu::Snafu; +use std::fmt::{Debug, Formatter}; +use std::io::{Read, Seek, SeekFrom}; +use std::ops::Range; +use tokio::io::AsyncWrite; + +/// An alias for a dynamically dispatched object store implementation. +pub type DynObjectStore = dyn ObjectStore; + +/// Id type for multi-part uploads. +pub type MultipartId = String; + +/// Universal API to multiple object store services. +#[async_trait] +pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { + /// Save the provided bytes to the specified location. + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()>; + + /// Get a multi-part upload that allows writing data in chunks + /// + /// Most cloud-based uploads will buffer and upload parts in parallel. + /// + /// To complete the upload, [AsyncWrite::poll_shutdown] must be called + /// to completion. + /// + /// For some object stores (S3, GCS, and local in particular), if the + /// writer fails or panics, you must call [ObjectStore::abort_multipart] + /// to clean up partially written data. + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)>; + + /// Cleanup an aborted upload. + /// + /// See documentation for individual stores for exact behavior, as capabilities + /// vary by object store. + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()>; + + /// Return the bytes that are stored at the specified location. + async fn get(&self, location: &Path) -> Result; + + /// Return the bytes that are stored at the specified location + /// in the given byte range + async fn get_range(&self, location: &Path, range: Range) -> Result; + + /// Return the bytes that are stored at the specified location + /// in the given byte ranges + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> Result> { + coalesce_ranges( + ranges, + |range| self.get_range(location, range), + OBJECT_STORE_COALESCE_DEFAULT, + ) + .await + } + + /// Return the metadata for the specified location + async fn head(&self, location: &Path) -> Result; + + /// Delete the object at the specified location. + async fn delete(&self, location: &Path) -> Result<()>; + + /// List all the objects with the given prefix. + /// + /// Prefixes are evaluated on a path segment basis, i.e. `foo/bar/` is a prefix of `foo/bar/x` but not of + /// `foo/bar_baz/x`. + async fn list( + &self, + prefix: Option<&Path>, + ) -> Result>>; + + /// List objects with the given prefix and an implementation specific + /// delimiter. Returns common prefixes (directories) in addition to object + /// metadata. + /// + /// Prefixes are evaluated on a path segment basis, i.e. `foo/bar/` is a prefix of `foo/bar/x` but not of + /// `foo/bar_baz/x`. + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result; + + /// Copy an object from one path to another in the same object store. + /// + /// If there exists an object at the destination, it will be overwritten. + async fn copy(&self, from: &Path, to: &Path) -> Result<()>; + + /// Move an object from one path to another in the same object store. + /// + /// By default, this is implemented as a copy and then delete source. It may not + /// check when deleting source that it was the same object that was originally copied. + /// + /// If there exists an object at the destination, it will be overwritten. + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + self.copy(from, to).await?; + self.delete(from).await + } + + /// Copy an object from one path to another, only if destination is empty. + /// + /// Will return an error if the destination already has an object. + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()>; + + /// Move an object from one path to another in the same object store. + /// + /// Will return an error if the destination already has an object. + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.copy_if_not_exists(from, to).await?; + self.delete(from).await + } +} + +/// Result of a list call that includes objects, prefixes (directories) and a +/// token for the next set of results. Individual result sets may be limited to +/// 1,000 objects based on the underlying object storage's limitations. +#[derive(Debug)] +pub struct ListResult { + /// Prefixes that are common (like directories) + pub common_prefixes: Vec, + /// Object metadata for the listing + pub objects: Vec, +} + +/// The metadata that describes an object. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ObjectMeta { + /// The full path to the object + pub location: Path, + /// The last modified time + pub last_modified: DateTime, + /// The size in bytes of the object + pub size: usize, +} + +/// Result for a get request +/// +/// This special cases the case of a local file, as some systems may +/// be able to optimise the case of a file already present on local disk +pub enum GetResult { + /// A file and its path on the local filesystem + File(std::fs::File, std::path::PathBuf), + /// An asynchronous stream + Stream(BoxStream<'static, Result>), +} + +impl Debug for GetResult { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::File(_, _) => write!(f, "GetResult(File)"), + Self::Stream(_) => write!(f, "GetResult(Stream)"), + } + } +} + +impl GetResult { + /// Collects the data into a [`Bytes`] + pub async fn bytes(self) -> Result { + match self { + Self::File(mut file, path) => { + maybe_spawn_blocking(move || { + let len = file.seek(SeekFrom::End(0)).map_err(|source| { + local::Error::Seek { + source, + path: path.clone(), + } + })?; + + file.seek(SeekFrom::Start(0)).map_err(|source| { + local::Error::Seek { + source, + path: path.clone(), + } + })?; + + let mut buffer = Vec::with_capacity(len as usize); + file.read_to_end(&mut buffer).map_err(|source| { + local::Error::UnableToReadBytes { source, path } + })?; + + Ok(buffer.into()) + }) + .await + } + Self::Stream(s) => collect_bytes(s, None).await, + } + } + + /// Converts this into a byte stream + /// + /// If the result is [`Self::File`] will perform chunked reads of the file, otherwise + /// will return the [`Self::Stream`]. + /// + /// # Tokio Compatibility + /// + /// Tokio discourages performing blocking IO on a tokio worker thread, however, + /// no major operating systems have stable async file APIs. Therefore if called from + /// a tokio context, this will use [`tokio::runtime::Handle::spawn_blocking`] to dispatch + /// IO to a blocking thread pool, much like `tokio::fs` does under-the-hood. + /// + /// If not called from a tokio context, this will perform IO on the current thread with + /// no additional complexity or overheads + pub fn into_stream(self) -> BoxStream<'static, Result> { + match self { + Self::File(file, path) => { + const CHUNK_SIZE: usize = 8 * 1024; + + futures::stream::try_unfold( + (file, path, false), + |(mut file, path, finished)| { + maybe_spawn_blocking(move || { + if finished { + return Ok(None); + } + + let mut buffer = Vec::with_capacity(CHUNK_SIZE); + let read = file + .by_ref() + .take(CHUNK_SIZE as u64) + .read_to_end(&mut buffer) + .map_err(|e| local::Error::UnableToReadBytes { + source: e, + path: path.clone(), + })?; + + Ok(Some((buffer.into(), (file, path, read != CHUNK_SIZE)))) + }) + }, + ) + .boxed() + } + Self::Stream(s) => s, + } + } +} + +/// A specialized `Result` for object store-related errors +pub type Result = std::result::Result; + +/// A specialized `Error` for object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("Generic {} error: {}", store, source))] + Generic { + store: &'static str, + source: Box, + }, + + #[snafu(display("Object at location {} not found: {}", path, source))] + NotFound { + path: String, + source: Box, + }, + + #[snafu( + display("Encountered object with invalid path: {}", source), + context(false) + )] + InvalidPath { source: path::Error }, + + #[snafu(display("Error joining spawned task: {}", source), context(false))] + JoinError { source: tokio::task::JoinError }, + + #[snafu(display("Operation not supported: {}", source))] + NotSupported { + source: Box, + }, + + #[snafu(display("Object at location {} already exists: {}", path, source))] + AlreadyExists { + path: String, + source: Box, + }, + + #[snafu(display("Operation not yet implemented."))] + NotImplemented, +} + +impl From for std::io::Error { + fn from(e: Error) -> Self { + let kind = match &e { + Error::NotFound { .. } => std::io::ErrorKind::NotFound, + _ => std::io::ErrorKind::Other, + }; + Self::new(kind, e) + } +} + +#[cfg(test)] +mod test_util { + use super::*; + use futures::TryStreamExt; + + pub async fn flatten_list_stream( + storage: &DynObjectStore, + prefix: Option<&Path>, + ) -> Result> { + storage + .list(prefix) + .await? + .map_ok(|meta| meta.location) + .try_collect::>() + .await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_util::flatten_list_stream; + use tokio::io::AsyncWriteExt; + + pub(crate) async fn put_get_delete_list(storage: &DynObjectStore) { + delete_fixtures(storage).await; + + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!( + content_list.is_empty(), + "Expected list to be empty; found: {:?}", + content_list + ); + + let location = Path::from("test_dir/test_file.json"); + + let data = Bytes::from("arbitrary data"); + let expected_data = data.clone(); + storage.put(&location, data).await.unwrap(); + + let root = Path::from("/"); + + // List everything + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert_eq!(content_list, &[location.clone()]); + + // Should behave the same as no prefix + let content_list = flatten_list_stream(storage, Some(&root)).await.unwrap(); + assert_eq!(content_list, &[location.clone()]); + + // List with delimiter + let result = storage.list_with_delimiter(None).await.unwrap(); + assert_eq!(&result.objects, &[]); + assert_eq!(result.common_prefixes.len(), 1); + assert_eq!(result.common_prefixes[0], Path::from("test_dir")); + + // Should behave the same as no prefix + let result = storage.list_with_delimiter(Some(&root)).await.unwrap(); + assert!(result.objects.is_empty()); + assert_eq!(result.common_prefixes.len(), 1); + assert_eq!(result.common_prefixes[0], Path::from("test_dir")); + + // List everything starting with a prefix that should return results + let prefix = Path::from("test_dir"); + let content_list = flatten_list_stream(storage, Some(&prefix)).await.unwrap(); + assert_eq!(content_list, &[location.clone()]); + + // List everything starting with a prefix that shouldn't return results + let prefix = Path::from("something"); + let content_list = flatten_list_stream(storage, Some(&prefix)).await.unwrap(); + assert!(content_list.is_empty()); + + let read_data = storage.get(&location).await.unwrap().bytes().await.unwrap(); + assert_eq!(&*read_data, expected_data); + + // Test range request + let range = 3..7; + let range_result = storage.get_range(&location, range.clone()).await; + + let out_of_range = 200..300; + let out_of_range_result = storage.get_range(&location, out_of_range).await; + + let bytes = range_result.unwrap(); + assert_eq!(bytes, expected_data.slice(range)); + + // Should be a non-fatal error + out_of_range_result.unwrap_err(); + + let ranges = vec![0..1, 2..3, 0..5]; + let bytes = storage.get_ranges(&location, &ranges).await.unwrap(); + for (range, bytes) in ranges.iter().zip(bytes) { + assert_eq!(bytes, expected_data.slice(range.clone())) + } + + let head = storage.head(&location).await.unwrap(); + assert_eq!(head.size, expected_data.len()); + + storage.delete(&location).await.unwrap(); + + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!(content_list.is_empty()); + + let err = storage.get(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + let err = storage.head(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + // Test handling of paths containing an encoded delimiter + + let file_with_delimiter = Path::from_iter(["a", "b/c", "foo.file"]); + storage + .put(&file_with_delimiter, Bytes::from("arbitrary")) + .await + .unwrap(); + + let files = flatten_list_stream(storage, None).await.unwrap(); + assert_eq!(files, vec![file_with_delimiter.clone()]); + + let files = flatten_list_stream(storage, Some(&Path::from("a/b"))) + .await + .unwrap(); + assert!(files.is_empty()); + + let files = storage + .list_with_delimiter(Some(&Path::from("a/b"))) + .await + .unwrap(); + assert!(files.common_prefixes.is_empty()); + assert!(files.objects.is_empty()); + + let files = storage + .list_with_delimiter(Some(&Path::from("a"))) + .await + .unwrap(); + assert_eq!(files.common_prefixes, vec![Path::from_iter(["a", "b/c"])]); + assert!(files.objects.is_empty()); + + let files = storage + .list_with_delimiter(Some(&Path::from_iter(["a", "b/c"]))) + .await + .unwrap(); + assert!(files.common_prefixes.is_empty()); + assert_eq!(files.objects.len(), 1); + assert_eq!(files.objects[0].location, file_with_delimiter); + + storage.delete(&file_with_delimiter).await.unwrap(); + + // Test handling of paths containing non-ASCII characters, e.g. emoji + + let emoji_prefix = Path::from("🙀"); + let emoji_file = Path::from("🙀/😀.parquet"); + storage + .put(&emoji_file, Bytes::from("arbitrary")) + .await + .unwrap(); + + storage.head(&emoji_file).await.unwrap(); + storage + .get(&emoji_file) + .await + .unwrap() + .bytes() + .await + .unwrap(); + + let files = flatten_list_stream(storage, Some(&emoji_prefix)) + .await + .unwrap(); + + assert_eq!(files, vec![emoji_file.clone()]); + + let dst = Path::from("foo.parquet"); + storage.copy(&emoji_file, &dst).await.unwrap(); + let mut files = flatten_list_stream(storage, None).await.unwrap(); + files.sort_unstable(); + assert_eq!(files, vec![emoji_file.clone(), dst.clone()]); + + storage.delete(&emoji_file).await.unwrap(); + storage.delete(&dst).await.unwrap(); + let files = flatten_list_stream(storage, Some(&emoji_prefix)) + .await + .unwrap(); + assert!(files.is_empty()); + + // Test handling of paths containing percent-encoded sequences + + // "HELLO" percent encoded + let hello_prefix = Path::parse("%48%45%4C%4C%4F").unwrap(); + let path = hello_prefix.child("foo.parquet"); + + storage.put(&path, Bytes::from(vec![0, 1])).await.unwrap(); + let files = flatten_list_stream(storage, Some(&hello_prefix)) + .await + .unwrap(); + assert_eq!(files, vec![path.clone()]); + + // Cannot list by decoded representation + let files = flatten_list_stream(storage, Some(&Path::from("HELLO"))) + .await + .unwrap(); + assert!(files.is_empty()); + + // Cannot access by decoded representation + let err = storage + .head(&Path::from("HELLO/foo.parquet")) + .await + .unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err); + + storage.delete(&path).await.unwrap(); + + // Can also write non-percent encoded sequences + let path = Path::parse("%Q.parquet").unwrap(); + storage.put(&path, Bytes::from(vec![0, 1])).await.unwrap(); + + let files = flatten_list_stream(storage, None).await.unwrap(); + assert_eq!(files, vec![path.clone()]); + + storage.delete(&path).await.unwrap(); + } + + fn get_vec_of_bytes(chunk_length: usize, num_chunks: usize) -> Vec { + std::iter::repeat(Bytes::from_iter(std::iter::repeat(b'x').take(chunk_length))) + .take(num_chunks) + .collect() + } + + pub(crate) async fn stream_get(storage: &DynObjectStore) { + let location = Path::from("test_dir/test_upload_file.txt"); + + // Can write to storage + let data = get_vec_of_bytes(5_000, 10); + let bytes_expected = data.concat(); + let (_, mut writer) = storage.put_multipart(&location).await.unwrap(); + for chunk in &data { + writer.write_all(chunk).await.unwrap(); + } + + // Object should not yet exist in store + let meta_res = storage.head(&location).await; + assert!(meta_res.is_err()); + assert!(matches!( + meta_res.unwrap_err(), + crate::Error::NotFound { .. } + )); + + writer.shutdown().await.unwrap(); + let bytes_written = storage.get(&location).await.unwrap().bytes().await.unwrap(); + assert_eq!(bytes_expected, bytes_written); + + // Can overwrite some storage + let data = get_vec_of_bytes(5_000, 5); + let bytes_expected = data.concat(); + let (_, mut writer) = storage.put_multipart(&location).await.unwrap(); + for chunk in &data { + writer.write_all(chunk).await.unwrap(); + } + writer.shutdown().await.unwrap(); + let bytes_written = storage.get(&location).await.unwrap().bytes().await.unwrap(); + assert_eq!(bytes_expected, bytes_written); + + // We can abort an empty write + let location = Path::from("test_dir/test_abort_upload.txt"); + let (upload_id, writer) = storage.put_multipart(&location).await.unwrap(); + drop(writer); + storage + .abort_multipart(&location, &upload_id) + .await + .unwrap(); + let get_res = storage.get(&location).await; + assert!(get_res.is_err()); + assert!(matches!( + get_res.unwrap_err(), + crate::Error::NotFound { .. } + )); + + // We can abort an in-progress write + let (upload_id, mut writer) = storage.put_multipart(&location).await.unwrap(); + if let Some(chunk) = data.get(0) { + writer.write_all(chunk).await.unwrap(); + let _ = writer.write(chunk).await.unwrap(); + } + drop(writer); + + storage + .abort_multipart(&location, &upload_id) + .await + .unwrap(); + let get_res = storage.get(&location).await; + assert!(get_res.is_err()); + assert!(matches!( + get_res.unwrap_err(), + crate::Error::NotFound { .. } + )); + } + + pub(crate) async fn list_uses_directories_correctly(storage: &DynObjectStore) { + delete_fixtures(storage).await; + + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!( + content_list.is_empty(), + "Expected list to be empty; found: {:?}", + content_list + ); + + let location1 = Path::from("foo/x.json"); + let location2 = Path::from("foo.bar/y.json"); + + let data = Bytes::from("arbitrary data"); + storage.put(&location1, data.clone()).await.unwrap(); + storage.put(&location2, data).await.unwrap(); + + let prefix = Path::from("foo"); + let content_list = flatten_list_stream(storage, Some(&prefix)).await.unwrap(); + assert_eq!(content_list, &[location1.clone()]); + + let prefix = Path::from("foo/x"); + let content_list = flatten_list_stream(storage, Some(&prefix)).await.unwrap(); + assert_eq!(content_list, &[]); + } + + pub(crate) async fn list_with_delimiter(storage: &DynObjectStore) { + delete_fixtures(storage).await; + + // ==================== check: store is empty ==================== + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!(content_list.is_empty()); + + // ==================== do: create files ==================== + let data = Bytes::from("arbitrary data"); + + let files: Vec<_> = [ + "test_file", + "mydb/wb/000/000/000.segment", + "mydb/wb/000/000/001.segment", + "mydb/wb/000/000/002.segment", + "mydb/wb/001/001/000.segment", + "mydb/wb/foo.json", + "mydb/wbwbwb/111/222/333.segment", + "mydb/data/whatevs", + ] + .iter() + .map(|&s| Path::from(s)) + .collect(); + + for f in &files { + let data = data.clone(); + storage.put(f, data).await.unwrap(); + } + + // ==================== check: prefix-list `mydb/wb` (directory) ==================== + let prefix = Path::from("mydb/wb"); + + let expected_000 = Path::from("mydb/wb/000"); + let expected_001 = Path::from("mydb/wb/001"); + let expected_location = Path::from("mydb/wb/foo.json"); + + let result = storage.list_with_delimiter(Some(&prefix)).await.unwrap(); + + assert_eq!(result.common_prefixes, vec![expected_000, expected_001]); + assert_eq!(result.objects.len(), 1); + + let object = &result.objects[0]; + + assert_eq!(object.location, expected_location); + assert_eq!(object.size, data.len()); + + // ==================== check: prefix-list `mydb/wb/000/000/001` (partial filename doesn't match) ==================== + let prefix = Path::from("mydb/wb/000/000/001"); + + let result = storage.list_with_delimiter(Some(&prefix)).await.unwrap(); + assert!(result.common_prefixes.is_empty()); + assert_eq!(result.objects.len(), 0); + + // ==================== check: prefix-list `not_there` (non-existing prefix) ==================== + let prefix = Path::from("not_there"); + + let result = storage.list_with_delimiter(Some(&prefix)).await.unwrap(); + assert!(result.common_prefixes.is_empty()); + assert!(result.objects.is_empty()); + + // ==================== do: remove all files ==================== + for f in &files { + storage.delete(f).await.unwrap(); + } + + // ==================== check: store is empty ==================== + let content_list = flatten_list_stream(storage, None).await.unwrap(); + assert!(content_list.is_empty()); + } + + pub(crate) async fn get_nonexistent_object( + storage: &DynObjectStore, + location: Option, + ) -> crate::Result { + let location = + location.unwrap_or_else(|| Path::from("this_file_should_not_exist")); + + let err = storage.head(&location).await.unwrap_err(); + assert!(matches!(err, crate::Error::NotFound { .. })); + + storage.get(&location).await?.bytes().await + } + + pub(crate) async fn rename_and_copy(storage: &DynObjectStore) { + // Create two objects + let path1 = Path::from("test1"); + let path2 = Path::from("test2"); + let contents1 = Bytes::from("cats"); + let contents2 = Bytes::from("dogs"); + + // copy() make both objects identical + storage.put(&path1, contents1.clone()).await.unwrap(); + storage.put(&path2, contents2.clone()).await.unwrap(); + storage.copy(&path1, &path2).await.unwrap(); + let new_contents = storage.get(&path2).await.unwrap().bytes().await.unwrap(); + assert_eq!(&new_contents, &contents1); + + // rename() copies contents and deletes original + storage.put(&path1, contents1.clone()).await.unwrap(); + storage.put(&path2, contents2.clone()).await.unwrap(); + storage.rename(&path1, &path2).await.unwrap(); + let new_contents = storage.get(&path2).await.unwrap().bytes().await.unwrap(); + assert_eq!(&new_contents, &contents1); + let result = storage.get(&path1).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), crate::Error::NotFound { .. })); + + // Clean up + storage.delete(&path2).await.unwrap(); + } + + pub(crate) async fn copy_if_not_exists(storage: &DynObjectStore) { + // Create two objects + let path1 = Path::from("test1"); + let path2 = Path::from("test2"); + let contents1 = Bytes::from("cats"); + let contents2 = Bytes::from("dogs"); + + // copy_if_not_exists() errors if destination already exists + storage.put(&path1, contents1.clone()).await.unwrap(); + storage.put(&path2, contents2.clone()).await.unwrap(); + let result = storage.copy_if_not_exists(&path1, &path2).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + crate::Error::AlreadyExists { .. } + )); + + // copy_if_not_exists() copies contents and allows deleting original + storage.delete(&path2).await.unwrap(); + storage.copy_if_not_exists(&path1, &path2).await.unwrap(); + storage.delete(&path1).await.unwrap(); + let new_contents = storage.get(&path2).await.unwrap().bytes().await.unwrap(); + assert_eq!(&new_contents, &contents1); + let result = storage.get(&path1).await; + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), crate::Error::NotFound { .. })); + + // Clean up + storage.delete(&path2).await.unwrap(); + } + + async fn delete_fixtures(storage: &DynObjectStore) { + let paths = flatten_list_stream(storage, None).await.unwrap(); + + for f in &paths { + let _ = storage.delete(f).await; + } + } + + /// Test that the returned stream does not borrow the lifetime of Path + async fn list_store<'a, 'b>( + store: &'a dyn ObjectStore, + path_str: &'b str, + ) -> super::Result>> { + let path = Path::from(path_str); + store.list(Some(&path)).await + } + + #[tokio::test] + async fn test_list_lifetimes() { + let store = memory::InMemory::new(); + let mut stream = list_store(&store, "path").await.unwrap(); + assert!(stream.next().await.is_none()); + } + + // Tests TODO: + // GET nonexisting location (in_memory/file) + // DELETE nonexisting location + // PUT overwriting +} diff --git a/object_store/src/limit.rs b/object_store/src/limit.rs new file mode 100644 index 000000000000..09c88aa2a4bc --- /dev/null +++ b/object_store/src/limit.rs @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store that limits the maximum concurrency of the wrapped implementation + +use crate::{ + BoxStream, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, Result, + StreamExt, +}; +use async_trait::async_trait; +use bytes::Bytes; +use futures::Stream; +use std::io::{Error, IoSlice}; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +/// Store wrapper that wraps an inner store and limits the maximum number of concurrent +/// object store operations. Where each call to an [`ObjectStore`] member function is +/// considered a single operation, even if it may result in more than one network call +/// +/// ``` +/// # use object_store::memory::InMemory; +/// # use object_store::limit::LimitStore; +/// +/// // Create an in-memory `ObjectStore` limited to 20 concurrent requests +/// let store = LimitStore::new(InMemory::new(), 20); +/// ``` +/// +#[derive(Debug)] +pub struct LimitStore { + inner: T, + max_requests: usize, + semaphore: Arc, +} + +impl LimitStore { + /// Create new limit store that will limit the maximum + /// number of outstanding concurrent requests to + /// `max_requests` + pub fn new(inner: T, max_requests: usize) -> Self { + Self { + inner, + max_requests, + semaphore: Arc::new(Semaphore::new(max_requests)), + } + } +} + +impl std::fmt::Display for LimitStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "LimitStore({}, {})", self.max_requests, self.inner) + } +} + +#[async_trait] +impl ObjectStore for LimitStore { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.put(location, bytes).await + } + + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); + let (id, write) = self.inner.put_multipart(location).await?; + Ok((id, Box::new(PermitWrapper::new(write, permit)))) + } + + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.abort_multipart(location, multipart_id).await + } + + async fn get(&self, location: &Path) -> Result { + let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); + match self.inner.get(location).await? { + r @ GetResult::File(_, _) => Ok(r), + GetResult::Stream(s) => { + Ok(GetResult::Stream(PermitWrapper::new(s, permit).boxed())) + } + } + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.get_range(location, range).await + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> Result> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.get_ranges(location, ranges).await + } + + async fn head(&self, location: &Path) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.head(location).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.delete(location).await + } + + async fn list( + &self, + prefix: Option<&Path>, + ) -> Result>> { + let permit = Arc::clone(&self.semaphore).acquire_owned().await.unwrap(); + let s = self.inner.list(prefix).await?; + Ok(PermitWrapper::new(s, permit).boxed()) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.copy(from, to).await + } + + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.rename(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.copy_if_not_exists(from, to).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let _permit = self.semaphore.acquire().await.unwrap(); + self.inner.rename_if_not_exists(from, to).await + } +} + +/// Combines an [`OwnedSemaphorePermit`] with some other type +struct PermitWrapper { + inner: T, + #[allow(dead_code)] + permit: OwnedSemaphorePermit, +} + +impl PermitWrapper { + fn new(inner: T, permit: OwnedSemaphorePermit) -> Self { + Self { inner, permit } + } +} + +impl Stream for PermitWrapper { + type Item = T::Item; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl AsyncWrite for PermitWrapper { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } +} + +#[cfg(test)] +mod tests { + use crate::limit::LimitStore; + use crate::memory::InMemory; + use crate::tests::{ + list_uses_directories_correctly, list_with_delimiter, put_get_delete_list, + rename_and_copy, stream_get, + }; + use crate::ObjectStore; + use std::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn limit_test() { + let max_requests = 10; + let memory = InMemory::new(); + let integration = LimitStore::new(memory, max_requests); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + stream_get(&integration).await; + + let mut streams = Vec::with_capacity(max_requests); + for _ in 0..max_requests { + let stream = integration.list(None).await.unwrap(); + streams.push(stream); + } + + let t = Duration::from_millis(20); + + // Expect to not be able to make another request + assert!(timeout(t, integration.list(None)).await.is_err()); + + // Drop one of the streams + streams.pop(); + + // Can now make another request + integration.list(None).await.unwrap(); + } +} diff --git a/object_store/src/local.rs b/object_store/src/local.rs new file mode 100644 index 000000000000..fd3c3592ab56 --- /dev/null +++ b/object_store/src/local.rs @@ -0,0 +1,1314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An object store implementation for a local filesystem +use crate::{ + maybe_spawn_blocking, + path::{absolute_path_to_url, Path}, + GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, +}; +use async_trait::async_trait; +use bytes::Bytes; +use futures::future::BoxFuture; +use futures::FutureExt; +use futures::{stream::BoxStream, StreamExt}; +use snafu::{ensure, OptionExt, ResultExt, Snafu}; +use std::fs::{metadata, symlink_metadata, File}; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; +use std::{collections::BTreeSet, convert::TryFrom, io}; +use std::{collections::VecDeque, path::PathBuf}; +use tokio::io::AsyncWrite; +use url::Url; +use walkdir::{DirEntry, WalkDir}; + +/// A specialized `Error` for filesystem object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub(crate) enum Error { + #[snafu(display("File size for {} did not fit in a usize: {}", path, source))] + FileSizeOverflowedUsize { + source: std::num::TryFromIntError, + path: String, + }, + + #[snafu(display("Unable to walk dir: {}", source))] + UnableToWalkDir { + source: walkdir::Error, + }, + + #[snafu(display("Unable to access metadata for {}: {}", path, source))] + UnableToAccessMetadata { + source: Box, + path: String, + }, + + #[snafu(display("Unable to copy data to file: {}", source))] + UnableToCopyDataToFile { + source: io::Error, + }, + + #[snafu(display("Unable to create dir {}: {}", path.display(), source))] + UnableToCreateDir { + source: io::Error, + path: PathBuf, + }, + + #[snafu(display("Unable to create file {}: {}", path.display(), err))] + UnableToCreateFile { + path: PathBuf, + err: io::Error, + }, + + #[snafu(display("Unable to delete file {}: {}", path.display(), source))] + UnableToDeleteFile { + source: io::Error, + path: PathBuf, + }, + + #[snafu(display("Unable to open file {}: {}", path.display(), source))] + UnableToOpenFile { + source: io::Error, + path: PathBuf, + }, + + #[snafu(display("Unable to read data from file {}: {}", path.display(), source))] + UnableToReadBytes { + source: io::Error, + path: PathBuf, + }, + + #[snafu(display("Out of range of file {}, expected: {}, actual: {}", path.display(), expected, actual))] + OutOfRange { + path: PathBuf, + expected: usize, + actual: usize, + }, + + #[snafu(display("Unable to copy file from {} to {}: {}", from.display(), to.display(), source))] + UnableToCopyFile { + from: PathBuf, + to: PathBuf, + source: io::Error, + }, + + NotFound { + path: PathBuf, + source: io::Error, + }, + + #[snafu(display("Error seeking file {}: {}", path.display(), source))] + Seek { + source: io::Error, + path: PathBuf, + }, + + #[snafu(display("Unable to convert URL \"{}\" to filesystem path", url))] + InvalidUrl { + url: Url, + }, + + AlreadyExists { + path: String, + source: io::Error, + }, + + #[snafu(display("Unable to canonicalize filesystem root: {}", path.display()))] + UnableToCanonicalize { + path: PathBuf, + source: io::Error, + }, +} + +impl From for super::Error { + fn from(source: Error) -> Self { + match source { + Error::NotFound { path, source } => Self::NotFound { + path: path.to_string_lossy().to_string(), + source: source.into(), + }, + Error::AlreadyExists { path, source } => Self::AlreadyExists { + path, + source: source.into(), + }, + _ => Self::Generic { + store: "LocalFileSystem", + source: Box::new(source), + }, + } + } +} + +/// Local filesystem storage providing an [`ObjectStore`] interface to files on +/// local disk. Can optionally be created with a directory prefix +/// +/// # Path Semantics +/// +/// This implementation follows the [file URI] scheme outlined in [RFC 3986]. In +/// particular paths are delimited by `/` +/// +/// [file URI]: https://en.wikipedia.org/wiki/File_URI_scheme +/// [RFC 3986]: https://www.rfc-editor.org/rfc/rfc3986 +/// +/// # Tokio Compatibility +/// +/// Tokio discourages performing blocking IO on a tokio worker thread, however, +/// no major operating systems have stable async file APIs. Therefore if called from +/// a tokio context, this will use [`tokio::runtime::Handle::spawn_blocking`] to dispatch +/// IO to a blocking thread pool, much like `tokio::fs` does under-the-hood. +/// +/// If not called from a tokio context, this will perform IO on the current thread with +/// no additional complexity or overheads +/// +/// # Symlinks +/// +/// [`LocalFileSystem`] will follow symlinks as normal, however, it is worth noting: +/// +/// * Broken symlinks will be silently ignored by listing operations +/// * No effort is made to prevent breaking symlinks when deleting files +/// * Symlinks that resolve to paths outside the root **will** be followed +/// * Mutating a file through one or more symlinks will mutate the underlying file +/// * Deleting a path that resolves to a symlink will only delete the symlink +/// +#[derive(Debug)] +pub struct LocalFileSystem { + config: Arc, +} + +#[derive(Debug)] +struct Config { + root: Url, +} + +impl std::fmt::Display for LocalFileSystem { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "LocalFileSystem({})", self.config.root) + } +} + +impl Default for LocalFileSystem { + fn default() -> Self { + Self::new() + } +} + +impl LocalFileSystem { + /// Create new filesystem storage with no prefix + pub fn new() -> Self { + Self { + config: Arc::new(Config { + root: Url::parse("file:///").unwrap(), + }), + } + } + + /// Create new filesystem storage with `prefix` applied to all paths + /// + /// Returns an error if the path does not exist + /// + pub fn new_with_prefix(prefix: impl AsRef) -> Result { + let path = std::fs::canonicalize(&prefix).context(UnableToCanonicalizeSnafu { + path: prefix.as_ref(), + })?; + + Ok(Self { + config: Arc::new(Config { + root: absolute_path_to_url(path)?, + }), + }) + } +} + +impl Config { + /// Return an absolute filesystem path of the given location + fn path_to_filesystem(&self, location: &Path) -> Result { + let mut url = self.root.clone(); + url.path_segments_mut() + .expect("url path") + // technically not necessary as Path ignores empty segments + // but avoids creating paths with "//" which look odd in error messages. + .pop_if_empty() + .extend(location.parts()); + + url.to_file_path() + .map_err(|_| Error::InvalidUrl { url }.into()) + } + + /// Resolves the provided absolute filesystem path to a [`Path`] prefix + fn filesystem_to_path(&self, location: &std::path::Path) -> Result { + Ok(Path::from_absolute_path_with_base( + location, + Some(&self.root), + )?) + } +} + +#[async_trait] +impl ObjectStore for LocalFileSystem { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + let path = self.config.path_to_filesystem(location)?; + + maybe_spawn_blocking(move || { + let mut file = open_writable_file(&path)?; + + file.write_all(&bytes) + .context(UnableToCopyDataToFileSnafu)?; + + Ok(()) + }) + .await + } + + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + let dest = self.config.path_to_filesystem(location)?; + + // Generate an id in case of concurrent writes + let mut multipart_id = 1; + + // Will write to a temporary path + let staging_path = loop { + let staging_path = get_upload_stage_path(&dest, &multipart_id.to_string()); + + match std::fs::metadata(&staging_path) { + Err(err) if err.kind() == io::ErrorKind::NotFound => break staging_path, + Err(err) => { + return Err(Error::UnableToCopyDataToFile { source: err }.into()) + } + Ok(_) => multipart_id += 1, + } + }; + let multipart_id = multipart_id.to_string(); + + let file = open_writable_file(&staging_path)?; + + Ok(( + multipart_id.clone(), + Box::new(LocalUpload::new(dest, multipart_id, Arc::new(file))), + )) + } + + async fn abort_multipart( + &self, + location: &Path, + multipart_id: &MultipartId, + ) -> Result<()> { + let dest = self.config.path_to_filesystem(location)?; + let staging_path: PathBuf = get_upload_stage_path(&dest, multipart_id); + + maybe_spawn_blocking(move || { + std::fs::remove_file(&staging_path) + .context(UnableToDeleteFileSnafu { path: staging_path })?; + Ok(()) + }) + .await + } + + async fn get(&self, location: &Path) -> Result { + let path = self.config.path_to_filesystem(location)?; + maybe_spawn_blocking(move || { + let file = open_file(&path)?; + Ok(GetResult::File(file, path)) + }) + .await + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let path = self.config.path_to_filesystem(location)?; + maybe_spawn_blocking(move || { + let mut file = open_file(&path)?; + read_range(&mut file, &path, range) + }) + .await + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> Result> { + let path = self.config.path_to_filesystem(location)?; + let ranges = ranges.to_vec(); + maybe_spawn_blocking(move || { + // Vectored IO might be faster + let mut file = open_file(&path)?; + ranges + .into_iter() + .map(|r| read_range(&mut file, &path, r)) + .collect() + }) + .await + } + + async fn head(&self, location: &Path) -> Result { + let path = self.config.path_to_filesystem(location)?; + let location = location.clone(); + + maybe_spawn_blocking(move || { + let file = open_file(&path)?; + let metadata = + file.metadata().map_err(|e| Error::UnableToAccessMetadata { + source: e.into(), + path: location.to_string(), + })?; + + convert_metadata(metadata, location) + }) + .await + } + + async fn delete(&self, location: &Path) -> Result<()> { + let path = self.config.path_to_filesystem(location)?; + maybe_spawn_blocking(move || { + std::fs::remove_file(&path).context(UnableToDeleteFileSnafu { path })?; + Ok(()) + }) + .await + } + + async fn list( + &self, + prefix: Option<&Path>, + ) -> Result>> { + let config = Arc::clone(&self.config); + + let root_path = match prefix { + Some(prefix) => config.path_to_filesystem(prefix)?, + None => self.config.root.to_file_path().unwrap(), + }; + + let walkdir = WalkDir::new(&root_path) + // Don't include the root directory itself + .min_depth(1) + .follow_links(true); + + let s = walkdir.into_iter().flat_map(move |result_dir_entry| { + match convert_walkdir_result(result_dir_entry) { + Err(e) => Some(Err(e)), + Ok(None) => None, + Ok(entry @ Some(_)) => entry + .filter(|dir_entry| { + dir_entry.file_type().is_file() + // Ignore file names with # in them, since they might be in-progress uploads. + // They would be rejected anyways by filesystem_to_path below. + && !dir_entry.file_name().to_string_lossy().contains('#') + }) + .map(|entry| { + let location = config.filesystem_to_path(entry.path())?; + convert_entry(entry, location) + }), + } + }); + + // If no tokio context, return iterator directly as no + // need to perform chunked spawn_blocking reads + if tokio::runtime::Handle::try_current().is_err() { + return Ok(futures::stream::iter(s).boxed()); + } + + // Otherwise list in batches of CHUNK_SIZE + const CHUNK_SIZE: usize = 1024; + + let buffer = VecDeque::with_capacity(CHUNK_SIZE); + let stream = + futures::stream::try_unfold((s, buffer), |(mut s, mut buffer)| async move { + if buffer.is_empty() { + (s, buffer) = tokio::task::spawn_blocking(move || { + for _ in 0..CHUNK_SIZE { + match s.next() { + Some(r) => buffer.push_back(r), + None => break, + } + } + (s, buffer) + }) + .await?; + } + + match buffer.pop_front() { + Some(Err(e)) => Err(e), + Some(Ok(meta)) => Ok(Some((meta, (s, buffer)))), + None => Ok(None), + } + }); + + Ok(stream.boxed()) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let config = Arc::clone(&self.config); + + let prefix = prefix.cloned().unwrap_or_default(); + let resolved_prefix = config.path_to_filesystem(&prefix)?; + + maybe_spawn_blocking(move || { + let walkdir = WalkDir::new(&resolved_prefix) + .min_depth(1) + .max_depth(1) + .follow_links(true); + + let mut common_prefixes = BTreeSet::new(); + let mut objects = Vec::new(); + + for entry_res in walkdir.into_iter().map(convert_walkdir_result) { + if let Some(entry) = entry_res? { + if entry.file_type().is_file() + // Ignore file names with # in them, since they might be in-progress uploads. + // They would be rejected anyways by filesystem_to_path below. + && entry.file_name().to_string_lossy().contains('#') + { + continue; + } + let is_directory = entry.file_type().is_dir(); + let entry_location = config.filesystem_to_path(entry.path())?; + + let mut parts = match entry_location.prefix_match(&prefix) { + Some(parts) => parts, + None => continue, + }; + + let common_prefix = match parts.next() { + Some(p) => p, + None => continue, + }; + + drop(parts); + + if is_directory { + common_prefixes.insert(prefix.child(common_prefix)); + } else { + objects.push(convert_entry(entry, entry_location)?); + } + } + } + + Ok(ListResult { + common_prefixes: common_prefixes.into_iter().collect(), + objects, + }) + }) + .await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + let from = self.config.path_to_filesystem(from)?; + let to = self.config.path_to_filesystem(to)?; + + maybe_spawn_blocking(move || { + std::fs::copy(&from, &to).context(UnableToCopyFileSnafu { from, to })?; + Ok(()) + }) + .await + } + + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + let from = self.config.path_to_filesystem(from)?; + let to = self.config.path_to_filesystem(to)?; + maybe_spawn_blocking(move || { + std::fs::rename(&from, &to).context(UnableToCopyFileSnafu { from, to })?; + Ok(()) + }) + .await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let from = self.config.path_to_filesystem(from)?; + let to = self.config.path_to_filesystem(to)?; + + maybe_spawn_blocking(move || { + std::fs::hard_link(&from, &to).map_err(|err| match err.kind() { + io::ErrorKind::AlreadyExists => Error::AlreadyExists { + path: to.to_str().unwrap().to_string(), + source: err, + } + .into(), + _ => Error::UnableToCopyFile { + from, + to, + source: err, + } + .into(), + }) + }) + .await + } +} + +fn get_upload_stage_path(dest: &std::path::Path, multipart_id: &MultipartId) -> PathBuf { + let mut staging_path = dest.as_os_str().to_owned(); + staging_path.push(format!("#{}", multipart_id)); + staging_path.into() +} + +enum LocalUploadState { + /// Upload is ready to send new data + Idle(Arc), + /// In the middle of a write + Writing( + Arc, + BoxFuture<'static, Result>, + ), + /// In the middle of syncing data and closing file. + /// + /// Future will contain last reference to file, so it will call drop on completion. + ShuttingDown(BoxFuture<'static, Result<(), io::Error>>), + /// File is being moved from it's temporary location to the final location + Committing(BoxFuture<'static, Result<(), io::Error>>), + /// Upload is complete + Complete, +} + +struct LocalUpload { + inner_state: LocalUploadState, + dest: PathBuf, + multipart_id: MultipartId, +} + +impl LocalUpload { + pub fn new( + dest: PathBuf, + multipart_id: MultipartId, + file: Arc, + ) -> Self { + Self { + inner_state: LocalUploadState::Idle(file), + dest, + multipart_id, + } + } +} + +impl AsyncWrite for LocalUpload { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let invalid_state = + |condition: &str| -> std::task::Poll> { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Tried to write to file {}.", condition), + ))) + }; + + if let Ok(runtime) = tokio::runtime::Handle::try_current() { + let mut data: Vec = buf.to_vec(); + let data_len = data.len(); + + loop { + match &mut self.inner_state { + LocalUploadState::Idle(file) => { + let file = Arc::clone(file); + let file2 = Arc::clone(&file); + let data: Vec = std::mem::take(&mut data); + self.inner_state = LocalUploadState::Writing( + file, + Box::pin( + runtime + .spawn_blocking(move || (&*file2).write_all(&data)) + .map(move |res| match res { + Err(err) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + Ok(res) => res.map(move |_| data_len), + }), + ), + ); + } + LocalUploadState::Writing(file, inner_write) => { + match inner_write.poll_unpin(cx) { + Poll::Ready(res) => { + self.inner_state = + LocalUploadState::Idle(Arc::clone(file)); + return Poll::Ready(res); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + LocalUploadState::ShuttingDown(_) => { + return invalid_state("when writer is shutting down"); + } + LocalUploadState::Committing(_) => { + return invalid_state("when writer is committing data"); + } + LocalUploadState::Complete => { + return invalid_state("when writer is complete"); + } + } + } + } else if let LocalUploadState::Idle(file) = &self.inner_state { + let file = Arc::clone(file); + (&*file).write_all(buf)?; + Poll::Ready(Ok(buf.len())) + } else { + // If we are running on this thread, then only possible states are Idle and Complete. + invalid_state("when writer is already complete.") + } + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if let Ok(runtime) = tokio::runtime::Handle::try_current() { + loop { + match &mut self.inner_state { + LocalUploadState::Idle(file) => { + // We are moving file into the future, and it will be dropped on it's completion, closing the file. + let file = Arc::clone(file); + self.inner_state = LocalUploadState::ShuttingDown(Box::pin( + runtime.spawn_blocking(move || (*file).sync_all()).map( + move |res| match res { + Err(err) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + Ok(res) => res, + }, + ), + )); + } + LocalUploadState::ShuttingDown(fut) => match fut.poll_unpin(cx) { + Poll::Ready(res) => { + res?; + let staging_path = + get_upload_stage_path(&self.dest, &self.multipart_id); + let dest = self.dest.clone(); + self.inner_state = LocalUploadState::Committing(Box::pin( + runtime + .spawn_blocking(move || { + std::fs::rename(&staging_path, &dest) + }) + .map(move |res| match res { + Err(err) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + } + Ok(res) => res, + }), + )); + } + Poll::Pending => { + return Poll::Pending; + } + }, + LocalUploadState::Writing(_, _) => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Tried to commit a file where a write is in progress.", + ))); + } + LocalUploadState::Committing(fut) => match fut.poll_unpin(cx) { + Poll::Ready(res) => { + self.inner_state = LocalUploadState::Complete; + return Poll::Ready(res); + } + Poll::Pending => return Poll::Pending, + }, + LocalUploadState::Complete => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Already complete", + ))) + } + } + } + } else { + let staging_path = get_upload_stage_path(&self.dest, &self.multipart_id); + match &mut self.inner_state { + LocalUploadState::Idle(file) => { + let file = Arc::clone(file); + self.inner_state = LocalUploadState::Complete; + file.sync_all()?; + std::mem::drop(file); + std::fs::rename(&staging_path, &self.dest)?; + Poll::Ready(Ok(())) + } + _ => { + // If we are running on this thread, then only possible states are Idle and Complete. + Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "Already complete", + ))) + } + } + } + } +} + +fn read_range(file: &mut File, path: &PathBuf, range: Range) -> Result { + let to_read = range.end - range.start; + file.seek(SeekFrom::Start(range.start as u64)) + .context(SeekSnafu { path })?; + + let mut buf = Vec::with_capacity(to_read); + let read = file + .take(to_read as u64) + .read_to_end(&mut buf) + .context(UnableToReadBytesSnafu { path })?; + + ensure!( + read == to_read, + OutOfRangeSnafu { + path, + expected: to_read, + actual: read + } + ); + Ok(buf.into()) +} + +fn open_file(path: &PathBuf) -> Result { + let file = File::open(path).map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + Error::NotFound { + path: path.clone(), + source: e, + } + } else { + Error::UnableToOpenFile { + path: path.clone(), + source: e, + } + } + })?; + Ok(file) +} + +fn open_writable_file(path: &PathBuf) -> Result { + match File::create(&path) { + Ok(f) => Ok(f), + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + let parent = path + .parent() + .context(UnableToCreateFileSnafu { path: &path, err })?; + std::fs::create_dir_all(&parent) + .context(UnableToCreateDirSnafu { path: parent })?; + + match File::create(&path) { + Ok(f) => Ok(f), + Err(err) => Err(Error::UnableToCreateFile { + path: path.to_path_buf(), + err, + } + .into()), + } + } + Err(err) => Err(Error::UnableToCreateFile { + path: path.to_path_buf(), + err, + } + .into()), + } +} + +fn convert_entry(entry: DirEntry, location: Path) -> Result { + let metadata = entry + .metadata() + .map_err(|e| Error::UnableToAccessMetadata { + source: e.into(), + path: location.to_string(), + })?; + convert_metadata(metadata, location) +} + +fn convert_metadata(metadata: std::fs::Metadata, location: Path) -> Result { + let last_modified = metadata + .modified() + .expect("Modified file time should be supported on this platform") + .into(); + + let size = usize::try_from(metadata.len()).context(FileSizeOverflowedUsizeSnafu { + path: location.as_ref(), + })?; + + Ok(ObjectMeta { + location, + last_modified, + size, + }) +} + +/// Convert walkdir results and converts not-found errors into `None`. +/// Convert broken symlinks to `None`. +fn convert_walkdir_result( + res: std::result::Result, +) -> Result> { + match res { + Ok(entry) => { + // To check for broken symlink: call symlink_metadata() - it does not traverse symlinks); + // if ok: check if entry is symlink; and try to read it by calling metadata(). + match symlink_metadata(entry.path()) { + Ok(attr) => { + if attr.is_symlink() { + let target_metadata = metadata(entry.path()); + match target_metadata { + Ok(_) => { + // symlink is valid + Ok(Some(entry)) + } + Err(_) => { + // this is a broken symlink, return None + Ok(None) + } + } + } else { + Ok(Some(entry)) + } + } + Err(_) => Ok(None), + } + } + + Err(walkdir_err) => match walkdir_err.io_error() { + Some(io_err) => match io_err.kind() { + io::ErrorKind::NotFound => Ok(None), + _ => Err(Error::UnableToWalkDir { + source: walkdir_err, + } + .into()), + }, + None => Err(Error::UnableToWalkDir { + source: walkdir_err, + } + .into()), + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_util::flatten_list_stream; + use crate::{ + tests::{ + copy_if_not_exists, get_nonexistent_object, list_uses_directories_correctly, + list_with_delimiter, put_get_delete_list, rename_and_copy, stream_get, + }, + Error as ObjectStoreError, ObjectStore, + }; + use futures::TryStreamExt; + use tempfile::{NamedTempFile, TempDir}; + use tokio::io::AsyncWriteExt; + + #[tokio::test] + async fn file_test() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } + + #[test] + fn test_non_tokio() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + futures::executor::block_on(async move { + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + stream_get(&integration).await; + }); + } + + #[tokio::test] + async fn creates_dir_if_not_present() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from("nested/file/test_file"); + + let data = Bytes::from("arbitrary data"); + let expected_data = data.clone(); + + integration.put(&location, data).await.unwrap(); + + let read_data = integration + .get(&location) + .await + .unwrap() + .bytes() + .await + .unwrap(); + assert_eq!(&*read_data, expected_data); + } + + #[tokio::test] + async fn unknown_length() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from("some_file"); + + let data = Bytes::from("arbitrary data"); + let expected_data = data.clone(); + + integration.put(&location, data).await.unwrap(); + + let read_data = integration + .get(&location) + .await + .unwrap() + .bytes() + .await + .unwrap(); + assert_eq!(&*read_data, expected_data); + } + + #[tokio::test] + #[cfg(target_family = "unix")] + // Fails on github actions runner (which runs the tests as root) + #[ignore] + async fn bubble_up_io_errors() { + use std::{fs::set_permissions, os::unix::prelude::PermissionsExt}; + + let root = TempDir::new().unwrap(); + + // make non-readable + let metadata = root.path().metadata().unwrap(); + let mut permissions = metadata.permissions(); + permissions.set_mode(0o000); + set_permissions(root.path(), permissions).unwrap(); + + let store = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + // `list` must fail + match store.list(None).await { + Err(_) => { + // ok, error found + } + Ok(mut stream) => { + let mut any_err = false; + while let Some(res) = stream.next().await { + if res.is_err() { + any_err = true; + } + } + assert!(any_err); + } + } + + // `list_with_delimiter + assert!(store.list_with_delimiter(None).await.is_err()); + } + + const NON_EXISTENT_NAME: &str = "nonexistentname"; + + #[tokio::test] + async fn get_nonexistent_location() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let location = Path::from(NON_EXISTENT_NAME); + + let err = get_nonexistent_object(&integration, Some(location)) + .await + .unwrap_err(); + if let ObjectStoreError::NotFound { path, source } = err { + let source_variant = source.downcast_ref::(); + assert!( + matches!(source_variant, Some(std::io::Error { .. }),), + "got: {:?}", + source_variant + ); + assert!(path.ends_with(NON_EXISTENT_NAME), "{}", path); + } else { + panic!("unexpected error type: {:?}", err); + } + } + + #[tokio::test] + async fn root() { + let integration = LocalFileSystem::new(); + + let canonical = std::path::Path::new("Cargo.toml").canonicalize().unwrap(); + let url = Url::from_directory_path(&canonical).unwrap(); + let path = Path::parse(url.path()).unwrap(); + + let roundtrip = integration.config.path_to_filesystem(&path).unwrap(); + + // Needed as on Windows canonicalize returns extended length path syntax + // C:\Users\circleci -> \\?\C:\Users\circleci + let roundtrip = roundtrip.canonicalize().unwrap(); + + assert_eq!(roundtrip, canonical); + + integration.head(&path).await.unwrap(); + } + + #[tokio::test] + async fn test_list_root() { + let integration = LocalFileSystem::new(); + let result = integration.list_with_delimiter(None).await; + if cfg!(target_family = "windows") { + let r = result.unwrap_err().to_string(); + assert!( + r.contains("Unable to convert URL \"file:///\" to filesystem path"), + "{}", + r + ); + } else { + result.unwrap(); + } + } + + async fn check_list( + integration: &LocalFileSystem, + prefix: Option<&Path>, + expected: &[&str], + ) { + let result: Vec<_> = integration + .list(prefix) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let mut strings: Vec<_> = result.iter().map(|x| x.location.as_ref()).collect(); + strings.sort_unstable(); + assert_eq!(&strings, expected) + } + + #[tokio::test] + #[cfg(target_family = "unix")] + async fn test_symlink() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + + let subdir = root.path().join("a"); + std::fs::create_dir(&subdir).unwrap(); + let file = subdir.join("file.parquet"); + std::fs::write(file, "test").unwrap(); + + check_list(&integration, None, &["a/file.parquet"]).await; + integration + .head(&Path::from("a/file.parquet")) + .await + .unwrap(); + + // Follow out of tree symlink + let other = NamedTempFile::new().unwrap(); + std::os::unix::fs::symlink(other.path(), root.path().join("test.parquet")) + .unwrap(); + + // Should return test.parquet even though out of tree + check_list(&integration, None, &["a/file.parquet", "test.parquet"]).await; + + // Can fetch test.parquet + integration.head(&Path::from("test.parquet")).await.unwrap(); + + // Follow in tree symlink + std::os::unix::fs::symlink(&subdir, root.path().join("b")).unwrap(); + check_list( + &integration, + None, + &["a/file.parquet", "b/file.parquet", "test.parquet"], + ) + .await; + check_list(&integration, Some(&Path::from("b")), &["b/file.parquet"]).await; + + // Can fetch through symlink + integration + .head(&Path::from("b/file.parquet")) + .await + .unwrap(); + + // Ignore broken symlink + std::os::unix::fs::symlink( + root.path().join("foo.parquet"), + root.path().join("c"), + ) + .unwrap(); + + check_list( + &integration, + None, + &["a/file.parquet", "b/file.parquet", "test.parquet"], + ) + .await; + + let mut r = integration.list_with_delimiter(None).await.unwrap(); + r.common_prefixes.sort_unstable(); + assert_eq!(r.common_prefixes.len(), 2); + assert_eq!(r.common_prefixes[0].as_ref(), "a"); + assert_eq!(r.common_prefixes[1].as_ref(), "b"); + assert_eq!(r.objects.len(), 1); + assert_eq!(r.objects[0].location.as_ref(), "test.parquet"); + + let r = integration + .list_with_delimiter(Some(&Path::from("a"))) + .await + .unwrap(); + assert_eq!(r.common_prefixes.len(), 0); + assert_eq!(r.objects.len(), 1); + assert_eq!(r.objects[0].location.as_ref(), "a/file.parquet"); + + // Deleting a symlink doesn't delete the source file + integration + .delete(&Path::from("test.parquet")) + .await + .unwrap(); + assert!(other.path().exists()); + + check_list(&integration, None, &["a/file.parquet", "b/file.parquet"]).await; + + // Deleting through a symlink deletes both files + integration + .delete(&Path::from("b/file.parquet")) + .await + .unwrap(); + + check_list(&integration, None, &[]).await; + + // Adding a file through a symlink creates in both paths + integration + .put(&Path::from("b/file.parquet"), Bytes::from(vec![0, 1, 2])) + .await + .unwrap(); + + check_list(&integration, None, &["a/file.parquet", "b/file.parquet"]).await; + } + + #[tokio::test] + async fn invalid_path() { + let root = TempDir::new().unwrap(); + let root = root.path().join("🙀"); + std::fs::create_dir(root.clone()).unwrap(); + + // Invalid paths supported above root of store + let integration = LocalFileSystem::new_with_prefix(root.clone()).unwrap(); + + let directory = Path::from("directory"); + let object = directory.child("child.txt"); + let data = Bytes::from("arbitrary"); + integration.put(&object, data.clone()).await.unwrap(); + integration.head(&object).await.unwrap(); + let result = integration.get(&object).await.unwrap(); + assert_eq!(result.bytes().await.unwrap(), data); + + flatten_list_stream(&integration, None).await.unwrap(); + flatten_list_stream(&integration, Some(&directory)) + .await + .unwrap(); + + let result = integration + .list_with_delimiter(Some(&directory)) + .await + .unwrap(); + assert_eq!(result.objects.len(), 1); + assert!(result.common_prefixes.is_empty()); + assert_eq!(result.objects[0].location, object); + + let illegal = root.join("💀"); + std::fs::write(illegal, "foo").unwrap(); + + // Can list directory that doesn't contain illegal path + flatten_list_stream(&integration, Some(&directory)) + .await + .unwrap(); + + // Cannot list illegal file + let err = flatten_list_stream(&integration, None) + .await + .unwrap_err() + .to_string(); + + assert!( + err.contains("Encountered illegal character sequence \"💀\" whilst parsing path segment \"💀\""), + "{}", + err + ); + } + + #[tokio::test] + async fn list_hides_incomplete_uploads() { + let root = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(root.path()).unwrap(); + let location = Path::from("some_file"); + + let data = Bytes::from("arbitrary data"); + let (multipart_id, mut writer) = + integration.put_multipart(&location).await.unwrap(); + writer.write_all(&data).await.unwrap(); + + let (multipart_id_2, mut writer_2) = + integration.put_multipart(&location).await.unwrap(); + assert_ne!(multipart_id, multipart_id_2); + writer_2.write_all(&data).await.unwrap(); + + let list = flatten_list_stream(&integration, None).await.unwrap(); + assert_eq!(list.len(), 0); + + assert_eq!( + integration + .list_with_delimiter(None) + .await + .unwrap() + .objects + .len(), + 0 + ); + } + + #[tokio::test] + async fn filesystem_filename_with_percent() { + let temp_dir = TempDir::new().unwrap(); + let integration = LocalFileSystem::new_with_prefix(temp_dir.path()).unwrap(); + let filename = "L%3ABC.parquet"; + + std::fs::write(temp_dir.path().join(filename), "foo").unwrap(); + + let list_stream = integration.list(None).await.unwrap(); + let res: Vec<_> = list_stream.try_collect().await.unwrap(); + assert_eq!(res.len(), 1); + assert_eq!(res[0].location.as_ref(), filename); + + let res = integration.list_with_delimiter(None).await.unwrap(); + assert_eq!(res.objects.len(), 1); + assert_eq!(res.objects[0].location.as_ref(), filename); + } + + #[tokio::test] + async fn relative_paths() { + LocalFileSystem::new_with_prefix(".").unwrap(); + LocalFileSystem::new_with_prefix("..").unwrap(); + LocalFileSystem::new_with_prefix("../..").unwrap(); + + let integration = LocalFileSystem::new(); + let path = Path::from_filesystem_path(".").unwrap(); + integration.list_with_delimiter(Some(&path)).await.unwrap(); + } +} diff --git a/object_store/src/memory.rs b/object_store/src/memory.rs new file mode 100644 index 000000000000..e4be5b2afddf --- /dev/null +++ b/object_store/src/memory.rs @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An in-memory object store implementation +use crate::MultipartId; +use crate::{path::Path, GetResult, ListResult, ObjectMeta, ObjectStore, Result}; +use async_trait::async_trait; +use bytes::Bytes; +use chrono::Utc; +use futures::{stream::BoxStream, StreamExt}; +use parking_lot::RwLock; +use snafu::{ensure, OptionExt, Snafu}; +use std::collections::BTreeMap; +use std::collections::BTreeSet; +use std::io; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; +use tokio::io::AsyncWrite; + +/// A specialized `Error` for in-memory object store-related errors +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +enum Error { + #[snafu(display("No data in memory found. Location: {path}"))] + NoDataInMemory { path: String }, + + #[snafu(display("Out of range"))] + OutOfRange, + + #[snafu(display("Bad range"))] + BadRange, + + #[snafu(display("Object already exists at that location: {path}"))] + AlreadyExists { path: String }, +} + +impl From for super::Error { + fn from(source: Error) -> Self { + match source { + Error::NoDataInMemory { ref path } => Self::NotFound { + path: path.into(), + source: source.into(), + }, + Error::AlreadyExists { ref path } => Self::AlreadyExists { + path: path.into(), + source: source.into(), + }, + _ => Self::Generic { + store: "InMemory", + source: Box::new(source), + }, + } + } +} + +/// In-memory storage suitable for testing or for opting out of using a cloud +/// storage provider. +#[derive(Debug, Default)] +pub struct InMemory { + storage: Arc>>, +} + +impl std::fmt::Display for InMemory { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "InMemory") + } +} + +#[async_trait] +impl ObjectStore for InMemory { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + self.storage.write().insert(location.clone(), bytes); + Ok(()) + } + + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + Ok(( + String::new(), + Box::new(InMemoryUpload { + location: location.clone(), + data: Vec::new(), + storage: Arc::clone(&self.storage), + }), + )) + } + + async fn abort_multipart( + &self, + _location: &Path, + _multipart_id: &MultipartId, + ) -> Result<()> { + // Nothing to clean up + Ok(()) + } + + async fn get(&self, location: &Path) -> Result { + let data = self.get_bytes(location).await?; + + Ok(GetResult::Stream( + futures::stream::once(async move { Ok(data) }).boxed(), + )) + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let data = self.get_bytes(location).await?; + ensure!(range.end <= data.len(), OutOfRangeSnafu); + ensure!(range.start <= range.end, BadRangeSnafu); + + Ok(data.slice(range)) + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> Result> { + let data = self.get_bytes(location).await?; + ranges + .iter() + .map(|range| { + ensure!(range.end <= data.len(), OutOfRangeSnafu); + ensure!(range.start <= range.end, BadRangeSnafu); + Ok(data.slice(range.clone())) + }) + .collect() + } + + async fn head(&self, location: &Path) -> Result { + let last_modified = Utc::now(); + let bytes = self.get_bytes(location).await?; + Ok(ObjectMeta { + location: location.clone(), + last_modified, + size: bytes.len(), + }) + } + + async fn delete(&self, location: &Path) -> Result<()> { + self.storage.write().remove(location); + Ok(()) + } + + async fn list( + &self, + prefix: Option<&Path>, + ) -> Result>> { + let last_modified = Utc::now(); + + let storage = self.storage.read(); + let values: Vec<_> = storage + .iter() + .filter(move |(key, _)| prefix.map(|p| key.prefix_matches(p)).unwrap_or(true)) + .map(move |(key, value)| { + Ok(ObjectMeta { + location: key.clone(), + last_modified, + size: value.len(), + }) + }) + .collect(); + + Ok(futures::stream::iter(values).boxed()) + } + + /// The memory implementation returns all results, as opposed to the cloud + /// versions which limit their results to 1k or more because of API + /// limitations. + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + let root = Path::default(); + let prefix = prefix.unwrap_or(&root); + + let mut common_prefixes = BTreeSet::new(); + let last_modified = Utc::now(); + + // Only objects in this base level should be returned in the + // response. Otherwise, we just collect the common prefixes. + let mut objects = vec![]; + for (k, v) in self.storage.read().range((prefix)..) { + let mut parts = match k.prefix_match(prefix) { + Some(parts) => parts, + None => break, + }; + + // Pop first element + let common_prefix = match parts.next() { + Some(p) => p, + None => continue, + }; + + if parts.next().is_some() { + common_prefixes.insert(prefix.child(common_prefix)); + } else { + let object = ObjectMeta { + location: k.clone(), + last_modified, + size: v.len(), + }; + objects.push(object); + } + } + + Ok(ListResult { + objects, + common_prefixes: common_prefixes.into_iter().collect(), + }) + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + let data = self.get_bytes(from).await?; + self.storage.write().insert(to.clone(), data); + Ok(()) + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + let data = self.get_bytes(from).await?; + let mut storage = self.storage.write(); + if storage.contains_key(to) { + return Err(Error::AlreadyExists { + path: to.to_string(), + } + .into()); + } + storage.insert(to.clone(), data); + Ok(()) + } +} + +impl InMemory { + /// Create new in-memory storage. + pub fn new() -> Self { + Self::default() + } + + /// Creates a clone of the store + pub async fn clone(&self) -> Self { + let storage = self.storage.read(); + let storage = storage.clone(); + + Self { + storage: Arc::new(RwLock::new(storage)), + } + } + + async fn get_bytes(&self, location: &Path) -> Result { + let storage = self.storage.read(); + let bytes = storage + .get(location) + .cloned() + .context(NoDataInMemorySnafu { + path: location.to_string(), + })?; + Ok(bytes) + } +} + +struct InMemoryUpload { + location: Path, + data: Vec, + storage: Arc>>, +} + +impl AsyncWrite for InMemoryUpload { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + self.data.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let data = Bytes::from(std::mem::take(&mut self.data)); + self.storage.write().insert(self.location.clone(), data); + Poll::Ready(Ok(())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{ + tests::{ + copy_if_not_exists, get_nonexistent_object, list_uses_directories_correctly, + list_with_delimiter, put_get_delete_list, rename_and_copy, stream_get, + }, + Error as ObjectStoreError, ObjectStore, + }; + + #[tokio::test] + async fn in_memory_test() { + let integration = InMemory::new(); + + put_get_delete_list(&integration).await; + list_uses_directories_correctly(&integration).await; + list_with_delimiter(&integration).await; + rename_and_copy(&integration).await; + copy_if_not_exists(&integration).await; + stream_get(&integration).await; + } + + #[tokio::test] + async fn unknown_length() { + let integration = InMemory::new(); + + let location = Path::from("some_file"); + + let data = Bytes::from("arbitrary data"); + let expected_data = data.clone(); + + integration.put(&location, data).await.unwrap(); + + let read_data = integration + .get(&location) + .await + .unwrap() + .bytes() + .await + .unwrap(); + assert_eq!(&*read_data, expected_data); + } + + const NON_EXISTENT_NAME: &str = "nonexistentname"; + + #[tokio::test] + async fn nonexistent_location() { + let integration = InMemory::new(); + + let location = Path::from(NON_EXISTENT_NAME); + + let err = get_nonexistent_object(&integration, Some(location)) + .await + .unwrap_err(); + if let ObjectStoreError::NotFound { path, source } = err { + let source_variant = source.downcast_ref::(); + assert!( + matches!(source_variant, Some(Error::NoDataInMemory { .. }),), + "got: {:?}", + source_variant + ); + assert_eq!(path, NON_EXISTENT_NAME); + } else { + panic!("unexpected error type: {:?}", err); + } + } +} diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs new file mode 100644 index 000000000000..1985d8694e50 --- /dev/null +++ b/object_store/src/multipart.rs @@ -0,0 +1,212 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use async_trait::async_trait; +use futures::{stream::FuturesUnordered, Future, StreamExt}; +use std::{io, pin::Pin, sync::Arc, task::Poll}; +use tokio::io::AsyncWrite; + +use crate::Result; + +type BoxedTryFuture = Pin> + Send>>; + +/// A trait that can be implemented by cloud-based object stores +/// and used in combination with [`CloudMultiPartUpload`] to provide +/// multipart upload support +#[async_trait] +pub(crate) trait CloudMultiPartUploadImpl: 'static { + /// Upload a single part + async fn put_multipart_part( + &self, + buf: Vec, + part_idx: usize, + ) -> Result; + + /// Complete the upload with the provided parts + /// + /// `completed_parts` is in order of part number + async fn complete(&self, completed_parts: Vec) -> Result<(), io::Error>; +} + +#[derive(Debug, Clone)] +pub(crate) struct UploadPart { + pub content_id: String, +} + +pub(crate) struct CloudMultiPartUpload +where + T: CloudMultiPartUploadImpl, +{ + inner: Arc, + /// A list of completed parts, in sequential order. + completed_parts: Vec>, + /// Part upload tasks currently running + tasks: FuturesUnordered>, + /// Maximum number of upload tasks to run concurrently + max_concurrency: usize, + /// Buffer that will be sent in next upload. + current_buffer: Vec, + /// Minimum size of a part in bytes + min_part_size: usize, + /// Index of current part + current_part_idx: usize, + /// The completion task + completion_task: Option>, +} + +impl CloudMultiPartUpload +where + T: CloudMultiPartUploadImpl, +{ + pub fn new(inner: T, max_concurrency: usize) -> Self { + Self { + inner: Arc::new(inner), + completed_parts: Vec::new(), + tasks: FuturesUnordered::new(), + max_concurrency, + current_buffer: Vec::new(), + // TODO: Should self vary by provider? + // TODO: Should we automatically increase then when part index gets large? + min_part_size: 5_000_000, + current_part_idx: 0, + completion_task: None, + } + } + + pub fn poll_tasks( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Result<(), io::Error> { + if self.tasks.is_empty() { + return Ok(()); + } + let total_parts = self.completed_parts.len(); + while let Poll::Ready(Some(res)) = self.tasks.poll_next_unpin(cx) { + let (part_idx, part) = res?; + self.completed_parts + .resize(std::cmp::max(part_idx + 1, total_parts), None); + self.completed_parts[part_idx] = Some(part); + } + Ok(()) + } +} + +impl AsyncWrite for CloudMultiPartUpload +where + T: CloudMultiPartUploadImpl + Send + Sync, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + // Poll current tasks + self.as_mut().poll_tasks(cx)?; + + // If adding buf to pending buffer would trigger send, check + // whether we have capacity for another task. + let enough_to_send = (buf.len() + self.current_buffer.len()) > self.min_part_size; + if enough_to_send && self.tasks.len() < self.max_concurrency { + // If we do, copy into the buffer and submit the task, and return ready. + self.current_buffer.extend_from_slice(buf); + + let out_buffer = std::mem::take(&mut self.current_buffer); + let inner = Arc::clone(&self.inner); + let part_idx = self.current_part_idx; + self.tasks.push(Box::pin(async move { + let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?; + Ok((part_idx, upload_part)) + })); + self.current_part_idx += 1; + + // We need to poll immediately after adding to setup waker + self.as_mut().poll_tasks(cx)?; + + Poll::Ready(Ok(buf.len())) + } else if !enough_to_send { + self.current_buffer.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } else { + // Waker registered by call to poll_tasks at beginning + Poll::Pending + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // Poll current tasks + self.as_mut().poll_tasks(cx)?; + + // If current_buffer is not empty, see if it can be submitted + if !self.current_buffer.is_empty() && self.tasks.len() < self.max_concurrency { + let out_buffer: Vec = std::mem::take(&mut self.current_buffer); + let inner = Arc::clone(&self.inner); + let part_idx = self.current_part_idx; + self.tasks.push(Box::pin(async move { + let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?; + Ok((part_idx, upload_part)) + })); + } + + self.as_mut().poll_tasks(cx)?; + + // If tasks and current_buffer are empty, return Ready + if self.tasks.is_empty() && self.current_buffer.is_empty() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // First, poll flush + match self.as_mut().poll_flush(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(res) => res?, + }; + + // If shutdown task is not set, set it + let parts = std::mem::take(&mut self.completed_parts); + let parts = parts + .into_iter() + .enumerate() + .map(|(idx, part)| { + part.ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + format!("Missing information for upload part {}", idx), + ) + }) + }) + .collect::>()?; + + let inner = Arc::clone(&self.inner); + let completion_task = self.completion_task.get_or_insert_with(|| { + Box::pin(async move { + inner.complete(parts).await?; + Ok(()) + }) + }); + + Pin::new(completion_task).poll(cx) + } +} diff --git a/object_store/src/path/mod.rs b/object_store/src/path/mod.rs new file mode 100644 index 000000000000..e5a7b6443bb1 --- /dev/null +++ b/object_store/src/path/mod.rs @@ -0,0 +1,537 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Path abstraction for Object Storage + +use itertools::Itertools; +use percent_encoding::percent_decode; +use snafu::{ensure, ResultExt, Snafu}; +use std::fmt::Formatter; +use url::Url; + +/// The delimiter to separate object namespaces, creating a directory structure. +pub const DELIMITER: &str = "/"; + +/// The path delimiter as a single byte +pub const DELIMITER_BYTE: u8 = DELIMITER.as_bytes()[0]; + +mod parts; + +pub use parts::{InvalidPart, PathPart}; + +/// Error returned by [`Path::parse`] +#[derive(Debug, Snafu)] +#[allow(missing_docs)] +pub enum Error { + #[snafu(display("Path \"{}\" contained empty path segment", path))] + EmptySegment { path: String }, + + #[snafu(display("Error parsing Path \"{}\": {}", path, source))] + BadSegment { path: String, source: InvalidPart }, + + #[snafu(display("Failed to canonicalize path \"{}\": {}", path.display(), source))] + Canonicalize { + path: std::path::PathBuf, + source: std::io::Error, + }, + + #[snafu(display("Unable to convert path \"{}\" to URL", path.display()))] + InvalidPath { path: std::path::PathBuf }, + + #[snafu(display("Path \"{}\" contained non-unicode characters: {}", path, source))] + NonUnicode { + path: String, + source: std::str::Utf8Error, + }, + + #[snafu(display("Path {} does not start with prefix {}", path, prefix))] + PrefixMismatch { path: String, prefix: String }, +} + +/// A parsed path representation that can be safely written to object storage +/// +/// # Path Safety +/// +/// In theory object stores support any UTF-8 character sequence, however, certain character +/// sequences cause compatibility problems with some applications and protocols. As such the +/// naming guidelines for [S3], [GCS] and [Azure Blob Storage] all recommend sticking to a +/// limited character subset. +/// +/// [S3]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html +/// [GCS]: https://cloud.google.com/storage/docs/naming-objects +/// [Azure Blob Storage]: https://docs.microsoft.com/en-us/rest/api/storageservices/Naming-and-Referencing-Containers--Blobs--and-Metadata#blob-names +/// +/// This presents libraries with two options for consistent path handling: +/// +/// 1. Allow constructing unsafe paths, allowing for both reading and writing of data to paths +/// that may not be consistently understood or supported +/// 2. Disallow constructing unsafe paths, ensuring data written can be consistently handled by +/// all other systems, but preventing interaction with objects at unsafe paths +/// +/// This library takes the second approach, in particular: +/// +/// * Paths are delimited by `/` +/// * Paths do not start with a `/` +/// * Empty path segments are discarded (e.g. `//` is treated as though it were `/`) +/// * Relative path segments, i.e. `.` and `..` are percent encoded +/// * Unsafe characters are percent encoded, as described by [RFC 1738] +/// * All paths are relative to the root of the object store +/// +/// In order to provide these guarantees there are two ways to safely construct a [`Path`] +/// +/// # Encode +/// +/// A string containing potentially illegal path segments can be encoded to a [`Path`] +/// using [`Path::from`] or [`Path::from_iter`]. +/// +/// ``` +/// # use object_store::path::Path; +/// assert_eq!(Path::from("foo/bar").as_ref(), "foo/bar"); +/// assert_eq!(Path::from("foo//bar").as_ref(), "foo/bar"); +/// assert_eq!(Path::from("foo/../bar").as_ref(), "foo/%2E%2E/bar"); +/// assert_eq!(Path::from_iter(["foo", "foo/bar"]).as_ref(), "foo/foo%2Fbar"); +/// ``` +/// +/// Note: if provided with an already percent encoded string, this will encode it again +/// +/// ``` +/// # use object_store::path::Path; +/// assert_eq!(Path::from("foo/foo%2Fbar").as_ref(), "foo/foo%252Fbar"); +/// ``` +/// +/// # Parse +/// +/// Alternatively a [`Path`] can be created from an existing string, returning an +/// error if it is invalid. Unlike the encoding methods, this will permit +/// valid percent encoded sequences. +/// +/// ``` +/// # use object_store::path::Path; +/// +/// assert_eq!(Path::parse("/foo/foo%2Fbar").unwrap().as_ref(), "foo/foo%2Fbar"); +/// Path::parse("..").unwrap_err(); +/// Path::parse("/foo//").unwrap_err(); +/// Path::parse("😀").unwrap_err(); +/// ``` +/// +/// [RFC 1738]: https://www.ietf.org/rfc/rfc1738.txt +#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct Path { + /// The raw path with no leading or trailing delimiters + raw: String, +} + +impl Path { + /// Parse a string as a [`Path`], returning a [`Error`] if invalid, + /// as defined on the docstring for [`Path`] + /// + /// Note: this will strip any leading `/` or trailing `/` + pub fn parse(path: impl AsRef) -> Result { + let path = path.as_ref(); + + let stripped = path.strip_prefix(DELIMITER).unwrap_or(path); + if stripped.is_empty() { + return Ok(Default::default()); + } + + let stripped = stripped.strip_suffix(DELIMITER).unwrap_or(stripped); + + for segment in stripped.split(DELIMITER) { + ensure!(!segment.is_empty(), EmptySegmentSnafu { path }); + PathPart::parse(segment).context(BadSegmentSnafu { path })?; + } + + Ok(Self { + raw: stripped.to_string(), + }) + } + + /// Convert a filesystem path to a [`Path`] relative to the filesystem root + /// + /// This will return an error if the path contains illegal character sequences + /// as defined by [`Path::parse`] or does not exist + /// + /// Note: this will canonicalize the provided path, resolving any symlinks + pub fn from_filesystem_path( + path: impl AsRef, + ) -> Result { + let absolute = std::fs::canonicalize(&path).context(CanonicalizeSnafu { + path: path.as_ref(), + })?; + + Self::from_absolute_path(absolute) + } + + /// Convert an absolute filesystem path to a [`Path`] relative to the filesystem root + /// + /// This will return an error if the path contains illegal character sequences + /// as defined by [`Path::parse`], or `base` is not an absolute path + pub fn from_absolute_path(path: impl AsRef) -> Result { + Self::from_absolute_path_with_base(path, None) + } + + /// Convert a filesystem path to a [`Path`] relative to the provided base + /// + /// This will return an error if the path contains illegal character sequences + /// as defined by [`Path::parse`], or `base` does not refer to a parent path of `path`, + /// or `base` is not an absolute path + pub(crate) fn from_absolute_path_with_base( + path: impl AsRef, + base: Option<&Url>, + ) -> Result { + let url = absolute_path_to_url(path)?; + let path = match base { + Some(prefix) => url.path().strip_prefix(prefix.path()).ok_or_else(|| { + Error::PrefixMismatch { + path: url.path().to_string(), + prefix: prefix.to_string(), + } + })?, + None => url.path(), + }; + + // Reverse any percent encoding performed by conversion to URL + let decoded = percent_decode(path.as_bytes()) + .decode_utf8() + .context(NonUnicodeSnafu { path })?; + + Self::parse(decoded) + } + + /// Returns the [`PathPart`] of this [`Path`] + pub fn parts(&self) -> impl Iterator> { + match self.raw.is_empty() { + true => itertools::Either::Left(std::iter::empty()), + false => itertools::Either::Right( + self.raw + .split(DELIMITER) + .map(|s| PathPart { raw: s.into() }), + ), + } + } + + /// Returns an iterator of the [`PathPart`] of this [`Path`] after `prefix` + /// + /// Returns `None` if the prefix does not match + pub fn prefix_match( + &self, + prefix: &Self, + ) -> Option> + '_> { + let diff = itertools::diff_with(self.parts(), prefix.parts(), |a, b| a == b); + + match diff { + // Both were equal + None => Some(itertools::Either::Left(std::iter::empty())), + // Mismatch or prefix was longer => None + Some( + itertools::Diff::FirstMismatch(_, _, _) | itertools::Diff::Longer(_, _), + ) => None, + // Match with remaining + Some(itertools::Diff::Shorter(_, back)) => { + Some(itertools::Either::Right(back)) + } + } + } + + /// Returns true if this [`Path`] starts with `prefix` + pub fn prefix_matches(&self, prefix: &Self) -> bool { + self.prefix_match(prefix).is_some() + } + + /// Creates a new child of this [`Path`] + pub fn child<'a>(&self, child: impl Into>) -> Self { + let raw = match self.raw.is_empty() { + true => format!("{}", child.into().raw), + false => format!("{}{}{}", self.raw, DELIMITER, child.into().raw), + }; + + Self { raw } + } +} + +impl AsRef for Path { + fn as_ref(&self) -> &str { + &self.raw + } +} + +impl From<&str> for Path { + fn from(path: &str) -> Self { + Self::from_iter(path.split(DELIMITER)) + } +} + +impl From for Path { + fn from(path: String) -> Self { + Self::from_iter(path.split(DELIMITER)) + } +} + +impl From for String { + fn from(path: Path) -> Self { + path.raw + } +} + +impl std::fmt::Display for Path { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.raw.fmt(f) + } +} + +impl<'a, I> FromIterator for Path +where + I: Into>, +{ + fn from_iter>(iter: T) -> Self { + let raw = T::into_iter(iter) + .map(|s| s.into()) + .filter(|s| !s.raw.is_empty()) + .map(|s| s.raw) + .join(DELIMITER); + + Self { raw } + } +} + +/// Given an absolute filesystem path convert it to a URL representation without canonicalization +pub(crate) fn absolute_path_to_url( + path: impl AsRef, +) -> Result { + Url::from_file_path(&path).map_err(|_| Error::InvalidPath { + path: path.as_ref().into(), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cloud_prefix_with_trailing_delimiter() { + // Use case: files exist in object storage named `foo/bar.json` and + // `foo_test.json`. A search for the prefix `foo/` should return + // `foo/bar.json` but not `foo_test.json'. + let prefix = Path::from_iter(["test"]); + assert_eq!(prefix.as_ref(), "test"); + } + + #[test] + fn push_encodes() { + let location = Path::from_iter(["foo/bar", "baz%2Ftest"]); + assert_eq!(location.as_ref(), "foo%2Fbar/baz%252Ftest"); + } + + #[test] + fn test_parse() { + assert_eq!(Path::parse("/").unwrap().as_ref(), ""); + assert_eq!(Path::parse("").unwrap().as_ref(), ""); + + let err = Path::parse("//").unwrap_err(); + assert!(matches!(err, Error::EmptySegment { .. })); + + assert_eq!(Path::parse("/foo/bar/").unwrap().as_ref(), "foo/bar"); + assert_eq!(Path::parse("foo/bar/").unwrap().as_ref(), "foo/bar"); + assert_eq!(Path::parse("foo/bar").unwrap().as_ref(), "foo/bar"); + + let err = Path::parse("foo///bar").unwrap_err(); + assert!(matches!(err, Error::EmptySegment { .. })); + } + + #[test] + fn convert_raw_before_partial_eq() { + // dir and file_name + let cloud = Path::from("test_dir/test_file.json"); + let built = Path::from_iter(["test_dir", "test_file.json"]); + + assert_eq!(built, cloud); + + // dir and file_name w/o dot + let cloud = Path::from("test_dir/test_file"); + let built = Path::from_iter(["test_dir", "test_file"]); + + assert_eq!(built, cloud); + + // dir, no file + let cloud = Path::from("test_dir/"); + let built = Path::from_iter(["test_dir"]); + assert_eq!(built, cloud); + + // file_name, no dir + let cloud = Path::from("test_file.json"); + let built = Path::from_iter(["test_file.json"]); + assert_eq!(built, cloud); + + // empty + let cloud = Path::from(""); + let built = Path::from_iter(["", ""]); + + assert_eq!(built, cloud); + } + + #[test] + fn parts_after_prefix_behavior() { + let existing_path = Path::from("apple/bear/cow/dog/egg.json"); + + // Prefix with one directory + let prefix = Path::from("apple"); + let expected_parts: Vec> = vec!["bear", "cow", "dog", "egg.json"] + .into_iter() + .map(Into::into) + .collect(); + let parts: Vec<_> = existing_path.prefix_match(&prefix).unwrap().collect(); + assert_eq!(parts, expected_parts); + + // Prefix with two directories + let prefix = Path::from("apple/bear"); + let expected_parts: Vec> = vec!["cow", "dog", "egg.json"] + .into_iter() + .map(Into::into) + .collect(); + let parts: Vec<_> = existing_path.prefix_match(&prefix).unwrap().collect(); + assert_eq!(parts, expected_parts); + + // Not a prefix + let prefix = Path::from("cow"); + assert!(existing_path.prefix_match(&prefix).is_none()); + + // Prefix with a partial directory + let prefix = Path::from("ap"); + assert!(existing_path.prefix_match(&prefix).is_none()); + + // Prefix matches but there aren't any parts after it + let existing_path = Path::from("apple/bear/cow/dog"); + + let prefix = existing_path.clone(); + assert_eq!(existing_path.prefix_match(&prefix).unwrap().count(), 0); + } + + #[test] + fn prefix_matches() { + let haystack = Path::from_iter(["foo/bar", "baz%2Ftest", "something"]); + let needle = haystack.clone(); + // self starts with self + assert!( + haystack.prefix_matches(&haystack), + "{:?} should have started with {:?}", + haystack, + haystack + ); + + // a longer prefix doesn't match + let needle = needle.child("longer now"); + assert!( + !haystack.prefix_matches(&needle), + "{:?} shouldn't have started with {:?}", + haystack, + needle + ); + + // one dir prefix matches + let needle = Path::from_iter(["foo/bar"]); + assert!( + haystack.prefix_matches(&needle), + "{:?} should have started with {:?}", + haystack, + needle + ); + + // two dir prefix matches + let needle = needle.child("baz%2Ftest"); + assert!( + haystack.prefix_matches(&needle), + "{:?} should have started with {:?}", + haystack, + needle + ); + + // partial dir prefix doesn't match + let needle = Path::from_iter(["f"]); + assert!( + !haystack.prefix_matches(&needle), + "{:?} should not have started with {:?}", + haystack, + needle + ); + + // one dir and one partial dir doesn't match + let needle = Path::from_iter(["foo/bar", "baz"]); + assert!( + !haystack.prefix_matches(&needle), + "{:?} should not have started with {:?}", + haystack, + needle + ); + + // empty prefix matches + let needle = Path::from(""); + assert!( + haystack.prefix_matches(&needle), + "{:?} should have started with {:?}", + haystack, + needle + ); + } + + #[test] + fn prefix_matches_with_file_name() { + let haystack = + Path::from_iter(["foo/bar", "baz%2Ftest", "something", "foo.segment"]); + + // All directories match and file name is a prefix + let needle = Path::from_iter(["foo/bar", "baz%2Ftest", "something", "foo"]); + + assert!( + !haystack.prefix_matches(&needle), + "{:?} should not have started with {:?}", + haystack, + needle + ); + + // All directories match but file name is not a prefix + let needle = Path::from_iter(["foo/bar", "baz%2Ftest", "something", "e"]); + + assert!( + !haystack.prefix_matches(&needle), + "{:?} should not have started with {:?}", + haystack, + needle + ); + + // Not all directories match; file name is a prefix of the next directory; this + // does not match + let needle = Path::from_iter(["foo/bar", "baz%2Ftest", "s"]); + + assert!( + !haystack.prefix_matches(&needle), + "{:?} should not have started with {:?}", + haystack, + needle + ); + + // Not all directories match; file name is NOT a prefix of the next directory; + // no match + let needle = Path::from_iter(["foo/bar", "baz%2Ftest", "p"]); + + assert!( + !haystack.prefix_matches(&needle), + "{:?} should not have started with {:?}", + haystack, + needle + ); + } +} diff --git a/object_store/src/path/parts.rs b/object_store/src/path/parts.rs new file mode 100644 index 000000000000..9da4815712db --- /dev/null +++ b/object_store/src/path/parts.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use percent_encoding::{percent_encode, AsciiSet, CONTROLS}; +use std::borrow::Cow; + +use crate::path::DELIMITER_BYTE; +use snafu::Snafu; + +/// Error returned by [`PathPart::parse`] +#[derive(Debug, Snafu)] +#[snafu(display( + "Encountered illegal character sequence \"{}\" whilst parsing path segment \"{}\"", + illegal, + segment +))] +#[allow(missing_copy_implementations)] +pub struct InvalidPart { + segment: String, + illegal: String, +} + +/// The PathPart type exists to validate the directory/file names that form part +/// of a path. +/// +/// A PathPart instance is guaranteed to to contain no illegal characters (e.g. `/`) +/// as it can only be constructed by going through the `from` impl. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] +pub struct PathPart<'a> { + pub(super) raw: Cow<'a, str>, +} + +impl<'a> PathPart<'a> { + /// Parse the provided path segment as a [`PathPart`] returning an error if invalid + pub fn parse(segment: &'a str) -> Result { + if segment == "." || segment == ".." { + return Err(InvalidPart { + segment: segment.to_string(), + illegal: segment.to_string(), + }); + } + + for (idx, b) in segment.as_bytes().iter().cloned().enumerate() { + // A percent character is always valid, even if not + // followed by a valid 2-digit hex code + // https://url.spec.whatwg.org/#percent-encoded-bytes + if b == b'%' { + continue; + } + + if !b.is_ascii() || should_percent_encode(b) { + return Err(InvalidPart { + segment: segment.to_string(), + // This is correct as only single byte characters up to this point + illegal: segment.chars().nth(idx).unwrap().to_string(), + }); + } + } + + Ok(Self { + raw: segment.into(), + }) + } +} + +fn should_percent_encode(c: u8) -> bool { + percent_encode(&[c], INVALID).next().unwrap().len() != 1 +} + +/// Characters we want to encode. +const INVALID: &AsciiSet = &CONTROLS + // The delimiter we are reserving for internal hierarchy + .add(DELIMITER_BYTE) + // Characters AWS recommends avoiding for object keys + // https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingMetadata.html + .add(b'\\') + .add(b'{') + .add(b'^') + .add(b'}') + .add(b'%') + .add(b'`') + .add(b']') + .add(b'"') // " <-- my editor is confused about double quotes within single quotes + .add(b'>') + .add(b'[') + .add(b'~') + .add(b'<') + .add(b'#') + .add(b'|') + // Characters Google Cloud Storage recommends avoiding for object names + // https://cloud.google.com/storage/docs/naming-objects + .add(b'\r') + .add(b'\n') + .add(b'*') + .add(b'?'); + +impl<'a> From<&'a [u8]> for PathPart<'a> { + fn from(v: &'a [u8]) -> Self { + let inner = match v { + // We don't want to encode `.` generally, but we do want to disallow parts of paths + // to be equal to `.` or `..` to prevent file system traversal shenanigans. + b"." => "%2E".into(), + b".." => "%2E%2E".into(), + other => percent_encode(other, INVALID).into(), + }; + Self { raw: inner } + } +} + +impl<'a> From<&'a str> for PathPart<'a> { + fn from(v: &'a str) -> Self { + Self::from(v.as_bytes()) + } +} + +impl From for PathPart<'static> { + fn from(s: String) -> Self { + Self { + raw: Cow::Owned(PathPart::from(s.as_str()).raw.into_owned()), + } + } +} + +impl<'a> AsRef for PathPart<'a> { + fn as_ref(&self) -> &str { + self.raw.as_ref() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn path_part_delimiter_gets_encoded() { + let part: PathPart<'_> = "foo/bar".into(); + assert_eq!(part.raw, "foo%2Fbar"); + } + + #[test] + fn path_part_given_already_encoded_string() { + let part: PathPart<'_> = "foo%2Fbar".into(); + assert_eq!(part.raw, "foo%252Fbar"); + } + + #[test] + fn path_part_cant_be_one_dot() { + let part: PathPart<'_> = ".".into(); + assert_eq!(part.raw, "%2E"); + } + + #[test] + fn path_part_cant_be_two_dots() { + let part: PathPart<'_> = "..".into(); + assert_eq!(part.raw, "%2E%2E"); + } + + #[test] + fn path_part_parse() { + PathPart::parse("foo").unwrap(); + PathPart::parse("foo/bar").unwrap_err(); + + // Test percent-encoded path + PathPart::parse("foo%2Fbar").unwrap(); + PathPart::parse("L%3ABC.parquet").unwrap(); + + // Test path containing bad escape sequence + PathPart::parse("%Z").unwrap(); + PathPart::parse("%%").unwrap(); + } +} diff --git a/object_store/src/throttle.rs b/object_store/src/throttle.rs new file mode 100644 index 000000000000..90f427cc2651 --- /dev/null +++ b/object_store/src/throttle.rs @@ -0,0 +1,589 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A throttling object store wrapper +use parking_lot::Mutex; +use std::ops::Range; +use std::{convert::TryInto, sync::Arc}; + +use crate::MultipartId; +use crate::{path::Path, GetResult, ListResult, ObjectMeta, ObjectStore, Result}; +use async_trait::async_trait; +use bytes::Bytes; +use futures::{stream::BoxStream, StreamExt}; +use std::time::Duration; +use tokio::io::AsyncWrite; + +/// Configuration settings for throttled store +#[derive(Debug, Default, Clone, Copy)] +pub struct ThrottleConfig { + /// Sleep duration for every call to [`delete`](ThrottledStore::delete). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. + pub wait_delete_per_call: Duration, + + /// Sleep duration for every byte received during [`get`](ThrottledStore::get). + /// + /// Sleeping is performed after the underlying store returned and only for successful gets. The + /// sleep duration is additive to [`wait_get_per_call`](Self::wait_get_per_call). + /// + /// Note that the per-byte sleep only happens as the user consumes the output bytes. Should + /// there be an intermediate failure (i.e. after partly consuming the output bytes), the + /// resulting sleep time will be partial as well. + pub wait_get_per_byte: Duration, + + /// Sleep duration for every call to [`get`](ThrottledStore::get). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. The sleep duration is additive to + /// [`wait_get_per_byte`](Self::wait_get_per_byte). + pub wait_get_per_call: Duration, + + /// Sleep duration for every call to [`list`](ThrottledStore::list). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. The sleep duration is additive to + /// [`wait_list_per_entry`](Self::wait_list_per_entry). + pub wait_list_per_call: Duration, + + /// Sleep duration for every entry received during [`list`](ThrottledStore::list). + /// + /// Sleeping is performed after the underlying store returned and only for successful lists. + /// The sleep duration is additive to [`wait_list_per_call`](Self::wait_list_per_call). + /// + /// Note that the per-entry sleep only happens as the user consumes the output entries. Should + /// there be an intermediate failure (i.e. after partly consuming the output entries), the + /// resulting sleep time will be partial as well. + pub wait_list_per_entry: Duration, + + /// Sleep duration for every call to + /// [`list_with_delimiter`](ThrottledStore::list_with_delimiter). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. The sleep duration is additive to + /// [`wait_list_with_delimiter_per_entry`](Self::wait_list_with_delimiter_per_entry). + pub wait_list_with_delimiter_per_call: Duration, + + /// Sleep duration for every entry received during + /// [`list_with_delimiter`](ThrottledStore::list_with_delimiter). + /// + /// Sleeping is performed after the underlying store returned and only for successful gets. The + /// sleep duration is additive to + /// [`wait_list_with_delimiter_per_call`](Self::wait_list_with_delimiter_per_call). + pub wait_list_with_delimiter_per_entry: Duration, + + /// Sleep duration for every call to [`put`](ThrottledStore::put). + /// + /// Sleeping is done before the underlying store is called and independently of the success of + /// the operation. + pub wait_put_per_call: Duration, +} + +/// Sleep only if non-zero duration +async fn sleep(duration: Duration) { + if !duration.is_zero() { + tokio::time::sleep(duration).await + } +} + +/// Store wrapper that wraps an inner store with some `sleep` calls. +/// +/// This can be used for performance testing. +/// +/// **Note that the behavior of the wrapper is deterministic and might not reflect real-world +/// conditions!** +#[derive(Debug)] +pub struct ThrottledStore { + inner: T, + config: Arc>, +} + +impl ThrottledStore { + /// Create new wrapper with zero waiting times. + pub fn new(inner: T, config: ThrottleConfig) -> Self { + Self { + inner, + config: Arc::new(Mutex::new(config)), + } + } + + /// Mutate config. + pub fn config_mut(&self, f: F) + where + F: Fn(&mut ThrottleConfig), + { + let mut guard = self.config.lock(); + f(&mut guard) + } + + /// Return copy of current config. + pub fn config(&self) -> ThrottleConfig { + *self.config.lock() + } +} + +impl std::fmt::Display for ThrottledStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ThrottledStore({})", self.inner) + } +} + +#[async_trait] +impl ObjectStore for ThrottledStore { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + sleep(self.config().wait_put_per_call).await; + + self.inner.put(location, bytes).await + } + + async fn put_multipart( + &self, + _location: &Path, + ) -> Result<(MultipartId, Box)> { + Err(super::Error::NotImplemented) + } + + async fn abort_multipart( + &self, + _location: &Path, + _multipart_id: &MultipartId, + ) -> Result<()> { + Err(super::Error::NotImplemented) + } + + async fn get(&self, location: &Path) -> Result { + sleep(self.config().wait_get_per_call).await; + + // need to copy to avoid moving / referencing `self` + let wait_get_per_byte = self.config().wait_get_per_byte; + + self.inner.get(location).await.map(|result| { + let s = match result { + GetResult::Stream(s) => s, + GetResult::File(_, _) => unimplemented!(), + }; + + GetResult::Stream( + s.then(move |bytes_result| async move { + match bytes_result { + Ok(bytes) => { + let bytes_len: u32 = usize_to_u32_saturate(bytes.len()); + sleep(wait_get_per_byte * bytes_len).await; + Ok(bytes) + } + Err(err) => Err(err), + } + }) + .boxed(), + ) + }) + } + + async fn get_range(&self, location: &Path, range: Range) -> Result { + let config = self.config(); + + let sleep_duration = config.wait_get_per_call + + config.wait_get_per_byte * (range.end - range.start) as u32; + + sleep(sleep_duration).await; + + self.inner.get_range(location, range).await + } + + async fn get_ranges( + &self, + location: &Path, + ranges: &[Range], + ) -> Result> { + let config = self.config(); + + let total_bytes: usize = ranges.iter().map(|range| range.end - range.start).sum(); + let sleep_duration = + config.wait_get_per_call + config.wait_get_per_byte * total_bytes as u32; + + sleep(sleep_duration).await; + + self.inner.get_ranges(location, ranges).await + } + + async fn head(&self, location: &Path) -> Result { + sleep(self.config().wait_put_per_call).await; + self.inner.head(location).await + } + + async fn delete(&self, location: &Path) -> Result<()> { + sleep(self.config().wait_delete_per_call).await; + + self.inner.delete(location).await + } + + async fn list( + &self, + prefix: Option<&Path>, + ) -> Result>> { + sleep(self.config().wait_list_per_call).await; + + // need to copy to avoid moving / referencing `self` + let wait_list_per_entry = self.config().wait_list_per_entry; + + self.inner.list(prefix).await.map(|stream| { + stream + .then(move |result| async move { + match result { + Ok(entry) => { + sleep(wait_list_per_entry).await; + Ok(entry) + } + Err(err) => Err(err), + } + }) + .boxed() + }) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + sleep(self.config().wait_list_with_delimiter_per_call).await; + + match self.inner.list_with_delimiter(prefix).await { + Ok(list_result) => { + let entries_len = usize_to_u32_saturate(list_result.objects.len()); + sleep(self.config().wait_list_with_delimiter_per_entry * entries_len) + .await; + Ok(list_result) + } + Err(err) => Err(err), + } + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + sleep(self.config().wait_put_per_call).await; + + self.inner.copy(from, to).await + } + + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + sleep(self.config().wait_put_per_call).await; + + self.inner.rename(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + sleep(self.config().wait_put_per_call).await; + + self.inner.copy_if_not_exists(from, to).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + sleep(self.config().wait_put_per_call).await; + + self.inner.rename_if_not_exists(from, to).await + } +} + +/// Saturated `usize` to `u32` cast. +fn usize_to_u32_saturate(x: usize) -> u32 { + x.try_into().unwrap_or(u32::MAX) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + memory::InMemory, + tests::{ + copy_if_not_exists, list_uses_directories_correctly, list_with_delimiter, + put_get_delete_list, rename_and_copy, + }, + }; + use bytes::Bytes; + use futures::TryStreamExt; + use tokio::time::Duration; + use tokio::time::Instant; + + const WAIT_TIME: Duration = Duration::from_millis(100); + const ZERO: Duration = Duration::from_millis(0); // Duration::default isn't constant + + macro_rules! assert_bounds { + ($d:expr, $lower:expr) => { + assert_bounds!($d, $lower, $lower + 2); + }; + ($d:expr, $lower:expr, $upper:expr) => { + let d = $d; + let lower = $lower * WAIT_TIME; + let upper = $upper * WAIT_TIME; + assert!(d >= lower, "{:?} must be >= than {:?}", d, lower); + assert!(d < upper, "{:?} must be < than {:?}", d, upper); + }; + } + + #[tokio::test] + async fn throttle_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + put_get_delete_list(&store).await; + list_uses_directories_correctly(&store).await; + list_with_delimiter(&store).await; + rename_and_copy(&store).await; + copy_if_not_exists(&store).await; + } + + #[tokio::test] + async fn delete_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_delete(&store, None).await, 0); + assert_bounds!(measure_delete(&store, Some(0)).await, 0); + assert_bounds!(measure_delete(&store, Some(10)).await, 0); + + store.config_mut(|cfg| cfg.wait_delete_per_call = WAIT_TIME); + assert_bounds!(measure_delete(&store, None).await, 1); + assert_bounds!(measure_delete(&store, Some(0)).await, 1); + assert_bounds!(measure_delete(&store, Some(10)).await, 1); + } + + #[tokio::test] + // macos github runner is so slow it can't complete within WAIT_TIME*2 + #[cfg(target_os = "linux")] + async fn get_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_get(&store, None).await, 0); + assert_bounds!(measure_get(&store, Some(0)).await, 0); + assert_bounds!(measure_get(&store, Some(10)).await, 0); + + store.config_mut(|cfg| cfg.wait_get_per_call = WAIT_TIME); + assert_bounds!(measure_get(&store, None).await, 1); + assert_bounds!(measure_get(&store, Some(0)).await, 1); + assert_bounds!(measure_get(&store, Some(10)).await, 1); + + store.config_mut(|cfg| { + cfg.wait_get_per_call = ZERO; + cfg.wait_get_per_byte = WAIT_TIME; + }); + assert_bounds!(measure_get(&store, Some(2)).await, 2); + + store.config_mut(|cfg| { + cfg.wait_get_per_call = WAIT_TIME; + cfg.wait_get_per_byte = WAIT_TIME; + }); + assert_bounds!(measure_get(&store, Some(2)).await, 3); + } + + #[tokio::test] + // macos github runner is so slow it can't complete within WAIT_TIME*2 + #[cfg(target_os = "linux")] + async fn list_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_list(&store, 0).await, 0); + assert_bounds!(measure_list(&store, 10).await, 0); + + store.config_mut(|cfg| cfg.wait_list_per_call = WAIT_TIME); + assert_bounds!(measure_list(&store, 0).await, 1); + assert_bounds!(measure_list(&store, 10).await, 1); + + store.config_mut(|cfg| { + cfg.wait_list_per_call = ZERO; + cfg.wait_list_per_entry = WAIT_TIME; + }); + assert_bounds!(measure_list(&store, 2).await, 2); + + store.config_mut(|cfg| { + cfg.wait_list_per_call = WAIT_TIME; + cfg.wait_list_per_entry = WAIT_TIME; + }); + assert_bounds!(measure_list(&store, 2).await, 3); + } + + #[tokio::test] + // macos github runner is so slow it can't complete within WAIT_TIME*2 + #[cfg(target_os = "linux")] + async fn list_with_delimiter_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_list_with_delimiter(&store, 0).await, 0); + assert_bounds!(measure_list_with_delimiter(&store, 10).await, 0); + + store.config_mut(|cfg| cfg.wait_list_with_delimiter_per_call = WAIT_TIME); + assert_bounds!(measure_list_with_delimiter(&store, 0).await, 1); + assert_bounds!(measure_list_with_delimiter(&store, 10).await, 1); + + store.config_mut(|cfg| { + cfg.wait_list_with_delimiter_per_call = ZERO; + cfg.wait_list_with_delimiter_per_entry = WAIT_TIME; + }); + assert_bounds!(measure_list_with_delimiter(&store, 2).await, 2); + + store.config_mut(|cfg| { + cfg.wait_list_with_delimiter_per_call = WAIT_TIME; + cfg.wait_list_with_delimiter_per_entry = WAIT_TIME; + }); + assert_bounds!(measure_list_with_delimiter(&store, 2).await, 3); + } + + #[tokio::test] + async fn put_test() { + let inner = InMemory::new(); + let store = ThrottledStore::new(inner, ThrottleConfig::default()); + + assert_bounds!(measure_put(&store, 0).await, 0); + assert_bounds!(measure_put(&store, 10).await, 0); + + store.config_mut(|cfg| cfg.wait_put_per_call = WAIT_TIME); + assert_bounds!(measure_put(&store, 0).await, 1); + assert_bounds!(measure_put(&store, 10).await, 1); + + store.config_mut(|cfg| cfg.wait_put_per_call = ZERO); + assert_bounds!(measure_put(&store, 0).await, 0); + } + + async fn place_test_object( + store: &ThrottledStore, + n_bytes: Option, + ) -> Path { + let path = Path::from("foo"); + + if let Some(n_bytes) = n_bytes { + let data: Vec<_> = std::iter::repeat(1u8).take(n_bytes).collect(); + let bytes = Bytes::from(data); + store.put(&path, bytes).await.unwrap(); + } else { + // ensure object is absent + store.delete(&path).await.unwrap(); + } + + path + } + + #[allow(dead_code)] + async fn place_test_objects( + store: &ThrottledStore, + n_entries: usize, + ) -> Path { + let prefix = Path::from("foo"); + + // clean up store + let entries: Vec<_> = store + .list(Some(&prefix)) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + for entry in entries { + store.delete(&entry.location).await.unwrap(); + } + + // create new entries + for i in 0..n_entries { + let path = prefix.child(i.to_string().as_str()); + + let data = Bytes::from("bar"); + store.put(&path, data).await.unwrap(); + } + + prefix + } + + async fn measure_delete( + store: &ThrottledStore, + n_bytes: Option, + ) -> Duration { + let path = place_test_object(store, n_bytes).await; + + let t0 = Instant::now(); + store.delete(&path).await.unwrap(); + + t0.elapsed() + } + + #[allow(dead_code)] + async fn measure_get( + store: &ThrottledStore, + n_bytes: Option, + ) -> Duration { + let path = place_test_object(store, n_bytes).await; + + let t0 = Instant::now(); + let res = store.get(&path).await; + if n_bytes.is_some() { + // need to consume bytes to provoke sleep times + let s = match res.unwrap() { + GetResult::Stream(s) => s, + GetResult::File(_, _) => unimplemented!(), + }; + + s.map_ok(|b| bytes::BytesMut::from(&b[..])) + .try_concat() + .await + .unwrap(); + } else { + assert!(res.is_err()); + } + + t0.elapsed() + } + + #[allow(dead_code)] + async fn measure_list( + store: &ThrottledStore, + n_entries: usize, + ) -> Duration { + let prefix = place_test_objects(store, n_entries).await; + + let t0 = Instant::now(); + store + .list(Some(&prefix)) + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + t0.elapsed() + } + + #[allow(dead_code)] + async fn measure_list_with_delimiter( + store: &ThrottledStore, + n_entries: usize, + ) -> Duration { + let prefix = place_test_objects(store, n_entries).await; + + let t0 = Instant::now(); + store.list_with_delimiter(Some(&prefix)).await.unwrap(); + + t0.elapsed() + } + + async fn measure_put(store: &ThrottledStore, n_bytes: usize) -> Duration { + let data: Vec<_> = std::iter::repeat(1u8).take(n_bytes).collect(); + let bytes = Bytes::from(data); + + let t0 = Instant::now(); + store.put(&Path::from("foo"), bytes).await.unwrap(); + + t0.elapsed() + } +} diff --git a/object_store/src/util.rs b/object_store/src/util.rs new file mode 100644 index 000000000000..2814ca244c39 --- /dev/null +++ b/object_store/src/util.rs @@ -0,0 +1,271 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common logic for interacting with remote object stores +use super::Result; +use bytes::Bytes; +use futures::{stream::StreamExt, Stream, TryStreamExt}; + +/// Returns the prefix to be passed to an object store +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub fn format_prefix(prefix: Option<&crate::path::Path>) -> Option { + prefix + .filter(|x| !x.as_ref().is_empty()) + .map(|p| format!("{}{}", p.as_ref(), crate::path::DELIMITER)) +} + +/// Returns a formatted HTTP range header as per +/// +#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] +pub fn format_http_range(range: std::ops::Range) -> String { + format!("bytes={}-{}", range.start, range.end.saturating_sub(1)) +} + +#[cfg(any(feature = "aws", feature = "azure"))] +pub(crate) fn hmac_sha256( + secret: impl AsRef<[u8]>, + bytes: impl AsRef<[u8]>, +) -> ring::hmac::Tag { + let key = ring::hmac::Key::new(ring::hmac::HMAC_SHA256, secret.as_ref()); + ring::hmac::sign(&key, bytes.as_ref()) +} + +/// Collect a stream into [`Bytes`] avoiding copying in the event of a single chunk +pub async fn collect_bytes(mut stream: S, size_hint: Option) -> Result +where + S: Stream> + Send + Unpin, +{ + let first = stream.next().await.transpose()?.unwrap_or_default(); + + // Avoid copying if single response + match stream.next().await.transpose()? { + None => Ok(first), + Some(second) => { + let size_hint = size_hint.unwrap_or_else(|| first.len() + second.len()); + + let mut buf = Vec::with_capacity(size_hint); + buf.extend_from_slice(&first); + buf.extend_from_slice(&second); + while let Some(maybe_bytes) = stream.next().await { + buf.extend_from_slice(&maybe_bytes?); + } + + Ok(buf.into()) + } + } +} + +/// Takes a function and spawns it to a tokio blocking pool if available +pub async fn maybe_spawn_blocking(f: F) -> Result +where + F: FnOnce() -> Result + Send + 'static, + T: Send + 'static, +{ + match tokio::runtime::Handle::try_current() { + Ok(runtime) => runtime.spawn_blocking(f).await?, + Err(_) => f(), + } +} + +/// Range requests with a gap less than or equal to this, +/// will be coalesced into a single request by [`coalesce_ranges`] +pub const OBJECT_STORE_COALESCE_DEFAULT: usize = 1024 * 1024; + +/// Up to this number of range requests will be performed in parallel by [`coalesce_ranges`] +pub const OBJECT_STORE_COALESCE_PARALLEL: usize = 10; + +/// Takes a function `fetch` that can fetch a range of bytes and uses this to +/// fetch the provided byte `ranges` +/// +/// To improve performance it will: +/// +/// * Combine ranges less than `coalesce` bytes apart into a single call to `fetch` +/// * Make multiple `fetch` requests in parallel (up to maximum of 10) +/// +pub async fn coalesce_ranges( + ranges: &[std::ops::Range], + fetch: F, + coalesce: usize, +) -> Result> +where + F: Send + FnMut(std::ops::Range) -> Fut, + Fut: std::future::Future> + Send, +{ + let fetch_ranges = merge_ranges(ranges, coalesce); + + let fetched: Vec<_> = futures::stream::iter(fetch_ranges.iter().cloned()) + .map(fetch) + .buffered(OBJECT_STORE_COALESCE_PARALLEL) + .try_collect() + .await?; + + Ok(ranges + .iter() + .map(|range| { + let idx = fetch_ranges.partition_point(|v| v.start <= range.start) - 1; + let fetch_range = &fetch_ranges[idx]; + let fetch_bytes = &fetched[idx]; + + let start = range.start - fetch_range.start; + let end = range.end - fetch_range.start; + fetch_bytes.slice(start..end) + }) + .collect()) +} + +/// Returns a sorted list of ranges that cover `ranges` +fn merge_ranges( + ranges: &[std::ops::Range], + coalesce: usize, +) -> Vec> { + if ranges.is_empty() { + return vec![]; + } + + let mut ranges = ranges.to_vec(); + ranges.sort_unstable_by_key(|range| range.start); + + let mut ret = Vec::with_capacity(ranges.len()); + let mut start_idx = 0; + let mut end_idx = 1; + + while start_idx != ranges.len() { + let mut range_end = ranges[start_idx].end; + + while end_idx != ranges.len() + && ranges[end_idx] + .start + .checked_sub(range_end) + .map(|delta| delta <= coalesce) + .unwrap_or(true) + { + range_end = range_end.max(ranges[end_idx].end); + end_idx += 1; + } + + let start = ranges[start_idx].start; + let end = range_end; + ret.push(start..end); + + start_idx = end_idx; + end_idx += 1; + } + + ret +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::{thread_rng, Rng}; + use std::ops::Range; + + /// Calls coalesce_ranges and validates the returned data is correct + /// + /// Returns the fetched ranges + async fn do_fetch(ranges: Vec>, coalesce: usize) -> Vec> { + let max = ranges.iter().map(|x| x.end).max().unwrap_or(0); + let src: Vec<_> = (0..max).map(|x| x as u8).collect(); + + let mut fetches = vec![]; + let coalesced = coalesce_ranges( + &ranges, + |range| { + fetches.push(range.clone()); + futures::future::ready(Ok(Bytes::from(src[range].to_vec()))) + }, + coalesce, + ) + .await + .unwrap(); + + assert_eq!(ranges.len(), coalesced.len()); + for (range, bytes) in ranges.iter().zip(coalesced) { + assert_eq!(bytes.as_ref(), &src[range.clone()]); + } + fetches + } + + #[tokio::test] + async fn test_coalesce_ranges() { + let fetches = do_fetch(vec![], 0).await; + assert_eq!(fetches, vec![]); + + let fetches = do_fetch(vec![0..3], 0).await; + assert_eq!(fetches, vec![0..3]); + + let fetches = do_fetch(vec![0..2, 3..5], 0).await; + assert_eq!(fetches, vec![0..2, 3..5]); + + let fetches = do_fetch(vec![0..1, 1..2], 0).await; + assert_eq!(fetches, vec![0..2]); + + let fetches = do_fetch(vec![0..1, 2..72], 1).await; + assert_eq!(fetches, vec![0..72]); + + let fetches = do_fetch(vec![0..1, 56..72, 73..75], 1).await; + assert_eq!(fetches, vec![0..1, 56..75]); + + let fetches = do_fetch(vec![0..1, 5..6, 7..9, 2..3, 4..6], 1).await; + assert_eq!(fetches, vec![0..9]); + + let fetches = do_fetch(vec![0..1, 5..6, 7..9, 2..3, 4..6], 1).await; + assert_eq!(fetches, vec![0..9]); + + let fetches = do_fetch(vec![0..1, 6..7, 8..9, 10..14, 9..10], 4).await; + assert_eq!(fetches, vec![0..1, 6..14]); + } + + #[tokio::test] + async fn test_coalesce_fuzz() { + let mut rand = thread_rng(); + for _ in 0..100 { + let object_len = rand.gen_range(10..250); + let range_count = rand.gen_range(0..10); + let ranges: Vec<_> = (0..range_count) + .map(|_| { + let start = rand.gen_range(0..object_len); + let max_len = 20.min(object_len - start); + let len = rand.gen_range(0..max_len); + start..start + len + }) + .collect(); + + let coalesce = rand.gen_range(1..5); + let fetches = do_fetch(ranges.clone(), coalesce).await; + + for fetch in fetches.windows(2) { + assert!( + fetch[0].start <= fetch[1].start, + "fetches should be sorted, {:?} vs {:?}", + fetch[0], + fetch[1] + ); + + let delta = fetch[1].end - fetch[0].end; + assert!( + delta > coalesce, + "fetches should not overlap by {}, {:?} vs {:?} for {:?}", + coalesce, + fetch[0], + fetch[1], + ranges + ); + } + } + } +} diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index 64819077a744..eb03033c52df 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "parquet" -version = "18.0.0" +version = "22.0.0" license = "Apache-2.0" description = "Apache Parquet implementation in Rust" homepage = "https://github.com/apache/arrow-rs" @@ -27,12 +27,12 @@ keywords = ["arrow", "parquet", "hadoop"] readme = "README.md" build = "build.rs" edition = "2021" -rust-version = "1.57" +rust-version = "1.62" [dependencies] +ahash = "0.8" parquet-format = { version = "4.0.0", default-features = false } bytes = { version = "1.1", default-features = false, features = ["std"] } -byteorder = { version = "1", default-features = false } thrift = { version = "0.13", default-features = false } snap = { version = "1.0", default-features = false, optional = true } brotli = { version = "3.3", default-features = false, features = ["std"], optional = true } @@ -42,25 +42,27 @@ zstd = { version = "0.11.1", optional = true, default-features = false } chrono = { version = "0.4", default-features = false, features = ["alloc"] } num = { version = "0.4", default-features = false } num-bigint = { version = "0.4", default-features = false } -arrow = { path = "../arrow", version = "18.0.0", optional = true, default-features = false, features = ["ipc"] } +arrow = { path = "../arrow", version = "22.0.0", optional = true, default-features = false, features = ["ipc"] } base64 = { version = "0.13", default-features = false, features = ["std"], optional = true } clap = { version = "3", default-features = false, features = ["std", "derive", "env"], optional = true } -serde_json = { version = "1.0", default-features = false, optional = true } +serde_json = { version = "1.0", default-features = false, features = ["std"], optional = true } +seq-macro = { version = "0.3", default-features = false } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } -futures = { version = "0.3", default-features = false, features = ["std" ], optional = true } +futures = { version = "0.3", default-features = false, features = ["std"], optional = true } tokio = { version = "1.0", optional = true, default-features = false, features = ["macros", "fs", "rt", "io-util"] } +hashbrown = { version = "0.12", default-features = false } [dev-dependencies] base64 = { version = "0.13", default-features = false, features = ["std"] } criterion = { version = "0.3", default-features = false } snap = { version = "1.0", default-features = false } tempfile = { version = "3.0", default-features = false } -brotli = { version = "3.3", default-features = false, features = [ "std" ] } -flate2 = { version = "1.0", default-features = false, features = [ "rust_backend" ] } +brotli = { version = "3.3", default-features = false, features = ["std"] } +flate2 = { version = "1.0", default-features = false, features = ["rust_backend"] } lz4 = { version = "1.23", default-features = false } zstd = { version = "0.11", default-features = false } -serde_json = { version = "1.0", default-features = false, features = ["preserve_order"] } -arrow = { path = "../arrow", version = "18.0.0", default-features = false, features = ["ipc", "test_utils", "prettyprint"] } +serde_json = { version = "1.0", features = ["std"], default-features = false } +arrow = { path = "../arrow", version = "22.0.0", default-features = false, features = ["ipc", "test_utils", "prettyprint", "json"] } [package.metadata.docs.rs] all-features = true @@ -70,7 +72,9 @@ default = ["arrow", "snap", "brotli", "flate2", "lz4", "zstd", "base64"] # Enable arrow reader/writer APIs arrow = ["dep:arrow", "base64"] # Enable CLI tools -cli = ["serde_json", "base64", "clap","arrow/csv"] +cli = ["json", "base64", "clap", "arrow/csv"] +# Enable JSON APIs +json = ["serde_json"] # Enable internal testing APIs test_common = ["arrow/test_utils"] # Experimental, unstable functionality primarily used for testing diff --git a/parquet/README.md b/parquet/README.md index fbb6e3e1b5d5..689a664b6326 100644 --- a/parquet/README.md +++ b/parquet/README.md @@ -19,17 +19,38 @@ # Apache Parquet Official Native Rust Implementation -[![Crates.io](https://img.shields.io/crates/v/parquet.svg)](https://crates.io/crates/parquet) +[![crates.io](https://img.shields.io/crates/v/parquet.svg)](https://crates.io/crates/parquet) +[![docs.rs](https://img.shields.io/docsrs/parquet.svg)](https://docs.rs/parquet/latest/parquet/) This crate contains the official Native Rust implementation of [Apache Parquet](https://parquet.apache.org/), which is part of the [Apache Arrow](https://arrow.apache.org/) project. See [crate documentation](https://docs.rs/parquet/latest/parquet/) for examples and the full API. -## Rust Version Compatbility +## Rust Version Compatibility This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler. -## Features +## Versioning / Releases + +The arrow crate follows the [SemVer standard](https://doc.rust-lang.org/cargo/reference/semver.html) defined by Cargo and works well within the Rust crate ecosystem. + +However, for historical reasons, this crate uses versions with major numbers greater than `0.x` (e.g. `19.0.0`), unlike many other crates in the Rust ecosystem which spend extended time releasing versions `0.x` to signal planned ongoing API changes. Minor arrow releases contain only compatible changes, while major releases may contain breaking API changes. + +## Feature Flags + +The `parquet` crate provides the following features which may be enabled in your `Cargo.toml`: + +- `arrow` (default) - support for reading / writing [`arrow`](https://crates.io/crates/arrow) arrays to / from parquet +- `async` - support `async` APIs for reading parquet +- `json` - support for reading / writing `json` data to / from parquet +- `brotli` (default) - support for parquet using `brotli` compression +- `flate2` (default) - support for parquet using `gzip` compression +- `lz4` (default) - support for parquet using `lz4` compression +- `zstd` (default) - support for parquet using `zstd` compression +- `cli` - parquet [CLI tools](https://github.com/apache/arrow-rs/tree/master/parquet/src/bin) +- `experimental` - Experimental APIs which may change, even between minor releases + +## Parquet Feature Status - [x] All encodings supported - [x] All compression codecs supported diff --git a/parquet/benches/arrow_reader.rs b/parquet/benches/arrow_reader.rs index 647a8dc6f393..d8a7f07fba25 100644 --- a/parquet/benches/arrow_reader.rs +++ b/parquet/benches/arrow_reader.rs @@ -20,7 +20,12 @@ use arrow::datatypes::DataType; use criterion::measurement::WallTime; use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion}; use num::FromPrimitive; +use num_bigint::BigInt; +use parquet::arrow::array_reader::{ + make_byte_array_reader, make_fixed_len_byte_array_reader, +}; use parquet::basic::Type; +use parquet::data_type::FixedLenByteArrayType; use parquet::util::{DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator}; use parquet::{ arrow::array_reader::ArrayReader, @@ -43,6 +48,14 @@ fn build_test_schema() -> SchemaDescPtr { OPTIONAL BYTE_ARRAY optional_string_leaf (UTF8); REQUIRED INT64 mandatory_int64_leaf; OPTIONAL INT64 optional_int64_leaf; + REQUIRED INT32 mandatory_decimal1_leaf (DECIMAL(8,2)); + OPTIONAL INT32 optional_decimal1_leaf (DECIMAL(8,2)); + REQUIRED INT64 mandatory_decimal2_leaf (DECIMAL(16,2)); + OPTIONAL INT64 optional_decimal2_leaf (DECIMAL(16,2)); + REQUIRED BYTE_ARRAY mandatory_decimal3_leaf (DECIMAL(16,2)); + OPTIONAL BYTE_ARRAY optional_decimal3_leaf (DECIMAL(16,2)); + REQUIRED FIXED_LEN_BYTE_ARRAY (16) mandatory_decimal4_leaf (DECIMAL(16,2)); + OPTIONAL FIXED_LEN_BYTE_ARRAY (16) optional_decimal4_leaf (DECIMAL(16,2)); } "; parse_message_type(message_type) @@ -61,11 +74,78 @@ pub fn seedable_rng() -> StdRng { StdRng::seed_from_u64(42) } +// support byte array for decimal +fn build_encoded_decimal_bytes_page_iterator( + schema: SchemaDescPtr, + column_desc: ColumnDescPtr, + null_density: f32, + encoding: Encoding, + min: i128, + max: i128, +) -> impl PageIterator + Clone +where + T: parquet::data_type::DataType, + T::T: From>, +{ + let max_def_level = column_desc.max_def_level(); + let max_rep_level = column_desc.max_rep_level(); + let rep_levels = vec![0; VALUES_PER_PAGE]; + let mut rng = seedable_rng(); + let mut pages: Vec> = Vec::new(); + for _i in 0..NUM_ROW_GROUPS { + let mut column_chunk_pages = Vec::new(); + for _j in 0..PAGES_PER_GROUP { + // generate page + let mut values = Vec::with_capacity(VALUES_PER_PAGE); + let mut def_levels = Vec::with_capacity(VALUES_PER_PAGE); + for _k in 0..VALUES_PER_PAGE { + let def_level = if rng.gen::() < null_density { + max_def_level - 1 + } else { + max_def_level + }; + if def_level == max_def_level { + // create the decimal value + let value = rng.gen_range(min..max); + // decimal of parquet use the big-endian to store + let bytes = match column_desc.physical_type() { + Type::BYTE_ARRAY => { + // byte array use the unfixed size + let big_int = BigInt::from(value); + big_int.to_signed_bytes_be() + } + Type::FIXED_LEN_BYTE_ARRAY => { + assert_eq!(column_desc.type_length(), 16); + // fixed length byte array use the fixed size + // the size is 16 + value.to_be_bytes().to_vec() + } + _ => unimplemented!(), + }; + let value = T::T::from(bytes); + values.push(value); + } + def_levels.push(def_level); + } + let mut page_builder = + DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); + page_builder.add_rep_levels(max_rep_level, &rep_levels); + page_builder.add_def_levels(max_def_level, &def_levels); + page_builder.add_values::(encoding, &values); + column_chunk_pages.push(page_builder.consume()); + } + pages.push(column_chunk_pages); + } + InMemoryPageIterator::new(schema, column_desc, pages) +} + fn build_encoded_primitive_page_iterator( schema: SchemaDescPtr, column_desc: ColumnDescPtr, null_density: f32, encoding: Encoding, + min: usize, + max: usize, ) -> impl PageIterator + Clone where T: parquet::data_type::DataType, @@ -90,7 +170,7 @@ where }; if def_level == max_def_level { let value = - FromPrimitive::from_usize(rng.gen_range(0..1000)).unwrap(); + FromPrimitive::from_usize(rng.gen_range(min..max)).unwrap(); values.push(value); } def_levels.push(def_level); @@ -300,6 +380,27 @@ fn bench_array_reader(mut array_reader: Box) -> usize { total_count } +fn bench_array_reader_skip(mut array_reader: Box) -> usize { + // test procedure: read data in batches of 8192 until no more data + let mut total_count = 0; + let mut skip = false; + let mut array_len; + loop { + if skip { + array_len = array_reader.skip_records(BATCH_SIZE).unwrap(); + } else { + let array = array_reader.next_batch(BATCH_SIZE); + array_len = array.unwrap().len(); + } + total_count += array_len; + skip = !skip; + if array_len < BATCH_SIZE { + break; + } + } + total_count +} + fn create_primitive_array_reader( page_iterator: impl PageIterator + 'static, column_desc: ColumnDescPtr, @@ -307,21 +408,19 @@ fn create_primitive_array_reader( use parquet::arrow::array_reader::PrimitiveArrayReader; match column_desc.physical_type() { Type::INT32 => { - let reader = PrimitiveArrayReader::::new_with_options( + let reader = PrimitiveArrayReader::::new( Box::new(page_iterator), column_desc, None, - true, ) .unwrap(); Box::new(reader) } Type::INT64 => { - let reader = PrimitiveArrayReader::::new_with_options( + let reader = PrimitiveArrayReader::::new( Box::new(page_iterator), column_desc, None, - true, ) .unwrap(); Box::new(reader) @@ -330,12 +429,28 @@ fn create_primitive_array_reader( } } +fn create_decimal_by_bytes_reader( + page_iterator: impl PageIterator + 'static, + column_desc: ColumnDescPtr, +) -> Box { + let physical_type = column_desc.physical_type(); + match physical_type { + Type::BYTE_ARRAY => { + make_byte_array_reader(Box::new(page_iterator), column_desc, None).unwrap() + } + Type::FIXED_LEN_BYTE_ARRAY => { + make_fixed_len_byte_array_reader(Box::new(page_iterator), column_desc, None) + .unwrap() + } + _ => unimplemented!(), + } +} + fn create_string_byte_array_reader( page_iterator: impl PageIterator + 'static, column_desc: ColumnDescPtr, ) -> Box { - use parquet::arrow::array_reader::make_byte_array_reader; - make_byte_array_reader(Box::new(page_iterator), column_desc, None, true).unwrap() + make_byte_array_reader(Box::new(page_iterator), column_desc, None).unwrap() } fn create_string_byte_array_dictionary_reader( @@ -350,16 +465,91 @@ fn create_string_byte_array_dictionary_reader( Box::new(page_iterator), column_desc, Some(arrow_type), - true, ) .unwrap() } +fn bench_byte_decimal( + group: &mut BenchmarkGroup, + schema: &SchemaDescPtr, + mandatory_column_desc: &ColumnDescPtr, + optional_column_desc: &ColumnDescPtr, + min: i128, + max: i128, +) where + T: parquet::data_type::DataType, + T::T: From>, +{ + // all are plain encoding + let mut count: usize = 0; + + // plain encoded, no NULLs + let data = build_encoded_decimal_bytes_page_iterator::( + schema.clone(), + mandatory_column_desc.clone(), + 0.0, + Encoding::PLAIN, + min, + max, + ); + group.bench_function("plain encoded, mandatory, no NULLs", |b| { + b.iter(|| { + let array_reader = create_decimal_by_bytes_reader( + data.clone(), + mandatory_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + let data = build_encoded_decimal_bytes_page_iterator::( + schema.clone(), + optional_column_desc.clone(), + 0.0, + Encoding::PLAIN, + min, + max, + ); + group.bench_function("plain encoded, optional, no NULLs", |b| { + b.iter(|| { + let array_reader = create_decimal_by_bytes_reader( + data.clone(), + optional_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + // half null + let data = build_encoded_decimal_bytes_page_iterator::( + schema.clone(), + optional_column_desc.clone(), + 0.5, + Encoding::PLAIN, + min, + max, + ); + group.bench_function("plain encoded, optional, half NULLs", |b| { + b.iter(|| { + let array_reader = create_decimal_by_bytes_reader( + data.clone(), + optional_column_desc.clone(), + ); + count = bench_array_reader(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); +} + fn bench_primitive( group: &mut BenchmarkGroup, schema: &SchemaDescPtr, mandatory_column_desc: &ColumnDescPtr, optional_column_desc: &ColumnDescPtr, + min: usize, + max: usize, ) where T: parquet::data_type::DataType, T::T: SampleUniform + FromPrimitive + Copy, @@ -372,6 +562,8 @@ fn bench_primitive( mandatory_column_desc.clone(), 0.0, Encoding::PLAIN, + min, + max, ); group.bench_function("plain encoded, mandatory, no NULLs", |b| { b.iter(|| { @@ -389,6 +581,8 @@ fn bench_primitive( optional_column_desc.clone(), 0.0, Encoding::PLAIN, + min, + max, ); group.bench_function("plain encoded, optional, no NULLs", |b| { b.iter(|| { @@ -405,6 +599,8 @@ fn bench_primitive( optional_column_desc.clone(), 0.5, Encoding::PLAIN, + min, + max, ); group.bench_function("plain encoded, optional, half NULLs", |b| { b.iter(|| { @@ -421,6 +617,8 @@ fn bench_primitive( mandatory_column_desc.clone(), 0.0, Encoding::DELTA_BINARY_PACKED, + min, + max, ); group.bench_function("binary packed, mandatory, no NULLs", |b| { b.iter(|| { @@ -438,6 +636,8 @@ fn bench_primitive( optional_column_desc.clone(), 0.0, Encoding::DELTA_BINARY_PACKED, + min, + max, ); group.bench_function("binary packed, optional, no NULLs", |b| { b.iter(|| { @@ -448,12 +648,51 @@ fn bench_primitive( assert_eq!(count, EXPECTED_VALUE_COUNT); }); + // binary packed skip , no NULLs + let data = build_encoded_primitive_page_iterator::( + schema.clone(), + mandatory_column_desc.clone(), + 0.0, + Encoding::DELTA_BINARY_PACKED, + min, + max, + ); + group.bench_function("binary packed skip, mandatory, no NULLs", |b| { + b.iter(|| { + let array_reader = create_primitive_array_reader( + data.clone(), + mandatory_column_desc.clone(), + ); + count = bench_array_reader_skip(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + + let data = build_encoded_primitive_page_iterator::( + schema.clone(), + optional_column_desc.clone(), + 0.0, + Encoding::DELTA_BINARY_PACKED, + min, + max, + ); + group.bench_function("binary packed skip, optional, no NULLs", |b| { + b.iter(|| { + let array_reader = + create_primitive_array_reader(data.clone(), optional_column_desc.clone()); + count = bench_array_reader_skip(array_reader); + }); + assert_eq!(count, EXPECTED_VALUE_COUNT); + }); + // binary packed, half NULLs let data = build_encoded_primitive_page_iterator::( schema.clone(), optional_column_desc.clone(), 0.5, Encoding::DELTA_BINARY_PACKED, + min, + max, ); group.bench_function("binary packed, optional, half NULLs", |b| { b.iter(|| { @@ -511,6 +750,69 @@ fn bench_primitive( }); } +fn decimal_benches(c: &mut Criterion) { + let schema = build_test_schema(); + // parquet int32, logical type decimal(8,2) + let mandatory_decimal1_leaf_desc = schema.column(6); + let optional_decimal1_leaf_desc = schema.column(7); + let mut group = c.benchmark_group("arrow_array_reader/INT32/Decimal128Array"); + bench_primitive::( + &mut group, + &schema, + &mandatory_decimal1_leaf_desc, + &optional_decimal1_leaf_desc, + // precision is 8: the max is 99999999 + 9999000, + 9999999, + ); + group.finish(); + + // parquet int64, logical type decimal(16,2) + let mut group = c.benchmark_group("arrow_array_reader/INT64/Decimal128Array"); + let mandatory_decimal2_leaf_desc = schema.column(8); + let optional_decimal2_leaf_desc = schema.column(9); + bench_primitive::( + &mut group, + &schema, + &mandatory_decimal2_leaf_desc, + &optional_decimal2_leaf_desc, + // precision is 16: the max is 9999999999999999 + 9999999999999000, + 9999999999999999, + ); + group.finish(); + + // parquet BYTE_ARRAY, logical type decimal(16,2) + let mut group = c.benchmark_group("arrow_array_reader/BYTE_ARRAY/Decimal128Array"); + let mandatory_decimal3_leaf_desc = schema.column(10); + let optional_decimal3_leaf_desc = schema.column(11); + bench_byte_decimal::( + &mut group, + &schema, + &mandatory_decimal3_leaf_desc, + &optional_decimal3_leaf_desc, + // precision is 16: the max is 9999999999999999 + 9999999999999000, + 9999999999999999, + ); + group.finish(); + + let mut group = + c.benchmark_group("arrow_array_reader/FIXED_LENGTH_BYTE_ARRAY/Decimal128Array"); + let mandatory_decimal4_leaf_desc = schema.column(12); + let optional_decimal4_leaf_desc = schema.column(13); + bench_byte_decimal::( + &mut group, + &schema, + &mandatory_decimal4_leaf_desc, + &optional_decimal4_leaf_desc, + // precision is 16: the max is 9999999999999999 + 9999999999999000, + 9999999999999999, + ); + group.finish(); +} + fn add_benches(c: &mut Criterion) { let mut count: usize = 0; @@ -530,6 +832,8 @@ fn add_benches(c: &mut Criterion) { &schema, &mandatory_int32_column_desc, &optional_int32_column_desc, + 0, + 1000, ); group.finish(); @@ -542,6 +846,8 @@ fn add_benches(c: &mut Criterion) { &schema, &mandatory_int64_column_desc, &optional_int64_column_desc, + 0, + 1000, ); group.finish(); @@ -693,5 +999,5 @@ fn add_benches(c: &mut Criterion) { group.finish(); } -criterion_group!(benches, add_benches); +criterion_group!(benches, add_benches, decimal_benches,); criterion_main!(benches); diff --git a/parquet/benches/arrow_writer.rs b/parquet/benches/arrow_writer.rs index 25ff1ca90dc6..ddca1e53c6de 100644 --- a/parquet/benches/arrow_writer.rs +++ b/parquet/benches/arrow_writer.rs @@ -92,6 +92,25 @@ fn create_string_bench_batch( )?) } +fn create_string_dictionary_bench_batch( + size: usize, + null_density: f32, + true_density: f32, +) -> Result { + let fields = vec![Field::new( + "_1", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + )]; + let schema = Schema::new(fields); + Ok(create_random_batch( + Arc::new(schema), + size, + null_density, + true_density, + )?) +} + fn create_string_bench_batch_non_null( size: usize, null_density: f32, @@ -346,6 +365,18 @@ fn bench_primitive_writer(c: &mut Criterion) { b.iter(|| write_batch(&batch).unwrap()) }); + let batch = create_string_dictionary_bench_batch(4096, 0.25, 0.75).unwrap(); + group.throughput(Throughput::Bytes( + batch + .columns() + .iter() + .map(|f| f.get_array_memory_size() as u64) + .sum(), + )); + group.bench_function("4096 values string dictionary", |b| { + b.iter(|| write_batch(&batch).unwrap()) + }); + let batch = create_string_bench_batch_non_null(4096, 0.25, 0.75).unwrap(); group.throughput(Throughput::Bytes( batch diff --git a/parquet/src/arrow/array_reader/builder.rs b/parquet/src/arrow/array_reader/builder.rs index e8c22f95aa0a..5f3ce75824ae 100644 --- a/parquet/src/arrow/array_reader/builder.rs +++ b/parquet/src/arrow/array_reader/builder.rs @@ -17,42 +17,34 @@ use std::sync::Arc; -use arrow::datatypes::{DataType, IntervalUnit, SchemaRef}; +use arrow::datatypes::{DataType, SchemaRef}; use crate::arrow::array_reader::empty_array::make_empty_array_reader; +use crate::arrow::array_reader::fixed_len_byte_array::make_fixed_len_byte_array_reader; use crate::arrow::array_reader::{ make_byte_array_dictionary_reader, make_byte_array_reader, ArrayReader, - ComplexObjectArrayReader, ListArrayReader, MapArrayReader, NullArrayReader, - PrimitiveArrayReader, RowGroupCollection, StructArrayReader, -}; -use crate::arrow::buffer::converter::{ - DecimalArrayConverter, DecimalConverter, FixedLenBinaryConverter, - FixedSizeArrayConverter, Int96ArrayConverter, Int96Converter, - IntervalDayTimeArrayConverter, IntervalDayTimeConverter, - IntervalYearMonthArrayConverter, IntervalYearMonthConverter, + ListArrayReader, MapArrayReader, NullArrayReader, PrimitiveArrayReader, + RowGroupCollection, StructArrayReader, }; use crate::arrow::schema::{convert_schema, ParquetField, ParquetFieldType}; use crate::arrow::ProjectionMask; use crate::basic::Type as PhysicalType; use crate::data_type::{ - BoolType, DoubleType, FixedLenByteArrayType, FloatType, Int32Type, Int64Type, - Int96Type, + BoolType, DoubleType, FloatType, Int32Type, Int64Type, Int96Type, }; use crate::errors::Result; -use crate::schema::types::{ColumnDescriptor, ColumnPath, SchemaDescPtr, Type}; +use crate::schema::types::{ColumnDescriptor, ColumnPath, Type}; /// Create array reader from parquet schema, projection mask, and parquet file reader. pub fn build_array_reader( - parquet_schema: SchemaDescPtr, arrow_schema: SchemaRef, mask: ProjectionMask, - row_groups: Box, + row_groups: &dyn RowGroupCollection, ) -> Result> { - let field = - convert_schema(parquet_schema.as_ref(), mask, Some(arrow_schema.as_ref()))?; + let field = convert_schema(&row_groups.schema(), mask, Some(arrow_schema.as_ref()))?; match &field { - Some(field) => build_reader(field, row_groups.as_ref()), + Some(field) => build_reader(field, row_groups), None => Ok(make_empty_array_reader(row_groups.num_rows())), } } @@ -90,6 +82,7 @@ fn build_map_reader( field.arrow_type.clone(), field.def_level, field.rep_level, + field.nullable, ))) } @@ -104,13 +97,11 @@ fn build_list_reader( let data_type = field.arrow_type.clone(); let item_reader = build_reader(&children[0], row_groups)?; - let item_type = item_reader.get_data_type().clone(); match is_large { false => Ok(Box::new(ListArrayReader::::new( item_reader, data_type, - item_type, field.def_level, field.rep_level, field.nullable, @@ -118,7 +109,6 @@ fn build_list_reader( true => Ok(Box::new(ListArrayReader::::new( item_reader, data_type, - item_type, field.def_level, field.rep_level, field.nullable, @@ -131,14 +121,12 @@ fn build_primitive_reader( field: &ParquetField, row_groups: &dyn RowGroupCollection, ) -> Result> { - let (col_idx, primitive_type, type_len) = match &field.field_type { + let (col_idx, primitive_type) = match &field.field_type { ParquetFieldType::Primitive { col_idx, primitive_type, } => match primitive_type.as_ref() { - Type::PrimitiveType { type_length, .. } => { - (*col_idx, primitive_type.clone(), *type_length) - } + Type::PrimitiveType { .. } => (*col_idx, primitive_type.clone()), Type::GroupType { .. } => unreachable!(), }, _ => unreachable!(), @@ -160,18 +148,14 @@ fn build_primitive_reader( )); let page_iterator = row_groups.column_chunks(col_idx)?; - let null_mask_only = field.def_level == 1 && field.nullable; let arrow_type = Some(field.arrow_type.clone()); match physical_type { - PhysicalType::BOOLEAN => Ok(Box::new( - PrimitiveArrayReader::::new_with_options( - page_iterator, - column_desc, - arrow_type, - null_mask_only, - )?, - )), + PhysicalType::BOOLEAN => Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)), PhysicalType::INT32 => { if let Some(DataType::Null) = arrow_type { Ok(Box::new(NullArrayReader::::new( @@ -179,130 +163,42 @@ fn build_primitive_reader( column_desc, )?)) } else { - Ok(Box::new( - PrimitiveArrayReader::::new_with_options( - page_iterator, - column_desc, - arrow_type, - null_mask_only, - )?, - )) - } - } - PhysicalType::INT64 => Ok(Box::new( - PrimitiveArrayReader::::new_with_options( - page_iterator, - column_desc, - arrow_type, - null_mask_only, - )?, - )), - PhysicalType::INT96 => { - // get the optional timezone information from arrow type - let timezone = arrow_type.as_ref().and_then(|data_type| { - if let DataType::Timestamp(_, tz) = data_type { - tz.clone() - } else { - None - } - }); - let converter = Int96Converter::new(Int96ArrayConverter { timezone }); - Ok(Box::new(ComplexObjectArrayReader::< - Int96Type, - Int96Converter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - PhysicalType::FLOAT => Ok(Box::new( - PrimitiveArrayReader::::new_with_options( - page_iterator, - column_desc, - arrow_type, - null_mask_only, - )?, - )), - PhysicalType::DOUBLE => Ok(Box::new( - PrimitiveArrayReader::::new_with_options( - page_iterator, - column_desc, - arrow_type, - null_mask_only, - )?, - )), - PhysicalType::BYTE_ARRAY => match arrow_type { - Some(DataType::Dictionary(_, _)) => make_byte_array_dictionary_reader( - page_iterator, - column_desc, - arrow_type, - null_mask_only, - ), - _ => make_byte_array_reader( - page_iterator, - column_desc, - arrow_type, - null_mask_only, - ), - }, - PhysicalType::FIXED_LEN_BYTE_ARRAY => match field.arrow_type { - DataType::Decimal(precision, scale) => { - let converter = DecimalConverter::new(DecimalArrayConverter::new( - precision as i32, - scale as i32, - )); - Ok(Box::new(ComplexObjectArrayReader::< - FixedLenByteArrayType, - DecimalConverter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - DataType::Interval(IntervalUnit::DayTime) => { - let converter = - IntervalDayTimeConverter::new(IntervalDayTimeArrayConverter {}); - Ok(Box::new(ComplexObjectArrayReader::< - FixedLenByteArrayType, - _, - >::new( + Ok(Box::new(PrimitiveArrayReader::::new( page_iterator, column_desc, - converter, arrow_type, )?)) } - DataType::Interval(IntervalUnit::YearMonth) => { - let converter = - IntervalYearMonthConverter::new(IntervalYearMonthArrayConverter {}); - Ok(Box::new(ComplexObjectArrayReader::< - FixedLenByteArrayType, - _, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) - } - _ => { - let converter = - FixedLenBinaryConverter::new(FixedSizeArrayConverter::new(type_len)); - Ok(Box::new(ComplexObjectArrayReader::< - FixedLenByteArrayType, - FixedLenBinaryConverter, - >::new( - page_iterator, - column_desc, - converter, - arrow_type, - )?)) + } + PhysicalType::INT64 => Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)), + PhysicalType::INT96 => Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)), + PhysicalType::FLOAT => Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)), + PhysicalType::DOUBLE => Ok(Box::new(PrimitiveArrayReader::::new( + page_iterator, + column_desc, + arrow_type, + )?)), + PhysicalType::BYTE_ARRAY => match arrow_type { + Some(DataType::Dictionary(_, _)) => { + make_byte_array_dictionary_reader(page_iterator, column_desc, arrow_type) } + _ => make_byte_array_reader(page_iterator, column_desc, arrow_type), }, + PhysicalType::FIXED_LEN_BYTE_ARRAY => { + make_fixed_len_byte_array_reader(page_iterator, column_desc, arrow_type) + } } } @@ -330,7 +226,7 @@ mod tests { use super::*; use crate::arrow::parquet_to_arrow_schema; use crate::file::reader::{FileReader, SerializedFileReader}; - use crate::util::test_common::get_test_file; + use crate::util::test_common::file_util::get_test_file; use arrow::datatypes::Field; use std::sync::Arc; @@ -348,13 +244,8 @@ mod tests { ) .unwrap(); - let array_reader = build_array_reader( - file_reader.metadata().file_metadata().schema_descr_ptr(), - Arc::new(arrow_schema), - mask, - Box::new(file_reader), - ) - .unwrap(); + let array_reader = + build_array_reader(Arc::new(arrow_schema), mask, &file_reader).unwrap(); // Create arrow types let arrow_type = DataType::Struct(vec![Field::new( diff --git a/parquet/src/arrow/array_reader/byte_array.rs b/parquet/src/arrow/array_reader/byte_array.rs index 853bc2b18989..4bf4dee0d0b2 100644 --- a/parquet/src/arrow/array_reader/byte_array.rs +++ b/parquet/src/arrow/array_reader/byte_array.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::array_reader::{read_records, ArrayReader}; +use crate::arrow::array_reader::{read_records, skip_records, ArrayReader}; +use crate::arrow::buffer::bit_util::sign_extend_be; use crate::arrow::buffer::offset_buffer::OffsetBuffer; +use crate::arrow::decoder::{DeltaByteArrayDecoder, DictIndexDecoder}; use crate::arrow::record_reader::buffer::ScalarValue; use crate::arrow::record_reader::GenericRecordReader; use crate::arrow::schema::parquet_to_arrow_field; @@ -24,25 +26,22 @@ use crate::basic::{ConvertedType, Encoding}; use crate::column::page::PageIterator; use crate::column::reader::decoder::ColumnValueDecoder; use crate::data_type::Int32Type; -use crate::encodings::{ - decoding::{Decoder, DeltaBitPackDecoder}, - rle::RleDecoder, -}; +use crate::encodings::decoding::{Decoder, DeltaBitPackDecoder}; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; use crate::util::memory::ByteBufferPtr; -use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, BinaryArray, Decimal128Array, OffsetSizeTrait}; use arrow::buffer::Buffer; use arrow::datatypes::DataType as ArrowType; use std::any::Any; use std::ops::Range; +use std::sync::Arc; /// Returns an [`ArrayReader`] that decodes the provided byte array column pub fn make_byte_array_reader( pages: Box, column_desc: ColumnDescPtr, arrow_type: Option, - null_mask_only: bool, ) -> Result> { // Check if Arrow type is specified, else create it from Parquet type let data_type = match arrow_type { @@ -53,16 +52,14 @@ pub fn make_byte_array_reader( }; match data_type { - ArrowType::Binary | ArrowType::Utf8 => { - let reader = - GenericRecordReader::new_with_options(column_desc, null_mask_only); + ArrowType::Binary | ArrowType::Utf8 | ArrowType::Decimal128(_, _) => { + let reader = GenericRecordReader::new(column_desc); Ok(Box::new(ByteArrayReader::::new( pages, data_type, reader, ))) } ArrowType::LargeUtf8 | ArrowType::LargeBinary => { - let reader = - GenericRecordReader::new_with_options(column_desc, null_mask_only); + let reader = GenericRecordReader::new(column_desc); Ok(Box::new(ByteArrayReader::::new( pages, data_type, reader, ))) @@ -111,31 +108,45 @@ impl ArrayReader for ByteArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { - read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size) + } + + fn consume_batch(&mut self) -> Result { let buffer = self.record_reader.consume_record_data(); let null_buffer = self.record_reader.consume_bitmap_buffer(); self.def_levels_buffer = self.record_reader.consume_def_levels(); self.rep_levels_buffer = self.record_reader.consume_rep_levels(); self.record_reader.reset(); - Ok(buffer.into_array(null_buffer, self.data_type.clone())) + let array = match self.data_type { + ArrowType::Decimal128(p, s) => { + let array = buffer.into_array(null_buffer, ArrowType::Binary); + let binary = array.as_any().downcast_ref::().unwrap(); + let decimal = binary + .iter() + .map(|opt| Some(i128::from_be_bytes(sign_extend_be(opt?)))) + .collect::() + .with_precision_and_scale(p, s)?; + + Arc::new(decimal) + } + _ => buffer.into_array(null_buffer, self.data_type.clone()), + }; + + Ok(array) } fn skip_records(&mut self, num_records: usize) -> Result { - self.record_reader.skip_records(num_records) + skip_records(&mut self.record_reader, self.pages.as_mut(), num_records) } fn get_def_levels(&self) -> Option<&[i16]> { - self.def_levels_buffer - .as_ref() - .map(|buf| buf.typed_data()) + self.def_levels_buffer.as_ref().map(|buf| buf.typed_data()) } fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_levels_buffer - .as_ref() - .map(|buf| buf.typed_data()) + self.rep_levels_buffer.as_ref().map(|buf| buf.typed_data()) } } @@ -388,11 +399,8 @@ impl ByteArrayDecoderPlain { Ok(to_read) } - pub fn skip( - &mut self, - to_skip: usize, - ) -> Result { - let to_skip = to_skip.min( self.max_remaining_values); + pub fn skip(&mut self, to_skip: usize) -> Result { + let to_skip = to_skip.min(self.max_remaining_values); let mut skip = 0; let buf = self.buf.as_ref(); @@ -406,6 +414,7 @@ impl ByteArrayDecoderPlain { skip += 1; self.offset = self.offset + 4 + len; } + self.max_remaining_values -= skip; Ok(skip) } } @@ -477,10 +486,7 @@ impl ByteArrayDecoderDeltaLength { Ok(to_read) } - fn skip( - &mut self, - to_skip: usize, - ) -> Result { + fn skip(&mut self, to_skip: usize) -> Result { let remain_values = self.lengths.len() - self.length_offset; let to_skip = remain_values.min(to_skip); @@ -495,45 +501,14 @@ impl ByteArrayDecoderDeltaLength { /// Decoder from [`Encoding::DELTA_BYTE_ARRAY`] to [`OffsetBuffer`] pub struct ByteArrayDecoderDelta { - prefix_lengths: Vec, - suffix_lengths: Vec, - data: ByteBufferPtr, - length_offset: usize, - data_offset: usize, - last_value: Vec, + decoder: DeltaByteArrayDecoder, validate_utf8: bool, } impl ByteArrayDecoderDelta { fn new(data: ByteBufferPtr, validate_utf8: bool) -> Result { - let mut prefix = DeltaBitPackDecoder::::new(); - prefix.set_data(data.all(), 0)?; - - let num_prefix = prefix.values_left(); - let mut prefix_lengths = vec![0; num_prefix]; - assert_eq!(prefix.get(&mut prefix_lengths)?, num_prefix); - - let mut suffix = DeltaBitPackDecoder::::new(); - suffix.set_data(data.start_from(prefix.get_offset()), 0)?; - - let num_suffix = suffix.values_left(); - let mut suffix_lengths = vec![0; num_suffix]; - assert_eq!(suffix.get(&mut suffix_lengths)?, num_suffix); - - if num_prefix != num_suffix { - return Err(general_err!(format!( - "inconsistent DELTA_BYTE_ARRAY lengths, prefixes: {}, suffixes: {}", - num_prefix, num_suffix - ))); - } - Ok(Self { - prefix_lengths, - suffix_lengths, - data, - length_offset: 0, - data_offset: prefix.get_offset() + suffix.get_offset(), - last_value: vec![], + decoder: DeltaByteArrayDecoder::new(data)?, validate_utf8, }) } @@ -544,101 +519,32 @@ impl ByteArrayDecoderDelta { len: usize, ) -> Result { let initial_values_length = output.values.len(); - assert_eq!(self.prefix_lengths.len(), self.suffix_lengths.len()); - - let to_read = len.min(self.prefix_lengths.len() - self.length_offset); - - output.offsets.reserve(to_read); - - let length_range = self.length_offset..self.length_offset + to_read; - let iter = self.prefix_lengths[length_range.clone()] - .iter() - .zip(&self.suffix_lengths[length_range]); - - let data = self.data.as_ref(); - - for (prefix_length, suffix_length) in iter { - let prefix_length = *prefix_length as usize; - let suffix_length = *suffix_length as usize; - - if self.data_offset + suffix_length > self.data.len() { - return Err(ParquetError::EOF("eof decoding byte array".into())); - } + output.offsets.reserve(len.min(self.decoder.remaining())); - self.last_value.truncate(prefix_length); - self.last_value.extend_from_slice( - &data[self.data_offset..self.data_offset + suffix_length], - ); - output.try_push(&self.last_value, self.validate_utf8)?; - - self.data_offset += suffix_length; - } - - self.length_offset += to_read; + let read = self + .decoder + .read(len, |bytes| output.try_push(bytes, self.validate_utf8))?; if self.validate_utf8 { output.check_valid_utf8(initial_values_length)?; } - Ok(to_read) + Ok(read) } - fn skip( - &mut self, - to_skip: usize, - ) -> Result { - let to_skip = to_skip.min(self.prefix_lengths.len() - self.length_offset); - - let length_range = self.length_offset..self.length_offset + to_skip; - let iter = self.prefix_lengths[length_range.clone()] - .iter() - .zip(&self.suffix_lengths[length_range]); - - let data = self.data.as_ref(); - - for (prefix_length, suffix_length) in iter { - let prefix_length = *prefix_length as usize; - let suffix_length = *suffix_length as usize; - - if self.data_offset + suffix_length > self.data.len() { - return Err(ParquetError::EOF("eof decoding byte array".into())); - } - - self.last_value.truncate(prefix_length); - self.last_value.extend_from_slice( - &data[self.data_offset..self.data_offset + suffix_length], - ); - self.data_offset += suffix_length; - } - self.length_offset += to_skip; - Ok(to_skip) + fn skip(&mut self, to_skip: usize) -> Result { + self.decoder.skip(to_skip) } } /// Decoder from [`Encoding::RLE_DICTIONARY`] to [`OffsetBuffer`] pub struct ByteArrayDecoderDictionary { - decoder: RleDecoder, - - index_buf: Box<[i32; 1024]>, - index_buf_len: usize, - index_offset: usize, - - /// This is a maximum as the null count is not always known, e.g. value data from - /// a v1 data page - max_remaining_values: usize, + decoder: DictIndexDecoder, } impl ByteArrayDecoderDictionary { fn new(data: ByteBufferPtr, num_levels: usize, num_values: Option) -> Self { - let bit_width = data[0]; - let mut decoder = RleDecoder::new(bit_width); - decoder.set_data(data.start_from(1)); - Self { - decoder, - index_buf: Box::new([0; 1024]), - index_buf_len: 0, - index_offset: 0, - max_remaining_values: num_values.unwrap_or(num_levels), + decoder: DictIndexDecoder::new(data, num_levels, num_values), } } @@ -648,37 +554,18 @@ impl ByteArrayDecoderDictionary { dict: &OffsetBuffer, len: usize, ) -> Result { + // All data must be NULL if dict.is_empty() { - return Ok(0); // All data must be NULL + return Ok(0); } - let mut values_read = 0; - - while values_read != len && self.max_remaining_values != 0 { - if self.index_offset == self.index_buf_len { - let read = self.decoder.get_batch(self.index_buf.as_mut())?; - if read == 0 { - break; - } - self.index_buf_len = read; - self.index_offset = 0; - } - - let to_read = (len - values_read) - .min(self.index_buf_len - self.index_offset) - .min(self.max_remaining_values); - + self.decoder.read(len, |keys| { output.extend_from_dictionary( - &self.index_buf[self.index_offset..self.index_offset + to_read], + keys, dict.offsets.as_slice(), dict.values.as_slice(), - )?; - - self.index_offset += to_read; - values_read += to_read; - self.max_remaining_values -= to_read; - } - Ok(values_read) + ) + }) } fn skip( @@ -686,31 +573,12 @@ impl ByteArrayDecoderDictionary { dict: &OffsetBuffer, to_skip: usize, ) -> Result { - let to_skip = to_skip.min(self.max_remaining_values); // All data must be NULL if dict.is_empty() { return Ok(0); } - let mut values_skip = 0; - while values_skip < to_skip { - if self.index_offset == self.index_buf_len { - let read = self.decoder.get_batch(self.index_buf.as_mut())?; - if read == 0 { - break; - } - self.index_buf_len = read; - self.index_offset = 0; - } - - let skip = (to_skip - values_skip) - .min(self.index_buf_len - self.index_offset); - - self.index_offset += skip; - self.max_remaining_values -= skip; - values_skip += skip; - } - Ok(values_skip) + self.decoder.skip(to_skip) } } @@ -815,14 +683,7 @@ mod tests { assert_eq!( strings.iter().collect::>(), - vec![ - None, - None, - Some("hello"), - Some("b"), - None, - None, - ] + vec![None, None, Some("hello"), Some("b"), None, None,] ); } } diff --git a/parquet/src/arrow/array_reader/byte_array_dictionary.rs b/parquet/src/arrow/array_reader/byte_array_dictionary.rs index bfe557499914..0a5d94fa6ae8 100644 --- a/parquet/src/arrow/array_reader/byte_array_dictionary.rs +++ b/parquet/src/arrow/array_reader/byte_array_dictionary.rs @@ -25,7 +25,7 @@ use arrow::buffer::Buffer; use arrow::datatypes::{ArrowNativeType, DataType as ArrowType}; use crate::arrow::array_reader::byte_array::{ByteArrayDecoder, ByteArrayDecoderPlain}; -use crate::arrow::array_reader::{read_records, ArrayReader}; +use crate::arrow::array_reader::{read_records, skip_records, ArrayReader}; use crate::arrow::buffer::{ dictionary_buffer::DictionaryBuffer, offset_buffer::OffsetBuffer, }; @@ -44,17 +44,14 @@ use crate::util::memory::ByteBufferPtr; /// A macro to reduce verbosity of [`make_byte_array_dictionary_reader`] macro_rules! make_reader { ( - ($pages:expr, $column_desc:expr, $data_type:expr, $null_mask_only:expr) => match ($k:expr, $v:expr) { + ($pages:expr, $column_desc:expr, $data_type:expr) => match ($k:expr, $v:expr) { $(($key_arrow:pat, $value_arrow:pat) => ($key_type:ty, $value_type:ty),)+ } ) => { match (($k, $v)) { $( ($key_arrow, $value_arrow) => { - let reader = GenericRecordReader::new_with_options( - $column_desc, - $null_mask_only, - ); + let reader = GenericRecordReader::new($column_desc); Ok(Box::new(ByteArrayDictionaryReader::<$key_type, $value_type>::new( $pages, $data_type, reader, ))) @@ -84,7 +81,6 @@ pub fn make_byte_array_dictionary_reader( pages: Box, column_desc: ColumnDescPtr, arrow_type: Option, - null_mask_only: bool, ) -> Result> { // Check if Arrow type is specified, else create it from Parquet type let data_type = match arrow_type { @@ -97,7 +93,7 @@ pub fn make_byte_array_dictionary_reader( match &data_type { ArrowType::Dictionary(key_type, value_type) => { make_reader! { - (pages, column_desc, data_type, null_mask_only) => match (key_type.as_ref(), value_type.as_ref()) { + (pages, column_desc, data_type) => match (key_type.as_ref(), value_type.as_ref()) { (ArrowType::UInt8, ArrowType::Binary | ArrowType::Utf8) => (u8, i32), (ArrowType::UInt8, ArrowType::LargeBinary | ArrowType::LargeUtf8) => (u8, i64), (ArrowType::Int8, ArrowType::Binary | ArrowType::Utf8) => (i8, i32), @@ -171,8 +167,11 @@ where &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { - read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size) + } + + fn consume_batch(&mut self) -> Result { let buffer = self.record_reader.consume_record_data(); let null_buffer = self.record_reader.consume_bitmap_buffer(); let array = buffer.into_array(null_buffer, &self.data_type)?; @@ -185,7 +184,7 @@ where } fn skip_records(&mut self, num_records: usize) -> Result { - self.record_reader.skip_records(num_records) + skip_records(&mut self.record_reader, self.pages.as_mut(), num_records) } fn get_def_levels(&self) -> Option<&[i16]> { @@ -346,6 +345,7 @@ where // Keys will be validated on conversion to arrow let keys_slice = keys.spare_capacity_mut(range.start + len); let len = decoder.get_batch(&mut keys_slice[range.start..])?; + *max_remaining_values -= len; Ok(len) } None => { @@ -368,7 +368,7 @@ where dict_offsets, dict_values, )?; - + *max_remaining_values -= len; Ok(len) } } @@ -376,8 +376,20 @@ where } } - fn skip_values(&mut self, _num_values: usize) -> Result { - Err(nyi_err!("https://github.com/apache/arrow-rs/issues/1792")) + fn skip_values(&mut self, num_values: usize) -> Result { + match self.decoder.as_mut().expect("decoder set") { + MaybeDictionaryDecoder::Fallback(decoder) => { + decoder.skip::(num_values, None) + } + MaybeDictionaryDecoder::Dict { + decoder, + max_remaining_values, + } => { + let num_values = num_values.min(*max_remaining_values); + *max_remaining_values -= num_values; + decoder.skip(num_values) + } + } } } @@ -464,6 +476,68 @@ mod tests { ) } + #[test] + fn test_dictionary_preservation_skip() { + let data_type = utf8_dictionary(); + + let data: Vec<_> = vec!["0", "1", "0", "1", "2", "1", "2"] + .into_iter() + .map(ByteArray::from) + .collect(); + let (dict, encoded) = encode_dictionary(&data); + + let column_desc = utf8_column(); + let mut decoder = DictionaryDecoder::::new(&column_desc); + + decoder + .set_dict(dict, 3, Encoding::RLE_DICTIONARY, false) + .unwrap(); + + decoder + .set_data(Encoding::RLE_DICTIONARY, encoded, 7, Some(data.len())) + .unwrap(); + + let mut output = DictionaryBuffer::::default(); + + // read two skip one + assert_eq!(decoder.read(&mut output, 0..2).unwrap(), 2); + assert_eq!(decoder.skip_values(1).unwrap(), 1); + + assert!(matches!(output, DictionaryBuffer::Dict { .. })); + + // read two skip one + assert_eq!(decoder.read(&mut output, 2..4).unwrap(), 2); + assert_eq!(decoder.skip_values(1).unwrap(), 1); + + // read one and test on skip at the end + assert_eq!(decoder.read(&mut output, 4..5).unwrap(), 1); + assert_eq!(decoder.skip_values(4).unwrap(), 0); + + let valid = vec![true, true, true, true, true]; + let valid_buffer = Buffer::from_iter(valid.iter().cloned()); + output.pad_nulls(0, 5, 5, valid_buffer.as_slice()); + + assert!(matches!(output, DictionaryBuffer::Dict { .. })); + + let array = output.into_array(Some(valid_buffer), &data_type).unwrap(); + assert_eq!(array.data_type(), &data_type); + + let array = cast(&array, &ArrowType::Utf8).unwrap(); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!(strings.len(), 5); + + assert_eq!( + strings.iter().collect::>(), + vec![ + Some("0"), + Some("1"), + Some("1"), + Some("2"), + Some("2"), + ] + ) + } + #[test] fn test_dictionary_fallback() { let data_type = utf8_dictionary(); @@ -507,6 +581,51 @@ mod tests { } } + #[test] + fn test_dictionary_skip_fallback() { + let data_type = utf8_dictionary(); + let data = vec!["hello", "world", "a", "b"]; + + let (pages, encoded_dictionary) = byte_array_all_encodings(data.clone()); + let num_encodings = pages.len(); + + let column_desc = utf8_column(); + let mut decoder = DictionaryDecoder::::new(&column_desc); + + decoder + .set_dict(encoded_dictionary, 4, Encoding::RLE_DICTIONARY, false) + .unwrap(); + + // Read all pages into single buffer + let mut output = DictionaryBuffer::::default(); + + for (encoding, page) in pages { + decoder.set_data(encoding, page, 4, Some(4)).unwrap(); + decoder.skip_values(2).expect("skipping two values"); + assert_eq!(decoder.read(&mut output, 0..1024).unwrap(), 2); + } + let array = output.into_array(None, &data_type).unwrap(); + assert_eq!(array.data_type(), &data_type); + + let array = cast(&array, &ArrowType::Utf8).unwrap(); + let strings = array.as_any().downcast_ref::().unwrap(); + assert_eq!(strings.len(), (data.len() - 2) * num_encodings); + + // Should have a copy of `data` for each encoding + for i in 0..num_encodings { + assert_eq!( + &strings + .iter() + .skip(i * (data.len() - 2)) + .take(data.len() - 2) + .map(|x| x.unwrap()) + .collect::>(), + &data[2..] + ) + } + } + + #[test] fn test_too_large_dictionary() { let data: Vec<_> = (0..128) @@ -542,7 +661,7 @@ mod tests { .set_dict(encoded_dictionary, 4, Encoding::PLAIN_DICTIONARY, false) .unwrap(); - for (encoding, page) in pages { + for (encoding, page) in pages.clone() { let mut output = DictionaryBuffer::::default(); decoder.set_data(encoding, page, 8, None).unwrap(); assert_eq!(decoder.read(&mut output, 0..1024).unwrap(), 0); @@ -555,5 +674,19 @@ mod tests { assert_eq!(array.len(), 8); assert_eq!(array.null_count(), 8); } + + for (encoding, page) in pages { + let mut output = DictionaryBuffer::::default(); + decoder.set_data(encoding, page, 8, None).unwrap(); + assert_eq!(decoder.skip_values(1024).unwrap(), 0); + + output.pad_nulls(0, 0, 8, &[0]); + let array = output + .into_array(Some(Buffer::from(&[0])), &data_type) + .unwrap(); + + assert_eq!(array.len(), 8); + assert_eq!(array.null_count(), 8); + } } } diff --git a/parquet/src/arrow/array_reader/complex_object_array.rs b/parquet/src/arrow/array_reader/complex_object_array.rs deleted file mode 100644 index 6e7585ff944c..000000000000 --- a/parquet/src/arrow/array_reader/complex_object_array.rs +++ /dev/null @@ -1,539 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::arrow::array_reader::ArrayReader; -use crate::arrow::buffer::converter::Converter; -use crate::arrow::schema::parquet_to_arrow_field; -use crate::column::page::PageIterator; -use crate::column::reader::ColumnReaderImpl; -use crate::data_type::DataType; -use crate::errors::Result; -use crate::schema::types::ColumnDescPtr; -use arrow::array::ArrayRef; -use arrow::datatypes::DataType as ArrowType; -use std::any::Any; -use std::marker::PhantomData; - -/// Primitive array readers are leaves of array reader tree. They accept page iterator -/// and read them into primitive arrays. -pub struct ComplexObjectArrayReader -where - T: DataType, - C: Converter>, ArrayRef> + 'static, -{ - data_type: ArrowType, - pages: Box, - def_levels_buffer: Option>, - rep_levels_buffer: Option>, - column_desc: ColumnDescPtr, - column_reader: Option>, - converter: C, - _parquet_type_marker: PhantomData, - _converter_marker: PhantomData, -} - -impl ArrayReader for ComplexObjectArrayReader -where - T: DataType, - C: Converter>, ArrayRef> + Send + 'static, -{ - fn as_any(&self) -> &dyn Any { - self - } - - fn get_data_type(&self) -> &ArrowType { - &self.data_type - } - - fn next_batch(&mut self, batch_size: usize) -> Result { - // Try to initialize column reader - if self.column_reader.is_none() { - self.next_column_reader()?; - } - - let mut data_buffer: Vec = Vec::with_capacity(batch_size); - data_buffer.resize_with(batch_size, T::T::default); - - let mut def_levels_buffer = if self.column_desc.max_def_level() > 0 { - let mut buf: Vec = Vec::with_capacity(batch_size); - buf.resize_with(batch_size, || 0); - Some(buf) - } else { - None - }; - - let mut rep_levels_buffer = if self.column_desc.max_rep_level() > 0 { - let mut buf: Vec = Vec::with_capacity(batch_size); - buf.resize_with(batch_size, || 0); - Some(buf) - } else { - None - }; - - let mut num_read = 0; - - while self.column_reader.is_some() && num_read < batch_size { - let num_to_read = batch_size - num_read; - let cur_data_buf = &mut data_buffer[num_read..]; - let cur_def_levels_buf = - def_levels_buffer.as_mut().map(|b| &mut b[num_read..]); - let cur_rep_levels_buf = - rep_levels_buffer.as_mut().map(|b| &mut b[num_read..]); - let (data_read, levels_read) = - self.column_reader.as_mut().unwrap().read_batch( - num_to_read, - cur_def_levels_buf, - cur_rep_levels_buf, - cur_data_buf, - )?; - - // Fill space - if levels_read > data_read { - def_levels_buffer.iter().for_each(|def_levels_buffer| { - let (mut level_pos, mut data_pos) = (levels_read, data_read); - while level_pos > 0 && data_pos > 0 { - if def_levels_buffer[num_read + level_pos - 1] - == self.column_desc.max_def_level() - { - cur_data_buf.swap(level_pos - 1, data_pos - 1); - level_pos -= 1; - data_pos -= 1; - } else { - level_pos -= 1; - } - } - }); - } - - let values_read = levels_read.max(data_read); - num_read += values_read; - // current page exhausted && page iterator exhausted - if values_read < num_to_read && !self.next_column_reader()? { - break; - } - } - - data_buffer.truncate(num_read); - def_levels_buffer - .iter_mut() - .for_each(|buf| buf.truncate(num_read)); - rep_levels_buffer - .iter_mut() - .for_each(|buf| buf.truncate(num_read)); - - self.def_levels_buffer = def_levels_buffer; - self.rep_levels_buffer = rep_levels_buffer; - - let data: Vec> = if self.def_levels_buffer.is_some() { - data_buffer - .into_iter() - .zip(self.def_levels_buffer.as_ref().unwrap().iter()) - .map(|(t, def_level)| { - if *def_level == self.column_desc.max_def_level() { - Some(t) - } else { - None - } - }) - .collect() - } else { - data_buffer.into_iter().map(Some).collect() - }; - - let mut array = self.converter.convert(data)?; - - if let ArrowType::Dictionary(_, _) = self.data_type { - array = arrow::compute::cast(&array, &self.data_type)?; - } - - Ok(array) - } - - fn skip_records(&mut self, num_records: usize) -> Result { - match self.column_reader.as_mut() { - Some(reader) => reader.skip_records(num_records), - None => Ok(0), - } - } - - fn get_def_levels(&self) -> Option<&[i16]> { - self.def_levels_buffer.as_deref() - } - - fn get_rep_levels(&self) -> Option<&[i16]> { - self.rep_levels_buffer.as_deref() - } -} - -impl ComplexObjectArrayReader -where - T: DataType, - C: Converter>, ArrayRef> + 'static, -{ - pub fn new( - pages: Box, - column_desc: ColumnDescPtr, - converter: C, - arrow_type: Option, - ) -> Result { - let data_type = match arrow_type { - Some(t) => t, - None => parquet_to_arrow_field(column_desc.as_ref())? - .data_type() - .clone(), - }; - - Ok(Self { - data_type, - pages, - def_levels_buffer: None, - rep_levels_buffer: None, - column_desc, - column_reader: None, - converter, - _parquet_type_marker: PhantomData, - _converter_marker: PhantomData, - }) - } - - fn next_column_reader(&mut self) -> Result { - Ok(match self.pages.next() { - Some(page) => { - self.column_reader = - Some(ColumnReaderImpl::::new(self.column_desc.clone(), page?)); - true - } - None => false, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::arrow::buffer::converter::{Utf8ArrayConverter, Utf8Converter}; - use crate::basic::Encoding; - use crate::column::page::Page; - use crate::data_type::{ByteArray, ByteArrayType}; - use crate::schema::parser::parse_message_type; - use crate::schema::types::SchemaDescriptor; - use crate::util::{DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator}; - use arrow::array::StringArray; - use rand::{thread_rng, Rng}; - use std::sync::Arc; - - #[test] - fn test_complex_array_reader_no_pages() { - let message_type = " - message test_schema { - REPEATED Group test_mid { - OPTIONAL BYTE_ARRAY leaf (UTF8); - } - } - "; - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - let column_desc = schema.column(0); - let pages: Vec> = Vec::new(); - let page_iterator = InMemoryPageIterator::new(schema, column_desc.clone(), pages); - - let converter = Utf8Converter::new(Utf8ArrayConverter {}); - let mut array_reader = - ComplexObjectArrayReader::::new( - Box::new(page_iterator), - column_desc, - converter, - None, - ) - .unwrap(); - - let values_per_page = 100; // this value is arbitrary in this test - the result should always be an array of 0 length - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), 0); - } - - #[test] - fn test_complex_array_reader_def_and_rep_levels() { - // Construct column schema - let message_type = " - message test_schema { - REPEATED Group test_mid { - OPTIONAL BYTE_ARRAY leaf (UTF8); - } - } - "; - let num_pages = 2; - let values_per_page = 100; - let str_base = "Hello World"; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - - let max_def_level = schema.column(0).max_def_level(); - let max_rep_level = schema.column(0).max_rep_level(); - - assert_eq!(max_def_level, 2); - assert_eq!(max_rep_level, 1); - - let mut rng = thread_rng(); - let column_desc = schema.column(0); - let mut pages: Vec> = Vec::new(); - - let mut rep_levels = Vec::with_capacity(num_pages * values_per_page); - let mut def_levels = Vec::with_capacity(num_pages * values_per_page); - let mut all_values = Vec::with_capacity(num_pages * values_per_page); - - for i in 0..num_pages { - let mut values = Vec::with_capacity(values_per_page); - - for _ in 0..values_per_page { - let def_level = rng.gen_range(0..max_def_level + 1); - let rep_level = rng.gen_range(0..max_rep_level + 1); - if def_level == max_def_level { - let len = rng.gen_range(1..str_base.len()); - let slice = &str_base[..len]; - values.push(ByteArray::from(slice)); - all_values.push(Some(slice.to_string())); - } else { - all_values.push(None) - } - rep_levels.push(rep_level); - def_levels.push(def_level) - } - - let range = i * values_per_page..(i + 1) * values_per_page; - let mut pb = - DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); - - pb.add_rep_levels(max_rep_level, &rep_levels.as_slice()[range.clone()]); - pb.add_def_levels(max_def_level, &def_levels.as_slice()[range]); - pb.add_values::(Encoding::PLAIN, values.as_slice()); - - let data_page = pb.consume(); - pages.push(vec![data_page]); - } - - let page_iterator = InMemoryPageIterator::new(schema, column_desc.clone(), pages); - - let converter = Utf8Converter::new(Utf8ArrayConverter {}); - let mut array_reader = - ComplexObjectArrayReader::::new( - Box::new(page_iterator), - column_desc, - converter, - None, - ) - .unwrap(); - - let mut accu_len: usize = 0; - - let array = array_reader.next_batch(values_per_page / 2).unwrap(); - assert_eq!(array.len(), values_per_page / 2); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - accu_len += array.len(); - - // Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, - // and the last values_per_page/2 ones are from the second column chunk - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - let strings = array.as_any().downcast_ref::().unwrap(); - for i in 0..array.len() { - if array.is_valid(i) { - assert_eq!( - all_values[i + accu_len].as_ref().unwrap().as_str(), - strings.value(i) - ) - } else { - assert_eq!(all_values[i + accu_len], None) - } - } - accu_len += array.len(); - - // Try to read values_per_page values, however there are only values_per_page/2 values - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page / 2); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - } - - #[test] - fn test_complex_array_reader_dict_enc_string() { - use crate::encodings::encoding::{DictEncoder, Encoder}; - // Construct column schema - let message_type = " - message test_schema { - REPEATED Group test_mid { - OPTIONAL BYTE_ARRAY leaf (UTF8); - } - } - "; - let num_pages = 2; - let values_per_page = 100; - let str_base = "Hello World"; - - let schema = parse_message_type(message_type) - .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) - .unwrap(); - let column_desc = schema.column(0); - let max_def_level = column_desc.max_def_level(); - let max_rep_level = column_desc.max_rep_level(); - - assert_eq!(max_def_level, 2); - assert_eq!(max_rep_level, 1); - - let mut rng = thread_rng(); - let mut pages: Vec> = Vec::new(); - - let mut rep_levels = Vec::with_capacity(num_pages * values_per_page); - let mut def_levels = Vec::with_capacity(num_pages * values_per_page); - let mut all_values = Vec::with_capacity(num_pages * values_per_page); - - for i in 0..num_pages { - let mut dict_encoder = DictEncoder::::new(column_desc.clone()); - // add data page - let mut values = Vec::with_capacity(values_per_page); - - for _ in 0..values_per_page { - let def_level = rng.gen_range(0..max_def_level + 1); - let rep_level = rng.gen_range(0..max_rep_level + 1); - if def_level == max_def_level { - let len = rng.gen_range(1..str_base.len()); - let slice = &str_base[..len]; - values.push(ByteArray::from(slice)); - all_values.push(Some(slice.to_string())); - } else { - all_values.push(None) - } - rep_levels.push(rep_level); - def_levels.push(def_level) - } - - let range = i * values_per_page..(i + 1) * values_per_page; - let mut pb = - DataPageBuilderImpl::new(column_desc.clone(), values.len() as u32, true); - pb.add_rep_levels(max_rep_level, &rep_levels.as_slice()[range.clone()]); - pb.add_def_levels(max_def_level, &def_levels.as_slice()[range]); - let _ = dict_encoder.put(&values); - let indices = dict_encoder - .write_indices() - .expect("write_indices() should be OK"); - pb.add_indices(indices); - let data_page = pb.consume(); - // for each page log num_values vs actual values in page - // println!("page num_values: {}, values.len(): {}", data_page.num_values(), values.len()); - // add dictionary page - let dict = dict_encoder - .write_dict() - .expect("write_dict() should be OK"); - let dict_page = Page::DictionaryPage { - buf: dict, - num_values: dict_encoder.num_entries() as u32, - encoding: Encoding::RLE_DICTIONARY, - is_sorted: false, - }; - pages.push(vec![dict_page, data_page]); - } - - let page_iterator = InMemoryPageIterator::new(schema, column_desc.clone(), pages); - let converter = Utf8Converter::new(Utf8ArrayConverter {}); - let mut array_reader = - ComplexObjectArrayReader::::new( - Box::new(page_iterator), - column_desc, - converter, - None, - ) - .unwrap(); - - let mut accu_len: usize = 0; - - // println!("---------- reading a batch of {} values ----------", values_per_page / 2); - let array = array_reader.next_batch(values_per_page / 2).unwrap(); - assert_eq!(array.len(), values_per_page / 2); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - accu_len += array.len(); - - // Read next values_per_page values, the first values_per_page/2 ones are from the first column chunk, - // and the last values_per_page/2 ones are from the second column chunk - // println!("---------- reading a batch of {} values ----------", values_per_page); - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - let strings = array.as_any().downcast_ref::().unwrap(); - for i in 0..array.len() { - if array.is_valid(i) { - assert_eq!( - all_values[i + accu_len].as_ref().unwrap().as_str(), - strings.value(i) - ) - } else { - assert_eq!(all_values[i + accu_len], None) - } - } - accu_len += array.len(); - - // Try to read values_per_page values, however there are only values_per_page/2 values - // println!("---------- reading a batch of {} values ----------", values_per_page); - let array = array_reader.next_batch(values_per_page).unwrap(); - assert_eq!(array.len(), values_per_page / 2); - assert_eq!( - Some(&def_levels[accu_len..(accu_len + array.len())]), - array_reader.get_def_levels() - ); - assert_eq!( - Some(&rep_levels[accu_len..(accu_len + array.len())]), - array_reader.get_rep_levels() - ); - } -} diff --git a/parquet/src/arrow/array_reader/empty_array.rs b/parquet/src/arrow/array_reader/empty_array.rs index b06646cc1c6e..abe839b9dc29 100644 --- a/parquet/src/arrow/array_reader/empty_array.rs +++ b/parquet/src/arrow/array_reader/empty_array.rs @@ -33,6 +33,7 @@ pub fn make_empty_array_reader(row_count: usize) -> Box { struct EmptyArrayReader { data_type: ArrowType, remaining_rows: usize, + need_consume_records: usize, } impl EmptyArrayReader { @@ -40,6 +41,7 @@ impl EmptyArrayReader { Self { data_type: ArrowType::Struct(vec![]), remaining_rows: row_count, + need_consume_records: 0, } } } @@ -53,15 +55,19 @@ impl ArrayReader for EmptyArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { + fn read_records(&mut self, batch_size: usize) -> Result { let len = self.remaining_rows.min(batch_size); self.remaining_rows -= len; + self.need_consume_records += len; + Ok(len) + } + fn consume_batch(&mut self) -> Result { let data = ArrayDataBuilder::new(self.data_type.clone()) - .len(len) + .len(self.need_consume_records) .build() .unwrap(); - + self.need_consume_records = 0; Ok(Arc::new(StructArray::from(data))) } diff --git a/parquet/src/arrow/array_reader/fixed_len_byte_array.rs b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs new file mode 100644 index 000000000000..ba3a02c4f6b7 --- /dev/null +++ b/parquet/src/arrow/array_reader/fixed_len_byte_array.rs @@ -0,0 +1,475 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::array_reader::{read_records, skip_records, ArrayReader}; +use crate::arrow::buffer::bit_util::{iter_set_bits_rev, sign_extend_be}; +use crate::arrow::decoder::{DeltaByteArrayDecoder, DictIndexDecoder}; +use crate::arrow::record_reader::buffer::{BufferQueue, ScalarBuffer, ValuesBuffer}; +use crate::arrow::record_reader::GenericRecordReader; +use crate::arrow::schema::parquet_to_arrow_field; +use crate::basic::{Encoding, Type}; +use crate::column::page::PageIterator; +use crate::column::reader::decoder::{ColumnValueDecoder, ValuesBufferSlice}; +use crate::errors::{ParquetError, Result}; +use crate::schema::types::ColumnDescPtr; +use crate::util::memory::ByteBufferPtr; +use arrow::array::{ + ArrayDataBuilder, ArrayRef, Decimal128Array, FixedSizeBinaryArray, + IntervalDayTimeArray, IntervalYearMonthArray, +}; +use arrow::buffer::Buffer; +use arrow::datatypes::{DataType as ArrowType, IntervalUnit}; +use std::any::Any; +use std::ops::Range; +use std::sync::Arc; + +/// Returns an [`ArrayReader`] that decodes the provided fixed length byte array column +pub fn make_fixed_len_byte_array_reader( + pages: Box, + column_desc: ColumnDescPtr, + arrow_type: Option, +) -> Result> { + // Check if Arrow type is specified, else create it from Parquet type + let data_type = match arrow_type { + Some(t) => t, + None => parquet_to_arrow_field(column_desc.as_ref())? + .data_type() + .clone(), + }; + + let byte_length = match column_desc.physical_type() { + Type::FIXED_LEN_BYTE_ARRAY => column_desc.type_length() as usize, + t => { + return Err(general_err!( + "invalid physical type for fixed length byte array reader - {}", + t + )) + } + }; + + match &data_type { + ArrowType::FixedSizeBinary(_) => {} + ArrowType::Decimal128(_, _) => { + if byte_length > 16 { + return Err(general_err!( + "decimal 128 type too large, must be less than 16 bytes, got {}", + byte_length + )); + } + } + ArrowType::Interval(_) => { + if byte_length != 12 { + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#interval + return Err(general_err!( + "interval type must consist of 12 bytes got {}", + byte_length + )); + } + } + _ => { + return Err(general_err!( + "invalid data type for fixed length byte array reader - {}", + data_type + )) + } + } + + Ok(Box::new(FixedLenByteArrayReader::new( + pages, + column_desc, + data_type, + byte_length, + ))) +} + +struct FixedLenByteArrayReader { + data_type: ArrowType, + byte_length: usize, + pages: Box, + def_levels_buffer: Option, + rep_levels_buffer: Option, + record_reader: GenericRecordReader, +} + +impl FixedLenByteArrayReader { + fn new( + pages: Box, + column_desc: ColumnDescPtr, + data_type: ArrowType, + byte_length: usize, + ) -> Self { + Self { + data_type, + byte_length, + pages, + def_levels_buffer: None, + rep_levels_buffer: None, + record_reader: GenericRecordReader::new_with_records( + column_desc, + FixedLenByteArrayBuffer { + buffer: Default::default(), + byte_length, + }, + ), + } + } +} + +impl ArrayReader for FixedLenByteArrayReader { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_data_type(&self) -> &ArrowType { + &self.data_type + } + + fn read_records(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size) + } + + fn consume_batch(&mut self) -> Result { + let record_data = self.record_reader.consume_record_data(); + + let array_data = + ArrayDataBuilder::new(ArrowType::FixedSizeBinary(self.byte_length as i32)) + .len(self.record_reader.num_values()) + .add_buffer(record_data) + .null_bit_buffer(self.record_reader.consume_bitmap_buffer()); + + let binary = FixedSizeBinaryArray::from(unsafe { array_data.build_unchecked() }); + + // TODO: An improvement might be to do this conversion on read + let array = match &self.data_type { + ArrowType::Decimal128(p, s) => { + let decimal = binary + .iter() + .map(|opt| Some(i128::from_be_bytes(sign_extend_be(opt?)))) + .collect::() + .with_precision_and_scale(*p, *s)?; + + Arc::new(decimal) + } + ArrowType::Interval(unit) => { + // An interval is stored as 3x 32-bit unsigned integers storing months, days, + // and milliseconds + match unit { + IntervalUnit::YearMonth => Arc::new( + binary + .iter() + .map(|o| { + o.map(|b| i32::from_le_bytes(b[0..4].try_into().unwrap())) + }) + .collect::(), + ) as ArrayRef, + IntervalUnit::DayTime => Arc::new( + binary + .iter() + .map(|o| { + o.map(|b| { + i64::from_le_bytes(b[4..12].try_into().unwrap()) + }) + }) + .collect::(), + ) as ArrayRef, + IntervalUnit::MonthDayNano => { + return Err(nyi_err!("MonthDayNano intervals not supported")); + } + } + } + _ => Arc::new(binary) as ArrayRef, + }; + + self.def_levels_buffer = self.record_reader.consume_def_levels(); + self.rep_levels_buffer = self.record_reader.consume_rep_levels(); + self.record_reader.reset(); + + Ok(array) + } + + fn skip_records(&mut self, num_records: usize) -> Result { + skip_records(&mut self.record_reader, self.pages.as_mut(), num_records) + } + + fn get_def_levels(&self) -> Option<&[i16]> { + self.def_levels_buffer.as_ref().map(|buf| buf.typed_data()) + } + + fn get_rep_levels(&self) -> Option<&[i16]> { + self.rep_levels_buffer.as_ref().map(|buf| buf.typed_data()) + } +} + +struct FixedLenByteArrayBuffer { + buffer: ScalarBuffer, + /// The length of each element in bytes + byte_length: usize, +} + +impl ValuesBufferSlice for FixedLenByteArrayBuffer { + fn capacity(&self) -> usize { + usize::MAX + } +} + +impl BufferQueue for FixedLenByteArrayBuffer { + type Output = Buffer; + type Slice = Self; + + fn split_off(&mut self, len: usize) -> Self::Output { + self.buffer.split_off(len * self.byte_length) + } + + fn spare_capacity_mut(&mut self, _batch_size: usize) -> &mut Self::Slice { + self + } + + fn set_len(&mut self, len: usize) { + assert_eq!(self.buffer.len(), len * self.byte_length); + } +} + +impl ValuesBuffer for FixedLenByteArrayBuffer { + fn pad_nulls( + &mut self, + read_offset: usize, + values_read: usize, + levels_read: usize, + valid_mask: &[u8], + ) { + assert_eq!( + self.buffer.len(), + (read_offset + values_read) * self.byte_length + ); + self.buffer + .resize((read_offset + levels_read) * self.byte_length); + + let slice = self.buffer.as_slice_mut(); + + let values_range = read_offset..read_offset + values_read; + for (value_pos, level_pos) in + values_range.rev().zip(iter_set_bits_rev(valid_mask)) + { + debug_assert!(level_pos >= value_pos); + if level_pos <= value_pos { + break; + } + + let level_pos_bytes = level_pos * self.byte_length; + let value_pos_bytes = value_pos * self.byte_length; + + for i in 0..self.byte_length { + slice[level_pos_bytes + i] = slice[value_pos_bytes + i] + } + } + } +} + +struct ValueDecoder { + byte_length: usize, + dict_page: Option, + decoder: Option, +} + +impl ColumnValueDecoder for ValueDecoder { + type Slice = FixedLenByteArrayBuffer; + + fn new(col: &ColumnDescPtr) -> Self { + Self { + byte_length: col.type_length() as usize, + dict_page: None, + decoder: None, + } + } + + fn set_dict( + &mut self, + buf: ByteBufferPtr, + num_values: u32, + encoding: Encoding, + _is_sorted: bool, + ) -> Result<()> { + if !matches!( + encoding, + Encoding::PLAIN | Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY + ) { + return Err(nyi_err!( + "Invalid/Unsupported encoding type for dictionary: {}", + encoding + )); + } + let expected_len = num_values as usize * self.byte_length; + if expected_len > buf.len() { + return Err(general_err!( + "too few bytes in dictionary page, expected {} got {}", + expected_len, + buf.len() + )); + } + + self.dict_page = Some(buf); + Ok(()) + } + + fn set_data( + &mut self, + encoding: Encoding, + data: ByteBufferPtr, + num_levels: usize, + num_values: Option, + ) -> Result<()> { + self.decoder = Some(match encoding { + Encoding::PLAIN => Decoder::Plain { + buf: data, + offset: 0, + }, + Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => Decoder::Dict { + decoder: DictIndexDecoder::new(data, num_levels, num_values), + }, + Encoding::DELTA_BYTE_ARRAY => Decoder::Delta { + decoder: DeltaByteArrayDecoder::new(data)?, + }, + _ => { + return Err(general_err!( + "unsupported encoding for fixed length byte array: {}", + encoding + )) + } + }); + Ok(()) + } + + fn read(&mut self, out: &mut Self::Slice, range: Range) -> Result { + assert_eq!(self.byte_length, out.byte_length); + + let len = range.end - range.start; + match self.decoder.as_mut().unwrap() { + Decoder::Plain { offset, buf } => { + let to_read = + (len * self.byte_length).min(buf.len() - *offset) / self.byte_length; + let end_offset = *offset + to_read * self.byte_length; + out.buffer + .extend_from_slice(&buf.as_ref()[*offset..end_offset]); + *offset = end_offset; + Ok(to_read) + } + Decoder::Dict { decoder } => { + let dict = self.dict_page.as_ref().unwrap(); + // All data must be NULL + if dict.is_empty() { + return Ok(0); + } + + decoder.read(len, |keys| { + out.buffer.reserve(keys.len() * self.byte_length); + for key in keys { + let offset = *key as usize * self.byte_length; + let val = &dict.as_ref()[offset..offset + self.byte_length]; + out.buffer.extend_from_slice(val); + } + Ok(()) + }) + } + Decoder::Delta { decoder } => { + let to_read = len.min(decoder.remaining()); + out.buffer.reserve(to_read * self.byte_length); + + decoder.read(to_read, |slice| { + if slice.len() != self.byte_length { + return Err(general_err!( + "encountered array with incorrect length, got {} expected {}", + slice.len(), + self.byte_length + )); + } + out.buffer.extend_from_slice(slice); + Ok(()) + }) + } + } + } + + fn skip_values(&mut self, num_values: usize) -> Result { + match self.decoder.as_mut().unwrap() { + Decoder::Plain { offset, buf } => { + let to_read = num_values.min((buf.len() - *offset) / self.byte_length); + *offset += to_read * self.byte_length; + Ok(to_read) + } + Decoder::Dict { decoder } => decoder.skip(num_values), + Decoder::Delta { decoder } => decoder.skip(num_values), + } + } +} + +enum Decoder { + Plain { buf: ByteBufferPtr, offset: usize }, + Dict { decoder: DictIndexDecoder }, + Delta { decoder: DeltaByteArrayDecoder }, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::arrow::arrow_reader::ParquetRecordBatchReader; + use crate::arrow::ArrowWriter; + use arrow::array::{Array, Decimal128Array, ListArray}; + use arrow::datatypes::Field; + use arrow::error::Result as ArrowResult; + use arrow::record_batch::RecordBatch; + use bytes::Bytes; + use std::sync::Arc; + + #[test] + fn test_decimal_list() { + let decimals = Decimal128Array::from_iter_values([1, 2, 3, 4, 5, 6, 7, 8]); + + // [[], [1], [2, 3], null, [4], null, [6, 7, 8]] + let data = ArrayDataBuilder::new(ArrowType::List(Box::new(Field::new( + "item", + decimals.data_type().clone(), + false, + )))) + .len(7) + .add_buffer(Buffer::from_iter([0_i32, 0, 1, 3, 3, 4, 5, 8])) + .null_bit_buffer(Some(Buffer::from(&[0b01010111]))) + .child_data(vec![decimals.into_data()]) + .build() + .unwrap(); + + let written = RecordBatch::try_from_iter([( + "list", + Arc::new(ListArray::from(data)) as ArrayRef, + )]) + .unwrap(); + + let mut buffer = Vec::with_capacity(1024); + let mut writer = + ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap(); + writer.write(&written).unwrap(); + writer.close().unwrap(); + + let read = ParquetRecordBatchReader::try_new(Bytes::from(buffer), 3) + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(&written.slice(0, 3), &read[0]); + assert_eq!(&written.slice(3, 3), &read[1]); + assert_eq!(&written.slice(6, 1), &read[2]); + } +} diff --git a/parquet/src/arrow/array_reader/list_array.rs b/parquet/src/arrow/array_reader/list_array.rs index 33bd9772a16e..d2fa94611906 100644 --- a/parquet/src/arrow/array_reader/list_array.rs +++ b/parquet/src/arrow/array_reader/list_array.rs @@ -34,7 +34,6 @@ use std::sync::Arc; pub struct ListArrayReader { item_reader: Box, data_type: ArrowType, - item_type: ArrowType, /// The definition level at which this list is not null def_level: i16, /// The repetition level that corresponds to a new value in this array @@ -49,7 +48,6 @@ impl ListArrayReader { pub fn new( item_reader: Box, data_type: ArrowType, - item_type: ArrowType, def_level: i16, rep_level: i16, nullable: bool, @@ -57,7 +55,6 @@ impl ListArrayReader { Self { item_reader, data_type, - item_type, def_level, rep_level, nullable, @@ -78,9 +75,13 @@ impl ArrayReader for ListArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { - let next_batch_array = self.item_reader.next_batch(batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + let size = self.item_reader.read_records(batch_size)?; + Ok(size) + } + fn consume_batch(&mut self) -> Result { + let next_batch_array = self.item_reader.consume_batch()?; if next_batch_array.len() == 0 { return Ok(new_empty_array(&self.data_type)); } @@ -264,10 +265,7 @@ mod tests { item_nullable: bool, ) -> ArrowType { let field = Box::new(Field::new("item", data_type, item_nullable)); - match OffsetSize::IS_LARGE { - true => ArrowType::LargeList(field), - false => ArrowType::List(field), - } + GenericListArray::::DATA_TYPE_CONSTRUCTOR(field) } fn downcast( @@ -303,13 +301,13 @@ mod tests { // ] let l3_item_type = ArrowType::Int32; - let l3_type = list_type::(l3_item_type.clone(), true); + let l3_type = list_type::(l3_item_type, true); let l2_item_type = l3_type.clone(); - let l2_type = list_type::(l2_item_type.clone(), true); + let l2_type = list_type::(l2_item_type, true); let l1_item_type = l2_type.clone(); - let l1_type = list_type::(l1_item_type.clone(), false); + let l1_type = list_type::(l1_item_type, false); let leaf = PrimitiveArray::::from_iter(vec![ Some(1), @@ -386,7 +384,6 @@ mod tests { let l3 = ListArrayReader::::new( Box::new(item_array_reader), l3_type, - l3_item_type, 5, 3, true, @@ -395,7 +392,6 @@ mod tests { let l2 = ListArrayReader::::new( Box::new(l3), l2_type, - l2_item_type, 3, 2, false, @@ -404,7 +400,6 @@ mod tests { let mut l1 = ListArrayReader::::new( Box::new(l2), l1_type, - l1_item_type, 2, 1, true, @@ -455,7 +450,6 @@ mod tests { let mut list_array_reader = ListArrayReader::::new( Box::new(item_array_reader), list_type::(ArrowType::Int32, true), - ArrowType::Int32, 1, 1, false, @@ -508,7 +502,6 @@ mod tests { let mut list_array_reader = ListArrayReader::::new( Box::new(item_array_reader), list_type::(ArrowType::Int32, true), - ArrowType::Int32, 2, 1, true, @@ -589,13 +582,9 @@ mod tests { let schema = file_metadata.schema_descr_ptr(); let mask = ProjectionMask::leaves(&schema, vec![0]); - let mut array_reader = build_array_reader( - schema, - Arc::new(arrow_schema), - mask, - Box::new(file_reader), - ) - .unwrap(); + let mut array_reader = + build_array_reader(Arc::new(arrow_schema), mask, &file_reader) + .unwrap(); let batch = array_reader.next_batch(100).unwrap(); assert_eq!(batch.data_type(), array_reader.get_data_type()); diff --git a/parquet/src/arrow/array_reader/map_array.rs b/parquet/src/arrow/array_reader/map_array.rs index 00c3db41a37c..bb80fdbdc5f7 100644 --- a/parquet/src/arrow/array_reader/map_array.rs +++ b/parquet/src/arrow/array_reader/map_array.rs @@ -15,41 +15,67 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::array_reader::ArrayReader; -use crate::errors::ParquetError::ArrowError; -use crate::errors::{ParquetError, Result}; -use arrow::array::{Array, ArrayDataBuilder, ArrayRef, MapArray}; -use arrow::buffer::{Buffer, MutableBuffer}; +use crate::arrow::array_reader::{ArrayReader, ListArrayReader, StructArrayReader}; +use crate::errors::Result; +use arrow::array::{Array, ArrayRef, MapArray}; use arrow::datatypes::DataType as ArrowType; -use arrow::datatypes::ToByteSlice; -use arrow::util::bit_util; use std::any::Any; use std::sync::Arc; /// Implementation of a map array reader. pub struct MapArrayReader { - key_reader: Box, - value_reader: Box, data_type: ArrowType, - map_def_level: i16, - map_rep_level: i16, + reader: ListArrayReader, } impl MapArrayReader { + /// Creates a new [`MapArrayReader`] with a `def_level`, `rep_level` and `nullable` + /// as defined on [`ParquetField`][crate::arrow::schema::ParquetField] pub fn new( key_reader: Box, value_reader: Box, data_type: ArrowType, def_level: i16, rep_level: i16, + nullable: bool, ) -> Self { - Self { - key_reader, - value_reader, - data_type, - map_def_level: rep_level, - map_rep_level: def_level, - } + let struct_def_level = match nullable { + true => def_level + 2, + false => def_level + 1, + }; + let struct_rep_level = rep_level + 1; + + let element = match &data_type { + ArrowType::Map(element, _) => match element.data_type() { + ArrowType::Struct(fields) if fields.len() == 2 => { + // Parquet cannot represent nullability at this level (#1697) + // and so encountering nullability here indicates some manner + // of schema inconsistency / inference bug + assert!(!element.is_nullable(), "map struct cannot be nullable"); + element + } + _ => unreachable!("expected struct with two fields"), + }, + _ => unreachable!("expected map type"), + }; + + let struct_reader = StructArrayReader::new( + element.data_type().clone(), + vec![key_reader, value_reader], + struct_def_level, + struct_rep_level, + false, + ); + + let reader = ListArrayReader::new( + Box::new(struct_reader), + ArrowType::List(element.clone()), + def_level, + rep_level, + nullable, + ); + + Self { data_type, reader } } } @@ -62,120 +88,129 @@ impl ArrayReader for MapArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { - let key_array = self.key_reader.next_batch(batch_size)?; - let value_array = self.value_reader.next_batch(batch_size)?; - - // Check that key and value have the same lengths - let key_length = key_array.len(); - if key_length != value_array.len() { - return Err(general_err!( - "Map key and value should have the same lengths." - )); - } - - let def_levels = self - .key_reader - .get_def_levels() - .ok_or_else(|| ArrowError("item_reader def levels are None.".to_string()))?; - let rep_levels = self - .key_reader - .get_rep_levels() - .ok_or_else(|| ArrowError("item_reader rep levels are None.".to_string()))?; - - if !((def_levels.len() == rep_levels.len()) && (rep_levels.len() == key_length)) { - return Err(ArrowError( - "Expected item_reader def_levels and rep_levels to be same length as batch".to_string(), - )); - } - - let entry_data_type = if let ArrowType::Map(field, _) = &self.data_type { - field.data_type().clone() - } else { - return Err(ArrowError("Expected a map arrow type".to_string())); - }; - - let entry_data = ArrayDataBuilder::new(entry_data_type) - .len(key_length) - .add_child_data(key_array.into_data()) - .add_child_data(value_array.into_data()); - let entry_data = unsafe { entry_data.build_unchecked() }; - - let entry_len = rep_levels.iter().filter(|level| **level == 0).count(); - - // first item in each list has rep_level = 0, subsequent items have rep_level = 1 - let mut offsets: Vec = Vec::new(); - let mut cur_offset = 0; - def_levels.iter().zip(rep_levels).for_each(|(d, r)| { - if *r == 0 || d == &self.map_def_level { - offsets.push(cur_offset); - } - if d > &self.map_def_level { - cur_offset += 1; - } - }); - offsets.push(cur_offset); - - let num_bytes = bit_util::ceil(offsets.len(), 8); - // TODO: A useful optimization is to use the null count to fill with - // 0 or null, to reduce individual bits set in a loop. - // To favour dense data, set every slot to true, then unset - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); - let null_slice = null_buf.as_slice_mut(); - let mut list_index = 0; - for i in 0..rep_levels.len() { - // If the level is lower than empty, then the slot is null. - // When a list is non-nullable, its empty level = null level, - // so this automatically factors that in. - if rep_levels[i] == 0 && def_levels[i] < self.map_def_level { - // should be empty list - bit_util::unset_bit(null_slice, list_index); - } - if rep_levels[i] == 0 { - list_index += 1; - } - } - let value_offsets = Buffer::from(&offsets.to_byte_slice()); - - // Now we can build array data - let array_data = ArrayDataBuilder::new(self.data_type.clone()) - .len(entry_len) - .add_buffer(value_offsets) - .null_bit_buffer(Some(null_buf.into())) - .add_child_data(entry_data); - - let array_data = unsafe { array_data.build_unchecked() }; + fn read_records(&mut self, batch_size: usize) -> Result { + self.reader.read_records(batch_size) + } - Ok(Arc::new(MapArray::from(array_data))) + fn consume_batch(&mut self) -> Result { + // A MapArray is just a ListArray with a StructArray child + // we can therefore just alter the ArrayData + let array = self.reader.consume_batch().unwrap(); + let data = array.data().clone(); + let builder = data.into_builder().data_type(self.data_type.clone()); + + // SAFETY - we can assume that ListArrayReader produces valid ListArray + // of the expected type, and as such its output can be reinterpreted as + // a MapArray without validation + Ok(Arc::new(MapArray::from(unsafe { + builder.build_unchecked() + }))) } fn skip_records(&mut self, num_records: usize) -> Result { - let key_skipped = self.key_reader.skip_records(num_records)?; - let value_skipped = self.value_reader.skip_records(num_records)?; - if key_skipped != value_skipped { - return Err(general_err!( - "MapArrayReader out of sync, skipped {} keys and {} values", - key_skipped, - value_skipped - )); - } - Ok(key_skipped) + self.reader.skip_records(num_records) } fn get_def_levels(&self) -> Option<&[i16]> { - // Children definition levels should describe the same parent structure, - // so return key_reader only - self.key_reader.get_def_levels() + self.reader.get_def_levels() } fn get_rep_levels(&self) -> Option<&[i16]> { - // Children repetition levels should describe the same parent structure, - // so return key_reader only - self.key_reader.get_rep_levels() + self.reader.get_rep_levels() } } #[cfg(test)] mod tests { - //TODO: Add unit tests (#1561) + use super::*; + use crate::arrow::arrow_reader::ParquetRecordBatchReader; + use crate::arrow::ArrowWriter; + use arrow::array; + use arrow::array::{MapBuilder, PrimitiveBuilder, StringBuilder}; + use arrow::datatypes::{Field, Int32Type, Schema}; + use arrow::record_batch::RecordBatch; + use bytes::Bytes; + + #[test] + // This test writes a parquet file with the following data: + // +--------------------------------------------------------+ + // |map | + // +--------------------------------------------------------+ + // |null | + // |null | + // |{three -> 3, four -> 4, five -> 5, six -> 6, seven -> 7}| + // +--------------------------------------------------------+ + // + // It then attempts to read the data back and checks that the third record + // contains the expected values. + fn read_map_array_column() { + // Schema for single map of string to int32 + let schema = Schema::new(vec![Field::new( + "map", + ArrowType::Map( + Box::new(Field::new( + "entries", + ArrowType::Struct(vec![ + Field::new("keys", ArrowType::Utf8, false), + Field::new("values", ArrowType::Int32, true), + ]), + false, + )), + false, // Map field not sorted + ), + true, + )]); + + // Create builders for map + let string_builder = StringBuilder::new(); + let ints_builder: PrimitiveBuilder = PrimitiveBuilder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, ints_builder); + + // Add two null records and one record with five entries + map_builder.append(false).expect("adding null map entry"); + map_builder.append(false).expect("adding null map entry"); + map_builder.keys().append_value("three"); + map_builder.keys().append_value("four"); + map_builder.keys().append_value("five"); + map_builder.keys().append_value("six"); + map_builder.keys().append_value("seven"); + + map_builder.values().append_value(3); + map_builder.values().append_value(4); + map_builder.values().append_value(5); + map_builder.values().append_value(6); + map_builder.values().append_value(7); + map_builder.append(true).expect("adding map entry"); + + // Create record batch + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(map_builder.finish())]) + .expect("create record batch"); + + // Write record batch to file + let mut buffer = Vec::with_capacity(1024); + let mut writer = ArrowWriter::try_new(&mut buffer, batch.schema(), None) + .expect("creat file writer"); + writer.write(&batch).expect("writing file"); + writer.close().expect("close writer"); + + // Read file + let reader = Bytes::from(buffer); + let record_batch_reader = + ParquetRecordBatchReader::try_new(reader, 1024).unwrap(); + for maybe_record_batch in record_batch_reader { + let record_batch = maybe_record_batch.expect("Getting current batch"); + let col = record_batch.column(0); + assert!(col.is_null(0)); + assert!(col.is_null(1)); + let map_entry = array::as_map_array(col).value(2); + let struct_col = array::as_struct_array(&map_entry); + let key_col = array::as_string_array(struct_col.column(0)); // Key column + assert_eq!(key_col.value(0), "three"); + assert_eq!(key_col.value(1), "four"); + assert_eq!(key_col.value(2), "five"); + assert_eq!(key_col.value(3), "six"); + assert_eq!(key_col.value(4), "seven"); + } + } } diff --git a/parquet/src/arrow/array_reader/mod.rs b/parquet/src/arrow/array_reader/mod.rs index e30c33bba35c..3740f0faea69 100644 --- a/parquet/src/arrow/array_reader/mod.rs +++ b/parquet/src/arrow/array_reader/mod.rs @@ -33,8 +33,8 @@ use crate::schema::types::SchemaDescPtr; mod builder; mod byte_array; mod byte_array_dictionary; -mod complex_object_array; mod empty_array; +mod fixed_len_byte_array; mod list_array; mod map_array; mod null_array; @@ -47,7 +47,7 @@ mod test_util; pub use builder::build_array_reader; pub use byte_array::make_byte_array_reader; pub use byte_array_dictionary::make_byte_array_dictionary_reader; -pub use complex_object_array::ComplexObjectArrayReader; +pub use fixed_len_byte_array::make_fixed_len_byte_array_reader; pub use list_array::ListArrayReader; pub use map_array::MapArrayReader; pub use null_array::NullArrayReader; @@ -62,7 +62,20 @@ pub trait ArrayReader: Send { fn get_data_type(&self) -> &ArrowType; /// Reads at most `batch_size` records into an arrow array and return it. - fn next_batch(&mut self, batch_size: usize) -> Result; + fn next_batch(&mut self, batch_size: usize) -> Result { + self.read_records(batch_size)?; + self.consume_batch() + } + + /// Reads at most `batch_size` records' bytes into buffer + /// + /// Returns the number of records read, which can be less than `batch_size` if + /// pages is exhausted. + fn read_records(&mut self, batch_size: usize) -> Result; + + /// Consume all currently stored buffer data + /// into an arrow array and return it. + fn consume_batch(&mut self) -> Result; /// Skips over `num_records` records, returning the number of rows skipped fn skip_records(&mut self, num_records: usize) -> Result; @@ -87,7 +100,7 @@ pub trait ArrayReader: Send { /// A collection of row groups pub trait RowGroupCollection { /// Get schema of parquet file. - fn schema(&self) -> Result; + fn schema(&self) -> SchemaDescPtr; /// Get the numer of rows in this collection fn num_rows(&self) -> usize; @@ -97,8 +110,8 @@ pub trait RowGroupCollection { } impl RowGroupCollection for Arc { - fn schema(&self) -> Result { - Ok(self.metadata().file_metadata().schema_descr_ptr()) + fn schema(&self) -> SchemaDescPtr { + self.metadata().file_metadata().schema_descr_ptr() } fn num_rows(&self) -> usize { @@ -111,9 +124,56 @@ impl RowGroupCollection for Arc { } } +pub(crate) struct FileReaderRowGroupCollection { + /// The underling file reader + reader: Arc, + /// Optional list of row group indices to scan + row_groups: Option>, +} + +impl FileReaderRowGroupCollection { + /// Creates a new [`RowGroupCollection`] from a `FileReader` and an optional + /// list of row group indexes to scan + pub fn new(reader: Arc, row_groups: Option>) -> Self { + Self { reader, row_groups } + } +} + +impl RowGroupCollection for FileReaderRowGroupCollection { + fn schema(&self) -> SchemaDescPtr { + self.reader.metadata().file_metadata().schema_descr_ptr() + } + + fn num_rows(&self) -> usize { + match &self.row_groups { + None => self.reader.metadata().file_metadata().num_rows() as usize, + Some(row_groups) => { + let meta = self.reader.metadata().row_groups(); + row_groups + .iter() + .map(|x| meta[*x].num_rows() as usize) + .sum() + } + } + } + + fn column_chunks(&self, i: usize) -> Result> { + let iterator = match &self.row_groups { + Some(row_groups) => FilePageIterator::with_row_groups( + i, + Box::new(row_groups.clone().into_iter()), + Arc::clone(&self.reader), + )?, + None => FilePageIterator::new(i, Arc::clone(&self.reader))?, + }; + + Ok(Box::new(iterator)) + } +} + /// Uses `record_reader` to read up to `batch_size` records from `pages` /// -/// Returns the number of records read, which can be less than batch_size if +/// Returns the number of records read, which can be less than `batch_size` if /// pages is exhausted. fn read_records( record_reader: &mut GenericRecordReader, @@ -121,7 +181,7 @@ fn read_records( batch_size: usize, ) -> Result where - V: ValuesBuffer + Default, + V: ValuesBuffer, CV: ColumnValueDecoder, { let mut records_read = 0usize; @@ -144,3 +204,37 @@ where } Ok(records_read) } + +/// Uses `record_reader` to skip up to `batch_size` records from`pages` +/// +/// Returns the number of records skipped, which can be less than `batch_size` if +/// pages is exhausted +fn skip_records( + record_reader: &mut GenericRecordReader, + pages: &mut dyn PageIterator, + batch_size: usize, +) -> Result +where + V: ValuesBuffer, + CV: ColumnValueDecoder, +{ + let mut records_skipped = 0usize; + while records_skipped < batch_size { + let records_to_read = batch_size - records_skipped; + + let records_skipped_once = record_reader.skip_records(records_to_read)?; + records_skipped += records_skipped_once; + + // Record reader exhausted + if records_skipped_once < records_to_read { + if let Some(page_reader) = pages.next() { + // Read from new page reader (i.e. column chunk) + record_reader.set_page_reader(page_reader?)?; + } else { + // Page reader also exhausted + break; + } + } + } + Ok(records_skipped) +} diff --git a/parquet/src/arrow/array_reader/null_array.rs b/parquet/src/arrow/array_reader/null_array.rs index b207d8b2c56e..405633f0a823 100644 --- a/parquet/src/arrow/array_reader/null_array.rs +++ b/parquet/src/arrow/array_reader/null_array.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::array_reader::{read_records, ArrayReader}; +use crate::arrow::array_reader::{read_records, skip_records, ArrayReader}; use crate::arrow::record_reader::buffer::ScalarValue; use crate::arrow::record_reader::RecordReader; use crate::column::page::PageIterator; @@ -39,7 +39,6 @@ where pages: Box, def_levels_buffer: Option, rep_levels_buffer: Option, - column_desc: ColumnDescPtr, record_reader: RecordReader, } @@ -50,14 +49,13 @@ where { /// Construct null array reader. pub fn new(pages: Box, column_desc: ColumnDescPtr) -> Result { - let record_reader = RecordReader::::new(column_desc.clone()); + let record_reader = RecordReader::::new(column_desc); Ok(Self { data_type: ArrowType::Null, pages, def_levels_buffer: None, rep_levels_buffer: None, - column_desc, record_reader, }) } @@ -78,10 +76,11 @@ where &self.data_type } - /// Reads at most `batch_size` records into array. - fn next_batch(&mut self, batch_size: usize) -> Result { - read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size) + } + fn consume_batch(&mut self) -> Result { // convert to arrays let array = arrow::array::NullArray::new(self.record_reader.num_values()); @@ -97,7 +96,7 @@ where } fn skip_records(&mut self, num_records: usize) -> Result { - self.record_reader.skip_records(num_records) + skip_records(&mut self.record_reader, self.pages.as_mut(), num_records) } fn get_def_levels(&self) -> Option<&[i16]> { diff --git a/parquet/src/arrow/array_reader/primitive_array.rs b/parquet/src/arrow/array_reader/primitive_array.rs index cb41d1fba9c2..d4f96e6a8d60 100644 --- a/parquet/src/arrow/array_reader/primitive_array.rs +++ b/parquet/src/arrow/array_reader/primitive_array.rs @@ -15,21 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::arrow::array_reader::{read_records, ArrayReader}; +use crate::arrow::array_reader::{read_records, skip_records, ArrayReader}; use crate::arrow::record_reader::buffer::ScalarValue; use crate::arrow::record_reader::RecordReader; use crate::arrow::schema::parquet_to_arrow_field; use crate::basic::Type as PhysicalType; use crate::column::page::PageIterator; -use crate::data_type::DataType; +use crate::data_type::{DataType, Int96}; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; use arrow::array::{ - ArrayDataBuilder, ArrayRef, BooleanArray, BooleanBufferBuilder, DecimalArray, - Float32Array, Float64Array, Int32Array, Int64Array, + ArrayDataBuilder, ArrayRef, BooleanArray, BooleanBufferBuilder, Decimal128Array, + Float32Array, Float64Array, Int32Array, Int64Array,TimestampNanosecondArray, TimestampNanosecondBufferBuilder, }; use arrow::buffer::Buffer; -use arrow::datatypes::DataType as ArrowType; +use arrow::datatypes::{DataType as ArrowType, TimeUnit}; use std::any::Any; use std::sync::Arc; @@ -44,7 +44,6 @@ where pages: Box, def_levels_buffer: Option, rep_levels_buffer: Option, - column_desc: ColumnDescPtr, record_reader: RecordReader, } @@ -58,17 +57,6 @@ where pages: Box, column_desc: ColumnDescPtr, arrow_type: Option, - ) -> Result { - Self::new_with_options(pages, column_desc, arrow_type, false) - } - - /// Construct primitive array reader with ability to only compute null mask and not - /// buffer level data - pub fn new_with_options( - pages: Box, - column_desc: ColumnDescPtr, - arrow_type: Option, - null_mask_only: bool, ) -> Result { // Check if Arrow type is specified, else create it from Parquet type let data_type = match arrow_type { @@ -78,15 +66,13 @@ where .clone(), }; - let record_reader = - RecordReader::::new_with_options(column_desc.clone(), null_mask_only); + let record_reader = RecordReader::::new(column_desc); Ok(Self { data_type, pages, def_levels_buffer: None, rep_levels_buffer: None, - column_desc, record_reader, }) } @@ -107,11 +93,12 @@ where &self.data_type } - /// Reads at most `batch_size` records into array. - fn next_batch(&mut self, batch_size: usize) -> Result { - read_records(&mut self.record_reader, self.pages.as_mut(), batch_size)?; + fn read_records(&mut self, batch_size: usize) -> Result { + read_records(&mut self.record_reader, self.pages.as_mut(), batch_size) + } - let target_type = self.get_data_type().clone(); + fn consume_batch(&mut self) -> Result { + let target_type = &self.data_type; let arrow_data_type = match T::get_physical_type() { PhysicalType::BOOLEAN => ArrowType::Boolean, PhysicalType::INT32 => { @@ -136,9 +123,11 @@ where } PhysicalType::FLOAT => ArrowType::Float32, PhysicalType::DOUBLE => ArrowType::Float64, - PhysicalType::INT96 - | PhysicalType::BYTE_ARRAY - | PhysicalType::FIXED_LEN_BYTE_ARRAY => { + PhysicalType::INT96 => match target_type { + ArrowType::Timestamp(TimeUnit::Nanosecond, _) => target_type.clone(), + _ => unreachable!("INT96 must be timestamp nanosecond"), + }, + PhysicalType::BYTE_ARRAY | PhysicalType::FIXED_LEN_BYTE_ARRAY => { unreachable!( "PrimitiveArrayReaders don't support complex physical types" ); @@ -148,16 +137,31 @@ where // Convert to arrays by using the Parquet physical type. // The physical types are then cast to Arrow types if necessary - let mut record_data = self.record_reader.consume_record_data(); + let record_data = self.record_reader.consume_record_data(); + let record_data = match T::get_physical_type() { + PhysicalType::BOOLEAN => { + let mut boolean_buffer = BooleanBufferBuilder::new(record_data.len()); - if T::get_physical_type() == PhysicalType::BOOLEAN { - let mut boolean_buffer = BooleanBufferBuilder::new(record_data.len()); + for e in record_data.as_slice() { + boolean_buffer.append(*e > 0); + } + boolean_buffer.finish() + } + PhysicalType::INT96 => { + // SAFETY - record_data is an aligned buffer of Int96 + let (prefix, slice, suffix) = + unsafe { record_data.as_slice().align_to::() }; + assert!(prefix.is_empty() && suffix.is_empty()); + + let mut builder = TimestampNanosecondBufferBuilder::new(slice.len()); + for v in slice { + builder.append(v.to_nanos()) + } - for e in record_data.as_slice() { - boolean_buffer.append(*e > 0); + builder.finish() } - record_data = boolean_buffer.finish(); - } + _ => record_data, + }; let array_data = ArrayDataBuilder::new(arrow_data_type) .len(self.record_reader.num_values()) @@ -171,9 +175,10 @@ where PhysicalType::INT64 => Arc::new(Int64Array::from(array_data)) as ArrayRef, PhysicalType::FLOAT => Arc::new(Float32Array::from(array_data)) as ArrayRef, PhysicalType::DOUBLE => Arc::new(Float64Array::from(array_data)) as ArrayRef, - PhysicalType::INT96 - | PhysicalType::BYTE_ARRAY - | PhysicalType::FIXED_LEN_BYTE_ARRAY => { + PhysicalType::INT96 => { + Arc::new(TimestampNanosecondArray::from(array_data)) as ArrayRef + } + PhysicalType::BYTE_ARRAY | PhysicalType::FIXED_LEN_BYTE_ARRAY => { unreachable!( "PrimitiveArrayReaders don't support complex physical types" ); @@ -189,41 +194,42 @@ where // are datatypes which we must convert explicitly. // These are: // - date64: we should cast int32 to date32, then date32 to date64. + // - decimal: cast in32 to decimal, int64 to decimal let array = match target_type { ArrowType::Date64 => { // this is cheap as it internally reinterprets the data let a = arrow::compute::cast(&array, &ArrowType::Date32)?; - arrow::compute::cast(&a, &target_type)? + arrow::compute::cast(&a, target_type)? } - ArrowType::Decimal(p, s) => { + ArrowType::Decimal128(p, s) => { let array = match array.data_type() { ArrowType::Int32 => array .as_any() .downcast_ref::() .unwrap() .iter() - .map(|v| v.map(|v| v.into())) - .collect::(), + .map(|v| v.map(|v| v as i128)) + .collect::(), ArrowType::Int64 => array .as_any() .downcast_ref::() .unwrap() .iter() - .map(|v| v.map(|v| v.into())) - .collect::(), + .map(|v| v.map(|v| v as i128)) + .collect::(), _ => { return Err(arrow_err!( "Cannot convert {:?} to decimal", array.data_type() - )) + )); } } - .with_precision_and_scale(p, s)?; + .with_precision_and_scale(*p, *s)?; Arc::new(array) as ArrayRef } - _ => arrow::compute::cast(&array, &target_type)?, + _ => arrow::compute::cast(&array, target_type)?, }; // save definition and repetition buffers @@ -234,7 +240,7 @@ where } fn skip_records(&mut self, num_records: usize) -> Result { - self.record_reader.skip_records(num_records) + skip_records(&mut self.record_reader, self.pages.as_mut(), num_records) } fn get_def_levels(&self) -> Option<&[i16]> { @@ -252,17 +258,19 @@ mod tests { use crate::arrow::array_reader::test_util::EmptyPageIterator; use crate::basic::Encoding; use crate::column::page::Page; - use crate::data_type::Int32Type; + use crate::data_type::{Int32Type, Int64Type}; use crate::schema::parser::parse_message_type; use crate::schema::types::SchemaDescriptor; - use crate::util::test_common::make_pages; + use crate::util::test_common::rand_gen::make_pages; use crate::util::InMemoryPageIterator; - use arrow::array::PrimitiveArray; + use arrow::array::{Array, PrimitiveArray}; use arrow::datatypes::ArrowPrimitiveType; + use arrow::datatypes::DataType::Decimal128; use rand::distributions::uniform::SampleUniform; use std::collections::VecDeque; + #[allow(clippy::too_many_arguments)] fn make_column_chunks( column_desc: ColumnDescPtr, encoding: Encoding, @@ -614,4 +622,133 @@ mod tests { ); } } + + #[test] + fn test_primitive_array_reader_decimal_types() { + // parquet `INT32` to decimal + let message_type = " + message test_schema { + REQUIRED INT32 decimal1 (DECIMAL(8,2)); + } + "; + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + let column_desc = schema.column(0); + + // create the array reader + { + let mut data = Vec::new(); + let mut page_lists = Vec::new(); + make_column_chunks::( + column_desc.clone(), + Encoding::PLAIN, + 100, + -99999999, + 99999999, + &mut Vec::new(), + &mut Vec::new(), + &mut data, + &mut page_lists, + true, + 2, + ); + let page_iterator = + InMemoryPageIterator::new(schema, column_desc.clone(), page_lists); + + let mut array_reader = PrimitiveArrayReader::::new( + Box::new(page_iterator), + column_desc, + None, + ) + .unwrap(); + + // read data from the reader + // the data type is decimal(8,2) + let array = array_reader.next_batch(50).unwrap(); + assert_eq!(array.data_type(), &Decimal128(8, 2)); + let array = array.as_any().downcast_ref::().unwrap(); + let data_decimal_array = data[0..50] + .iter() + .copied() + .map(|v| Some(v as i128)) + .collect::() + .with_precision_and_scale(8, 2) + .unwrap(); + assert_eq!(array, &data_decimal_array); + + // not equal with different data type(precision and scale) + let data_decimal_array = data[0..50] + .iter() + .copied() + .map(|v| Some(v as i128)) + .collect::() + .with_precision_and_scale(9, 0) + .unwrap(); + assert_ne!(array, &data_decimal_array) + } + + // parquet `INT64` to decimal + let message_type = " + message test_schema { + REQUIRED INT64 decimal1 (DECIMAL(18,4)); + } + "; + let schema = parse_message_type(message_type) + .map(|t| Arc::new(SchemaDescriptor::new(Arc::new(t)))) + .unwrap(); + let column_desc = schema.column(0); + + // create the array reader + { + let mut data = Vec::new(); + let mut page_lists = Vec::new(); + make_column_chunks::( + column_desc.clone(), + Encoding::PLAIN, + 100, + -999999999999999999, + 999999999999999999, + &mut Vec::new(), + &mut Vec::new(), + &mut data, + &mut page_lists, + true, + 2, + ); + let page_iterator = + InMemoryPageIterator::new(schema, column_desc.clone(), page_lists); + + let mut array_reader = PrimitiveArrayReader::::new( + Box::new(page_iterator), + column_desc, + None, + ) + .unwrap(); + + // read data from the reader + // the data type is decimal(18,4) + let array = array_reader.next_batch(50).unwrap(); + assert_eq!(array.data_type(), &Decimal128(18, 4)); + let array = array.as_any().downcast_ref::().unwrap(); + let data_decimal_array = data[0..50] + .iter() + .copied() + .map(|v| Some(v as i128)) + .collect::() + .with_precision_and_scale(18, 4) + .unwrap(); + assert_eq!(array, &data_decimal_array); + + // not equal with different data type(precision and scale) + let data_decimal_array = data[0..50] + .iter() + .copied() + .map(|v| Some(v as i128)) + .collect::() + .with_precision_and_scale(34, 0) + .unwrap(); + assert_ne!(array, &data_decimal_array) + } + } } diff --git a/parquet/src/arrow/array_reader/struct_array.rs b/parquet/src/arrow/array_reader/struct_array.rs index 602c598f8269..f682f146c721 100644 --- a/parquet/src/arrow/array_reader/struct_array.rs +++ b/parquet/src/arrow/array_reader/struct_array.rs @@ -63,7 +63,27 @@ impl ArrayReader for StructArrayReader { &self.data_type } - /// Read `batch_size` struct records. + fn read_records(&mut self, batch_size: usize) -> Result { + let mut read = None; + for child in self.children.iter_mut() { + let child_read = child.read_records(batch_size)?; + match read { + Some(expected) => { + if expected != child_read { + return Err(general_err!( + "StructArrayReader out of sync in read_records, expected {} skipped, got {}", + expected, + child_read + )); + } + } + None => read = Some(child_read), + } + } + Ok(read.unwrap_or(0)) + } + + /// Consume struct records. /// /// Definition levels of struct array is calculated as following: /// ```ignore @@ -80,7 +100,8 @@ impl ArrayReader for StructArrayReader { /// ```ignore /// null_bitmap[i] = (def_levels[i] >= self.def_level); /// ``` - fn next_batch(&mut self, batch_size: usize) -> Result { + /// + fn consume_batch(&mut self) -> Result { if self.children.is_empty() { return Ok(Arc::new(StructArray::from(Vec::new()))); } @@ -88,7 +109,7 @@ impl ArrayReader for StructArrayReader { let children_array = self .children .iter_mut() - .map(|reader| reader.next_batch(batch_size)) + .map(|reader| reader.consume_batch()) .collect::>>()?; // check that array child data has same size @@ -293,7 +314,6 @@ mod tests { let list_reader = ListArrayReader::::new( Box::new(reader), expected_l.data_type().clone(), - ArrowType::Int32, 3, 1, true, diff --git a/parquet/src/arrow/array_reader/test_util.rs b/parquet/src/arrow/array_reader/test_util.rs index 04c0f6c68f3f..ca1aabfd4aa1 100644 --- a/parquet/src/arrow/array_reader/test_util.rs +++ b/parquet/src/arrow/array_reader/test_util.rs @@ -48,8 +48,7 @@ pub fn utf8_column() -> ColumnDescPtr { /// Encode `data` with the provided `encoding` pub fn encode_byte_array(encoding: Encoding, data: &[ByteArray]) -> ByteBufferPtr { - let descriptor = utf8_column(); - let mut encoder = get_encoder::(descriptor, encoding).unwrap(); + let mut encoder = get_encoder::(encoding).unwrap(); encoder.put(data).unwrap(); encoder.flush_buffer().unwrap() @@ -101,6 +100,7 @@ pub struct InMemoryArrayReader { rep_levels: Option>, last_idx: usize, cur_idx: usize, + need_consume_records: usize, } impl InMemoryArrayReader { @@ -127,6 +127,7 @@ impl InMemoryArrayReader { rep_levels, cur_idx: 0, last_idx: 0, + need_consume_records: 0, } } } @@ -140,7 +141,7 @@ impl ArrayReader for InMemoryArrayReader { &self.data_type } - fn next_batch(&mut self, batch_size: usize) -> Result { + fn read_records(&mut self, batch_size: usize) -> Result { assert_ne!(batch_size, 0); // This replicates the logical normally performed by // RecordReader to delimit semantic records @@ -164,10 +165,17 @@ impl ArrayReader for InMemoryArrayReader { } None => batch_size.min(self.array.len() - self.cur_idx), }; + self.need_consume_records += read; + Ok(read) + } + fn consume_batch(&mut self) -> Result { + let batch_size = self.need_consume_records; + assert_ne!(batch_size, 0); self.last_idx = self.cur_idx; - self.cur_idx += read; - Ok(self.array.slice(self.last_idx, read)) + self.cur_idx += batch_size; + self.need_consume_records = 0; + Ok(self.array.slice(self.last_idx, batch_size)) } fn skip_records(&mut self, num_records: usize) -> Result { diff --git a/parquet/src/arrow/arrow_reader.rs b/parquet/src/arrow/arrow_reader.rs deleted file mode 100644 index ebbb864d6309..000000000000 --- a/parquet/src/arrow/arrow_reader.rs +++ /dev/null @@ -1,1589 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Contains reader which reads parquet data into arrow [`RecordBatch`] - -use std::collections::VecDeque; -use std::sync::Arc; - -use arrow::array::Array; -use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef}; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::{RecordBatch, RecordBatchReader}; -use arrow::{array::StructArray, error::ArrowError}; - -use crate::arrow::array_reader::{build_array_reader, ArrayReader}; -use crate::arrow::schema::parquet_to_arrow_schema; -use crate::arrow::schema::parquet_to_arrow_schema_by_columns; -use crate::arrow::ProjectionMask; -use crate::errors::{ParquetError, Result}; -use crate::file::metadata::{KeyValue, ParquetMetaData}; -use crate::file::reader::{ChunkReader, FileReader, SerializedFileReader}; -use crate::schema::types::SchemaDescriptor; - -/// Arrow reader api. -/// With this api, user can get arrow schema from parquet file, and read parquet data -/// into arrow arrays. -pub trait ArrowReader { - type RecordReader: RecordBatchReader; - - /// Read parquet schema and convert it into arrow schema. - fn get_schema(&mut self) -> Result; - - /// Read parquet schema and convert it into arrow schema. - /// This schema only includes columns identified by `mask`. - fn get_schema_by_columns(&mut self, mask: ProjectionMask) -> Result; - - /// Returns record batch reader from whole parquet file. - /// - /// # Arguments - /// - /// `batch_size`: The size of each record batch returned from this reader. Only the - /// last batch may contain records less than this size, otherwise record batches - /// returned from this reader should contains exactly `batch_size` elements. - fn get_record_reader(&mut self, batch_size: usize) -> Result; - - /// Returns record batch reader whose record batch contains columns identified by - /// `mask`. - /// - /// # Arguments - /// - /// `mask`: The columns that should be included in record batches. - /// `batch_size`: Please refer to `get_record_reader`. - fn get_record_reader_by_columns( - &mut self, - mask: ProjectionMask, - batch_size: usize, - ) -> Result; -} - -/// [`RowSelection`] allows selecting or skipping a provided number of rows -/// when scanning the parquet file -#[derive(Debug, Clone, Copy)] -pub(crate) struct RowSelection { - /// The number of rows - pub row_count: usize, - - /// If true, skip `row_count` rows - pub skip: bool, -} - -impl RowSelection { - /// Select `row_count` rows - pub fn select(row_count: usize) -> Self { - Self { - row_count, - skip: false, - } - } - - /// Skip `row_count` rows - pub fn skip(row_count: usize) -> Self { - Self { - row_count, - skip: true, - } - } -} - -#[derive(Debug, Clone, Default)] -pub struct ArrowReaderOptions { - skip_arrow_metadata: bool, - selection: Option>, -} - -impl ArrowReaderOptions { - /// Create a new [`ArrowReaderOptions`] with the default settings - fn new() -> Self { - Self::default() - } - - /// Parquet files generated by some writers may contain embedded arrow - /// schema and metadata. This may not be correct or compatible with your system. - /// - /// For example:[ARROW-16184](https://issues.apache.org/jira/browse/ARROW-16184) - /// - /// Set `skip_arrow_metadata` to true, to skip decoding this - pub fn with_skip_arrow_metadata(self, skip_arrow_metadata: bool) -> Self { - Self { - skip_arrow_metadata, - ..self - } - } - - /// Scan rows from the parquet file according to the provided `selection` - /// - /// TODO: Make public once row selection fully implemented (#1792) - pub(crate) fn with_row_selection( - self, - selection: impl Into>, - ) -> Self { - Self { - selection: Some(selection.into()), - ..self - } - } -} - -pub struct ParquetFileArrowReader { - file_reader: Arc, - - options: ArrowReaderOptions, -} - -impl ArrowReader for ParquetFileArrowReader { - type RecordReader = ParquetRecordBatchReader; - - fn get_schema(&mut self) -> Result { - let file_metadata = self.file_reader.metadata().file_metadata(); - parquet_to_arrow_schema(file_metadata.schema_descr(), self.get_kv_metadata()) - } - - fn get_schema_by_columns(&mut self, mask: ProjectionMask) -> Result { - let file_metadata = self.file_reader.metadata().file_metadata(); - parquet_to_arrow_schema_by_columns( - file_metadata.schema_descr(), - mask, - self.get_kv_metadata(), - ) - } - - fn get_record_reader( - &mut self, - batch_size: usize, - ) -> Result { - self.get_record_reader_by_columns(ProjectionMask::all(), batch_size) - } - - fn get_record_reader_by_columns( - &mut self, - mask: ProjectionMask, - batch_size: usize, - ) -> Result { - let array_reader = build_array_reader( - self.file_reader - .metadata() - .file_metadata() - .schema_descr_ptr(), - Arc::new(self.get_schema()?), - mask, - Box::new(self.file_reader.clone()), - )?; - - let selection = self.options.selection.clone().map(Into::into); - Ok(ParquetRecordBatchReader::new( - batch_size, - array_reader, - selection, - )) - } -} - -impl ParquetFileArrowReader { - /// Create a new [`ParquetFileArrowReader`] with the provided [`ChunkReader`] - /// - /// ```no_run - /// # use std::fs::File; - /// # use bytes::Bytes; - /// # use parquet::arrow::ParquetFileArrowReader; - /// - /// let file = File::open("file.parquet").unwrap(); - /// let reader = ParquetFileArrowReader::try_new(file).unwrap(); - /// - /// let bytes = Bytes::from(vec![]); - /// let reader = ParquetFileArrowReader::try_new(bytes).unwrap(); - /// ``` - pub fn try_new(chunk_reader: R) -> Result { - Self::try_new_with_options(chunk_reader, Default::default()) - } - - /// Create a new [`ParquetFileArrowReader`] with the provided [`ChunkReader`] - /// and [`ArrowReaderOptions`] - pub fn try_new_with_options( - chunk_reader: R, - options: ArrowReaderOptions, - ) -> Result { - let file_reader = Arc::new(SerializedFileReader::new(chunk_reader)?); - Ok(Self::new_with_options(file_reader, options)) - } - - /// Create a new [`ParquetFileArrowReader`] with the provided [`Arc`] - pub fn new(file_reader: Arc) -> Self { - Self::new_with_options(file_reader, Default::default()) - } - - /// Create a new [`ParquetFileArrowReader`] with the provided [`Arc`] - /// and [`ArrowReaderOptions`] - pub fn new_with_options( - file_reader: Arc, - options: ArrowReaderOptions, - ) -> Self { - Self { - file_reader, - options, - } - } - - /// Expose the reader metadata - #[deprecated = "use metadata() instead"] - pub fn get_metadata(&mut self) -> ParquetMetaData { - self.file_reader.metadata().clone() - } - - /// Returns the parquet metadata - pub fn metadata(&self) -> &ParquetMetaData { - self.file_reader.metadata() - } - - /// Returns the parquet schema - pub fn parquet_schema(&self) -> &SchemaDescriptor { - self.file_reader.metadata().file_metadata().schema_descr() - } - - /// Returns the key value metadata, returns `None` if [`ArrowReaderOptions::skip_arrow_metadata`] - fn get_kv_metadata(&self) -> Option<&Vec> { - if self.options.skip_arrow_metadata { - return None; - } - - self.file_reader - .metadata() - .file_metadata() - .key_value_metadata() - } -} - -pub struct ParquetRecordBatchReader { - batch_size: usize, - array_reader: Box, - schema: SchemaRef, - selection: Option>, -} - -impl Iterator for ParquetRecordBatchReader { - type Item = ArrowResult; - - fn next(&mut self) -> Option { - let to_read = match self.selection.as_mut() { - Some(selection) => loop { - let front = selection.pop_front()?; - if front.skip { - let skipped = match self.array_reader.skip_records(front.row_count) { - Ok(skipped) => skipped, - Err(e) => return Some(Err(e.into())), - }; - - if skipped != front.row_count { - return Some(Err(general_err!( - "failed to skip rows, expected {}, got {}", - front.row_count, - skipped - ) - .into())); - } - continue; - } - - let to_read = match front.row_count.checked_sub(self.batch_size) { - Some(remaining) => { - selection.push_front(RowSelection::skip(remaining)); - self.batch_size - } - None => front.row_count, - }; - - break to_read; - }, - None => self.batch_size, - }; - - match self.array_reader.next_batch(to_read) { - Err(error) => Some(Err(error.into())), - Ok(array) => { - let struct_array = - array.as_any().downcast_ref::().ok_or_else(|| { - ArrowError::ParquetError( - "Struct array reader should return struct array".to_string(), - ) - }); - - match struct_array { - Err(err) => Some(Err(err)), - Ok(e) => (e.len() > 0).then(|| Ok(RecordBatch::from(e))), - } - } - } - } -} - -impl RecordBatchReader for ParquetRecordBatchReader { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl ParquetRecordBatchReader { - pub fn try_new( - batch_size: usize, - array_reader: Box, - ) -> Result { - Ok(Self::new(batch_size, array_reader, None)) - } - - /// Create a new [`ParquetRecordBatchReader`] that will read at most `batch_size` rows at - /// a time from [`ArrayReader`] based on the configured `selection`. If `selection` is `None` - /// all rows will be returned - /// - /// TODO: Make public once row selection fully implemented (#1792) - pub(crate) fn new( - batch_size: usize, - array_reader: Box, - selection: Option>, - ) -> Self { - let schema = match array_reader.get_data_type() { - ArrowType::Struct(ref fields) => Schema::new(fields.clone()), - _ => unreachable!("Struct array reader's data type is not struct!"), - }; - - Self { - batch_size, - array_reader, - schema: Arc::new(schema), - selection, - } - } -} - -#[cfg(test)] -mod tests { - use bytes::Bytes; - use std::cmp::min; - use std::convert::TryFrom; - use std::fs::File; - use std::io::Seek; - use std::path::PathBuf; - use std::sync::Arc; - - use rand::{thread_rng, RngCore}; - use serde_json::json; - use serde_json::Value::{Array as JArray, Null as JNull, Object as JObject}; - use tempfile::tempfile; - - use arrow::array::*; - use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; - use arrow::error::Result as ArrowResult; - use arrow::record_batch::{RecordBatch, RecordBatchReader}; - - use crate::arrow::arrow_reader::{ - ArrowReader, ArrowReaderOptions, ParquetFileArrowReader, - }; - use crate::arrow::buffer::converter::{ - BinaryArrayConverter, Converter, FixedSizeArrayConverter, FromConverter, - IntervalDayTimeArrayConverter, LargeUtf8ArrayConverter, Utf8ArrayConverter, - }; - use crate::arrow::schema::add_encoded_arrow_schema_to_metadata; - use crate::arrow::{ArrowWriter, ProjectionMask}; - use crate::basic::{ConvertedType, Encoding, Repetition, Type as PhysicalType}; - use crate::data_type::{ - BoolType, ByteArray, ByteArrayType, DataType, FixedLenByteArray, - FixedLenByteArrayType, Int32Type, Int64Type, - }; - use crate::errors::Result; - use crate::file::properties::{WriterProperties, WriterVersion}; - use crate::file::reader::{FileReader, SerializedFileReader}; - use crate::file::writer::SerializedFileWriter; - use crate::schema::parser::parse_message_type; - use crate::schema::types::{Type, TypePtr}; - use crate::util::test_common::RandGen; - - #[test] - fn test_arrow_reader_all_columns() { - let json_values = get_json_array("parquet/generated_simple_numerics/blogs.json"); - - let parquet_file_reader = - get_test_reader("parquet/generated_simple_numerics/blogs.parquet"); - - let max_len = parquet_file_reader.metadata().file_metadata().num_rows() as usize; - - let mut arrow_reader = ParquetFileArrowReader::new(parquet_file_reader); - - let mut record_batch_reader = arrow_reader - .get_record_reader(60) - .expect("Failed to read into array!"); - - // Verify that the schema was correctly parsed - let original_schema = arrow_reader.get_schema().unwrap().fields().clone(); - assert_eq!(original_schema, *record_batch_reader.schema().fields()); - - compare_batch_json(&mut record_batch_reader, json_values, max_len); - } - - #[test] - fn test_arrow_reader_single_column() { - let json_values = get_json_array("parquet/generated_simple_numerics/blogs.json"); - - let projected_json_values = json_values - .into_iter() - .map(|value| match value { - JObject(fields) => { - json!({ "blog_id": fields.get("blog_id").unwrap_or(&JNull).clone()}) - } - _ => panic!("Input should be json object array!"), - }) - .collect::>(); - - let parquet_file_reader = - get_test_reader("parquet/generated_simple_numerics/blogs.parquet"); - - let file_metadata = parquet_file_reader.metadata().file_metadata(); - let max_len = file_metadata.num_rows() as usize; - - let mask = ProjectionMask::leaves(file_metadata.schema_descr(), [2]); - let mut arrow_reader = ParquetFileArrowReader::new(parquet_file_reader); - - let mut record_batch_reader = arrow_reader - .get_record_reader_by_columns(mask, 60) - .expect("Failed to read into array!"); - - // Verify that the schema was correctly parsed - let original_schema = arrow_reader.get_schema().unwrap().fields().clone(); - assert_eq!(1, record_batch_reader.schema().fields().len()); - assert_eq!(original_schema[1], record_batch_reader.schema().fields()[0]); - - compare_batch_json(&mut record_batch_reader, projected_json_values, max_len); - } - - #[test] - fn test_null_column_reader_test() { - let mut file = tempfile::tempfile().unwrap(); - - let schema = " - message message { - OPTIONAL INT32 int32; - } - "; - let schema = Arc::new(parse_message_type(schema).unwrap()); - - let def_levels = vec![vec![0, 0, 0], vec![0, 0, 0, 0]]; - generate_single_column_file_with_data::( - &[vec![], vec![]], - Some(&def_levels), - file.try_clone().unwrap(), // Cannot use &mut File (#1163) - schema, - Some(Field::new("int32", ArrowDataType::Null, true)), - &Default::default(), - ) - .unwrap(); - - file.rewind().unwrap(); - - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - let record_reader = arrow_reader.get_record_reader(2).unwrap(); - - let batches = record_reader.collect::>>().unwrap(); - - assert_eq!(batches.len(), 4); - for batch in &batches[0..3] { - assert_eq!(batch.num_rows(), 2); - assert_eq!(batch.num_columns(), 1); - assert_eq!(batch.column(0).null_count(), 2); - } - - assert_eq!(batches[3].num_rows(), 1); - assert_eq!(batches[3].num_columns(), 1); - assert_eq!(batches[3].column(0).null_count(), 1); - } - - #[test] - fn test_primitive_single_column_reader_test() { - run_single_column_reader_tests::( - 2, - ConvertedType::NONE, - None, - &FromConverter::new(), - &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], - ); - run_single_column_reader_tests::( - 2, - ConvertedType::NONE, - None, - &FromConverter::new(), - &[ - Encoding::PLAIN, - Encoding::RLE_DICTIONARY, - Encoding::DELTA_BINARY_PACKED, - ], - ); - run_single_column_reader_tests::( - 2, - ConvertedType::NONE, - None, - &FromConverter::new(), - &[ - Encoding::PLAIN, - Encoding::RLE_DICTIONARY, - Encoding::DELTA_BINARY_PACKED, - ], - ); - } - - struct RandFixedLenGen {} - - impl RandGen for RandFixedLenGen { - fn gen(len: i32) -> FixedLenByteArray { - let mut v = vec![0u8; len as usize]; - rand::thread_rng().fill_bytes(&mut v); - ByteArray::from(v).into() - } - } - - #[test] - fn test_fixed_length_binary_column_reader() { - let converter = FixedSizeArrayConverter::new(20); - run_single_column_reader_tests::< - FixedLenByteArrayType, - FixedSizeBinaryArray, - FixedSizeArrayConverter, - RandFixedLenGen, - >( - 20, - ConvertedType::NONE, - None, - &converter, - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY], - ); - } - - #[test] - fn test_interval_day_time_column_reader() { - let converter = IntervalDayTimeArrayConverter {}; - run_single_column_reader_tests::< - FixedLenByteArrayType, - IntervalDayTimeArray, - IntervalDayTimeArrayConverter, - RandFixedLenGen, - >( - 12, - ConvertedType::INTERVAL, - None, - &converter, - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY], - ); - } - - struct RandUtf8Gen {} - - impl RandGen for RandUtf8Gen { - fn gen(len: i32) -> ByteArray { - Int32Type::gen(len).to_string().as_str().into() - } - } - - #[test] - fn test_utf8_single_column_reader_test() { - let encodings = &[ - Encoding::PLAIN, - Encoding::RLE_DICTIONARY, - Encoding::DELTA_LENGTH_BYTE_ARRAY, - Encoding::DELTA_BYTE_ARRAY, - ]; - - let converter = BinaryArrayConverter {}; - run_single_column_reader_tests::< - ByteArrayType, - BinaryArray, - BinaryArrayConverter, - RandUtf8Gen, - >(2, ConvertedType::NONE, None, &converter, encodings); - - let utf8_converter = Utf8ArrayConverter {}; - run_single_column_reader_tests::< - ByteArrayType, - StringArray, - Utf8ArrayConverter, - RandUtf8Gen, - >(2, ConvertedType::UTF8, None, &utf8_converter, encodings); - - run_single_column_reader_tests::< - ByteArrayType, - StringArray, - Utf8ArrayConverter, - RandUtf8Gen, - >( - 2, - ConvertedType::UTF8, - Some(ArrowDataType::Utf8), - &utf8_converter, - encodings, - ); - - let large_utf8_converter = LargeUtf8ArrayConverter {}; - run_single_column_reader_tests::< - ByteArrayType, - LargeStringArray, - LargeUtf8ArrayConverter, - RandUtf8Gen, - >( - 2, - ConvertedType::UTF8, - Some(ArrowDataType::LargeUtf8), - &large_utf8_converter, - encodings, - ); - - let small_key_types = [ArrowDataType::Int8, ArrowDataType::UInt8]; - for key in &small_key_types { - for encoding in encodings { - let mut opts = TestOptions::new(2, 20, 15).with_null_percent(50); - opts.encoding = *encoding; - - // Cannot run full test suite as keys overflow, run small test instead - single_column_reader_test::< - ByteArrayType, - StringArray, - Utf8ArrayConverter, - RandUtf8Gen, - >( - opts, - 2, - ConvertedType::UTF8, - Some(ArrowDataType::Dictionary( - Box::new(key.clone()), - Box::new(ArrowDataType::Utf8), - )), - &utf8_converter, - ); - } - } - - let key_types = [ - ArrowDataType::Int16, - ArrowDataType::UInt16, - ArrowDataType::Int32, - ArrowDataType::UInt32, - ArrowDataType::Int64, - ArrowDataType::UInt64, - ]; - - for key in &key_types { - run_single_column_reader_tests::< - ByteArrayType, - StringArray, - Utf8ArrayConverter, - RandUtf8Gen, - >( - 2, - ConvertedType::UTF8, - Some(ArrowDataType::Dictionary( - Box::new(key.clone()), - Box::new(ArrowDataType::Utf8), - )), - &utf8_converter, - encodings, - ); - - // https://github.com/apache/arrow-rs/issues/1179 - // run_single_column_reader_tests::< - // ByteArrayType, - // LargeStringArray, - // LargeUtf8ArrayConverter, - // RandUtf8Gen, - // >( - // 2, - // ConvertedType::UTF8, - // Some(ArrowDataType::Dictionary( - // Box::new(key.clone()), - // Box::new(ArrowDataType::LargeUtf8), - // )), - // &large_utf8_converter, - // encodings - // ); - } - } - - #[test] - fn test_read_decimal_file() { - use arrow::array::DecimalArray; - let testdata = arrow::util::test_util::parquet_test_data(); - let file_variants = vec![("fixed_length", 25), ("int32", 4), ("int64", 10)]; - for (prefix, target_precision) in file_variants { - let path = format!("{}/{}_decimal.parquet", testdata, prefix); - let file = File::open(&path).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - - let mut record_reader = arrow_reader.get_record_reader(32).unwrap(); - - let batch = record_reader.next().unwrap().unwrap(); - assert_eq!(batch.num_rows(), 24); - let col = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - - let expected = 1..25; - - assert_eq!(col.precision(), target_precision); - assert_eq!(col.scale(), 2); - - for (i, v) in expected.enumerate() { - assert_eq!(col.value(i).as_i128(), v * 100_i128); - } - } - } - - /// Parameters for single_column_reader_test - #[derive(Debug, Clone)] - struct TestOptions { - /// Number of row group to write to parquet (row group size = - /// num_row_groups / num_rows) - num_row_groups: usize, - /// Total number of rows per row group - num_rows: usize, - /// Size of batches to read back - record_batch_size: usize, - /// Percentage of nulls in column or None if required - null_percent: Option, - /// Set write batch size - /// - /// This is the number of rows that are written at once to a page and - /// therefore acts as a bound on the page granularity of a row group - write_batch_size: usize, - /// Maximum size of page in bytes - max_data_page_size: usize, - /// Maximum size of dictionary page in bytes - max_dict_page_size: usize, - /// Writer version - writer_version: WriterVersion, - /// Encoding - encoding: Encoding, - } - - impl Default for TestOptions { - fn default() -> Self { - Self { - num_row_groups: 2, - num_rows: 100, - record_batch_size: 15, - null_percent: None, - write_batch_size: 64, - max_data_page_size: 1024 * 1024, - max_dict_page_size: 1024 * 1024, - writer_version: WriterVersion::PARQUET_1_0, - encoding: Encoding::PLAIN, - } - } - } - - impl TestOptions { - fn new(num_row_groups: usize, num_rows: usize, record_batch_size: usize) -> Self { - Self { - num_row_groups, - num_rows, - record_batch_size, - ..Default::default() - } - } - - fn with_null_percent(self, null_percent: usize) -> Self { - Self { - null_percent: Some(null_percent), - ..self - } - } - - fn with_max_data_page_size(self, max_data_page_size: usize) -> Self { - Self { - max_data_page_size, - ..self - } - } - - fn with_max_dict_page_size(self, max_dict_page_size: usize) -> Self { - Self { - max_dict_page_size, - ..self - } - } - - fn writer_props(&self) -> WriterProperties { - let builder = WriterProperties::builder() - .set_data_pagesize_limit(self.max_data_page_size) - .set_write_batch_size(self.write_batch_size) - .set_writer_version(self.writer_version); - - let builder = match self.encoding { - Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => builder - .set_dictionary_enabled(true) - .set_dictionary_pagesize_limit(self.max_dict_page_size), - _ => builder - .set_dictionary_enabled(false) - .set_encoding(self.encoding), - }; - - builder.build() - } - } - - /// Create a parquet file and then read it using - /// `ParquetFileArrowReader` using a standard set of parameters - /// `opts`. - /// - /// `rand_max` represents the maximum size of value to pass to to - /// value generator - fn run_single_column_reader_tests( - rand_max: i32, - converted_type: ConvertedType, - arrow_type: Option, - converter: &C, - encodings: &[Encoding], - ) where - T: DataType, - G: RandGen, - A: Array + 'static, - C: Converter>, A> + 'static, - { - let all_options = vec![ - // choose record_batch_batch (15) so batches cross row - // group boundaries (50 rows in 2 row groups) cases. - TestOptions::new(2, 100, 15), - // choose record_batch_batch (5) so batches sometime fall - // on row group boundaries and (25 rows in 3 row groups - // --> row groups of 10, 10, and 5). Tests buffer - // refilling edge cases. - TestOptions::new(3, 25, 5), - // Choose record_batch_size (25) so all batches fall - // exactly on row group boundary (25). Tests buffer - // refilling edge cases. - TestOptions::new(4, 100, 25), - // Set maximum page size so row groups have multiple pages - TestOptions::new(3, 256, 73).with_max_data_page_size(128), - // Set small dictionary page size to test dictionary fallback - TestOptions::new(3, 256, 57).with_max_dict_page_size(128), - // Test optional but with no nulls - TestOptions::new(2, 256, 127).with_null_percent(0), - // Test optional with nulls - TestOptions::new(2, 256, 93).with_null_percent(25), - ]; - - all_options.into_iter().for_each(|opts| { - for writer_version in [WriterVersion::PARQUET_1_0, WriterVersion::PARQUET_2_0] - { - for encoding in encodings { - let opts = TestOptions { - writer_version, - encoding: *encoding, - ..opts - }; - - single_column_reader_test::( - opts, - rand_max, - converted_type, - arrow_type.clone(), - converter, - ) - } - } - }); - } - - /// Create a parquet file and then read it using - /// `ParquetFileArrowReader` using the parameters described in - /// `opts`. - fn single_column_reader_test( - opts: TestOptions, - rand_max: i32, - converted_type: ConvertedType, - arrow_type: Option, - converter: &C, - ) where - T: DataType, - G: RandGen, - A: Array + 'static, - C: Converter>, A> + 'static, - { - // Print out options to facilitate debugging failures on CI - println!( - "Running single_column_reader_test ConvertedType::{}/ArrowType::{:?} with Options: {:?}", - converted_type, arrow_type, opts - ); - - let (repetition, def_levels) = match opts.null_percent.as_ref() { - Some(null_percent) => { - let mut rng = thread_rng(); - - let def_levels: Vec> = (0..opts.num_row_groups) - .map(|_| { - std::iter::from_fn(|| { - Some((rng.next_u32() as usize % 100 >= *null_percent) as i16) - }) - .take(opts.num_rows) - .collect() - }) - .collect(); - (Repetition::OPTIONAL, Some(def_levels)) - } - None => (Repetition::REQUIRED, None), - }; - - let values: Vec> = (0..opts.num_row_groups) - .map(|idx| { - let null_count = match def_levels.as_ref() { - Some(d) => d[idx].iter().filter(|x| **x == 0).count(), - None => 0, - }; - G::gen_vec(rand_max, opts.num_rows - null_count) - }) - .collect(); - - let len = match T::get_physical_type() { - crate::basic::Type::FIXED_LEN_BYTE_ARRAY => rand_max, - crate::basic::Type::INT96 => 12, - _ => -1, - }; - - let mut fields = vec![Arc::new( - Type::primitive_type_builder("leaf", T::get_physical_type()) - .with_repetition(repetition) - .with_converted_type(converted_type) - .with_length(len) - .build() - .unwrap(), - )]; - - let schema = Arc::new( - Type::group_type_builder("test_schema") - .with_fields(&mut fields) - .build() - .unwrap(), - ); - - let arrow_field = arrow_type - .clone() - .map(|t| arrow::datatypes::Field::new("leaf", t, false)); - - let mut file = tempfile::tempfile().unwrap(); - - generate_single_column_file_with_data::( - &values, - def_levels.as_ref(), - file.try_clone().unwrap(), // Cannot use &mut File (#1163) - schema, - arrow_field, - &opts, - ) - .unwrap(); - - file.rewind().unwrap(); - - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - let mut record_reader = arrow_reader - .get_record_reader(opts.record_batch_size) - .unwrap(); - - let expected_data: Vec> = match def_levels { - Some(levels) => { - let mut values_iter = values.iter().flatten(); - levels - .iter() - .flatten() - .map(|d| match d { - 1 => Some(values_iter.next().cloned().unwrap()), - 0 => None, - _ => unreachable!(), - }) - .collect() - } - None => values.iter().flatten().map(|b| Some(b.clone())).collect(), - }; - - assert_eq!(expected_data.len(), opts.num_rows * opts.num_row_groups); - - let mut total_read = 0; - loop { - let maybe_batch = record_reader.next(); - if total_read < expected_data.len() { - let end = min(total_read + opts.record_batch_size, expected_data.len()); - let batch = maybe_batch.unwrap().unwrap(); - assert_eq!(end - total_read, batch.num_rows()); - - let mut data = vec![]; - data.extend_from_slice(&expected_data[total_read..end]); - - let a = converter.convert(data).unwrap(); - let mut b = Arc::clone(batch.column(0)); - - if let Some(arrow_type) = arrow_type.as_ref() { - assert_eq!(b.data_type(), arrow_type); - if let ArrowDataType::Dictionary(_, v) = arrow_type { - assert_eq!(a.data_type(), v.as_ref()); - b = arrow::compute::cast(&b, v.as_ref()).unwrap() - } - } - assert_eq!(a.data_type(), b.data_type()); - assert_eq!(a.data(), b.data(), "{:#?} vs {:#?}", a.data(), b.data()); - - total_read = end; - } else { - assert!(maybe_batch.is_none()); - break; - } - } - } - - fn generate_single_column_file_with_data( - values: &[Vec], - def_levels: Option<&Vec>>, - file: File, - schema: TypePtr, - field: Option, - opts: &TestOptions, - ) -> Result { - let mut writer_props = opts.writer_props(); - if let Some(field) = field { - let arrow_schema = arrow::datatypes::Schema::new(vec![field]); - add_encoded_arrow_schema_to_metadata(&arrow_schema, &mut writer_props); - } - - let mut writer = SerializedFileWriter::new(file, schema, Arc::new(writer_props))?; - - for (idx, v) in values.iter().enumerate() { - let def_levels = def_levels.map(|d| d[idx].as_slice()); - let mut row_group_writer = writer.next_row_group()?; - { - let mut column_writer = row_group_writer - .next_column()? - .expect("Column writer is none!"); - - column_writer - .typed::() - .write_batch(v, def_levels, None)?; - - column_writer.close()?; - } - row_group_writer.close()?; - } - - writer.close() - } - - fn get_test_reader(file_name: &str) -> Arc> { - let file = get_test_file(file_name); - - let reader = - SerializedFileReader::new(file).expect("Failed to create serialized reader"); - - Arc::new(reader) - } - - fn get_test_file(file_name: &str) -> File { - let mut path = PathBuf::new(); - path.push(arrow::util::test_util::arrow_test_data()); - path.push(file_name); - - File::open(path.as_path()).expect("File not found!") - } - - fn get_json_array(filename: &str) -> Vec { - match serde_json::from_reader(get_test_file(filename)) - .expect("Failed to read json value from file!") - { - JArray(values) => values, - _ => panic!("Input should be json array!"), - } - } - - fn compare_batch_json( - record_batch_reader: &mut dyn RecordBatchReader, - json_values: Vec, - max_len: usize, - ) { - for i in 0..20 { - let array: Option = record_batch_reader - .next() - .map(|r| r.expect("Failed to read record batch!").into()); - - let (start, end) = (i * 60_usize, (i + 1) * 60_usize); - - if start < max_len { - assert!(array.is_some()); - assert_ne!(0, array.as_ref().unwrap().len()); - let end = min(end, max_len); - let json = JArray(Vec::from(&json_values[start..end])); - assert_eq!(array.unwrap(), json) - } else { - assert!(array.is_none()); - } - } - } - - #[test] - fn test_read_structs() { - // This particular test file has columns of struct types where there is - // a column that has the same name as one of the struct fields - // (see: ARROW-11452) - let testdata = arrow::util::test_util::parquet_test_data(); - let path = format!("{}/nested_structs.rust.parquet", testdata); - let file = File::open(&path).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - let record_batch_reader = arrow_reader - .get_record_reader(60) - .expect("Failed to read into array!"); - - for batch in record_batch_reader { - batch.unwrap(); - } - - let mask = ProjectionMask::leaves(arrow_reader.parquet_schema(), [3, 8, 10]); - let projected_reader = arrow_reader - .get_record_reader_by_columns(mask.clone(), 60) - .unwrap(); - let projected_schema = arrow_reader.get_schema_by_columns(mask).unwrap(); - - let expected_schema = Schema::new(vec![ - Field::new( - "roll_num", - ArrowDataType::Struct(vec![Field::new( - "count", - ArrowDataType::UInt64, - false, - )]), - false, - ), - Field::new( - "PC_CUR", - ArrowDataType::Struct(vec![ - Field::new("mean", ArrowDataType::Int64, false), - Field::new("sum", ArrowDataType::Int64, false), - ]), - false, - ), - ]); - - // Tests for #1652 and #1654 - assert_eq!(projected_reader.schema().as_ref(), &projected_schema); - assert_eq!(expected_schema, projected_schema); - - for batch in projected_reader { - let batch = batch.unwrap(); - assert_eq!(batch.schema().as_ref(), &projected_schema); - } - } - - #[test] - fn test_read_maps() { - let testdata = arrow::util::test_util::parquet_test_data(); - let path = format!("{}/nested_maps.snappy.parquet", testdata); - let file = File::open(&path).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - let record_batch_reader = arrow_reader - .get_record_reader(60) - .expect("Failed to read into array!"); - - for batch in record_batch_reader { - batch.unwrap(); - } - } - - #[test] - fn test_nested_nullability() { - let message_type = "message nested { - OPTIONAL Group group { - REQUIRED INT32 leaf; - } - }"; - - let file = tempfile::tempfile().unwrap(); - let schema = Arc::new(parse_message_type(message_type).unwrap()); - - { - // Write using low-level parquet API (#1167) - let writer_props = Arc::new(WriterProperties::builder().build()); - let mut writer = SerializedFileWriter::new( - file.try_clone().unwrap(), - schema, - writer_props, - ) - .unwrap(); - - { - let mut row_group_writer = writer.next_row_group().unwrap(); - let mut column_writer = row_group_writer.next_column().unwrap().unwrap(); - - column_writer - .typed::() - .write_batch(&[34, 76], Some(&[0, 1, 0, 1]), None) - .unwrap(); - - column_writer.close().unwrap(); - row_group_writer.close().unwrap(); - } - - writer.close().unwrap(); - } - - let mut reader = ParquetFileArrowReader::try_new(file).unwrap(); - let mask = ProjectionMask::leaves(reader.parquet_schema(), [0]); - - let reader = reader.get_record_reader_by_columns(mask, 1024).unwrap(); - - let expected_schema = Schema::new(vec![Field::new( - "group", - ArrowDataType::Struct(vec![Field::new("leaf", ArrowDataType::Int32, false)]), - true, - )]); - - let batch = reader.into_iter().next().unwrap().unwrap(); - assert_eq!(batch.schema().as_ref(), &expected_schema); - assert_eq!(batch.num_rows(), 4); - assert_eq!(batch.column(0).data().null_count(), 2); - } - - #[test] - fn test_invalid_utf8() { - // a parquet file with 1 column with invalid utf8 - let data = vec![ - 80, 65, 82, 49, 21, 6, 21, 22, 21, 22, 92, 21, 2, 21, 0, 21, 2, 21, 0, 21, 4, - 21, 0, 18, 28, 54, 0, 40, 5, 104, 101, 255, 108, 111, 24, 5, 104, 101, 255, - 108, 111, 0, 0, 0, 3, 1, 5, 0, 0, 0, 104, 101, 255, 108, 111, 38, 110, 28, - 21, 12, 25, 37, 6, 0, 25, 24, 2, 99, 49, 21, 0, 22, 2, 22, 102, 22, 102, 38, - 8, 60, 54, 0, 40, 5, 104, 101, 255, 108, 111, 24, 5, 104, 101, 255, 108, 111, - 0, 0, 0, 21, 4, 25, 44, 72, 4, 114, 111, 111, 116, 21, 2, 0, 21, 12, 37, 2, - 24, 2, 99, 49, 37, 0, 76, 28, 0, 0, 0, 22, 2, 25, 28, 25, 28, 38, 110, 28, - 21, 12, 25, 37, 6, 0, 25, 24, 2, 99, 49, 21, 0, 22, 2, 22, 102, 22, 102, 38, - 8, 60, 54, 0, 40, 5, 104, 101, 255, 108, 111, 24, 5, 104, 101, 255, 108, 111, - 0, 0, 0, 22, 102, 22, 2, 0, 40, 44, 65, 114, 114, 111, 119, 50, 32, 45, 32, - 78, 97, 116, 105, 118, 101, 32, 82, 117, 115, 116, 32, 105, 109, 112, 108, - 101, 109, 101, 110, 116, 97, 116, 105, 111, 110, 32, 111, 102, 32, 65, 114, - 114, 111, 119, 0, 130, 0, 0, 0, 80, 65, 82, 49, - ]; - - let file = Bytes::from(data); - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - let mut record_batch_reader = arrow_reader - .get_record_reader_by_columns(ProjectionMask::all(), 10) - .unwrap(); - - let error = record_batch_reader.next().unwrap().unwrap_err(); - - assert!( - error.to_string().contains("invalid utf-8 sequence"), - "{}", - error - ); - } - - #[test] - fn test_dictionary_preservation() { - let mut fields = vec![Arc::new( - Type::primitive_type_builder("leaf", PhysicalType::BYTE_ARRAY) - .with_repetition(Repetition::OPTIONAL) - .with_converted_type(ConvertedType::UTF8) - .build() - .unwrap(), - )]; - - let schema = Arc::new( - Type::group_type_builder("test_schema") - .with_fields(&mut fields) - .build() - .unwrap(), - ); - - let dict_type = ArrowDataType::Dictionary( - Box::new(ArrowDataType::Int32), - Box::new(ArrowDataType::Utf8), - ); - - let arrow_field = Field::new("leaf", dict_type, true); - - let mut file = tempfile::tempfile().unwrap(); - - let values = vec![ - vec![ - ByteArray::from("hello"), - ByteArray::from("a"), - ByteArray::from("b"), - ByteArray::from("d"), - ], - vec![ - ByteArray::from("c"), - ByteArray::from("a"), - ByteArray::from("b"), - ], - ]; - - let def_levels = vec![ - vec![1, 0, 0, 1, 0, 0, 1, 1], - vec![0, 0, 1, 1, 0, 0, 1, 0, 0], - ]; - - let opts = TestOptions { - encoding: Encoding::RLE_DICTIONARY, - ..Default::default() - }; - - generate_single_column_file_with_data::( - &values, - Some(&def_levels), - file.try_clone().unwrap(), // Cannot use &mut File (#1163) - schema, - Some(arrow_field), - &opts, - ) - .unwrap(); - - file.rewind().unwrap(); - - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - - let record_reader = arrow_reader.get_record_reader(3).unwrap(); - - let batches = record_reader - .collect::>>() - .unwrap(); - - assert_eq!(batches.len(), 6); - assert!(batches.iter().all(|x| x.num_columns() == 1)); - - let row_counts = batches - .iter() - .map(|x| (x.num_rows(), x.column(0).null_count())) - .collect::>(); - - assert_eq!( - row_counts, - vec![(3, 2), (3, 2), (3, 1), (3, 1), (3, 2), (2, 2)] - ); - - let get_dict = - |batch: &RecordBatch| batch.column(0).data().child_data()[0].clone(); - - // First and second batch in same row group -> same dictionary - assert_eq!(get_dict(&batches[0]), get_dict(&batches[1])); - // Third batch spans row group -> computed dictionary - assert_ne!(get_dict(&batches[1]), get_dict(&batches[2])); - assert_ne!(get_dict(&batches[2]), get_dict(&batches[3])); - // Fourth, fifth and sixth from same row group -> same dictionary - assert_eq!(get_dict(&batches[3]), get_dict(&batches[4])); - assert_eq!(get_dict(&batches[4]), get_dict(&batches[5])); - } - - #[test] - fn test_read_null_list() { - let testdata = arrow::util::test_util::parquet_test_data(); - let path = format!("{}/null_list.parquet", testdata); - let file = File::open(&path).unwrap(); - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - let mut record_batch_reader = arrow_reader - .get_record_reader(60) - .expect("Failed to read into array!"); - - let batch = record_batch_reader.next().unwrap().unwrap(); - assert_eq!(batch.num_rows(), 1); - assert_eq!(batch.num_columns(), 1); - assert_eq!(batch.column(0).len(), 1); - - let list = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(list.len(), 1); - assert!(list.is_valid(0)); - - let val = list.value(0); - assert_eq!(val.len(), 0); - } - - #[test] - fn test_null_schema_inference() { - let testdata = arrow::util::test_util::parquet_test_data(); - let path = format!("{}/null_list.parquet", testdata); - let reader = - Arc::new(SerializedFileReader::try_from(File::open(&path).unwrap()).unwrap()); - - let arrow_field = Field::new( - "emptylist", - ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Null, true))), - true, - ); - - let options = ArrowReaderOptions::default().with_skip_arrow_metadata(true); - let mut arrow_reader = ParquetFileArrowReader::new_with_options(reader, options); - let schema = arrow_reader.get_schema().unwrap(); - assert_eq!(schema.fields().len(), 1); - assert_eq!(schema.field(0), &arrow_field); - } - - #[test] - fn test_skip_metadata() { - let col = Arc::new(TimestampNanosecondArray::from_iter_values(vec![0, 1, 2])); - let field = Field::new("col", col.data_type().clone(), true); - - let schema_without_metadata = Arc::new(Schema::new(vec![field.clone()])); - - let metadata = [("key".to_string(), "value".to_string())] - .into_iter() - .collect(); - - let schema_with_metadata = - Arc::new(Schema::new(vec![field.with_metadata(Some(metadata))])); - - assert_ne!(schema_with_metadata, schema_without_metadata); - - let batch = - RecordBatch::try_new(schema_with_metadata.clone(), vec![col as ArrayRef]) - .unwrap(); - - let file = |version: WriterVersion| { - let props = WriterProperties::builder() - .set_writer_version(version) - .build(); - - let file = tempfile().unwrap(); - let mut writer = ArrowWriter::try_new( - file.try_clone().unwrap(), - batch.schema(), - Some(props), - ) - .unwrap(); - writer.write(&batch).unwrap(); - writer.close().unwrap(); - file - }; - - let v1_reader = Arc::new( - SerializedFileReader::new(file(WriterVersion::PARQUET_1_0)).unwrap(), - ); - let v2_reader = Arc::new( - SerializedFileReader::new(file(WriterVersion::PARQUET_2_0)).unwrap(), - ); - - let mut arrow_reader = ParquetFileArrowReader::new(v1_reader.clone()); - assert_eq!( - &arrow_reader.get_schema().unwrap(), - schema_with_metadata.as_ref() - ); - - let options = ArrowReaderOptions::new().with_skip_arrow_metadata(true); - let mut arrow_reader = - ParquetFileArrowReader::new_with_options(v1_reader, options); - assert_eq!( - &arrow_reader.get_schema().unwrap(), - schema_without_metadata.as_ref() - ); - - let mut arrow_reader = ParquetFileArrowReader::new(v2_reader.clone()); - assert_eq!( - &arrow_reader.get_schema().unwrap(), - schema_with_metadata.as_ref() - ); - - let options = ArrowReaderOptions::new().with_skip_arrow_metadata(true); - let mut arrow_reader = - ParquetFileArrowReader::new_with_options(v2_reader, options); - assert_eq!( - &arrow_reader.get_schema().unwrap(), - schema_without_metadata.as_ref() - ); - } - - #[test] - fn test_empty_projection() { - let testdata = arrow::util::test_util::parquet_test_data(); - let path = format!("{}/alltypes_plain.parquet", testdata); - let file = File::open(&path).unwrap(); - - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - let file_metadata = arrow_reader.metadata().file_metadata(); - let expected_rows = file_metadata.num_rows() as usize; - let schema = file_metadata.schema_descr_ptr(); - - let mask = ProjectionMask::leaves(&schema, []); - let batch_reader = arrow_reader.get_record_reader_by_columns(mask, 2).unwrap(); - - let mut total_rows = 0; - for maybe_batch in batch_reader { - let batch = maybe_batch.unwrap(); - total_rows += batch.num_rows(); - assert_eq!(batch.num_columns(), 0); - assert!(batch.num_rows() <= 2); - } - - assert_eq!(total_rows, expected_rows); - } - - fn test_row_group_batch(row_group_size: usize, batch_size: usize) { - let schema = Arc::new(Schema::new(vec![Field::new( - "list", - ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), - true, - )])); - - let mut buf = Vec::with_capacity(1024); - - let mut writer = ArrowWriter::try_new( - &mut buf, - schema.clone(), - Some( - WriterProperties::builder() - .set_max_row_group_size(row_group_size) - .build(), - ), - ) - .unwrap(); - for _ in 0..2 { - let mut list_builder = ListBuilder::new(Int32Builder::new(batch_size)); - for _ in 0..(batch_size) { - list_builder.append(true).unwrap(); - } - let batch = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(list_builder.finish())], - ) - .unwrap(); - writer.write(&batch).unwrap(); - } - writer.close().unwrap(); - - let mut file_reader = ParquetFileArrowReader::try_new(Bytes::from(buf)).unwrap(); - let mut record_reader = file_reader.get_record_reader(batch_size).unwrap(); - assert_eq!( - batch_size, - record_reader.next().unwrap().unwrap().num_rows() - ); - assert_eq!( - batch_size, - record_reader.next().unwrap().unwrap().num_rows() - ); - } - - #[test] - fn test_row_group_exact_multiple() { - use crate::arrow::record_reader::MIN_BATCH_SIZE; - test_row_group_batch(8, 8); - test_row_group_batch(10, 8); - test_row_group_batch(8, 10); - test_row_group_batch(MIN_BATCH_SIZE, MIN_BATCH_SIZE); - test_row_group_batch(MIN_BATCH_SIZE + 1, MIN_BATCH_SIZE); - test_row_group_batch(MIN_BATCH_SIZE, MIN_BATCH_SIZE + 1); - test_row_group_batch(MIN_BATCH_SIZE, MIN_BATCH_SIZE - 1); - test_row_group_batch(MIN_BATCH_SIZE - 1, MIN_BATCH_SIZE); - } -} diff --git a/parquet/src/arrow/arrow_reader/filter.rs b/parquet/src/arrow/arrow_reader/filter.rs new file mode 100644 index 000000000000..8945ccde4248 --- /dev/null +++ b/parquet/src/arrow/arrow_reader/filter.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::ProjectionMask; +use arrow::array::BooleanArray; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; + +/// A predicate operating on [`RecordBatch`] +pub trait ArrowPredicate: Send + 'static { + /// Returns the [`ProjectionMask`] that describes the columns required + /// to evaluate this predicate. All projected columns will be provided in the `batch` + /// passed to [`evaluate`](Self::evaluate) + fn projection(&self) -> &ProjectionMask; + + /// Evaluate this predicate for the given [`RecordBatch`] containing the columns + /// identified by [`Self::projection`] + /// + /// Rows that are `true` in the returned [`BooleanArray`] will be returned by the + /// parquet reader, whereas rows that are `false` or `Null` will not be + fn evaluate(&mut self, batch: RecordBatch) -> ArrowResult; +} + +/// An [`ArrowPredicate`] created from an [`FnMut`] +pub struct ArrowPredicateFn { + f: F, + projection: ProjectionMask, +} + +impl ArrowPredicateFn +where + F: FnMut(RecordBatch) -> ArrowResult + Send + 'static, +{ + /// Create a new [`ArrowPredicateFn`]. `f` will be passed batches + /// that contains the columns specified in `projection` + /// and returns a [`BooleanArray`] that describes which rows should + /// be passed along + pub fn new(projection: ProjectionMask, f: F) -> Self { + Self { f, projection } + } +} + +impl ArrowPredicate for ArrowPredicateFn +where + F: FnMut(RecordBatch) -> ArrowResult + Send + 'static, +{ + fn projection(&self) -> &ProjectionMask { + &self.projection + } + + fn evaluate(&mut self, batch: RecordBatch) -> ArrowResult { + (self.f)(batch) + } +} + +/// A [`RowFilter`] allows pushing down a filter predicate to skip IO and decode +/// +/// This consists of a list of [`ArrowPredicate`] where only the rows that satisfy all +/// of the predicates will be returned. Any [`RowSelection`] will be applied prior +/// to the first predicate, and each predicate in turn will then be used to compute +/// a more refined [`RowSelection`] to use when evaluating the subsequent predicates. +/// +/// Once all predicates have been evaluated, the final [`RowSelection`] is applied +/// to the top-level [`ProjectionMask`] to produce the final output [`RecordBatch`]. +/// +/// This design has a couple of implications: +/// +/// * [`RowFilter`] can be used to skip entire pages, and thus IO, in addition to CPU decode overheads +/// * Columns may be decoded multiple times if they appear in multiple [`ProjectionMask`] +/// * IO will be deferred until needed by a [`ProjectionMask`] +/// +/// As such there is a trade-off between a single large predicate, or multiple predicates, +/// that will depend on the shape of the data. Whilst multiple smaller predicates may +/// minimise the amount of data scanned/decoded, it may not be faster overall. +/// +/// For example, if a predicate that needs a single column of data filters out all but +/// 1% of the rows, applying it as one of the early `ArrowPredicateFn` will likely significantly +/// improve performance. +/// +/// As a counter example, if a predicate needs several columns of data to evaluate but +/// leaves 99% of the rows, it may be better to not filter the data from parquet and +/// apply the filter after the RecordBatch has been fully decoded. +/// +/// [`RowSelection`]: [super::selection::RowSelection] +pub struct RowFilter { + /// A list of [`ArrowPredicate`] + pub(crate) predicates: Vec>, +} + +impl RowFilter { + /// Create a new [`RowFilter`] from an array of [`ArrowPredicate`] + pub fn new(predicates: Vec>) -> Self { + Self { predicates } + } +} diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs new file mode 100644 index 000000000000..76e247ae1f1f --- /dev/null +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -0,0 +1,2310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains reader which reads parquet data into arrow [`RecordBatch`] + +use std::collections::VecDeque; +use std::sync::Arc; + +use arrow::array::Array; +use arrow::compute::prep_null_mask_filter; +use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef}; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::{RecordBatch, RecordBatchReader}; +use arrow::{array::StructArray, error::ArrowError}; + +use crate::arrow::array_reader::{ + build_array_reader, ArrayReader, FileReaderRowGroupCollection, RowGroupCollection, +}; +use crate::arrow::schema::parquet_to_arrow_schema; +use crate::arrow::schema::parquet_to_arrow_schema_by_columns; +use crate::arrow::ProjectionMask; +use crate::errors::{ParquetError, Result}; +use crate::file::metadata::{KeyValue, ParquetMetaData}; +use crate::file::reader::{ChunkReader, FileReader, SerializedFileReader}; +use crate::file::serialized_reader::ReadOptionsBuilder; +use crate::schema::types::SchemaDescriptor; + +mod filter; +mod selection; + +pub use filter::{ArrowPredicate, ArrowPredicateFn, RowFilter}; +pub use selection::{RowSelection, RowSelector}; + +/// A generic builder for constructing sync or async arrow parquet readers. This is not intended +/// to be used directly, instead you should use the specialization for the type of reader +/// you wish to use +/// +/// * For a synchronous API - [`ParquetRecordBatchReaderBuilder`] +/// * For an asynchronous API - [`ParquetRecordBatchStreamBuilder`] +/// +/// [`ParquetRecordBatchStreamBuilder`]: [crate::arrow::async_reader::ParquetRecordBatchStreamBuilder] +pub struct ArrowReaderBuilder { + pub(crate) input: T, + + pub(crate) metadata: Arc, + + pub(crate) schema: SchemaRef, + + pub(crate) batch_size: usize, + + pub(crate) row_groups: Option>, + + pub(crate) projection: ProjectionMask, + + pub(crate) filter: Option, + + pub(crate) selection: Option, +} + +impl ArrowReaderBuilder { + pub(crate) fn new_builder( + input: T, + metadata: Arc, + options: ArrowReaderOptions, + ) -> Result { + let kv_metadata = match options.skip_arrow_metadata { + true => None, + false => metadata.file_metadata().key_value_metadata(), + }; + + let schema = Arc::new(parquet_to_arrow_schema( + metadata.file_metadata().schema_descr(), + kv_metadata, + )?); + + Ok(Self { + input, + metadata, + schema, + batch_size: 1024, + row_groups: None, + projection: ProjectionMask::all(), + filter: None, + selection: None, + }) + } + + /// Returns a reference to the [`ParquetMetaData`] for this parquet file + pub fn metadata(&self) -> &Arc { + &self.metadata + } + + /// Returns the parquet [`SchemaDescriptor`] for this parquet file + pub fn parquet_schema(&self) -> &SchemaDescriptor { + self.metadata.file_metadata().schema_descr() + } + + /// Returns the arrow [`SchemaRef`] for this parquet file + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Set the size of [`RecordBatch`] to produce. Defaults to 1024 + /// If the batch_size more than the file row count, use the file row count. + pub fn with_batch_size(self, batch_size: usize) -> Self { + // Try to avoid allocate large buffer + let batch_size = + batch_size.min(self.metadata.file_metadata().num_rows() as usize); + Self { batch_size, ..self } + } + + /// Only read data from the provided row group indexes + pub fn with_row_groups(self, row_groups: Vec) -> Self { + Self { + row_groups: Some(row_groups), + ..self + } + } + + /// Only read data from the provided column indexes + pub fn with_projection(self, mask: ProjectionMask) -> Self { + Self { + projection: mask, + ..self + } + } + + /// Provide a [`RowSelection`] to filter out rows, and avoid fetching their + /// data into memory. + /// + /// Row group filtering is applied prior to this, and therefore rows from skipped + /// row groups should not be included in the [`RowSelection`] + /// + /// An example use case of this would be applying a selection determined by + /// evaluating predicates against the [`Index`] + /// + /// [`Index`]: [parquet::file::page_index::index::Index] + pub fn with_row_selection(self, selection: RowSelection) -> Self { + Self { + selection: Some(selection), + ..self + } + } + + /// Provide a [`RowFilter`] to skip decoding rows + /// + /// Row filters are applied after row group selection and row selection + pub fn with_row_filter(self, filter: RowFilter) -> Self { + Self { + filter: Some(filter), + ..self + } + } +} + +/// Arrow reader api. +/// With this api, user can get arrow schema from parquet file, and read parquet data +/// into arrow arrays. +#[deprecated(note = "Use ParquetRecordBatchReaderBuilder instead")] +pub trait ArrowReader { + type RecordReader: RecordBatchReader; + + /// Read parquet schema and convert it into arrow schema. + fn get_schema(&mut self) -> Result; + + /// Read parquet schema and convert it into arrow schema. + /// This schema only includes columns identified by `mask`. + fn get_schema_by_columns(&mut self, mask: ProjectionMask) -> Result; + + /// Returns record batch reader from whole parquet file. + /// + /// # Arguments + /// + /// `batch_size`: The size of each record batch returned from this reader. Only the + /// last batch may contain records less than this size, otherwise record batches + /// returned from this reader should contains exactly `batch_size` elements. + fn get_record_reader(&mut self, batch_size: usize) -> Result; + + /// Returns record batch reader whose record batch contains columns identified by + /// `mask`. + /// + /// # Arguments + /// + /// `mask`: The columns that should be included in record batches. + /// `batch_size`: Please refer to `get_record_reader`. + fn get_record_reader_by_columns( + &mut self, + mask: ProjectionMask, + batch_size: usize, + ) -> Result; +} + +/// Options that control how metadata is read for a parquet file +/// +/// See [`ArrowReaderBuilder`] for how to configure how the column data +/// is then read from the file, including projection and filter pushdown +#[derive(Debug, Clone, Default)] +pub struct ArrowReaderOptions { + skip_arrow_metadata: bool, + pub(crate) page_index: bool, +} + +impl ArrowReaderOptions { + /// Create a new [`ArrowReaderOptions`] with the default settings + pub fn new() -> Self { + Self::default() + } + + /// Parquet files generated by some writers may contain embedded arrow + /// schema and metadata. This may not be correct or compatible with your system. + /// + /// For example:[ARROW-16184](https://issues.apache.org/jira/browse/ARROW-16184) + /// + /// Set `skip_arrow_metadata` to true, to skip decoding this + pub fn with_skip_arrow_metadata(self, skip_arrow_metadata: bool) -> Self { + Self { + skip_arrow_metadata, + ..self + } + } + + /// Set this true to enable decoding of the [PageIndex] if present. This can be used + /// to push down predicates to the parquet scan, potentially eliminating unnecessary IO + /// + /// [PageIndex]: [https://github.com/apache/parquet-format/blob/master/PageIndex.md] + pub fn with_page_index(self, page_index: bool) -> Self { + Self { page_index, ..self } + } +} + +/// An `ArrowReader` that can be used to synchronously read parquet data as [`RecordBatch`] +/// +/// See [`crate::arrow::async_reader`] for an asynchronous interface +#[deprecated(note = "Use ParquetRecordBatchReaderBuilder instead")] +pub struct ParquetFileArrowReader { + file_reader: Arc, + + #[allow(deprecated)] + options: ArrowReaderOptions, +} + +#[allow(deprecated)] +impl ArrowReader for ParquetFileArrowReader { + type RecordReader = ParquetRecordBatchReader; + + fn get_schema(&mut self) -> Result { + let file_metadata = self.file_reader.metadata().file_metadata(); + parquet_to_arrow_schema(file_metadata.schema_descr(), self.get_kv_metadata()) + } + + fn get_schema_by_columns(&mut self, mask: ProjectionMask) -> Result { + let file_metadata = self.file_reader.metadata().file_metadata(); + parquet_to_arrow_schema_by_columns( + file_metadata.schema_descr(), + mask, + self.get_kv_metadata(), + ) + } + + fn get_record_reader( + &mut self, + batch_size: usize, + ) -> Result { + self.get_record_reader_by_columns(ProjectionMask::all(), batch_size) + } + + fn get_record_reader_by_columns( + &mut self, + mask: ProjectionMask, + batch_size: usize, + ) -> Result { + let array_reader = + build_array_reader(Arc::new(self.get_schema()?), mask, &self.file_reader)?; + + // Try to avoid allocate large buffer + let batch_size = self.file_reader.num_rows().min(batch_size); + Ok(ParquetRecordBatchReader::new( + batch_size, + array_reader, + None, + )) + } +} + +#[allow(deprecated)] +impl ParquetFileArrowReader { + /// Create a new [`ParquetFileArrowReader`] with the provided [`ChunkReader`] + /// + /// ```no_run + /// # use std::fs::File; + /// # use bytes::Bytes; + /// # use parquet::arrow::ParquetFileArrowReader; + /// + /// let file = File::open("file.parquet").unwrap(); + /// let reader = ParquetFileArrowReader::try_new(file).unwrap(); + /// + /// let bytes = Bytes::from(vec![]); + /// let reader = ParquetFileArrowReader::try_new(bytes).unwrap(); + /// ``` + pub fn try_new(chunk_reader: R) -> Result { + Self::try_new_with_options(chunk_reader, Default::default()) + } + + /// Create a new [`ParquetFileArrowReader`] with the provided [`ChunkReader`] + /// and [`ArrowReaderOptions`] + pub fn try_new_with_options( + chunk_reader: R, + options: ArrowReaderOptions, + ) -> Result { + let file_reader = Arc::new(SerializedFileReader::new(chunk_reader)?); + Ok(Self::new_with_options(file_reader, options)) + } + + /// Create a new [`ParquetFileArrowReader`] with the provided [`Arc`] + pub fn new(file_reader: Arc) -> Self { + Self::new_with_options(file_reader, Default::default()) + } + + /// Create a new [`ParquetFileArrowReader`] with the provided [`Arc`] + /// and [`ArrowReaderOptions`] + pub fn new_with_options( + file_reader: Arc, + options: ArrowReaderOptions, + ) -> Self { + Self { + file_reader, + options, + } + } + + /// Expose the reader metadata + #[deprecated = "use metadata() instead"] + pub fn get_metadata(&mut self) -> ParquetMetaData { + self.file_reader.metadata().clone() + } + + /// Returns the parquet metadata + pub fn metadata(&self) -> &ParquetMetaData { + self.file_reader.metadata() + } + + /// Returns the parquet schema + pub fn parquet_schema(&self) -> &SchemaDescriptor { + self.file_reader.metadata().file_metadata().schema_descr() + } + + /// Returns the key value metadata, returns `None` if [`ArrowReaderOptions::skip_arrow_metadata`] + fn get_kv_metadata(&self) -> Option<&Vec> { + if self.options.skip_arrow_metadata { + return None; + } + + self.file_reader + .metadata() + .file_metadata() + .key_value_metadata() + } +} + +#[doc(hidden)] +/// A newtype used within [`ReaderOptionsBuilder`] to distinguish sync readers from async +pub struct SyncReader(SerializedFileReader); + +/// A synchronous builder used to construct [`ParquetRecordBatchReader`] for a file +/// +/// For an async API see [`crate::arrow::async_reader::ParquetRecordBatchStreamBuilder`] +pub type ParquetRecordBatchReaderBuilder = ArrowReaderBuilder>; + +impl ArrowReaderBuilder> { + /// Create a new [`ParquetRecordBatchReaderBuilder`] + pub fn try_new(reader: T) -> Result { + Self::try_new_with_options(reader, Default::default()) + } + + /// Create a new [`ParquetRecordBatchReaderBuilder`] with [`ArrowReaderOptions`] + pub fn try_new_with_options(reader: T, options: ArrowReaderOptions) -> Result { + let reader = match options.page_index { + true => { + let read_options = ReadOptionsBuilder::new().with_page_index().build(); + SerializedFileReader::new_with_options(reader, read_options)? + } + false => SerializedFileReader::new(reader)?, + }; + + let metadata = Arc::clone(reader.metadata_ref()); + Self::new_builder(SyncReader(reader), metadata, options) + } + + /// Build a [`ParquetRecordBatchReader`] + /// + /// Note: this will eagerly evaluate any `RowFilter` before returning + pub fn build(self) -> Result { + let reader = + FileReaderRowGroupCollection::new(Arc::new(self.input.0), self.row_groups); + + let mut filter = self.filter; + let mut selection = self.selection; + + // Try to avoid allocate large buffer + let batch_size = self + .batch_size + .min(self.metadata.file_metadata().num_rows() as usize); + if let Some(filter) = filter.as_mut() { + for predicate in filter.predicates.iter_mut() { + if !selects_any(selection.as_ref()) { + break; + } + + let projection = predicate.projection().clone(); + let array_reader = + build_array_reader(Arc::clone(&self.schema), projection, &reader)?; + + selection = Some(evaluate_predicate( + batch_size, + array_reader, + selection, + predicate.as_mut(), + )?); + } + } + + let array_reader = build_array_reader(self.schema, self.projection, &reader)?; + + // If selection is empty, truncate + if !selects_any(selection.as_ref()) { + selection = Some(RowSelection::from(vec![])); + } + + Ok(ParquetRecordBatchReader::new( + batch_size, + array_reader, + selection, + )) + } +} + +/// An `Iterator>` that yields [`RecordBatch`] +/// read from a parquet data source +pub struct ParquetRecordBatchReader { + batch_size: usize, + array_reader: Box, + schema: SchemaRef, + selection: Option>, +} + +impl Iterator for ParquetRecordBatchReader { + type Item = ArrowResult; + + fn next(&mut self) -> Option { + let mut read_records = 0; + match self.selection.as_mut() { + Some(selection) => { + while read_records < self.batch_size && !selection.is_empty() { + let front = selection.pop_front().unwrap(); + if front.skip { + let skipped = + match self.array_reader.skip_records(front.row_count) { + Ok(skipped) => skipped, + Err(e) => return Some(Err(e.into())), + }; + + if skipped != front.row_count { + return Some(Err(general_err!( + "failed to skip rows, expected {}, got {}", + front.row_count, + skipped + ) + .into())); + } + continue; + } + + // try to read record + let need_read = self.batch_size - read_records; + let to_read = match front.row_count.checked_sub(need_read) { + Some(remaining) if remaining != 0 => { + // if page row count less than batch_size we must set batch size to page row count. + // add check avoid dead loop + selection.push_front(RowSelector::select(remaining)); + need_read + } + _ => front.row_count, + }; + match self.array_reader.read_records(to_read) { + Ok(0) => break, + Ok(rec) => read_records += rec, + Err(error) => return Some(Err(error.into())), + } + } + } + None => { + if let Err(error) = self.array_reader.read_records(self.batch_size) { + return Some(Err(error.into())); + } + } + }; + + match self.array_reader.consume_batch() { + Err(error) => Some(Err(error.into())), + Ok(array) => { + let struct_array = + array.as_any().downcast_ref::().ok_or_else(|| { + ArrowError::ParquetError( + "Struct array reader should return struct array".to_string(), + ) + }); + + match struct_array { + Err(err) => Some(Err(err)), + Ok(e) => (e.len() > 0).then(|| Ok(RecordBatch::from(e))), + } + } + } + } +} + +impl RecordBatchReader for ParquetRecordBatchReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl ParquetRecordBatchReader { + /// Create a new [`ParquetRecordBatchReader`] from the provided chunk reader + /// + /// See [`ParquetRecordBatchReaderBuilder`] for more options + pub fn try_new( + reader: T, + batch_size: usize, + ) -> Result { + ParquetRecordBatchReaderBuilder::try_new(reader)? + .with_batch_size(batch_size) + .build() + } + + /// Create a new [`ParquetRecordBatchReader`] that will read at most `batch_size` rows at + /// a time from [`ArrayReader`] based on the configured `selection`. If `selection` is `None` + /// all rows will be returned + pub(crate) fn new( + batch_size: usize, + array_reader: Box, + selection: Option, + ) -> Self { + let schema = match array_reader.get_data_type() { + ArrowType::Struct(ref fields) => Schema::new(fields.clone()), + _ => unreachable!("Struct array reader's data type is not struct!"), + }; + + Self { + batch_size, + array_reader, + schema: Arc::new(schema), + selection: selection.map(Into::into), + } + } +} + +/// Returns `true` if `selection` is `None` or selects some rows +pub(crate) fn selects_any(selection: Option<&RowSelection>) -> bool { + selection.map(|x| x.selects_any()).unwrap_or(true) +} + +/// Evaluates an [`ArrowPredicate`] returning the [`RowSelection`] +/// +/// If this [`ParquetRecordBatchReader`] has a [`RowSelection`], the +/// returned [`RowSelection`] will be the conjunction of this and +/// the rows selected by `predicate` +pub(crate) fn evaluate_predicate( + batch_size: usize, + array_reader: Box, + input_selection: Option, + predicate: &mut dyn ArrowPredicate, +) -> Result { + let reader = + ParquetRecordBatchReader::new(batch_size, array_reader, input_selection.clone()); + let mut filters = vec![]; + for maybe_batch in reader { + let filter = predicate.evaluate(maybe_batch?)?; + match filter.null_count() { + 0 => filters.push(filter), + _ => filters.push(prep_null_mask_filter(&filter)), + }; + } + + let raw = RowSelection::from_filters(&filters); + Ok(match input_selection { + Some(selection) => selection.and_then(&raw), + None => raw, + }) +} + +#[cfg(test)] +mod tests { + use std::cmp::min; + use std::collections::VecDeque; + use std::fmt::Formatter; + use std::fs::File; + use std::io::Seek; + use std::path::PathBuf; + use std::sync::Arc; + + use bytes::Bytes; + use rand::{thread_rng, Rng, RngCore}; + use tempfile::tempfile; + + use arrow::array::*; + use arrow::buffer::Buffer; + use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; + use arrow::error::Result as ArrowResult; + use arrow::record_batch::{RecordBatch, RecordBatchReader}; + + use crate::arrow::arrow_reader::{ + ArrowPredicateFn, ArrowReaderOptions, ParquetRecordBatchReader, + ParquetRecordBatchReaderBuilder, RowFilter, RowSelection, RowSelector, + }; + use crate::arrow::schema::add_encoded_arrow_schema_to_metadata; + use crate::arrow::{ArrowWriter, ProjectionMask}; + use crate::basic::{ConvertedType, Encoding, Repetition, Type as PhysicalType}; + use crate::data_type::{ + BoolType, ByteArray, ByteArrayType, DataType, FixedLenByteArray, + FixedLenByteArrayType, Int32Type, Int64Type, Int96Type, + }; + use crate::errors::Result; + use crate::file::properties::{EnabledStatistics, WriterProperties, WriterVersion}; + use crate::file::writer::SerializedFileWriter; + use crate::schema::parser::parse_message_type; + use crate::schema::types::{Type, TypePtr}; + use crate::util::test_common::rand_gen::RandGen; + + #[test] + fn test_arrow_reader_all_columns() { + let file = get_test_file("parquet/generated_simple_numerics/blogs.parquet"); + + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + let original_schema = Arc::clone(builder.schema()); + let reader = builder.build().unwrap(); + + // Verify that the schema was correctly parsed + assert_eq!(original_schema.fields(), reader.schema().fields()); + } + + #[test] + fn test_arrow_reader_single_column() { + let file = get_test_file("parquet/generated_simple_numerics/blogs.parquet"); + + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + let original_schema = Arc::clone(builder.schema()); + + let mask = ProjectionMask::leaves(builder.parquet_schema(), [2]); + let reader = builder.with_projection(mask).build().unwrap(); + + // Verify that the schema was correctly parsed + assert_eq!(1, reader.schema().fields().len()); + assert_eq!(original_schema.fields()[1], reader.schema().fields()[0]); + } + + #[test] + fn test_null_column_reader_test() { + let mut file = tempfile::tempfile().unwrap(); + + let schema = " + message message { + OPTIONAL INT32 int32; + } + "; + let schema = Arc::new(parse_message_type(schema).unwrap()); + + let def_levels = vec![vec![0, 0, 0], vec![0, 0, 0, 0]]; + generate_single_column_file_with_data::( + &[vec![], vec![]], + Some(&def_levels), + file.try_clone().unwrap(), // Cannot use &mut File (#1163) + schema, + Some(Field::new("int32", ArrowDataType::Null, true)), + &Default::default(), + ) + .unwrap(); + + file.rewind().unwrap(); + + let record_reader = ParquetRecordBatchReader::try_new(file, 2).unwrap(); + let batches = record_reader.collect::>>().unwrap(); + + assert_eq!(batches.len(), 4); + for batch in &batches[0..3] { + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.column(0).null_count(), 2); + } + + assert_eq!(batches[3].num_rows(), 1); + assert_eq!(batches[3].num_columns(), 1); + assert_eq!(batches[3].column(0).null_count(), 1); + } + + #[test] + fn test_primitive_single_column_reader_test() { + run_single_column_reader_tests::( + 2, + ConvertedType::NONE, + None, + |vals| Arc::new(BooleanArray::from_iter(vals.iter().cloned())), + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], + ); + run_single_column_reader_tests::( + 2, + ConvertedType::NONE, + None, + |vals| Arc::new(Int32Array::from_iter(vals.iter().cloned())), + &[ + Encoding::PLAIN, + Encoding::RLE_DICTIONARY, + Encoding::DELTA_BINARY_PACKED, + ], + ); + run_single_column_reader_tests::( + 2, + ConvertedType::NONE, + None, + |vals| Arc::new(Int64Array::from_iter(vals.iter().cloned())), + &[ + Encoding::PLAIN, + Encoding::RLE_DICTIONARY, + Encoding::DELTA_BINARY_PACKED, + ], + ); + } + + #[test] + fn test_unsigned_primitive_single_column_reader_test() { + run_single_column_reader_tests::( + 2, + ConvertedType::UINT_32, + Some(ArrowDataType::UInt32), + |vals| { + Arc::new(UInt32Array::from_iter( + vals.iter().map(|x| x.map(|x| x as u32)), + )) + }, + &[ + Encoding::PLAIN, + Encoding::RLE_DICTIONARY, + Encoding::DELTA_BINARY_PACKED, + ], + ); + run_single_column_reader_tests::( + 2, + ConvertedType::UINT_64, + Some(ArrowDataType::UInt64), + |vals| { + Arc::new(UInt64Array::from_iter( + vals.iter().map(|x| x.map(|x| x as u64)), + )) + }, + &[ + Encoding::PLAIN, + Encoding::RLE_DICTIONARY, + Encoding::DELTA_BINARY_PACKED, + ], + ); + } + + #[test] + fn test_unsigned_roundtrip() { + let schema = Arc::new(Schema::new(vec![ + Field::new("uint32", ArrowDataType::UInt32, true), + Field::new("uint64", ArrowDataType::UInt64, true), + ])); + + let mut buf = Vec::with_capacity(1024); + let mut writer = ArrowWriter::try_new(&mut buf, schema.clone(), None).unwrap(); + + let original = RecordBatch::try_new( + schema, + vec![ + Arc::new(UInt32Array::from_iter_values([ + 0, + i32::MAX as u32, + u32::MAX, + ])), + Arc::new(UInt64Array::from_iter_values([ + 0, + i64::MAX as u64, + u64::MAX, + ])), + ], + ) + .unwrap(); + + writer.write(&original).unwrap(); + writer.close().unwrap(); + + let mut reader = + ParquetRecordBatchReader::try_new(Bytes::from(buf), 1024).unwrap(); + let ret = reader.next().unwrap().unwrap(); + assert_eq!(ret, original); + + // Check they can be downcast to the correct type + ret.column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + ret.column(1) + .as_any() + .downcast_ref::() + .unwrap(); + } + + struct RandFixedLenGen {} + + impl RandGen for RandFixedLenGen { + fn gen(len: i32) -> FixedLenByteArray { + let mut v = vec![0u8; len as usize]; + thread_rng().fill_bytes(&mut v); + ByteArray::from(v).into() + } + } + + #[test] + fn test_fixed_length_binary_column_reader() { + run_single_column_reader_tests::( + 20, + ConvertedType::NONE, + None, + |vals| { + let mut builder = FixedSizeBinaryBuilder::with_capacity(vals.len(), 20); + for val in vals { + match val { + Some(b) => builder.append_value(b).unwrap(), + None => builder.append_null(), + } + } + Arc::new(builder.finish()) + }, + &[Encoding::PLAIN, Encoding::RLE_DICTIONARY], + ); + } + + #[test] + fn test_interval_day_time_column_reader() { + run_single_column_reader_tests::( + 12, + ConvertedType::INTERVAL, + None, + |vals| { + Arc::new( + vals.iter() + .map(|x| { + x.as_ref().map(|b| { + i64::from_le_bytes(b.as_ref()[4..12].try_into().unwrap()) + }) + }) + .collect::(), + ) + }, + &[Encoding::PLAIN, Encoding::RLE_DICTIONARY], + ); + } + + #[test] + fn test_int96_single_column_reader_test() { + let encodings = &[Encoding::PLAIN, Encoding::RLE_DICTIONARY]; + run_single_column_reader_tests::( + 2, + ConvertedType::NONE, + None, + |vals| { + Arc::new(TimestampNanosecondArray::from_iter( + vals.iter().map(|x| x.map(|x| x.to_nanos())), + )) as _ + }, + encodings, + ); + } + + struct RandUtf8Gen {} + + impl RandGen for RandUtf8Gen { + fn gen(len: i32) -> ByteArray { + Int32Type::gen(len).to_string().as_str().into() + } + } + + #[test] + fn test_utf8_single_column_reader_test() { + fn string_converter(vals: &[Option]) -> ArrayRef { + Arc::new(GenericStringArray::::from_iter(vals.iter().map(|x| { + x.as_ref().map(|b| std::str::from_utf8(b.data()).unwrap()) + }))) + } + + let encodings = &[ + Encoding::PLAIN, + Encoding::RLE_DICTIONARY, + Encoding::DELTA_LENGTH_BYTE_ARRAY, + Encoding::DELTA_BYTE_ARRAY, + ]; + + run_single_column_reader_tests::( + 2, + ConvertedType::NONE, + None, + |vals| { + Arc::new(BinaryArray::from_iter( + vals.iter().map(|x| x.as_ref().map(|x| x.data())), + )) + }, + encodings, + ); + + run_single_column_reader_tests::( + 2, + ConvertedType::UTF8, + None, + string_converter::, + encodings, + ); + + run_single_column_reader_tests::( + 2, + ConvertedType::UTF8, + Some(ArrowDataType::Utf8), + string_converter::, + encodings, + ); + + run_single_column_reader_tests::( + 2, + ConvertedType::UTF8, + Some(ArrowDataType::LargeUtf8), + string_converter::, + encodings, + ); + + let small_key_types = [ArrowDataType::Int8, ArrowDataType::UInt8]; + for key in &small_key_types { + for encoding in encodings { + let mut opts = TestOptions::new(2, 20, 15).with_null_percent(50); + opts.encoding = *encoding; + + let data_type = ArrowDataType::Dictionary( + Box::new(key.clone()), + Box::new(ArrowDataType::Utf8), + ); + + // Cannot run full test suite as keys overflow, run small test instead + single_column_reader_test::( + opts, + 2, + ConvertedType::UTF8, + Some(data_type.clone()), + move |vals| { + let vals = string_converter::(vals); + arrow::compute::cast(&vals, &data_type).unwrap() + }, + ); + } + } + + let key_types = [ + ArrowDataType::Int16, + ArrowDataType::UInt16, + ArrowDataType::Int32, + ArrowDataType::UInt32, + ArrowDataType::Int64, + ArrowDataType::UInt64, + ]; + + for key in &key_types { + let data_type = ArrowDataType::Dictionary( + Box::new(key.clone()), + Box::new(ArrowDataType::Utf8), + ); + + run_single_column_reader_tests::( + 2, + ConvertedType::UTF8, + Some(data_type.clone()), + move |vals| { + let vals = string_converter::(vals); + arrow::compute::cast(&vals, &data_type).unwrap() + }, + encodings, + ); + + // https://github.com/apache/arrow-rs/issues/1179 + // let data_type = ArrowDataType::Dictionary( + // Box::new(key.clone()), + // Box::new(ArrowDataType::LargeUtf8), + // ); + // + // run_single_column_reader_tests::( + // 2, + // ConvertedType::UTF8, + // Some(data_type.clone()), + // move |vals| { + // let vals = string_converter::(vals); + // arrow::compute::cast(&vals, &data_type).unwrap() + // }, + // encodings, + // ); + } + } + + #[test] + fn test_decimal_nullable_struct() { + let decimals = Decimal128Array::from_iter_values([1, 2, 3, 4, 5, 6, 7, 8]); + + let data = ArrayDataBuilder::new(ArrowDataType::Struct(vec![Field::new( + "decimals", + decimals.data_type().clone(), + false, + )])) + .len(8) + .null_bit_buffer(Some(Buffer::from(&[0b11101111]))) + .child_data(vec![decimals.into_data()]) + .build() + .unwrap(); + + let written = RecordBatch::try_from_iter([( + "struct", + Arc::new(StructArray::from(data)) as ArrayRef, + )]) + .unwrap(); + + let mut buffer = Vec::with_capacity(1024); + let mut writer = + ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap(); + writer.write(&written).unwrap(); + writer.close().unwrap(); + + let read = ParquetRecordBatchReader::try_new(Bytes::from(buffer), 3) + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(&written.slice(0, 3), &read[0]); + assert_eq!(&written.slice(3, 3), &read[1]); + assert_eq!(&written.slice(6, 2), &read[2]); + } + + #[test] + fn test_int32_nullable_struct() { + let int32 = Int32Array::from_iter_values([1, 2, 3, 4, 5, 6, 7, 8]); + let data = ArrayDataBuilder::new(ArrowDataType::Struct(vec![Field::new( + "int32", + int32.data_type().clone(), + false, + )])) + .len(8) + .null_bit_buffer(Some(Buffer::from(&[0b11101111]))) + .child_data(vec![int32.into_data()]) + .build() + .unwrap(); + + let written = RecordBatch::try_from_iter([( + "struct", + Arc::new(StructArray::from(data)) as ArrayRef, + )]) + .unwrap(); + + let mut buffer = Vec::with_capacity(1024); + let mut writer = + ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap(); + writer.write(&written).unwrap(); + writer.close().unwrap(); + + let read = ParquetRecordBatchReader::try_new(Bytes::from(buffer), 3) + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(&written.slice(0, 3), &read[0]); + assert_eq!(&written.slice(3, 3), &read[1]); + assert_eq!(&written.slice(6, 2), &read[2]); + } + + #[test] + #[ignore] // https://github.com/apache/arrow-rs/issues/2253 + fn test_decimal_list() { + let decimals = Decimal128Array::from_iter_values([1, 2, 3, 4, 5, 6, 7, 8]); + + // [[], [1], [2, 3], null, [4], null, [6, 7, 8]] + let data = ArrayDataBuilder::new(ArrowDataType::List(Box::new(Field::new( + "item", + decimals.data_type().clone(), + false, + )))) + .len(7) + .add_buffer(Buffer::from_iter([0_i32, 0, 1, 3, 3, 4, 5, 8])) + .null_bit_buffer(Some(Buffer::from(&[0b01010111]))) + .child_data(vec![decimals.into_data()]) + .build() + .unwrap(); + + let written = RecordBatch::try_from_iter([( + "list", + Arc::new(ListArray::from(data)) as ArrayRef, + )]) + .unwrap(); + + let mut buffer = Vec::with_capacity(1024); + let mut writer = + ArrowWriter::try_new(&mut buffer, written.schema(), None).unwrap(); + writer.write(&written).unwrap(); + writer.close().unwrap(); + + let read = ParquetRecordBatchReader::try_new(Bytes::from(buffer), 3) + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(&written.slice(0, 3), &read[0]); + assert_eq!(&written.slice(3, 3), &read[1]); + assert_eq!(&written.slice(6, 1), &read[2]); + } + + #[test] + fn test_read_decimal_file() { + use arrow::array::Decimal128Array; + let testdata = arrow::util::test_util::parquet_test_data(); + let file_variants = vec![ + ("byte_array", 4), + ("fixed_length", 25), + ("int32", 4), + ("int64", 10), + ]; + for (prefix, target_precision) in file_variants { + let path = format!("{}/{}_decimal.parquet", testdata, prefix); + let file = File::open(&path).unwrap(); + let mut record_reader = ParquetRecordBatchReader::try_new(file, 32).unwrap(); + + let batch = record_reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 24); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + let expected = 1..25; + + assert_eq!(col.precision(), target_precision); + assert_eq!(col.scale(), 2); + + for (i, v) in expected.enumerate() { + assert_eq!(col.value(i).as_i128(), v * 100_i128); + } + } + } + + /// Parameters for single_column_reader_test + #[derive(Clone)] + struct TestOptions { + /// Number of row group to write to parquet (row group size = + /// num_row_groups / num_rows) + num_row_groups: usize, + /// Total number of rows per row group + num_rows: usize, + /// Size of batches to read back + record_batch_size: usize, + /// Percentage of nulls in column or None if required + null_percent: Option, + /// Set write batch size + /// + /// This is the number of rows that are written at once to a page and + /// therefore acts as a bound on the page granularity of a row group + write_batch_size: usize, + /// Maximum size of page in bytes + max_data_page_size: usize, + /// Maximum size of dictionary page in bytes + max_dict_page_size: usize, + /// Writer version + writer_version: WriterVersion, + /// Enabled statistics + enabled_statistics: EnabledStatistics, + /// Encoding + encoding: Encoding, + /// row selections and total selected row count + row_selections: Option<(RowSelection, usize)>, + /// row filter + row_filter: Option>, + } + + /// Manually implement this to avoid printing entire contents of row_selections and row_filter + impl std::fmt::Debug for TestOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TestOptions") + .field("num_row_groups", &self.num_row_groups) + .field("num_rows", &self.num_rows) + .field("record_batch_size", &self.record_batch_size) + .field("null_percent", &self.null_percent) + .field("write_batch_size", &self.write_batch_size) + .field("max_data_page_size", &self.max_data_page_size) + .field("max_dict_page_size", &self.max_dict_page_size) + .field("writer_version", &self.writer_version) + .field("enabled_statistics", &self.enabled_statistics) + .field("encoding", &self.encoding) + .field("row_selections", &self.row_selections.is_some()) + .field("row_filter", &self.row_filter.is_some()) + .finish() + } + } + + impl Default for TestOptions { + fn default() -> Self { + Self { + num_row_groups: 2, + num_rows: 100, + record_batch_size: 15, + null_percent: None, + write_batch_size: 64, + max_data_page_size: 1024 * 1024, + max_dict_page_size: 1024 * 1024, + writer_version: WriterVersion::PARQUET_1_0, + enabled_statistics: EnabledStatistics::Page, + encoding: Encoding::PLAIN, + row_selections: None, + row_filter: None, + } + } + } + + impl TestOptions { + fn new(num_row_groups: usize, num_rows: usize, record_batch_size: usize) -> Self { + Self { + num_row_groups, + num_rows, + record_batch_size, + ..Default::default() + } + } + + fn with_null_percent(self, null_percent: usize) -> Self { + Self { + null_percent: Some(null_percent), + ..self + } + } + + fn with_max_data_page_size(self, max_data_page_size: usize) -> Self { + Self { + max_data_page_size, + ..self + } + } + + fn with_max_dict_page_size(self, max_dict_page_size: usize) -> Self { + Self { + max_dict_page_size, + ..self + } + } + + fn with_enabled_statistics(self, enabled_statistics: EnabledStatistics) -> Self { + Self { + enabled_statistics, + ..self + } + } + + fn with_row_selections(self) -> Self { + assert!(self.row_filter.is_none(), "Must set row selection first"); + + let mut rng = thread_rng(); + let step = rng.gen_range(self.record_batch_size..self.num_rows); + let row_selections = create_test_selection( + step, + self.num_row_groups * self.num_rows, + rng.gen::(), + ); + Self { + row_selections: Some(row_selections), + ..self + } + } + + fn with_row_filter(self) -> Self { + let row_count = match &self.row_selections { + Some((_, count)) => *count, + None => self.num_row_groups * self.num_rows, + }; + + let mut rng = thread_rng(); + Self { + row_filter: Some((0..row_count).map(|_| rng.gen_bool(0.9)).collect()), + ..self + } + } + + fn writer_props(&self) -> WriterProperties { + let builder = WriterProperties::builder() + .set_data_pagesize_limit(self.max_data_page_size) + .set_write_batch_size(self.write_batch_size) + .set_writer_version(self.writer_version) + .set_statistics_enabled(self.enabled_statistics); + + let builder = match self.encoding { + Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => builder + .set_dictionary_enabled(true) + .set_dictionary_pagesize_limit(self.max_dict_page_size), + _ => builder + .set_dictionary_enabled(false) + .set_encoding(self.encoding), + }; + + builder.build() + } + } + + /// Create a parquet file and then read it using + /// `ParquetFileArrowReader` using a standard set of parameters + /// `opts`. + /// + /// `rand_max` represents the maximum size of value to pass to to + /// value generator + fn run_single_column_reader_tests( + rand_max: i32, + converted_type: ConvertedType, + arrow_type: Option, + converter: F, + encodings: &[Encoding], + ) where + T: DataType, + G: RandGen, + F: Fn(&[Option]) -> ArrayRef, + { + let all_options = vec![ + // choose record_batch_batch (15) so batches cross row + // group boundaries (50 rows in 2 row groups) cases. + TestOptions::new(2, 100, 15), + // choose record_batch_batch (5) so batches sometime fall + // on row group boundaries and (25 rows in 3 row groups + // --> row groups of 10, 10, and 5). Tests buffer + // refilling edge cases. + TestOptions::new(3, 25, 5), + // Choose record_batch_size (25) so all batches fall + // exactly on row group boundary (25). Tests buffer + // refilling edge cases. + TestOptions::new(4, 100, 25), + // Set maximum page size so row groups have multiple pages + TestOptions::new(3, 256, 73).with_max_data_page_size(128), + // Set small dictionary page size to test dictionary fallback + TestOptions::new(3, 256, 57).with_max_dict_page_size(128), + // Test optional but with no nulls + TestOptions::new(2, 256, 127).with_null_percent(0), + // Test optional with nulls + TestOptions::new(2, 256, 93).with_null_percent(25), + // Test with no page-level statistics + TestOptions::new(2, 256, 91) + .with_null_percent(25) + .with_enabled_statistics(EnabledStatistics::Chunk), + // Test with no statistics + TestOptions::new(2, 256, 91) + .with_null_percent(25) + .with_enabled_statistics(EnabledStatistics::None), + // Test with all null + TestOptions::new(2, 128, 91) + .with_null_percent(100) + .with_enabled_statistics(EnabledStatistics::None), + // Test skip + + // choose record_batch_batch (15) so batches cross row + // group boundaries (50 rows in 2 row groups) cases. + TestOptions::new(2, 100, 15).with_row_selections(), + // choose record_batch_batch (5) so batches sometime fall + // on row group boundaries and (25 rows in 3 row groups + // --> row groups of 10, 10, and 5). Tests buffer + // refilling edge cases. + TestOptions::new(3, 25, 5).with_row_selections(), + // Choose record_batch_size (25) so all batches fall + // exactly on row group boundary (25). Tests buffer + // refilling edge cases. + TestOptions::new(4, 100, 25).with_row_selections(), + // Set maximum page size so row groups have multiple pages + TestOptions::new(3, 256, 73) + .with_max_data_page_size(128) + .with_row_selections(), + // Set small dictionary page size to test dictionary fallback + TestOptions::new(3, 256, 57) + .with_max_dict_page_size(128) + .with_row_selections(), + // Test optional but with no nulls + TestOptions::new(2, 256, 127) + .with_null_percent(0) + .with_row_selections(), + // Test optional with nulls + TestOptions::new(2, 256, 93) + .with_null_percent(25) + .with_row_selections(), + // Test filter + + // Test with row filter + TestOptions::new(4, 100, 25).with_row_filter(), + // Test with row selection and row filter + TestOptions::new(4, 100, 25) + .with_row_selections() + .with_row_filter(), + // Test with nulls and row filter + TestOptions::new(2, 256, 93) + .with_null_percent(25) + .with_max_data_page_size(10) + .with_row_filter(), + // Test with nulls and row filter and small pages + TestOptions::new(2, 256, 93) + .with_null_percent(25) + .with_max_data_page_size(10) + .with_row_selections() + .with_row_filter(), + // Test with row selection and no offset index and small pages + TestOptions::new(2, 256, 93) + .with_enabled_statistics(EnabledStatistics::None) + .with_max_data_page_size(10) + .with_row_selections(), + ]; + + all_options.into_iter().for_each(|opts| { + for writer_version in [WriterVersion::PARQUET_1_0, WriterVersion::PARQUET_2_0] + { + for encoding in encodings { + let opts = TestOptions { + writer_version, + encoding: *encoding, + ..opts.clone() + }; + + single_column_reader_test::( + opts, + rand_max, + converted_type, + arrow_type.clone(), + &converter, + ) + } + } + }); + } + + /// Create a parquet file and then read it using + /// `ParquetFileArrowReader` using the parameters described in + /// `opts`. + fn single_column_reader_test( + opts: TestOptions, + rand_max: i32, + converted_type: ConvertedType, + arrow_type: Option, + converter: F, + ) where + T: DataType, + G: RandGen, + F: Fn(&[Option]) -> ArrayRef, + { + // Print out options to facilitate debugging failures on CI + println!( + "Running type {:?} single_column_reader_test ConvertedType::{}/ArrowType::{:?} with Options: {:?}", + T::get_physical_type(), converted_type, arrow_type, opts + ); + + //according to null_percent generate def_levels + let (repetition, def_levels) = match opts.null_percent.as_ref() { + Some(null_percent) => { + let mut rng = thread_rng(); + + let def_levels: Vec> = (0..opts.num_row_groups) + .map(|_| { + std::iter::from_fn(|| { + Some((rng.next_u32() as usize % 100 >= *null_percent) as i16) + }) + .take(opts.num_rows) + .collect() + }) + .collect(); + (Repetition::OPTIONAL, Some(def_levels)) + } + None => (Repetition::REQUIRED, None), + }; + + //generate random table data + let values: Vec> = (0..opts.num_row_groups) + .map(|idx| { + let null_count = match def_levels.as_ref() { + Some(d) => d[idx].iter().filter(|x| **x == 0).count(), + None => 0, + }; + G::gen_vec(rand_max, opts.num_rows - null_count) + }) + .collect(); + + let len = match T::get_physical_type() { + crate::basic::Type::FIXED_LEN_BYTE_ARRAY => rand_max, + crate::basic::Type::INT96 => 12, + _ => -1, + }; + + let mut fields = vec![Arc::new( + Type::primitive_type_builder("leaf", T::get_physical_type()) + .with_repetition(repetition) + .with_converted_type(converted_type) + .with_length(len) + .build() + .unwrap(), + )]; + + let schema = Arc::new( + Type::group_type_builder("test_schema") + .with_fields(&mut fields) + .build() + .unwrap(), + ); + + let arrow_field = arrow_type.map(|t| Field::new("leaf", t, false)); + + let mut file = tempfile::tempfile().unwrap(); + + generate_single_column_file_with_data::( + &values, + def_levels.as_ref(), + file.try_clone().unwrap(), // Cannot use &mut File (#1163) + schema, + arrow_field, + &opts, + ) + .unwrap(); + + file.rewind().unwrap(); + + let options = ArrowReaderOptions::new() + .with_page_index(opts.enabled_statistics == EnabledStatistics::Page); + + let mut builder = + ParquetRecordBatchReaderBuilder::try_new_with_options(file, options).unwrap(); + + let expected_data = match opts.row_selections { + Some((selections, row_count)) => { + let mut without_skip_data = gen_expected_data::(&def_levels, &values); + + let mut skip_data: Vec> = vec![]; + let dequeue: VecDeque = selections.clone().into(); + for select in dequeue { + if select.skip { + without_skip_data.drain(0..select.row_count); + } else { + skip_data.extend(without_skip_data.drain(0..select.row_count)); + } + } + builder = builder.with_row_selection(selections); + + assert_eq!(skip_data.len(), row_count); + skip_data + } + None => { + //get flatten table data + let expected_data = gen_expected_data::(&def_levels, &values); + assert_eq!(expected_data.len(), opts.num_rows * opts.num_row_groups); + expected_data + } + }; + + let expected_data = match opts.row_filter { + Some(filter) => { + let expected_data = expected_data + .into_iter() + .zip(filter.iter()) + .filter_map(|(d, f)| f.then(|| d)) + .collect(); + + let mut filter_offset = 0; + let filter = RowFilter::new(vec![Box::new(ArrowPredicateFn::new( + ProjectionMask::all(), + move |b| { + let array = BooleanArray::from_iter( + filter + .iter() + .skip(filter_offset) + .take(b.num_rows()) + .map(|x| Some(*x)), + ); + filter_offset += b.num_rows(); + Ok(array) + }, + ))]); + + builder = builder.with_row_filter(filter); + expected_data + } + None => expected_data, + }; + + let mut record_reader = builder + .with_batch_size(opts.record_batch_size) + .build() + .unwrap(); + + let mut total_read = 0; + loop { + let maybe_batch = record_reader.next(); + if total_read < expected_data.len() { + let end = min(total_read + opts.record_batch_size, expected_data.len()); + let batch = maybe_batch.unwrap().unwrap(); + assert_eq!(end - total_read, batch.num_rows()); + + let a = converter(&expected_data[total_read..end]); + let b = Arc::clone(batch.column(0)); + + assert_eq!(a.data_type(), b.data_type()); + assert_eq!(a.data(), b.data(), "{:#?} vs {:#?}", a.data(), b.data()); + assert_eq!( + a.as_any().type_id(), + b.as_any().type_id(), + "incorrect type ids" + ); + + total_read = end; + } else { + assert!(maybe_batch.is_none()); + break; + } + } + } + + fn gen_expected_data( + def_levels: &Option>>, + values: &[Vec], + ) -> Vec> { + let data: Vec> = match def_levels { + Some(levels) => { + let mut values_iter = values.iter().flatten(); + levels + .iter() + .flatten() + .map(|d| match d { + 1 => Some(values_iter.next().cloned().unwrap()), + 0 => None, + _ => unreachable!(), + }) + .collect() + } + None => values.iter().flatten().map(|b| Some(b.clone())).collect(), + }; + data + } + + fn generate_single_column_file_with_data( + values: &[Vec], + def_levels: Option<&Vec>>, + file: File, + schema: TypePtr, + field: Option, + opts: &TestOptions, + ) -> Result { + let mut writer_props = opts.writer_props(); + if let Some(field) = field { + let arrow_schema = Schema::new(vec![field]); + add_encoded_arrow_schema_to_metadata(&arrow_schema, &mut writer_props); + } + + let mut writer = SerializedFileWriter::new(file, schema, Arc::new(writer_props))?; + + for (idx, v) in values.iter().enumerate() { + let def_levels = def_levels.map(|d| d[idx].as_slice()); + let mut row_group_writer = writer.next_row_group()?; + { + let mut column_writer = row_group_writer + .next_column()? + .expect("Column writer is none!"); + + column_writer + .typed::() + .write_batch(v, def_levels, None)?; + + column_writer.close()?; + } + row_group_writer.close()?; + } + + writer.close() + } + + fn get_test_file(file_name: &str) -> File { + let mut path = PathBuf::new(); + path.push(arrow::util::test_util::arrow_test_data()); + path.push(file_name); + + File::open(path.as_path()).expect("File not found!") + } + + #[test] + fn test_read_structs() { + // This particular test file has columns of struct types where there is + // a column that has the same name as one of the struct fields + // (see: ARROW-11452) + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/nested_structs.rust.parquet", testdata); + let file = File::open(&path).unwrap(); + let record_batch_reader = ParquetRecordBatchReader::try_new(file, 60).unwrap(); + + for batch in record_batch_reader { + batch.unwrap(); + } + + let file = File::open(&path).unwrap(); + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + + let mask = ProjectionMask::leaves(builder.parquet_schema(), [3, 8, 10]); + let projected_reader = builder + .with_projection(mask) + .with_batch_size(60) + .build() + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new( + "roll_num", + ArrowDataType::Struct(vec![Field::new( + "count", + ArrowDataType::UInt64, + false, + )]), + false, + ), + Field::new( + "PC_CUR", + ArrowDataType::Struct(vec![ + Field::new("mean", ArrowDataType::Int64, false), + Field::new("sum", ArrowDataType::Int64, false), + ]), + false, + ), + ]); + + // Tests for #1652 and #1654 + assert_eq!(&expected_schema, projected_reader.schema().as_ref()); + + for batch in projected_reader { + let batch = batch.unwrap(); + assert_eq!(batch.schema().as_ref(), &expected_schema); + } + } + + #[test] + fn test_read_maps() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/nested_maps.snappy.parquet", testdata); + let file = File::open(&path).unwrap(); + let record_batch_reader = ParquetRecordBatchReader::try_new(file, 60).unwrap(); + + for batch in record_batch_reader { + batch.unwrap(); + } + } + + #[test] + fn test_nested_nullability() { + let message_type = "message nested { + OPTIONAL Group group { + REQUIRED INT32 leaf; + } + }"; + + let file = tempfile::tempfile().unwrap(); + let schema = Arc::new(parse_message_type(message_type).unwrap()); + + { + // Write using low-level parquet API (#1167) + let writer_props = Arc::new(WriterProperties::builder().build()); + let mut writer = SerializedFileWriter::new( + file.try_clone().unwrap(), + schema, + writer_props, + ) + .unwrap(); + + { + let mut row_group_writer = writer.next_row_group().unwrap(); + let mut column_writer = row_group_writer.next_column().unwrap().unwrap(); + + column_writer + .typed::() + .write_batch(&[34, 76], Some(&[0, 1, 0, 1]), None) + .unwrap(); + + column_writer.close().unwrap(); + row_group_writer.close().unwrap(); + } + + writer.close().unwrap(); + } + + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + let mask = ProjectionMask::leaves(builder.parquet_schema(), [0]); + + let reader = builder.with_projection(mask).build().unwrap(); + + let expected_schema = Schema::new(vec![Field::new( + "group", + ArrowDataType::Struct(vec![Field::new("leaf", ArrowDataType::Int32, false)]), + true, + )]); + + let batch = reader.into_iter().next().unwrap().unwrap(); + assert_eq!(batch.schema().as_ref(), &expected_schema); + assert_eq!(batch.num_rows(), 4); + assert_eq!(batch.column(0).data().null_count(), 2); + } + + #[test] + fn test_invalid_utf8() { + // a parquet file with 1 column with invalid utf8 + let data = vec![ + 80, 65, 82, 49, 21, 6, 21, 22, 21, 22, 92, 21, 2, 21, 0, 21, 2, 21, 0, 21, 4, + 21, 0, 18, 28, 54, 0, 40, 5, 104, 101, 255, 108, 111, 24, 5, 104, 101, 255, + 108, 111, 0, 0, 0, 3, 1, 5, 0, 0, 0, 104, 101, 255, 108, 111, 38, 110, 28, + 21, 12, 25, 37, 6, 0, 25, 24, 2, 99, 49, 21, 0, 22, 2, 22, 102, 22, 102, 38, + 8, 60, 54, 0, 40, 5, 104, 101, 255, 108, 111, 24, 5, 104, 101, 255, 108, 111, + 0, 0, 0, 21, 4, 25, 44, 72, 4, 114, 111, 111, 116, 21, 2, 0, 21, 12, 37, 2, + 24, 2, 99, 49, 37, 0, 76, 28, 0, 0, 0, 22, 2, 25, 28, 25, 28, 38, 110, 28, + 21, 12, 25, 37, 6, 0, 25, 24, 2, 99, 49, 21, 0, 22, 2, 22, 102, 22, 102, 38, + 8, 60, 54, 0, 40, 5, 104, 101, 255, 108, 111, 24, 5, 104, 101, 255, 108, 111, + 0, 0, 0, 22, 102, 22, 2, 0, 40, 44, 65, 114, 114, 111, 119, 50, 32, 45, 32, + 78, 97, 116, 105, 118, 101, 32, 82, 117, 115, 116, 32, 105, 109, 112, 108, + 101, 109, 101, 110, 116, 97, 116, 105, 111, 110, 32, 111, 102, 32, 65, 114, + 114, 111, 119, 0, 130, 0, 0, 0, 80, 65, 82, 49, + ]; + + let file = Bytes::from(data); + let mut record_batch_reader = + ParquetRecordBatchReader::try_new(file, 10).unwrap(); + + let error = record_batch_reader.next().unwrap().unwrap_err(); + + assert!( + error.to_string().contains("invalid utf-8 sequence"), + "{}", + error + ); + } + + #[test] + fn test_dictionary_preservation() { + let mut fields = vec![Arc::new( + Type::primitive_type_builder("leaf", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::OPTIONAL) + .with_converted_type(ConvertedType::UTF8) + .build() + .unwrap(), + )]; + + let schema = Arc::new( + Type::group_type_builder("test_schema") + .with_fields(&mut fields) + .build() + .unwrap(), + ); + + let dict_type = ArrowDataType::Dictionary( + Box::new(ArrowDataType::Int32), + Box::new(ArrowDataType::Utf8), + ); + + let arrow_field = Field::new("leaf", dict_type, true); + + let mut file = tempfile::tempfile().unwrap(); + + let values = vec![ + vec![ + ByteArray::from("hello"), + ByteArray::from("a"), + ByteArray::from("b"), + ByteArray::from("d"), + ], + vec![ + ByteArray::from("c"), + ByteArray::from("a"), + ByteArray::from("b"), + ], + ]; + + let def_levels = vec![ + vec![1, 0, 0, 1, 0, 0, 1, 1], + vec![0, 0, 1, 1, 0, 0, 1, 0, 0], + ]; + + let opts = TestOptions { + encoding: Encoding::RLE_DICTIONARY, + ..Default::default() + }; + + generate_single_column_file_with_data::( + &values, + Some(&def_levels), + file.try_clone().unwrap(), // Cannot use &mut File (#1163) + schema, + Some(arrow_field), + &opts, + ) + .unwrap(); + + file.rewind().unwrap(); + + let record_reader = ParquetRecordBatchReader::try_new(file, 3).unwrap(); + + let batches = record_reader + .collect::>>() + .unwrap(); + + assert_eq!(batches.len(), 6); + assert!(batches.iter().all(|x| x.num_columns() == 1)); + + let row_counts = batches + .iter() + .map(|x| (x.num_rows(), x.column(0).null_count())) + .collect::>(); + + assert_eq!( + row_counts, + vec![(3, 2), (3, 2), (3, 1), (3, 1), (3, 2), (2, 2)] + ); + + let get_dict = + |batch: &RecordBatch| batch.column(0).data().child_data()[0].clone(); + + // First and second batch in same row group -> same dictionary + assert_eq!(get_dict(&batches[0]), get_dict(&batches[1])); + // Third batch spans row group -> computed dictionary + assert_ne!(get_dict(&batches[1]), get_dict(&batches[2])); + assert_ne!(get_dict(&batches[2]), get_dict(&batches[3])); + // Fourth, fifth and sixth from same row group -> same dictionary + assert_eq!(get_dict(&batches[3]), get_dict(&batches[4])); + assert_eq!(get_dict(&batches[4]), get_dict(&batches[5])); + } + + #[test] + fn test_read_null_list() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/null_list.parquet", testdata); + let file = File::open(&path).unwrap(); + let mut record_batch_reader = + ParquetRecordBatchReader::try_new(file, 60).unwrap(); + + let batch = record_batch_reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.column(0).len(), 1); + + let list = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(list.len(), 1); + assert!(list.is_valid(0)); + + let val = list.value(0); + assert_eq!(val.len(), 0); + } + + #[test] + fn test_null_schema_inference() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/null_list.parquet", testdata); + let file = File::open(&path).unwrap(); + + let arrow_field = Field::new( + "emptylist", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Null, true))), + true, + ); + + let options = ArrowReaderOptions::new().with_skip_arrow_metadata(true); + let builder = + ParquetRecordBatchReaderBuilder::try_new_with_options(file, options).unwrap(); + let schema = builder.schema(); + assert_eq!(schema.fields().len(), 1); + assert_eq!(schema.field(0), &arrow_field); + } + + #[test] + fn test_skip_metadata() { + let col = Arc::new(TimestampNanosecondArray::from_iter_values(vec![0, 1, 2])); + let field = Field::new("col", col.data_type().clone(), true); + + let schema_without_metadata = Arc::new(Schema::new(vec![field.clone()])); + + let metadata = [("key".to_string(), "value".to_string())] + .into_iter() + .collect(); + + let schema_with_metadata = + Arc::new(Schema::new(vec![field.with_metadata(Some(metadata))])); + + assert_ne!(schema_with_metadata, schema_without_metadata); + + let batch = + RecordBatch::try_new(schema_with_metadata.clone(), vec![col as ArrayRef]) + .unwrap(); + + let file = |version: WriterVersion| { + let props = WriterProperties::builder() + .set_writer_version(version) + .build(); + + let file = tempfile().unwrap(); + let mut writer = ArrowWriter::try_new( + file.try_clone().unwrap(), + batch.schema(), + Some(props), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); + file + }; + + let skip_options = ArrowReaderOptions::new().with_skip_arrow_metadata(true); + + let v1_reader = file(WriterVersion::PARQUET_1_0); + let v2_reader = file(WriterVersion::PARQUET_2_0); + + let arrow_reader = + ParquetRecordBatchReader::try_new(v1_reader.try_clone().unwrap(), 1024) + .unwrap(); + assert_eq!(arrow_reader.schema(), schema_with_metadata); + + let reader = ParquetRecordBatchReaderBuilder::try_new_with_options( + v1_reader, + skip_options.clone(), + ) + .unwrap() + .build() + .unwrap(); + assert_eq!(reader.schema(), schema_without_metadata); + + let arrow_reader = + ParquetRecordBatchReader::try_new(v2_reader.try_clone().unwrap(), 1024) + .unwrap(); + assert_eq!(arrow_reader.schema(), schema_with_metadata); + + let reader = ParquetRecordBatchReaderBuilder::try_new_with_options( + v2_reader, + skip_options, + ) + .unwrap() + .build() + .unwrap(); + assert_eq!(reader.schema(), schema_without_metadata); + } + + #[test] + fn test_empty_projection() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/alltypes_plain.parquet", testdata); + let file = File::open(&path).unwrap(); + + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + let file_metadata = builder.metadata().file_metadata(); + let expected_rows = file_metadata.num_rows() as usize; + + let mask = ProjectionMask::leaves(builder.parquet_schema(), []); + let batch_reader = builder + .with_projection(mask) + .with_batch_size(2) + .build() + .unwrap(); + + let mut total_rows = 0; + for maybe_batch in batch_reader { + let batch = maybe_batch.unwrap(); + total_rows += batch.num_rows(); + assert_eq!(batch.num_columns(), 0); + assert!(batch.num_rows() <= 2); + } + + assert_eq!(total_rows, expected_rows); + } + + fn test_row_group_batch(row_group_size: usize, batch_size: usize) { + let schema = Arc::new(Schema::new(vec![Field::new( + "list", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + true, + )])); + + let mut buf = Vec::with_capacity(1024); + + let mut writer = ArrowWriter::try_new( + &mut buf, + schema.clone(), + Some( + WriterProperties::builder() + .set_max_row_group_size(row_group_size) + .build(), + ), + ) + .unwrap(); + for _ in 0..2 { + let mut list_builder = + ListBuilder::new(Int32Builder::with_capacity(batch_size)); + for _ in 0..(batch_size) { + list_builder.append(true); + } + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(list_builder.finish())], + ) + .unwrap(); + writer.write(&batch).unwrap(); + } + writer.close().unwrap(); + + let mut record_reader = + ParquetRecordBatchReader::try_new(Bytes::from(buf), batch_size).unwrap(); + assert_eq!( + batch_size, + record_reader.next().unwrap().unwrap().num_rows() + ); + assert_eq!( + batch_size, + record_reader.next().unwrap().unwrap().num_rows() + ); + } + + #[test] + fn test_row_group_exact_multiple() { + use crate::arrow::record_reader::MIN_BATCH_SIZE; + test_row_group_batch(8, 8); + test_row_group_batch(10, 8); + test_row_group_batch(8, 10); + test_row_group_batch(MIN_BATCH_SIZE, MIN_BATCH_SIZE); + test_row_group_batch(MIN_BATCH_SIZE + 1, MIN_BATCH_SIZE); + test_row_group_batch(MIN_BATCH_SIZE, MIN_BATCH_SIZE + 1); + test_row_group_batch(MIN_BATCH_SIZE, MIN_BATCH_SIZE - 1); + test_row_group_batch(MIN_BATCH_SIZE - 1, MIN_BATCH_SIZE); + } + + /// Given a RecordBatch containing all the column data, return the expected batches given + /// a `batch_size` and `selection` + fn get_expected_batches( + column: &RecordBatch, + selection: &RowSelection, + batch_size: usize, + ) -> Vec { + let mut expected_batches = vec![]; + + let mut selection: VecDeque<_> = selection.clone().into(); + let mut row_offset = 0; + let mut last_start = None; + while row_offset < column.num_rows() && !selection.is_empty() { + let mut batch_remaining = batch_size.min(column.num_rows() - row_offset); + while batch_remaining > 0 && !selection.is_empty() { + let (to_read, skip) = match selection.front_mut() { + Some(selection) if selection.row_count > batch_remaining => { + selection.row_count -= batch_remaining; + (batch_remaining, selection.skip) + } + Some(_) => { + let select = selection.pop_front().unwrap(); + (select.row_count, select.skip) + } + None => break, + }; + + batch_remaining -= to_read; + + match skip { + true => { + if let Some(last_start) = last_start.take() { + expected_batches + .push(column.slice(last_start, row_offset - last_start)) + } + row_offset += to_read + } + false => { + last_start.get_or_insert(row_offset); + row_offset += to_read + } + } + } + } + + if let Some(last_start) = last_start.take() { + expected_batches.push(column.slice(last_start, row_offset - last_start)) + } + + // Sanity check, all batches except the final should be the batch size + for batch in &expected_batches[..expected_batches.len() - 1] { + assert_eq!(batch.num_rows(), batch_size); + } + + expected_batches + } + + fn create_test_selection( + step_len: usize, + total_len: usize, + skip_first: bool, + ) -> (RowSelection, usize) { + let mut remaining = total_len; + let mut skip = skip_first; + let mut vec = vec![]; + let mut selected_count = 0; + while remaining != 0 { + let step = if remaining > step_len { + step_len + } else { + remaining + }; + vec.push(RowSelector { + row_count: step, + skip, + }); + remaining -= step; + if !skip { + selected_count += step; + } + skip = !skip; + } + (vec.into(), selected_count) + } + + #[test] + fn test_scan_row_with_selection() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/alltypes_tiny_pages_plain.parquet", testdata); + let test_file = File::open(&path).unwrap(); + + let mut serial_reader = + ParquetRecordBatchReader::try_new(File::open(path).unwrap(), 7300).unwrap(); + let data = serial_reader.next().unwrap().unwrap(); + + let do_test = |batch_size: usize, selection_len: usize| { + for skip_first in [false, true] { + let selections = + create_test_selection(batch_size, data.num_rows(), skip_first).0; + + let expected = get_expected_batches(&data, &selections, batch_size); + let skip_reader = create_skip_reader(&test_file, batch_size, selections); + assert_eq!( + skip_reader.collect::>>().unwrap(), + expected, + "batch_size: {}, selection_len: {}, skip_first: {}", + batch_size, + selection_len, + skip_first + ); + } + }; + + // total row count 7300 + // 1. test selection len more than one page row count + do_test(1000, 1000); + + // 2. test selection len less than one page row count + do_test(20, 20); + + // 3. test selection_len less than batch_size + do_test(20, 5); + + // 4. test selection_len more than batch_size + // If batch_size < selection_len + do_test(20, 5); + + fn create_skip_reader( + test_file: &File, + batch_size: usize, + selections: RowSelection, + ) -> ParquetRecordBatchReader { + let options = ArrowReaderOptions::new().with_page_index(true); + let file = test_file.try_clone().unwrap(); + ParquetRecordBatchReaderBuilder::try_new_with_options(file, options) + .unwrap() + .with_batch_size(batch_size) + .with_row_selection(selections) + .build() + .unwrap() + } + } + + #[test] + fn test_batch_size_overallocate() { + let testdata = arrow::util::test_util::parquet_test_data(); + // `alltypes_plain.parquet` only have 8 rows + let path = format!("{}/alltypes_plain.parquet", testdata); + let test_file = File::open(&path).unwrap(); + + let builder = ParquetRecordBatchReaderBuilder::try_new(test_file).unwrap(); + let num_rows = builder.metadata.file_metadata().num_rows(); + let reader = builder + .with_batch_size(1024) + .with_projection(ProjectionMask::all()) + .build() + .unwrap(); + assert_ne!(1024, num_rows); + assert_eq!(reader.batch_size, num_rows as usize); + } +} diff --git a/parquet/src/arrow/arrow_reader/selection.rs b/parquet/src/arrow/arrow_reader/selection.rs new file mode 100644 index 000000000000..544b7931a265 --- /dev/null +++ b/parquet/src/arrow/arrow_reader/selection.rs @@ -0,0 +1,618 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, BooleanArray}; +use arrow::compute::SlicesIterator; +use std::cmp::Ordering; +use std::collections::VecDeque; +use std::ops::Range; + +/// [`RowSelection`] is a collection of [`RowSelector`] used to skip rows when +/// scanning a parquet file +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct RowSelector { + /// The number of rows + pub row_count: usize, + + /// If true, skip `row_count` rows + pub skip: bool, +} + +impl RowSelector { + /// Select `row_count` rows + pub fn select(row_count: usize) -> Self { + Self { + row_count, + skip: false, + } + } + + /// Skip `row_count` rows + pub fn skip(row_count: usize) -> Self { + Self { + row_count, + skip: true, + } + } +} + +/// [`RowSelection`] allows selecting or skipping a provided number of rows +/// when scanning the parquet file. +/// +/// This is applied prior to reading column data, and can therefore +/// be used to skip IO to fetch data into memory +/// +/// A typical use-case would be using the [`PageIndex`] to filter out rows +/// that don't satisfy a predicate +/// +/// [`PageIndex`]: [crate::file::page_index::index::PageIndex] +#[derive(Debug, Clone, Default, Eq, PartialEq)] +pub struct RowSelection { + selectors: Vec, +} + +impl RowSelection { + /// Creates a [`RowSelection`] from a slice of [`BooleanArray`] + /// + /// # Panic + /// + /// Panics if any of the [`BooleanArray`] contain nulls + pub fn from_filters(filters: &[BooleanArray]) -> Self { + let mut next_offset = 0; + let total_rows = filters.iter().map(|x| x.len()).sum(); + + let iter = filters.iter().flat_map(|filter| { + let offset = next_offset; + next_offset += filter.len(); + assert_eq!(filter.null_count(), 0); + SlicesIterator::new(filter) + .map(move |(start, end)| start + offset..end + offset) + }); + + Self::from_consecutive_ranges(iter, total_rows) + } + + /// Creates a [`RowSelection`] from an iterator of consecutive ranges to keep + fn from_consecutive_ranges>>( + ranges: I, + total_rows: usize, + ) -> Self { + let mut selectors: Vec = Vec::with_capacity(ranges.size_hint().0); + let mut last_end = 0; + for range in ranges { + let len = range.end - range.start; + + match range.start.cmp(&last_end) { + Ordering::Equal => match selectors.last_mut() { + Some(last) => last.row_count += len, + None => selectors.push(RowSelector::select(len)), + }, + Ordering::Greater => { + selectors.push(RowSelector::skip(range.start - last_end)); + selectors.push(RowSelector::select(len)) + } + Ordering::Less => panic!("out of order"), + } + last_end = range.end; + } + + if last_end != total_rows { + selectors.push(RowSelector::skip(total_rows - last_end)) + } + + Self { selectors } + } + + /// Given an offset index, return the offset ranges for all data pages selected by `self` + #[cfg(any(test, feature = "async"))] + pub(crate) fn scan_ranges( + &self, + page_locations: &[parquet_format::PageLocation], + ) -> Vec> { + let mut ranges = vec![]; + let mut row_offset = 0; + + let mut pages = page_locations.iter().peekable(); + let mut selectors = self.selectors.iter().cloned(); + let mut current_selector = selectors.next(); + let mut current_page = pages.next(); + + let mut current_page_included = false; + + while let Some((selector, page)) = current_selector.as_mut().zip(current_page) { + if !(selector.skip || current_page_included) { + let start = page.offset as usize; + let end = start + page.compressed_page_size as usize; + ranges.push(start..end); + current_page_included = true; + } + + if let Some(next_page) = pages.peek() { + if row_offset + selector.row_count > next_page.first_row_index as usize { + let remaining_in_page = + next_page.first_row_index as usize - row_offset; + selector.row_count -= remaining_in_page; + row_offset += remaining_in_page; + current_page = pages.next(); + current_page_included = false; + + continue; + } else { + if row_offset + selector.row_count + == next_page.first_row_index as usize + { + current_page = pages.next(); + current_page_included = false; + } + row_offset += selector.row_count; + current_selector = selectors.next(); + } + } else { + if !(selector.skip || current_page_included) { + let start = page.offset as usize; + let end = start + page.compressed_page_size as usize; + ranges.push(start..end); + } + current_selector = selectors.next() + } + } + + ranges + } + + /// Splits off the first `row_count` from this [`RowSelection`] + pub fn split_off(&mut self, row_count: usize) -> Self { + let mut total_count = 0; + + // Find the index where the selector exceeds the row count + let find = self.selectors.iter().enumerate().find(|(_, selector)| { + total_count += selector.row_count; + total_count > row_count + }); + + let split_idx = match find { + Some((idx, _)) => idx, + None => { + let selectors = std::mem::take(&mut self.selectors); + return Self { selectors }; + } + }; + + let mut remaining = self.selectors.split_off(split_idx); + + // Always present as `split_idx < self.selectors.len` + let next = remaining.first_mut().unwrap(); + let overflow = total_count - row_count; + + if next.row_count != overflow { + self.selectors.push(RowSelector { + row_count: next.row_count - overflow, + skip: next.skip, + }) + } + next.row_count = overflow; + + std::mem::swap(&mut remaining, &mut self.selectors); + Self { + selectors: remaining, + } + } + + /// Given a [`RowSelection`] computed under `self`, returns the [`RowSelection`] + /// representing their conjunction + /// + /// For example: + /// + /// self: NNNNNNNNNNNNYYYYYYYYYYYYYYYYYYYYYYNNNYYYYY + /// other: YYYYYNNNNYYYYYYYYYYYYY YYNNN + /// + /// returned: NNNNNNNNNNNNYYYYYNNNNYYYYYYYYYYYYYNNNYYNNN + /// + /// + pub fn and_then(&self, other: &Self) -> Self { + let mut selectors = vec![]; + let mut first = self.selectors.iter().cloned().peekable(); + let mut second = other.selectors.iter().cloned().peekable(); + + let mut to_skip = 0; + while let Some(b) = second.peek_mut() { + let a = first.peek_mut().unwrap(); + + if b.row_count == 0 { + second.next().unwrap(); + continue; + } + + if a.row_count == 0 { + first.next().unwrap(); + continue; + } + + if a.skip { + // Records were skipped when producing second + to_skip += a.row_count; + first.next().unwrap(); + continue; + } + + let skip = b.skip; + let to_process = a.row_count.min(b.row_count); + + a.row_count -= to_process; + b.row_count -= to_process; + + match skip { + true => to_skip += to_process, + false => { + if to_skip != 0 { + selectors.push(RowSelector::skip(to_skip)); + to_skip = 0; + } + selectors.push(RowSelector::select(to_process)) + } + } + } + + for v in first { + if v.row_count != 0 { + assert!(v.skip); + to_skip += v.row_count + } + } + + if to_skip != 0 { + selectors.push(RowSelector::skip(to_skip)); + } + + Self { selectors } + } + + /// Returns `true` if this [`RowSelection`] selects any rows + pub fn selects_any(&self) -> bool { + self.selectors.iter().any(|x| !x.skip) + } +} + +impl From> for RowSelection { + fn from(selectors: Vec) -> Self { + Self { selectors } + } +} + +impl From for VecDeque { + fn from(r: RowSelection) -> Self { + r.selectors.into() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use parquet_format::PageLocation; + use rand::{thread_rng, Rng}; + + #[test] + fn test_from_filters() { + let filters = vec![ + BooleanArray::from(vec![false, false, false, true, true, true, true]), + BooleanArray::from(vec![true, true, false, false, true, true, true]), + BooleanArray::from(vec![false, false, false, false]), + BooleanArray::from(Vec::::new()), + ]; + + let selection = RowSelection::from_filters(&filters[..1]); + assert!(selection.selects_any()); + assert_eq!( + selection.selectors, + vec![RowSelector::skip(3), RowSelector::select(4)] + ); + + let selection = RowSelection::from_filters(&filters[..2]); + assert!(selection.selects_any()); + assert_eq!( + selection.selectors, + vec![ + RowSelector::skip(3), + RowSelector::select(6), + RowSelector::skip(2), + RowSelector::select(3) + ] + ); + + let selection = RowSelection::from_filters(&filters); + assert!(selection.selects_any()); + assert_eq!( + selection.selectors, + vec![ + RowSelector::skip(3), + RowSelector::select(6), + RowSelector::skip(2), + RowSelector::select(3), + RowSelector::skip(4) + ] + ); + + let selection = RowSelection::from_filters(&filters[2..3]); + assert!(!selection.selects_any()); + assert_eq!(selection.selectors, vec![RowSelector::skip(4)]); + } + + #[test] + fn test_split_off() { + let mut selection = RowSelection::from(vec![ + RowSelector::skip(34), + RowSelector::select(12), + RowSelector::skip(3), + RowSelector::select(35), + ]); + + let split = selection.split_off(34); + assert_eq!(split.selectors, vec![RowSelector::skip(34)]); + assert_eq!( + selection.selectors, + vec![ + RowSelector::select(12), + RowSelector::skip(3), + RowSelector::select(35) + ] + ); + + let split = selection.split_off(5); + assert_eq!(split.selectors, vec![RowSelector::select(5)]); + assert_eq!( + selection.selectors, + vec![ + RowSelector::select(7), + RowSelector::skip(3), + RowSelector::select(35) + ] + ); + + let split = selection.split_off(8); + assert_eq!( + split.selectors, + vec![RowSelector::select(7), RowSelector::skip(1)] + ); + assert_eq!( + selection.selectors, + vec![RowSelector::skip(2), RowSelector::select(35)] + ); + + let split = selection.split_off(200); + assert_eq!( + split.selectors, + vec![RowSelector::skip(2), RowSelector::select(35)] + ); + assert!(selection.selectors.is_empty()); + } + + #[test] + fn test_and() { + let mut a = RowSelection::from(vec![ + RowSelector::skip(12), + RowSelector::select(23), + RowSelector::skip(3), + RowSelector::select(5), + ]); + + let b = RowSelection::from(vec![ + RowSelector::select(5), + RowSelector::skip(4), + RowSelector::select(15), + RowSelector::skip(4), + ]); + + let mut expected = RowSelection::from(vec![ + RowSelector::skip(12), + RowSelector::select(5), + RowSelector::skip(4), + RowSelector::select(14), + RowSelector::skip(3), + RowSelector::select(1), + RowSelector::skip(4), + ]); + + assert_eq!(a.and_then(&b), expected); + + a.split_off(7); + expected.split_off(7); + assert_eq!(a.and_then(&b), expected); + + let a = RowSelection::from(vec![RowSelector::select(5), RowSelector::skip(3)]); + + let b = RowSelection::from(vec![ + RowSelector::select(2), + RowSelector::skip(1), + RowSelector::select(1), + RowSelector::skip(1), + ]); + + assert_eq!( + a.and_then(&b).selectors, + vec![ + RowSelector::select(2), + RowSelector::skip(1), + RowSelector::select(1), + RowSelector::skip(4) + ] + ); + } + + #[test] + fn test_and_fuzz() { + let mut rand = thread_rng(); + for _ in 0..100 { + let a_len = rand.gen_range(10..100); + let a_bools: Vec<_> = (0..a_len).map(|_| rand.gen_bool(0.2)).collect(); + let a = RowSelection::from_filters(&[BooleanArray::from(a_bools.clone())]); + + let b_len: usize = a_bools.iter().map(|x| *x as usize).sum(); + let b_bools: Vec<_> = (0..b_len).map(|_| rand.gen_bool(0.8)).collect(); + let b = RowSelection::from_filters(&[BooleanArray::from(b_bools.clone())]); + + let mut expected_bools = vec![false; a_len]; + + let mut iter_b = b_bools.iter(); + for (idx, b) in a_bools.iter().enumerate() { + if *b && *iter_b.next().unwrap() { + expected_bools[idx] = true; + } + } + + let expected = + RowSelection::from_filters(&[BooleanArray::from(expected_bools)]); + + let total_rows: usize = expected.selectors.iter().map(|s| s.row_count).sum(); + assert_eq!(a_len, total_rows); + + assert_eq!(a.and_then(&b), expected); + } + } + + #[test] + fn test_scan_ranges() { + let index = vec![ + PageLocation { + offset: 0, + compressed_page_size: 10, + first_row_index: 0, + }, + PageLocation { + offset: 10, + compressed_page_size: 10, + first_row_index: 10, + }, + PageLocation { + offset: 20, + compressed_page_size: 10, + first_row_index: 20, + }, + PageLocation { + offset: 30, + compressed_page_size: 10, + first_row_index: 30, + }, + PageLocation { + offset: 40, + compressed_page_size: 10, + first_row_index: 40, + }, + PageLocation { + offset: 50, + compressed_page_size: 10, + first_row_index: 50, + }, + PageLocation { + offset: 60, + compressed_page_size: 10, + first_row_index: 60, + }, + ]; + + let selection = RowSelection::from(vec![ + // Skip first page + RowSelector::skip(10), + // Multiple selects in same page + RowSelector::select(3), + RowSelector::skip(3), + RowSelector::select(4), + // Select to page boundary + RowSelector::skip(5), + RowSelector::select(5), + // Skip full page past page boundary + RowSelector::skip(12), + // Select across page boundaries + RowSelector::select(12), + // Skip final page + RowSelector::skip(12), + ]); + + let ranges = selection.scan_ranges(&index); + + // assert_eq!(mask, vec![false, true, true, false, true, true, false]); + assert_eq!(ranges, vec![10..20, 20..30, 40..50, 50..60]); + + let selection = RowSelection::from(vec![ + // Skip first page + RowSelector::skip(10), + // Multiple selects in same page + RowSelector::select(3), + RowSelector::skip(3), + RowSelector::select(4), + // Select to page boundary + RowSelector::skip(5), + RowSelector::select(5), + // Skip full page past page boundary + RowSelector::skip(12), + // Select across page boundaries + RowSelector::select(12), + RowSelector::skip(1), + // Select across page boundaries including final page + RowSelector::select(8), + ]); + + let ranges = selection.scan_ranges(&index); + + // assert_eq!(mask, vec![false, true, true, false, true, true, true]); + assert_eq!(ranges, vec![10..20, 20..30, 40..50, 50..60, 60..70]); + + let selection = RowSelection::from(vec![ + // Skip first page + RowSelector::skip(10), + // Multiple selects in same page + RowSelector::select(3), + RowSelector::skip(3), + RowSelector::select(4), + // Select to page boundary + RowSelector::skip(5), + RowSelector::select(5), + // Skip full page past page boundary + RowSelector::skip(12), + // Select to final page bounday + RowSelector::select(12), + RowSelector::skip(1), + // Skip across final page boundary + RowSelector::skip(8), + // Select from final page + RowSelector::select(4), + ]); + + let ranges = selection.scan_ranges(&index); + + // assert_eq!(mask, vec![false, true, true, false, true, true, true]); + assert_eq!(ranges, vec![10..20, 20..30, 40..50, 50..60, 60..70]); + + let selection = RowSelection::from(vec![ + // Skip first page + RowSelector::skip(10), + // Multiple selects in same page + RowSelector::select(3), + RowSelector::skip(3), + RowSelector::select(4), + // Select to remaining in page and first row of next page + RowSelector::skip(5), + RowSelector::select(6), + // Skip remaining + RowSelector::skip(50), + ]); + + let ranges = selection.scan_ranges(&index); + + // assert_eq!(mask, vec![false, true, true, false, true, true, true]); + assert_eq!(ranges, vec![10..20, 20..30, 30..40]); + } +} diff --git a/parquet/src/arrow/arrow_writer/byte_array.rs b/parquet/src/arrow/arrow_writer/byte_array.rs new file mode 100644 index 000000000000..a25bd8d5c505 --- /dev/null +++ b/parquet/src/arrow/arrow_writer/byte_array.rs @@ -0,0 +1,580 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow::arrow_writer::levels::LevelInfo; +use crate::basic::Encoding; +use crate::column::page::PageWriter; +use crate::column::writer::encoder::{ + ColumnValueEncoder, DataPageValues, DictionaryPage, +}; +use crate::column::writer::GenericColumnWriter; +use crate::data_type::{AsBytes, ByteArray, Int32Type}; +use crate::encodings::encoding::{DeltaBitPackEncoder, Encoder}; +use crate::encodings::rle::RleEncoder; +use crate::errors::{ParquetError, Result}; +use crate::file::properties::{WriterProperties, WriterPropertiesPtr, WriterVersion}; +use crate::file::writer::OnCloseColumnChunk; +use crate::schema::types::ColumnDescPtr; +use crate::util::bit_util::num_required_bits; +use crate::util::interner::{Interner, Storage}; +use arrow::array::{ + Array, ArrayAccessor, ArrayRef, BinaryArray, DictionaryArray, LargeBinaryArray, + LargeStringArray, StringArray, +}; +use arrow::datatypes::DataType; + +macro_rules! downcast_dict_impl { + ($array:ident, $key:ident, $val:ident, $op:expr $(, $arg:expr)*) => {{ + $op($array + .as_any() + .downcast_ref::>() + .unwrap() + .downcast_dict::<$val>() + .unwrap()$(, $arg)*) + }}; +} + +macro_rules! downcast_dict_op { + ($key_type:expr, $val:ident, $array:ident, $op:expr $(, $arg:expr)*) => { + match $key_type.as_ref() { + DataType::UInt8 => downcast_dict_impl!($array, UInt8Type, $val, $op$(, $arg)*), + DataType::UInt16 => downcast_dict_impl!($array, UInt16Type, $val, $op$(, $arg)*), + DataType::UInt32 => downcast_dict_impl!($array, UInt32Type, $val, $op$(, $arg)*), + DataType::UInt64 => downcast_dict_impl!($array, UInt64Type, $val, $op$(, $arg)*), + DataType::Int8 => downcast_dict_impl!($array, Int8Type, $val, $op$(, $arg)*), + DataType::Int16 => downcast_dict_impl!($array, Int16Type, $val, $op$(, $arg)*), + DataType::Int32 => downcast_dict_impl!($array, Int32Type, $val, $op$(, $arg)*), + DataType::Int64 => downcast_dict_impl!($array, Int64Type, $val, $op$(, $arg)*), + _ => unreachable!(), + } + }; +} + +macro_rules! downcast_op { + ($data_type:expr, $array:ident, $op:expr $(, $arg:expr)*) => { + match $data_type { + DataType::Utf8 => $op($array.as_any().downcast_ref::().unwrap()$(, $arg)*), + DataType::LargeUtf8 => { + $op($array.as_any().downcast_ref::().unwrap()$(, $arg)*) + } + DataType::Binary => { + $op($array.as_any().downcast_ref::().unwrap()$(, $arg)*) + } + DataType::LargeBinary => { + $op($array.as_any().downcast_ref::().unwrap()$(, $arg)*) + } + DataType::Dictionary(key, value) => match value.as_ref() { + DataType::Utf8 => downcast_dict_op!(key, StringArray, $array, $op$(, $arg)*), + DataType::LargeUtf8 => { + downcast_dict_op!(key, LargeStringArray, $array, $op$(, $arg)*) + } + DataType::Binary => downcast_dict_op!(key, BinaryArray, $array, $op$(, $arg)*), + DataType::LargeBinary => { + downcast_dict_op!(key, LargeBinaryArray, $array, $op$(, $arg)*) + } + d => unreachable!("cannot downcast {} dictionary value to byte array", d), + }, + d => unreachable!("cannot downcast {} to byte array", d), + } + }; +} + +/// A writer for byte array types +pub(super) struct ByteArrayWriter<'a> { + writer: GenericColumnWriter<'a, ByteArrayEncoder>, + on_close: Option>, +} + +impl<'a> ByteArrayWriter<'a> { + /// Returns a new [`ByteArrayWriter`] + pub fn new( + descr: ColumnDescPtr, + props: &'a WriterPropertiesPtr, + page_writer: Box, + on_close: OnCloseColumnChunk<'a>, + ) -> Result { + Ok(Self { + writer: GenericColumnWriter::new(descr, props.clone(), page_writer), + on_close: Some(on_close), + }) + } + + pub fn write(&mut self, array: &ArrayRef, levels: LevelInfo) -> Result<()> { + self.writer.write_batch_internal( + array, + Some(levels.non_null_indices()), + levels.def_levels(), + levels.rep_levels(), + None, + None, + None, + )?; + Ok(()) + } + + pub fn close(self) -> Result<()> { + let r = self.writer.close()?; + + if let Some(on_close) = self.on_close { + on_close(r)?; + } + Ok(()) + } +} + +/// A fallback encoder, i.e. non-dictionary, for [`ByteArray`] +struct FallbackEncoder { + encoder: FallbackEncoderImpl, + num_values: usize, +} + +/// The fallback encoder in use +/// +/// Note: DeltaBitPackEncoder is boxed as it is rather large +enum FallbackEncoderImpl { + Plain { + buffer: Vec, + }, + DeltaLength { + buffer: Vec, + lengths: Box>, + }, + Delta { + buffer: Vec, + last_value: Vec, + prefix_lengths: Box>, + suffix_lengths: Box>, + }, +} + +impl FallbackEncoder { + /// Create the fallback encoder for the given [`ColumnDescPtr`] and [`WriterProperties`] + fn new(descr: &ColumnDescPtr, props: &WriterProperties) -> Result { + // Set either main encoder or fallback encoder. + let encoding = props.encoding(descr.path()).unwrap_or_else(|| { + match props.writer_version() { + WriterVersion::PARQUET_1_0 => Encoding::PLAIN, + WriterVersion::PARQUET_2_0 => Encoding::DELTA_BYTE_ARRAY, + } + }); + + let encoder = match encoding { + Encoding::PLAIN => FallbackEncoderImpl::Plain { buffer: vec![] }, + Encoding::DELTA_LENGTH_BYTE_ARRAY => FallbackEncoderImpl::DeltaLength { + buffer: vec![], + lengths: Box::new(DeltaBitPackEncoder::new()), + }, + Encoding::DELTA_BYTE_ARRAY => FallbackEncoderImpl::Delta { + buffer: vec![], + last_value: vec![], + prefix_lengths: Box::new(DeltaBitPackEncoder::new()), + suffix_lengths: Box::new(DeltaBitPackEncoder::new()), + }, + _ => { + return Err(general_err!( + "unsupported encoding {} for byte array", + encoding + )) + } + }; + + Ok(Self { + encoder, + num_values: 0, + }) + } + + /// Encode `values` to the in-progress page + fn encode(&mut self, values: T, indices: &[usize]) + where + T: ArrayAccessor + Copy, + T::Item: AsRef<[u8]>, + { + self.num_values += indices.len(); + match &mut self.encoder { + FallbackEncoderImpl::Plain { buffer } => { + for idx in indices { + let value = values.value(*idx); + let value = value.as_ref(); + buffer.extend_from_slice((value.len() as u32).as_bytes()); + buffer.extend_from_slice(value) + } + } + FallbackEncoderImpl::DeltaLength { buffer, lengths } => { + for idx in indices { + let value = values.value(*idx); + let value = value.as_ref(); + lengths.put(&[value.len() as i32]).unwrap(); + buffer.extend_from_slice(value); + } + } + FallbackEncoderImpl::Delta { + buffer, + last_value, + prefix_lengths, + suffix_lengths, + } => { + for idx in indices { + let value = values.value(*idx); + let value = value.as_ref(); + let mut prefix_length = 0; + + while prefix_length < last_value.len() + && prefix_length < value.len() + && last_value[prefix_length] == value[prefix_length] + { + prefix_length += 1; + } + + let suffix_length = value.len() - prefix_length; + + last_value.clear(); + last_value.extend_from_slice(value); + + buffer.extend_from_slice(&value[prefix_length..]); + prefix_lengths.put(&[prefix_length as i32]).unwrap(); + suffix_lengths.put(&[suffix_length as i32]).unwrap(); + } + } + } + } + + fn estimated_data_page_size(&self) -> usize { + match &self.encoder { + FallbackEncoderImpl::Plain { buffer, .. } => buffer.len(), + FallbackEncoderImpl::DeltaLength { buffer, lengths } => { + buffer.len() + lengths.estimated_data_encoded_size() + } + FallbackEncoderImpl::Delta { + buffer, + prefix_lengths, + suffix_lengths, + .. + } => { + buffer.len() + + prefix_lengths.estimated_data_encoded_size() + + suffix_lengths.estimated_data_encoded_size() + } + } + } + + fn flush_data_page( + &mut self, + min_value: Option, + max_value: Option, + ) -> Result> { + let (buf, encoding) = match &mut self.encoder { + FallbackEncoderImpl::Plain { buffer } => { + (std::mem::take(buffer), Encoding::PLAIN) + } + FallbackEncoderImpl::DeltaLength { buffer, lengths } => { + let lengths = lengths.flush_buffer()?; + + let mut out = Vec::with_capacity(lengths.len() + buffer.len()); + out.extend_from_slice(lengths.data()); + out.extend_from_slice(buffer); + (out, Encoding::DELTA_LENGTH_BYTE_ARRAY) + } + FallbackEncoderImpl::Delta { + buffer, + prefix_lengths, + suffix_lengths, + .. + } => { + let prefix_lengths = prefix_lengths.flush_buffer()?; + let suffix_lengths = suffix_lengths.flush_buffer()?; + + let mut out = Vec::with_capacity( + prefix_lengths.len() + suffix_lengths.len() + buffer.len(), + ); + out.extend_from_slice(prefix_lengths.data()); + out.extend_from_slice(suffix_lengths.data()); + out.extend_from_slice(buffer); + (out, Encoding::DELTA_BYTE_ARRAY) + } + }; + + Ok(DataPageValues { + buf: buf.into(), + num_values: std::mem::take(&mut self.num_values), + encoding, + min_value, + max_value, + }) + } +} + +/// [`Storage`] for the [`Interner`] used by [`DictEncoder`] +#[derive(Debug, Default)] +struct ByteArrayStorage { + /// Encoded dictionary data + page: Vec, + + values: Vec>, +} + +impl Storage for ByteArrayStorage { + type Key = u64; + type Value = [u8]; + + fn get(&self, idx: Self::Key) -> &Self::Value { + &self.page[self.values[idx as usize].clone()] + } + + fn push(&mut self, value: &Self::Value) -> Self::Key { + let key = self.values.len(); + + self.page.reserve(4 + value.len()); + self.page.extend_from_slice((value.len() as u32).as_bytes()); + + let start = self.page.len(); + self.page.extend_from_slice(value); + self.values.push(start..self.page.len()); + + key as u64 + } +} + +/// A dictionary encoder for byte array data +#[derive(Debug, Default)] +struct DictEncoder { + interner: Interner, + indices: Vec, +} + +impl DictEncoder { + /// Encode `values` to the in-progress page + fn encode(&mut self, values: T, indices: &[usize]) + where + T: ArrayAccessor + Copy, + T::Item: AsRef<[u8]>, + { + self.indices.reserve(indices.len()); + + for idx in indices { + let value = values.value(*idx); + let interned = self.interner.intern(value.as_ref()); + self.indices.push(interned); + } + } + + fn bit_width(&self) -> u8 { + let length = self.interner.storage().values.len(); + num_required_bits(length.saturating_sub(1) as u64) + } + + fn estimated_data_page_size(&self) -> usize { + let bit_width = self.bit_width(); + 1 + RleEncoder::min_buffer_size(bit_width) + + RleEncoder::max_buffer_size(bit_width, self.indices.len()) + } + + fn estimated_dict_page_size(&self) -> usize { + self.interner.storage().page.len() + } + + fn flush_dict_page(self) -> DictionaryPage { + let storage = self.interner.into_inner(); + + DictionaryPage { + buf: storage.page.into(), + num_values: storage.values.len(), + is_sorted: false, + } + } + + fn flush_data_page( + &mut self, + min_value: Option, + max_value: Option, + ) -> DataPageValues { + let num_values = self.indices.len(); + let buffer_len = self.estimated_data_page_size(); + let mut buffer = Vec::with_capacity(buffer_len); + buffer.push(self.bit_width() as u8); + + let mut encoder = RleEncoder::new_from_buf(self.bit_width(), buffer); + for index in &self.indices { + encoder.put(*index as u64) + } + + self.indices.clear(); + + DataPageValues { + buf: encoder.consume().into(), + num_values, + encoding: Encoding::RLE_DICTIONARY, + min_value, + max_value, + } + } +} + +struct ByteArrayEncoder { + fallback: FallbackEncoder, + dict_encoder: Option, + num_values: usize, + min_value: Option, + max_value: Option, +} + +impl ColumnValueEncoder for ByteArrayEncoder { + type T = ByteArray; + type Values = ArrayRef; + + fn min_max( + &self, + values: &ArrayRef, + value_indices: Option<&[usize]>, + ) -> Option<(Self::T, Self::T)> { + match value_indices { + Some(indices) => { + let iter = indices.iter().cloned(); + downcast_op!(values.data_type(), values, compute_min_max, iter) + } + None => { + let len = Array::len(values); + downcast_op!(values.data_type(), values, compute_min_max, 0..len) + } + } + } + + fn try_new(descr: &ColumnDescPtr, props: &WriterProperties) -> Result + where + Self: Sized, + { + let dictionary = props + .dictionary_enabled(descr.path()) + .then(DictEncoder::default); + + let fallback = FallbackEncoder::new(descr, props)?; + + Ok(Self { + fallback, + dict_encoder: dictionary, + num_values: 0, + min_value: None, + max_value: None, + }) + } + + fn write( + &mut self, + _values: &Self::Values, + _offset: usize, + _len: usize, + ) -> Result<()> { + unreachable!("should call write_gather instead") + } + + fn write_gather(&mut self, values: &Self::Values, indices: &[usize]) -> Result<()> { + downcast_op!(values.data_type(), values, encode, indices, self); + Ok(()) + } + + fn num_values(&self) -> usize { + self.num_values + } + + fn has_dictionary(&self) -> bool { + self.dict_encoder.is_some() + } + + fn estimated_dict_page_size(&self) -> Option { + Some(self.dict_encoder.as_ref()?.estimated_dict_page_size()) + } + + fn estimated_data_page_size(&self) -> usize { + match &self.dict_encoder { + Some(encoder) => encoder.estimated_data_page_size(), + None => self.fallback.estimated_data_page_size(), + } + } + + fn flush_dict_page(&mut self) -> Result> { + match self.dict_encoder.take() { + Some(encoder) => { + if self.num_values != 0 { + return Err(general_err!( + "Must flush data pages before flushing dictionary" + )); + } + + Ok(Some(encoder.flush_dict_page())) + } + _ => Ok(None), + } + } + + fn flush_data_page(&mut self) -> Result> { + let min_value = self.min_value.take(); + let max_value = self.max_value.take(); + + match &mut self.dict_encoder { + Some(encoder) => Ok(encoder.flush_data_page(min_value, max_value)), + _ => self.fallback.flush_data_page(min_value, max_value), + } + } +} + +/// Encodes the provided `values` and `indices` to `encoder` +/// +/// This is a free function so it can be used with `downcast_op!` +fn encode(values: T, indices: &[usize], encoder: &mut ByteArrayEncoder) +where + T: ArrayAccessor + Copy, + T::Item: Copy + Ord + AsRef<[u8]>, +{ + if let Some((min, max)) = compute_min_max(values, indices.iter().cloned()) { + if encoder.min_value.as_ref().map_or(true, |m| m > &min) { + encoder.min_value = Some(min); + } + + if encoder.max_value.as_ref().map_or(true, |m| m < &max) { + encoder.max_value = Some(max); + } + } + + match &mut encoder.dict_encoder { + Some(dict_encoder) => dict_encoder.encode(values, indices), + None => encoder.fallback.encode(values, indices), + } +} + +/// Computes the min and max for the provided array and indices +/// +/// This is a free function so it can be used with `downcast_op!` +fn compute_min_max( + array: T, + mut valid: impl Iterator, +) -> Option<(ByteArray, ByteArray)> +where + T: ArrayAccessor, + T::Item: Copy + Ord + AsRef<[u8]>, +{ + let first_idx = valid.next()?; + + let first_val = array.value(first_idx); + let mut min = first_val; + let mut max = first_val; + for idx in valid { + let val = array.value(idx); + min = min.min(val); + max = max.max(val); + } + Some((min.as_ref().to_vec().into(), max.as_ref().to_vec().into())) +} diff --git a/parquet/src/arrow/arrow_writer/levels.rs b/parquet/src/arrow/arrow_writer/levels.rs index 51e494d41be0..49f997ac81ff 100644 --- a/parquet/src/arrow/arrow_writer/levels.rs +++ b/parquet/src/arrow/arrow_writer/levels.rs @@ -88,7 +88,7 @@ fn is_leaf(data_type: &DataType) -> bool { | DataType::Interval(_) | DataType::Binary | DataType::LargeBinary - | DataType::Decimal(_, _) + | DataType::Decimal128(_, _) | DataType::FixedSizeBinary(_) ) } @@ -1188,7 +1188,7 @@ mod tests { Field::new("item", DataType::Struct(vec![int_field.clone()]), true); let list_field = Field::new("list", DataType::List(Box::new(item_field)), true); - let int_builder = Int32Builder::new(10); + let int_builder = Int32Builder::with_capacity(10); let struct_builder = StructBuilder::new(vec![int_field], vec![Box::new(int_builder)]); let mut list_builder = ListBuilder::new(struct_builder); @@ -1200,52 +1200,47 @@ mod tests { values .field_builder::(0) .unwrap() - .append_value(1) - .unwrap(); - values.append(true).unwrap(); - list_builder.append(true).unwrap(); + .append_value(1); + values.append(true); + list_builder.append(true); // [] - list_builder.append(true).unwrap(); + list_builder.append(true); // null - list_builder.append(false).unwrap(); + list_builder.append(false); // [null, null] let values = list_builder.values(); values .field_builder::(0) .unwrap() - .append_null() - .unwrap(); - values.append(false).unwrap(); + .append_null(); + values.append(false); values .field_builder::(0) .unwrap() - .append_null() - .unwrap(); - values.append(false).unwrap(); - list_builder.append(true).unwrap(); + .append_null(); + values.append(false); + list_builder.append(true); // [{a: null}] let values = list_builder.values(); values .field_builder::(0) .unwrap() - .append_null() - .unwrap(); - values.append(true).unwrap(); - list_builder.append(true).unwrap(); + .append_null(); + values.append(true); + list_builder.append(true); // [{a: 2}] let values = list_builder.values(); values .field_builder::(0) .unwrap() - .append_value(2) - .unwrap(); - values.append(true).unwrap(); - list_builder.append(true).unwrap(); + .append_value(2); + values.append(true); + list_builder.append(true); let array = Arc::new(list_builder.finish()); diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 73f46f971f95..6f9d5b3aff81 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use arrow::array as arrow_array; use arrow::array::ArrayRef; -use arrow::array::BasicDecimalArray; use arrow::datatypes::{DataType as ArrowDataType, IntervalUnit, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_array::Array; @@ -33,14 +32,16 @@ use super::schema::{ decimal_length_from_precision, }; -use crate::column::writer::ColumnWriter; +use crate::arrow::arrow_writer::byte_array::ByteArrayWriter; +use crate::column::writer::{ColumnWriter, ColumnWriterImpl}; use crate::errors::{ParquetError, Result}; use crate::file::metadata::RowGroupMetaDataPtr; use crate::file::properties::WriterProperties; -use crate::file::writer::{SerializedColumnWriter, SerializedRowGroupWriter}; +use crate::file::writer::SerializedRowGroupWriter; use crate::{data_type::*, file::writer::SerializedFileWriter}; use levels::{calculate_array_levels, LevelInfo}; +mod byte_array; mod levels; /// Arrow writer @@ -222,6 +223,12 @@ impl ArrowWriter { Ok(()) } + /// Flushes any outstanding data and returns the underlying writer. + pub fn into_inner(mut self) -> Result { + self.flush()?; + self.writer.into_inner() + } + /// Close and finalize the underlying Parquet writer pub fn close(mut self) -> Result { self.flush()?; @@ -229,17 +236,6 @@ impl ArrowWriter { } } -/// Convenience method to get the next ColumnWriter from the RowGroupWriter -#[inline] -fn get_col_writer<'a, W: Write>( - row_group_writer: &'a mut SerializedRowGroupWriter<'_, W>, -) -> Result> { - let col_writer = row_group_writer - .next_column()? - .expect("Unable to get column writer"); - Ok(col_writer) -} - fn write_leaves( row_group_writer: &mut SerializedRowGroupWriter<'_, W>, arrays: &[ArrayRef], @@ -271,22 +267,24 @@ fn write_leaves( | ArrowDataType::Time64(_) | ArrowDataType::Duration(_) | ArrowDataType::Interval(_) - | ArrowDataType::LargeBinary + | ArrowDataType::Decimal128(_, _) + | ArrowDataType::Decimal256(_, _) + | ArrowDataType::FixedSizeBinary(_) => { + let mut col_writer = row_group_writer.next_column()?.unwrap(); + for (array, levels) in arrays.iter().zip(levels.iter_mut()) { + write_leaf(col_writer.untyped(), array, levels.pop().expect("Levels exhausted"))?; + } + col_writer.close() + } + ArrowDataType::LargeBinary | ArrowDataType::Binary | ArrowDataType::Utf8 - | ArrowDataType::LargeUtf8 - | ArrowDataType::Decimal(_, _) - | ArrowDataType::FixedSizeBinary(_) => { - let mut col_writer = get_col_writer(row_group_writer)?; + | ArrowDataType::LargeUtf8 => { + let mut col_writer = row_group_writer.next_column_with_factory(ByteArrayWriter::new)?.unwrap(); for (array, levels) in arrays.iter().zip(levels.iter_mut()) { - write_leaf( - col_writer.untyped(), - array, - levels.pop().expect("Levels exhausted"), - )?; + col_writer.write(array, levels.pop().expect("Levels exhausted"))?; } - col_writer.close()?; - Ok(()) + col_writer.close() } ArrowDataType::List(_) | ArrowDataType::LargeList(_) => { let arrays: Vec<_> = arrays.iter().map(|array|{ @@ -337,19 +335,21 @@ fn write_leaves( write_leaves(row_group_writer, &values, levels)?; Ok(()) } - ArrowDataType::Dictionary(_, value_type) => { - let mut col_writer = get_col_writer(row_group_writer)?; - for (array, levels) in arrays.iter().zip(levels.iter_mut()) { - // cast dictionary to a primitive - let array = arrow::compute::cast(array, value_type)?; - write_leaf( - col_writer.untyped(), - &array, - levels.pop().expect("Levels exhausted"), - )?; + ArrowDataType::Dictionary(_, value_type) => match value_type.as_ref() { + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Binary | ArrowDataType::LargeBinary => { + let mut col_writer = row_group_writer.next_column_with_factory(ByteArrayWriter::new)?.unwrap(); + for (array, levels) in arrays.iter().zip(levels.iter_mut()) { + col_writer.write(array, levels.pop().expect("Levels exhausted"))?; + } + col_writer.close() + } + _ => { + let mut col_writer = row_group_writer.next_column()?.unwrap(); + for (array, levels) in arrays.iter().zip(levels.iter_mut()) { + write_leaf(col_writer.untyped(), array, levels.pop().expect("Levels exhausted"))?; + } + col_writer.close() } - col_writer.close()?; - Ok(()) } ArrowDataType::Float16 => Err(ParquetError::ArrowError( "Float16 arrays not supported".to_string(), @@ -373,33 +373,25 @@ fn write_leaf( let indices = levels.non_null_indices(); let written = match writer { ColumnWriter::Int32ColumnWriter(ref mut typed) => { - let values = match column.data_type() { + match column.data_type() { ArrowDataType::Date64 => { // If the column is a Date64, we cast it to a Date32, and then interpret that as Int32 - let array = if let ArrowDataType::Date64 = column.data_type() { - let array = arrow::compute::cast(column, &ArrowDataType::Date32)?; - arrow::compute::cast(&array, &ArrowDataType::Int32)? - } else { - arrow::compute::cast(column, &ArrowDataType::Int32)? - }; + let array = arrow::compute::cast(column, &ArrowDataType::Date32)?; + let array = arrow::compute::cast(&array, &ArrowDataType::Int32)?; + let array = array .as_any() .downcast_ref::() .expect("Unable to get int32 array"); - get_numeric_array_slice::(array, indices) + write_primitive(typed, array.values(), levels)? } ArrowDataType::UInt32 => { + let data = column.data(); + let offset = data.offset(); // follow C++ implementation and use overflow/reinterpret cast from u32 to i32 which will map // `(i32::MAX as u32)..u32::MAX` to `i32::MIN..0` - let array = column - .as_any() - .downcast_ref::() - .expect("Unable to get u32 array"); - let array = arrow::compute::unary::<_, _, arrow::datatypes::Int32Type>( - array, - |x| x as i32, - ); - get_numeric_array_slice::(&array, indices) + let array: &[i32] = data.buffers()[0].typed_data(); + write_primitive(typed, &array[offset..offset + data.len()], levels)? } _ => { let array = arrow::compute::cast(column, &ArrowDataType::Int32)?; @@ -407,14 +399,9 @@ fn write_leaf( .as_any() .downcast_ref::() .expect("Unable to get i32 array"); - get_numeric_array_slice::(array, indices) + write_primitive(typed, array.values(), levels)? } - }; - typed.write_batch( - values.as_slice(), - levels.def_levels(), - levels.rep_levels(), - )? + } } ColumnWriter::BoolColumnWriter(ref mut typed) => { let array = column @@ -428,26 +415,21 @@ fn write_leaf( )? } ColumnWriter::Int64ColumnWriter(ref mut typed) => { - let values = match column.data_type() { + match column.data_type() { ArrowDataType::Int64 => { let array = column .as_any() .downcast_ref::() .expect("Unable to get i64 array"); - get_numeric_array_slice::(array, indices) + write_primitive(typed, array.values(), levels)? } ArrowDataType::UInt64 => { // follow C++ implementation and use overflow/reinterpret cast from u64 to i64 which will map // `(i64::MAX as u64)..u64::MAX` to `i64::MIN..0` - let array = column - .as_any() - .downcast_ref::() - .expect("Unable to get u64 array"); - let array = arrow::compute::unary::<_, _, arrow::datatypes::Int64Type>( - array, - |x| x as i64, - ); - get_numeric_array_slice::(&array, indices) + let data = column.data(); + let offset = data.offset(); + let array: &[i64] = data.buffers()[0].typed_data(); + write_primitive(typed, &array[offset..offset + data.len()], levels)? } _ => { let array = arrow::compute::cast(column, &ArrowDataType::Int64)?; @@ -455,14 +437,9 @@ fn write_leaf( .as_any() .downcast_ref::() .expect("Unable to get i64 array"); - get_numeric_array_slice::(array, indices) + write_primitive(typed, array.values(), levels)? } - }; - typed.write_batch( - values.as_slice(), - levels.def_levels(), - levels.rep_levels(), - )? + } } ColumnWriter::Int96ColumnWriter(ref mut _typed) => { unreachable!("Currently unreachable because data type not supported") @@ -472,70 +449,18 @@ fn write_leaf( .as_any() .downcast_ref::() .expect("Unable to get Float32 array"); - typed.write_batch( - get_numeric_array_slice::(array, indices).as_slice(), - levels.def_levels(), - levels.rep_levels(), - )? + write_primitive(typed, array.values(), levels)? } ColumnWriter::DoubleColumnWriter(ref mut typed) => { let array = column .as_any() .downcast_ref::() .expect("Unable to get Float64 array"); - typed.write_batch( - get_numeric_array_slice::(array, indices).as_slice(), - levels.def_levels(), - levels.rep_levels(), - )? + write_primitive(typed, array.values(), levels)? + } + ColumnWriter::ByteArrayColumnWriter(_) => { + unreachable!("should use ByteArrayWriter") } - ColumnWriter::ByteArrayColumnWriter(ref mut typed) => match column.data_type() { - ArrowDataType::Binary => { - let array = column - .as_any() - .downcast_ref::() - .expect("Unable to get BinaryArray array"); - typed.write_batch( - get_binary_array(array).as_slice(), - levels.def_levels(), - levels.rep_levels(), - )? - } - ArrowDataType::Utf8 => { - let array = column - .as_any() - .downcast_ref::() - .expect("Unable to get LargeBinaryArray array"); - typed.write_batch( - get_string_array(array).as_slice(), - levels.def_levels(), - levels.rep_levels(), - )? - } - ArrowDataType::LargeBinary => { - let array = column - .as_any() - .downcast_ref::() - .expect("Unable to get LargeBinaryArray array"); - typed.write_batch( - get_large_binary_array(array).as_slice(), - levels.def_levels(), - levels.rep_levels(), - )? - } - ArrowDataType::LargeUtf8 => { - let array = column - .as_any() - .downcast_ref::() - .expect("Unable to get LargeUtf8 array"); - typed.write_batch( - get_large_string_array(array).as_slice(), - levels.def_levels(), - levels.rep_levels(), - )? - } - _ => unreachable!("Currently unreachable because data type not supported"), - }, ColumnWriter::FixedLenByteArrayColumnWriter(ref mut typed) => { let bytes = match column.data_type() { ArrowDataType::Interval(interval_unit) => match interval_unit { @@ -569,10 +494,10 @@ fn write_leaf( .unwrap(); get_fsb_array_slice(array, indices) } - ArrowDataType::Decimal(_, _) => { + ArrowDataType::Decimal128(_, _) => { let array = column .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap(); get_decimal_array_slice(array, indices) } @@ -593,55 +518,20 @@ fn write_leaf( Ok(written as i64) } -macro_rules! def_get_binary_array_fn { - ($name:ident, $ty:ty) => { - fn $name(array: &$ty) -> Vec { - let mut byte_array = ByteArray::new(); - let ptr = crate::util::memory::ByteBufferPtr::new( - array.value_data().as_slice().to_vec(), - ); - byte_array.set_data(ptr); - array - .value_offsets() - .windows(2) - .enumerate() - .filter_map(|(i, offsets)| { - if array.is_valid(i) { - let start = offsets[0] as usize; - let len = offsets[1] as usize - start; - Some(byte_array.slice(start, len)) - } else { - None - } - }) - .collect() - } - }; -} - -// TODO: These methods don't handle non null indices correctly (#1753) -def_get_binary_array_fn!(get_binary_array, arrow_array::BinaryArray); -def_get_binary_array_fn!(get_string_array, arrow_array::StringArray); -def_get_binary_array_fn!(get_large_binary_array, arrow_array::LargeBinaryArray); -def_get_binary_array_fn!(get_large_string_array, arrow_array::LargeStringArray); - -/// Get the underlying numeric array slice, skipping any null values. -/// If there are no null values, it might be quicker to get the slice directly instead of -/// calling this function. -fn get_numeric_array_slice( - array: &arrow_array::PrimitiveArray, - indices: &[usize], -) -> Vec -where - T: DataType, - A: arrow::datatypes::ArrowNumericType, - T::T: From, -{ - let mut values = Vec::with_capacity(indices.len()); - for i in indices { - values.push(array.value(*i).into()) - } - values +fn write_primitive<'a, T: DataType>( + writer: &mut ColumnWriterImpl<'a, T>, + values: &[T::T], + levels: LevelInfo, +) -> Result { + writer.write_batch_internal( + values, + Some(levels.non_null_indices()), + levels.def_levels(), + levels.rep_levels(), + None, + None, + None, + ) } fn get_bool_array_slice( @@ -689,7 +579,7 @@ fn get_interval_dt_array_slice( } fn get_decimal_array_slice( - array: &arrow_array::DecimalArray, + array: &arrow_array::Decimal128Array, indices: &[usize], ) -> Vec { let mut values = Vec::with_capacity(indices.len()); @@ -722,6 +612,9 @@ mod tests { use std::fs::File; use std::sync::Arc; + use crate::arrow::arrow_reader::{ + ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder, + }; use arrow::datatypes::ToByteSlice; use arrow::datatypes::{DataType, Field, Schema, UInt32Type, UInt8Type}; use arrow::error::Result as ArrowResult; @@ -729,8 +622,9 @@ mod tests { use arrow::util::pretty::pretty_format_batches; use arrow::{array::*, buffer::Buffer}; - use crate::arrow::{ArrowReader, ParquetFileArrowReader}; + use crate::basic::Encoding; use crate::file::metadata::ParquetMetaData; + use crate::file::properties::WriterVersion; use crate::file::{ reader::{FileReader, SerializedFileReader}, statistics::Statistics, @@ -756,6 +650,25 @@ mod tests { roundtrip(batch, Some(SMALL_SIZE / 2)); } + fn get_bytes_after_close(schema: SchemaRef, expected_batch: &RecordBatch) -> Vec { + let mut buffer = vec![]; + + let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap(); + writer.write(expected_batch).unwrap(); + writer.close().unwrap(); + + buffer + } + + fn get_bytes_by_into_inner( + schema: SchemaRef, + expected_batch: &RecordBatch, + ) -> Vec { + let mut writer = ArrowWriter::try_new(Vec::new(), schema, None).unwrap(); + writer.write(expected_batch).unwrap(); + writer.into_inner().unwrap() + } + #[test] fn roundtrip_bytes() { // define schema @@ -772,31 +685,28 @@ mod tests { let expected_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(a), Arc::new(b)]).unwrap(); - let mut buffer = vec![]; - - { - let mut writer = ArrowWriter::try_new(&mut buffer, schema, None).unwrap(); - writer.write(&expected_batch).unwrap(); - writer.close().unwrap(); - } - - let cursor = Bytes::from(buffer); - let mut arrow_reader = ParquetFileArrowReader::try_new(cursor).unwrap(); - let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); - - let actual_batch = record_batch_reader - .next() - .expect("No batch found") - .expect("Unable to get batch"); - - assert_eq!(expected_batch.schema(), actual_batch.schema()); - assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); - assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); - for i in 0..expected_batch.num_columns() { - let expected_data = expected_batch.column(i).data().clone(); - let actual_data = actual_batch.column(i).data().clone(); - - assert_eq!(expected_data, actual_data); + for buffer in vec![ + get_bytes_after_close(schema.clone(), &expected_batch), + get_bytes_by_into_inner(schema, &expected_batch), + ] { + let cursor = Bytes::from(buffer); + let mut record_batch_reader = + ParquetRecordBatchReader::try_new(cursor, 1024).unwrap(); + + let actual_batch = record_batch_reader + .next() + .expect("No batch found") + .expect("Unable to get batch"); + + assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); + assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); + for i in 0..expected_batch.num_columns() { + let expected_data = expected_batch.column(i).data().clone(); + let actual_data = actual_batch.column(i).data().clone(); + + assert_eq!(expected_data, actual_data); + } } } @@ -926,13 +836,13 @@ mod tests { #[test] fn arrow_writer_decimal() { - let decimal_field = Field::new("a", DataType::Decimal(5, 2), false); + let decimal_field = Field::new("a", DataType::Decimal128(5, 2), false); let schema = Schema::new(vec![decimal_field]); let decimal_values = vec![10_000, 50_000, 0, -100] .into_iter() .map(Some) - .collect::() + .collect::() .with_precision_and_scale(5, 2) .unwrap(); @@ -1200,25 +1110,38 @@ mod tests { const SMALL_SIZE: usize = 7; - fn roundtrip(expected_batch: RecordBatch, max_row_group_size: Option) -> File { + fn roundtrip( + expected_batch: RecordBatch, + max_row_group_size: Option, + ) -> Vec { + let mut files = vec![]; + for version in [WriterVersion::PARQUET_1_0, WriterVersion::PARQUET_2_0] { + let mut props = WriterProperties::builder().set_writer_version(version); + + if let Some(size) = max_row_group_size { + props = props.set_max_row_group_size(size) + } + + let props = props.build(); + files.push(roundtrip_opts(&expected_batch, props)) + } + files + } + + fn roundtrip_opts(expected_batch: &RecordBatch, props: WriterProperties) -> File { let file = tempfile::tempfile().unwrap(); let mut writer = ArrowWriter::try_new( file.try_clone().unwrap(), expected_batch.schema(), - max_row_group_size.map(|size| { - WriterProperties::builder() - .set_max_row_group_size(size) - .build() - }), + Some(props), ) .expect("Unable to write file"); - writer.write(&expected_batch).unwrap(); + writer.write(expected_batch).unwrap(); writer.close().unwrap(); - let mut arrow_reader = - ParquetFileArrowReader::try_new(file.try_clone().unwrap()).unwrap(); - let mut record_batch_reader = arrow_reader.get_record_reader(1024).unwrap(); + let mut record_batch_reader = + ParquetRecordBatchReader::try_new(file.try_clone().unwrap(), 1024).unwrap(); let actual_batch = record_batch_reader .next() @@ -1238,20 +1161,59 @@ mod tests { file } - fn one_column_roundtrip( - values: ArrayRef, - nullable: bool, - max_row_group_size: Option, - ) -> File { - let schema = Schema::new(vec![Field::new( - "col", - values.data_type().clone(), - nullable, - )]); - let expected_batch = - RecordBatch::try_new(Arc::new(schema), vec![values]).unwrap(); + fn one_column_roundtrip(values: ArrayRef, nullable: bool) -> Vec { + let data_type = values.data_type().clone(); + let schema = Schema::new(vec![Field::new("col", data_type, nullable)]); + one_column_roundtrip_with_schema(values, Arc::new(schema)) + } - roundtrip(expected_batch, max_row_group_size) + fn one_column_roundtrip_with_schema( + values: ArrayRef, + schema: SchemaRef, + ) -> Vec { + let encodings = match values.data_type() { + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary => vec![ + Encoding::PLAIN, + Encoding::DELTA_BYTE_ARRAY, + Encoding::DELTA_LENGTH_BYTE_ARRAY, + ], + DataType::Int64 + | DataType::Int32 + | DataType::Int16 + | DataType::Int8 + | DataType::UInt64 + | DataType::UInt32 + | DataType::UInt16 + | DataType::UInt8 => vec![Encoding::PLAIN, Encoding::DELTA_BINARY_PACKED], + _ => vec![Encoding::PLAIN], + }; + + let expected_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); + + let row_group_sizes = [1024, SMALL_SIZE, SMALL_SIZE / 2, SMALL_SIZE / 2 + 1, 10]; + + let mut files = vec![]; + for dictionary_size in [0, 1, 1024] { + for encoding in &encodings { + for version in [WriterVersion::PARQUET_1_0, WriterVersion::PARQUET_2_0] { + for row_group_size in row_group_sizes { + let props = WriterProperties::builder() + .set_writer_version(version) + .set_max_row_group_size(row_group_size) + .set_dictionary_enabled(dictionary_size != 0) + .set_dictionary_pagesize_limit(dictionary_size.max(1)) + .set_encoding(*encoding) + .build(); + + files.push(roundtrip_opts(&expected_batch, props)) + } + } + } + } + files } fn values_required(iter: I) @@ -1261,7 +1223,7 @@ mod tests { { let raw_values: Vec<_> = iter.into_iter().collect(); let values = Arc::new(A::from(raw_values)); - one_column_roundtrip(values, false, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, false); } fn values_optional(iter: I) @@ -1275,7 +1237,7 @@ mod tests { .map(|(i, v)| if i % 2 == 0 { None } else { Some(v) }) .collect(); let optional_values = Arc::new(A::from(optional_raw_values)); - one_column_roundtrip(optional_values, true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(optional_values, true); } fn required_and_optional(iter: I) @@ -1290,12 +1252,12 @@ mod tests { #[test] fn all_null_primitive_single_column() { let values = Arc::new(Int32Array::from(vec![None; SMALL_SIZE])); - one_column_roundtrip(values, true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, true); } #[test] fn null_single_column() { let values = Arc::new(NullArray::new(SMALL_SIZE)); - one_column_roundtrip(values, true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, true); // null arrays are always nullable, a test with non-nullable nulls fails } @@ -1391,7 +1353,7 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); let values = Arc::new(TimestampSecondArray::from_vec(raw_values, None)); - one_column_roundtrip(values, false, Some(3)); + one_column_roundtrip(values, false); } #[test] @@ -1399,7 +1361,7 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); let values = Arc::new(TimestampMillisecondArray::from_vec(raw_values, None)); - one_column_roundtrip(values, false, Some(SMALL_SIZE / 2 + 1)); + one_column_roundtrip(values, false); } #[test] @@ -1407,7 +1369,7 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); let values = Arc::new(TimestampMicrosecondArray::from_vec(raw_values, None)); - one_column_roundtrip(values, false, Some(SMALL_SIZE / 2 + 2)); + one_column_roundtrip(values, false); } #[test] @@ -1415,7 +1377,7 @@ mod tests { let raw_values: Vec<_> = (0..SMALL_SIZE as i64).collect(); let values = Arc::new(TimestampNanosecondArray::from_vec(raw_values, None)); - one_column_roundtrip(values, false, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, false); } #[test] @@ -1515,14 +1477,14 @@ mod tests { #[test] fn fixed_size_binary_single_column() { - let mut builder = FixedSizeBinaryBuilder::new(16, 4); + let mut builder = FixedSizeBinaryBuilder::new(4); builder.append_value(b"0123").unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append_value(b"8910").unwrap(); builder.append_value(b"1112").unwrap(); let array = Arc::new(builder.finish()); - one_column_roundtrip(array, true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(array, true); } #[test] @@ -1600,7 +1562,7 @@ mod tests { let a = ListArray::from(a_list_data); let values = Arc::new(a); - one_column_roundtrip(values, true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, true); } #[test] @@ -1626,7 +1588,7 @@ mod tests { let a = LargeListArray::from(a_list_data); let values = Arc::new(a); - one_column_roundtrip(values, true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, true); } #[test] @@ -1642,10 +1604,10 @@ mod tests { ]; let list = ListArray::from_iter_primitive::(data.clone()); - one_column_roundtrip(Arc::new(list), true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(Arc::new(list), true); let list = LargeListArray::from_iter_primitive::(data); - one_column_roundtrip(Arc::new(list), true, Some(SMALL_SIZE / 2)); + one_column_roundtrip(Arc::new(list), true); } #[test] @@ -1655,7 +1617,7 @@ mod tests { let s = StructArray::from(vec![(struct_field_a, Arc::new(a_values) as ArrayRef)]); let values = Arc::new(s); - one_column_roundtrip(values, false, Some(SMALL_SIZE / 2)); + one_column_roundtrip(values, false); } #[test] @@ -1676,9 +1638,7 @@ mod tests { .collect(); // build a record batch - let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); - - roundtrip(expected_batch, Some(SMALL_SIZE / 2)); + one_column_roundtrip_with_schema(Arc::new(d), schema); } #[test] @@ -1693,19 +1653,16 @@ mod tests { )])); // create some data - let key_builder = PrimitiveBuilder::::new(3); - let value_builder = PrimitiveBuilder::::new(2); + let key_builder = PrimitiveBuilder::::with_capacity(3); + let value_builder = PrimitiveBuilder::::with_capacity(2); let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder); builder.append(12345678).unwrap(); - builder.append_null().unwrap(); + builder.append_null(); builder.append(22345678).unwrap(); builder.append(12345678).unwrap(); let d = builder.finish(); - // build a record batch - let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); - - roundtrip(expected_batch, Some(SMALL_SIZE / 2)); + one_column_roundtrip_with_schema(Arc::new(d), schema); } #[test] @@ -1725,16 +1682,13 @@ mod tests { .copied() .collect(); - // build a record batch - let expected_batch = RecordBatch::try_new(schema, vec![Arc::new(d)]).unwrap(); - - roundtrip(expected_batch, Some(SMALL_SIZE / 2)); + one_column_roundtrip_with_schema(Arc::new(d), schema); } #[test] fn u32_min_max() { // check values roundtrip through parquet - let values = Arc::new(UInt32Array::from_iter_values(vec![ + let src = vec![ u32::MIN, u32::MIN + 1, (i32::MAX as u32) - 1, @@ -1742,30 +1696,40 @@ mod tests { (i32::MAX as u32) + 1, u32::MAX - 1, u32::MAX, - ])); - let file = one_column_roundtrip(values, false, None); - - // check statistics are valid - let reader = SerializedFileReader::new(file).unwrap(); - let metadata = reader.metadata(); - assert_eq!(metadata.num_row_groups(), 1); - let row_group = metadata.row_group(0); - assert_eq!(row_group.num_columns(), 1); - let column = row_group.column(0); - let stats = column.statistics().unwrap(); - assert!(stats.has_min_max_set()); - if let Statistics::Int32(stats) = stats { - assert_eq!(*stats.min() as u32, u32::MIN); - assert_eq!(*stats.max() as u32, u32::MAX); - } else { - panic!("Statistics::Int32 missing") + ]; + let values = Arc::new(UInt32Array::from_iter_values(src.iter().cloned())); + let files = one_column_roundtrip(values, false); + + for file in files { + // check statistics are valid + let reader = SerializedFileReader::new(file).unwrap(); + let metadata = reader.metadata(); + + let mut row_offset = 0; + for row_group in metadata.row_groups() { + assert_eq!(row_group.num_columns(), 1); + let column = row_group.column(0); + + let num_values = column.num_values() as usize; + let src_slice = &src[row_offset..row_offset + num_values]; + row_offset += column.num_values() as usize; + + let stats = column.statistics().unwrap(); + assert!(stats.has_min_max_set()); + if let Statistics::Int32(stats) = stats { + assert_eq!(*stats.min() as u32, *src_slice.iter().min().unwrap()); + assert_eq!(*stats.max() as u32, *src_slice.iter().max().unwrap()); + } else { + panic!("Statistics::Int32 missing") + } + } } } #[test] fn u64_min_max() { // check values roundtrip through parquet - let values = Arc::new(UInt64Array::from_iter_values(vec![ + let src = vec![ u64::MIN, u64::MIN + 1, (i64::MAX as u64) - 1, @@ -1773,23 +1737,33 @@ mod tests { (i64::MAX as u64) + 1, u64::MAX - 1, u64::MAX, - ])); - let file = one_column_roundtrip(values, false, None); - - // check statistics are valid - let reader = SerializedFileReader::new(file).unwrap(); - let metadata = reader.metadata(); - assert_eq!(metadata.num_row_groups(), 1); - let row_group = metadata.row_group(0); - assert_eq!(row_group.num_columns(), 1); - let column = row_group.column(0); - let stats = column.statistics().unwrap(); - assert!(stats.has_min_max_set()); - if let Statistics::Int64(stats) = stats { - assert_eq!(*stats.min() as u64, u64::MIN); - assert_eq!(*stats.max() as u64, u64::MAX); - } else { - panic!("Statistics::Int64 missing") + ]; + let values = Arc::new(UInt64Array::from_iter_values(src.iter().cloned())); + let files = one_column_roundtrip(values, false); + + for file in files { + // check statistics are valid + let reader = SerializedFileReader::new(file).unwrap(); + let metadata = reader.metadata(); + + let mut row_offset = 0; + for row_group in metadata.row_groups() { + assert_eq!(row_group.num_columns(), 1); + let column = row_group.column(0); + + let num_values = column.num_values() as usize; + let src_slice = &src[row_offset..row_offset + num_values]; + row_offset += column.num_values() as usize; + + let stats = column.statistics().unwrap(); + assert!(stats.has_min_max_set()); + if let Statistics::Int64(stats) = stats { + assert_eq!(*stats.min() as u64, *src_slice.iter().min().unwrap()); + assert_eq!(*stats.max() as u64, *src_slice.iter().max().unwrap()); + } else { + panic!("Statistics::Int64 missing") + } + } } } @@ -1797,17 +1771,19 @@ mod tests { fn statistics_null_counts_only_nulls() { // check that null-count statistics for "only NULL"-columns are correct let values = Arc::new(UInt64Array::from(vec![None, None])); - let file = one_column_roundtrip(values, true, None); - - // check statistics are valid - let reader = SerializedFileReader::new(file).unwrap(); - let metadata = reader.metadata(); - assert_eq!(metadata.num_row_groups(), 1); - let row_group = metadata.row_group(0); - assert_eq!(row_group.num_columns(), 1); - let column = row_group.column(0); - let stats = column.statistics().unwrap(); - assert_eq!(stats.null_count(), 2); + let files = one_column_roundtrip(values, true); + + for file in files { + // check statistics are valid + let reader = SerializedFileReader::new(file).unwrap(); + let metadata = reader.metadata(); + assert_eq!(metadata.num_row_groups(), 1); + let row_group = metadata.row_group(0); + assert_eq!(row_group.num_columns(), 1); + let column = row_group.column(0); + let stats = column.statistics().unwrap(); + assert_eq!(stats.null_count(), 2); + } } #[test] @@ -1816,8 +1792,8 @@ mod tests { let int_field = Field::new("a", DataType::Int32, true); let int_field2 = Field::new("b", DataType::Int32, true); - let int_builder = Int32Builder::new(10); - let int_builder2 = Int32Builder::new(10); + let int_builder = Int32Builder::with_capacity(10); + let int_builder2 = Int32Builder::with_capacity(10); let struct_builder = StructBuilder::new( vec![int_field, int_field2], @@ -1833,81 +1809,71 @@ mod tests { values .field_builder::(0) .unwrap() - .append_value(1) - .unwrap(); + .append_value(1); values .field_builder::(1) .unwrap() - .append_value(2) - .unwrap(); - values.append(true).unwrap(); - list_builder.append(true).unwrap(); + .append_value(2); + values.append(true); + list_builder.append(true); // [] - list_builder.append(true).unwrap(); + list_builder.append(true); // null - list_builder.append(false).unwrap(); + list_builder.append(false); // [null, null] let values = list_builder.values(); values .field_builder::(0) .unwrap() - .append_null() - .unwrap(); + .append_null(); values .field_builder::(1) .unwrap() - .append_null() - .unwrap(); - values.append(false).unwrap(); + .append_null(); + values.append(false); values .field_builder::(0) .unwrap() - .append_null() - .unwrap(); + .append_null(); values .field_builder::(1) .unwrap() - .append_null() - .unwrap(); - values.append(false).unwrap(); - list_builder.append(true).unwrap(); + .append_null(); + values.append(false); + list_builder.append(true); // [{a: null, b: 3}] let values = list_builder.values(); values .field_builder::(0) .unwrap() - .append_null() - .unwrap(); + .append_null(); values .field_builder::(1) .unwrap() - .append_value(3) - .unwrap(); - values.append(true).unwrap(); - list_builder.append(true).unwrap(); + .append_value(3); + values.append(true); + list_builder.append(true); // [{a: 2, b: null}] let values = list_builder.values(); values .field_builder::(0) .unwrap() - .append_value(2) - .unwrap(); + .append_value(2); values .field_builder::(1) .unwrap() - .append_null() - .unwrap(); - values.append(true).unwrap(); - list_builder.append(true).unwrap(); + .append_null(); + values.append(true); + list_builder.append(true); let array = Arc::new(list_builder.finish()); - one_column_roundtrip(array, true, Some(10)); + one_column_roundtrip(array, true); } fn row_group_sizes(metadata: &ParquetMetaData) -> Vec { @@ -1946,11 +1912,12 @@ mod tests { writer.close().unwrap(); - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - assert_eq!(&row_group_sizes(arrow_reader.metadata()), &[200, 200, 50]); + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + assert_eq!(&row_group_sizes(builder.metadata()), &[200, 200, 50]); - let batches = arrow_reader - .get_record_reader(100) + let batches = builder + .with_batch_size(100) + .build() .unwrap() .collect::>>() .unwrap(); @@ -2091,11 +2058,12 @@ mod tests { // Should have written entire first batch and first row of second to the first row group // leaving a single row in the second row group - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - assert_eq!(&row_group_sizes(arrow_reader.metadata()), &[6, 1]); + let builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + assert_eq!(&row_group_sizes(builder.metadata()), &[6, 1]); - let batches = arrow_reader - .get_record_reader(2) + let batches = builder + .with_batch_size(2) + .build() .unwrap() .collect::>>() .unwrap(); diff --git a/parquet/src/arrow/async_reader.rs b/parquet/src/arrow/async_reader.rs index 923f329eff20..201f2afcf0e8 100644 --- a/parquet/src/arrow/async_reader.rs +++ b/parquet/src/arrow/async_reader.rs @@ -77,46 +77,90 @@ use std::collections::VecDeque; use std::fmt::Formatter; + use std::io::{Cursor, SeekFrom}; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use futures::future::{BoxFuture, FutureExt}; +use futures::ready; use futures::stream::Stream; -use parquet_format::PageType; +use parquet_format::OffsetIndex; +use thrift::protocol::TCompactInputProtocol; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use crate::arrow::array_reader::{build_array_reader, RowGroupCollection}; -use crate::arrow::arrow_reader::ParquetRecordBatchReader; -use crate::arrow::schema::parquet_to_arrow_schema; +use crate::arrow::arrow_reader::{ + evaluate_predicate, selects_any, ArrowReaderBuilder, ArrowReaderOptions, + ParquetRecordBatchReader, RowFilter, RowSelection, +}; use crate::arrow::ProjectionMask; -use crate::basic::Compression; -use crate::column::page::{Page, PageIterator, PageMetadata, PageReader}; -use crate::compression::{create_codec, Codec}; + +use crate::column::page::{PageIterator, PageReader}; + use crate::errors::{ParquetError, Result}; use crate::file::footer::{decode_footer, decode_metadata}; -use crate::file::metadata::ParquetMetaData; -use crate::file::serialized_reader::{decode_page, read_page_header}; +use crate::file::metadata::{ParquetMetaData, RowGroupMetaData}; +use crate::file::reader::{ChunkReader, Length, SerializedPageReader}; + +use crate::file::page_index::index_reader; use crate::file::FOOTER_SIZE; -use crate::schema::types::{ColumnDescPtr, SchemaDescPtr, SchemaDescriptor}; + +use crate::schema::types::{ColumnDescPtr, SchemaDescPtr}; /// The asynchronous interface used by [`ParquetRecordBatchStream`] to read parquet files -pub trait AsyncFileReader { +pub trait AsyncFileReader: Send { /// Retrieve the bytes in `range` fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, Result>; + /// Retrieve multiple byte ranges. The default implementation will call `get_bytes` sequentially + fn get_byte_ranges( + &mut self, + ranges: Vec>, + ) -> BoxFuture<'_, Result>> { + async move { + let mut result = Vec::with_capacity(ranges.len()); + + for range in ranges.into_iter() { + let data = self.get_bytes(range).await?; + result.push(data); + } + + Ok(result) + } + .boxed() + } + /// Provides asynchronous access to the [`ParquetMetaData`] of a parquet file, /// allowing fine-grained control over how metadata is sourced, in particular allowing /// for caching, pre-fetching, catalog metadata, etc... fn get_metadata(&mut self) -> BoxFuture<'_, Result>>; } +impl AsyncFileReader for Box { + fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, Result> { + self.as_mut().get_bytes(range) + } + + fn get_byte_ranges( + &mut self, + ranges: Vec>, + ) -> BoxFuture<'_, Result>> { + self.as_mut().get_byte_ranges(ranges) + } + + fn get_metadata(&mut self) -> BoxFuture<'_, Result>> { + self.as_mut().get_metadata() + } +} + impl AsyncFileReader for T { fn get_bytes(&mut self, range: Range) -> BoxFuture<'_, Result> { async move { @@ -155,80 +199,116 @@ impl AsyncFileReader for T { } } +#[doc(hidden)] +/// A newtype used within [`ReaderOptionsBuilder`] to distinguish sync readers from async +/// +/// Allows sharing the same builder for both the sync and async versions, whilst also not +/// breaking the pre-existing ParquetRecordBatchStreamBuilder API +pub struct AsyncReader(T); + /// A builder used to construct a [`ParquetRecordBatchStream`] for a parquet file /// /// In particular, this handles reading the parquet file metadata, allowing consumers /// to use this information to select what specific columns, row groups, etc... /// they wish to be read by the resulting stream /// -pub struct ParquetRecordBatchStreamBuilder { - input: T, +pub type ParquetRecordBatchStreamBuilder = ArrowReaderBuilder>; - metadata: Arc, - - schema: SchemaRef, - - batch_size: usize, - - row_groups: Option>, - - projection: ProjectionMask, -} - -impl ParquetRecordBatchStreamBuilder { +impl ArrowReaderBuilder> { /// Create a new [`ParquetRecordBatchStreamBuilder`] with the provided parquet file pub async fn new(mut input: T) -> Result { let metadata = input.get_metadata().await?; + Self::new_builder(AsyncReader(input), metadata, Default::default()) + } + + pub async fn new_with_options( + mut input: T, + options: ArrowReaderOptions, + ) -> Result { + let mut metadata = input.get_metadata().await?; + + if options.page_index + && metadata + .page_indexes() + .zip(metadata.offset_indexes()) + .is_none() + { + let mut fetch_ranges = vec![]; + let mut index_lengths: Vec> = vec![]; + + for rg in metadata.row_groups() { + let (loc_offset, loc_length) = + index_reader::get_location_offset_and_total_length(rg.columns())?; + + let (idx_offset, idx_lengths) = + index_reader::get_index_offset_and_lengths(rg.columns())?; + let idx_length = idx_lengths.iter().sum::(); + + // If index data is missing, return without any indexes + if loc_length == 0 || idx_length == 0 { + return Self::new_builder(AsyncReader(input), metadata, options); + } - let schema = Arc::new(parquet_to_arrow_schema( - metadata.file_metadata().schema_descr(), - metadata.file_metadata().key_value_metadata(), - )?); + fetch_ranges.push(loc_offset as usize..loc_offset as usize + loc_length); + fetch_ranges.push(idx_offset as usize..idx_offset as usize + idx_length); + index_lengths.push(idx_lengths); + } - Ok(Self { - input, - metadata, - schema, - batch_size: 1024, - row_groups: None, - projection: ProjectionMask::all(), - }) - } + let mut chunks = input.get_byte_ranges(fetch_ranges).await?.into_iter(); + let mut index_lengths = index_lengths.into_iter(); - /// Returns a reference to the [`ParquetMetaData`] for this parquet file - pub fn metadata(&self) -> &Arc { - &self.metadata - } + let mut row_groups = metadata.row_groups().to_vec(); - /// Returns the parquet [`SchemaDescriptor`] for this parquet file - pub fn parquet_schema(&self) -> &SchemaDescriptor { - self.metadata.file_metadata().schema_descr() - } + let mut columns_indexes = vec![]; + let mut offset_indexes = vec![]; - /// Returns the arrow [`SchemaRef`] for this parquet file - pub fn schema(&self) -> &SchemaRef { - &self.schema - } + for rg in row_groups.iter_mut() { + let columns = rg.columns(); - /// Set the size of [`RecordBatch`] to produce - pub fn with_batch_size(self, batch_size: usize) -> Self { - Self { batch_size, ..self } - } + let location_data = chunks.next().unwrap(); + let mut cursor = Cursor::new(location_data); + let mut offset_index = vec![]; - /// Only read data from the provided row group indexes - pub fn with_row_groups(self, row_groups: Vec) -> Self { - Self { - row_groups: Some(row_groups), - ..self - } - } + for _ in 0..columns.len() { + let mut prot = TCompactInputProtocol::new(&mut cursor); + let offset = OffsetIndex::read_from_in_protocol(&mut prot)?; + offset_index.push(offset.page_locations); + } - /// Only read data from the provided column indexes - pub fn with_projection(self, mask: ProjectionMask) -> Self { - Self { - projection: mask, - ..self + rg.set_page_offset(offset_index.clone()); + offset_indexes.push(offset_index); + + let index_data = chunks.next().unwrap(); + let index_lengths = index_lengths.next().unwrap(); + + let mut start = 0; + let data = index_lengths.into_iter().map(|length| { + let r = index_data.slice(start..start + length); + start += length; + r + }); + + let indexes = rg + .columns() + .iter() + .zip(data) + .map(|(column, data)| { + let column_type = column.column_type(); + index_reader::deserialize_column_index(&data, column_type) + }) + .collect::>>()?; + columns_indexes.push(indexes); + } + + metadata = Arc::new(ParquetMetaData::new_with_page_index( + metadata.file_metadata().clone(), + row_groups, + Some(columns_indexes), + Some(offset_indexes), + )); } + + Self::new_builder(AsyncReader(input), metadata, options) } /// Build a new [`ParquetRecordBatchStream`] @@ -249,25 +329,119 @@ impl ParquetRecordBatchStreamBuilder { None => (0..self.metadata.row_groups().len()).collect(), }; + // Try to avoid allocate large buffer + let batch_size = self + .batch_size + .min(self.metadata.file_metadata().num_rows() as usize); + let reader = ReaderFactory { + input: self.input.0, + filter: self.filter, + metadata: self.metadata.clone(), + schema: self.schema.clone(), + }; + Ok(ParquetRecordBatchStream { + metadata: self.metadata, + batch_size, row_groups, projection: self.projection, - batch_size: self.batch_size, - metadata: self.metadata, + selection: self.selection, schema: self.schema, - input: Some(self.input), + reader: Some(reader), state: StreamState::Init, }) } } +type ReadResult = Result<(ReaderFactory, Option)>; + +/// [`ReaderFactory`] is used by [`ParquetRecordBatchStream`] to create +/// [`ParquetRecordBatchReader`] +struct ReaderFactory { + metadata: Arc, + + schema: SchemaRef, + + input: T, + + filter: Option, +} + +impl ReaderFactory +where + T: AsyncFileReader + Send, +{ + /// Reads the next row group with the provided `selection`, `projection` and `batch_size` + /// + /// Note: this captures self so that the resulting future has a static lifetime + async fn read_row_group( + mut self, + row_group_idx: usize, + mut selection: Option, + projection: ProjectionMask, + batch_size: usize, + ) -> ReadResult { + // TODO: calling build_array multiple times is wasteful + + let meta = self.metadata.row_group(row_group_idx); + let mut row_group = InMemoryRowGroup { + metadata: meta, + // schema: meta.schema_descr_ptr(), + row_count: meta.num_rows() as usize, + column_chunks: vec![None; meta.columns().len()], + }; + + if let Some(filter) = self.filter.as_mut() { + for predicate in filter.predicates.iter_mut() { + if !selects_any(selection.as_ref()) { + return Ok((self, None)); + } + + let predicate_projection = predicate.projection().clone(); + row_group + .fetch(&mut self.input, &predicate_projection, selection.as_ref()) + .await?; + + let array_reader = build_array_reader( + self.schema.clone(), + predicate_projection, + &row_group, + )?; + + selection = Some(evaluate_predicate( + batch_size, + array_reader, + selection, + predicate.as_mut(), + )?); + } + } + + if !selects_any(selection.as_ref()) { + return Ok((self, None)); + } + + row_group + .fetch(&mut self.input, &projection, selection.as_ref()) + .await?; + + let reader = ParquetRecordBatchReader::new( + batch_size, + build_array_reader(self.schema.clone(), projection, &row_group)?, + selection, + ); + + Ok((self, Some(reader))) + } +} + enum StreamState { /// At the start of a new row group, or the end of the parquet stream Init, /// Decoding a batch Decoding(ParquetRecordBatchReader), /// Reading data from input - Reading(BoxFuture<'static, Result<(T, InMemoryRowGroup)>>), + Reading(BoxFuture<'static, ReadResult>), /// Error Error, } @@ -283,20 +457,23 @@ impl std::fmt::Debug for StreamState { } } -/// An asynchronous [`Stream`] of [`RecordBatch`] for a parquet file +/// An asynchronous [`Stream`] of [`RecordBatch`] for a parquet file that can be +/// constructed using [`ParquetRecordBatchStreamBuilder`] pub struct ParquetRecordBatchStream { metadata: Arc, schema: SchemaRef, - batch_size: usize, + row_groups: VecDeque, projection: ProjectionMask, - row_groups: VecDeque, + batch_size: usize, + + selection: Option, /// This is an option so it can be moved into a future - input: Option, + reader: Option>, state: StreamState, } @@ -348,87 +525,40 @@ where None => return Poll::Ready(None), }; - let metadata = self.metadata.clone(); - let mut input = match self.input.take() { - Some(input) => input, - None => { - self.state = StreamState::Error; - return Poll::Ready(Some(Err(general_err!( - "input stream lost" - )))); - } - }; - - let projection = self.projection.clone(); - self.state = StreamState::Reading( - async move { - let row_group_metadata = metadata.row_group(row_group_idx); - let mut column_chunks = - vec![None; row_group_metadata.columns().len()]; - - // TODO: Combine consecutive ranges - for (idx, chunk) in column_chunks.iter_mut().enumerate() { - if !projection.leaf_included(idx) { - continue; - } - - let column = row_group_metadata.column(idx); - let (start, length) = column.byte_range(); - - let data = input - .get_bytes(start as usize..(start + length) as usize) - .await?; - - *chunk = Some(InMemoryColumnChunk { - num_values: column.num_values(), - compression: column.compression(), - physical_type: column.column_type(), - data, - }); - } - - Ok(( - input, - InMemoryRowGroup { - schema: metadata.file_metadata().schema_descr_ptr(), - row_count: row_group_metadata.num_rows() as usize, - column_chunks, - }, - )) - } - .boxed(), - ) - } - StreamState::Reading(f) => { - let result = futures::ready!(f.poll_unpin(cx)); - self.state = StreamState::Init; - - let row_group: Box = match result { - Ok((input, row_group)) => { - self.input = Some(input); - Box::new(row_group) - } - Err(e) => { - self.state = StreamState::Error; - return Poll::Ready(Some(Err(e))); - } - }; + let reader = self.reader.take().expect("lost reader"); - let parquet_schema = self.metadata.file_metadata().schema_descr_ptr(); + let row_count = + self.metadata.row_group(row_group_idx).num_rows() as usize; - let array_reader = build_array_reader( - parquet_schema, - self.schema.clone(), - self.projection.clone(), - row_group, - )?; + let selection = + self.selection.as_mut().map(|s| s.split_off(row_count)); - let batch_reader = - ParquetRecordBatchReader::try_new(self.batch_size, array_reader) - .expect("reader"); + let fut = reader + .read_row_group( + row_group_idx, + selection, + self.projection.clone(), + self.batch_size, + ) + .boxed(); - self.state = StreamState::Decoding(batch_reader) + self.state = StreamState::Reading(fut) } + StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) { + Ok((reader_factory, maybe_reader)) => { + self.reader = Some(reader_factory); + match maybe_reader { + // Read records from [`ParquetRecordBatchReader`] + Some(reader) => self.state = StreamState::Decoding(reader), + // All rows skipped, read next row group + None => self.state = StreamState::Init, + } + } + Err(e) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(e))); + } + }, StreamState::Error => return Poll::Pending, } } @@ -436,128 +566,198 @@ where } /// An in-memory collection of column chunks -struct InMemoryRowGroup { - schema: SchemaDescPtr, - column_chunks: Vec>, +struct InMemoryRowGroup<'a> { + metadata: &'a RowGroupMetaData, + column_chunks: Vec>>, row_count: usize, } -impl RowGroupCollection for InMemoryRowGroup { - fn schema(&self) -> Result { - Ok(self.schema.clone()) - } +impl<'a> InMemoryRowGroup<'a> { + /// Fetches the necessary column data into memory + async fn fetch( + &mut self, + input: &mut T, + projection: &ProjectionMask, + selection: Option<&RowSelection>, + ) -> Result<()> { + if let Some((selection, page_locations)) = + selection.zip(self.metadata.page_offset_index().as_ref()) + { + // If we have a `RowSelection` and an `OffsetIndex` then only fetch pages required for the + // `RowSelection` + let mut page_start_offsets: Vec> = vec![]; + + let fetch_ranges = self + .column_chunks + .iter() + .zip(self.metadata.columns()) + .enumerate() + .into_iter() + .filter_map(|(idx, (chunk, chunk_meta))| { + (chunk.is_none() && projection.leaf_included(idx)).then(|| { + // If the first page does not start at the beginning of the column, + // then we need to also fetch a dictionary page. + let mut ranges = vec![]; + let (start, _len) = chunk_meta.byte_range(); + match page_locations[idx].first() { + Some(first) if first.offset as u64 != start => { + ranges.push(start as usize..first.offset as usize); + } + _ => (), + } - fn num_rows(&self) -> usize { - self.row_count - } + ranges.extend(selection.scan_ranges(&page_locations[idx])); + page_start_offsets + .push(ranges.iter().map(|range| range.start).collect()); - fn column_chunks(&self, i: usize) -> Result> { - let page_reader = self.column_chunks[i].as_ref().unwrap().pages(); + ranges + }) + }) + .flatten() + .collect(); - Ok(Box::new(ColumnChunkIterator { - schema: self.schema.clone(), - column_schema: self.schema.columns()[i].clone(), - reader: Some(page_reader), - })) - } -} + let mut chunk_data = input.get_byte_ranges(fetch_ranges).await?.into_iter(); + let mut page_start_offsets = page_start_offsets.into_iter(); -/// Data for a single column chunk -#[derive(Clone)] -struct InMemoryColumnChunk { - num_values: i64, - compression: Compression, - physical_type: crate::basic::Type, - data: Bytes, -} + for (idx, chunk) in self.column_chunks.iter_mut().enumerate() { + if chunk.is_some() || !projection.leaf_included(idx) { + continue; + } -impl InMemoryColumnChunk { - fn pages(&self) -> Result> { - let page_reader = InMemoryColumnChunkReader::new(self.clone())?; - Ok(Box::new(page_reader)) - } -} + if let Some(offsets) = page_start_offsets.next() { + let mut chunks = Vec::with_capacity(offsets.len()); + for _ in 0..offsets.len() { + chunks.push(chunk_data.next().unwrap()); + } -// A serialized implementation for Parquet [`PageReader`]. -struct InMemoryColumnChunkReader { - chunk: InMemoryColumnChunk, - decompressor: Option>, - offset: usize, - seen_num_values: i64, -} + *chunk = Some(Arc::new(ColumnChunkData::Sparse { + length: self.metadata.column(idx).byte_range().1 as usize, + data: offsets.into_iter().zip(chunks.into_iter()).collect(), + })) + } + } + } else { + let fetch_ranges = self + .column_chunks + .iter() + .enumerate() + .into_iter() + .filter_map(|(idx, chunk)| { + (chunk.is_none() && projection.leaf_included(idx)).then(|| { + let column = self.metadata.column(idx); + let (start, length) = column.byte_range(); + start as usize..(start + length) as usize + }) + }) + .collect(); + + let mut chunk_data = input.get_byte_ranges(fetch_ranges).await?.into_iter(); + + for (idx, chunk) in self.column_chunks.iter_mut().enumerate() { + if chunk.is_some() || !projection.leaf_included(idx) { + continue; + } -impl InMemoryColumnChunkReader { - /// Creates a new serialized page reader from file source. - fn new(chunk: InMemoryColumnChunk) -> Result { - let decompressor = create_codec(chunk.compression)?; - let result = Self { - chunk, - decompressor, - offset: 0, - seen_num_values: 0, - }; - Ok(result) + if let Some(data) = chunk_data.next() { + *chunk = Some(Arc::new(ColumnChunkData::Dense { + offset: self.metadata.column(idx).byte_range().0 as usize, + data, + })); + } + } + } + + Ok(()) } } -impl Iterator for InMemoryColumnChunkReader { - type Item = Result; +impl<'a> RowGroupCollection for InMemoryRowGroup<'a> { + fn schema(&self) -> SchemaDescPtr { + self.metadata.schema_descr_ptr() + } - fn next(&mut self) -> Option { - self.get_next_page().transpose() + fn num_rows(&self) -> usize { + self.row_count + } + + fn column_chunks(&self, i: usize) -> Result> { + match &self.column_chunks[i] { + None => Err(ParquetError::General(format!( + "Invalid column index {}, column was not fetched", + i + ))), + Some(data) => { + let page_locations = self + .metadata + .page_offset_index() + .as_ref() + .map(|index| index[i].clone()); + let page_reader: Box = + Box::new(SerializedPageReader::new( + data.clone(), + self.metadata.column(i), + self.row_count, + page_locations, + )?); + + Ok(Box::new(ColumnChunkIterator { + schema: self.metadata.schema_descr_ptr(), + column_schema: self.metadata.schema_descr_ptr().columns()[i].clone(), + reader: Some(Ok(page_reader)), + })) + } + } } } -impl PageReader for InMemoryColumnChunkReader { - fn get_next_page(&mut self) -> Result> { - while self.seen_num_values < self.chunk.num_values { - let mut cursor = Cursor::new(&self.chunk.data.as_ref()[self.offset..]); - let page_header = read_page_header(&mut cursor)?; - let compressed_size = page_header.compressed_page_size as usize; - - self.offset += cursor.position() as usize; - let start_offset = self.offset; - let end_offset = self.offset + compressed_size; - self.offset = end_offset; - - let buffer = self.chunk.data.slice(start_offset..end_offset); - - let result = match page_header.type_ { - PageType::DataPage | PageType::DataPageV2 => { - let decoded = decode_page( - page_header, - buffer.into(), - self.chunk.physical_type, - self.decompressor.as_mut(), - )?; - self.seen_num_values += decoded.num_values() as i64; - decoded - } - PageType::DictionaryPage => decode_page( - page_header, - buffer.into(), - self.chunk.physical_type, - self.decompressor.as_mut(), - )?, - _ => { - // For unknown page type (e.g., INDEX_PAGE), skip and read next. - continue; - } - }; +/// An in-memory column chunk +#[derive(Clone)] +enum ColumnChunkData { + /// Column chunk data representing only a subset of data pages + Sparse { + /// Length of the full column chunk + length: usize, + /// Set of data pages included in this sparse chunk. Each element is a tuple + /// of (page offset, page data) + data: Vec<(usize, Bytes)>, + }, + /// Full column chunk and its offset + Dense { offset: usize, data: Bytes }, +} - return Ok(Some(result)); +impl Length for ColumnChunkData { + fn len(&self) -> u64 { + match &self { + ColumnChunkData::Sparse { length, .. } => *length as u64, + ColumnChunkData::Dense { data, .. } => data.len() as u64, } - - // We are at the end of this column chunk and no more page left. Return None. - Ok(None) } +} + +impl ChunkReader for ColumnChunkData { + type T = bytes::buf::Reader; - fn peek_next_page(&mut self) -> Result> { - Err(nyi_err!("https://github.com/apache/arrow-rs/issues/1792")) + fn get_read(&self, start: u64, length: usize) -> Result { + Ok(self.get_bytes(start, length)?.reader()) } - fn skip_next_page(&mut self) -> Result<()> { - Err(nyi_err!("https://github.com/apache/arrow-rs/issues/1792")) + fn get_bytes(&self, start: u64, length: usize) -> Result { + match &self { + ColumnChunkData::Sparse { data, .. } => data + .binary_search_by_key(&start, |(offset, _)| *offset as u64) + .map(|idx| data[idx].1.slice(0..length)) + .map_err(|_| { + ParquetError::General(format!( + "Invalid offset in sparse column chunk data: {}", + start + )) + }), + ColumnChunkData::Dense { offset, data } => { + let start = start as usize - *offset; + let end = start + length; + Ok(data.slice(start..end)) + } + } } } @@ -589,7 +789,13 @@ impl PageIterator for ColumnChunkIterator { #[cfg(test)] mod tests { use super::*; - use crate::arrow::{ArrowReader, ParquetFileArrowReader}; + use crate::arrow::arrow_reader::{ + ArrowPredicateFn, ParquetRecordBatchReaderBuilder, RowSelector, + }; + use crate::arrow::{parquet_to_arrow_schema, ArrowWriter}; + use crate::file::footer::parse_metadata; + use crate::file::page_index::index_reader; + use arrow::array::{Array, ArrayRef, Int32Array, StringArray}; use arrow::error::Result as ArrowResult; use futures::TryStreamExt; use std::sync::Mutex; @@ -617,7 +823,7 @@ mod tests { let path = format!("{}/alltypes_plain.parquet", testdata); let data = Bytes::from(std::fs::read(path).unwrap()); - let metadata = crate::file::footer::parse_metadata(&data).unwrap(); + let metadata = parse_metadata(&data).unwrap(); let metadata = Arc::new(metadata); assert_eq!(metadata.num_row_groups(), 1); @@ -642,9 +848,11 @@ mod tests { let async_batches: Vec<_> = stream.try_collect().await.unwrap(); - let mut sync_reader = ParquetFileArrowReader::try_new(data).unwrap(); - let sync_batches = sync_reader - .get_record_reader_by_columns(mask, 1024) + let sync_batches = ParquetRecordBatchReaderBuilder::try_new(data) + .unwrap() + .with_projection(mask) + .with_batch_size(104) + .build() .unwrap() .collect::>>() .unwrap(); @@ -663,4 +871,300 @@ mod tests { ] ); } + + #[tokio::test] + async fn test_async_reader_with_index() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/alltypes_tiny_pages_plain.parquet", testdata); + let data = Bytes::from(std::fs::read(path).unwrap()); + + let metadata = parse_metadata(&data).unwrap(); + let metadata = Arc::new(metadata); + + assert_eq!(metadata.num_row_groups(), 1); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let options = ArrowReaderOptions::new().with_page_index(true); + let builder = + ParquetRecordBatchStreamBuilder::new_with_options(async_reader, options) + .await + .unwrap(); + + // The builder should have page and offset indexes loaded now + let metadata_with_index = builder.metadata(); + + // Check offset indexes are present for all columns + for rg in metadata_with_index.row_groups() { + let page_locations = rg + .page_offset_index() + .as_ref() + .expect("expected page offset index"); + assert_eq!(page_locations.len(), rg.columns().len()) + } + + // Check page indexes are present for all columns + let page_indexes = metadata_with_index + .page_indexes() + .expect("expected page indexes"); + for (idx, rg) in metadata_with_index.row_groups().iter().enumerate() { + assert_eq!(page_indexes[idx].len(), rg.columns().len()) + } + + let mask = ProjectionMask::leaves(builder.parquet_schema(), vec![1, 2]); + let stream = builder + .with_projection(mask.clone()) + .with_batch_size(1024) + .build() + .unwrap(); + + let async_batches: Vec<_> = stream.try_collect().await.unwrap(); + + let sync_batches = ParquetRecordBatchReaderBuilder::try_new(data) + .unwrap() + .with_projection(mask) + .with_batch_size(1024) + .build() + .unwrap() + .collect::>>() + .unwrap(); + + assert_eq!(async_batches, sync_batches); + } + + #[tokio::test] + async fn test_row_filter() { + let a = StringArray::from_iter_values(["a", "b", "b", "b", "c", "c"]); + let b = StringArray::from_iter_values(["1", "2", "3", "4", "5", "6"]); + let c = Int32Array::from_iter(0..6); + let data = RecordBatch::try_from_iter([ + ("a", Arc::new(a) as ArrayRef), + ("b", Arc::new(b) as ArrayRef), + ("c", Arc::new(c) as ArrayRef), + ]) + .unwrap(); + + let mut buf = Vec::with_capacity(1024); + let mut writer = ArrowWriter::try_new(&mut buf, data.schema(), None).unwrap(); + writer.write(&data).unwrap(); + writer.close().unwrap(); + + let data: Bytes = buf.into(); + let metadata = parse_metadata(&data).unwrap(); + let parquet_schema = metadata.file_metadata().schema_descr_ptr(); + + let test = TestReader { + data, + metadata: Arc::new(metadata), + requests: Default::default(), + }; + let requests = test.requests.clone(); + + let a_filter = ArrowPredicateFn::new( + ProjectionMask::leaves(&parquet_schema, vec![0]), + |batch| arrow::compute::eq_dyn_utf8_scalar(batch.column(0), "b"), + ); + + let b_filter = ArrowPredicateFn::new( + ProjectionMask::leaves(&parquet_schema, vec![1]), + |batch| arrow::compute::eq_dyn_utf8_scalar(batch.column(0), "4"), + ); + + let filter = RowFilter::new(vec![Box::new(a_filter), Box::new(b_filter)]); + + let mask = ProjectionMask::leaves(&parquet_schema, vec![0, 2]); + let stream = ParquetRecordBatchStreamBuilder::new(test) + .await + .unwrap() + .with_projection(mask.clone()) + .with_batch_size(1024) + .with_row_filter(filter) + .build() + .unwrap(); + + let batches: Vec<_> = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + + let batch = &batches[0]; + assert_eq!(batch.num_rows(), 1); + assert_eq!(batch.num_columns(), 2); + + let col = batch.column(0); + let val = col.as_any().downcast_ref::().unwrap().value(0); + assert_eq!(val, "b"); + + let col = batch.column(1); + let val = col.as_any().downcast_ref::().unwrap().value(0); + assert_eq!(val, 3); + + // Should only have made 3 requests + assert_eq!(requests.lock().unwrap().len(), 3); + } + + #[tokio::test] + async fn test_row_filter_with_index() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/alltypes_tiny_pages_plain.parquet", testdata); + let data = Bytes::from(std::fs::read(path).unwrap()); + + let metadata = parse_metadata(&data).unwrap(); + let parquet_schema = metadata.file_metadata().schema_descr_ptr(); + let metadata = Arc::new(metadata); + + assert_eq!(metadata.num_row_groups(), 1); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let a_filter = ArrowPredicateFn::new( + ProjectionMask::leaves(&parquet_schema, vec![1]), + |batch| arrow::compute::eq_dyn_bool_scalar(batch.column(0), true), + ); + + let b_filter = ArrowPredicateFn::new( + ProjectionMask::leaves(&parquet_schema, vec![2]), + |batch| arrow::compute::eq_dyn_scalar(batch.column(0), 2_i32), + ); + + let filter = RowFilter::new(vec![Box::new(a_filter), Box::new(b_filter)]); + + let mask = ProjectionMask::leaves(&parquet_schema, vec![0, 2]); + + let options = ArrowReaderOptions::new().with_page_index(true); + let stream = + ParquetRecordBatchStreamBuilder::new_with_options(async_reader, options) + .await + .unwrap() + .with_projection(mask.clone()) + .with_batch_size(1024) + .with_row_filter(filter) + .build() + .unwrap(); + + let batches: Vec = stream.try_collect().await.unwrap(); + + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 730); + } + + #[tokio::test] + async fn test_in_memory_row_group_sparse() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{}/alltypes_tiny_pages.parquet", testdata); + let data = Bytes::from(std::fs::read(path).unwrap()); + + let metadata = parse_metadata(&data).unwrap(); + + let offset_index = + index_reader::read_pages_locations(&data, metadata.row_group(0).columns()) + .expect("reading offset index"); + + let mut row_group_meta = metadata.row_group(0).clone(); + row_group_meta.set_page_offset(offset_index.clone()); + let metadata = + ParquetMetaData::new(metadata.file_metadata().clone(), vec![row_group_meta]); + + let metadata = Arc::new(metadata); + + let num_rows = metadata.row_group(0).num_rows(); + + assert_eq!(metadata.num_row_groups(), 1); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let requests = async_reader.requests.clone(); + let schema = Arc::new( + parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) + .expect("building arrow schema"), + ); + + let _schema_desc = metadata.file_metadata().schema_descr(); + + let projection = + ProjectionMask::leaves(metadata.file_metadata().schema_descr(), vec![0]); + + let reader_factory = ReaderFactory { + metadata, + schema, + input: async_reader, + filter: None, + }; + + let mut skip = true; + let mut pages = offset_index[0].iter().peekable(); + + // Setup `RowSelection` so that we can skip every other page, selecting the last page + let mut selectors = vec![]; + let mut expected_page_requests: Vec> = vec![]; + while let Some(page) = pages.next() { + let num_rows = if let Some(next_page) = pages.peek() { + next_page.first_row_index - page.first_row_index + } else { + num_rows - page.first_row_index + }; + + if skip { + selectors.push(RowSelector::skip(num_rows as usize)); + } else { + selectors.push(RowSelector::select(num_rows as usize)); + let start = page.offset as usize; + let end = start + page.compressed_page_size as usize; + expected_page_requests.push(start..end); + } + skip = !skip; + } + + let selection = RowSelection::from(selectors); + + let (_factory, _reader) = reader_factory + .read_row_group(0, Some(selection), projection.clone(), 48) + .await + .expect("reading row group"); + + let requests = requests.lock().unwrap(); + + assert_eq!(&requests[..], &expected_page_requests) + } + + #[tokio::test] + async fn test_batch_size_overallocate() { + let testdata = arrow::util::test_util::parquet_test_data(); + // `alltypes_plain.parquet` only have 8 rows + let path = format!("{}/alltypes_plain.parquet", testdata); + let data = Bytes::from(std::fs::read(path).unwrap()); + + let metadata = parse_metadata(&data).unwrap(); + let file_rows = metadata.file_metadata().num_rows() as usize; + let metadata = Arc::new(metadata); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let builder = ParquetRecordBatchStreamBuilder::new(async_reader) + .await + .unwrap(); + + let stream = builder + .with_projection(ProjectionMask::all()) + .with_batch_size(1024) + .build() + .unwrap(); + assert_ne!(1024, file_rows); + assert_eq!(stream.batch_size, file_rows as usize); + } } diff --git a/parquet/src/arrow/buffer/bit_util.rs b/parquet/src/arrow/buffer/bit_util.rs index 192ab4b72163..04704237c458 100644 --- a/parquet/src/arrow/buffer/bit_util.rs +++ b/parquet/src/arrow/buffer/bit_util.rs @@ -51,6 +51,17 @@ pub fn iter_set_bits_rev(bytes: &[u8]) -> impl Iterator + '_ { }) } +/// Performs big endian sign extension +pub fn sign_extend_be(b: &[u8]) -> [u8; N] { + assert!(b.len() <= N, "Array too large, expected less than {}", N); + let is_negative = (b[0] & 128u8) == 128u8; + let mut result = if is_negative { [255u8; N] } else { [0u8; N] }; + for (d, s) in result.iter_mut().skip(N - b.len()).zip(b) { + *d = *s; + } + result +} + #[cfg(test)] mod tests { use super::*; diff --git a/parquet/src/arrow/buffer/converter.rs b/parquet/src/arrow/buffer/converter.rs deleted file mode 100644 index 51e1d8290ee3..000000000000 --- a/parquet/src/arrow/buffer/converter.rs +++ /dev/null @@ -1,335 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::data_type::{ByteArray, FixedLenByteArray, Int96}; -use arrow::array::{ - Array, ArrayRef, BinaryArray, BinaryBuilder, DecimalArray, FixedSizeBinaryArray, - FixedSizeBinaryBuilder, IntervalDayTimeArray, IntervalDayTimeBuilder, - IntervalYearMonthArray, IntervalYearMonthBuilder, LargeBinaryArray, - LargeBinaryBuilder, LargeStringArray, LargeStringBuilder, StringArray, StringBuilder, - TimestampNanosecondArray, -}; -use std::convert::{From, TryInto}; -use std::sync::Arc; - -use crate::errors::Result; -use std::marker::PhantomData; - -/// A converter is used to consume record reader's content and convert it to arrow -/// primitive array. -pub trait Converter { - /// This method converts record reader's buffered content into arrow array. - /// It will consume record reader's data, but will not reset record reader's - /// state. - fn convert(&self, source: S) -> Result; -} - -pub struct FixedSizeArrayConverter { - byte_width: i32, -} - -impl FixedSizeArrayConverter { - pub fn new(byte_width: i32) -> Self { - Self { byte_width } - } -} - -impl Converter>, FixedSizeBinaryArray> - for FixedSizeArrayConverter -{ - fn convert( - &self, - source: Vec>, - ) -> Result { - let mut builder = FixedSizeBinaryBuilder::new(source.len(), self.byte_width); - for v in source { - match v { - Some(array) => builder.append_value(array.data()), - None => builder.append_null(), - }? - } - - Ok(builder.finish()) - } -} - -pub struct DecimalArrayConverter { - precision: i32, - scale: i32, -} - -impl DecimalArrayConverter { - pub fn new(precision: i32, scale: i32) -> Self { - Self { precision, scale } - } - - fn from_bytes_to_i128(b: &[u8]) -> i128 { - assert!(b.len() <= 16, "DecimalArray supports only up to size 16"); - let first_bit = b[0] & 128u8 == 128u8; - let mut result = if first_bit { [255u8; 16] } else { [0u8; 16] }; - for (i, v) in b.iter().enumerate() { - result[i + (16 - b.len())] = *v; - } - i128::from_be_bytes(result) - } -} - -impl Converter>, DecimalArray> for DecimalArrayConverter { - fn convert(&self, source: Vec>) -> Result { - let array = source - .into_iter() - .map(|array| array.map(|array| Self::from_bytes_to_i128(array.data()))) - .collect::() - .with_precision_and_scale(self.precision as usize, self.scale as usize)?; - - Ok(array) - } -} -/// An Arrow Interval converter, which reads the first 4 bytes of a Parquet interval, -/// and interprets it as an i32 value representing the Arrow YearMonth value -pub struct IntervalYearMonthArrayConverter {} - -impl Converter>, IntervalYearMonthArray> - for IntervalYearMonthArrayConverter -{ - fn convert( - &self, - source: Vec>, - ) -> Result { - let mut builder = IntervalYearMonthBuilder::new(source.len()); - for v in source { - match v { - Some(array) => builder.append_value(i32::from_le_bytes( - array.data()[0..4].try_into().unwrap(), - )), - None => builder.append_null(), - }? - } - - Ok(builder.finish()) - } -} - -/// An Arrow Interval converter, which reads the last 8 bytes of a Parquet interval, -/// and interprets it as an i32 value representing the Arrow DayTime value -pub struct IntervalDayTimeArrayConverter {} - -impl Converter>, IntervalDayTimeArray> - for IntervalDayTimeArrayConverter -{ - fn convert( - &self, - source: Vec>, - ) -> Result { - let mut builder = IntervalDayTimeBuilder::new(source.len()); - for v in source { - match v { - Some(array) => builder.append_value(i64::from_le_bytes( - array.data()[4..12].try_into().unwrap(), - )), - None => builder.append_null(), - }? - } - - Ok(builder.finish()) - } -} - -pub struct Int96ArrayConverter { - pub timezone: Option, -} - -impl Converter>, TimestampNanosecondArray> for Int96ArrayConverter { - fn convert(&self, source: Vec>) -> Result { - Ok(TimestampNanosecondArray::from_opt_vec( - source - .into_iter() - .map(|int96| int96.map(|val| val.to_i64() * 1_000_000)) - .collect(), - self.timezone.clone(), - )) - } -} - -pub struct Utf8ArrayConverter {} - -impl Converter>, StringArray> for Utf8ArrayConverter { - fn convert(&self, source: Vec>) -> Result { - let data_size = source - .iter() - .map(|x| x.as_ref().map(|b| b.len()).unwrap_or(0)) - .sum(); - - let mut builder = StringBuilder::with_capacity(source.len(), data_size); - for v in source { - match v { - Some(array) => builder.append_value(array.as_utf8()?), - None => builder.append_null(), - }? - } - - Ok(builder.finish()) - } -} - -pub struct LargeUtf8ArrayConverter {} - -impl Converter>, LargeStringArray> for LargeUtf8ArrayConverter { - fn convert(&self, source: Vec>) -> Result { - let data_size = source - .iter() - .map(|x| x.as_ref().map(|b| b.len()).unwrap_or(0)) - .sum(); - - let mut builder = LargeStringBuilder::with_capacity(source.len(), data_size); - for v in source { - match v { - Some(array) => builder.append_value(array.as_utf8()?), - None => builder.append_null(), - }? - } - - Ok(builder.finish()) - } -} - -pub struct BinaryArrayConverter {} - -impl Converter>, BinaryArray> for BinaryArrayConverter { - fn convert(&self, source: Vec>) -> Result { - let mut builder = BinaryBuilder::new(source.len()); - for v in source { - match v { - Some(array) => builder.append_value(array.data()), - None => builder.append_null(), - }? - } - - Ok(builder.finish()) - } -} - -pub struct LargeBinaryArrayConverter {} - -impl Converter>, LargeBinaryArray> for LargeBinaryArrayConverter { - fn convert(&self, source: Vec>) -> Result { - let mut builder = LargeBinaryBuilder::new(source.len()); - for v in source { - match v { - Some(array) => builder.append_value(array.data()), - None => builder.append_null(), - }? - } - - Ok(builder.finish()) - } -} - -pub type Utf8Converter = - ArrayRefConverter>, StringArray, Utf8ArrayConverter>; -pub type LargeUtf8Converter = - ArrayRefConverter>, LargeStringArray, LargeUtf8ArrayConverter>; -pub type BinaryConverter = - ArrayRefConverter>, BinaryArray, BinaryArrayConverter>; -pub type LargeBinaryConverter = ArrayRefConverter< - Vec>, - LargeBinaryArray, - LargeBinaryArrayConverter, ->; - -pub type Int96Converter = - ArrayRefConverter>, TimestampNanosecondArray, Int96ArrayConverter>; - -pub type FixedLenBinaryConverter = ArrayRefConverter< - Vec>, - FixedSizeBinaryArray, - FixedSizeArrayConverter, ->; -pub type IntervalYearMonthConverter = ArrayRefConverter< - Vec>, - IntervalYearMonthArray, - IntervalYearMonthArrayConverter, ->; -pub type IntervalDayTimeConverter = ArrayRefConverter< - Vec>, - IntervalDayTimeArray, - IntervalDayTimeArrayConverter, ->; - -pub type DecimalConverter = ArrayRefConverter< - Vec>, - DecimalArray, - DecimalArrayConverter, ->; - -pub struct FromConverter { - _source: PhantomData, - _dest: PhantomData, -} - -impl FromConverter -where - T: From, -{ - pub fn new() -> Self { - Self { - _source: PhantomData, - _dest: PhantomData, - } - } -} - -impl Converter for FromConverter -where - T: From, -{ - fn convert(&self, source: S) -> Result { - Ok(T::from(source)) - } -} - -pub struct ArrayRefConverter { - _source: PhantomData, - _array: PhantomData, - converter: C, -} - -impl ArrayRefConverter -where - A: Array + 'static, - C: Converter + 'static, -{ - pub fn new(converter: C) -> Self { - Self { - _source: PhantomData, - _array: PhantomData, - converter, - } - } -} - -impl Converter for ArrayRefConverter -where - A: Array + 'static, - C: Converter + 'static, -{ - fn convert(&self, source: S) -> Result { - self.converter - .convert(source) - .map(|array| Arc::new(array) as ArrayRef) - } -} diff --git a/parquet/src/arrow/buffer/dictionary_buffer.rs b/parquet/src/arrow/buffer/dictionary_buffer.rs index b64b2946b91a..ae9e3590de3f 100644 --- a/parquet/src/arrow/buffer/dictionary_buffer.rs +++ b/parquet/src/arrow/buffer/dictionary_buffer.rs @@ -49,6 +49,7 @@ impl Default for DictionaryBuffer { impl DictionaryBuffer { + #[allow(unused)] pub fn len(&self) -> usize { match self { Self::Dict { keys, .. } => keys.len(), diff --git a/parquet/src/arrow/buffer/mod.rs b/parquet/src/arrow/buffer/mod.rs index 5ee89aa1a782..cbc795d94f57 100644 --- a/parquet/src/arrow/buffer/mod.rs +++ b/parquet/src/arrow/buffer/mod.rs @@ -18,6 +18,5 @@ //! Logic for reading data into arrow buffers pub mod bit_util; -pub mod converter; pub mod dictionary_buffer; pub mod offset_buffer; diff --git a/parquet/src/arrow/decoder/delta_byte_array.rs b/parquet/src/arrow/decoder/delta_byte_array.rs new file mode 100644 index 000000000000..af73f4f25eb9 --- /dev/null +++ b/parquet/src/arrow/decoder/delta_byte_array.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::data_type::Int32Type; +use crate::encodings::decoding::{Decoder, DeltaBitPackDecoder}; +use crate::errors::{ParquetError, Result}; +use crate::util::memory::ByteBufferPtr; + +/// Decoder for `Encoding::DELTA_BYTE_ARRAY` +pub struct DeltaByteArrayDecoder { + prefix_lengths: Vec, + suffix_lengths: Vec, + data: ByteBufferPtr, + length_offset: usize, + data_offset: usize, + last_value: Vec, +} + +impl DeltaByteArrayDecoder { + /// Create a new [`DeltaByteArrayDecoder`] with the provided data page + pub fn new(data: ByteBufferPtr) -> Result { + let mut prefix = DeltaBitPackDecoder::::new(); + prefix.set_data(data.all(), 0)?; + + let num_prefix = prefix.values_left(); + let mut prefix_lengths = vec![0; num_prefix]; + assert_eq!(prefix.get(&mut prefix_lengths)?, num_prefix); + + let mut suffix = DeltaBitPackDecoder::::new(); + suffix.set_data(data.start_from(prefix.get_offset()), 0)?; + + let num_suffix = suffix.values_left(); + let mut suffix_lengths = vec![0; num_suffix]; + assert_eq!(suffix.get(&mut suffix_lengths)?, num_suffix); + + if num_prefix != num_suffix { + return Err(general_err!(format!( + "inconsistent DELTA_BYTE_ARRAY lengths, prefixes: {}, suffixes: {}", + num_prefix, num_suffix + ))); + } + + assert_eq!(prefix_lengths.len(), suffix_lengths.len()); + + Ok(Self { + prefix_lengths, + suffix_lengths, + data, + length_offset: 0, + data_offset: prefix.get_offset() + suffix.get_offset(), + last_value: vec![], + }) + } + + /// Returns the number of values remaining + pub fn remaining(&self) -> usize { + self.prefix_lengths.len() - self.length_offset + } + + /// Read up to `len` values, returning the number of values read + /// and calling `f` with each decoded byte slice + /// + /// Will short-circuit and return on error + pub fn read(&mut self, len: usize, mut f: F) -> Result + where + F: FnMut(&[u8]) -> Result<()>, + { + let to_read = len.min(self.remaining()); + + let length_range = self.length_offset..self.length_offset + to_read; + let iter = self.prefix_lengths[length_range.clone()] + .iter() + .zip(&self.suffix_lengths[length_range]); + + let data = self.data.as_ref(); + + for (prefix_length, suffix_length) in iter { + let prefix_length = *prefix_length as usize; + let suffix_length = *suffix_length as usize; + + if self.data_offset + suffix_length > self.data.len() { + return Err(ParquetError::EOF("eof decoding byte array".into())); + } + + self.last_value.truncate(prefix_length); + self.last_value.extend_from_slice( + &data[self.data_offset..self.data_offset + suffix_length], + ); + f(&self.last_value)?; + + self.data_offset += suffix_length; + } + + self.length_offset += to_read; + Ok(to_read) + } + + /// Skip up to `to_skip` values, returning the number of values skipped + pub fn skip(&mut self, to_skip: usize) -> Result { + let to_skip = to_skip.min(self.prefix_lengths.len() - self.length_offset); + + let length_range = self.length_offset..self.length_offset + to_skip; + let iter = self.prefix_lengths[length_range.clone()] + .iter() + .zip(&self.suffix_lengths[length_range]); + + let data = self.data.as_ref(); + + for (prefix_length, suffix_length) in iter { + let prefix_length = *prefix_length as usize; + let suffix_length = *suffix_length as usize; + + if self.data_offset + suffix_length > self.data.len() { + return Err(ParquetError::EOF("eof decoding byte array".into())); + } + + self.last_value.truncate(prefix_length); + self.last_value.extend_from_slice( + &data[self.data_offset..self.data_offset + suffix_length], + ); + self.data_offset += suffix_length; + } + self.length_offset += to_skip; + Ok(to_skip) + } +} diff --git a/parquet/src/arrow/decoder/dictionary_index.rs b/parquet/src/arrow/decoder/dictionary_index.rs new file mode 100644 index 000000000000..3d258309dd3b --- /dev/null +++ b/parquet/src/arrow/decoder/dictionary_index.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::encodings::rle::RleDecoder; +use crate::errors::Result; +use crate::util::memory::ByteBufferPtr; + +/// Decoder for `Encoding::RLE_DICTIONARY` indices +pub struct DictIndexDecoder { + /// Decoder for the dictionary offsets array + decoder: RleDecoder, + + /// We want to decode the offsets in chunks so we will maintain an internal buffer of decoded + /// offsets + index_buf: Box<[i32; 1024]>, + /// Current length of `index_buf` + index_buf_len: usize, + /// Current offset into `index_buf`. If `index_buf_offset` == `index_buf_len` then we've consumed + /// the entire buffer and need to decode another chunk of offsets. + index_offset: usize, + + /// This is a maximum as the null count is not always known, e.g. value data from + /// a v1 data page + max_remaining_values: usize, +} + +impl DictIndexDecoder { + /// Create a new [`DictIndexDecoder`] with the provided data page, the number of levels + /// associated with this data page, and the number of non-null values (if known) + pub fn new( + data: ByteBufferPtr, + num_levels: usize, + num_values: Option, + ) -> Self { + let bit_width = data[0]; + let mut decoder = RleDecoder::new(bit_width); + decoder.set_data(data.start_from(1)); + + Self { + decoder, + index_buf: Box::new([0; 1024]), + index_buf_len: 0, + index_offset: 0, + max_remaining_values: num_values.unwrap_or(num_levels), + } + } + + /// Read up to `len` values, returning the number of values read + /// and calling `f` with each decoded dictionary index + /// + /// Will short-circuit and return on error + pub fn read Result<()>>( + &mut self, + len: usize, + mut f: F, + ) -> Result { + let mut values_read = 0; + + while values_read != len && self.max_remaining_values != 0 { + if self.index_offset == self.index_buf_len { + // We've consumed the entire index buffer so we need to reload it before proceeding + let read = self.decoder.get_batch(self.index_buf.as_mut())?; + if read == 0 { + break; + } + self.index_buf_len = read; + self.index_offset = 0; + } + + let to_read = (len - values_read) + .min(self.index_buf_len - self.index_offset) + .min(self.max_remaining_values); + + f(&self.index_buf[self.index_offset..self.index_offset + to_read])?; + + self.index_offset += to_read; + values_read += to_read; + self.max_remaining_values -= to_read; + } + Ok(values_read) + } + + /// Skip up to `to_skip` values, returning the number of values skipped + pub fn skip(&mut self, to_skip: usize) -> Result { + let to_skip = to_skip.min(self.max_remaining_values); + + let mut values_skip = 0; + while values_skip < to_skip { + if self.index_offset == self.index_buf_len { + // Instead of reloading the buffer, just skip in the decoder + let skip = self.decoder.skip(to_skip - values_skip)?; + + if skip == 0 { + break; + } + + self.max_remaining_values -= skip; + values_skip += skip; + } else { + // We still have indices buffered, so skip within the buffer + let skip = + (to_skip - values_skip).min(self.index_buf_len - self.index_offset); + + self.index_offset += skip; + self.max_remaining_values -= skip; + values_skip += skip; + } + } + Ok(values_skip) + } +} diff --git a/parquet/src/arrow/decoder/mod.rs b/parquet/src/arrow/decoder/mod.rs new file mode 100644 index 000000000000..dc1000ffd15e --- /dev/null +++ b/parquet/src/arrow/decoder/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Specialized decoders optimised for decoding to arrow format + +mod delta_byte_array; +mod dictionary_index; + +pub use delta_byte_array::DeltaByteArrayDecoder; +pub use dictionary_index::DictIndexDecoder; diff --git a/parquet/src/arrow/mod.rs b/parquet/src/arrow/mod.rs index 3aee7cf42cbc..c0de656bf9c5 100644 --- a/parquet/src/arrow/mod.rs +++ b/parquet/src/arrow/mod.rs @@ -119,19 +119,20 @@ //!} //! ``` -experimental_mod!(array_reader); +experimental!(mod array_reader); pub mod arrow_reader; pub mod arrow_writer; mod buffer; +mod decoder; #[cfg(feature = "async")] pub mod async_reader; mod record_reader; -experimental_mod!(schema); +experimental!(mod schema); -pub use self::arrow_reader::ArrowReader; -pub use self::arrow_reader::ParquetFileArrowReader; +#[allow(deprecated)] +pub use self::arrow_reader::{ArrowReader, ParquetFileArrowReader}; pub use self::arrow_writer::ArrowWriter; #[cfg(feature = "async")] pub use self::async_reader::ParquetRecordBatchStreamBuilder; diff --git a/parquet/src/arrow/record_reader/buffer.rs b/parquet/src/arrow/record_reader/buffer.rs index 7101eaa9ccc9..64ea38f801d9 100644 --- a/parquet/src/arrow/record_reader/buffer.rs +++ b/parquet/src/arrow/record_reader/buffer.rs @@ -18,6 +18,7 @@ use std::marker::PhantomData; use crate::arrow::buffer::bit_util::iter_set_bits_rev; +use crate::data_type::Int96; use arrow::buffer::{Buffer, MutableBuffer}; use arrow::datatypes::ArrowNativeType; @@ -85,6 +86,7 @@ impl ScalarValue for u64 {} impl ScalarValue for i64 {} impl ScalarValue for f32 {} impl ScalarValue for f64 {} +impl ScalarValue for Int96 {} /// A typed buffer similar to [`Vec`] but using [`MutableBuffer`] for storage #[derive(Debug)] diff --git a/parquet/src/arrow/record_reader/definition_levels.rs b/parquet/src/arrow/record_reader/definition_levels.rs index 21526f21f6ce..2d65db77fa69 100644 --- a/parquet/src/arrow/record_reader/definition_levels.rs +++ b/parquet/src/arrow/record_reader/definition_levels.rs @@ -20,6 +20,7 @@ use std::ops::Range; use arrow::array::BooleanBufferBuilder; use arrow::bitmap::Bitmap; use arrow::buffer::Buffer; +use arrow::util::bit_chunk_iterator::UnalignedBitChunk; use crate::arrow::buffer::bit_util::count_set_bits; use crate::arrow::record_reader::buffer::BufferQueue; @@ -146,51 +147,50 @@ impl LevelsBufferSlice for DefinitionLevelBuffer { } } +enum MaybePacked { + Packed(PackedDecoder), + Fallback(ColumnLevelDecoderImpl), +} + pub struct DefinitionLevelBufferDecoder { max_level: i16, - encoding: Encoding, - data: Option, - column_decoder: Option, - packed_decoder: Option, + decoder: MaybePacked, +} + +impl DefinitionLevelBufferDecoder { + pub fn new(max_level: i16, packed: bool) -> Self { + let decoder = match packed { + true => MaybePacked::Packed(PackedDecoder::new()), + false => MaybePacked::Fallback(ColumnLevelDecoderImpl::new(max_level)), + }; + + Self { max_level, decoder } + } } impl ColumnLevelDecoder for DefinitionLevelBufferDecoder { type Slice = DefinitionLevelBuffer; - fn new(max_level: i16, encoding: Encoding, data: ByteBufferPtr) -> Self { - Self { - max_level, - encoding, - data: Some(data), - column_decoder: None, - packed_decoder: None, + fn set_data(&mut self, encoding: Encoding, data: ByteBufferPtr) { + match &mut self.decoder { + MaybePacked::Packed(d) => d.set_data(encoding, data), + MaybePacked::Fallback(d) => d.set_data(encoding, data), } } - fn read( - &mut self, - writer: &mut Self::Slice, - range: Range, - ) -> crate::errors::Result { - match &mut writer.inner { - BufferInner::Full { - levels, - nulls, - max_level, - } => { + fn read(&mut self, writer: &mut Self::Slice, range: Range) -> Result { + match (&mut writer.inner, &mut self.decoder) { + ( + BufferInner::Full { + levels, + nulls, + max_level, + }, + MaybePacked::Fallback(decoder), + ) => { assert_eq!(self.max_level, *max_level); assert_eq!(range.start + writer.len, nulls.len()); - let decoder = match self.data.take() { - Some(data) => self.column_decoder.insert( - ColumnLevelDecoderImpl::new(self.max_level, self.encoding, data), - ), - None => self - .column_decoder - .as_mut() - .expect("consistent null_mask_only"), - }; - levels.resize(range.end + writer.len); let slice = &mut levels.as_slice_mut()[writer.len..]; @@ -203,22 +203,13 @@ impl ColumnLevelDecoder for DefinitionLevelBufferDecoder { Ok(levels_read) } - BufferInner::Mask { nulls } => { + (BufferInner::Mask { nulls }, MaybePacked::Packed(decoder)) => { assert_eq!(self.max_level, 1); assert_eq!(range.start + writer.len, nulls.len()); - let decoder = match self.data.take() { - Some(data) => self - .packed_decoder - .insert(PackedDecoder::new(self.encoding, data)), - None => self - .packed_decoder - .as_mut() - .expect("consistent null_mask_only"), - }; - decoder.read(nulls, range.end - range.start) } + _ => unreachable!("inconsistent null mask"), } } } @@ -226,10 +217,15 @@ impl ColumnLevelDecoder for DefinitionLevelBufferDecoder { impl DefinitionLevelDecoder for DefinitionLevelBufferDecoder { fn skip_def_levels( &mut self, - _num_levels: usize, - _max_def_level: i16, + num_levels: usize, + max_def_level: i16, ) -> Result<(usize, usize)> { - Err(nyi_err!("https://github.com/apache/arrow-rs/issues/1792")) + match &mut self.decoder { + MaybePacked::Fallback(decoder) => { + decoder.skip_def_levels(num_levels, max_def_level) + } + MaybePacked::Packed(decoder) => decoder.skip(num_levels), + } } } @@ -306,28 +302,30 @@ impl PackedDecoder { } impl PackedDecoder { - fn new(encoding: Encoding, data: ByteBufferPtr) -> Self { - match encoding { - Encoding::RLE => Self { - data, - data_offset: 0, - rle_left: 0, - rle_value: false, - packed_count: 0, - packed_offset: 0, - }, - Encoding::BIT_PACKED => Self { - data_offset: 0, - rle_left: 0, - rle_value: false, - packed_count: data.len() * 8, - packed_offset: 0, - data, - }, - _ => unreachable!("invalid level encoding: {}", encoding), + fn new() -> Self { + Self { + data: ByteBufferPtr::new(vec![]), + data_offset: 0, + rle_left: 0, + rle_value: false, + packed_count: 0, + packed_offset: 0, } } + fn set_data(&mut self, encoding: Encoding, data: ByteBufferPtr) { + self.rle_left = 0; + self.rle_value = false; + self.packed_offset = 0; + self.packed_count = match encoding { + Encoding::RLE => 0, + Encoding::BIT_PACKED => data.len() * 8, + _ => unreachable!("invalid level encoding: {}", encoding), + }; + self.data = data; + self.data_offset = 0; + } + fn read(&mut self, buffer: &mut BooleanBufferBuilder, len: usize) -> Result { let mut read = 0; while read != len { @@ -354,6 +352,41 @@ impl PackedDecoder { } Ok(read) } + + /// Skips `level_num` definition levels + /// + /// Returns the number of values skipped and the number of levels skipped + fn skip(&mut self, level_num: usize) -> Result<(usize, usize)> { + let mut skipped_value = 0; + let mut skipped_level = 0; + while skipped_level != level_num { + if self.rle_left != 0 { + let to_skip = self.rle_left.min(level_num - skipped_level); + self.rle_left -= to_skip; + skipped_level += to_skip; + if self.rle_value { + skipped_value += to_skip; + } + } else if self.packed_count != self.packed_offset { + let to_skip = (self.packed_count - self.packed_offset) + .min(level_num - skipped_level); + let offset = self.data_offset * 8 + self.packed_offset; + let bit_chunk = + UnalignedBitChunk::new(self.data.as_ref(), offset, to_skip); + skipped_value += bit_chunk.count_ones(); + self.packed_offset += to_skip; + skipped_level += to_skip; + if self.packed_offset == self.packed_count { + self.data_offset += self.packed_count / 8; + } + } else if self.data_offset == self.data.len() { + break; + } else { + self.next_rle_block()? + } + } + Ok((skipped_value, skipped_level)) + } } #[cfg(test)] @@ -375,13 +408,14 @@ mod tests { let mut encoder = RleEncoder::new(1, 1024); for _ in 0..len { let bool = rng.gen_bool(0.8); - assert!(encoder.put(bool as u64).unwrap()); + encoder.put(bool as u64); expected.append(bool); } assert_eq!(expected.len(), len); - let encoded = encoder.consume().unwrap(); - let mut decoder = PackedDecoder::new(Encoding::RLE, ByteBufferPtr::new(encoded)); + let encoded = encoder.consume(); + let mut decoder = PackedDecoder::new(); + decoder.set_data(Encoding::RLE, ByteBufferPtr::new(encoded)); // Decode data in random length intervals let mut decoded = BooleanBufferBuilder::new(len); @@ -399,6 +433,67 @@ mod tests { assert_eq!(decoded.as_slice(), expected.as_slice()); } + #[test] + fn test_packed_decoder_skip() { + let mut rng = thread_rng(); + let len: usize = rng.gen_range(512..1024); + + let mut expected = BooleanBufferBuilder::new(len); + let mut encoder = RleEncoder::new(1, 1024); + + let mut total_value = 0; + for _ in 0..len { + let bool = rng.gen_bool(0.8); + encoder.put(bool as u64); + expected.append(bool); + if bool { + total_value += 1; + } + } + assert_eq!(expected.len(), len); + + let encoded = encoder.consume(); + let mut decoder = PackedDecoder::new(); + decoder.set_data(Encoding::RLE, ByteBufferPtr::new(encoded)); + + let mut skip_value = 0; + let mut read_value = 0; + let mut skip_level = 0; + let mut read_level = 0; + + loop { + let offset = skip_level + read_level; + let remaining_levels = len - offset; + if remaining_levels == 0 { + break; + } + let to_read_or_skip_level = rng.gen_range(1..=remaining_levels); + if rng.gen_bool(0.5) { + let (skip_val_num, skip_level_num) = + decoder.skip(to_read_or_skip_level).unwrap(); + skip_value += skip_val_num; + skip_level += skip_level_num + } else { + let mut decoded = BooleanBufferBuilder::new(to_read_or_skip_level); + let read_level_num = + decoder.read(&mut decoded, to_read_or_skip_level).unwrap(); + read_level += read_level_num; + for i in 0..read_level_num { + assert!(!decoded.is_empty()); + //check each read bit + let read_bit = decoded.get_bit(i); + if read_bit { + read_value += 1; + } + let expect_bit = expected.get_bit(i + offset); + assert_eq!(read_bit, expect_bit); + } + } + } + assert_eq!(read_level + skip_level, len); + assert_eq!(read_value + skip_value, total_value); + } + #[test] fn test_split_off() { let t = Type::primitive_type_builder("col", PhysicalType::INT32) diff --git a/parquet/src/arrow/record_reader/mod.rs b/parquet/src/arrow/record_reader/mod.rs index d2720aedeb86..b7318af9e85a 100644 --- a/parquet/src/arrow/record_reader/mod.rs +++ b/parquet/src/arrow/record_reader/mod.rs @@ -45,7 +45,9 @@ pub(crate) const MIN_BATCH_SIZE: usize = 1024; pub type RecordReader = GenericRecordReader::T>, ColumnValueDecoderImpl>; -#[doc(hidden)] +pub(crate) type ColumnReader = + GenericColumnReader; + /// A generic stateful column reader that delimits semantic records /// /// This type is hidden from the docs, and relies on private traits with no @@ -57,12 +59,11 @@ pub struct GenericRecordReader { records: V, def_levels: Option, rep_levels: Option>, - column_reader: Option< - GenericColumnReader, - >, + column_reader: Option>, /// Number of records accumulated in records num_records: usize, + /// Number of values `num_records` contains. num_values: usize, @@ -77,38 +78,23 @@ where { /// Create a new [`GenericRecordReader`] pub fn new(desc: ColumnDescPtr) -> Self { - Self::new_with_options(desc, false) + Self::new_with_records(desc, V::default()) } +} - /// Create a new [`GenericRecordReader`] with the ability to only generate the bitmask - /// - /// If `null_mask_only` is true only the null bitmask will be generated and - /// [`Self::consume_def_levels`] and [`Self::consume_rep_levels`] will always return `None` - /// - /// It is insufficient to solely check that that the max definition level is 1 as we - /// need there to be no nullable parent array that will required decoded definition levels - /// - /// In particular consider the case of: - /// - /// ```ignore - /// message nested { - /// OPTIONAL Group group { - /// REQUIRED INT32 leaf; - /// } - /// } - /// ``` - /// - /// The maximum definition level of leaf is 1, however, we still need to decode the - /// definition levels so that the parent group can be constructed correctly - /// - pub(crate) fn new_with_options(desc: ColumnDescPtr, null_mask_only: bool) -> Self { +impl GenericRecordReader +where + V: ValuesBuffer, + CV: ColumnValueDecoder, +{ + pub fn new_with_records(desc: ColumnDescPtr, records: V) -> Self { let def_levels = (desc.max_def_level() > 0) - .then(|| DefinitionLevelBuffer::new(&desc, null_mask_only)); + .then(|| DefinitionLevelBuffer::new(&desc, packed_null_mask(&desc))); let rep_levels = (desc.max_rep_level() > 0).then(ScalarBuffer::new); Self { - records: Default::default(), + records, def_levels, rep_levels, column_reader: None, @@ -121,9 +107,25 @@ where /// Set the current page reader. pub fn set_page_reader(&mut self, page_reader: Box) -> Result<()> { - self.column_reader = Some(GenericColumnReader::new( + let descr = &self.column_desc; + let values_decoder = CV::new(descr); + + let def_level_decoder = (descr.max_def_level() != 0).then(|| { + DefinitionLevelBufferDecoder::new( + descr.max_def_level(), + packed_null_mask(descr), + ) + }); + + let rep_level_decoder = (descr.max_rep_level() != 0) + .then(|| ColumnLevelDecoderImpl::new(descr.max_rep_level())); + + self.column_reader = Some(GenericColumnReader::new_with_decoders( self.column_desc.clone(), page_reader, + values_decoder, + def_level_decoder, + rep_level_decoder, )); Ok(()) } @@ -143,7 +145,11 @@ where loop { // Try to find some records from buffers that has been read into memory // but not counted as seen records. - let end_of_column = !self.column_reader.as_mut().unwrap().has_next()?; + + // Check to see if the column is exhausted. Only peek the next page since in + // case we are reading to a page boundary and do not actually need to read + // the next page. + let end_of_column = !self.column_reader.as_mut().unwrap().peek_next()?; let (record_count, value_count) = self.count_records(num_records - records_read, end_of_column); @@ -152,7 +158,9 @@ where self.num_values += value_count; records_read += record_count; - if records_read == num_records || end_of_column { + if records_read == num_records + || !self.column_reader.as_mut().unwrap().has_next()? + { break; } @@ -196,7 +204,7 @@ where pub fn skip_records(&mut self, num_records: usize) -> Result { // First need to clear the buffer let end_of_column = match self.column_reader.as_mut() { - Some(reader) => !reader.has_next()?, + Some(reader) => !reader.peek_next()?, None => return Ok(0), }; @@ -206,12 +214,6 @@ where self.num_records += buffered_records; self.num_values += buffered_values; - self.consume_def_levels(); - self.consume_rep_levels(); - self.consume_record_data(); - self.consume_bitmap(); - self.reset(); - let remaining = num_records - buffered_records; if remaining == 0 { @@ -228,6 +230,7 @@ where } /// Returns number of records stored in buffer. + #[allow(unused)] pub fn num_records(&self) -> usize { self.num_records } @@ -393,6 +396,15 @@ where } } +/// Returns true if we do not need to unpack the nullability for this column, this is +/// only possible if the max defiition level is 1, and corresponds to nulls at the +/// leaf level, as opposed to a nullable parent nested type +fn packed_null_mask(descr: &ColumnDescPtr) -> bool { + descr.max_def_level() == 1 + && descr.max_rep_level() == 0 + && descr.self_type().is_optional() +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -790,4 +802,186 @@ mod tests { assert_eq!(record_reader.num_records(), 8); assert_eq!(record_reader.num_values(), 14); } + + #[test] + fn test_skip_required_records() { + // Construct column schema + let message_type = " + message test_schema { + REQUIRED INT32 leaf; + } + "; + let desc = parse_message_type(message_type) + .map(|t| SchemaDescriptor::new(Arc::new(t))) + .map(|s| s.column(0)) + .unwrap(); + + // Construct record reader + let mut record_reader = RecordReader::::new(desc.clone()); + + // First page + + // Records data: + // test_schema + // leaf: 4 + // test_schema + // leaf: 7 + // test_schema + // leaf: 6 + // test_schema + // left: 3 + // test_schema + // left: 2 + { + let values = [4, 7, 6, 3, 2]; + let mut pb = DataPageBuilderImpl::new(desc.clone(), 5, true); + pb.add_values::(Encoding::PLAIN, &values); + let page = pb.consume(); + + let page_reader = Box::new(InMemoryPageReader::new(vec![page])); + record_reader.set_page_reader(page_reader).unwrap(); + assert_eq!(2, record_reader.skip_records(2).unwrap()); + assert_eq!(0, record_reader.num_records()); + assert_eq!(0, record_reader.num_values()); + assert_eq!(3, record_reader.read_records(3).unwrap()); + assert_eq!(3, record_reader.num_records()); + assert_eq!(3, record_reader.num_values()); + } + + // Second page + + // Records data: + // test_schema + // leaf: 8 + // test_schema + // leaf: 9 + { + let values = [8, 9]; + let mut pb = DataPageBuilderImpl::new(desc, 2, true); + pb.add_values::(Encoding::PLAIN, &values); + let page = pb.consume(); + + let page_reader = Box::new(InMemoryPageReader::new(vec![page])); + record_reader.set_page_reader(page_reader).unwrap(); + assert_eq!(2, record_reader.skip_records(10).unwrap()); + assert_eq!(3, record_reader.num_records()); + assert_eq!(3, record_reader.num_values()); + assert_eq!(0, record_reader.read_records(10).unwrap()); + } + + let mut bb = Int32BufferBuilder::new(3); + bb.append_slice(&[6, 3, 2]); + let expected_buffer = bb.finish(); + assert_eq!(expected_buffer, record_reader.consume_record_data()); + assert_eq!(None, record_reader.consume_def_levels()); + assert_eq!(None, record_reader.consume_bitmap()); + } + + #[test] + fn test_skip_optional_records() { + // Construct column schema + let message_type = " + message test_schema { + OPTIONAL Group test_struct { + OPTIONAL INT32 leaf; + } + } + "; + + let desc = parse_message_type(message_type) + .map(|t| SchemaDescriptor::new(Arc::new(t))) + .map(|s| s.column(0)) + .unwrap(); + + // Construct record reader + let mut record_reader = RecordReader::::new(desc.clone()); + + // First page + + // Records data: + // test_schema + // test_struct + // test_schema + // test_struct + // leaf: 7 + // test_schema + // test_schema + // test_struct + // leaf: 6 + // test_schema + // test_struct + // leaf: 6 + { + let values = [7, 6, 3]; + //empty, non-empty, empty, non-empty, non-empty + let def_levels = [1i16, 2i16, 0i16, 2i16, 2i16]; + let mut pb = DataPageBuilderImpl::new(desc.clone(), 5, true); + pb.add_def_levels(2, &def_levels); + pb.add_values::(Encoding::PLAIN, &values); + let page = pb.consume(); + + let page_reader = Box::new(InMemoryPageReader::new(vec![page])); + record_reader.set_page_reader(page_reader).unwrap(); + assert_eq!(2, record_reader.skip_records(2).unwrap()); + assert_eq!(0, record_reader.num_records()); + assert_eq!(0, record_reader.num_values()); + assert_eq!(3, record_reader.read_records(3).unwrap()); + assert_eq!(3, record_reader.num_records()); + assert_eq!(3, record_reader.num_values()); + } + + // Second page + + // Records data: + // test_schema + // test_schema + // test_struct + // left: 8 + { + let values = [8]; + //empty, non-empty + let def_levels = [0i16, 2i16]; + let mut pb = DataPageBuilderImpl::new(desc, 2, true); + pb.add_def_levels(2, &def_levels); + pb.add_values::(Encoding::PLAIN, &values); + let page = pb.consume(); + + let page_reader = Box::new(InMemoryPageReader::new(vec![page])); + record_reader.set_page_reader(page_reader).unwrap(); + assert_eq!(2, record_reader.skip_records(10).unwrap()); + assert_eq!(3, record_reader.num_records()); + assert_eq!(3, record_reader.num_values()); + assert_eq!(0, record_reader.read_records(10).unwrap()); + } + + // Verify result def levels + let mut bb = Int16BufferBuilder::new(7); + bb.append_slice(&[0i16, 2i16, 2i16]); + let expected_def_levels = bb.finish(); + assert_eq!( + Some(expected_def_levels), + record_reader.consume_def_levels() + ); + + // Verify bitmap + let expected_valid = &[false, true, true]; + let expected_buffer = Buffer::from_iter(expected_valid.iter().cloned()); + let expected_bitmap = Bitmap::from(expected_buffer); + assert_eq!(Some(expected_bitmap), record_reader.consume_bitmap()); + + // Verify result record data + let actual = record_reader.consume_record_data(); + let actual_values = actual.typed_data::(); + + let expected = &[0, 6, 3]; + assert_eq!(actual_values.len(), expected.len()); + + // Only validate valid values are equal + let iter = expected_valid.iter().zip(actual_values).zip(expected); + for ((valid, actual), expected) in iter { + if *valid { + assert_eq!(actual, expected) + } + } + } } diff --git a/parquet/src/arrow/schema.rs b/parquet/src/arrow/schema.rs index 97611d0ec300..ad5b6b1f5f80 100644 --- a/parquet/src/arrow/schema.rs +++ b/parquet/src/arrow/schema.rs @@ -73,7 +73,7 @@ pub fn parquet_to_arrow_schema_by_columns( // Add the Arrow metadata to the Parquet metadata skipping keys that collide if let Some(arrow_schema) = &maybe_schema { arrow_schema.metadata().iter().for_each(|(k, v)| { - metadata.entry(k.clone()).or_insert(v.clone()); + metadata.entry(k.clone()).or_insert_with(|| v.clone()); }); } @@ -100,7 +100,7 @@ fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Result { Ok(message) => message .header_as_schema() .map(arrow::ipc::convert::fb_to_schema) - .ok_or(arrow_err!("the message is not Arrow Schema")), + .ok_or_else(|| arrow_err!("the message is not Arrow Schema")), Err(err) => { // The flatbuffers implementation returns an error on verification error. Err(arrow_err!( @@ -220,7 +220,7 @@ pub fn parquet_to_arrow_field(parquet_column: &ColumnDescriptor) -> Result usize { +pub fn decimal_length_from_precision(precision: u8) -> usize { (10.0_f64.powi(precision as i32).log2() / 8.0).ceil() as usize } @@ -380,7 +380,8 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_length(*length) .build() } - DataType::Decimal(precision, scale) => { + DataType::Decimal128(precision, scale) + | DataType::Decimal256(precision, scale) => { // Decimal precision determines the Parquet physical type to use. // TODO(ARROW-12018): Enable the below after ARROW-10818 Decimal support // @@ -487,7 +488,7 @@ mod tests { use crate::file::metadata::KeyValue; use crate::{ - arrow::{ArrowReader, ArrowWriter, ParquetFileArrowReader}, + arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter}, schema::{parser::parse_message_type, types::SchemaDescriptor}, }; @@ -531,6 +532,32 @@ mod tests { assert_eq!(&arrow_fields, converted_arrow_schema.fields()); } + #[test] + fn test_decimal_fields() { + let message_type = " + message test_schema { + REQUIRED INT32 decimal1 (DECIMAL(4,2)); + REQUIRED INT64 decimal2 (DECIMAL(12,2)); + REQUIRED FIXED_LEN_BYTE_ARRAY (16) decimal3 (DECIMAL(30,2)); + REQUIRED BYTE_ARRAY decimal4 (DECIMAL(33,2)); + } + "; + + let parquet_group_type = parse_message_type(message_type).unwrap(); + + let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_group_type)); + let converted_arrow_schema = + parquet_to_arrow_schema(&parquet_schema, None).unwrap(); + + let arrow_fields = vec![ + Field::new("decimal1", DataType::Decimal128(4, 2), false), + Field::new("decimal2", DataType::Decimal128(12, 2), false), + Field::new("decimal3", DataType::Decimal128(30, 2), false), + Field::new("decimal4", DataType::Decimal128(33, 2), false), + ]; + assert_eq!(&arrow_fields, converted_arrow_schema.fields()); + } + #[test] fn test_byte_array_fields() { let message_type = " @@ -1206,6 +1233,9 @@ mod tests { OPTIONAL INT64 ts_milli (TIMESTAMP_MILLIS); REQUIRED INT64 ts_micro (TIMESTAMP_MICROS); REQUIRED INT64 ts_nano (TIMESTAMP(NANOS,true)); + REPEATED INT32 int_list; + REPEATED BINARY byte_list; + REPEATED BINARY string_list (UTF8); } "; let parquet_group_type = parse_message_type(message_type).unwrap(); @@ -1252,6 +1282,29 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".to_string())), false, ), + Field::new( + "int_list", + DataType::List(Box::new(Field::new("int_list", DataType::Int32, false))), + false, + ), + Field::new( + "byte_list", + DataType::List(Box::new(Field::new( + "byte_list", + DataType::Binary, + false, + ))), + false, + ), + Field::new( + "string_list", + DataType::List(Box::new(Field::new( + "string_list", + DataType::Utf8, + false, + ))), + false, + ), ]; assert_eq!(arrow_fields, converted_arrow_fields); @@ -1549,9 +1602,9 @@ mod tests { // true, // ), Field::new("c35", DataType::Null, true), - Field::new("c36", DataType::Decimal(2, 1), false), - Field::new("c37", DataType::Decimal(50, 20), false), - Field::new("c38", DataType::Decimal(18, 12), true), + Field::new("c36", DataType::Decimal128(2, 1), false), + Field::new("c37", DataType::Decimal128(50, 20), false), + Field::new("c38", DataType::Decimal128(18, 12), true), Field::new( "c39", DataType::Map( @@ -1635,14 +1688,9 @@ mod tests { writer.close()?; // read file back - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - let read_schema = arrow_reader.get_schema()?; - assert_eq!(schema, read_schema); - - // read all fields by columns - let partial_read_schema = - arrow_reader.get_schema_by_columns(ProjectionMask::all())?; - assert_eq!(schema, partial_read_schema); + let arrow_reader = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + let read_schema = arrow_reader.schema(); + assert_eq!(&schema, read_schema.as_ref()); Ok(()) } @@ -1704,15 +1752,9 @@ mod tests { writer.close()?; // read file back - let mut arrow_reader = ParquetFileArrowReader::try_new(file).unwrap(); - let read_schema = arrow_reader.get_schema()?; - assert_eq!(schema, read_schema); - - // read all fields by columns - let partial_read_schema = - arrow_reader.get_schema_by_columns(ProjectionMask::all())?; - assert_eq!(schema, partial_read_schema); - + let arrow_reader = ParquetRecordBatchReaderBuilder::try_new(file).unwrap(); + let read_schema = arrow_reader.schema(); + assert_eq!(&schema, read_schema.as_ref()); Ok(()) } } diff --git a/parquet/src/arrow/schema/primitive.rs b/parquet/src/arrow/schema/primitive.rs index 0cee5aa1e961..87edd75b0b8d 100644 --- a/parquet/src/arrow/schema/primitive.rs +++ b/parquet/src/arrow/schema/primitive.rs @@ -94,7 +94,7 @@ fn from_parquet(parquet_type: &Type) -> Result { PhysicalType::INT96 => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), PhysicalType::FLOAT => Ok(DataType::Float32), PhysicalType::DOUBLE => Ok(DataType::Float64), - PhysicalType::BYTE_ARRAY => from_byte_array(basic_info), + PhysicalType::BYTE_ARRAY => from_byte_array(basic_info, *precision, *scale), PhysicalType::FIXED_LEN_BYTE_ARRAY => { from_fixed_len_byte_array(basic_info, *scale, *precision, *type_length) } @@ -112,7 +112,7 @@ fn decimal_type(scale: i32, precision: i32) -> Result { .try_into() .map_err(|_| arrow_err!("precision cannot be negative: {}", precision))?; - Ok(DataType::Decimal(precision, scale)) + Ok(DataType::Decimal128(precision, scale)) } fn from_int32(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result { @@ -224,7 +224,7 @@ fn from_int64(info: &BasicTypeInfo, scale: i32, precision: i32) -> Result Result { +fn from_byte_array(info: &BasicTypeInfo, precision: i32, scale: i32) -> Result { match (info.logical_type(), info.converted_type()) { (Some(LogicalType::String), _) => Ok(DataType::Utf8), (Some(LogicalType::Json), _) => Ok(DataType::Binary), @@ -235,6 +235,14 @@ fn from_byte_array(info: &BasicTypeInfo) -> Result { (None, ConvertedType::BSON) => Ok(DataType::Binary), (None, ConvertedType::ENUM) => Ok(DataType::Binary), (None, ConvertedType::UTF8) => Ok(DataType::Utf8), + (Some(LogicalType::Decimal { precision, scale }), _) => Ok(DataType::Decimal128( + precision.try_into().unwrap(), + scale.try_into().unwrap(), + )), + (None, ConvertedType::DECIMAL) => Ok(DataType::Decimal128( + precision.try_into().unwrap(), + scale.try_into().unwrap(), + )), (logical, converted) => Err(arrow_err!( "Unable to convert parquet BYTE_ARRAY logical type {:?} or converted type {}", logical, diff --git a/parquet/src/basic.rs b/parquet/src/basic.rs index 59a0fe07b7de..7adbc8c1b6d0 100644 --- a/parquet/src/basic.rs +++ b/parquet/src/basic.rs @@ -18,7 +18,7 @@ //! Contains Rust mappings for Thrift definition. //! Refer to `parquet.thrift` file to see raw definitions. -use std::{convert, fmt, result, str}; +use std::{fmt, result, str}; use parquet_format as parquet; @@ -42,6 +42,7 @@ pub use parquet_format::{ /// For example INT16 is not included as a type since a good encoding of INT32 /// would handle this. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[allow(non_camel_case_types)] pub enum Type { BOOLEAN, INT32, @@ -62,7 +63,8 @@ pub enum Type { /// /// This struct was renamed from `LogicalType` in version 4.0.0. /// If targeting Parquet format 2.4.0 or above, please use [LogicalType] instead. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(non_camel_case_types)] pub enum ConvertedType { NONE, /// A BYTE_ARRAY actually contains UTF8 encoded chars. @@ -163,7 +165,7 @@ pub enum ConvertedType { /// This is an *entirely new* struct as of version /// 4.0.0. The struct previously named `LogicalType` was renamed to /// [`ConvertedType`]. Please see the README.md for more details. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum LogicalType { String, Map, @@ -196,7 +198,8 @@ pub enum LogicalType { // Mirrors `parquet::FieldRepetitionType` /// Representation of field types in schema. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(non_camel_case_types)] pub enum Repetition { /// Field is required (can not be null) and each record has exactly 1 value. REQUIRED, @@ -212,7 +215,8 @@ pub enum Repetition { /// Encodings supported by Parquet. /// Not all encodings are valid for all types. These enums are also used to specify the /// encoding of definition and repetition levels. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[allow(non_camel_case_types)] pub enum Encoding { /// Default byte encoding. /// - BOOLEAN - 1 bit per value, 0 is false; 1 is true. @@ -277,7 +281,7 @@ pub enum Encoding { // Mirrors `parquet::CompressionCodec` /// Supported compression algorithms. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Compression { UNCOMPRESSED, SNAPPY, @@ -293,7 +297,8 @@ pub enum Compression { /// Available data pages for Parquet file format. /// Note that some of the page types may not be supported. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(non_camel_case_types)] pub enum PageType { DATA_PAGE, INDEX_PAGE, @@ -312,7 +317,8 @@ pub enum PageType { /// /// See reference in /// -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(non_camel_case_types)] pub enum SortOrder { /// Signed (either value or legacy byte-wise) comparison. SIGNED, @@ -327,7 +333,8 @@ pub enum SortOrder { /// /// If column order is undefined, then it is the legacy behaviour and all values should /// be compared as signed values/bytes. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(non_camel_case_types)] pub enum ColumnOrder { /// Column uses the order defined by its logical or physical type /// (if there is no logical type), parquet-format 2.4.0+. @@ -489,7 +496,7 @@ impl fmt::Display for ColumnOrder { // ---------------------------------------------------------------------- // parquet::Type <=> Type conversion -impl convert::From for Type { +impl From for Type { fn from(value: parquet::Type) -> Self { match value { parquet::Type::Boolean => Type::BOOLEAN, @@ -504,7 +511,7 @@ impl convert::From for Type { } } -impl convert::From for parquet::Type { +impl From for parquet::Type { fn from(value: Type) -> Self { match value { Type::BOOLEAN => parquet::Type::Boolean, @@ -522,7 +529,7 @@ impl convert::From for parquet::Type { // ---------------------------------------------------------------------- // parquet::ConvertedType <=> ConvertedType conversion -impl convert::From> for ConvertedType { +impl From> for ConvertedType { fn from(option: Option) -> Self { match option { None => ConvertedType::NONE, @@ -558,7 +565,7 @@ impl convert::From> for ConvertedType { } } -impl convert::From for Option { +impl From for Option { fn from(value: ConvertedType) -> Self { match value { ConvertedType::NONE => None, @@ -595,7 +602,7 @@ impl convert::From for Option { // ---------------------------------------------------------------------- // parquet::LogicalType <=> LogicalType conversion -impl convert::From for LogicalType { +impl From for LogicalType { fn from(value: parquet::LogicalType) -> Self { match value { parquet::LogicalType::STRING(_) => LogicalType::String, @@ -627,7 +634,7 @@ impl convert::From for LogicalType { } } -impl convert::From for parquet::LogicalType { +impl From for parquet::LogicalType { fn from(value: LogicalType) -> Self { match value { LogicalType::String => parquet::LogicalType::STRING(Default::default()), @@ -723,7 +730,7 @@ impl From> for ConvertedType { // ---------------------------------------------------------------------- // parquet::FieldRepetitionType <=> Repetition conversion -impl convert::From for Repetition { +impl From for Repetition { fn from(value: parquet::FieldRepetitionType) -> Self { match value { parquet::FieldRepetitionType::Required => Repetition::REQUIRED, @@ -733,7 +740,7 @@ impl convert::From for Repetition { } } -impl convert::From for parquet::FieldRepetitionType { +impl From for parquet::FieldRepetitionType { fn from(value: Repetition) -> Self { match value { Repetition::REQUIRED => parquet::FieldRepetitionType::Required, @@ -746,7 +753,7 @@ impl convert::From for parquet::FieldRepetitionType { // ---------------------------------------------------------------------- // parquet::Encoding <=> Encoding conversion -impl convert::From for Encoding { +impl From for Encoding { fn from(value: parquet::Encoding) -> Self { match value { parquet::Encoding::Plain => Encoding::PLAIN, @@ -762,7 +769,7 @@ impl convert::From for Encoding { } } -impl convert::From for parquet::Encoding { +impl From for parquet::Encoding { fn from(value: Encoding) -> Self { match value { Encoding::PLAIN => parquet::Encoding::Plain, @@ -781,7 +788,7 @@ impl convert::From for parquet::Encoding { // ---------------------------------------------------------------------- // parquet::CompressionCodec <=> Compression conversion -impl convert::From for Compression { +impl From for Compression { fn from(value: parquet::CompressionCodec) -> Self { match value { parquet::CompressionCodec::Uncompressed => Compression::UNCOMPRESSED, @@ -795,7 +802,7 @@ impl convert::From for Compression { } } -impl convert::From for parquet::CompressionCodec { +impl From for parquet::CompressionCodec { fn from(value: Compression) -> Self { match value { Compression::UNCOMPRESSED => parquet::CompressionCodec::Uncompressed, @@ -812,7 +819,7 @@ impl convert::From for parquet::CompressionCodec { // ---------------------------------------------------------------------- // parquet::PageType <=> PageType conversion -impl convert::From for PageType { +impl From for PageType { fn from(value: parquet::PageType) -> Self { match value { parquet::PageType::DataPage => PageType::DATA_PAGE, @@ -823,7 +830,7 @@ impl convert::From for PageType { } } -impl convert::From for parquet::PageType { +impl From for parquet::PageType { fn from(value: PageType) -> Self { match value { PageType::DATA_PAGE => parquet::PageType::DataPage, @@ -1059,6 +1066,7 @@ mod tests { assert_eq!(ConvertedType::JSON.to_string(), "JSON"); assert_eq!(ConvertedType::BSON.to_string(), "BSON"); assert_eq!(ConvertedType::INTERVAL.to_string(), "INTERVAL"); + assert_eq!(ConvertedType::DECIMAL.to_string(), "DECIMAL") } #[test] @@ -1153,6 +1161,10 @@ mod tests { ConvertedType::from(Some(parquet::ConvertedType::Interval)), ConvertedType::INTERVAL ); + assert_eq!( + ConvertedType::from(Some(parquet::ConvertedType::Decimal)), + ConvertedType::DECIMAL + ) } #[test] @@ -1244,6 +1256,10 @@ mod tests { Some(parquet::ConvertedType::Interval), ConvertedType::INTERVAL.into() ); + assert_eq!( + Some(parquet::ConvertedType::Decimal), + ConvertedType::DECIMAL.into() + ) } #[test] @@ -1409,6 +1425,13 @@ mod tests { .unwrap(), ConvertedType::INTERVAL ); + assert_eq!( + ConvertedType::DECIMAL + .to_string() + .parse::() + .unwrap(), + ConvertedType::DECIMAL + ) } #[test] diff --git a/parquet/src/bin/parquet-fromcsv.rs b/parquet/src/bin/parquet-fromcsv.rs index aa1d50563cd9..827aa7311f58 100644 --- a/parquet/src/bin/parquet-fromcsv.rs +++ b/parquet/src/bin/parquet-fromcsv.rs @@ -439,7 +439,7 @@ mod tests { // test default values assert_eq!(args.input_format, CsvDialect::Csv); assert_eq!(args.batch_size, 1000); - assert_eq!(args.has_header, false); + assert!(!args.has_header); assert_eq!(args.delimiter, None); assert_eq!(args.get_delimiter(), b','); assert_eq!(args.record_terminator, None); @@ -553,7 +553,7 @@ mod tests { Field::new("field5", DataType::Utf8, false), ])); - let reader_builder = configure_reader_builder(&args, arrow_schema.clone()); + let reader_builder = configure_reader_builder(&args, arrow_schema); let builder_debug = format!("{:?}", reader_builder); assert_debug_text(&builder_debug, "has_header", "false"); assert_debug_text(&builder_debug, "delimiter", "Some(44)"); @@ -585,7 +585,7 @@ mod tests { Field::new("field4", DataType::Utf8, false), Field::new("field5", DataType::Utf8, false), ])); - let reader_builder = configure_reader_builder(&args, arrow_schema.clone()); + let reader_builder = configure_reader_builder(&args, arrow_schema); let builder_debug = format!("{:?}", reader_builder); assert_debug_text(&builder_debug, "has_header", "true"); assert_debug_text(&builder_debug, "delimiter", "Some(9)"); diff --git a/parquet/src/bin/parquet-read.rs b/parquet/src/bin/parquet-read.rs index 0530afaa786a..733e56173aa2 100644 --- a/parquet/src/bin/parquet-read.rs +++ b/parquet/src/bin/parquet-read.rs @@ -41,12 +41,13 @@ extern crate parquet; use clap::Parser; use parquet::file::reader::{FileReader, SerializedFileReader}; use parquet::record::Row; +use std::io::{self, Read}; use std::{fs::File, path::Path}; #[derive(Debug, Parser)] #[clap(author, version, about("Binary file to read data from a Parquet file"), long_about = None)] struct Args { - #[clap(short, long, help("Path to a parquet file"))] + #[clap(short, long, help("Path to a parquet file, or - for stdin"))] file_name: String, #[clap( short, @@ -66,10 +67,20 @@ fn main() { let num_records = args.num_records; let json = args.json; - let path = Path::new(&filename); - let file = File::open(&path).expect("Unable to open file"); - let parquet_reader = - SerializedFileReader::new(file).expect("Failed to create reader"); + let parquet_reader: Box = if filename == "-" { + let mut buf = Vec::new(); + io::stdin() + .read_to_end(&mut buf) + .expect("Failed to read stdin into a buffer"); + Box::new( + SerializedFileReader::new(bytes::Bytes::from(buf)) + .expect("Failed to create reader"), + ) + } else { + let path = Path::new(&filename); + let file = File::open(&path).expect("Unable to open file"); + Box::new(SerializedFileReader::new(file).expect("Failed to create reader")) + }; // Use full schema as projected schema let mut iter = parquet_reader @@ -93,6 +104,6 @@ fn print_row(row: &Row, json: bool) { if json { println!("{}", row.to_json_value()) } else { - println!("{}", row.to_string()); + println!("{}", row); } } diff --git a/parquet/src/bin/parquet-schema.rs b/parquet/src/bin/parquet-schema.rs index b875b0e7102b..68c52def7c44 100644 --- a/parquet/src/bin/parquet-schema.rs +++ b/parquet/src/bin/parquet-schema.rs @@ -67,9 +67,9 @@ fn main() { println!("Metadata for file: {}", &filename); println!(); if verbose { - print_parquet_metadata(&mut std::io::stdout(), &metadata); + print_parquet_metadata(&mut std::io::stdout(), metadata); } else { - print_file_metadata(&mut std::io::stdout(), &metadata.file_metadata()); + print_file_metadata(&mut std::io::stdout(), metadata.file_metadata()); } } } diff --git a/parquet/src/column/page.rs b/parquet/src/column/page.rs index 78890f36a47f..ab2d885a23f7 100644 --- a/parquet/src/column/page.rs +++ b/parquet/src/column/page.rs @@ -18,10 +18,11 @@ //! Contains Parquet Page definitions and page reader interface. use crate::basic::{Encoding, PageType}; -use crate::errors::Result; +use crate::errors::{ParquetError, Result}; use crate::file::{metadata::ColumnChunkMetaData, statistics::Statistics}; use crate::schema::types::{ColumnDescPtr, SchemaDescPtr}; use crate::util::memory::ByteBufferPtr; +use parquet_format::PageHeader; /// Parquet Page definition. /// @@ -173,6 +174,12 @@ pub struct PageWriteSpec { pub bytes_written: u64, } +impl Default for PageWriteSpec { + fn default() -> Self { + Self::new() + } +} + impl PageWriteSpec { /// Creates new spec with default page write metrics. pub fn new() -> Self { @@ -188,6 +195,7 @@ impl PageWriteSpec { } /// Contains metadata for a page +#[derive(Clone)] pub struct PageMetadata { /// The number of rows in this page pub num_rows: usize, @@ -196,6 +204,31 @@ pub struct PageMetadata { pub is_dict: bool, } +impl TryFrom<&PageHeader> for PageMetadata { + type Error = ParquetError; + + fn try_from(value: &PageHeader) -> std::result::Result { + match value.type_ { + parquet_format::PageType::DataPage => Ok(PageMetadata { + num_rows: value.data_page_header.as_ref().unwrap().num_values as usize, + is_dict: false, + }), + parquet_format::PageType::DictionaryPage => Ok(PageMetadata { + num_rows: usize::MIN, + is_dict: true, + }), + parquet_format::PageType::DataPageV2 => Ok(PageMetadata { + num_rows: value.data_page_header_v2.as_ref().unwrap().num_rows as usize, + is_dict: false, + }), + other => Err(ParquetError::General(format!( + "page type {:?} cannot be converted to PageMetadata", + other + ))), + } + } +} + /// API for reading pages from a column chunk. /// This offers a iterator like API to get the next page. pub trait PageReader: Iterator> + Send { diff --git a/parquet/src/column/reader.rs b/parquet/src/column/reader.rs index 80174d756791..09254999bdd3 100644 --- a/parquet/src/column/reader.rs +++ b/parquet/src/column/reader.rs @@ -22,13 +22,13 @@ use std::cmp::min; use super::page::{Page, PageReader}; use crate::basic::*; use crate::column::reader::decoder::{ - ColumnValueDecoder, DefinitionLevelDecoder, LevelsBufferSlice, - RepetitionLevelDecoder, ValuesBufferSlice, + ColumnLevelDecoderImpl, ColumnValueDecoder, ColumnValueDecoderImpl, + DefinitionLevelDecoder, LevelsBufferSlice, RepetitionLevelDecoder, ValuesBufferSlice, }; use crate::data_type::*; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; -use crate::util::bit_util::{ceil, num_required_bits}; +use crate::util::bit_util::{ceil, num_required_bits, read_num_bytes}; use crate::util::memory::ByteBufferPtr; pub(crate) mod decoder; @@ -103,17 +103,16 @@ pub fn get_typed_column_reader( /// Typed value reader for a particular primitive column. pub type ColumnReaderImpl = GenericColumnReader< - decoder::ColumnLevelDecoderImpl, - decoder::ColumnLevelDecoderImpl, - decoder::ColumnValueDecoderImpl, + ColumnLevelDecoderImpl, + ColumnLevelDecoderImpl, + ColumnValueDecoderImpl, >; -#[doc(hidden)] /// Reads data for a given column chunk, using the provided decoders: /// -/// - R: [`ColumnLevelDecoder`] used to decode repetition levels -/// - D: [`ColumnLevelDecoder`] used to decode definition levels -/// - V: [`ColumnValueDecoder`] used to decode value data +/// - R: `ColumnLevelDecoder` used to decode repetition levels +/// - D: `ColumnLevelDecoder` used to decode definition levels +/// - V: `ColumnValueDecoder` used to decode value data pub struct GenericColumnReader { descr: ColumnDescPtr, @@ -136,27 +135,47 @@ pub struct GenericColumnReader { values_decoder: V, } -impl GenericColumnReader +impl GenericColumnReader where - R: RepetitionLevelDecoder, - D: DefinitionLevelDecoder, V: ColumnValueDecoder, { /// Creates new column reader based on column descriptor and page reader. pub fn new(descr: ColumnDescPtr, page_reader: Box) -> Self { let values_decoder = V::new(&descr); - Self::new_with_decoder(descr, page_reader, values_decoder) + + let def_level_decoder = (descr.max_def_level() != 0) + .then(|| ColumnLevelDecoderImpl::new(descr.max_def_level())); + + let rep_level_decoder = (descr.max_rep_level() != 0) + .then(|| ColumnLevelDecoderImpl::new(descr.max_rep_level())); + + Self::new_with_decoders( + descr, + page_reader, + values_decoder, + def_level_decoder, + rep_level_decoder, + ) } +} - fn new_with_decoder( +impl GenericColumnReader +where + R: RepetitionLevelDecoder, + D: DefinitionLevelDecoder, + V: ColumnValueDecoder, +{ + pub(crate) fn new_with_decoders( descr: ColumnDescPtr, page_reader: Box, values_decoder: V, + def_level_decoder: Option, + rep_level_decoder: Option, ) -> Self { Self { descr, - def_level_decoder: None, - rep_level_decoder: None, + def_level_decoder, + rep_level_decoder, page_reader, num_buffered_values: 0, num_decoded_values: 0, @@ -176,7 +195,6 @@ where /// /// `values` will be contiguously populated with the non-null values. Note that if the column /// is not required, this may be less than either `batch_size` or the number of levels read - #[inline] pub fn read_batch( &mut self, batch_size: usize, @@ -288,19 +306,25 @@ where // If dictionary, we must read it if metadata.is_dict { - self.read_new_page()?; + self.read_dictionary_page()?; continue; } // If page has less rows than the remaining records to // be skipped, skip entire page - if metadata.num_rows < remaining { + if metadata.num_rows <= remaining { self.page_reader.skip_next_page()?; remaining -= metadata.num_rows; continue; + }; + // because self.num_buffered_values == self.num_decoded_values means + // we need reads a new page and set up the decoders for levels + if !self.read_new_page()? { + return Ok(num_records - remaining); } } + // start skip values in page level let to_read = remaining .min((self.num_buffered_values - self.num_decoded_values) as usize); @@ -338,6 +362,24 @@ where Ok(num_records - remaining) } + /// Read the next page as a dictionary page. If the next page is not a dictionary page, + /// this will return an error. + fn read_dictionary_page(&mut self) -> Result<()> { + match self.page_reader.get_next_page()? { + Some(Page::DictionaryPage { + buf, + num_values, + encoding, + is_sorted, + }) => self + .values_decoder + .set_dict(buf, num_values, encoding, is_sorted), + _ => Err(ParquetError::General( + "Invalid page. Expecting dictionary page".to_string(), + )), + } + } + /// Reads a new page and set up the decoders for levels, values or dictionary. /// Returns false if there's no page left. fn read_new_page(&mut self) -> Result { @@ -384,10 +426,10 @@ where )?; offset += bytes_read; - let decoder = - R::new(max_rep_level, rep_level_encoding, level_data); - - self.rep_level_decoder = Some(decoder); + self.rep_level_decoder + .as_mut() + .unwrap() + .set_data(rep_level_encoding, level_data); } if max_def_level > 0 { @@ -399,10 +441,10 @@ where )?; offset += bytes_read; - let decoder = - D::new(max_def_level, def_level_encoding, level_data); - - self.def_level_decoder = Some(decoder); + self.def_level_decoder + .as_mut() + .unwrap() + .set_data(def_level_encoding, level_data); } self.values_decoder.set_data( @@ -435,26 +477,22 @@ where // DataPage v2 only supports RLE encoding for repetition // levels if self.descr.max_rep_level() > 0 { - let decoder = R::new( - self.descr.max_rep_level(), + self.rep_level_decoder.as_mut().unwrap().set_data( Encoding::RLE, buf.range(0, rep_levels_byte_len as usize), ); - self.rep_level_decoder = Some(decoder); } // DataPage v2 only supports RLE encoding for definition // levels if self.descr.max_def_level() > 0 { - let decoder = D::new( - self.descr.max_def_level(), + self.def_level_decoder.as_mut().unwrap().set_data( Encoding::RLE, buf.range( rep_levels_byte_len as usize, def_levels_byte_len as usize, ), ); - self.def_level_decoder = Some(decoder); } self.values_decoder.set_data( @@ -473,6 +511,28 @@ where } } + /// Check whether there is more data to read from this column, + /// If the current page is fully decoded, this will NOT load the next page + /// into the buffer + #[inline] + pub(crate) fn peek_next(&mut self) -> Result { + if self.num_buffered_values == 0 + || self.num_buffered_values == self.num_decoded_values + { + // TODO: should we return false if read_new_page() = true and + // num_buffered_values = 0? + match self.page_reader.peek_next_page()? { + Some(next_page) => Ok(next_page.num_rows != 0), + None => Ok(false), + } + } else { + Ok(true) + } + } + + /// Check whether there is more data to read from this column, + /// If the current page is fully decoded, this will load the next page + /// (if it exists) into the buffer #[inline] pub(crate) fn has_next(&mut self) -> Result { if self.num_buffered_values == 0 @@ -500,7 +560,7 @@ fn parse_v1_level( match encoding { Encoding::RLE => { let i32_size = std::mem::size_of::(); - let data_size = read_num_bytes!(i32, i32_size, buf.as_ref()) as usize; + let data_size = read_num_bytes::(i32_size, buf.as_ref()) as usize; Ok((i32_size + data_size, buf.range(i32_size, data_size))) } Encoding::BIT_PACKED => { @@ -524,8 +584,8 @@ mod tests { use crate::basic::Type as PhysicalType; use crate::schema::types::{ColumnDescriptor, ColumnPath, Type as SchemaType}; - use crate::util::test_common::make_pages; use crate::util::test_common::page_util::InMemoryPageReader; + use crate::util::test_common::rand_gen::make_pages; const NUM_LEVELS: usize = 128; const NUM_PAGES: usize = 2; @@ -1211,6 +1271,7 @@ mod tests { // Helper function for the general case of `read_batch()` where `values`, // `def_levels` and `rep_levels` are always provided with enough space. + #[allow(clippy::too_many_arguments)] fn test_read_batch_general( &mut self, desc: ColumnDescPtr, @@ -1242,6 +1303,7 @@ mod tests { // Helper function to test `read_batch()` method with custom buffers for values, // definition and repetition levels. + #[allow(clippy::too_many_arguments)] fn test_read_batch( &mut self, desc: ColumnDescPtr, diff --git a/parquet/src/column/reader/decoder.rs b/parquet/src/column/reader/decoder.rs index 6fefdca23e1b..b95b24a21c4b 100644 --- a/parquet/src/column/reader/decoder.rs +++ b/parquet/src/column/reader/decoder.rs @@ -66,8 +66,8 @@ impl ValuesBufferSlice for [T] { pub trait ColumnLevelDecoder { type Slice: LevelsBufferSlice + ?Sized; - /// Create a new [`ColumnLevelDecoder`] - fn new(max_level: i16, encoding: Encoding, data: ByteBufferPtr) -> Self; + /// Set data for this [`ColumnLevelDecoder`] + fn set_data(&mut self, encoding: Encoding, data: ByteBufferPtr); /// Read level data into `out[range]` returning the number of levels read /// @@ -250,14 +250,34 @@ impl ColumnValueDecoder for ColumnValueDecoderImpl { current_decoder.get(&mut out[range]) } - fn skip_values(&mut self, _num_values: usize) -> Result { - Err(nyi_err!("https://github.com/apache/arrow-rs/issues/1792")) + fn skip_values(&mut self, num_values: usize) -> Result { + let encoding = self + .current_encoding + .expect("current_encoding should be set"); + + let current_decoder = self + .decoders + .get_mut(&encoding) + .unwrap_or_else(|| panic!("decoder for encoding {} should be set", encoding)); + + current_decoder.skip(num_values) } } /// An implementation of [`ColumnLevelDecoder`] for `[i16]` pub struct ColumnLevelDecoderImpl { - inner: LevelDecoderInner, + decoder: Option, + bit_width: u8, +} + +impl ColumnLevelDecoderImpl { + pub fn new(max_level: i16) -> Self { + let bit_width = num_required_bits(max_level as u64); + Self { + decoder: None, + bit_width, + } + } } enum LevelDecoderInner { @@ -268,25 +288,25 @@ enum LevelDecoderInner { impl ColumnLevelDecoder for ColumnLevelDecoderImpl { type Slice = [i16]; - fn new(max_level: i16, encoding: Encoding, data: ByteBufferPtr) -> Self { - let bit_width = num_required_bits(max_level as u64); + fn set_data(&mut self, encoding: Encoding, data: ByteBufferPtr) { match encoding { Encoding::RLE => { - let mut decoder = RleDecoder::new(bit_width); + let mut decoder = RleDecoder::new(self.bit_width); decoder.set_data(data); - Self { - inner: LevelDecoderInner::Rle(decoder), - } + self.decoder = Some(LevelDecoderInner::Rle(decoder)); + } + Encoding::BIT_PACKED => { + self.decoder = Some(LevelDecoderInner::Packed( + BitReader::new(data), + self.bit_width, + )); } - Encoding::BIT_PACKED => Self { - inner: LevelDecoderInner::Packed(BitReader::new(data), bit_width), - }, _ => unreachable!("invalid level encoding: {}", encoding), } } fn read(&mut self, out: &mut Self::Slice, range: Range) -> Result { - match &mut self.inner { + match self.decoder.as_mut().unwrap() { LevelDecoderInner::Packed(reader, bit_width) => { Ok(reader.get_batch::(&mut out[range], *bit_width as usize)) } @@ -298,10 +318,41 @@ impl ColumnLevelDecoder for ColumnLevelDecoderImpl { impl DefinitionLevelDecoder for ColumnLevelDecoderImpl { fn skip_def_levels( &mut self, - _num_levels: usize, - _max_def_level: i16, + num_levels: usize, + max_def_level: i16, ) -> Result<(usize, usize)> { - Err(nyi_err!("https://github.com/apache/arrow-rs/issues/1792")) + let mut level_skip = 0; + let mut value_skip = 0; + match self.decoder.as_mut().unwrap() { + LevelDecoderInner::Packed(reader, bit_width) => { + for _ in 0..num_levels { + // Values are delimited by max_def_level + if max_def_level + == reader + .get_value::(*bit_width as usize) + .expect("Not enough values in Packed ColumnLevelDecoderImpl.") + { + value_skip += 1; + } + level_skip += 1; + } + } + LevelDecoderInner::Rle(reader) => { + for _ in 0..num_levels { + if let Some(level) = reader + .get::() + .expect("Not enough values in Rle ColumnLevelDecoderImpl.") + { + // Values are delimited by max_def_level + if level == max_def_level { + value_skip += 1; + } + } + level_skip += 1; + } + } + } + Ok((value_skip, level_skip)) } } diff --git a/parquet/src/column/writer/encoder.rs b/parquet/src/column/writer/encoder.rs new file mode 100644 index 000000000000..4fb4f210e146 --- /dev/null +++ b/parquet/src/column/writer/encoder.rs @@ -0,0 +1,290 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::basic::Encoding; +use crate::column::writer::{ + compare_greater, fallback_encoding, has_dictionary_support, is_nan, update_max, + update_min, +}; +use crate::data_type::private::ParquetValueType; +use crate::data_type::DataType; +use crate::encodings::encoding::{get_encoder, DictEncoder, Encoder}; +use crate::errors::{ParquetError, Result}; +use crate::file::properties::{EnabledStatistics, WriterProperties}; +use crate::schema::types::{ColumnDescPtr, ColumnDescriptor}; +use crate::util::memory::ByteBufferPtr; + +/// A collection of [`ParquetValueType`] encoded by a [`ColumnValueEncoder`] +pub trait ColumnValues { + /// The number of values in this collection + fn len(&self) -> usize; +} + +#[cfg(any(feature = "arrow", test))] +impl ColumnValues for T { + fn len(&self) -> usize { + arrow::array::Array::len(self) + } +} + +impl ColumnValues for [T] { + fn len(&self) -> usize { + self.len() + } +} + +/// The encoded data for a dictionary page +pub struct DictionaryPage { + pub buf: ByteBufferPtr, + pub num_values: usize, + pub is_sorted: bool, +} + +/// The encoded values for a data page, with optional statistics +pub struct DataPageValues { + pub buf: ByteBufferPtr, + pub num_values: usize, + pub encoding: Encoding, + pub min_value: Option, + pub max_value: Option, +} + +/// A generic encoder of [`ColumnValues`] to data and dictionary pages used by +/// [super::GenericColumnWriter`] +pub trait ColumnValueEncoder { + /// The underlying value type of [`Self::Values`] + /// + /// Note: this avoids needing to fully qualify `::T` + type T: ParquetValueType; + + /// The values encoded by this encoder + type Values: ColumnValues + ?Sized; + + /// Returns the min and max values in this collection, skipping any NaN values + /// + /// Returns `None` if no values found + fn min_max( + &self, + values: &Self::Values, + value_indices: Option<&[usize]>, + ) -> Option<(Self::T, Self::T)>; + + /// Create a new [`ColumnValueEncoder`] + fn try_new(descr: &ColumnDescPtr, props: &WriterProperties) -> Result + where + Self: Sized; + + /// Write the corresponding values to this [`ColumnValueEncoder`] + fn write(&mut self, values: &Self::Values, offset: usize, len: usize) -> Result<()>; + + /// Write the values at the indexes in `indices` to this [`ColumnValueEncoder`] + fn write_gather(&mut self, values: &Self::Values, indices: &[usize]) -> Result<()>; + + /// Returns the number of buffered values + fn num_values(&self) -> usize; + + /// Returns true if this encoder has a dictionary page + fn has_dictionary(&self) -> bool; + + /// Returns an estimate of the dictionary page size in bytes, or `None` if no dictionary + fn estimated_dict_page_size(&self) -> Option; + + /// Returns an estimate of the data page size in bytes + fn estimated_data_page_size(&self) -> usize; + + /// Flush the dictionary page for this column chunk if any. Any subsequent calls to + /// [`Self::write`] will not be dictionary encoded + /// + /// Note: [`Self::flush_data_page`] must be called first, as this will error if there + /// are any pending page values + fn flush_dict_page(&mut self) -> Result>; + + /// Flush the next data page for this column chunk + fn flush_data_page(&mut self) -> Result>; +} + +pub struct ColumnValueEncoderImpl { + encoder: Box>, + dict_encoder: Option>, + descr: ColumnDescPtr, + num_values: usize, + statistics_enabled: EnabledStatistics, + min_value: Option, + max_value: Option, +} + +impl ColumnValueEncoderImpl { + fn write_slice(&mut self, slice: &[T::T]) -> Result<()> { + if self.statistics_enabled == EnabledStatistics::Page { + if let Some((min, max)) = self.min_max(slice, None) { + update_min(&self.descr, &min, &mut self.min_value); + update_max(&self.descr, &max, &mut self.max_value); + } + } + + match &mut self.dict_encoder { + Some(encoder) => encoder.put(slice), + _ => self.encoder.put(slice), + } + } +} + +impl ColumnValueEncoder for ColumnValueEncoderImpl { + type T = T::T; + + type Values = [T::T]; + + fn min_max( + &self, + values: &Self::Values, + value_indices: Option<&[usize]>, + ) -> Option<(Self::T, Self::T)> { + match value_indices { + Some(indices) => { + get_min_max(&self.descr, indices.iter().map(|x| &values[*x])) + } + None => get_min_max(&self.descr, values.iter()), + } + } + + fn try_new(descr: &ColumnDescPtr, props: &WriterProperties) -> Result { + let dict_supported = props.dictionary_enabled(descr.path()) + && has_dictionary_support(T::get_physical_type(), props); + let dict_encoder = dict_supported.then(|| DictEncoder::new(descr.clone())); + + // Set either main encoder or fallback encoder. + let encoder = get_encoder( + props + .encoding(descr.path()) + .unwrap_or_else(|| fallback_encoding(T::get_physical_type(), props)), + )?; + + let statistics_enabled = props.statistics_enabled(descr.path()); + + Ok(Self { + encoder, + dict_encoder, + descr: descr.clone(), + num_values: 0, + statistics_enabled, + min_value: None, + max_value: None, + }) + } + + fn write(&mut self, values: &[T::T], offset: usize, len: usize) -> Result<()> { + self.num_values += len; + + let slice = values.get(offset..offset + len).ok_or_else(|| { + general_err!( + "Expected to write {} values, but have only {}", + len, + values.len() - offset + ) + })?; + + self.write_slice(slice) + } + + fn write_gather(&mut self, values: &Self::Values, indices: &[usize]) -> Result<()> { + let slice: Vec<_> = indices.iter().map(|idx| values[*idx].clone()).collect(); + self.write_slice(&slice) + } + + fn num_values(&self) -> usize { + self.num_values + } + + fn has_dictionary(&self) -> bool { + self.dict_encoder.is_some() + } + + fn estimated_dict_page_size(&self) -> Option { + Some(self.dict_encoder.as_ref()?.dict_encoded_size()) + } + + fn estimated_data_page_size(&self) -> usize { + match &self.dict_encoder { + Some(encoder) => encoder.estimated_data_encoded_size(), + _ => self.encoder.estimated_data_encoded_size(), + } + } + + fn flush_dict_page(&mut self) -> Result> { + match self.dict_encoder.take() { + Some(encoder) => { + if self.num_values != 0 { + return Err(general_err!( + "Must flush data pages before flushing dictionary" + )); + } + + let buf = encoder.write_dict()?; + + Ok(Some(DictionaryPage { + buf, + num_values: encoder.num_entries(), + is_sorted: encoder.is_sorted(), + })) + } + _ => Ok(None), + } + } + + fn flush_data_page(&mut self) -> Result> { + let (buf, encoding) = match &mut self.dict_encoder { + Some(encoder) => (encoder.write_indices()?, Encoding::RLE_DICTIONARY), + _ => (self.encoder.flush_buffer()?, self.encoder.encoding()), + }; + + Ok(DataPageValues { + buf, + encoding, + num_values: std::mem::take(&mut self.num_values), + min_value: self.min_value.take(), + max_value: self.max_value.take(), + }) + } +} + +fn get_min_max<'a, T, I>(descr: &ColumnDescriptor, mut iter: I) -> Option<(T, T)> +where + T: ParquetValueType + 'a, + I: Iterator, +{ + let first = loop { + let next = iter.next()?; + if !is_nan(next) { + break next; + } + }; + + let mut min = first; + let mut max = first; + for val in iter { + if is_nan(val) { + continue; + } + if compare_greater(descr, min, val) { + min = val; + } + if compare_greater(descr, val, max) { + max = val; + } + } + Some((min.clone(), max.clone())) +} diff --git a/parquet/src/column/writer.rs b/parquet/src/column/writer/mod.rs similarity index 74% rename from parquet/src/column/writer.rs rename to parquet/src/column/writer/mod.rs index 1fc5207f6b4f..05e32f7e48ad 100644 --- a/parquet/src/column/writer.rs +++ b/parquet/src/column/writer/mod.rs @@ -17,18 +17,17 @@ //! Contains column writer API. use parquet_format::{ColumnIndex, OffsetIndex}; -use std::{collections::VecDeque, convert::TryFrom, marker::PhantomData}; +use std::collections::{BTreeSet, VecDeque}; use crate::basic::{Compression, ConvertedType, Encoding, LogicalType, PageType, Type}; use crate::column::page::{CompressedPage, Page, PageWriteSpec, PageWriter}; +use crate::column::writer::encoder::{ + ColumnValueEncoder, ColumnValueEncoderImpl, ColumnValues, +}; use crate::compression::{create_codec, Codec}; use crate::data_type::private::ParquetValueType; -use crate::data_type::AsBytes; use crate::data_type::*; -use crate::encodings::{ - encoding::{get_encoder, DictEncoder, Encoder}, - levels::{max_buffer_size, LevelEncoder}, -}; +use crate::encodings::levels::LevelEncoder; use crate::errors::{ParquetError, Result}; use crate::file::metadata::{ColumnIndexBuilder, OffsetIndexBuilder}; use crate::file::properties::EnabledStatistics; @@ -38,9 +37,10 @@ use crate::file::{ properties::{WriterProperties, WriterPropertiesPtr, WriterVersion}, }; use crate::schema::types::{ColumnDescPtr, ColumnDescriptor}; -use crate::util::bit_util::FromBytes; use crate::util::memory::ByteBufferPtr; +pub(crate) mod encoder; + /// Column writer for a Parquet type. pub enum ColumnWriter<'a> { BoolColumnWriter(ColumnWriterImpl<'a, BoolType>), @@ -58,26 +58,6 @@ pub enum Level { Column, } -macro_rules! gen_stats_section { - ($physical_ty: ty, $stat_fn: ident, $min: ident, $max: ident, $distinct: ident, $nulls: ident) => {{ - let min = $min.as_ref().and_then(|v| { - Some(read_num_bytes!( - $physical_ty, - v.as_bytes().len(), - &v.as_bytes() - )) - }); - let max = $max.as_ref().and_then(|v| { - Some(read_num_bytes!( - $physical_ty, - v.as_bytes().len(), - &v.as_bytes() - )) - }); - Statistics::$stat_fn(min, max, $distinct, $nulls, false) - }}; -} - /// Gets a specific column writer corresponding to column descriptor `descr`. pub fn get_column_writer<'a>( descr: ColumnDescPtr, @@ -165,36 +145,31 @@ pub fn get_typed_column_writer_mut<'a, 'b: 'a, T: DataType>( }) } -type ColumnCloseResult = ( - u64, - u64, - ColumnChunkMetaData, - Option, - Option, -); - -/// Typed column writer for a primitive column. -pub struct ColumnWriterImpl<'a, T: DataType> { - // Column writer properties - descr: ColumnDescPtr, - props: WriterPropertiesPtr, - statistics_enabled: EnabledStatistics, +/// Metadata returned by [`GenericColumnWriter::close`] +#[derive(Debug, Clone)] +pub struct ColumnCloseResult { + /// The total number of bytes written + pub bytes_written: u64, + /// The total number of rows written + pub rows_written: u64, + /// Metadata for this column chunk + pub metadata: ColumnChunkMetaData, + /// Optional column index, for filtering + pub column_index: Option, + /// Optional offset index, identifying page locations + pub offset_index: Option, +} - page_writer: Box, - has_dictionary: bool, - dict_encoder: Option>, - encoder: Box>, - codec: Compression, - compressor: Option>, - // Metrics per page +// Metrics per page +#[derive(Default)] +struct PageMetrics { num_buffered_values: u32, - num_buffered_encoded_values: u32, num_buffered_rows: u32, - min_page_value: Option, - max_page_value: Option, num_page_nulls: u64, - page_distinct_count: Option, - // Metrics per column writer +} + +// Metrics per column writer +struct ColumnMetrics { total_bytes_written: u64, total_rows_written: u64, total_uncompressed_size: u64, @@ -202,21 +177,44 @@ pub struct ColumnWriterImpl<'a, T: DataType> { total_num_values: u64, dictionary_page_offset: Option, data_page_offset: Option, - min_column_value: Option, - max_column_value: Option, + min_column_value: Option, + max_column_value: Option, num_column_nulls: u64, column_distinct_count: Option, +} + +/// Typed column writer for a primitive column. +pub type ColumnWriterImpl<'a, T> = GenericColumnWriter<'a, ColumnValueEncoderImpl>; + +pub struct GenericColumnWriter<'a, E: ColumnValueEncoder> { + // Column writer properties + descr: ColumnDescPtr, + props: WriterPropertiesPtr, + statistics_enabled: EnabledStatistics, + + page_writer: Box, + codec: Compression, + compressor: Option>, + encoder: E, + + page_metrics: PageMetrics, + // Metrics per column writer + column_metrics: ColumnMetrics, + + /// The order of encodings within the generated metadata does not impact its meaning, + /// but we use a BTreeSet so that the output is deterministic + encodings: BTreeSet, // Reused buffers def_levels_sink: Vec, rep_levels_sink: Vec, data_pages: VecDeque, - _phantom: PhantomData, + // column index and offset index column_index_builder: ColumnIndexBuilder, offset_index_builder: OffsetIndexBuilder, } -impl<'a, T: DataType> ColumnWriterImpl<'a, T> { +impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { pub fn new( descr: ColumnDescPtr, props: WriterPropertiesPtr, @@ -224,74 +222,58 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { ) -> Self { let codec = props.compression(descr.path()); let compressor = create_codec(codec).unwrap(); - - // Optionally set dictionary encoder. - let dict_encoder = if props.dictionary_enabled(descr.path()) - && has_dictionary_support(T::get_physical_type(), &props) - { - Some(DictEncoder::new(descr.clone())) - } else { - None - }; - - // Whether or not this column writer has a dictionary encoding. - let has_dictionary = dict_encoder.is_some(); - - // Set either main encoder or fallback encoder. - let fallback_encoder = get_encoder( - descr.clone(), - props - .encoding(descr.path()) - .unwrap_or_else(|| fallback_encoding(T::get_physical_type(), &props)), - ) - .unwrap(); + let encoder = E::try_new(&descr, props.as_ref()).unwrap(); let statistics_enabled = props.statistics_enabled(descr.path()); + let mut encodings = BTreeSet::new(); + // Used for level information + encodings.insert(Encoding::RLE); + Self { descr, props, statistics_enabled, page_writer, - has_dictionary, - dict_encoder, - encoder: fallback_encoder, codec, compressor, - num_buffered_values: 0, - num_buffered_encoded_values: 0, - num_buffered_rows: 0, - total_bytes_written: 0, - total_rows_written: 0, - total_uncompressed_size: 0, - total_compressed_size: 0, - total_num_values: 0, - dictionary_page_offset: None, - data_page_offset: None, + encoder, def_levels_sink: vec![], rep_levels_sink: vec![], data_pages: VecDeque::new(), - min_page_value: None, - max_page_value: None, - num_page_nulls: 0, - page_distinct_count: None, - min_column_value: None, - max_column_value: None, - num_column_nulls: 0, - column_distinct_count: None, - _phantom: PhantomData, + page_metrics: PageMetrics { + num_buffered_values: 0, + num_buffered_rows: 0, + num_page_nulls: 0, + }, + column_metrics: ColumnMetrics { + total_bytes_written: 0, + total_rows_written: 0, + total_uncompressed_size: 0, + total_compressed_size: 0, + total_num_values: 0, + dictionary_page_offset: None, + data_page_offset: None, + min_column_value: None, + max_column_value: None, + num_column_nulls: 0, + column_distinct_count: None, + }, column_index_builder: ColumnIndexBuilder::new(), offset_index_builder: OffsetIndexBuilder::new(), + encodings, } } - fn write_batch_internal( + #[allow(clippy::too_many_arguments)] + pub(crate) fn write_batch_internal( &mut self, - values: &[T::T], + values: &E::Values, + value_indices: Option<&[usize]>, def_levels: Option<&[i16]>, rep_levels: Option<&[i16]>, - min: Option<&T::T>, - max: Option<&T::T>, + min: Option<&E::T>, + max: Option<&E::T>, distinct_count: Option, ) -> Result { // We check for DataPage limits only after we have inserted the values. If a user @@ -304,18 +286,14 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { // TODO: find out why we don't account for size of levels when we estimate page // size. - // Find out the minimal length to prevent index out of bound errors. - let mut min_len = values.len(); - if let Some(levels) = def_levels { - min_len = min_len.min(levels.len()); - } - if let Some(levels) = rep_levels { - min_len = min_len.min(levels.len()); - } + let num_levels = match def_levels { + Some(def_levels) => def_levels.len(), + None => values.len(), + }; // Find out number of batches to process. let write_batch_size = self.props.write_batch_size(); - let num_batches = min_len / write_batch_size; + let num_batches = num_levels / write_batch_size; // If only computing chunk-level statistics compute them here, page-level statistics // are computed in [`Self::write_mini_batch`] and used to update chunk statistics in @@ -323,33 +301,53 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { if self.statistics_enabled == EnabledStatistics::Chunk { match (min, max) { (Some(min), Some(max)) => { - Self::update_min(&self.descr, min, &mut self.min_column_value); - Self::update_max(&self.descr, max, &mut self.max_column_value); + update_min( + &self.descr, + min, + &mut self.column_metrics.min_column_value, + ); + update_max( + &self.descr, + max, + &mut self.column_metrics.max_column_value, + ); } (None, Some(_)) | (Some(_), None) => { panic!("min/max should be both set or both None") } (None, None) => { - for val in values { - Self::update_min(&self.descr, val, &mut self.min_column_value); - Self::update_max(&self.descr, val, &mut self.max_column_value); + if let Some((min, max)) = self.encoder.min_max(values, value_indices) + { + update_min( + &self.descr, + &min, + &mut self.column_metrics.min_column_value, + ); + update_max( + &self.descr, + &max, + &mut self.column_metrics.max_column_value, + ); } } }; } // We can only set the distinct count if there are no other writes - if self.num_buffered_values == 0 && self.num_page_nulls == 0 { - self.column_distinct_count = distinct_count; + if self.encoder.num_values() == 0 { + self.column_metrics.column_distinct_count = distinct_count; } else { - self.column_distinct_count = None; + self.column_metrics.column_distinct_count = None; } let mut values_offset = 0; let mut levels_offset = 0; for _ in 0..num_batches { values_offset += self.write_mini_batch( - &values[values_offset..values_offset + write_batch_size], + values, + values_offset, + value_indices, + write_batch_size, def_levels.map(|lv| &lv[levels_offset..levels_offset + write_batch_size]), rep_levels.map(|lv| &lv[levels_offset..levels_offset + write_batch_size]), )?; @@ -357,7 +355,10 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { } values_offset += self.write_mini_batch( - &values[values_offset..], + values, + values_offset, + value_indices, + num_levels - levels_offset, def_levels.map(|lv| &lv[levels_offset..]), rep_levels.map(|lv| &lv[levels_offset..]), )?; @@ -380,11 +381,11 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { /// non-nullable and/or non-repeated. pub fn write_batch( &mut self, - values: &[T::T], + values: &E::Values, def_levels: Option<&[i16]>, rep_levels: Option<&[i16]>, ) -> Result { - self.write_batch_internal(values, def_levels, rep_levels, None, None, None) + self.write_batch_internal(values, None, def_levels, rep_levels, None, None, None) } /// Writer may optionally provide pre-calculated statistics for use when computing @@ -396,15 +397,16 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { /// computed page statistics pub fn write_batch_with_statistics( &mut self, - values: &[T::T], + values: &E::Values, def_levels: Option<&[i16]>, rep_levels: Option<&[i16]>, - min: Option<&T::T>, - max: Option<&T::T>, + min: Option<&E::T>, + max: Option<&E::T>, distinct_count: Option, ) -> Result { self.write_batch_internal( values, + None, def_levels, rep_levels, min, @@ -416,24 +418,26 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { /// Returns total number of bytes written by this column writer so far. /// This value is also returned when column writer is closed. pub fn get_total_bytes_written(&self) -> u64 { - self.total_bytes_written + self.column_metrics.total_bytes_written } /// Returns total number of rows written by this column writer so far. /// This value is also returned when column writer is closed. pub fn get_total_rows_written(&self) -> u64 { - self.total_rows_written + self.column_metrics.total_rows_written } /// Finalises writes and closes the column writer. /// Returns total bytes written, total rows written and column chunk metadata. pub fn close(mut self) -> Result { - if self.dict_encoder.is_some() { + if self.page_metrics.num_buffered_values > 0 { + self.add_data_page()?; + } + if self.encoder.has_dictionary() { self.write_dictionary_page()?; } self.flush_data_pages()?; let metadata = self.write_column_metadata()?; - self.dict_encoder = None; self.page_writer.close()?; let (column_index, offset_index) = if self.column_index_builder.valid() { @@ -445,13 +449,13 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { (None, None) }; - Ok(( - self.total_bytes_written, - self.total_rows_written, + Ok(ColumnCloseResult { + bytes_written: self.column_metrics.total_bytes_written, + rows_written: self.column_metrics.total_rows_written, metadata, column_index, offset_index, - )) + }) } /// Writes mini batch of values, definition and repetition levels. @@ -459,12 +463,13 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { /// page size. fn write_mini_batch( &mut self, - values: &[T::T], + values: &E::Values, + values_offset: usize, + value_indices: Option<&[usize]>, + num_levels: usize, def_levels: Option<&[i16]>, rep_levels: Option<&[i16]>, ) -> Result { - let mut values_to_write = 0; - // Check if number of definition levels is the same as number of repetition // levels. if let (Some(def), Some(rep)) = (def_levels, rep_levels) { @@ -478,7 +483,7 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { } // Process definition levels and determine how many values to write. - let num_values = if self.descr.max_def_level() > 0 { + let values_to_write = if self.descr.max_def_level() > 0 { let levels = def_levels.ok_or_else(|| { general_err!( "Definition levels are required, because max definition level = {}", @@ -486,19 +491,20 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { ) })?; + let mut values_to_write = 0; for &level in levels { if level == self.descr.max_def_level() { values_to_write += 1; - } else if self.statistics_enabled == EnabledStatistics::Page { - self.num_page_nulls += 1 + } else { + // We must always compute this as it is used to populate v2 pages + self.page_metrics.num_page_nulls += 1 } } - self.write_definition_levels(levels); - u32::try_from(levels.len()).unwrap() + self.def_levels_sink.extend_from_slice(levels); + values_to_write } else { - values_to_write = values.len(); - u32::try_from(values_to_write).unwrap() + num_levels }; // Process repetition levels and determine how many rows we are about to process. @@ -513,35 +519,25 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { // Count the occasions where we start a new row for &level in levels { - self.num_buffered_rows += (level == 0) as u32 + self.page_metrics.num_buffered_rows += (level == 0) as u32 } - self.write_repetition_levels(levels); + self.rep_levels_sink.extend_from_slice(levels); } else { // Each value is exactly one row. // Equals to the number of values, we count nulls as well. - self.num_buffered_rows += num_values; + self.page_metrics.num_buffered_rows += num_levels as u32; } - // Check that we have enough values to write. - let values_to_write = values.get(0..values_to_write).ok_or_else(|| { - general_err!( - "Expected to write {} values, but have only {}", - values_to_write, - values.len() - ) - })?; - - if self.statistics_enabled == EnabledStatistics::Page { - for val in values_to_write { - self.update_page_min_max(val); + match value_indices { + Some(indices) => { + let indices = &indices[values_offset..values_offset + values_to_write]; + self.encoder.write_gather(values, indices)?; } + None => self.encoder.write(values, values_offset, values_to_write)?, } - self.write_values(values_to_write)?; - - self.num_buffered_values += num_values; - self.num_buffered_encoded_values += u32::try_from(values_to_write.len()).unwrap(); + self.page_metrics.num_buffered_values += num_levels as u32; if self.should_add_data_page() { self.add_data_page()?; @@ -551,25 +547,7 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { self.dict_fallback()?; } - Ok(values_to_write.len()) - } - - #[inline] - fn write_definition_levels(&mut self, def_levels: &[i16]) { - self.def_levels_sink.extend_from_slice(def_levels); - } - - #[inline] - fn write_repetition_levels(&mut self, rep_levels: &[i16]) { - self.rep_levels_sink.extend_from_slice(rep_levels); - } - - #[inline] - fn write_values(&mut self, values: &[T::T]) -> Result<()> { - match self.dict_encoder { - Some(ref mut encoder) => encoder.put(values), - None => self.encoder.put(values), - } + Ok(values_to_write) } /// Returns true if we need to fall back to non-dictionary encoding. @@ -578,10 +556,8 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { /// size. #[inline] fn should_dict_fallback(&self) -> bool { - match self.dict_encoder { - Some(ref encoder) => { - encoder.dict_encoded_size() >= self.props.dictionary_pagesize_limit() - } + match self.encoder.estimated_dict_page_size() { + Some(size) => size >= self.props.dictionary_pagesize_limit(), None => false, } } @@ -593,35 +569,30 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { // // In such a scenario the dictionary decoder may return an estimated encoded // size in excess of the page size limit, even when there are no buffered values - if self.num_buffered_values == 0 { + if self.encoder.num_values() == 0 { return false; } - match self.dict_encoder { - Some(ref encoder) => { - encoder.estimated_data_encoded_size() >= self.props.data_pagesize_limit() - } - None => { - self.encoder.estimated_data_encoded_size() - >= self.props.data_pagesize_limit() - } - } + self.encoder.estimated_data_page_size() >= self.props.data_pagesize_limit() } /// Performs dictionary fallback. /// Prepares and writes dictionary and all data pages into page writer. fn dict_fallback(&mut self) -> Result<()> { // At this point we know that we need to fall back. + if self.page_metrics.num_buffered_values > 0 { + self.add_data_page()?; + } self.write_dictionary_page()?; self.flush_data_pages()?; - self.dict_encoder = None; Ok(()) } /// Update the column index and offset index when adding the data page fn update_column_offset_index(&mut self, page_statistics: &Option) { // update the column index - let null_page = (self.num_buffered_rows as u64) == self.num_page_nulls; + let null_page = (self.page_metrics.num_buffered_rows as u64) + == self.page_metrics.num_page_nulls; // a page contains only null values, // and writers have to set the corresponding entries in min_values and max_values to byte[0] if null_page && self.column_index_builder.valid() { @@ -629,7 +600,7 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { null_page, &[0; 1], &[0; 1], - self.num_page_nulls as i64, + self.page_metrics.num_page_nulls as i64, ); } else if self.column_index_builder.valid() { // from page statistics @@ -643,7 +614,7 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { null_page, stat.min_bytes(), stat.max_bytes(), - self.num_page_nulls as i64, + self.page_metrics.num_page_nulls as i64, ); } } @@ -651,35 +622,31 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { // update the offset index self.offset_index_builder - .append_row_count(self.num_buffered_rows as i64); + .append_row_count(self.page_metrics.num_buffered_rows as i64); } /// Adds data page. /// Data page is either buffered in case of dictionary encoding or written directly. fn add_data_page(&mut self) -> Result<()> { // Extract encoded values - let value_bytes = match self.dict_encoder { - Some(ref mut encoder) => encoder.write_indices()?, - None => self.encoder.flush_buffer()?, - }; - - // Select encoding based on current encoder and writer version (v1 or v2). - let encoding = if self.dict_encoder.is_some() { - self.props.dictionary_data_page_encoding() - } else { - self.encoder.encoding() - }; + let values_data = self.encoder.flush_data_page()?; let max_def_level = self.descr.max_def_level(); let max_rep_level = self.descr.max_rep_level(); - self.num_column_nulls += self.num_page_nulls; - - let has_min_max = self.min_page_value.is_some() && self.max_page_value.is_some(); - let page_statistics = match self.statistics_enabled { - EnabledStatistics::Page if has_min_max => { - self.update_column_min_max(); - Some(self.make_page_statistics()) + self.column_metrics.num_column_nulls += self.page_metrics.num_page_nulls; + + let page_statistics = match (values_data.min_value, values_data.max_value) { + (Some(min), Some(max)) => { + update_min(&self.descr, &min, &mut self.column_metrics.min_column_value); + update_max(&self.descr, &max, &mut self.column_metrics.max_column_value); + Some(Statistics::new( + Some(min), + Some(max), + None, + self.page_metrics.num_page_nulls, + false, + )) } _ => None, }; @@ -697,7 +664,7 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { Encoding::RLE, &self.rep_levels_sink[..], max_rep_level, - )?[..], + )[..], ); } @@ -707,23 +674,23 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { Encoding::RLE, &self.def_levels_sink[..], max_def_level, - )?[..], + )[..], ); } - buffer.extend_from_slice(value_bytes.data()); + buffer.extend_from_slice(values_data.buf.data()); let uncompressed_size = buffer.len(); if let Some(ref mut cmpr) = self.compressor { - let mut compressed_buf = Vec::with_capacity(value_bytes.data().len()); + let mut compressed_buf = Vec::with_capacity(uncompressed_size); cmpr.compress(&buffer[..], &mut compressed_buf)?; buffer = compressed_buf; } let data_page = Page::DataPage { buf: ByteBufferPtr::new(buffer), - num_values: self.num_buffered_values, - encoding, + num_values: self.page_metrics.num_buffered_values, + encoding: values_data.encoding, def_level_encoding: Encoding::RLE, rep_level_encoding: Encoding::RLE, statistics: page_statistics, @@ -738,35 +705,35 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { if max_rep_level > 0 { let levels = - self.encode_levels_v2(&self.rep_levels_sink[..], max_rep_level)?; + self.encode_levels_v2(&self.rep_levels_sink[..], max_rep_level); rep_levels_byte_len = levels.len(); buffer.extend_from_slice(&levels[..]); } if max_def_level > 0 { let levels = - self.encode_levels_v2(&self.def_levels_sink[..], max_def_level)?; + self.encode_levels_v2(&self.def_levels_sink[..], max_def_level); def_levels_byte_len = levels.len(); buffer.extend_from_slice(&levels[..]); } let uncompressed_size = - rep_levels_byte_len + def_levels_byte_len + value_bytes.len(); + rep_levels_byte_len + def_levels_byte_len + values_data.buf.len(); // Data Page v2 compresses values only. match self.compressor { Some(ref mut cmpr) => { - cmpr.compress(value_bytes.data(), &mut buffer)?; + cmpr.compress(values_data.buf.data(), &mut buffer)?; } - None => buffer.extend_from_slice(value_bytes.data()), + None => buffer.extend_from_slice(values_data.buf.data()), } let data_page = Page::DataPageV2 { buf: ByteBufferPtr::new(buffer), - num_values: self.num_buffered_values, - encoding, - num_nulls: self.num_page_nulls as u32, - num_rows: self.num_buffered_rows, + num_values: self.page_metrics.num_buffered_values, + encoding: values_data.encoding, + num_nulls: self.page_metrics.num_page_nulls as u32, + num_rows: self.page_metrics.num_buffered_rows, def_levels_byte_len: def_levels_byte_len as u32, rep_levels_byte_len: rep_levels_byte_len as u32, is_compressed: self.compressor.is_some(), @@ -778,25 +745,20 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { }; // Check if we need to buffer data page or flush it to the sink directly. - if self.dict_encoder.is_some() { + if self.encoder.has_dictionary() { self.data_pages.push_back(compressed_page); } else { self.write_data_page(compressed_page)?; } // Update total number of rows. - self.total_rows_written += self.num_buffered_rows as u64; + self.column_metrics.total_rows_written += + self.page_metrics.num_buffered_rows as u64; // Reset state. self.rep_levels_sink.clear(); self.def_levels_sink.clear(); - self.num_buffered_values = 0; - self.num_buffered_encoded_values = 0; - self.num_buffered_rows = 0; - self.min_page_value = None; - self.max_page_value = None; - self.num_page_nulls = 0; - self.page_distinct_count = None; + self.page_metrics = PageMetrics::default(); Ok(()) } @@ -806,7 +768,7 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { #[inline] fn flush_data_pages(&mut self) -> Result<()> { // Write all outstanding data to a new page. - if self.num_buffered_values > 0 { + if self.page_metrics.num_buffered_values > 0 { self.add_data_page()?; } @@ -819,46 +781,41 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { /// Assembles and writes column chunk metadata. fn write_column_metadata(&mut self) -> Result { - let total_compressed_size = self.total_compressed_size as i64; - let total_uncompressed_size = self.total_uncompressed_size as i64; - let num_values = self.total_num_values as i64; - let dict_page_offset = self.dictionary_page_offset.map(|v| v as i64); + let total_compressed_size = self.column_metrics.total_compressed_size as i64; + let total_uncompressed_size = self.column_metrics.total_uncompressed_size as i64; + let num_values = self.column_metrics.total_num_values as i64; + let dict_page_offset = + self.column_metrics.dictionary_page_offset.map(|v| v as i64); // If data page offset is not set, then no pages have been written - let data_page_offset = self.data_page_offset.unwrap_or(0) as i64; - - let file_offset; - let mut encodings = Vec::new(); - - if self.has_dictionary { - assert!(dict_page_offset.is_some(), "Dictionary offset is not set"); - file_offset = dict_page_offset.unwrap() + total_compressed_size; - // NOTE: This should be in sync with writing dictionary pages. - encodings.push(self.props.dictionary_page_encoding()); - encodings.push(self.props.dictionary_data_page_encoding()); - // Fallback to alternative encoding, add it to the list. - if self.dict_encoder.is_none() { - encodings.push(self.encoder.encoding()); - } - } else { - file_offset = data_page_offset + total_compressed_size; - encodings.push(self.encoder.encoding()); - } - // We use only RLE level encoding for data page v1 and data page v2. - encodings.push(Encoding::RLE); + let data_page_offset = self.column_metrics.data_page_offset.unwrap_or(0) as i64; + + let file_offset = match dict_page_offset { + Some(dict_offset) => dict_offset + total_compressed_size, + None => data_page_offset + total_compressed_size, + }; - let statistics = self.make_column_statistics(); - let metadata = ColumnChunkMetaData::builder(self.descr.clone()) + let mut builder = ColumnChunkMetaData::builder(self.descr.clone()) .set_compression(self.codec) - .set_encodings(encodings) + .set_encodings(self.encodings.iter().cloned().collect()) .set_file_offset(file_offset) .set_total_compressed_size(total_compressed_size) .set_total_uncompressed_size(total_uncompressed_size) .set_num_values(num_values) .set_data_page_offset(data_page_offset) - .set_dictionary_page_offset(dict_page_offset) - .set_statistics(statistics) - .build()?; + .set_dictionary_page_offset(dict_page_offset); + + if self.statistics_enabled != EnabledStatistics::None { + let statistics = Statistics::new( + self.column_metrics.min_column_value.clone(), + self.column_metrics.max_column_value.clone(), + self.column_metrics.column_distinct_count, + self.column_metrics.num_column_nulls, + false, + ); + builder = builder.set_statistics(statistics); + } + let metadata = builder.build()?; self.page_writer.write_metadata(&metadata)?; Ok(metadata) @@ -871,26 +828,25 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { encoding: Encoding, levels: &[i16], max_level: i16, - ) -> Result> { - let size = max_buffer_size(encoding, max_level, levels.len()); - let mut encoder = LevelEncoder::v1(encoding, max_level, vec![0; size]); - encoder.put(levels)?; + ) -> Vec { + let mut encoder = LevelEncoder::v1(encoding, max_level, levels.len()); + encoder.put(levels); encoder.consume() } /// Encodes definition or repetition levels for Data Page v2. /// Encoding is always RLE. #[inline] - fn encode_levels_v2(&self, levels: &[i16], max_level: i16) -> Result> { - let size = max_buffer_size(Encoding::RLE, max_level, levels.len()); - let mut encoder = LevelEncoder::v2(max_level, vec![0; size]); - encoder.put(levels)?; + fn encode_levels_v2(&self, levels: &[i16], max_level: i16) -> Vec { + let mut encoder = LevelEncoder::v2(max_level, levels.len()); + encoder.put(levels); encoder.consume() } /// Writes compressed data page into underlying sink and updates global metrics. #[inline] fn write_data_page(&mut self, page: CompressedPage) -> Result<()> { + self.encodings.insert(page.encoding()); let page_spec = self.page_writer.write_page(page)?; // update offset index // compressed_size = header_size + compressed_data_size @@ -906,31 +862,29 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { #[inline] fn write_dictionary_page(&mut self) -> Result<()> { let compressed_page = { - let encoder = self - .dict_encoder - .as_ref() + let mut page = self + .encoder + .flush_dict_page()? .ok_or_else(|| general_err!("Dictionary encoder is not set"))?; - let is_sorted = encoder.is_sorted(); - let num_values = encoder.num_entries(); - let mut values_buf = encoder.write_dict()?; - let uncompressed_size = values_buf.len(); + let uncompressed_size = page.buf.len(); if let Some(ref mut cmpr) = self.compressor { let mut output_buf = Vec::with_capacity(uncompressed_size); - cmpr.compress(values_buf.data(), &mut output_buf)?; - values_buf = ByteBufferPtr::new(output_buf); + cmpr.compress(page.buf.data(), &mut output_buf)?; + page.buf = ByteBufferPtr::new(output_buf); } let dict_page = Page::DictionaryPage { - buf: values_buf, - num_values: num_values as u32, + buf: page.buf, + num_values: page.num_values as u32, encoding: self.props.dictionary_page_encoding(), - is_sorted, + is_sorted: page.is_sorted, }; CompressedPage::new(dict_page, uncompressed_size) }; + self.encodings.insert(compressed_page.encoding()); let page_spec = self.page_writer.write_page(compressed_page)?; self.update_metrics_for_page(page_spec); // For the directory page, don't need to update column/offset index. @@ -940,176 +894,110 @@ impl<'a, T: DataType> ColumnWriterImpl<'a, T> { /// Updates column writer metrics with each page metadata. #[inline] fn update_metrics_for_page(&mut self, page_spec: PageWriteSpec) { - self.total_uncompressed_size += page_spec.uncompressed_size as u64; - self.total_compressed_size += page_spec.compressed_size as u64; - self.total_num_values += page_spec.num_values as u64; - self.total_bytes_written += page_spec.bytes_written; + self.column_metrics.total_uncompressed_size += page_spec.uncompressed_size as u64; + self.column_metrics.total_compressed_size += page_spec.compressed_size as u64; + self.column_metrics.total_num_values += page_spec.num_values as u64; + self.column_metrics.total_bytes_written += page_spec.bytes_written; match page_spec.page_type { PageType::DATA_PAGE | PageType::DATA_PAGE_V2 => { - if self.data_page_offset.is_none() { - self.data_page_offset = Some(page_spec.offset); + if self.column_metrics.data_page_offset.is_none() { + self.column_metrics.data_page_offset = Some(page_spec.offset); } } PageType::DICTIONARY_PAGE => { assert!( - self.dictionary_page_offset.is_none(), + self.column_metrics.dictionary_page_offset.is_none(), "Dictionary offset is already set" ); - self.dictionary_page_offset = Some(page_spec.offset); + self.column_metrics.dictionary_page_offset = Some(page_spec.offset); } _ => {} } } +} - /// Returns reference to the underlying page writer. - /// This method is intended to use in tests only. - fn get_page_writer_ref(&self) -> &dyn PageWriter { - self.page_writer.as_ref() - } - - fn make_column_statistics(&self) -> Statistics { - self.make_typed_statistics(Level::Column) - } - - fn make_page_statistics(&self) -> Statistics { - self.make_typed_statistics(Level::Page) - } +fn update_min( + descr: &ColumnDescriptor, + val: &T, + min: &mut Option, +) { + update_stat::(val, min, |cur| compare_greater(descr, cur, val)) +} - pub fn make_typed_statistics(&self, level: Level) -> Statistics { - let (min, max, distinct, nulls) = match level { - Level::Page => ( - self.min_page_value.as_ref(), - self.max_page_value.as_ref(), - self.page_distinct_count, - self.num_page_nulls, - ), - Level::Column => ( - self.min_column_value.as_ref(), - self.max_column_value.as_ref(), - self.column_distinct_count, - self.num_column_nulls, - ), - }; - match self.descr.physical_type() { - Type::INT32 => gen_stats_section!(i32, int32, min, max, distinct, nulls), - Type::BOOLEAN => gen_stats_section!(bool, boolean, min, max, distinct, nulls), - Type::INT64 => gen_stats_section!(i64, int64, min, max, distinct, nulls), - Type::INT96 => gen_stats_section!(Int96, int96, min, max, distinct, nulls), - Type::FLOAT => gen_stats_section!(f32, float, min, max, distinct, nulls), - Type::DOUBLE => gen_stats_section!(f64, double, min, max, distinct, nulls), - Type::BYTE_ARRAY => { - let min = min.as_ref().map(|v| ByteArray::from(v.as_bytes().to_vec())); - let max = max.as_ref().map(|v| ByteArray::from(v.as_bytes().to_vec())); - Statistics::byte_array(min, max, distinct, nulls, false) - } - Type::FIXED_LEN_BYTE_ARRAY => { - let min = min - .as_ref() - .map(|v| ByteArray::from(v.as_bytes().to_vec())) - .map(|ba| { - let ba: FixedLenByteArray = ba.into(); - ba - }); - let max = max - .as_ref() - .map(|v| ByteArray::from(v.as_bytes().to_vec())) - .map(|ba| { - let ba: FixedLenByteArray = ba.into(); - ba - }); - Statistics::fixed_len_byte_array(min, max, distinct, nulls, false) - } - } - } +fn update_max( + descr: &ColumnDescriptor, + val: &T, + max: &mut Option, +) { + update_stat::(val, max, |cur| compare_greater(descr, val, cur)) +} - fn update_page_min_max(&mut self, val: &T::T) { - Self::update_min(&self.descr, val, &mut self.min_page_value); - Self::update_max(&self.descr, val, &mut self.max_page_value); +#[inline] +#[allow(clippy::eq_op)] +fn is_nan(val: &T) -> bool { + match T::PHYSICAL_TYPE { + Type::FLOAT | Type::DOUBLE => val != val, + _ => false, } +} - fn update_column_min_max(&mut self) { - let min = self.min_page_value.as_ref().unwrap(); - Self::update_min(&self.descr, min, &mut self.min_column_value); - - let max = self.max_page_value.as_ref().unwrap(); - Self::update_max(&self.descr, max, &mut self.max_column_value); - } +/// Perform a conditional update of `cur`, skipping any NaN values +/// +/// If `cur` is `None`, sets `cur` to `Some(val)`, otherwise calls `should_update` with +/// the value of `cur`, and updates `cur` to `Some(val)` if it returns `true` - fn update_min(descr: &ColumnDescriptor, val: &T::T, min: &mut Option) { - Self::update_stat(val, min, |cur| Self::compare_greater(descr, cur, val)) +fn update_stat(val: &T, cur: &mut Option, should_update: F) +where + F: Fn(&T) -> bool, +{ + if is_nan(val) { + return; } - fn update_max(descr: &ColumnDescriptor, val: &T::T, max: &mut Option) { - Self::update_stat(val, max, |cur| Self::compare_greater(descr, val, cur)) + if cur.as_ref().map_or(true, should_update) { + *cur = Some(val.clone()); } +} - /// Perform a conditional update of `cur`, skipping any NaN values - /// - /// If `cur` is `None`, sets `cur` to `Some(val)`, otherwise calls `should_update` with - /// the value of `cur`, and updates `cur` to `Some(val)` if it returns `true` - #[allow(clippy::eq_op)] - fn update_stat(val: &T::T, cur: &mut Option, should_update: F) - where - F: Fn(&T::T) -> bool, - { - if let Type::FLOAT | Type::DOUBLE = T::get_physical_type() { - // Skip NaN values - if val != val { - return; - } - } - - if cur.as_ref().map_or(true, should_update) { - *cur = Some(val.clone()); +/// Evaluate `a > b` according to underlying logical type. +fn compare_greater(descr: &ColumnDescriptor, a: &T, b: &T) -> bool { + if let Some(LogicalType::Integer { is_signed, .. }) = descr.logical_type() { + if !is_signed { + // need to compare unsigned + return a.as_u64().unwrap() > b.as_u64().unwrap(); } } - /// Evaluate `a > b` according to underlying logical type. - fn compare_greater(descr: &ColumnDescriptor, a: &T::T, b: &T::T) -> bool { - if let Some(LogicalType::Integer { is_signed, .. }) = descr.logical_type() { - if !is_signed { - // need to compare unsigned - return a.as_u64().unwrap() > b.as_u64().unwrap(); - } + match descr.converted_type() { + ConvertedType::UINT_8 + | ConvertedType::UINT_16 + | ConvertedType::UINT_32 + | ConvertedType::UINT_64 => { + return a.as_u64().unwrap() > b.as_u64().unwrap(); } + _ => {} + }; - match descr.converted_type() { - ConvertedType::UINT_8 - | ConvertedType::UINT_16 - | ConvertedType::UINT_32 - | ConvertedType::UINT_64 => { - return a.as_u64().unwrap() > b.as_u64().unwrap(); + if let Some(LogicalType::Decimal { .. }) = descr.logical_type() { + match T::PHYSICAL_TYPE { + Type::FIXED_LEN_BYTE_ARRAY | Type::BYTE_ARRAY => { + return compare_greater_byte_array_decimals(a.as_bytes(), b.as_bytes()); } _ => {} }; + } - if let Some(LogicalType::Decimal { .. }) = descr.logical_type() { - match T::get_physical_type() { - Type::FIXED_LEN_BYTE_ARRAY | Type::BYTE_ARRAY => { - return compare_greater_byte_array_decimals( - a.as_bytes(), - b.as_bytes(), - ); - } - _ => {} - }; - } - - if descr.converted_type() == ConvertedType::DECIMAL { - match T::get_physical_type() { - Type::FIXED_LEN_BYTE_ARRAY | Type::BYTE_ARRAY => { - return compare_greater_byte_array_decimals( - a.as_bytes(), - b.as_bytes(), - ); - } - _ => {} - }; + if descr.converted_type() == ConvertedType::DECIMAL { + match T::PHYSICAL_TYPE { + Type::FIXED_LEN_BYTE_ARRAY | Type::BYTE_ARRAY => { + return compare_greater_byte_array_decimals(a.as_bytes(), b.as_bytes()); + } + _ => {} }; + }; - a > b - } + a > b } // ---------------------------------------------------------------------- @@ -1201,6 +1089,7 @@ fn compare_greater_byte_array_decimals(a: &[u8], b: &[u8]) -> bool { #[cfg(test)] mod tests { + use bytes::Bytes; use parquet_format::BoundaryOrder; use rand::distributions::uniform::SampleUniform; use std::sync::Arc; @@ -1215,7 +1104,7 @@ mod tests { writer::SerializedPageWriter, }; use crate::schema::types::{ColumnDescriptor, ColumnPath, Type as SchemaType}; - use crate::util::{io::FileSource, test_common::random_numbers_range}; + use crate::util::test_common::rand_gen::random_numbers_range; use super::*; @@ -1280,16 +1169,16 @@ mod tests { } #[test] - #[should_panic(expected = "Dictionary offset is already set")] fn test_column_writer_write_only_one_dictionary_page() { let page_writer = get_test_page_writer(); let props = Arc::new(WriterProperties::builder().build()); let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(&[1, 2, 3, 4], None, None).unwrap(); // First page should be correctly written. - let res = writer.write_dictionary_page(); - assert!(res.is_ok()); + writer.add_data_page().unwrap(); writer.write_dictionary_page().unwrap(); + let err = writer.write_dictionary_page().unwrap_err().to_string(); + assert_eq!(err, "Parquet error: Dictionary encoder is not set"); } #[test] @@ -1302,14 +1191,8 @@ mod tests { ); let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(&[1, 2, 3, 4], None, None).unwrap(); - let res = writer.write_dictionary_page(); - assert!(res.is_err()); - if let Err(err) = res { - assert_eq!( - format!("{}", err), - "Parquet error: Dictionary encoder is not set" - ); - } + let err = writer.write_dictionary_page().unwrap_err().to_string(); + assert_eq!(err, "Parquet error: Dictionary encoder is not set"); } #[test] @@ -1325,11 +1208,13 @@ mod tests { .write_batch(&[true, false, true, false], None, None) .unwrap(); - let (bytes_written, rows_written, metadata, _, _) = writer.close().unwrap(); + let r = writer.close().unwrap(); // PlainEncoder uses bit writer to write boolean values, which all fit into 1 // byte. - assert_eq!(bytes_written, 1); - assert_eq!(rows_written, 4); + assert_eq!(r.bytes_written, 1); + assert_eq!(r.rows_written, 4); + + let metadata = r.metadata; assert_eq!(metadata.encodings(), &vec![Encoding::PLAIN, Encoding::RLE]); assert_eq!(metadata.num_values(), 4); // just values assert_eq!(metadata.dictionary_page_offset(), None); @@ -1356,14 +1241,14 @@ mod tests { true, &[true, false], None, - &[Encoding::RLE, Encoding::RLE], + &[Encoding::RLE], ); check_encoding_write_support::( WriterVersion::PARQUET_2_0, false, &[true, false], None, - &[Encoding::RLE, Encoding::RLE], + &[Encoding::RLE], ); } @@ -1374,7 +1259,7 @@ mod tests { true, &[1, 2], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_1_0, @@ -1388,14 +1273,14 @@ mod tests { true, &[1, 2], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_2_0, false, &[1, 2], None, - &[Encoding::DELTA_BINARY_PACKED, Encoding::RLE], + &[Encoding::RLE, Encoding::DELTA_BINARY_PACKED], ); } @@ -1406,7 +1291,7 @@ mod tests { true, &[1, 2], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_1_0, @@ -1420,14 +1305,14 @@ mod tests { true, &[1, 2], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_2_0, false, &[1, 2], None, - &[Encoding::DELTA_BINARY_PACKED, Encoding::RLE], + &[Encoding::RLE, Encoding::DELTA_BINARY_PACKED], ); } @@ -1438,7 +1323,7 @@ mod tests { true, &[Int96::from(vec![1, 2, 3])], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_1_0, @@ -1452,7 +1337,7 @@ mod tests { true, &[Int96::from(vec![1, 2, 3])], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_2_0, @@ -1470,7 +1355,7 @@ mod tests { true, &[1.0, 2.0], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_1_0, @@ -1484,7 +1369,7 @@ mod tests { true, &[1.0, 2.0], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_2_0, @@ -1502,7 +1387,7 @@ mod tests { true, &[1.0, 2.0], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_1_0, @@ -1516,7 +1401,7 @@ mod tests { true, &[1.0, 2.0], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_2_0, @@ -1534,7 +1419,7 @@ mod tests { true, &[ByteArray::from(vec![1u8])], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_1_0, @@ -1548,14 +1433,14 @@ mod tests { true, &[ByteArray::from(vec![1u8])], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_2_0, false, &[ByteArray::from(vec![1u8])], None, - &[Encoding::DELTA_BYTE_ARRAY, Encoding::RLE], + &[Encoding::RLE, Encoding::DELTA_BYTE_ARRAY], ); } @@ -1580,14 +1465,14 @@ mod tests { true, &[ByteArray::from(vec![1u8]).into()], Some(0), - &[Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE], + &[Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY], ); check_encoding_write_support::( WriterVersion::PARQUET_2_0, false, &[ByteArray::from(vec![1u8]).into()], None, - &[Encoding::DELTA_BYTE_ARRAY, Encoding::RLE], + &[Encoding::RLE, Encoding::DELTA_BYTE_ARRAY], ); } @@ -1598,12 +1483,14 @@ mod tests { let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(&[1, 2, 3, 4], None, None).unwrap(); - let (bytes_written, rows_written, metadata, _, _) = writer.close().unwrap(); - assert_eq!(bytes_written, 20); - assert_eq!(rows_written, 4); + let r = writer.close().unwrap(); + assert_eq!(r.bytes_written, 20); + assert_eq!(r.rows_written, 4); + + let metadata = r.metadata; assert_eq!( metadata.encodings(), - &vec![Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE] + &vec![Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY] ); assert_eq!(metadata.num_values(), 8); // dictionary + value indexes assert_eq!(metadata.compressed_size(), 20); @@ -1655,7 +1542,7 @@ mod tests { None, ) .unwrap(); - let (_bytes_written, _rows_written, metadata, _, _) = writer.close().unwrap(); + let metadata = writer.close().unwrap().metadata; if let Some(stats) = metadata.statistics() { assert!(stats.has_min_max_set()); if let Statistics::ByteArray(stats) = stats { @@ -1689,7 +1576,7 @@ mod tests { Int32Type, >(page_writer, 0, 0, props); writer.write_batch(&[0, 1, 2, 3, 4, 5], None, None).unwrap(); - let (_bytes_written, _rows_written, metadata, _, _) = writer.close().unwrap(); + let metadata = writer.close().unwrap().metadata; if let Some(stats) = metadata.statistics() { assert!(stats.has_min_max_set()); if let Statistics::Int32(stats) = stats { @@ -1723,12 +1610,14 @@ mod tests { ) .unwrap(); - let (bytes_written, rows_written, metadata, _, _) = writer.close().unwrap(); - assert_eq!(bytes_written, 20); - assert_eq!(rows_written, 4); + let r = writer.close().unwrap(); + assert_eq!(r.bytes_written, 20); + assert_eq!(r.rows_written, 4); + + let metadata = r.metadata; assert_eq!( metadata.encodings(), - &vec![Encoding::PLAIN, Encoding::RLE_DICTIONARY, Encoding::RLE] + &vec![Encoding::PLAIN, Encoding::RLE, Encoding::RLE_DICTIONARY] ); assert_eq!(metadata.num_values(), 8); // dictionary + value indexes assert_eq!(metadata.compressed_size(), 20); @@ -1770,19 +1659,19 @@ mod tests { ) .unwrap(); - let (_, _, metadata, _, _) = writer.close().unwrap(); + let r = writer.close().unwrap(); - let stats = metadata.statistics().unwrap(); + let stats = r.metadata.statistics().unwrap(); assert_eq!(stats.min_bytes(), 1_i32.to_le_bytes()); assert_eq!(stats.max_bytes(), 7_i32.to_le_bytes()); assert_eq!(stats.null_count(), 0); assert!(stats.distinct_count().is_none()); let reader = SerializedPageReader::new( - std::io::Cursor::new(buf), - 7, - Compression::UNCOMPRESSED, - Type::INT32, + Arc::new(Bytes::from(buf)), + &r.metadata, + r.rows_written as usize, + None, ) .unwrap(); @@ -1799,6 +1688,56 @@ mod tests { assert!(page_statistics.distinct_count().is_none()); } + #[test] + fn test_disabled_statistics() { + let mut buf = Vec::with_capacity(100); + let mut write = TrackedWrite::new(&mut buf); + let page_writer = Box::new(SerializedPageWriter::new(&mut write)); + let props = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::None) + .set_writer_version(WriterVersion::PARQUET_2_0) + .build(); + let props = Arc::new(props); + + let mut writer = get_test_column_writer::(page_writer, 1, 0, props); + writer + .write_batch(&[1, 2, 3, 4], Some(&[1, 0, 0, 1, 1, 1]), None) + .unwrap(); + + let r = writer.close().unwrap(); + assert!(r.metadata.statistics().is_none()); + + let reader = SerializedPageReader::new( + Arc::new(Bytes::from(buf)), + &r.metadata, + r.rows_written as usize, + None, + ) + .unwrap(); + + let pages = reader.collect::>>().unwrap(); + assert_eq!(pages.len(), 2); + + assert_eq!(pages[0].page_type(), PageType::DICTIONARY_PAGE); + assert_eq!(pages[1].page_type(), PageType::DATA_PAGE_V2); + + match &pages[1] { + Page::DataPageV2 { + num_values, + num_nulls, + num_rows, + statistics, + .. + } => { + assert_eq!(*num_values, 6); + assert_eq!(*num_nulls, 2); + assert_eq!(*num_rows, 6); + assert!(statistics.is_none()); + } + _ => unreachable!(), + } + } + #[test] fn test_column_writer_empty_column_roundtrip() { let props = WriterProperties::builder().build(); @@ -1808,40 +1747,19 @@ mod tests { #[test] fn test_column_writer_non_nullable_values_roundtrip() { let props = WriterProperties::builder().build(); - column_roundtrip_random::( - props, - 1024, - std::i32::MIN, - std::i32::MAX, - 0, - 0, - ); + column_roundtrip_random::(props, 1024, i32::MIN, i32::MAX, 0, 0); } #[test] fn test_column_writer_nullable_non_repeated_values_roundtrip() { let props = WriterProperties::builder().build(); - column_roundtrip_random::( - props, - 1024, - std::i32::MIN, - std::i32::MAX, - 10, - 0, - ); + column_roundtrip_random::(props, 1024, i32::MIN, i32::MAX, 10, 0); } #[test] fn test_column_writer_nullable_repeated_values_roundtrip() { let props = WriterProperties::builder().build(); - column_roundtrip_random::( - props, - 1024, - std::i32::MIN, - std::i32::MAX, - 10, - 10, - ); + column_roundtrip_random::(props, 1024, i32::MIN, i32::MAX, 10, 10); } #[test] @@ -1850,14 +1768,7 @@ mod tests { .set_dictionary_pagesize_limit(32) .set_data_pagesize_limit(32) .build(); - column_roundtrip_random::( - props, - 1024, - std::i32::MIN, - std::i32::MAX, - 10, - 10, - ); + column_roundtrip_random::(props, 1024, i32::MIN, i32::MAX, 10, 10); } #[test] @@ -1865,14 +1776,7 @@ mod tests { for i in &[1usize, 2, 5, 10, 11, 1023] { let props = WriterProperties::builder().set_write_batch_size(*i).build(); - column_roundtrip_random::( - props, - 1024, - std::i32::MIN, - std::i32::MAX, - 10, - 10, - ); + column_roundtrip_random::(props, 1024, i32::MIN, i32::MAX, 10, 10); } } @@ -1882,14 +1786,7 @@ mod tests { .set_writer_version(WriterVersion::PARQUET_1_0) .set_dictionary_enabled(false) .build(); - column_roundtrip_random::( - props, - 1024, - std::i32::MIN, - std::i32::MAX, - 10, - 10, - ); + column_roundtrip_random::(props, 1024, i32::MIN, i32::MAX, 10, 10); } #[test] @@ -1898,14 +1795,7 @@ mod tests { .set_writer_version(WriterVersion::PARQUET_2_0) .set_dictionary_enabled(false) .build(); - column_roundtrip_random::( - props, - 1024, - std::i32::MIN, - std::i32::MAX, - 10, - 10, - ); + column_roundtrip_random::(props, 1024, i32::MIN, i32::MAX, 10, 10); } #[test] @@ -1914,14 +1804,7 @@ mod tests { .set_writer_version(WriterVersion::PARQUET_1_0) .set_compression(Compression::SNAPPY) .build(); - column_roundtrip_random::( - props, - 2048, - std::i32::MIN, - std::i32::MAX, - 10, - 10, - ); + column_roundtrip_random::(props, 2048, i32::MIN, i32::MAX, 10, 10); } #[test] @@ -1930,14 +1813,7 @@ mod tests { .set_writer_version(WriterVersion::PARQUET_2_0) .set_compression(Compression::SNAPPY) .build(); - column_roundtrip_random::( - props, - 2048, - std::i32::MIN, - std::i32::MAX, - 10, - 10, - ); + column_roundtrip_random::(props, 2048, i32::MIN, i32::MAX, 10, 10); } #[test] @@ -1956,16 +1832,15 @@ mod tests { let data = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(data, None, None).unwrap(); - let (bytes_written, _, _, _, _) = writer.close().unwrap(); + let r = writer.close().unwrap(); // Read pages and check the sequence - let source = FileSource::new(&file, 0, bytes_written as usize); let mut page_reader = Box::new( SerializedPageReader::new( - source, - data.len() as i64, - Compression::UNCOMPRESSED, - Int32Type::get_physical_type(), + Arc::new(file), + &r.metadata, + r.rows_written as usize, + None, ) .unwrap(), ); @@ -2202,22 +2077,11 @@ mod tests { // second page writer.write_batch(&[4, 8, 2, -5], None, None).unwrap(); - let (_, rows_written, metadata, column_index, offset_index) = - writer.close().unwrap(); - let column_index = match column_index { - None => { - panic!("Can't fine the column index"); - } - Some(column_index) => column_index, - }; - let offset_index = match offset_index { - None => { - panic!("Can't find the offset index"); - } - Some(offset_index) => offset_index, - }; + let r = writer.close().unwrap(); + let column_index = r.column_index.unwrap(); + let offset_index = r.offset_index.unwrap(); - assert_eq!(8, rows_written); + assert_eq!(8, r.rows_written); // column index assert_eq!(2, column_index.null_pages.len()); @@ -2228,7 +2092,7 @@ mod tests { assert_eq!(0, column_index.null_counts.as_ref().unwrap()[idx]); } - if let Some(stats) = metadata.statistics() { + if let Some(stats) = r.metadata.statistics() { assert!(stats.has_min_max_set()); assert_eq!(stats.null_count(), 0); assert_eq!(stats.distinct_count(), None); @@ -2339,16 +2203,14 @@ mod tests { let values_written = writer.write_batch(values, def_levels, rep_levels).unwrap(); assert_eq!(values_written, values.len()); - let (bytes_written, rows_written, column_metadata, _, _) = - writer.close().unwrap(); + let result = writer.close().unwrap(); - let source = FileSource::new(&file, 0, bytes_written as usize); let page_reader = Box::new( SerializedPageReader::new( - source, - column_metadata.num_values(), - column_metadata.compression(), - T::get_physical_type(), + Arc::new(file), + &result.metadata, + result.rows_written as usize, + None, ) .unwrap(), ); @@ -2388,11 +2250,11 @@ mod tests { actual_rows_written += 1; } } - assert_eq!(actual_rows_written, rows_written); + assert_eq!(actual_rows_written, result.rows_written); } else if actual_def_levels.is_some() { - assert_eq!(levels_read as u64, rows_written); + assert_eq!(levels_read as u64, result.rows_written); } else { - assert_eq!(values_read as u64, rows_written); + assert_eq!(values_read as u64, result.rows_written); } } @@ -2406,8 +2268,7 @@ mod tests { let props = Arc::new(props); let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(values, None, None).unwrap(); - let (_, _, metadata, _, _) = writer.close().unwrap(); - metadata + writer.close().unwrap().metadata } // Function to use in tests for EncodingWriteSupport. This checks that dictionary @@ -2518,7 +2379,7 @@ mod tests { let mut writer = get_test_column_writer::(page_writer, 0, 0, props); writer.write_batch(values, None, None).unwrap(); - let (_bytes_written, _rows_written, metadata, _, _) = writer.close().unwrap(); + let metadata = writer.close().unwrap().metadata; if let Some(stats) = metadata.statistics() { stats.clone() } else { @@ -2541,20 +2402,6 @@ mod tests { get_typed_column_writer::(column_writer) } - /// Returns decimals column reader. - fn get_test_decimals_column_reader( - page_reader: Box, - max_def_level: i16, - max_rep_level: i16, - ) -> ColumnReaderImpl { - let descr = Arc::new(get_test_decimals_column_descr::( - max_def_level, - max_rep_level, - )); - let column_reader = get_column_reader(descr, page_reader); - get_typed_column_reader::(column_reader) - } - /// Returns descriptor for Decimal type with primitive column. fn get_test_decimals_column_descr( max_def_level: i16, @@ -2589,20 +2436,6 @@ mod tests { get_typed_column_writer::(column_writer) } - /// Returns column reader for UINT32 Column provided as ConvertedType only - fn get_test_unsigned_int_given_as_converted_column_reader( - page_reader: Box, - max_def_level: i16, - max_rep_level: i16, - ) -> ColumnReaderImpl { - let descr = Arc::new(get_test_converted_type_unsigned_integer_column_descr::( - max_def_level, - max_rep_level, - )); - let column_reader = get_column_reader(descr, page_reader); - get_typed_column_reader::(column_reader) - } - /// Returns column descriptor for UINT32 Column provided as ConvertedType only fn get_test_converted_type_unsigned_integer_column_descr( max_def_level: i16, diff --git a/parquet/src/compression.rs b/parquet/src/compression.rs index a5e49360a28a..ee5141cbe140 100644 --- a/parquet/src/compression.rs +++ b/parquet/src/compression.rs @@ -329,7 +329,7 @@ pub use zstd_codec::*; mod tests { use super::*; - use crate::util::test_common::*; + use crate::util::test_common::rand_gen::random_bytes; fn test_roundtrip(c: CodecType, data: &[u8]) { let mut c1 = create_codec(c).unwrap().unwrap(); diff --git a/parquet/src/data_type.rs b/parquet/src/data_type.rs index 86ccefbd85eb..9cd36cf43dc8 100644 --- a/parquet/src/data_type.rs +++ b/parquet/src/data_type.rs @@ -23,8 +23,6 @@ use std::mem; use std::ops::{Deref, DerefMut}; use std::str::from_utf8; -use byteorder::{BigEndian, ByteOrder}; - use crate::basic::Type; use crate::column::reader::{ColumnReader, ColumnReaderImpl}; use crate::column::writer::{ColumnWriter, ColumnWriterImpl}; @@ -36,52 +34,54 @@ use crate::util::{ /// Rust representation for logical type INT96, value is backed by an array of `u32`. /// The type only takes 12 bytes, without extra padding. -#[derive(Clone, Debug, PartialOrd, Default)] +#[derive(Clone, Copy, Debug, PartialOrd, Default, PartialEq, Eq)] pub struct Int96 { - value: Option<[u32; 3]>, + value: [u32; 3], } impl Int96 { /// Creates new INT96 type struct with no data set. pub fn new() -> Self { - Self { value: None } + Self { value: [0; 3] } } /// Returns underlying data as slice of [`u32`]. #[inline] pub fn data(&self) -> &[u32] { - self.value - .as_ref() - .expect("set_data should have been called") + &self.value } /// Sets data for this INT96 type. #[inline] pub fn set_data(&mut self, elem0: u32, elem1: u32, elem2: u32) { - self.value = Some([elem0, elem1, elem2]); + self.value = [elem0, elem1, elem2]; } /// Converts this INT96 into an i64 representing the number of MILLISECONDS since Epoch pub fn to_i64(&self) -> i64 { + let (seconds, nanoseconds) = self.to_seconds_and_nanos(); + seconds * 1_000 + nanoseconds / 1_000_000 + } + + /// Converts this INT96 into an i64 representing the number of NANOSECONDS since EPOCH + /// + /// Will wrap around on overflow + pub fn to_nanos(&self) -> i64 { + let (seconds, nanoseconds) = self.to_seconds_and_nanos(); + seconds + .wrapping_mul(1_000_000_000) + .wrapping_add(nanoseconds) + } + + /// Converts this INT96 to a number of seconds and nanoseconds since EPOCH + pub fn to_seconds_and_nanos(&self) -> (i64, i64) { const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; const SECONDS_PER_DAY: i64 = 86_400; - const MILLIS_PER_SECOND: i64 = 1_000; let day = self.data()[2] as i64; let nanoseconds = ((self.data()[1] as i64) << 32) + self.data()[0] as i64; let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; - - seconds * MILLIS_PER_SECOND + nanoseconds / 1_000_000 - } -} - -impl PartialEq for Int96 { - fn eq(&self, other: &Int96) -> bool { - match (&self.value, &other.value) { - (Some(v1), Some(v2)) => v1 == v2, - (None, None) => true, - _ => false, - } + (seconds, nanoseconds) } } @@ -313,6 +313,12 @@ impl From for FixedLenByteArray { } } +impl From> for FixedLenByteArray { + fn from(buf: Vec) -> FixedLenByteArray { + FixedLenByteArray(ByteArray::from(buf)) + } +} + impl From for ByteArray { fn from(other: FixedLenByteArray) -> Self { other.0 @@ -349,8 +355,7 @@ pub enum Decimal { impl Decimal { /// Creates new decimal value from `i32`. pub fn from_i32(value: i32, precision: i32, scale: i32) -> Self { - let mut bytes = [0; 4]; - BigEndian::write_i32(&mut bytes, value); + let bytes = value.to_be_bytes(); Decimal::Int32 { value: bytes, precision, @@ -360,8 +365,7 @@ impl Decimal { /// Creates new decimal value from `i64`. pub fn from_i64(value: i64, precision: i32, scale: i32) -> Self { - let mut bytes = [0; 8]; - BigEndian::write_i64(&mut bytes, value); + let bytes = value.to_be_bytes(); Decimal::Int64 { value: bytes, precision, @@ -565,34 +569,35 @@ impl AsBytes for str { pub(crate) mod private { use crate::encodings::decoding::PlainDecoderDetails; - use crate::util::bit_util::{round_upto_power_of_2, BitReader, BitWriter}; + use crate::util::bit_util::{read_num_bytes, BitReader, BitWriter}; use crate::util::memory::ByteBufferPtr; - use byteorder::ByteOrder; + use crate::basic::Type; use std::convert::TryInto; use super::{ParquetError, Result, SliceAsBytes}; - pub type BitIndex = u64; - /// Sealed trait to start to remove specialisation from implementations /// /// This is done to force the associated value type to be unimplementable outside of this /// crate, and thus hint to the type system (and end user) traits are public for the contract /// and not for extension. pub trait ParquetValueType: - std::cmp::PartialEq + PartialEq + std::fmt::Debug + std::fmt::Display - + std::default::Default - + std::clone::Clone + + Default + + Clone + super::AsBytes + super::FromBytes - + super::SliceAsBytes + + SliceAsBytes + PartialOrd + Send + crate::encodings::decoding::private::GetDecoder + + crate::file::statistics::private::MakeStatistics { + const PHYSICAL_TYPE: Type; + /// Encode the value directly from a higher level encoder fn encode( values: &[Self], @@ -613,6 +618,8 @@ pub(crate) mod private { decoder: &mut PlainDecoderDetails, ) -> Result; + fn skip(decoder: &mut PlainDecoderDetails, num_values: usize) -> Result; + /// Return the encoded size for a type fn dict_encoding_size(&self) -> (usize, usize) { (std::mem::size_of::(), 1) @@ -644,26 +651,16 @@ pub(crate) mod private { } impl ParquetValueType for bool { + const PHYSICAL_TYPE: Type = Type::BOOLEAN; + #[inline] fn encode( values: &[Self], _: &mut W, bit_writer: &mut BitWriter, ) -> Result<()> { - if bit_writer.bytes_written() + values.len() / 8 >= bit_writer.capacity() { - let bits_available = - (bit_writer.capacity() - bit_writer.bytes_written()) * 8; - let bits_needed = values.len() - bits_available; - let bytes_needed = (bits_needed + 7) / 8; - let bytes_needed = round_upto_power_of_2(bytes_needed, 256); - bit_writer.extend(bytes_needed); - } for value in values { - if !bit_writer.put_value(*value as u64, 1) { - return Err(ParquetError::EOF( - "unable to put boolean value".to_string(), - )); - } + bit_writer.put_value(*value as u64, 1) } Ok(()) } @@ -690,6 +687,14 @@ pub(crate) mod private { Ok(values_read) } + fn skip(decoder: &mut PlainDecoderDetails, num_values: usize) -> Result { + let bit_reader = decoder.bit_reader.as_mut().unwrap(); + let num_values = std::cmp::min(num_values, decoder.num_values); + let values_read = bit_reader.skip(num_values, 1); + decoder.num_values -= values_read; + Ok(values_read) + } + #[inline] fn as_i64(&self) -> Result { Ok(*self as i64) @@ -706,22 +711,11 @@ pub(crate) mod private { } } - /// Hopelessly unsafe function that emulates `num::as_ne_bytes` - /// - /// It is not recommended to use this outside of this private module as, while it - /// _should_ work for primitive values, it is little better than a transmutation - /// and can act as a backdoor into mis-interpreting types as arbitary byte slices - #[inline] - fn as_raw<'a, T>(value: *const T) -> &'a [u8] { - unsafe { - let value = value as *const u8; - std::slice::from_raw_parts(value, std::mem::size_of::()) - } - } - macro_rules! impl_from_raw { - ($ty: ty, $self: ident => $as_i64: block) => { + ($ty: ty, $physical_ty: expr, $self: ident => $as_i64: block) => { impl ParquetValueType for $ty { + const PHYSICAL_TYPE: Type = $physical_ty; + #[inline] fn encode(values: &[Self], writer: &mut W, _: &mut BitWriter) -> Result<()> { let raw = unsafe { @@ -764,6 +758,23 @@ pub(crate) mod private { Ok(num_values) } + #[inline] + fn skip(decoder: &mut PlainDecoderDetails, num_values: usize) -> Result { + let data = decoder.data.as_ref().expect("set_data should have been called"); + let num_values = num_values.min(decoder.num_values); + let bytes_left = data.len() - decoder.start; + let bytes_to_skip = std::mem::size_of::() * num_values; + + if bytes_left < bytes_to_skip { + return Err(eof_err!("Not enough bytes to skip")); + } + + decoder.start += bytes_to_skip; + decoder.num_values -= num_values; + + Ok(num_values) + } + #[inline] fn as_i64(&$self) -> Result { $as_i64 @@ -782,12 +793,14 @@ pub(crate) mod private { } } - impl_from_raw!(i32, self => { Ok(*self as i64) }); - impl_from_raw!(i64, self => { Ok(*self) }); - impl_from_raw!(f32, self => { Err(general_err!("Type cannot be converted to i64")) }); - impl_from_raw!(f64, self => { Err(general_err!("Type cannot be converted to i64")) }); + impl_from_raw!(i32, Type::INT32, self => { Ok(*self as i64) }); + impl_from_raw!(i64, Type::INT64, self => { Ok(*self) }); + impl_from_raw!(f32, Type::FLOAT, self => { Err(general_err!("Type cannot be converted to i64")) }); + impl_from_raw!(f64, Type::DOUBLE, self => { Err(general_err!("Type cannot be converted to i64")) }); impl ParquetValueType for super::Int96 { + const PHYSICAL_TYPE: Type = Type::INT96; + #[inline] fn encode( values: &[Self], @@ -841,9 +854,11 @@ pub(crate) mod private { let mut pos = 0; // position in byte array for item in buffer.iter_mut().take(num_values) { - let elem0 = byteorder::LittleEndian::read_u32(&bytes[pos..pos + 4]); - let elem1 = byteorder::LittleEndian::read_u32(&bytes[pos + 4..pos + 8]); - let elem2 = byteorder::LittleEndian::read_u32(&bytes[pos + 8..pos + 12]); + let elem0 = u32::from_le_bytes(bytes[pos..pos + 4].try_into().unwrap()); + let elem1 = + u32::from_le_bytes(bytes[pos + 4..pos + 8].try_into().unwrap()); + let elem2 = + u32::from_le_bytes(bytes[pos + 8..pos + 12].try_into().unwrap()); item.set_data(elem0, elem1, elem2); pos += 12; @@ -853,6 +868,24 @@ pub(crate) mod private { Ok(num_values) } + fn skip(decoder: &mut PlainDecoderDetails, num_values: usize) -> Result { + let data = decoder + .data + .as_ref() + .expect("set_data should have been called"); + let num_values = std::cmp::min(num_values, decoder.num_values); + let bytes_left = data.len() - decoder.start; + let bytes_to_skip = 12 * num_values; + + if bytes_left < bytes_to_skip { + return Err(eof_err!("Not enough bytes to skip")); + } + decoder.start += bytes_to_skip; + decoder.num_values -= num_values; + + Ok(num_values) + } + #[inline] fn as_any(&self) -> &dyn std::any::Any { self @@ -864,22 +897,9 @@ pub(crate) mod private { } } - // TODO - Why does macro importing fail? - /// Reads `$size` of bytes from `$src`, and reinterprets them as type `$ty`, in - /// little-endian order. `$ty` must implement the `Default` trait. Otherwise this won't - /// compile. - /// This is copied and modified from byteorder crate. - macro_rules! read_num_bytes { - ($ty:ty, $size:expr, $src:expr) => {{ - assert!($size <= $src.len()); - let mut buffer = - <$ty as $crate::util::bit_util::FromBytes>::Buffer::default(); - buffer.as_mut()[..$size].copy_from_slice(&$src[..$size]); - <$ty>::from_ne_bytes(buffer) - }}; - } - impl ParquetValueType for super::ByteArray { + const PHYSICAL_TYPE: Type = Type::BYTE_ARRAY; + #[inline] fn encode( values: &[Self], @@ -916,9 +936,9 @@ pub(crate) mod private { .as_mut() .expect("set_data should have been called"); let num_values = std::cmp::min(buffer.len(), decoder.num_values); - for i in 0..num_values { + for val_array in buffer.iter_mut().take(num_values) { let len: usize = - read_num_bytes!(u32, 4, data.start_from(decoder.start).as_ref()) + read_num_bytes::(4, data.start_from(decoder.start).as_ref()) as usize; decoder.start += std::mem::size_of::(); @@ -926,7 +946,7 @@ pub(crate) mod private { return Err(eof_err!("Not enough bytes to decode")); } - let val: &mut Self = buffer[i].as_mut_any().downcast_mut().unwrap(); + let val: &mut Self = val_array.as_mut_any().downcast_mut().unwrap(); val.set_data(data.range(decoder.start, len)); decoder.start += len; @@ -936,6 +956,24 @@ pub(crate) mod private { Ok(num_values) } + fn skip(decoder: &mut PlainDecoderDetails, num_values: usize) -> Result { + let data = decoder + .data + .as_mut() + .expect("set_data should have been called"); + let num_values = num_values.min(decoder.num_values); + + for _ in 0..num_values { + let len: usize = + read_num_bytes::(4, data.start_from(decoder.start).as_ref()) + as usize; + decoder.start += std::mem::size_of::() + len; + } + decoder.num_values -= num_values; + + Ok(num_values) + } + #[inline] fn dict_encoding_size(&self) -> (usize, usize) { (std::mem::size_of::(), self.len()) @@ -953,6 +991,8 @@ pub(crate) mod private { } impl ParquetValueType for super::FixedLenByteArray { + const PHYSICAL_TYPE: Type = Type::FIXED_LEN_BYTE_ARRAY; + #[inline] fn encode( values: &[Self], @@ -1005,6 +1045,28 @@ pub(crate) mod private { Ok(num_values) } + fn skip(decoder: &mut PlainDecoderDetails, num_values: usize) -> Result { + assert!(decoder.type_length > 0); + + let data = decoder + .data + .as_mut() + .expect("set_data should have been called"); + let num_values = std::cmp::min(num_values, decoder.num_values); + for _ in 0..num_values { + let len = decoder.type_length as usize; + + if data.len() < decoder.start + len { + return Err(eof_err!("Not enough bytes to skip")); + } + + decoder.start += len; + } + decoder.num_values -= num_values; + + Ok(num_values) + } + #[inline] fn dict_encoding_size(&self) -> (usize, usize) { (std::mem::size_of::(), self.len()) @@ -1028,7 +1090,9 @@ pub trait DataType: 'static + Send { type T: private::ParquetValueType; /// Returns Parquet physical type. - fn get_physical_type() -> Type; + fn get_physical_type() -> Type { + ::PHYSICAL_TYPE + } /// Returns size in bytes for Rust representation of the physical type. fn get_type_size() -> usize; @@ -1071,25 +1135,21 @@ where } macro_rules! make_type { - ($name:ident, $physical_ty:path, $reader_ident: ident, $writer_ident: ident, $native_ty:ty, $size:expr) => { + ($name:ident, $reader_ident: ident, $writer_ident: ident, $native_ty:ty, $size:expr) => { #[derive(Clone)] pub struct $name {} impl DataType for $name { type T = $native_ty; - fn get_physical_type() -> Type { - $physical_ty - } - fn get_type_size() -> usize { $size } fn get_column_reader( - column_writer: ColumnReader, + column_reader: ColumnReader, ) -> Option> { - match column_writer { + match column_reader { ColumnReader::$reader_ident(w) => Some(w), _ => None, } @@ -1127,57 +1187,20 @@ macro_rules! make_type { // Generate struct definitions for all physical types -make_type!( - BoolType, - Type::BOOLEAN, - BoolColumnReader, - BoolColumnWriter, - bool, - 1 -); -make_type!( - Int32Type, - Type::INT32, - Int32ColumnReader, - Int32ColumnWriter, - i32, - 4 -); -make_type!( - Int64Type, - Type::INT64, - Int64ColumnReader, - Int64ColumnWriter, - i64, - 8 -); +make_type!(BoolType, BoolColumnReader, BoolColumnWriter, bool, 1); +make_type!(Int32Type, Int32ColumnReader, Int32ColumnWriter, i32, 4); +make_type!(Int64Type, Int64ColumnReader, Int64ColumnWriter, i64, 8); make_type!( Int96Type, - Type::INT96, Int96ColumnReader, Int96ColumnWriter, Int96, mem::size_of::() ); -make_type!( - FloatType, - Type::FLOAT, - FloatColumnReader, - FloatColumnWriter, - f32, - 4 -); -make_type!( - DoubleType, - Type::DOUBLE, - DoubleColumnReader, - DoubleColumnWriter, - f64, - 8 -); +make_type!(FloatType, FloatColumnReader, FloatColumnWriter, f32, 4); +make_type!(DoubleType, DoubleColumnReader, DoubleColumnWriter, f64, 8); make_type!( ByteArrayType, - Type::BYTE_ARRAY, ByteArrayColumnReader, ByteArrayColumnWriter, ByteArray, @@ -1185,13 +1208,24 @@ make_type!( ); make_type!( FixedLenByteArrayType, - Type::FIXED_LEN_BYTE_ARRAY, FixedLenByteArrayColumnReader, FixedLenByteArrayColumnWriter, FixedLenByteArray, mem::size_of::() ); +impl AsRef<[u8]> for ByteArray { + fn as_ref(&self) -> &[u8] { + self.as_bytes() + } +} + +impl AsRef<[u8]> for FixedLenByteArray { + fn as_ref(&self) -> &[u8] { + self.as_bytes() + } +} + impl FromBytes for Int96 { type Buffer = [u8; 12]; fn from_le_bytes(bs: Self::Buffer) -> Self { @@ -1220,29 +1254,29 @@ impl FromBytes for Int96 { // FIXME Needed to satisfy the constraint of many decoding functions but ByteArray does not // appear to actual be converted directly from bytes impl FromBytes for ByteArray { - type Buffer = [u8; 8]; + type Buffer = Vec; fn from_le_bytes(bs: Self::Buffer) -> Self { - ByteArray::from(bs.to_vec()) + ByteArray::from(bs) } fn from_be_bytes(_bs: Self::Buffer) -> Self { unreachable!() } fn from_ne_bytes(bs: Self::Buffer) -> Self { - ByteArray::from(bs.to_vec()) + ByteArray::from(bs) } } impl FromBytes for FixedLenByteArray { - type Buffer = [u8; 8]; + type Buffer = Vec; fn from_le_bytes(bs: Self::Buffer) -> Self { - Self(ByteArray::from(bs.to_vec())) + Self(ByteArray::from(bs)) } fn from_be_bytes(_bs: Self::Buffer) -> Self { unreachable!() } fn from_ne_bytes(bs: Self::Buffer) -> Self { - Self(ByteArray::from(bs.to_vec())) + Self(ByteArray::from(bs)) } } diff --git a/parquet/src/encodings/decoding.rs b/parquet/src/encodings/decoding.rs index b33514aaf629..86941ffe0eeb 100644 --- a/parquet/src/encodings/decoding.rs +++ b/parquet/src/encodings/decoding.rs @@ -206,6 +206,9 @@ pub trait Decoder: Send { /// Returns the encoding for this decoder. fn encoding(&self) -> Encoding; + + /// Skip the specified number of values in this decoder stream. + fn skip(&mut self, num_values: usize) -> Result; } /// Gets a decoder for the column descriptor `descr` and encoding type `encoding`. @@ -291,6 +294,11 @@ impl Decoder for PlainDecoder { fn get(&mut self, buffer: &mut [T::T]) -> Result { T::T::decode(buffer, &mut self.inner) } + + #[inline] + fn skip(&mut self, num_values: usize) -> Result { + T::T::skip(&mut self.inner, num_values) + } } // ---------------------------------------------------------------------- @@ -314,6 +322,12 @@ pub struct DictDecoder { num_values: usize, } +impl Default for DictDecoder { + fn default() -> Self { + Self::new() + } +} + impl DictDecoder { /// Creates new dictionary decoder. pub fn new() -> Self { @@ -363,6 +377,15 @@ impl Decoder for DictDecoder { fn encoding(&self) -> Encoding { Encoding::RLE_DICTIONARY } + + fn skip(&mut self, num_values: usize) -> Result { + assert!(self.rle_decoder.is_some()); + assert!(self.has_dictionary, "Must call set_dict() first!"); + + let rle = self.rle_decoder.as_mut().unwrap(); + let num_values = cmp::min(num_values, self.num_values); + rle.skip(num_values) + } } // ---------------------------------------------------------------------- @@ -377,6 +400,12 @@ pub struct RleValueDecoder { _phantom: PhantomData, } +impl Default for RleValueDecoder { + fn default() -> Self { + Self::new() + } +} + impl RleValueDecoder { pub fn new() -> Self { Self { @@ -395,7 +424,7 @@ impl Decoder for RleValueDecoder { // We still need to remove prefix of i32 from the stream. const I32_SIZE: usize = mem::size_of::(); - let data_size = read_num_bytes!(i32, I32_SIZE, data.as_ref()) as usize; + let data_size = bit_util::read_num_bytes::(I32_SIZE, data.as_ref()) as usize; self.decoder = RleDecoder::new(1); self.decoder.set_data(data.range(I32_SIZE, data_size)); self.values_left = num_values; @@ -419,6 +448,14 @@ impl Decoder for RleValueDecoder { self.values_left -= values_read; Ok(values_read) } + + #[inline] + fn skip(&mut self, num_values: usize) -> Result { + let num_values = cmp::min(num_values, self.values_left); + let values_skipped = self.decoder.skip(num_values)?; + self.values_left -= values_skipped; + Ok(values_skipped) + } } // ---------------------------------------------------------------------- @@ -460,6 +497,15 @@ pub struct DeltaBitPackDecoder { last_value: T::T, } +impl Default for DeltaBitPackDecoder +where + T::T: Default + FromPrimitive + WrappingAdd + Copy, +{ + fn default() -> Self { + Self::new() + } +} + impl DeltaBitPackDecoder where T::T: Default + FromPrimitive + WrappingAdd + Copy, @@ -688,6 +734,64 @@ where fn encoding(&self) -> Encoding { Encoding::DELTA_BINARY_PACKED } + + fn skip(&mut self, num_values: usize) -> Result { + let mut skip = 0; + let to_skip = num_values.min(self.values_left); + if to_skip == 0 { + return Ok(0); + } + + // try to consume first value in header. + if let Some(value) = self.first_value.take() { + self.last_value = value; + skip += 1; + self.values_left -= 1; + } + + let mini_block_batch_size = match T::T::PHYSICAL_TYPE { + Type::INT32 => 32, + Type::INT64 => 64, + _ => unreachable!(), + }; + + let mut skip_buffer = vec![T::T::default(); mini_block_batch_size]; + while skip < to_skip { + if self.mini_block_remaining == 0 { + self.next_mini_block()?; + } + + let bit_width = self.mini_block_bit_widths[self.mini_block_idx] as usize; + let mini_block_to_skip = self.mini_block_remaining.min(to_skip - skip); + let mini_block_should_skip = mini_block_to_skip; + + let skip_count = self + .bit_reader + .get_batch(&mut skip_buffer[0..mini_block_to_skip], bit_width); + + if skip_count != mini_block_to_skip { + return Err(general_err!( + "Expected to skip {} values from mini block got {}.", + mini_block_batch_size, + skip_count + )); + } + + for v in &mut skip_buffer[0..skip_count] { + *v = v + .wrapping_add(&self.min_delta) + .wrapping_add(&self.last_value); + + self.last_value = *v; + } + + skip += mini_block_should_skip; + self.mini_block_remaining -= mini_block_should_skip; + self.values_left -= mini_block_should_skip; + } + + Ok(to_skip) + } } // ---------------------------------------------------------------------- @@ -719,6 +823,12 @@ pub struct DeltaLengthByteArrayDecoder { _phantom: PhantomData, } +impl Default for DeltaLengthByteArrayDecoder { + fn default() -> Self { + Self::new() + } +} + impl DeltaLengthByteArrayDecoder { /// Creates new delta length byte array decoder. pub fn new() -> Self { @@ -791,6 +901,29 @@ impl Decoder for DeltaLengthByteArrayDecoder { fn encoding(&self) -> Encoding { Encoding::DELTA_LENGTH_BYTE_ARRAY } + + fn skip(&mut self, num_values: usize) -> Result { + match T::get_physical_type() { + Type::BYTE_ARRAY => { + let num_values = cmp::min(num_values, self.num_values); + + let next_offset: i32 = self.lengths + [self.current_idx..self.current_idx + num_values] + .iter() + .sum(); + + self.current_idx += num_values; + self.offset += next_offset as usize; + + self.num_values -= num_values; + Ok(num_values) + } + other_type => Err(general_err!( + "DeltaLengthByteArrayDecoder not support {}, only support byte array", + other_type + )), + } + } } // ---------------------------------------------------------------------- @@ -823,6 +956,12 @@ pub struct DeltaByteArrayDecoder { _phantom: PhantomData, } +impl Default for DeltaByteArrayDecoder { + fn default() -> Self { + Self::new() + } +} + impl DeltaByteArrayDecoder { /// Creates new delta byte array decoder. pub fn new() -> Self { @@ -922,6 +1061,11 @@ impl Decoder for DeltaByteArrayDecoder { fn encoding(&self) -> Encoding { Encoding::DELTA_BYTE_ARRAY } + + fn skip(&mut self, num_values: usize) -> Result { + let mut buffer = vec![T::T::default(); num_values]; + self.get(&mut buffer) + } } #[cfg(test)] @@ -934,7 +1078,7 @@ mod tests { use crate::schema::types::{ ColumnDescPtr, ColumnDescriptor, ColumnPath, Type as SchemaType, }; - use crate::util::{bit_util::set_array_bit, test_common::RandGen}; + use crate::util::test_common::rand_gen::RandGen; #[test] fn test_get_decoders() { @@ -995,6 +1139,26 @@ mod tests { ); } + #[test] + fn test_plain_skip_int32() { + let data = vec![42, 18, 52]; + let data_bytes = Int32Type::to_byte_array(&data[..]); + test_plain_skip::( + ByteBufferPtr::new(data_bytes), + 3, + 1, + -1, + &data[1..], + ); + } + + #[test] + fn test_plain_skip_all_int32() { + let data = vec![42, 18, 52]; + let data_bytes = Int32Type::to_byte_array(&data[..]); + test_plain_skip::(ByteBufferPtr::new(data_bytes), 3, 5, -1, &[]); + } + #[test] fn test_plain_decode_int32_spaced() { let data = [42, 18, 52]; @@ -1028,6 +1192,26 @@ mod tests { ); } + #[test] + fn test_plain_skip_int64() { + let data = vec![42, 18, 52]; + let data_bytes = Int64Type::to_byte_array(&data[..]); + test_plain_skip::( + ByteBufferPtr::new(data_bytes), + 3, + 2, + -1, + &data[2..], + ); + } + + #[test] + fn test_plain_skip_all_int64() { + let data = vec![42, 18, 52]; + let data_bytes = Int64Type::to_byte_array(&data[..]); + test_plain_skip::(ByteBufferPtr::new(data_bytes), 3, 3, -1, &[]); + } + #[test] fn test_plain_decode_float() { let data = vec![3.14, 2.414, 12.51]; @@ -1042,6 +1226,46 @@ mod tests { ); } + #[test] + fn test_plain_skip_float() { + let data = vec![3.14, 2.414, 12.51]; + let data_bytes = FloatType::to_byte_array(&data[..]); + test_plain_skip::( + ByteBufferPtr::new(data_bytes), + 3, + 1, + -1, + &data[1..], + ); + } + + #[test] + fn test_plain_skip_all_float() { + let data = vec![3.14, 2.414, 12.51]; + let data_bytes = FloatType::to_byte_array(&data[..]); + test_plain_skip::(ByteBufferPtr::new(data_bytes), 3, 4, -1, &[]); + } + + #[test] + fn test_plain_skip_double() { + let data = vec![3.14f64, 2.414f64, 12.51f64]; + let data_bytes = DoubleType::to_byte_array(&data[..]); + test_plain_skip::( + ByteBufferPtr::new(data_bytes), + 3, + 1, + -1, + &data[1..], + ); + } + + #[test] + fn test_plain_skip_all_double() { + let data = vec![3.14f64, 2.414f64, 12.51f64]; + let data_bytes = DoubleType::to_byte_array(&data[..]); + test_plain_skip::(ByteBufferPtr::new(data_bytes), 3, 5, -1, &[]); + } + #[test] fn test_plain_decode_double() { let data = vec![3.14f64, 2.414f64, 12.51f64]; @@ -1074,6 +1298,34 @@ mod tests { ); } + #[test] + fn test_plain_skip_int96() { + let mut data = vec![Int96::new(); 4]; + data[0].set_data(11, 22, 33); + data[1].set_data(44, 55, 66); + data[2].set_data(10, 20, 30); + data[3].set_data(40, 50, 60); + let data_bytes = Int96Type::to_byte_array(&data[..]); + test_plain_skip::( + ByteBufferPtr::new(data_bytes), + 4, + 2, + -1, + &data[2..], + ); + } + + #[test] + fn test_plain_skip_all_int96() { + let mut data = vec![Int96::new(); 4]; + data[0].set_data(11, 22, 33); + data[1].set_data(44, 55, 66); + data[2].set_data(10, 20, 30); + data[3].set_data(40, 50, 60); + let data_bytes = Int96Type::to_byte_array(&data[..]); + test_plain_skip::(ByteBufferPtr::new(data_bytes), 4, 8, -1, &[]); + } + #[test] fn test_plain_decode_bool() { let data = vec![ @@ -1090,6 +1342,30 @@ mod tests { ); } + #[test] + fn test_plain_skip_bool() { + let data = vec![ + false, true, false, false, true, false, true, true, false, true, + ]; + let data_bytes = BoolType::to_byte_array(&data[..]); + test_plain_skip::( + ByteBufferPtr::new(data_bytes), + 10, + 5, + -1, + &data[5..], + ); + } + + #[test] + fn test_plain_skip_all_bool() { + let data = vec![ + false, true, false, false, true, false, true, true, false, true, + ]; + let data_bytes = BoolType::to_byte_array(&data[..]); + test_plain_skip::(ByteBufferPtr::new(data_bytes), 10, 20, -1, &[]); + } + #[test] fn test_plain_decode_byte_array() { let mut data = vec![ByteArray::new(); 2]; @@ -1106,6 +1382,30 @@ mod tests { ); } + #[test] + fn test_plain_skip_byte_array() { + let mut data = vec![ByteArray::new(); 2]; + data[0].set_data(ByteBufferPtr::new(String::from("hello").into_bytes())); + data[1].set_data(ByteBufferPtr::new(String::from("parquet").into_bytes())); + let data_bytes = ByteArrayType::to_byte_array(&data[..]); + test_plain_skip::( + ByteBufferPtr::new(data_bytes), + 2, + 1, + -1, + &data[1..], + ); + } + + #[test] + fn test_plain_skip_all_byte_array() { + let mut data = vec![ByteArray::new(); 2]; + data[0].set_data(ByteBufferPtr::new(String::from("hello").into_bytes())); + data[1].set_data(ByteBufferPtr::new(String::from("parquet").into_bytes())); + let data_bytes = ByteArrayType::to_byte_array(&data[..]); + test_plain_skip::(ByteBufferPtr::new(data_bytes), 2, 2, -1, &[]); + } + #[test] fn test_plain_decode_fixed_len_byte_array() { let mut data = vec![FixedLenByteArray::default(); 3]; @@ -1123,6 +1423,38 @@ mod tests { ); } + #[test] + fn test_plain_skip_fixed_len_byte_array() { + let mut data = vec![FixedLenByteArray::default(); 3]; + data[0].set_data(ByteBufferPtr::new(String::from("bird").into_bytes())); + data[1].set_data(ByteBufferPtr::new(String::from("come").into_bytes())); + data[2].set_data(ByteBufferPtr::new(String::from("flow").into_bytes())); + let data_bytes = FixedLenByteArrayType::to_byte_array(&data[..]); + test_plain_skip::( + ByteBufferPtr::new(data_bytes), + 3, + 1, + 4, + &data[1..], + ); + } + + #[test] + fn test_plain_skip_all_fixed_len_byte_array() { + let mut data = vec![FixedLenByteArray::default(); 3]; + data[0].set_data(ByteBufferPtr::new(String::from("bird").into_bytes())); + data[1].set_data(ByteBufferPtr::new(String::from("come").into_bytes())); + data[2].set_data(ByteBufferPtr::new(String::from("flow").into_bytes())); + let data_bytes = FixedLenByteArrayType::to_byte_array(&data[..]); + test_plain_skip::( + ByteBufferPtr::new(data_bytes), + 3, + 6, + 4, + &[], + ); + } + fn test_plain_decode( data: ByteBufferPtr, num_values: usize, @@ -1139,6 +1471,34 @@ mod tests { assert_eq!(buffer, expected); } + fn test_plain_skip( + data: ByteBufferPtr, + num_values: usize, + skip: usize, + type_length: i32, + expected: &[T::T], + ) { + let mut decoder: PlainDecoder = PlainDecoder::new(type_length); + let result = decoder.set_data(data, num_values); + assert!(result.is_ok()); + let skipped = decoder.skip(skip).expect("skipping values"); + + if skip >= num_values { + assert_eq!(skipped, num_values); + + let mut buffer = vec![T::T::default(); 1]; + let remaining = decoder.get(&mut buffer).expect("getting remaining values"); + assert_eq!(remaining, 0); + } else { + assert_eq!(skipped, skip); + let mut buffer = vec![T::T::default(); num_values - skip]; + let remaining = decoder.get(&mut buffer).expect("getting remaining values"); + assert_eq!(remaining, num_values - skip); + assert_eq!(decoder.values_left(), 0); + assert_eq!(buffer, expected); + } + } + fn test_plain_decode_spaced( data: ByteBufferPtr, num_values: usize, @@ -1217,12 +1577,29 @@ mod tests { test_delta_bit_packed_decode::(vec![block_data]); } + #[test] + fn test_skip_delta_bit_packed_int32_repeat() { + let block_data = vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, + 3, 4, 5, 6, 7, 8, + ]; + test_skip::(block_data.clone(), Encoding::DELTA_BINARY_PACKED, 10); + test_skip::(block_data, Encoding::DELTA_BINARY_PACKED, 100); + } + #[test] fn test_delta_bit_packed_int32_uneven() { let block_data = vec![1, -2, 3, -4, 5, 6, 7, 8, 9, 10, 11]; test_delta_bit_packed_decode::(vec![block_data]); } + #[test] + fn test_skip_delta_bit_packed_int32_uneven() { + let block_data = vec![1, -2, 3, -4, 5, 6, 7, 8, 9, 10, 11]; + test_skip::(block_data.clone(), Encoding::DELTA_BINARY_PACKED, 5); + test_skip::(block_data, Encoding::DELTA_BINARY_PACKED, 100); + } + #[test] fn test_delta_bit_packed_int32_same_values() { let block_data = vec![ @@ -1238,21 +1615,54 @@ mod tests { test_delta_bit_packed_decode::(vec![block_data]); } + #[test] + fn test_skip_delta_bit_packed_int32_same_values() { + let block_data = vec![ + 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, 127, + 127, + ]; + test_skip::(block_data.clone(), Encoding::DELTA_BINARY_PACKED, 5); + test_skip::(block_data, Encoding::DELTA_BINARY_PACKED, 100); + + let block_data = vec![ + -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, -127, + -127, -127, -127, + ]; + test_skip::(block_data.clone(), Encoding::DELTA_BINARY_PACKED, 5); + test_skip::(block_data, Encoding::DELTA_BINARY_PACKED, 100); + } + #[test] fn test_delta_bit_packed_int32_min_max() { let block_data = vec![ - i32::min_value(), - i32::max_value(), - i32::min_value(), - i32::max_value(), - i32::min_value(), - i32::max_value(), - i32::min_value(), - i32::max_value(), + i32::MIN, + i32::MIN, + i32::MIN, + i32::MAX, + i32::MIN, + i32::MAX, + i32::MIN, + i32::MAX, ]; test_delta_bit_packed_decode::(vec![block_data]); } + #[test] + fn test_skip_delta_bit_packed_int32_min_max() { + let block_data = vec![ + i32::MIN, + i32::MIN, + i32::MIN, + i32::MAX, + i32::MIN, + i32::MAX, + i32::MIN, + i32::MAX, + ]; + test_skip::(block_data.clone(), Encoding::DELTA_BINARY_PACKED, 5); + test_skip::(block_data, Encoding::DELTA_BINARY_PACKED, 100); + } + #[test] fn test_delta_bit_packed_int32_multiple_blocks() { // Test multiple 'put' calls on the same encoder @@ -1465,8 +1875,7 @@ mod tests { let col_descr = create_test_col_desc_ptr(-1, T::get_physical_type()); // Encode data - let mut encoder = - get_encoder::(col_descr.clone(), encoding).expect("get encoder"); + let mut encoder = get_encoder::(encoding).expect("get encoder"); for v in &data[..] { encoder.put(&v[..]).expect("ok to encode"); @@ -1493,6 +1902,41 @@ mod tests { assert_eq!(result, expected); } + fn test_skip(data: Vec, encoding: Encoding, skip: usize) { + // Type length should not really matter for encode/decode test, + // otherwise change it based on type + let col_descr = create_test_col_desc_ptr(-1, T::get_physical_type()); + + // Encode data + let mut encoder = get_encoder::(encoding).expect("get encoder"); + + encoder.put(&data).expect("ok to encode"); + + let bytes = encoder.flush_buffer().expect("ok to flush buffer"); + + let mut decoder = get_decoder::(col_descr, encoding).expect("get decoder"); + decoder.set_data(bytes, data.len()).expect("ok to set data"); + + if skip >= data.len() { + let skipped = decoder.skip(skip).expect("ok to skip"); + assert_eq!(skipped, data.len()); + + let skipped_again = decoder.skip(skip).expect("ok to skip again"); + assert_eq!(skipped_again, 0); + } else { + let skipped = decoder.skip(skip).expect("ok to skip"); + assert_eq!(skipped, skip); + + let remaining = data.len() - skip; + + let expected = &data[skip..]; + let mut buffer = vec![T::T::default(); remaining]; + let fetched = decoder.get(&mut buffer).expect("ok to decode"); + assert_eq!(remaining, fetched); + assert_eq!(&buffer, expected); + } + } + fn create_and_check_decoder( encoding: Encoding, err: Option, @@ -1560,7 +2004,7 @@ mod tests { v.push(0); } if *item { - set_array_bit(&mut v[..], i); + v[i / 8] |= 1 << (i % 8); } } v diff --git a/parquet/src/encodings/encoding/dict_encoder.rs b/parquet/src/encodings/encoding/dict_encoder.rs new file mode 100644 index 000000000000..18deba65e687 --- /dev/null +++ b/parquet/src/encodings/encoding/dict_encoder.rs @@ -0,0 +1,172 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// ---------------------------------------------------------------------- +// Dictionary encoding + +use crate::basic::{Encoding, Type}; +use crate::data_type::private::ParquetValueType; +use crate::data_type::DataType; +use crate::encodings::encoding::{Encoder, PlainEncoder}; +use crate::encodings::rle::RleEncoder; +use crate::errors::Result; +use crate::schema::types::ColumnDescPtr; +use crate::util::bit_util::num_required_bits; +use crate::util::interner::{Interner, Storage}; +use crate::util::memory::ByteBufferPtr; + +#[derive(Debug)] +struct KeyStorage { + uniques: Vec, + + size_in_bytes: usize, + + type_length: usize, +} + +impl Storage for KeyStorage { + type Key = u64; + type Value = T::T; + + fn get(&self, idx: Self::Key) -> &Self::Value { + &self.uniques[idx as usize] + } + + fn push(&mut self, value: &Self::Value) -> Self::Key { + let (base_size, num_elements) = value.dict_encoding_size(); + + let unique_size = match T::get_physical_type() { + Type::BYTE_ARRAY => base_size + num_elements, + Type::FIXED_LEN_BYTE_ARRAY => self.type_length, + _ => base_size, + }; + self.size_in_bytes += unique_size; + + let key = self.uniques.len() as u64; + self.uniques.push(value.clone()); + key + } +} + +/// Dictionary encoder. +/// The dictionary encoding builds a dictionary of values encountered in a given column. +/// The dictionary page is written first, before the data pages of the column chunk. +/// +/// Dictionary page format: the entries in the dictionary - in dictionary order - +/// using the plain encoding. +/// +/// Data page format: the bit width used to encode the entry ids stored as 1 byte +/// (max bit width = 32), followed by the values encoded using RLE/Bit packed described +/// above (with the given bit width). +pub struct DictEncoder { + interner: Interner>, + + /// The buffered indices + indices: Vec, +} + +impl DictEncoder { + /// Creates new dictionary encoder. + pub fn new(desc: ColumnDescPtr) -> Self { + let storage = KeyStorage { + uniques: vec![], + size_in_bytes: 0, + type_length: desc.type_length() as usize, + }; + + Self { + interner: Interner::new(storage), + indices: vec![], + } + } + + /// Returns true if dictionary entries are sorted, false otherwise. + pub fn is_sorted(&self) -> bool { + // Sorting is not supported currently. + false + } + + /// Returns number of unique values (keys) in the dictionary. + pub fn num_entries(&self) -> usize { + self.interner.storage().uniques.len() + } + + /// Returns size of unique values (keys) in the dictionary, in bytes. + pub fn dict_encoded_size(&self) -> usize { + self.interner.storage().size_in_bytes + } + + /// Writes out the dictionary values with PLAIN encoding in a byte buffer, and return + /// the result. + pub fn write_dict(&self) -> Result { + let mut plain_encoder = PlainEncoder::::new(); + plain_encoder.put(&self.interner.storage().uniques)?; + plain_encoder.flush_buffer() + } + + /// Writes out the dictionary values with RLE encoding in a byte buffer, and return + /// the result. + pub fn write_indices(&mut self) -> Result { + let buffer_len = self.estimated_data_encoded_size(); + let mut buffer = Vec::with_capacity(buffer_len); + buffer.push(self.bit_width() as u8); + + // Write bit width in the first byte + let mut encoder = RleEncoder::new_from_buf(self.bit_width(), buffer); + for index in &self.indices { + encoder.put(*index as u64) + } + self.indices.clear(); + Ok(ByteBufferPtr::new(encoder.consume())) + } + + fn put_one(&mut self, value: &T::T) { + self.indices.push(self.interner.intern(value)); + } + + #[inline] + fn bit_width(&self) -> u8 { + num_required_bits(self.num_entries().saturating_sub(1) as u64) + } +} + +impl Encoder for DictEncoder { + fn put(&mut self, values: &[T::T]) -> Result<()> { + self.indices.reserve(values.len()); + for i in values { + self.put_one(i) + } + Ok(()) + } + + // Performance Note: + // As far as can be seen these functions are rarely called and as such we can hint to the + // compiler that they dont need to be folded into hot locations in the final output. + fn encoding(&self) -> Encoding { + Encoding::PLAIN_DICTIONARY + } + + fn estimated_data_encoded_size(&self) -> usize { + let bit_width = self.bit_width(); + 1 + RleEncoder::min_buffer_size(bit_width) + + RleEncoder::max_buffer_size(bit_width, self.indices.len()) + } + + fn flush_buffer(&mut self) -> Result { + self.write_indices() + } +} diff --git a/parquet/src/encodings/encoding.rs b/parquet/src/encodings/encoding/mod.rs similarity index 79% rename from parquet/src/encodings/encoding.rs rename to parquet/src/encodings/encoding/mod.rs index 651635af59ce..34d3bb3d4c75 100644 --- a/parquet/src/encodings/encoding.rs +++ b/parquet/src/encodings/encoding/mod.rs @@ -17,20 +17,22 @@ //! Contains all supported encoders for Parquet. -use std::{cmp, io::Write, marker::PhantomData}; +use std::{cmp, marker::PhantomData}; use crate::basic::*; use crate::data_type::private::ParquetValueType; use crate::data_type::*; use crate::encodings::rle::RleEncoder; use crate::errors::{ParquetError, Result}; -use crate::schema::types::ColumnDescPtr; use crate::util::{ bit_util::{self, num_required_bits, BitWriter}, - hash_util, memory::ByteBufferPtr, }; +pub use dict_encoder::DictEncoder; + +mod dict_encoder; + // ---------------------------------------------------------------------- // Encoders @@ -73,12 +75,9 @@ pub trait Encoder { /// Gets a encoder for the particular data type `T` and encoding `encoding`. Memory usage /// for the encoder instance is tracked by `mem_tracker`. -pub fn get_encoder( - desc: ColumnDescPtr, - encoding: Encoding, -) -> Result>> { +pub fn get_encoder(encoding: Encoding) -> Result>> { let encoder: Box> = match encoding { - Encoding::PLAIN => Box::new(PlainEncoder::new(desc, vec![])), + Encoding::PLAIN => Box::new(PlainEncoder::new()), Encoding::RLE_DICTIONARY | Encoding::PLAIN_DICTIONARY => { return Err(general_err!( "Cannot initialize this encoding through this function" @@ -110,17 +109,21 @@ pub fn get_encoder( pub struct PlainEncoder { buffer: Vec, bit_writer: BitWriter, - desc: ColumnDescPtr, _phantom: PhantomData, } +impl Default for PlainEncoder { + fn default() -> Self { + Self::new() + } +} + impl PlainEncoder { /// Creates new plain encoder. - pub fn new(desc: ColumnDescPtr, buffer: Vec) -> Self { + pub fn new() -> Self { Self { - buffer, + buffer: vec![], bit_writer: BitWriter::new(256), - desc, _phantom: PhantomData, } } @@ -154,225 +157,6 @@ impl Encoder for PlainEncoder { } } -// ---------------------------------------------------------------------- -// Dictionary encoding - -const INITIAL_HASH_TABLE_SIZE: usize = 1024; -const MAX_HASH_LOAD: f32 = 0.7; -const HASH_SLOT_EMPTY: i32 = -1; - -/// Dictionary encoder. -/// The dictionary encoding builds a dictionary of values encountered in a given column. -/// The dictionary page is written first, before the data pages of the column chunk. -/// -/// Dictionary page format: the entries in the dictionary - in dictionary order - -/// using the plain encoding. -/// -/// Data page format: the bit width used to encode the entry ids stored as 1 byte -/// (max bit width = 32), followed by the values encoded using RLE/Bit packed described -/// above (with the given bit width). -pub struct DictEncoder { - // Descriptor for the column to be encoded. - desc: ColumnDescPtr, - - // Size of the table. **Must be** a power of 2. - hash_table_size: usize, - - // Store `hash_table_size` - 1, so that `j & mod_bitmask` is equivalent to - // `j % hash_table_size`, but uses far fewer CPU cycles. - mod_bitmask: u32, - - // Stores indices which map (many-to-one) to the values in the `uniques` array. - // Here we are using fix-sized array with linear probing. - // A slot with `HASH_SLOT_EMPTY` indicates the slot is not currently occupied. - hash_slots: Vec, - - // Indices that have not yet be written out by `write_indices()`. - buffered_indices: Vec, - - // The unique observed values. - uniques: Vec, - - // Size in bytes needed to encode this dictionary. - uniques_size_in_bytes: usize, -} - -impl DictEncoder { - /// Creates new dictionary encoder. - pub fn new(desc: ColumnDescPtr) -> Self { - let mut slots = vec![]; - slots.resize(INITIAL_HASH_TABLE_SIZE, -1); - Self { - desc, - hash_table_size: INITIAL_HASH_TABLE_SIZE, - mod_bitmask: (INITIAL_HASH_TABLE_SIZE - 1) as u32, - hash_slots: slots, - buffered_indices: vec![], - uniques: vec![], - uniques_size_in_bytes: 0, - } - } - - /// Returns true if dictionary entries are sorted, false otherwise. - #[inline] - pub fn is_sorted(&self) -> bool { - // Sorting is not supported currently. - false - } - - /// Returns number of unique values (keys) in the dictionary. - pub fn num_entries(&self) -> usize { - self.uniques.len() - } - - /// Returns size of unique values (keys) in the dictionary, in bytes. - pub fn dict_encoded_size(&self) -> usize { - self.uniques_size_in_bytes - } - - /// Writes out the dictionary values with PLAIN encoding in a byte buffer, and return - /// the result. - #[inline] - pub fn write_dict(&self) -> Result { - let mut plain_encoder = PlainEncoder::::new(self.desc.clone(), vec![]); - plain_encoder.put(&self.uniques)?; - plain_encoder.flush_buffer() - } - - /// Writes out the dictionary values with RLE encoding in a byte buffer, and return - /// the result. - pub fn write_indices(&mut self) -> Result { - // TODO: the caller should allocate the buffer - let buffer_len = self.estimated_data_encoded_size(); - let mut buffer: Vec = vec![0; buffer_len as usize]; - buffer[0] = self.bit_width() as u8; - - // Write bit width in the first byte - buffer.write_all((self.bit_width() as u8).as_bytes())?; - let mut encoder = RleEncoder::new_from_buf(self.bit_width(), buffer, 1); - for index in &self.buffered_indices { - if !encoder.put(*index as u64)? { - return Err(general_err!("Encoder doesn't have enough space")); - } - } - self.buffered_indices.clear(); - Ok(ByteBufferPtr::new(encoder.consume()?)) - } - - #[inline] - #[allow(clippy::unnecessary_wraps)] - fn put_one(&mut self, value: &T::T) -> Result<()> { - let mut j = (hash_util::hash(value, 0) & self.mod_bitmask) as usize; - let mut index = self.hash_slots[j]; - - while index != HASH_SLOT_EMPTY && self.uniques[index as usize] != *value { - j += 1; - if j == self.hash_table_size { - j = 0; - } - index = self.hash_slots[j]; - } - - if index == HASH_SLOT_EMPTY { - index = self.insert_fresh_slot(j, value.clone()); - } - - self.buffered_indices.push(index); - Ok(()) - } - - #[inline(never)] - fn insert_fresh_slot(&mut self, slot: usize, value: T::T) -> i32 { - let index = self.uniques.len() as i32; - self.hash_slots[slot] = index; - - let (base_size, num_elements) = value.dict_encoding_size(); - - let unique_size = match T::get_physical_type() { - Type::BYTE_ARRAY => base_size + num_elements, - Type::FIXED_LEN_BYTE_ARRAY => self.desc.type_length() as usize, - _ => base_size, - }; - - self.uniques_size_in_bytes += unique_size; - self.uniques.push(value); - - if self.uniques.len() > (self.hash_table_size as f32 * MAX_HASH_LOAD) as usize { - self.double_table_size(); - } - - index - } - - #[inline] - fn bit_width(&self) -> u8 { - let num_entries = self.uniques.len(); - if num_entries <= 1 { - num_entries as u8 - } else { - num_required_bits(num_entries as u64 - 1) - } - } - - fn double_table_size(&mut self) { - let new_size = self.hash_table_size * 2; - let mut new_hash_slots = vec![]; - new_hash_slots.resize(new_size, HASH_SLOT_EMPTY); - for i in 0..self.hash_table_size { - let index = self.hash_slots[i]; - if index == HASH_SLOT_EMPTY { - continue; - } - let value = &self.uniques[index as usize]; - let mut j = (hash_util::hash(value, 0) & ((new_size - 1) as u32)) as usize; - let mut slot = new_hash_slots[j]; - while slot != HASH_SLOT_EMPTY && self.uniques[slot as usize] != *value { - j += 1; - if j == new_size { - j = 0; - } - slot = new_hash_slots[j]; - } - - new_hash_slots[j] = index; - } - - self.hash_table_size = new_size; - self.mod_bitmask = (new_size - 1) as u32; - self.hash_slots = new_hash_slots; - } -} - -impl Encoder for DictEncoder { - #[inline] - fn put(&mut self, values: &[T::T]) -> Result<()> { - for i in values { - self.put_one(i)? - } - Ok(()) - } - - // Performance Note: - // As far as can be seen these functions are rarely called and as such we can hint to the - // compiler that they dont need to be folded into hot locations in the final output. - #[cold] - fn encoding(&self) -> Encoding { - Encoding::PLAIN_DICTIONARY - } - - #[inline] - fn estimated_data_encoded_size(&self) -> usize { - let bit_width = self.bit_width(); - 1 + RleEncoder::min_buffer_size(bit_width) - + RleEncoder::max_buffer_size(bit_width, self.buffered_indices.len()) - } - - #[inline] - fn flush_buffer(&mut self) -> Result { - self.write_indices() - } -} - // ---------------------------------------------------------------------- // RLE encoding @@ -387,6 +171,12 @@ pub struct RleValueEncoder { _phantom: PhantomData, } +impl Default for RleValueEncoder { + fn default() -> Self { + Self::new() + } +} + impl RleValueEncoder { /// Creates new rle value encoder. pub fn new() -> Self { @@ -402,15 +192,16 @@ impl Encoder for RleValueEncoder { fn put(&mut self, values: &[T::T]) -> Result<()> { ensure_phys_ty!(Type::BOOLEAN, "RleValueEncoder only supports BoolType"); - if self.encoder.is_none() { - self.encoder = Some(RleEncoder::new(1, DEFAULT_RLE_BUFFER_LEN)); - } - let rle_encoder = self.encoder.as_mut().unwrap(); + let rle_encoder = self.encoder.get_or_insert_with(|| { + let mut buffer = Vec::with_capacity(DEFAULT_RLE_BUFFER_LEN); + // Reserve space for length + buffer.extend_from_slice(&[0; 4]); + RleEncoder::new_from_buf(1, buffer) + }); + for value in values { let value = value.as_u64()?; - if !rle_encoder.put(value)? { - return Err(general_err!("RLE buffer is full")); - } + rle_encoder.put(value) } Ok(()) } @@ -436,25 +227,18 @@ impl Encoder for RleValueEncoder { ensure_phys_ty!(Type::BOOLEAN, "RleValueEncoder only supports BoolType"); let rle_encoder = self .encoder - .as_mut() + .take() .expect("RLE value encoder is not initialized"); // Flush all encoder buffers and raw values - let encoded_data = { - let buf = rle_encoder.flush_buffer()?; - - // Note that buf does not have any offset, all data is encoded bytes - let len = (buf.len() as i32).to_le(); - let len_bytes = len.as_bytes(); - let mut encoded_data = vec![]; - encoded_data.extend_from_slice(len_bytes); - encoded_data.extend_from_slice(buf); - encoded_data - }; - // Reset rle encoder for the next batch - rle_encoder.clear(); + let mut buf = rle_encoder.consume(); + assert!(buf.len() >= 4, "should have had padding inserted"); + + // Note that buf does not have any offset, all data is encoded bytes + let len = (buf.len() - 4) as i32; + buf[..4].copy_from_slice(&len.to_le_bytes()); - Ok(ByteBufferPtr::new(encoded_data)) + Ok(ByteBufferPtr::new(buf)) } } @@ -463,7 +247,6 @@ impl Encoder for RleValueEncoder { const MAX_PAGE_HEADER_WRITER_SIZE: usize = 32; const MAX_BIT_WRITER_SIZE: usize = 10 * 1024 * 1024; -const DEFAULT_BLOCK_SIZE: usize = 128; const DEFAULT_NUM_MINI_BLOCKS: usize = 4; /// Delta bit packed encoder. @@ -503,15 +286,28 @@ pub struct DeltaBitPackEncoder { _phantom: PhantomData, } +impl Default for DeltaBitPackEncoder { + fn default() -> Self { + Self::new() + } +} + impl DeltaBitPackEncoder { /// Creates new delta bit packed encoder. pub fn new() -> Self { - let block_size = DEFAULT_BLOCK_SIZE; - let num_mini_blocks = DEFAULT_NUM_MINI_BLOCKS; - let mini_block_size = block_size / num_mini_blocks; - assert!(mini_block_size % 8 == 0); Self::assert_supported_type(); + // Size miniblocks so that they can be efficiently decoded + let mini_block_size = match T::T::PHYSICAL_TYPE { + Type::INT32 => 32, + Type::INT64 => 64, + _ => unreachable!(), + }; + + let num_mini_blocks = DEFAULT_NUM_MINI_BLOCKS; + let block_size = mini_block_size * num_mini_blocks; + assert_eq!(block_size % 128, 0); + DeltaBitPackEncoder { page_header_writer: BitWriter::new(MAX_PAGE_HEADER_WRITER_SIZE), bit_writer: BitWriter::new(MAX_BIT_WRITER_SIZE), @@ -562,7 +358,7 @@ impl DeltaBitPackEncoder { self.bit_writer.put_zigzag_vlq_int(min_delta); // Slice to store bit width for each mini block - let offset = self.bit_writer.skip(self.num_mini_blocks)?; + let offset = self.bit_writer.skip(self.num_mini_blocks); for i in 0..self.num_mini_blocks { // Find how many values we need to encode - either block size or whatever @@ -580,14 +376,15 @@ impl DeltaBitPackEncoder { } // Compute the max delta in current mini block - let mut max_delta = i64::min_value(); + let mut max_delta = i64::MIN; for j in 0..n { max_delta = cmp::max(max_delta, self.deltas[i * self.mini_block_size + j]); } // Compute bit width to store (max_delta - min_delta) - let bit_width = num_required_bits(self.subtract_u64(max_delta, min_delta)) as usize; + let bit_width = + num_required_bits(self.subtract_u64(max_delta, min_delta)) as usize; self.bit_writer.write_at(offset + i, bit_width as u8); // Encode values in current mini block using min_delta and bit_width @@ -746,6 +543,12 @@ pub struct DeltaLengthByteArrayEncoder { _phantom: PhantomData, } +impl Default for DeltaLengthByteArrayEncoder { + fn default() -> Self { + Self::new() + } +} + impl DeltaLengthByteArrayEncoder { /// Creates new delta length byte array encoder. pub fn new() -> Self { @@ -825,6 +628,12 @@ pub struct DeltaByteArrayEncoder { _phantom: PhantomData, } +impl Default for DeltaByteArrayEncoder { + fn default() -> Self { + Self::new() + } +} + impl DeltaByteArrayEncoder { /// Creates new delta byte array encoder. pub fn new() -> Self { @@ -920,7 +729,7 @@ mod tests { use crate::schema::types::{ ColumnDescPtr, ColumnDescriptor, ColumnPath, Type as SchemaType, }; - use crate::util::test_common::{random_bytes, RandGen}; + use crate::util::test_common::rand_gen::{random_bytes, RandGen}; const TEST_SET_SIZE: usize = 1024; @@ -1034,7 +843,7 @@ mod tests { run_test::( -1, &[Int96::from(vec![1, 2, 3]), Int96::from(vec![2, 3, 4])], - 32, + 24, ); run_test::( -1, @@ -1062,7 +871,7 @@ mod tests { Encoding::PLAIN_DICTIONARY | Encoding::RLE_DICTIONARY => { Box::new(create_test_dict_encoder::(type_length)) } - _ => create_test_encoder::(type_length, encoding), + _ => create_test_encoder::(encoding), }; assert_eq!(encoder.estimated_data_encoded_size(), initial_size); @@ -1088,7 +897,7 @@ mod tests { let mut values = vec![]; values.extend_from_slice(&[true; 16]); values.extend_from_slice(&[false; 16]); - run_test::(Encoding::RLE, -1, &values, 0, 2, 0); + run_test::(Encoding::RLE, -1, &values, 0, 6, 0); // DELTA_LENGTH_BYTE_ARRAY run_test::( @@ -1115,7 +924,7 @@ mod tests { #[test] fn test_issue_47() { let mut encoder = - create_test_encoder::(0, Encoding::DELTA_BYTE_ARRAY); + create_test_encoder::(Encoding::DELTA_BYTE_ARRAY); let mut decoder = create_test_decoder::(0, Encoding::DELTA_BYTE_ARRAY); @@ -1167,7 +976,7 @@ mod tests { impl> EncodingTester for T { fn test_internal(enc: Encoding, total: usize, type_length: i32) -> Result<()> { - let mut encoder = create_test_encoder::(type_length, enc); + let mut encoder = create_test_encoder::(enc); let mut decoder = create_test_decoder::(type_length, enc); let mut values = >::gen_vec(type_length, total); let mut result_data = vec![T::T::default(); total]; @@ -1269,8 +1078,7 @@ mod tests { encoding: Encoding, err: Option, ) { - let descr = create_test_col_desc_ptr(-1, T::get_physical_type()); - let encoder = get_encoder::(descr, encoding); + let encoder = get_encoder::(encoding); match err { Some(parquet_error) => { assert!(encoder.is_err()); @@ -1297,12 +1105,8 @@ mod tests { )) } - fn create_test_encoder( - type_len: i32, - enc: Encoding, - ) -> Box> { - let desc = create_test_col_desc_ptr(type_len, T::get_physical_type()); - get_encoder(desc, enc).unwrap() + fn create_test_encoder(enc: Encoding) -> Box> { + get_encoder(enc).unwrap() } fn create_test_decoder( diff --git a/parquet/src/encodings/levels.rs b/parquet/src/encodings/levels.rs index 28fb63881693..95384926ddba 100644 --- a/parquet/src/encodings/levels.rs +++ b/parquet/src/encodings/levels.rs @@ -21,9 +21,9 @@ use super::rle::{RleDecoder, RleEncoder}; use crate::basic::Encoding; use crate::data_type::AsBytes; -use crate::errors::{ParquetError, Result}; +use crate::errors::Result; use crate::util::{ - bit_util::{ceil, num_required_bits, BitReader, BitWriter}, + bit_util::{ceil, num_required_bits, read_num_bytes, BitReader, BitWriter}, memory::ByteBufferPtr, }; @@ -65,22 +65,21 @@ impl LevelEncoder { /// Used to encode levels for Data Page v1. /// /// Panics, if encoding is not supported. - pub fn v1(encoding: Encoding, max_level: i16, byte_buffer: Vec) -> Self { + pub fn v1(encoding: Encoding, max_level: i16, capacity: usize) -> Self { + let capacity_bytes = max_buffer_size(encoding, max_level, capacity); + let mut buffer = Vec::with_capacity(capacity_bytes); let bit_width = num_required_bits(max_level as u64); match encoding { - Encoding::RLE => LevelEncoder::Rle(RleEncoder::new_from_buf( - bit_width, - byte_buffer, - mem::size_of::(), - )), + Encoding::RLE => { + // Reserve space for length header + buffer.extend_from_slice(&[0; 4]); + LevelEncoder::Rle(RleEncoder::new_from_buf(bit_width, buffer)) + } Encoding::BIT_PACKED => { // Here we set full byte buffer without adjusting for num_buffered_values, // because byte buffer will already be allocated with size from // `max_buffer_size()` method. - LevelEncoder::BitPacked( - bit_width, - BitWriter::new_from_buf(byte_buffer, 0), - ) + LevelEncoder::BitPacked(bit_width, BitWriter::new_from_buf(buffer)) } _ => panic!("Unsupported encoding type {}", encoding), } @@ -88,59 +87,54 @@ impl LevelEncoder { /// Creates new level encoder based on RLE encoding. Used to encode Data Page v2 /// repetition and definition levels. - pub fn v2(max_level: i16, byte_buffer: Vec) -> Self { + pub fn v2(max_level: i16, capacity: usize) -> Self { + let capacity_bytes = max_buffer_size(Encoding::RLE, max_level, capacity); + let buffer = Vec::with_capacity(capacity_bytes); let bit_width = num_required_bits(max_level as u64); - LevelEncoder::RleV2(RleEncoder::new_from_buf(bit_width, byte_buffer, 0)) + LevelEncoder::RleV2(RleEncoder::new_from_buf(bit_width, buffer)) } /// Put/encode levels vector into this level encoder. /// Returns number of encoded values that are less than or equal to length of the /// input buffer. - /// - /// RLE and BIT_PACKED level encoders return Err() when internal buffer overflows or - /// flush fails. #[inline] - pub fn put(&mut self, buffer: &[i16]) -> Result { + pub fn put(&mut self, buffer: &[i16]) -> usize { let mut num_encoded = 0; match *self { LevelEncoder::Rle(ref mut encoder) | LevelEncoder::RleV2(ref mut encoder) => { for value in buffer { - if !encoder.put(*value as u64)? { - return Err(general_err!("RLE buffer is full")); - } + encoder.put(*value as u64); num_encoded += 1; } - encoder.flush()?; + encoder.flush(); } LevelEncoder::BitPacked(bit_width, ref mut encoder) => { for value in buffer { - if !encoder.put_value(*value as u64, bit_width as usize) { - return Err(general_err!("Not enough bytes left")); - } + encoder.put_value(*value as u64, bit_width as usize); num_encoded += 1; } encoder.flush(); } } - Ok(num_encoded) + num_encoded } /// Finalizes level encoder, flush all intermediate buffers and return resulting /// encoded buffer. Returned buffer is already truncated to encoded bytes only. #[inline] - pub fn consume(self) -> Result> { + pub fn consume(self) -> Vec { match self { LevelEncoder::Rle(encoder) => { - let mut encoded_data = encoder.consume()?; + let mut encoded_data = encoder.consume(); // Account for the buffer offset let encoded_len = encoded_data.len() - mem::size_of::(); let len = (encoded_len as i32).to_le(); let len_bytes = len.as_bytes(); encoded_data[0..len_bytes.len()].copy_from_slice(len_bytes); - Ok(encoded_data) + encoded_data } LevelEncoder::RleV2(encoder) => encoder.consume(), - LevelEncoder::BitPacked(_, encoder) => Ok(encoder.consume()), + LevelEncoder::BitPacked(_, encoder) => encoder.consume(), } } } @@ -148,12 +142,14 @@ impl LevelEncoder { /// Decoder for definition/repetition levels. /// Currently only supports RLE and BIT_PACKED encoding for Data Page v1 and /// RLE for Data Page v2. +#[allow(unused)] pub enum LevelDecoder { Rle(Option, RleDecoder), RleV2(Option, RleDecoder), BitPacked(Option, u8, BitReader), } +#[allow(unused)] impl LevelDecoder { /// Creates new level decoder based on encoding and max definition/repetition level. /// This method only initializes level decoder, `set_data` method must be called @@ -196,7 +192,7 @@ impl LevelDecoder { LevelDecoder::Rle(ref mut num_values, ref mut decoder) => { *num_values = Some(num_buffered_values); let i32_size = mem::size_of::(); - let data_size = read_num_bytes!(i32, i32_size, data.as_ref()) as usize; + let data_size = read_num_bytes::(i32_size, data.as_ref()) as usize; decoder.set_data(data.range(i32_size, data_size)); i32_size + data_size } @@ -280,17 +276,16 @@ impl LevelDecoder { mod tests { use super::*; - use crate::util::test_common::random_numbers_range; + use crate::util::test_common::rand_gen::random_numbers_range; fn test_internal_roundtrip(enc: Encoding, levels: &[i16], max_level: i16, v2: bool) { - let size = max_buffer_size(enc, max_level, levels.len()); let mut encoder = if v2 { - LevelEncoder::v2(max_level, vec![0; size]) + LevelEncoder::v2(max_level, levels.len()) } else { - LevelEncoder::v1(enc, max_level, vec![0; size]) + LevelEncoder::v1(enc, max_level, levels.len()) }; - encoder.put(levels).expect("put() should be OK"); - let encoded_levels = encoder.consume().expect("consume() should be OK"); + encoder.put(levels); + let encoded_levels = encoder.consume(); let byte_buf = ByteBufferPtr::new(encoded_levels); let mut decoder; @@ -315,14 +310,13 @@ mod tests { max_level: i16, v2: bool, ) { - let size = max_buffer_size(enc, max_level, levels.len()); let mut encoder = if v2 { - LevelEncoder::v2(max_level, vec![0; size]) + LevelEncoder::v2(max_level, levels.len()) } else { - LevelEncoder::v1(enc, max_level, vec![0; size]) + LevelEncoder::v1(enc, max_level, levels.len()) }; - encoder.put(levels).expect("put() should be OK"); - let encoded_levels = encoder.consume().expect("consume() should be OK"); + encoder.put(levels); + let encoded_levels = encoder.consume(); let byte_buf = ByteBufferPtr::new(encoded_levels); let mut decoder; @@ -363,15 +357,14 @@ mod tests { max_level: i16, v2: bool, ) { - let size = max_buffer_size(enc, max_level, levels.len()); let mut encoder = if v2 { - LevelEncoder::v2(max_level, vec![0; size]) + LevelEncoder::v2(max_level, levels.len()) } else { - LevelEncoder::v1(enc, max_level, vec![0; size]) + LevelEncoder::v1(enc, max_level, levels.len()) }; // Encode only one value - let num_encoded = encoder.put(&levels[0..1]).expect("put() should be OK"); - let encoded_levels = encoder.consume().expect("consume() should be OK"); + let num_encoded = encoder.put(&levels[0..1]); + let encoded_levels = encoder.consume(); assert_eq!(num_encoded, 1); let byte_buf = ByteBufferPtr::new(encoded_levels); @@ -391,33 +384,6 @@ mod tests { assert_eq!(buffer[0..num_decoded], levels[0..num_decoded]); } - // Tests when encoded values are larger than encoder's buffer - fn test_internal_roundtrip_overflow( - enc: Encoding, - levels: &[i16], - max_level: i16, - v2: bool, - ) { - let size = max_buffer_size(enc, max_level, levels.len()); - let mut encoder = if v2 { - LevelEncoder::v2(max_level, vec![0; size]) - } else { - LevelEncoder::v1(enc, max_level, vec![0; size]) - }; - let mut found_err = false; - // Insert a large number of values, so we run out of space - for _ in 0..100 { - if let Err(err) = encoder.put(levels) { - assert!(format!("{}", err).contains("Not enough bytes left")); - found_err = true; - break; - }; - } - if !found_err { - panic!("Failed test: no buffer overflow"); - } - } - #[test] fn test_roundtrip_one() { let levels = vec![0, 1, 1, 1, 1, 0, 0, 0, 0, 1]; @@ -470,6 +436,15 @@ mod tests { test_internal_roundtrip(Encoding::RLE, &levels, max_level, true); } + #[test] + fn test_rountrip_max() { + let levels = vec![0, i16::MAX, i16::MAX, i16::MAX, 0]; + let max_level = i16::MAX; + test_internal_roundtrip(Encoding::RLE, &levels, max_level, false); + test_internal_roundtrip(Encoding::BIT_PACKED, &levels, max_level, false); + test_internal_roundtrip(Encoding::RLE, &levels, max_level, true); + } + #[test] fn test_roundtrip_underflow() { let levels = vec![1, 1, 2, 3, 2, 1, 1, 2, 3, 1]; @@ -484,15 +459,6 @@ mod tests { test_internal_roundtrip_underflow(Encoding::RLE, &levels, max_level, true); } - #[test] - fn test_roundtrip_overflow() { - let levels = vec![1, 1, 2, 3, 2, 1, 1, 2, 3, 1]; - let max_level = 3; - test_internal_roundtrip_overflow(Encoding::RLE, &levels, max_level, false); - test_internal_roundtrip_overflow(Encoding::BIT_PACKED, &levels, max_level, false); - test_internal_roundtrip_overflow(Encoding::RLE, &levels, max_level, true); - } - #[test] fn test_rle_decoder_set_data_range() { // Buffer containing both repetition and definition levels diff --git a/parquet/src/encodings/mod.rs b/parquet/src/encodings/mod.rs index 9577a8e624f6..894c4fb961ee 100644 --- a/parquet/src/encodings/mod.rs +++ b/parquet/src/encodings/mod.rs @@ -18,4 +18,4 @@ pub mod decoding; pub mod encoding; pub mod levels; -experimental_mod_crate!(rle); +experimental!(pub(crate) mod rle); diff --git a/parquet/src/encodings/rle.rs b/parquet/src/encodings/rle.rs index 5f6f91a8bd0a..39a0aa4d03da 100644 --- a/parquet/src/encodings/rle.rs +++ b/parquet/src/encodings/rle.rs @@ -45,7 +45,6 @@ use crate::util::{ /// Maximum groups per bit-packed run. Current value is 64. const MAX_GROUPS_PER_BIT_PACKED_RUN: usize = 1 << 6; const MAX_VALUES_PER_BIT_PACKED_RUN: usize = MAX_GROUPS_PER_BIT_PACKED_RUN * 8; -const MAX_WRITER_BUF_SIZE: usize = 1 << 10; /// A RLE/Bit-Packing hybrid encoder. // TODO: tracking memory usage @@ -56,9 +55,6 @@ pub struct RleEncoder { // Underlying writer which holds an internal buffer. bit_writer: BitWriter, - // The maximum byte size a single run can take. - max_run_byte_size: usize, - // Buffered values for bit-packed runs. buffered_values: [u64; 8], @@ -82,26 +78,18 @@ pub struct RleEncoder { } impl RleEncoder { + #[allow(unused)] pub fn new(bit_width: u8, buffer_len: usize) -> Self { - let buffer = vec![0; buffer_len]; - RleEncoder::new_from_buf(bit_width, buffer, 0) - } - - /// Initialize the encoder from existing `buffer` and the starting offset `start`. - pub fn new_from_buf(bit_width: u8, buffer: Vec, start: usize) -> Self { - assert!(bit_width <= 64, "bit_width ({}) out of range.", bit_width); - let max_run_byte_size = RleEncoder::min_buffer_size(bit_width); - assert!( - buffer.len() >= max_run_byte_size, - "buffer length {} must be greater than {}", - buffer.len(), - max_run_byte_size - ); - let bit_writer = BitWriter::new_from_buf(buffer, start); + let buffer = Vec::with_capacity(buffer_len); + RleEncoder::new_from_buf(bit_width, buffer) + } + + /// Initialize the encoder from existing `buffer` + pub fn new_from_buf(bit_width: u8, buffer: Vec) -> Self { + let bit_writer = BitWriter::new_from_buf(buffer); RleEncoder { bit_width, bit_writer, - max_run_byte_size, buffered_values: [0; 8], num_buffered_values: 0, current_value: 0, @@ -139,23 +127,21 @@ impl RleEncoder { } /// Encodes `value`, which must be representable with `bit_width` bits. - /// Returns true if the value fits in buffer, false if it doesn't, or - /// error if something is wrong. #[inline] - pub fn put(&mut self, value: u64) -> Result { + pub fn put(&mut self, value: u64) { // This function buffers 8 values at a time. After seeing 8 values, it // decides whether the current run should be encoded in bit-packed or RLE. if self.current_value == value { self.repeat_count += 1; if self.repeat_count > 8 { // A continuation of last value. No need to buffer. - return Ok(true); + return; } } else { if self.repeat_count >= 8 { // The current RLE run has ended and we've gathered enough. Flush first. assert_eq!(self.bit_packed_count, 0); - self.flush_rle_run()?; + self.flush_rle_run(); } self.repeat_count = 1; self.current_value = value; @@ -166,13 +152,12 @@ impl RleEncoder { if self.num_buffered_values == 8 { // Buffered values are full. Flush them. assert_eq!(self.bit_packed_count % 8, 0); - self.flush_buffered_values()?; + self.flush_buffered_values(); } - - Ok(true) } #[inline] + #[allow(unused)] pub fn buffer(&self) -> &[u8] { self.bit_writer.buffer() } @@ -182,23 +167,30 @@ impl RleEncoder { self.bit_writer.bytes_written() } + #[allow(unused)] + pub fn is_empty(&self) -> bool { + self.bit_writer.bytes_written() == 0 + } + #[inline] - pub fn consume(mut self) -> Result> { - self.flush()?; - Ok(self.bit_writer.consume()) + pub fn consume(mut self) -> Vec { + self.flush(); + self.bit_writer.consume() } /// Borrow equivalent of the `consume` method. /// Call `clear()` after invoking this method. #[inline] - pub fn flush_buffer(&mut self) -> Result<&[u8]> { - self.flush()?; - Ok(self.bit_writer.flush_buffer()) + #[allow(unused)] + pub fn flush_buffer(&mut self) -> &[u8] { + self.flush(); + self.bit_writer.flush_buffer() } /// Clears the internal state so this encoder can be reused (e.g., after becoming /// full). #[inline] + #[allow(unused)] pub fn clear(&mut self) { self.bit_writer.clear(); self.num_buffered_values = 0; @@ -211,7 +203,7 @@ impl RleEncoder { /// Flushes all remaining values and return the final byte buffer maintained by the /// internal writer. #[inline] - pub fn flush(&mut self) -> Result<()> { + pub fn flush(&mut self) { if self.bit_packed_count > 0 || self.repeat_count > 0 || self.num_buffered_values > 0 @@ -220,7 +212,7 @@ impl RleEncoder { && (self.repeat_count == self.num_buffered_values || self.num_buffered_values == 0); if self.repeat_count > 0 && all_repeat { - self.flush_rle_run()?; + self.flush_rle_run(); } else { // Buffer the last group of bit-packed values to 8 by padding with 0s. if self.num_buffered_values > 0 { @@ -230,38 +222,32 @@ impl RleEncoder { } } self.bit_packed_count += self.num_buffered_values; - self.flush_bit_packed_run(true)?; + self.flush_bit_packed_run(true); self.repeat_count = 0; } } - Ok(()) } - fn flush_rle_run(&mut self) -> Result<()> { + fn flush_rle_run(&mut self) { assert!(self.repeat_count > 0); let indicator_value = self.repeat_count << 1; - let mut result = self.bit_writer.put_vlq_int(indicator_value as u64); - result &= self.bit_writer.put_aligned( + self.bit_writer.put_vlq_int(indicator_value as u64); + self.bit_writer.put_aligned( self.current_value, bit_util::ceil(self.bit_width as i64, 8) as usize, ); - if !result { - return Err(general_err!("Failed to write RLE run")); - } self.num_buffered_values = 0; self.repeat_count = 0; - Ok(()) } - fn flush_bit_packed_run(&mut self, update_indicator_byte: bool) -> Result<()> { + fn flush_bit_packed_run(&mut self, update_indicator_byte: bool) { if self.indicator_byte_pos < 0 { - self.indicator_byte_pos = self.bit_writer.skip(1)? as i64; + self.indicator_byte_pos = self.bit_writer.skip(1) as i64; } // Write all buffered values as bit-packed literals for i in 0..self.num_buffered_values { - let _ = self - .bit_writer + self.bit_writer .put_value(self.buffered_values[i], self.bit_width as usize); } self.num_buffered_values = 0; @@ -269,30 +255,27 @@ impl RleEncoder { // Write the indicator byte to the reserved position in `bit_writer` let num_groups = self.bit_packed_count / 8; let indicator_byte = ((num_groups << 1) | 1) as u8; - if !self.bit_writer.put_aligned_offset( + self.bit_writer.put_aligned_offset( indicator_byte, 1, self.indicator_byte_pos as usize, - ) { - return Err(general_err!("Not enough space to write indicator byte")); - } + ); self.indicator_byte_pos = -1; self.bit_packed_count = 0; } - Ok(()) } #[inline(never)] - fn flush_buffered_values(&mut self) -> Result<()> { + fn flush_buffered_values(&mut self) { if self.repeat_count >= 8 { self.num_buffered_values = 0; if self.bit_packed_count > 0 { // In this case we choose RLE encoding. Flush the current buffered values // as bit-packed encoding. assert_eq!(self.bit_packed_count % 8, 0); - self.flush_bit_packed_run(true)? + self.flush_bit_packed_run(true) } - return Ok(()); + return; } self.bit_packed_count += self.num_buffered_values; @@ -301,12 +284,11 @@ impl RleEncoder { // We've reached the maximum value that can be hold in a single bit-packed // run. assert!(self.indicator_byte_pos >= 0); - self.flush_bit_packed_run(true)?; + self.flush_bit_packed_run(true); } else { - self.flush_bit_packed_run(false)?; + self.flush_bit_packed_run(false); } self.repeat_count = 0; - Ok(()) } } @@ -434,6 +416,37 @@ impl RleDecoder { Ok(values_read) } + #[inline(never)] + pub fn skip(&mut self, num_values: usize) -> Result { + let mut values_skipped = 0; + while values_skipped < num_values { + if self.rle_left > 0 { + let num_values = + cmp::min(num_values - values_skipped, self.rle_left as usize); + self.rle_left -= num_values as u32; + values_skipped += num_values; + } else if self.bit_packed_left > 0 { + let mut num_values = + cmp::min(num_values - values_skipped, self.bit_packed_left as usize); + let bit_reader = + self.bit_reader.as_mut().expect("bit_reader should be set"); + + num_values = bit_reader.skip(num_values, self.bit_width as usize); + if num_values == 0 { + // Handle writers which truncate the final block + self.bit_packed_left = 0; + continue; + } + self.bit_packed_left -= num_values as u32; + values_skipped += num_values; + } else if !self.reload() { + break; + } + } + + Ok(values_skipped) + } + #[inline(never)] pub fn get_batch_with_dict( &mut self, @@ -538,17 +551,36 @@ mod tests { assert_eq!(buffer, expected); } + #[test] + fn test_rle_skip_int32() { + // Test data: 0-7 with bit width 3 + // 00000011 10001000 11000110 11111010 + let data = ByteBufferPtr::new(vec![0x03, 0x88, 0xC6, 0xFA]); + let mut decoder: RleDecoder = RleDecoder::new(3); + decoder.set_data(data); + let expected = vec![2, 3, 4, 5, 6, 7]; + let skipped = decoder.skip(2).expect("skipping values"); + assert_eq!(skipped, 2); + + let mut buffer = vec![0; 6]; + let remaining = decoder + .get_batch::(&mut buffer) + .expect("getting remaining"); + assert_eq!(remaining, 6); + assert_eq!(buffer, expected); + } + #[test] fn test_rle_consume_flush_buffer() { let data = vec![1, 1, 1, 2, 2, 3, 3, 3]; let mut encoder1 = RleEncoder::new(3, 256); let mut encoder2 = RleEncoder::new(3, 256); for value in data { - encoder1.put(value as u64).unwrap(); - encoder2.put(value as u64).unwrap(); + encoder1.put(value as u64); + encoder2.put(value as u64); } - let res1 = encoder1.flush_buffer().unwrap(); - let res2 = encoder2.consume().unwrap(); + let res1 = encoder1.flush_buffer(); + let res2 = encoder2.consume(); assert_eq!(res1, &res2[..]); } @@ -596,6 +628,52 @@ mod tests { assert_eq!(buffer, expected); } + #[test] + fn test_rle_skip_bool() { + // RLE test data: 50 1s followed by 50 0s + // 01100100 00000001 01100100 00000000 + let data1 = ByteBufferPtr::new(vec![0x64, 0x01, 0x64, 0x00]); + + // Bit-packing test data: alternating 1s and 0s, 100 total + // 100 / 8 = 13 groups + // 00011011 10101010 ... 00001010 + let data2 = ByteBufferPtr::new(vec![ + 0x1B, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, + 0x0A, + ]); + + let mut decoder: RleDecoder = RleDecoder::new(1); + decoder.set_data(data1); + let mut buffer = vec![true; 50]; + let expected = vec![false; 50]; + + let skipped = decoder.skip(50).expect("skipping first 50"); + assert_eq!(skipped, 50); + let remainder = decoder + .get_batch::(&mut buffer) + .expect("getting remaining 50"); + assert_eq!(remainder, 50); + assert_eq!(buffer, expected); + + decoder.set_data(data2); + let mut buffer = vec![false; 50]; + let mut expected = vec![]; + for i in 0..50 { + if i % 2 == 0 { + expected.push(false); + } else { + expected.push(true); + } + } + let skipped = decoder.skip(50).expect("skipping first 50"); + assert_eq!(skipped, 50); + let remainder = decoder + .get_batch::(&mut buffer) + .expect("getting remaining 50"); + assert_eq!(remainder, 50); + assert_eq!(buffer, expected); + } + #[test] fn test_rle_decode_with_dict_int32() { // Test RLE encoding: 3 0s followed by 4 1s followed by 5 2s @@ -631,6 +709,42 @@ mod tests { assert_eq!(buffer, expected); } + #[test] + fn test_rle_skip_dict() { + // Test RLE encoding: 3 0s followed by 4 1s followed by 5 2s + // 00000110 00000000 00001000 00000001 00001010 00000010 + let dict = vec![10, 20, 30]; + let data = ByteBufferPtr::new(vec![0x06, 0x00, 0x08, 0x01, 0x0A, 0x02]); + let mut decoder: RleDecoder = RleDecoder::new(3); + decoder.set_data(data); + let mut buffer = vec![0; 10]; + let expected = vec![10, 20, 20, 20, 20, 30, 30, 30, 30, 30]; + let skipped = decoder.skip(2).expect("skipping two values"); + assert_eq!(skipped, 2); + let remainder = decoder + .get_batch_with_dict::(&dict, &mut buffer, 10) + .expect("getting remainder"); + assert_eq!(remainder, 10); + assert_eq!(buffer, expected); + + // Test bit-pack encoding: 345345345455 (2 groups: 8 and 4) + // 011 100 101 011 100 101 011 100 101 100 101 101 + // 00000011 01100011 11000111 10001110 00000011 01100101 00001011 + let dict = vec!["aaa", "bbb", "ccc", "ddd", "eee", "fff"]; + let data = ByteBufferPtr::new(vec![0x03, 0x63, 0xC7, 0x8E, 0x03, 0x65, 0x0B]); + let mut decoder: RleDecoder = RleDecoder::new(3); + decoder.set_data(data); + let mut buffer = vec![""; 8]; + let expected = vec!["eee", "fff", "ddd", "eee", "fff", "eee", "fff", "fff"]; + let skipped = decoder.skip(4).expect("skipping four values"); + assert_eq!(skipped, 4); + let remainder = decoder + .get_batch_with_dict::<&str>(dict.as_slice(), buffer.as_mut_slice(), 8) + .expect("getting remainder"); + assert_eq!(remainder, 8); + assert_eq!(buffer, expected); + } + fn validate_rle( values: &[i64], bit_width: u8, @@ -640,10 +754,9 @@ mod tests { let buffer_len = 64 * 1024; let mut encoder = RleEncoder::new(bit_width, buffer_len); for v in values { - let result = encoder.put(*v as u64); - assert!(result.is_ok()); + encoder.put(*v as u64) } - let buffer = ByteBufferPtr::new(encoder.consume().expect("Expect consume() OK")); + let buffer = ByteBufferPtr::new(encoder.consume()); if expected_len != -1 { assert_eq!(buffer.len(), expected_len as usize); } @@ -796,9 +909,9 @@ mod tests { let values: Vec = vec![0, 1, 1, 1, 1, 0, 0, 0, 0, 1]; let mut encoder = RleEncoder::new(bit_width, buffer_len); for v in &values { - assert!(encoder.put(*v as u64).expect("put() should be OK")); + encoder.put(*v as u64) } - let buffer = encoder.consume().expect("consume() should be OK"); + let buffer = encoder.consume(); let mut decoder = RleDecoder::new(bit_width); decoder.set_data(ByteBufferPtr::new(buffer)); let mut actual_values: Vec = vec![0; values.len()]; @@ -812,12 +925,10 @@ mod tests { let buffer_len = 64 * 1024; let mut encoder = RleEncoder::new(bit_width, buffer_len); for v in values { - let result = encoder.put(*v as u64).expect("put() should be OK"); - assert!(result, "put() should not return false"); + encoder.put(*v as u64) } - let buffer = - ByteBufferPtr::new(encoder.consume().expect("consume() should be OK")); + let buffer = ByteBufferPtr::new(encoder.consume()); // Verify read let mut decoder = RleDecoder::new(bit_width); diff --git a/parquet/src/errors.rs b/parquet/src/errors.rs index c2fb5bd66cf9..c4f5faaaacae 100644 --- a/parquet/src/errors.rs +++ b/parquet/src/errors.rs @@ -22,7 +22,7 @@ use std::{cell, io, result, str}; #[cfg(any(feature = "arrow", test))] use arrow::error::ArrowError; -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone, Eq)] pub enum ParquetError { /// General Parquet error. /// Returned when code violates normal workflow of working with Parquet files. @@ -148,8 +148,8 @@ macro_rules! arrow_err { // Convert parquet error into other errors #[cfg(any(feature = "arrow", test))] -impl Into for ParquetError { - fn into(self) -> ArrowError { - ArrowError::ParquetError(format!("{}", self)) +impl From for ArrowError { + fn from(p: ParquetError) -> Self { + Self::ParquetError(format!("{}", p)) } } diff --git a/parquet/src/file/footer.rs b/parquet/src/file/footer.rs index dc1d66d0fa44..30afec55eb3a 100644 --- a/parquet/src/file/footer.rs +++ b/parquet/src/file/footer.rs @@ -17,7 +17,6 @@ use std::{io::Read, sync::Arc}; -use byteorder::{ByteOrder, LittleEndian}; use parquet_format::{ColumnOrder as TColumnOrder, FileMetaData as TFileMetaData}; use thrift::protocol::TCompactInputProtocol; @@ -62,19 +61,8 @@ pub fn parse_metadata(chunk_reader: &R) -> Result Result { } // get the metadata length from the footer - let metadata_len = LittleEndian::read_i32(&slice[..4]); + let metadata_len = i32::from_le_bytes(slice[..4].try_into().unwrap()); metadata_len.try_into().map_err(|_| { general_err!( "Invalid Parquet file. Metadata length is less than zero ({})", diff --git a/parquet/src/file/metadata.rs b/parquet/src/file/metadata.rs index bffe538cc72f..018dd95d9f35 100644 --- a/parquet/src/file/metadata.rs +++ b/parquet/src/file/metadata.rs @@ -50,15 +50,18 @@ use crate::schema::types::{ Type as SchemaType, }; +pub type ParquetColumnIndex = Vec>; +pub type ParquetOffsetIndex = Vec>>; + /// Global Parquet metadata. #[derive(Debug, Clone)] pub struct ParquetMetaData { file_metadata: FileMetaData, row_groups: Vec, /// Page index for all pages in each column chunk - page_indexes: Option>>, + page_indexes: Option, /// Offset index for all pages in each column chunk - offset_indexes: Option>>>, + offset_indexes: Option, } impl ParquetMetaData { @@ -76,8 +79,8 @@ impl ParquetMetaData { pub fn new_with_page_index( file_metadata: FileMetaData, row_groups: Vec, - page_indexes: Option>>, - offset_indexes: Option>>>, + page_indexes: Option, + offset_indexes: Option, ) -> Self { ParquetMetaData { file_metadata, @@ -109,12 +112,12 @@ impl ParquetMetaData { } /// Returns page indexes in this file. - pub fn page_indexes(&self) -> Option<&Vec>> { + pub fn page_indexes(&self) -> Option<&ParquetColumnIndex> { self.page_indexes.as_ref() } /// Returns offset indexes in this file. - pub fn offset_indexes(&self) -> Option<&Vec>>> { + pub fn offset_indexes(&self) -> Option<&ParquetOffsetIndex> { self.offset_indexes.as_ref() } } @@ -831,6 +834,12 @@ pub struct ColumnIndexBuilder { valid: bool, } +impl Default for ColumnIndexBuilder { + fn default() -> Self { + Self::new() + } +} + impl ColumnIndexBuilder { pub fn new() -> Self { ColumnIndexBuilder { @@ -884,6 +893,12 @@ pub struct OffsetIndexBuilder { current_first_row_index: i64, } +impl Default for OffsetIndexBuilder { + fn default() -> Self { + Self::new() + } +} + impl OffsetIndexBuilder { pub fn new() -> Self { OffsetIndexBuilder { diff --git a/parquet/src/file/page_encoding_stats.rs b/parquet/src/file/page_encoding_stats.rs index 3180c7820802..e499a094ae00 100644 --- a/parquet/src/file/page_encoding_stats.rs +++ b/parquet/src/file/page_encoding_stats.rs @@ -21,7 +21,7 @@ use parquet_format::{ }; /// PageEncodingStats for a column chunk and data page. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct PageEncodingStats { /// the page type (data/dic/...) pub page_type: PageType, diff --git a/parquet/src/file/page_index/index.rs b/parquet/src/file/page_index/index.rs index 45381234c027..f29b80accae2 100644 --- a/parquet/src/file/page_index/index.rs +++ b/parquet/src/file/page_index/index.rs @@ -47,6 +47,7 @@ impl PageIndex { } #[derive(Debug, Clone, PartialEq)] +#[allow(non_camel_case_types)] pub enum Index { /// Sometimes reading page index from parquet file /// will only return pageLocations without min_max index, diff --git a/parquet/src/file/page_index/index_reader.rs b/parquet/src/file/page_index/index_reader.rs index 33499e7426a5..e6a4e5981022 100644 --- a/parquet/src/file/page_index/index_reader.rs +++ b/parquet/src/file/page_index/index_reader.rs @@ -34,8 +34,12 @@ pub fn read_columns_indexes( let (offset, lengths) = get_index_offset_and_lengths(chunks)?; let length = lengths.iter().sum::(); + if length == 0 { + return Ok(vec![Index::NONE; chunks.len()]); + } + //read all need data into buffer - let mut reader = reader.get_read(offset, reader.len() as usize)?; + let mut reader = reader.get_read(offset, length)?; let mut data = vec![0; length]; reader.read_exact(&mut data)?; @@ -64,8 +68,12 @@ pub fn read_pages_locations( ) -> Result>, ParquetError> { let (offset, total_length) = get_location_offset_and_total_length(chunks)?; + if total_length == 0 { + return Ok(vec![]); + } + //read all need data into buffer - let mut reader = reader.get_read(offset, reader.len() as usize)?; + let mut reader = reader.get_read(offset, total_length)?; let mut data = vec![0; total_length]; reader.read_exact(&mut data)?; @@ -82,7 +90,7 @@ pub fn read_pages_locations( //Get File offsets of every ColumnChunk's page_index //If there are invalid offset return a zero offset with empty lengths. -fn get_index_offset_and_lengths( +pub(crate) fn get_index_offset_and_lengths( chunks: &[ColumnChunkMetaData], ) -> Result<(u64, Vec), ParquetError> { let first_col_metadata = if let Some(chunk) = chunks.first() { @@ -111,7 +119,7 @@ fn get_index_offset_and_lengths( //Get File offset of ColumnChunk's pages_locations //If there are invalid offset return a zero offset with zero length. -fn get_location_offset_and_total_length( +pub(crate) fn get_location_offset_and_total_length( chunks: &[ColumnChunkMetaData], ) -> Result<(u64, usize), ParquetError> { let metadata = if let Some(chunk) = chunks.first() { @@ -133,7 +141,7 @@ fn get_location_offset_and_total_length( Ok((offset, total_length)) } -fn deserialize_column_index( +pub(crate) fn deserialize_column_index( data: &[u8], column_type: Type, ) -> Result { diff --git a/parquet/src/file/page_index/mod.rs b/parquet/src/file/page_index/mod.rs index fc87ef20448f..bb7808f16487 100644 --- a/parquet/src/file/page_index/mod.rs +++ b/parquet/src/file/page_index/mod.rs @@ -17,4 +17,6 @@ pub mod index; pub mod index_reader; + +#[cfg(test)] pub(crate) mod range; diff --git a/parquet/src/file/page_index/range.rs b/parquet/src/file/page_index/range.rs index 06c06553ccd5..e9741ec8e7fd 100644 --- a/parquet/src/file/page_index/range.rs +++ b/parquet/src/file/page_index/range.rs @@ -213,6 +213,7 @@ impl RowRanges { result } + #[allow(unused)] pub fn row_count(&self) -> usize { self.ranges.iter().map(|x| x.count()).sum() } diff --git a/parquet/src/file/properties.rs b/parquet/src/file/properties.rs index 9ca7c4daa597..57dae323d892 100644 --- a/parquet/src/file/properties.rs +++ b/parquet/src/file/properties.rs @@ -68,7 +68,8 @@ const DEFAULT_CREATED_BY: &str = env!("PARQUET_CREATED_BY"); /// Parquet writer version. /// /// Basic constant, which is not part of the Thrift definition. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(non_camel_case_types)] pub enum WriterVersion { PARQUET_1_0, PARQUET_2_0, @@ -360,7 +361,7 @@ impl WriterPropertiesBuilder { fn get_mut_props(&mut self, col: ColumnPath) -> &mut ColumnProperties { self.column_properties .entry(col) - .or_insert(ColumnProperties::new()) + .or_insert_with(ColumnProperties::new) } /// Sets encoding for a column. diff --git a/parquet/src/file/reader.rs b/parquet/src/file/reader.rs index d752273655c5..70ff37a41e15 100644 --- a/parquet/src/file/reader.rs +++ b/parquet/src/file/reader.rs @@ -18,6 +18,7 @@ //! Contains file reader API and provides methods to access file metadata, row group //! readers to read individual column chunks, or access record iterator. +use bytes::Bytes; use std::{boxed::Box, io::Read, sync::Arc}; use crate::column::page::PageIterator; @@ -45,9 +46,25 @@ pub trait Length { /// For an object store reader, each read can be mapped to a range request. pub trait ChunkReader: Length + Send + Sync { type T: Read + Send; - /// get a serialy readeable slice of the current reader + /// Get a serially readable slice of the current reader /// This should fail if the slice exceeds the current bounds fn get_read(&self, start: u64, length: usize) -> Result; + + /// Get a range as bytes + /// This should fail if the exact number of bytes cannot be read + fn get_bytes(&self, start: u64, length: usize) -> Result { + let mut buffer = Vec::with_capacity(length); + let read = self.get_read(start, length)?.read_to_end(&mut buffer)?; + + if read != length { + return Err(eof_err!( + "Expected to read {} bytes, read only {}", + length, + read + )); + } + Ok(buffer.into()) + } } // ---------------------------------------------------------------------- diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index 766813f11aee..f3beb57c02e5 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -18,14 +18,15 @@ //! Contains implementations of the reader traits FileReader, RowGroupReader and PageReader //! Also contains implementations of the ChunkReader for files (with buffering) and byte arrays (RAM) -use bytes::{Buf, Bytes}; use std::collections::VecDeque; +use std::io::Cursor; use std::{convert::TryFrom, fs::File, io::Read, path::Path, sync::Arc}; +use bytes::{Buf, Bytes}; use parquet_format::{PageHeader, PageLocation, PageType}; use thrift::protocol::TCompactInputProtocol; -use crate::basic::{Compression, Encoding, Type}; +use crate::basic::{Encoding, Type}; use crate::column::page::{Page, PageMetadata, PageReader}; use crate::compression::{create_codec, Codec}; use crate::errors::{ParquetError, Result}; @@ -34,13 +35,10 @@ use crate::file::{footer, metadata::*, reader::*, statistics}; use crate::record::reader::RowIter; use crate::record::Row; use crate::schema::types::Type as SchemaType; -use crate::util::page_util::{calculate_row_count, get_pages_readable_slices}; use crate::util::{io::TryClone, memory::ByteBufferPtr}; - // export `SliceableCursor` and `FileSource` publically so clients can // re-use the logic in their own ParquetFileWriter wrappers -#[allow(deprecated)] -pub use crate::util::{cursor::SliceableCursor, io::FileSource}; +pub use crate::util::io::FileSource; // ---------------------------------------------------------------------- // Implementations of traits facilitating the creation of a new reader @@ -81,24 +79,12 @@ impl ChunkReader for Bytes { type T = bytes::buf::Reader; fn get_read(&self, start: u64, length: usize) -> Result { - let start = start as usize; - Ok(self.slice(start..start + length).reader()) - } -} - -#[allow(deprecated)] -impl Length for SliceableCursor { - fn len(&self) -> u64 { - SliceableCursor::len(self) + Ok(self.get_bytes(start, length)?.reader()) } -} - -#[allow(deprecated)] -impl ChunkReader for SliceableCursor { - type T = SliceableCursor; - fn get_read(&self, start: u64, length: usize) -> Result { - self.slice(start, length).map_err(|e| e.into()) + fn get_bytes(&self, start: u64, length: usize) -> Result { + let start = start as usize; + Ok(self.slice(start..start + length)) } } @@ -152,32 +138,32 @@ impl IntoIterator for SerializedFileReader { /// A serialized implementation for Parquet [`FileReader`]. pub struct SerializedFileReader { chunk_reader: Arc, - metadata: ParquetMetaData, + metadata: Arc, } +/// A predicate for filtering row groups, invoked with the metadata and index +/// of each row group in the file. Only row groups for which the predicate +/// evaluates to `true` will be scanned +pub type ReadGroupPredicate = Box bool>; + /// A builder for [`ReadOptions`]. /// For the predicates that are added to the builder, /// they will be chained using 'AND' to filter the row groups. +#[derive(Default)] pub struct ReadOptionsBuilder { - predicates: Vec bool>>, + predicates: Vec, enable_page_index: bool, } impl ReadOptionsBuilder { /// New builder pub fn new() -> Self { - ReadOptionsBuilder { - predicates: vec![], - enable_page_index: false, - } + Self::default() } /// Add a predicate on row group metadata to the reading option, /// Filter only row groups that match the predicate criteria - pub fn with_predicate( - mut self, - predicate: Box bool>, - ) -> Self { + pub fn with_predicate(mut self, predicate: ReadGroupPredicate) -> Self { self.predicates.push(predicate); self } @@ -214,7 +200,7 @@ impl ReadOptionsBuilder { /// Currently, only predicates on row group metadata are supported. /// All predicates will be chained using 'AND' to filter the row groups. pub struct ReadOptions { - predicates: Vec bool>>, + predicates: Vec, enable_page_index: bool, } @@ -225,7 +211,7 @@ impl SerializedFileReader { let metadata = footer::parse_metadata(&chunk_reader)?; Ok(Self { chunk_reader: Arc::new(chunk_reader), - metadata, + metadata: Arc::new(metadata), }) } @@ -265,23 +251,27 @@ impl SerializedFileReader { Ok(Self { chunk_reader: Arc::new(chunk_reader), - metadata: ParquetMetaData::new_with_page_index( + metadata: Arc::new(ParquetMetaData::new_with_page_index( metadata.file_metadata().clone(), filtered_row_groups, Some(columns_indexes), Some(offset_indexes), - ), + )), }) } else { Ok(Self { chunk_reader: Arc::new(chunk_reader), - metadata: ParquetMetaData::new( + metadata: Arc::new(ParquetMetaData::new( metadata.file_metadata().clone(), filtered_row_groups, - ), + )), }) } } + + pub(crate) fn metadata_ref(&self) -> &Arc { + &self.metadata + } } /// Get midpoint offset for a row group @@ -348,33 +338,19 @@ impl<'a, R: 'static + ChunkReader> RowGroupReader for SerializedRowGroupReader<' // TODO: fix PARQUET-816 fn get_column_page_reader(&self, i: usize) -> Result> { let col = self.metadata.column(i); - let (col_start, col_length) = col.byte_range(); - let page_reader = if let Some(offset_index) = self.metadata.page_offset_index() { - let col_chunk_offset_index = &offset_index[i]; - let (page_bufs, has_dict) = get_pages_readable_slices( - col_chunk_offset_index, - col_start, - self.chunk_reader.clone(), - )?; - SerializedPageReader::new_with_page_offsets( - col.num_values(), - col.compression(), - col.column_descr().physical_type(), - col_chunk_offset_index.clone(), - has_dict, - page_bufs, - )? - } else { - let file_chunk = - self.chunk_reader.get_read(col_start, col_length as usize)?; - SerializedPageReader::new( - file_chunk, - col.num_values(), - col.compression(), - col.column_descr().physical_type(), - )? - }; - Ok(Box::new(page_reader)) + + let page_locations = self + .metadata + .page_offset_index() + .as_ref() + .map(|x| x[i].clone()); + + Ok(Box::new(SerializedPageReader::new( + Arc::clone(&self.chunk_reader), + col, + self.metadata.num_rows() as usize, + page_locations, + )?)) } fn get_row_iter(&self, projection: Option) -> Result { @@ -389,6 +365,30 @@ pub(crate) fn read_page_header(input: &mut T) -> Result { Ok(page_header) } +/// Reads a [`PageHeader`] from the provided [`Read`] returning the number of bytes read +fn read_page_header_len(input: &mut T) -> Result<(usize, PageHeader)> { + /// A wrapper around a [`std::io::Read`] that keeps track of the bytes read + struct TrackedRead { + inner: R, + bytes_read: usize, + } + + impl Read for TrackedRead { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let v = self.inner.read(buf)?; + self.bytes_read += v; + Ok(v) + } + } + + let mut tracked = TrackedRead { + inner: input, + bytes_read: 0, + }; + let header = read_page_header(&mut tracked)?; + Ok((tracked.bytes_read, header)) +} + /// Decodes a [`Page`] from the provided `buffer` pub(crate) fn decode_page( page_header: PageHeader, @@ -484,83 +484,89 @@ pub(crate) fn decode_page( Ok(result) } -enum SerializedPages { - /// Read entire chunk - Chunk { buf: T }, - /// Read operate pages which can skip. +enum SerializedPageReaderState { + Values { + /// The current byte offset in the reader + offset: usize, + + /// The length of the chunk in bytes + remaining_bytes: usize, + + // If the next page header has already been "peeked", we will cache it and it`s length here + next_page_header: Option>, + }, Pages { - offset_index: Vec, - seen_num_data_pages: usize, - has_dictionary_page_to_read: bool, - page_bufs: VecDeque, + /// Remaining page locations + page_locations: VecDeque, + /// Remaining dictionary location if any + dictionary_page: Option, + /// The total number of rows in this column chunk + total_rows: usize, }, } /// A serialized implementation for Parquet [`PageReader`]. -pub struct SerializedPageReader { - // The file source buffer which references exactly the bytes for the column trunk - // to be read by this page reader. - buf: SerializedPages, +pub struct SerializedPageReader { + /// The chunk reader + reader: Arc, - // The compression codec for this column chunk. Only set for non-PLAIN codec. + /// The compression codec for this column chunk. Only set for non-PLAIN codec. decompressor: Option>, - // The number of values we have seen so far. - seen_num_values: i64, - - // The number of total values in this column chunk. - total_num_values: i64, - - // Column chunk type. + /// Column chunk type. physical_type: Type, + + state: SerializedPageReaderState, } -impl SerializedPageReader { - /// Creates a new serialized page reader from file source. +impl SerializedPageReader { + /// Creates a new serialized page reader from a chunk reader and metadata pub fn new( - buf: T, - total_num_values: i64, - compression: Compression, - physical_type: Type, + reader: Arc, + meta: &ColumnChunkMetaData, + total_rows: usize, + page_locations: Option>, ) -> Result { - let decompressor = create_codec(compression)?; - let result = Self { - buf: SerializedPages::Chunk { buf }, - total_num_values, - seen_num_values: 0, - decompressor, - physical_type, - }; - Ok(result) - } + let decompressor = create_codec(meta.compression())?; + let (start, len) = meta.byte_range(); + + let state = match page_locations { + Some(locations) => { + let dictionary_page = match locations.first() { + Some(dict_offset) if dict_offset.offset as u64 != start => { + Some(PageLocation { + offset: start as i64, + compressed_page_size: (dict_offset.offset as u64 - start) + as i32, + first_row_index: 0, + }) + } + _ => None, + }; - /// Creates a new serialized page reader from file source. - pub fn new_with_page_offsets( - total_num_values: i64, - compression: Compression, - physical_type: Type, - offset_index: Vec, - has_dictionary_page_to_read: bool, - page_bufs: VecDeque, - ) -> Result { - let decompressor = create_codec(compression)?; - let result = Self { - buf: SerializedPages::Pages { - offset_index, - seen_num_data_pages: 0, - has_dictionary_page_to_read, - page_bufs, + SerializedPageReaderState::Pages { + page_locations: locations.into(), + dictionary_page, + total_rows, + } + } + None => SerializedPageReaderState::Values { + offset: start as usize, + remaining_bytes: len as usize, + next_page_header: None, }, - total_num_values, - seen_num_values: 0, - decompressor, - physical_type, }; - Ok(result) + + Ok(Self { + reader, + decompressor, + state, + physical_type: meta.column_type(), + }) } } -impl Iterator for SerializedPageReader { +impl Iterator for SerializedPageReader { type Item = Result; fn next(&mut self) -> Option { @@ -568,133 +574,177 @@ impl Iterator for SerializedPageReader { } } -impl PageReader for SerializedPageReader { +impl PageReader for SerializedPageReader { fn get_next_page(&mut self) -> Result> { - let mut cursor; - let mut dictionary_cursor; - while self.seen_num_values < self.total_num_values { - match &mut self.buf { - SerializedPages::Chunk { buf } => { - cursor = buf; - } - SerializedPages::Pages { - offset_index, - seen_num_data_pages, - has_dictionary_page_to_read, - page_bufs, + loop { + let page = match &mut self.state { + SerializedPageReaderState::Values { + offset, + remaining_bytes: remaining, + next_page_header, } => { - if offset_index.len() <= *seen_num_data_pages { + if *remaining == 0 { return Ok(None); - } else if *seen_num_data_pages == 0 && *has_dictionary_page_to_read { - dictionary_cursor = page_bufs.pop_front().unwrap(); - cursor = &mut dictionary_cursor; - } else { - cursor = page_bufs.get_mut(*seen_num_data_pages).unwrap(); } - } - } - let page_header = read_page_header(cursor)?; + let mut read = self.reader.get_read(*offset as u64, *remaining)?; + let header = if let Some(header) = next_page_header.take() { + *header + } else { + let (header_len, header) = read_page_header_len(&mut read)?; + *offset += header_len; + *remaining -= header_len; + header + }; + let data_len = header.compressed_page_size as usize; + *offset += data_len; + *remaining -= data_len; + + if header.type_ == PageType::IndexPage { + continue; + } - let to_read = page_header.compressed_page_size as usize; - let mut buffer = Vec::with_capacity(to_read); - let read = cursor.take(to_read as u64).read_to_end(&mut buffer)?; + let mut buffer = Vec::with_capacity(data_len); + let read = read.take(data_len as u64).read_to_end(&mut buffer)?; - if read != to_read { - return Err(eof_err!( - "Expected to read {} bytes of page, read only {}", - to_read, - read - )); - } + if read != data_len { + return Err(eof_err!( + "Expected to read {} bytes of page, read only {}", + data_len, + read + )); + } - let buffer = ByteBufferPtr::new(buffer); - let result = match page_header.type_ { - PageType::DataPage | PageType::DataPageV2 => { - let decoded = decode_page( - page_header, - buffer, + decode_page( + header, + ByteBufferPtr::new(buffer), self.physical_type, self.decompressor.as_mut(), - )?; - self.seen_num_values += decoded.num_values() as i64; - if let SerializedPages::Pages { - seen_num_data_pages, - .. - } = &mut self.buf - { - *seen_num_data_pages += 1; - } - decoded + )? } - PageType::DictionaryPage => { - if let SerializedPages::Pages { - has_dictionary_page_to_read, - .. - } = &mut self.buf + SerializedPageReaderState::Pages { + page_locations, + dictionary_page, + .. + } => { + let front = match dictionary_page + .take() + .or_else(|| page_locations.pop_front()) { - *has_dictionary_page_to_read = false; - } + Some(front) => front, + None => return Ok(None), + }; + + let page_len = front.compressed_page_size as usize; + + let buffer = self.reader.get_bytes(front.offset as u64, page_len)?; + + let mut cursor = Cursor::new(buffer.as_ref()); + let header = read_page_header(&mut cursor)?; + let offset = cursor.position(); + + let bytes = buffer.slice(offset as usize..); decode_page( - page_header, - buffer, + header, + bytes.into(), self.physical_type, self.decompressor.as_mut(), )? } - _ => { - // For unknown page type (e.g., INDEX_PAGE), skip and read next. - continue; - } }; - return Ok(Some(result)); - } - // We are at the end of this column chunk and no more page left. Return None. - Ok(None) + return Ok(Some(page)); + } } fn peek_next_page(&mut self) -> Result> { - match &mut self.buf { - SerializedPages::Chunk { .. } => { Err(general_err!("Must set page_offset_index when using peek_next_page in SerializedPageReader.")) } - SerializedPages::Pages { offset_index, seen_num_data_pages, has_dictionary_page_to_read, .. } => { - if *seen_num_data_pages >= offset_index.len() { - Ok(None) - } else if *seen_num_data_pages == 0 && *has_dictionary_page_to_read { - // Will set `has_dictionary_page_to_read` false in `get_next_page`, - // assume dictionary page must be read and cannot be skipped. + match &mut self.state { + SerializedPageReaderState::Values { + offset, + remaining_bytes, + next_page_header, + } => { + loop { + if *remaining_bytes == 0 { + return Ok(None); + } + return if let Some(header) = next_page_header.as_ref() { + if let Ok(page_meta) = (&**header).try_into() { + Ok(Some(page_meta)) + } else { + // For unknown page type (e.g., INDEX_PAGE), skip and read next. + *next_page_header = None; + continue; + } + } else { + let mut read = + self.reader.get_read(*offset as u64, *remaining_bytes)?; + let (header_len, header) = read_page_header_len(&mut read)?; + *offset += header_len; + *remaining_bytes -= header_len; + let page_meta = if let Ok(page_meta) = (&header).try_into() { + Ok(Some(page_meta)) + } else { + // For unknown page type (e.g., INDEX_PAGE), skip and read next. + continue; + }; + *next_page_header = Some(Box::new(header)); + page_meta + }; + } + } + SerializedPageReaderState::Pages { + page_locations, + dictionary_page, + total_rows, + } => { + if dictionary_page.is_some() { Ok(Some(PageMetadata { - num_rows: usize::MIN, + num_rows: 0, is_dict: true, })) - } else { - let row_count = calculate_row_count( - offset_index, - *seen_num_data_pages, - self.total_num_values, - )?; + } else if let Some(page) = page_locations.front() { + let next_rows = page_locations + .get(1) + .map(|x| x.first_row_index as usize) + .unwrap_or(*total_rows); + Ok(Some(PageMetadata { - num_rows: row_count, + num_rows: next_rows - page.first_row_index as usize, is_dict: false, })) + } else { + Ok(None) } } } } fn skip_next_page(&mut self) -> Result<()> { - match &mut self.buf { - SerializedPages::Chunk { .. } => { Err(general_err!("Must set page_offset_index when using skip_next_page in SerializedPageReader.")) } - SerializedPages::Pages { offset_index, seen_num_data_pages, .. } => { - if offset_index.len() <= *seen_num_data_pages { - Err(general_err!( - "seen_num_data_pages is out of bound in SerializedPageReader." - )) + match &mut self.state { + SerializedPageReaderState::Values { + offset, + remaining_bytes, + next_page_header, + } => { + if let Some(buffered_header) = next_page_header.take() { + // The next page header has already been peeked, so just advance the offset + *offset += buffered_header.compressed_page_size as usize; + *remaining_bytes -= buffered_header.compressed_page_size as usize; } else { - *seen_num_data_pages += 1; - // Notice: maybe need 'self.seen_num_values += xxx', for now we can not get skip values in skip_next_page. - Ok(()) + let mut read = + self.reader.get_read(*offset as u64, *remaining_bytes)?; + let (header_len, header) = read_page_header_len(&mut read)?; + let data_page_size = header.compressed_page_size as usize; + *offset += header_len + data_page_size; + *remaining_bytes -= header_len + data_page_size; } + Ok(()) + } + SerializedPageReaderState::Pages { page_locations, .. } => { + page_locations.pop_front(); + + Ok(()) } } } @@ -702,16 +752,19 @@ impl PageReader for SerializedPageReader { #[cfg(test)] mod tests { - use super::*; + use std::sync::Arc; + + use parquet_format::BoundaryOrder; + use crate::basic::{self, ColumnOrder}; use crate::data_type::private::ParquetValueType; use crate::file::page_index::index::{ByteArrayIndex, Index, NativeIndex}; use crate::record::RowAccessor; use crate::schema::parser::parse_message_type; use crate::util::bit_util::from_le_slice; - use crate::util::test_common::{get_test_file, get_test_path}; - use parquet_format::BoundaryOrder; - use std::sync::Arc; + use crate::util::test_common::file_util::{get_test_file, get_test_path}; + + use super::*; #[test] fn test_cursor_and_file_has_the_same_behaviour() { @@ -1485,6 +1538,36 @@ mod tests { assert_eq!(vec.len(), 163); } + #[test] + fn test_skip_page_without_offset_index() { + let test_file = get_test_file("alltypes_tiny_pages_plain.parquet"); + + // use default SerializedFileReader without read offsetIndex + let reader_result = SerializedFileReader::new(test_file); + let reader = reader_result.unwrap(); + + let row_group_reader = reader.get_row_group(0).unwrap(); + + //use 'int_col', Boundary order: ASCENDING, total 325 pages. + let mut column_page_reader = row_group_reader.get_column_page_reader(4).unwrap(); + + let mut vec = vec![]; + + for i in 0..325 { + if i % 2 == 0 { + vec.push(column_page_reader.get_next_page().unwrap().unwrap()); + } else { + column_page_reader.peek_next_page().unwrap().unwrap(); + column_page_reader.skip_next_page().unwrap(); + } + } + //check read all pages. + assert!(column_page_reader.peek_next_page().unwrap().is_none()); + assert!(column_page_reader.get_next_page().unwrap().is_none()); + + assert_eq!(vec.len(), 163); + } + #[test] fn test_peek_page_with_dictionary_page() { let test_file = get_test_file("alltypes_tiny_pages.parquet"); @@ -1512,7 +1595,51 @@ mod tests { if i != 351 { assert!((meta.num_rows == 21) || (meta.num_rows == 20)); } else { - assert_eq!(meta.num_rows, 11); + // last page first row index is 7290, total row count is 7300 + // because first row start with zero, last page row count should be 10. + assert_eq!(meta.num_rows, 10); + } + assert!(!meta.is_dict); + vec.push(meta); + let page = column_page_reader.get_next_page().unwrap().unwrap(); + assert!(matches!(page.page_type(), basic::PageType::DATA_PAGE)); + } + + //check read all pages. + assert!(column_page_reader.peek_next_page().unwrap().is_none()); + assert!(column_page_reader.get_next_page().unwrap().is_none()); + + assert_eq!(vec.len(), 352); + } + + #[test] + fn test_peek_page_with_dictionary_page_without_offset_index() { + let test_file = get_test_file("alltypes_tiny_pages.parquet"); + + let reader_result = SerializedFileReader::new(test_file); + let reader = reader_result.unwrap(); + let row_group_reader = reader.get_row_group(0).unwrap(); + + //use 'string_col', Boundary order: UNORDERED, total 352 data ages and 1 dictionary page. + let mut column_page_reader = row_group_reader.get_column_page_reader(9).unwrap(); + + let mut vec = vec![]; + + let meta = column_page_reader.peek_next_page().unwrap().unwrap(); + assert!(meta.is_dict); + let page = column_page_reader.get_next_page().unwrap().unwrap(); + assert!(matches!(page.page_type(), basic::PageType::DICTIONARY_PAGE)); + + for i in 0..352 { + let meta = column_page_reader.peek_next_page().unwrap().unwrap(); + // have checked with `parquet-tools column-index -c string_col ./alltypes_tiny_pages.parquet` + // page meta has two scenarios(21, 20) of num_rows expect last page has 11 rows. + if i != 351 { + assert!((meta.num_rows == 21) || (meta.num_rows == 20)); + } else { + // last page first row index is 7290, total row count is 7300 + // because first row start with zero, last page row count should be 10. + assert_eq!(meta.num_rows, 10); } assert!(!meta.is_dict); vec.push(meta); diff --git a/parquet/src/file/statistics.rs b/parquet/src/file/statistics.rs index 40db3c1017fe..da2ec2e9a149 100644 --- a/parquet/src/file/statistics.rs +++ b/parquet/src/file/statistics.rs @@ -37,15 +37,47 @@ //! } //! ``` -use std::{cmp, fmt}; +use std::fmt; -use byteorder::{ByteOrder, LittleEndian}; use parquet_format::Statistics as TStatistics; use crate::basic::Type; +use crate::data_type::private::ParquetValueType; use crate::data_type::*; use crate::util::bit_util::from_ne_slice; +pub(crate) mod private { + use super::*; + + pub trait MakeStatistics { + fn make_statistics(statistics: ValueStatistics) -> Statistics + where + Self: Sized; + } + + macro_rules! gen_make_statistics { + ($value_ty:ty, $stat:ident) => { + impl MakeStatistics for $value_ty { + fn make_statistics(statistics: ValueStatistics) -> Statistics + where + Self: Sized, + { + Statistics::$stat(statistics) + } + } + }; + } + + gen_make_statistics!(bool, Boolean); + gen_make_statistics!(i32, Int32); + gen_make_statistics!(i64, Int64); + gen_make_statistics!(Int96, Int96); + gen_make_statistics!(f32, Float); + gen_make_statistics!(f64, Double); + gen_make_statistics!(ByteArray, ByteArray); + gen_make_statistics!(FixedLenByteArray, FixedLenByteArray); +} + // Macro to generate methods create Statistics. macro_rules! statistics_new_func { ($func:ident, $vtype:ty, $stat:ident) => { @@ -56,7 +88,7 @@ macro_rules! statistics_new_func { nulls: u64, is_deprecated: bool, ) -> Self { - Statistics::$stat(TypedStatistics::new( + Statistics::$stat(ValueStatistics::new( min, max, distinct, @@ -130,15 +162,15 @@ pub fn from_thrift( old_format, ), Type::INT32 => Statistics::int32( - min.map(|data| LittleEndian::read_i32(&data)), - max.map(|data| LittleEndian::read_i32(&data)), + min.map(|data| i32::from_le_bytes(data[..4].try_into().unwrap())), + max.map(|data| i32::from_le_bytes(data[..4].try_into().unwrap())), distinct_count, null_count, old_format, ), Type::INT64 => Statistics::int64( - min.map(|data| LittleEndian::read_i64(&data)), - max.map(|data| LittleEndian::read_i64(&data)), + min.map(|data| i64::from_le_bytes(data[..8].try_into().unwrap())), + max.map(|data| i64::from_le_bytes(data[..8].try_into().unwrap())), distinct_count, null_count, old_format, @@ -158,15 +190,15 @@ pub fn from_thrift( Statistics::int96(min, max, distinct_count, null_count, old_format) } Type::FLOAT => Statistics::float( - min.map(|data| LittleEndian::read_f32(&data)), - max.map(|data| LittleEndian::read_f32(&data)), + min.map(|data| f32::from_le_bytes(data[..4].try_into().unwrap())), + max.map(|data| f32::from_le_bytes(data[..4].try_into().unwrap())), distinct_count, null_count, old_format, ), Type::DOUBLE => Statistics::double( - min.map(|data| LittleEndian::read_f64(&data)), - max.map(|data| LittleEndian::read_f64(&data)), + min.map(|data| f64::from_le_bytes(data[..8].try_into().unwrap())), + max.map(|data| f64::from_le_bytes(data[..8].try_into().unwrap())), distinct_count, null_count, old_format, @@ -234,17 +266,39 @@ pub fn to_thrift(stats: Option<&Statistics>) -> Option { /// Statistics for a column chunk and data page. #[derive(Debug, Clone, PartialEq)] pub enum Statistics { - Boolean(TypedStatistics), - Int32(TypedStatistics), - Int64(TypedStatistics), - Int96(TypedStatistics), - Float(TypedStatistics), - Double(TypedStatistics), - ByteArray(TypedStatistics), - FixedLenByteArray(TypedStatistics), + Boolean(ValueStatistics), + Int32(ValueStatistics), + Int64(ValueStatistics), + Int96(ValueStatistics), + Float(ValueStatistics), + Double(ValueStatistics), + ByteArray(ValueStatistics), + FixedLenByteArray(ValueStatistics), +} + +impl From> for Statistics { + fn from(t: ValueStatistics) -> Self { + T::make_statistics(t) + } } impl Statistics { + pub fn new( + min: Option, + max: Option, + distinct_count: Option, + null_count: u64, + is_deprecated: bool, + ) -> Self { + Self::from(ValueStatistics::new( + min, + max, + distinct_count, + null_count, + is_deprecated, + )) + } + statistics_new_func![boolean, Option, Boolean]; statistics_new_func![int32, Option, Int32]; @@ -341,21 +395,24 @@ impl fmt::Display for Statistics { } /// Typed implementation for [`Statistics`]. -#[derive(Clone)] -pub struct TypedStatistics { - min: Option, - max: Option, +pub type TypedStatistics = ValueStatistics<::T>; + +/// Statistics for a particular `ParquetValueType` +#[derive(Clone, Eq, PartialEq)] +pub struct ValueStatistics { + min: Option, + max: Option, // Distinct count could be omitted in some cases distinct_count: Option, null_count: u64, is_min_max_deprecated: bool, } -impl TypedStatistics { +impl ValueStatistics { /// Creates new typed statistics. pub fn new( - min: Option, - max: Option, + min: Option, + max: Option, distinct_count: Option, null_count: u64, is_min_max_deprecated: bool, @@ -373,7 +430,7 @@ impl TypedStatistics { /// /// Panics if min value is not set, e.g. all values are `null`. /// Use `has_min_max_set` method to check that. - pub fn min(&self) -> &T::T { + pub fn min(&self) -> &T { self.min.as_ref().unwrap() } @@ -381,7 +438,7 @@ impl TypedStatistics { /// /// Panics if max value is not set, e.g. all values are `null`. /// Use `has_min_max_set` method to check that. - pub fn max(&self) -> &T::T { + pub fn max(&self) -> &T { self.max.as_ref().unwrap() } @@ -423,7 +480,7 @@ impl TypedStatistics { } } -impl fmt::Display for TypedStatistics { +impl fmt::Display for ValueStatistics { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{{")?; write!(f, "min: ")?; @@ -447,7 +504,7 @@ impl fmt::Display for TypedStatistics { } } -impl fmt::Debug for TypedStatistics { +impl fmt::Debug for ValueStatistics { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -462,16 +519,6 @@ impl fmt::Debug for TypedStatistics { } } -impl cmp::PartialEq for TypedStatistics { - fn eq(&self, other: &TypedStatistics) -> bool { - self.min == other.min - && self.max == other.max - && self.distinct_count == other.distinct_count - && self.null_count == other.null_count - && self.is_min_max_deprecated == other.is_min_max_deprecated - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index 10983c741355..7af4b0fa2c94 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -20,13 +20,14 @@ use std::{io::Write, sync::Arc}; -use byteorder::{ByteOrder, LittleEndian}; use parquet_format as parquet; use parquet_format::{ColumnIndex, OffsetIndex, RowGroup}; use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol}; use crate::basic::PageType; -use crate::column::writer::{get_typed_column_writer_mut, ColumnWriterImpl}; +use crate::column::writer::{ + get_typed_column_writer_mut, ColumnCloseResult, ColumnWriterImpl, +}; use crate::column::{ page::{CompressedPage, Page, PageWriteSpec, PageWriter}, writer::{get_column_writer, ColumnWriter}, @@ -35,10 +36,11 @@ use crate::data_type::DataType; use crate::errors::{ParquetError, Result}; use crate::file::{ metadata::*, properties::WriterPropertiesPtr, - statistics::to_thrift as statistics_to_thrift, FOOTER_SIZE, PARQUET_MAGIC, + statistics::to_thrift as statistics_to_thrift, PARQUET_MAGIC, +}; +use crate::schema::types::{ + self, ColumnDescPtr, SchemaDescPtr, SchemaDescriptor, TypePtr, }; -use crate::schema::types::{self, SchemaDescPtr, SchemaDescriptor, TypePtr}; -use crate::util::io::TryClone; /// A wrapper around a [`Write`] that keeps track of the number /// of bytes that have been written @@ -60,6 +62,11 @@ impl TrackedWrite { pub fn bytes_written(&self) -> usize { self.bytes_written } + + /// Returns the underlying writer. + pub fn into_inner(self) -> W { + self.inner + } } impl Write for TrackedWrite { @@ -74,24 +81,8 @@ impl Write for TrackedWrite { } } -/// Callback invoked on closing a column chunk, arguments are: -/// -/// - the number of bytes written -/// - the number of rows written -/// - the column chunk metadata -/// - the column index -/// - the offset index -/// -pub type OnCloseColumnChunk<'a> = Box< - dyn FnOnce( - u64, - u64, - ColumnChunkMetaData, - Option, - Option, - ) -> Result<()> - + 'a, ->; +/// Callback invoked on closing a column chunk +pub type OnCloseColumnChunk<'a> = Box Result<()> + 'a>; /// Callback invoked on closing a row group, arguments are: /// @@ -107,11 +98,6 @@ pub type OnCloseRowGroup<'a> = Box< + 'a, >; -#[deprecated = "use std::io::Write"] -pub trait ParquetWriter: Write + std::io::Seek + TryClone {} -#[allow(deprecated)] -impl ParquetWriter for T {} - // ---------------------------------------------------------------------- // Serialized impl for file & row group writers @@ -296,11 +282,10 @@ impl SerializedFileWriter { let end_pos = self.buf.bytes_written(); // Write footer - let mut footer_buffer: [u8; FOOTER_SIZE] = [0; FOOTER_SIZE]; let metadata_len = (end_pos - start_pos) as i32; - LittleEndian::write_i32(&mut footer_buffer, metadata_len); - (&mut footer_buffer[4..]).write_all(&PARQUET_MAGIC)?; - self.buf.write_all(&footer_buffer)?; + + self.buf.write_all(&metadata_len.to_le_bytes())?; + self.buf.write_all(&PARQUET_MAGIC)?; Ok(file_metadata) } @@ -312,6 +297,14 @@ impl SerializedFileWriter { Ok(()) } } + + /// Writes the file footer and returns the underlying writer. + pub fn into_inner(mut self) -> Result { + self.assert_previous_writer_closed()?; + let _ = self.write_metadata()?; + + Ok(self.buf.into_inner()) + } } /// Parquet row group writer API. @@ -367,22 +360,26 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> { } } - /// Returns the next column writer, if available; otherwise returns `None`. - /// In case of any IO error or Thrift error, or if row group writer has already been - /// closed returns `Err`. - pub fn next_column(&mut self) -> Result>> { + /// Returns the next column writer, if available, using the factory function; + /// otherwise returns `None`. + pub(crate) fn next_column_with_factory<'b, F, C>( + &'b mut self, + factory: F, + ) -> Result> + where + F: FnOnce( + ColumnDescPtr, + &'b WriterPropertiesPtr, + Box, + OnCloseColumnChunk<'b>, + ) -> Result, + { self.assert_previous_writer_closed()?; if self.column_index >= self.descr.num_columns() { return Ok(None); } let page_writer = Box::new(SerializedPageWriter::new(self.buf)); - let column_writer = get_column_writer( - self.descr.column(self.column_index), - self.props.clone(), - page_writer, - ); - self.column_index += 1; let total_bytes_written = &mut self.total_bytes_written; let total_rows_written = &mut self.total_rows_written; @@ -390,33 +387,47 @@ impl<'a, W: Write> SerializedRowGroupWriter<'a, W> { let column_indexes = &mut self.column_indexes; let offset_indexes = &mut self.offset_indexes; - let on_close = - |bytes_written, rows_written, metadata, column_index, offset_index| { - // Update row group writer metrics - *total_bytes_written += bytes_written; - column_chunks.push(metadata); - column_indexes.push(column_index); - offset_indexes.push(offset_index); - - if let Some(rows) = *total_rows_written { - if rows != rows_written { - return Err(general_err!( - "Incorrect number of rows, expected {} != {} rows", - rows, - rows_written - )); - } - } else { - *total_rows_written = Some(rows_written); + let on_close = |r: ColumnCloseResult| { + // Update row group writer metrics + *total_bytes_written += r.bytes_written; + column_chunks.push(r.metadata); + column_indexes.push(r.column_index); + offset_indexes.push(r.offset_index); + + if let Some(rows) = *total_rows_written { + if rows != r.rows_written { + return Err(general_err!( + "Incorrect number of rows, expected {} != {} rows", + rows, + r.rows_written + )); } + } else { + *total_rows_written = Some(r.rows_written); + } - Ok(()) - }; + Ok(()) + }; - Ok(Some(SerializedColumnWriter::new( - column_writer, - Some(Box::new(on_close)), - ))) + let column = self.descr.column(self.column_index); + self.column_index += 1; + + Ok(Some(factory( + column, + &self.props, + page_writer, + Box::new(on_close), + )?)) + } + + /// Returns the next column writer, if available; otherwise returns `None`. + /// In case of any IO error or Thrift error, or if row group writer has already been + /// closed returns `Err`. + pub fn next_column(&mut self) -> Result>> { + self.next_column_with_factory(|descr, props, page_writer, on_close| { + let column_writer = get_column_writer(descr, props.clone(), page_writer); + Ok(SerializedColumnWriter::new(column_writer, Some(on_close))) + }) } /// Closes this row group writer and returns row group metadata. @@ -489,26 +500,19 @@ impl<'a> SerializedColumnWriter<'a> { /// Close this [`SerializedColumnWriter] pub fn close(mut self) -> Result<()> { - let (bytes_written, rows_written, metadata, column_index, offset_index) = - match self.inner { - ColumnWriter::BoolColumnWriter(typed) => typed.close()?, - ColumnWriter::Int32ColumnWriter(typed) => typed.close()?, - ColumnWriter::Int64ColumnWriter(typed) => typed.close()?, - ColumnWriter::Int96ColumnWriter(typed) => typed.close()?, - ColumnWriter::FloatColumnWriter(typed) => typed.close()?, - ColumnWriter::DoubleColumnWriter(typed) => typed.close()?, - ColumnWriter::ByteArrayColumnWriter(typed) => typed.close()?, - ColumnWriter::FixedLenByteArrayColumnWriter(typed) => typed.close()?, - }; + let r = match self.inner { + ColumnWriter::BoolColumnWriter(typed) => typed.close()?, + ColumnWriter::Int32ColumnWriter(typed) => typed.close()?, + ColumnWriter::Int64ColumnWriter(typed) => typed.close()?, + ColumnWriter::Int96ColumnWriter(typed) => typed.close()?, + ColumnWriter::FloatColumnWriter(typed) => typed.close()?, + ColumnWriter::DoubleColumnWriter(typed) => typed.close()?, + ColumnWriter::ByteArrayColumnWriter(typed) => typed.close()?, + ColumnWriter::FixedLenByteArrayColumnWriter(typed) => typed.close()?, + }; if let Some(on_close) = self.on_close.take() { - on_close( - bytes_written, - rows_written, - metadata, - column_index, - offset_index, - )? + on_close(r)? } Ok(()) @@ -648,7 +652,7 @@ mod tests { use super::*; use bytes::Bytes; - use std::{fs::File, io::Cursor}; + use std::fs::File; use crate::basic::{Compression, Encoding, LogicalType, Repetition, Type}; use crate::column::page::PageReader; @@ -660,6 +664,7 @@ mod tests { statistics::{from_thrift, to_thrift, Statistics}, }; use crate::record::RowAccessor; + use crate::schema::types::{ColumnDescriptor, ColumnPath}; use crate::util::memory::ByteBufferPtr; #[test] @@ -1047,11 +1052,25 @@ mod tests { page_writer.close().unwrap(); } { + let reader = bytes::Bytes::from(buffer); + + let t = types::Type::primitive_type_builder("t", physical_type) + .build() + .unwrap(); + + let desc = ColumnDescriptor::new(Arc::new(t), 0, 0, ColumnPath::new(vec![])); + let meta = ColumnChunkMetaData::builder(Arc::new(desc)) + .set_compression(codec) + .set_total_compressed_size(reader.len() as i64) + .set_num_values(total_num_values) + .build() + .unwrap(); + let mut page_reader = SerializedPageReader::new( - Cursor::new(&buffer), - total_num_values, - codec, - physical_type, + Arc::new(reader), + &meta, + total_num_values as usize, + None, ) .unwrap(); diff --git a/parquet/src/lib.rs b/parquet/src/lib.rs index e86b9e65917a..90fe399e78d7 100644 --- a/parquet/src/lib.rs +++ b/parquet/src/lib.rs @@ -19,6 +19,9 @@ //! [Apache Parquet](https://parquet.apache.org/), part of //! the [Apache Arrow](https://arrow.apache.org/) project. //! +//! Please see the [parquet crates.io](https://crates.io/crates/parquet) +//! page for feature flags and tips to improve performance. +//! //! # Getting Started //! Start with some examples: //! @@ -30,43 +33,28 @@ //! //! 3. [arrow::async_reader] for `async` reading and writing parquet //! files to Arrow `RecordBatch`es (requires the `async` feature). -#![allow(dead_code)] -#![allow(non_camel_case_types)] -#![allow( - clippy::from_over_into, - clippy::new_without_default, - clippy::or_fun_call, - clippy::too_many_arguments -)] -/// Defines a module with an experimental public API +/// Defines a an item with an experimental public API /// /// The module will not be documented, and will only be public if the /// experimental feature flag is enabled /// -/// Experimental modules have no stability guarantees -macro_rules! experimental_mod { - ($module:ident $(, #[$meta:meta])*) => { - #[cfg(feature = "experimental")] +/// Experimental components have no stability guarantees +#[cfg(feature = "experimental")] +macro_rules! experimental { + ($(#[$meta:meta])* $vis:vis mod $module:ident) => { #[doc(hidden)] $(#[$meta])* pub mod $module; - #[cfg(not(feature = "experimental"))] - $(#[$meta])* - mod $module; - }; + } } -macro_rules! experimental_mod_crate { - ($module:ident $(, #[$meta:meta])*) => { - #[cfg(feature = "experimental")] - #[doc(hidden)] - $(#[$meta])* - pub mod $module; - #[cfg(not(feature = "experimental"))] +#[cfg(not(feature = "experimental"))] +macro_rules! experimental { + ($(#[$meta:meta])* $vis:vis mod $module:ident) => { $(#[$meta])* - pub(crate) mod $module; - }; + $vis mod $module; + } } #[macro_use] @@ -85,12 +73,12 @@ pub use self::encodings::{decoding, encoding}; #[doc(hidden)] pub use self::util::memory; -experimental_mod!(util, #[macro_use]); +experimental!(#[macro_use] mod util); #[cfg(any(feature = "arrow", test))] pub mod arrow; pub mod column; -experimental_mod!(compression); -experimental_mod!(encodings); +experimental!(mod compression); +experimental!(mod encodings); pub mod file; pub mod record; pub mod schema; diff --git a/parquet/src/record/api.rs b/parquet/src/record/api.rs index 0a360fd29648..7e1c484bf881 100644 --- a/parquet/src/record/api.rs +++ b/parquet/src/record/api.rs @@ -27,7 +27,7 @@ use crate::data_type::{ByteArray, Decimal, Int96}; use crate::errors::{ParquetError, Result}; use crate::schema::types::ColumnDescPtr; -#[cfg(any(feature = "cli", test))] +#[cfg(any(feature = "json", test))] use serde_json::Value; /// Macro as a shortcut to generate 'not yet implemented' panic error. @@ -79,7 +79,7 @@ impl Row { } } - #[cfg(any(feature = "cli", test))] + #[cfg(any(feature = "json", test))] pub fn to_json_value(&self) -> Value { Value::Object( self.fields @@ -667,7 +667,7 @@ impl Field { } } - #[cfg(any(feature = "cli", test))] + #[cfg(any(feature = "json", test))] pub fn to_json_value(&self) -> Value { match &self { Field::Null => Value::Null, @@ -1685,7 +1685,6 @@ mod tests { } #[test] - #[cfg(any(feature = "cli", test))] fn test_to_json_value() { assert_eq!(Field::Null.to_json_value(), Value::Null); assert_eq!(Field::Bool(true).to_json_value(), Value::Bool(true)); diff --git a/parquet/src/record/reader.rs b/parquet/src/record/reader.rs index 05b63661f09b..0b7e04587354 100644 --- a/parquet/src/record/reader.rs +++ b/parquet/src/record/reader.rs @@ -40,6 +40,12 @@ pub struct TreeBuilder { batch_size: usize, } +impl Default for TreeBuilder { + fn default() -> Self { + Self::new() + } +} + impl TreeBuilder { /// Creates new tree builder with default parameters. pub fn new() -> Self { @@ -822,7 +828,7 @@ mod tests { use crate::file::reader::{FileReader, SerializedFileReader}; use crate::record::api::{Field, Row, RowAccessor, RowFormatter}; use crate::schema::parser::parse_message_type; - use crate::util::test_common::{get_test_file, get_test_path}; + use crate::util::test_common::file_util::{get_test_file, get_test_path}; use std::convert::TryFrom; // Convenient macros to assemble row, list, map, and group. diff --git a/parquet/src/record/triplet.rs b/parquet/src/record/triplet.rs index de566a122e20..b4b4ea2f4a55 100644 --- a/parquet/src/record/triplet.rs +++ b/parquet/src/record/triplet.rs @@ -151,7 +151,7 @@ impl TripletIter { Field::convert_int64(typed.column_descr(), *typed.current_value()) } TripletIter::Int96TripletIter(ref typed) => { - Field::convert_int96(typed.column_descr(), typed.current_value().clone()) + Field::convert_int96(typed.column_descr(), *typed.current_value()) } TripletIter::FloatTripletIter(ref typed) => { Field::convert_float(typed.column_descr(), *typed.current_value()) @@ -363,7 +363,7 @@ mod tests { use crate::file::reader::{FileReader, SerializedFileReader}; use crate::schema::types::ColumnPath; - use crate::util::test_common::get_test_file; + use crate::util::test_common::file_util::get_test_file; #[test] #[should_panic(expected = "Expected positive batch size, found: 0")] diff --git a/parquet/src/schema/types.rs b/parquet/src/schema/types.rs index 8d624fe3d185..823803167ca1 100644 --- a/parquet/src/schema/types.rs +++ b/parquet/src/schema/types.rs @@ -593,7 +593,7 @@ impl<'a> GroupTypeBuilder<'a> { /// Basic type info. This contains information such as the name of the type, /// the repetition level, the logical type and the kind of the type (group, primitive). -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct BasicTypeInfo { name: String, repetition: Option, diff --git a/parquet/src/util/bit_pack.rs b/parquet/src/util/bit_pack.rs new file mode 100644 index 000000000000..8cea20de2539 --- /dev/null +++ b/parquet/src/util/bit_pack.rs @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Vectorised bit-packing utilities + +/// Macro that generates an unpack function taking the number of bits as a const generic +macro_rules! unpack_impl { + ($t:ty, $bytes:literal, $bits:tt) => { + pub fn unpack(input: &[u8], output: &mut [$t; $bits]) { + if NUM_BITS == 0 { + for out in output { + *out = 0; + } + return; + } + + assert!(NUM_BITS <= $bytes * 8); + + let mask = match NUM_BITS { + $bits => <$t>::MAX, + _ => ((1 << NUM_BITS) - 1), + }; + + assert!(input.len() >= NUM_BITS * $bytes); + + let r = |output_idx: usize| { + <$t>::from_le_bytes( + input[output_idx * $bytes..output_idx * $bytes + $bytes] + .try_into() + .unwrap(), + ) + }; + + seq_macro::seq!(i in 0..$bits { + let start_bit = i * NUM_BITS; + let end_bit = start_bit + NUM_BITS; + + let start_bit_offset = start_bit % $bits; + let end_bit_offset = end_bit % $bits; + let start_byte = start_bit / $bits; + let end_byte = end_bit / $bits; + if start_byte != end_byte && end_bit_offset != 0 { + let val = r(start_byte); + let a = val >> start_bit_offset; + let val = r(end_byte); + let b = val << (NUM_BITS - end_bit_offset); + + output[i] = a | (b & mask); + } else { + let val = r(start_byte); + output[i] = (val >> start_bit_offset) & mask; + } + }); + } + }; +} + +/// Macro that generates unpack functions that accept num_bits as a parameter +macro_rules! unpack { + ($name:ident, $t:ty, $bytes:literal, $bits:tt) => { + mod $name { + unpack_impl!($t, $bytes, $bits); + } + + /// Unpack packed `input` into `output` with a bit width of `num_bits` + pub fn $name(input: &[u8], output: &mut [$t; $bits], num_bits: usize) { + // This will get optimised into a jump table + seq_macro::seq!(i in 0..=$bits { + if i == num_bits { + return $name::unpack::(input, output); + } + }); + unreachable!("invalid num_bits {}", num_bits); + } + }; +} + +unpack!(unpack8, u8, 1, 8); +unpack!(unpack16, u16, 2, 16); +unpack!(unpack32, u32, 4, 32); +unpack!(unpack64, u64, 8, 64); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic() { + let input = [0xFF; 4096]; + + for i in 0..=8 { + let mut output = [0; 8]; + unpack8(&input, &mut output, i); + for (idx, out) in output.iter().enumerate() { + assert_eq!(out.trailing_ones() as usize, i, "out[{}] = {}", idx, out); + } + } + + for i in 0..=16 { + let mut output = [0; 16]; + unpack16(&input, &mut output, i); + for (idx, out) in output.iter().enumerate() { + assert_eq!(out.trailing_ones() as usize, i, "out[{}] = {}", idx, out); + } + } + + for i in 0..=32 { + let mut output = [0; 32]; + unpack32(&input, &mut output, i); + for (idx, out) in output.iter().enumerate() { + assert_eq!(out.trailing_ones() as usize, i, "out[{}] = {}", idx, out); + } + } + + for i in 0..=64 { + let mut output = [0; 64]; + unpack64(&input, &mut output, i); + for (idx, out) in output.iter().enumerate() { + assert_eq!(out.trailing_ones() as usize, i, "out[{}] = {}", idx, out); + } + } + } +} diff --git a/parquet/src/util/bit_packing.rs b/parquet/src/util/bit_packing.rs deleted file mode 100644 index 758992ab2723..000000000000 --- a/parquet/src/util/bit_packing.rs +++ /dev/null @@ -1,3662 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -/// Unpack 32 values with bit width `num_bits` from `in_ptr`, and write to `out_ptr`. -/// Return the `in_ptr` where the starting offset points to the first byte after all the -/// bytes that were consumed. -// TODO: may be better to make these more compact using if-else conditions. -// However, this may require const generics: -// https://github.com/rust-lang/rust/issues/44580 -// to eliminate the branching cost. -// TODO: we should use SIMD instructions to further optimize this. I have explored -// https://github.com/tantivy-search/bitpacking -// but the layout it uses for SIMD is different from Parquet. -// TODO: support packing as well, which is used for encoding. -pub unsafe fn unpack32( - mut in_ptr: *const u32, - out_ptr: *mut u32, - num_bits: usize, -) -> *const u32 { - in_ptr = match num_bits { - 0 => nullunpacker32(in_ptr, out_ptr), - 1 => unpack1_32(in_ptr, out_ptr), - 2 => unpack2_32(in_ptr, out_ptr), - 3 => unpack3_32(in_ptr, out_ptr), - 4 => unpack4_32(in_ptr, out_ptr), - 5 => unpack5_32(in_ptr, out_ptr), - 6 => unpack6_32(in_ptr, out_ptr), - 7 => unpack7_32(in_ptr, out_ptr), - 8 => unpack8_32(in_ptr, out_ptr), - 9 => unpack9_32(in_ptr, out_ptr), - 10 => unpack10_32(in_ptr, out_ptr), - 11 => unpack11_32(in_ptr, out_ptr), - 12 => unpack12_32(in_ptr, out_ptr), - 13 => unpack13_32(in_ptr, out_ptr), - 14 => unpack14_32(in_ptr, out_ptr), - 15 => unpack15_32(in_ptr, out_ptr), - 16 => unpack16_32(in_ptr, out_ptr), - 17 => unpack17_32(in_ptr, out_ptr), - 18 => unpack18_32(in_ptr, out_ptr), - 19 => unpack19_32(in_ptr, out_ptr), - 20 => unpack20_32(in_ptr, out_ptr), - 21 => unpack21_32(in_ptr, out_ptr), - 22 => unpack22_32(in_ptr, out_ptr), - 23 => unpack23_32(in_ptr, out_ptr), - 24 => unpack24_32(in_ptr, out_ptr), - 25 => unpack25_32(in_ptr, out_ptr), - 26 => unpack26_32(in_ptr, out_ptr), - 27 => unpack27_32(in_ptr, out_ptr), - 28 => unpack28_32(in_ptr, out_ptr), - 29 => unpack29_32(in_ptr, out_ptr), - 30 => unpack30_32(in_ptr, out_ptr), - 31 => unpack31_32(in_ptr, out_ptr), - 32 => unpack32_32(in_ptr, out_ptr), - _ => unimplemented!(), - }; - in_ptr -} - -unsafe fn nullunpacker32(in_buf: *const u32, mut out: *mut u32) -> *const u32 { - for _ in 0..32 { - *out = 0; - out = out.offset(1); - } - in_buf -} - -unsafe fn unpack1_32(in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 1) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 2) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 3) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 4) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 5) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 6) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 7) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 9) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 11) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 13) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 15) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 17) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 19) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 21) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 22) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 23) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 25) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 26) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 27) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 28) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 29) & 1; - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 30) & 1; - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - - in_buf.offset(1) -} - -unsafe fn unpack2_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 26) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 2); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - out = out.offset(1); - in_buf = in_buf.offset(1); - *out = (in_buf.read_unaligned()) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 26) % (1u32 << 2); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 2); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - - in_buf.offset(1) -} - -unsafe fn unpack3_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 21) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 27) % (1u32 << 3); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (3 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 25) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 3); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (3 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 23) % (1u32 << 3); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 26) % (1u32 << 3); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - - in_buf.offset(1) -} - -unsafe fn unpack4_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 4); - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 4); - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 4); - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 4); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 4); - - in_buf.offset(1) -} - -unsafe fn unpack5_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 25) % (1u32 << 5); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (5 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 23) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 28) % (1u32 << 5); - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (5 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 21) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 26) % (1u32 << 5); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (5 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 5); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (5 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 5); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 5); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - - in_buf.offset(1) -} - -unsafe fn unpack6_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 6); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (6 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 6); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (6 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 6); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 6); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (6 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 6); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (6 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 6); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 6); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - - in_buf.offset(1) -} - -unsafe fn unpack7_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 21) % (1u32 << 7); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (7 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 24) % (1u32 << 7); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (7 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 7); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (7 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 23) % (1u32 << 7); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (7 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 7); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (7 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 7); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (7 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 7); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 7); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - - in_buf.offset(1) -} - -unsafe fn unpack8_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 8); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 8); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - - in_buf.offset(1) -} - -unsafe fn unpack9_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 9); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (9 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 22) % (1u32 << 9); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (9 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 9); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (9 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 21) % (1u32 << 9); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (9 - 7); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 9); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (9 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 9); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (9 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 9); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (9 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 9); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (9 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 9); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 9); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 23; - - in_buf.offset(1) -} - -unsafe fn unpack10_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (10 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (10 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (10 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (10 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (10 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (10 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (10 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (10 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 10); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 10); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - - in_buf.offset(1) -} - -unsafe fn unpack11_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (11 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (11 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (11 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (11 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (11 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (11 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (11 - 7); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (11 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 19) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (11 - 9); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 11); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 20) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (11 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 11); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 21; - - in_buf.offset(1) -} - -unsafe fn unpack12_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 12); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (12 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 12); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (12 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 12); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (12 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 12); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (12 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 12); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (12 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 12); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (12 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 12); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (12 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 12); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (12 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 12); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - - in_buf.offset(1) -} - -unsafe fn unpack13_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 13); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (13 - 7); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (13 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 13); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (13 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (13 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 13); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (13 - 9); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (13 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 13); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (13 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (13 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 13); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 17) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (13 - 11); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (13 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 13); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 18) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (13 - 12); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (13 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 13); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 19; - - in_buf.offset(1) -} - -unsafe fn unpack14_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 14); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (14 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (14 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (14 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 14); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (14 - 12); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (14 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (14 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 14); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (14 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (14 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (14 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 14); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (14 - 12); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (14 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (14 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 14); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 18; - - in_buf.offset(1) -} - -unsafe fn unpack15_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 15); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 15) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (15 - 13); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (15 - 11); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (15 - 9); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (15 - 7); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (15 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (15 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (15 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 15); - out = out.offset(1); - *out = ((in_buf.read_unaligned()) >> 16) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (15 - 14); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (15 - 12); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (15 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (15 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (15 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (15 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 19; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (15 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 15); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 17; - - in_buf.offset(1) -} - -unsafe fn unpack16_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - out = out.offset(1); - in_buf = in_buf.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 16); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 16; - - in_buf.offset(1) -} - -unsafe fn unpack17_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 17; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (17 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 19; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (17 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (17 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (17 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (17 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (17 - 12); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (17 - 14); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 14) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (17 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (17 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (17 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (17 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (17 - 7); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (17 - 9); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (17 - 11); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (17 - 13); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 13) % (1u32 << 17); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (17 - 15); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 15; - - in_buf.offset(1) -} - -unsafe fn unpack18_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (18 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (18 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (18 - 12); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (18 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (18 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (18 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (18 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (18 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (18 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (18 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (18 - 12); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (18 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (18 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (18 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (18 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 18); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (18 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - - in_buf.offset(1) -} - -unsafe fn unpack19_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 19; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (19 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (19 - 12); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 12) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (19 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (19 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (19 - 11); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 11) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (19 - 17); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 17; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (19 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (19 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (19 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (19 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (19 - 9); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (19 - 15); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 15; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (19 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (19 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (19 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (19 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (19 - 7); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 19); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (19 - 13); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 13; - - in_buf.offset(1) -} - -unsafe fn unpack20_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (20 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (20 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (20 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (20 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (20 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (20 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (20 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (20 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (20 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (20 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (20 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (20 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (20 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (20 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (20 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 20); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (20 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - - in_buf.offset(1) -} - -unsafe fn unpack21_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (21 - 10); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 10) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (21 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (21 - 9); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 9) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (21 - 19); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 19; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (21 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (21 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (21 - 7); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (21 - 17); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 17; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (21 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (21 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (21 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (21 - 15); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 15; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (21 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (21 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (21 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (21 - 13); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 13; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (21 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (21 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (21 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 21); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (21 - 11); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 11; - - in_buf.offset(1) -} - -unsafe fn unpack22_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (22 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (22 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (22 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (22 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (22 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (22 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (22 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (22 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (22 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (22 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (22 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (22 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (22 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (22 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (22 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (22 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (22 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (22 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 22); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (22 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (22 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - - in_buf.offset(1) -} - -unsafe fn unpack23_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 23); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (23 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (23 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 23); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (23 - 19); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 19; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (23 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (23 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 23); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (23 - 15); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 15; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (23 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 23); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (23 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (23 - 11); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 11; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (23 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 23); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (23 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (23 - 7); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 7) % (1u32 << 23); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (23 - 21); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (23 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (23 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 23); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (23 - 17); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 17; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (23 - 8); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 8) % (1u32 << 23); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (23 - 22); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (23 - 13); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 13; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (23 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 23); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (23 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (23 - 9); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 9; - - in_buf.offset(1) -} - -unsafe fn unpack24_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 24); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 24); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 24); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 24); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 24); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 24); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 24); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 24); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (24 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (24 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - - in_buf.offset(1) -} - -unsafe fn unpack25_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 25); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (25 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (25 - 11); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 11; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (25 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 25); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (25 - 22); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (25 - 15); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 15; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (25 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (25 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 25); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (25 - 19); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 19; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (25 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (25 - 5); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 5) % (1u32 << 25); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 23)) << (25 - 23); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (25 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (25 - 9); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 9; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (25 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 25); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (25 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (25 - 13); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 13; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (25 - 6); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 6) % (1u32 << 25); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (25 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (25 - 17); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 17; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (25 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (25 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 25); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (25 - 21); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (25 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (25 - 7); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 7; - - in_buf.offset(1) -} - -unsafe fn unpack26_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 26); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (26 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (26 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (26 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (26 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 26); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (26 - 22); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (26 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (26 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (26 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 26); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (26 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (26 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (26 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (26 - 6); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 6; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 26); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (26 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (26 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (26 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (26 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 26); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (26 - 22); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (26 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (26 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (26 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 26); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (26 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (26 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (26 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (26 - 6); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 6; - - in_buf.offset(1) -} - -unsafe fn unpack27_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 27); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (27 - 22); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (27 - 17); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 17; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (27 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (27 - 7); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 7; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (27 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 27); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (27 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (27 - 19); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 19; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (27 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (27 - 9); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 9; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (27 - 4); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 4) % (1u32 << 27); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (27 - 26); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (27 - 21); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (27 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (27 - 11); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 11; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (27 - 6); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 6; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (27 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 27); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 23)) << (27 - 23); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (27 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (27 - 13); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 13; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (27 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (27 - 3); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 3) % (1u32 << 27); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 25)) << (27 - 25); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (27 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (27 - 15); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 15; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (27 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (27 - 5); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 5; - - in_buf.offset(1) -} - -unsafe fn unpack28_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 28); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (28 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (28 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (28 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (28 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (28 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (28 - 4); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 4; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 28); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (28 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (28 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (28 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (28 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (28 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (28 - 4); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 4; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 28); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (28 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (28 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (28 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (28 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (28 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (28 - 4); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 4; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 28); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (28 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (28 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (28 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (28 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (28 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (28 - 4); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 4; - - in_buf.offset(1) -} - -unsafe fn unpack29_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 29); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (29 - 26); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 23)) << (29 - 23); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (29 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (29 - 17); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 17; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (29 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (29 - 11); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 11; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (29 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (29 - 5); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 5; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (29 - 2); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 2) % (1u32 << 29); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 28)) << (29 - 28); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 25)) << (29 - 25); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (29 - 22); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (29 - 19); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 19; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (29 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (29 - 13); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 13; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (29 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (29 - 7); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 7; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (29 - 4); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 4; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (29 - 1); - out = out.offset(1); - - *out = ((in_buf.read_unaligned()) >> 1) % (1u32 << 29); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 27)) << (29 - 27); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (29 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (29 - 21); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (29 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (29 - 15); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 15; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (29 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (29 - 9); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 9; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (29 - 6); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 6; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (29 - 3); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 3; - - in_buf.offset(1) -} - -unsafe fn unpack30_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 30); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 28)) << (30 - 28); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (30 - 26); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (30 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (30 - 22); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (30 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (30 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (30 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (30 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (30 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (30 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (30 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (30 - 6); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 6; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (30 - 4); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 4; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (30 - 2); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 2; - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) % (1u32 << 30); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 28)) << (30 - 28); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (30 - 26); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (30 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (30 - 22); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (30 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (30 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (30 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (30 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (30 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (30 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (30 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (30 - 6); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 6; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (30 - 4); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 4; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (30 - 2); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 2; - - in_buf.offset(1) -} - -unsafe fn unpack31_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = (in_buf.read_unaligned()) % (1u32 << 31); - out = out.offset(1); - *out = (in_buf.read_unaligned()) >> 31; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 30)) << (31 - 30); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 30; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 29)) << (31 - 29); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 29; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 28)) << (31 - 28); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 28; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 27)) << (31 - 27); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 27; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 26)) << (31 - 26); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 26; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 25)) << (31 - 25); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 25; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 24)) << (31 - 24); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 24; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 23)) << (31 - 23); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 23; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 22)) << (31 - 22); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 22; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 21)) << (31 - 21); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 21; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 20)) << (31 - 20); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 20; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 19)) << (31 - 19); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 19; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 18)) << (31 - 18); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 18; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 17)) << (31 - 17); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 17; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 16)) << (31 - 16); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 16; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 15)) << (31 - 15); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 15; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 14)) << (31 - 14); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 14; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 13)) << (31 - 13); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 13; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 12)) << (31 - 12); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 12; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 11)) << (31 - 11); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 11; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 10)) << (31 - 10); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 10; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 9)) << (31 - 9); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 9; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 8)) << (31 - 8); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 8; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 7)) << (31 - 7); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 7; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 6)) << (31 - 6); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 6; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 5)) << (31 - 5); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 5; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 4)) << (31 - 4); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 4; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 3)) << (31 - 3); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 3; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 2)) << (31 - 2); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 2; - in_buf = in_buf.offset(1); - *out |= ((in_buf.read_unaligned()) % (1u32 << 1)) << (31 - 1); - out = out.offset(1); - - *out = (in_buf.read_unaligned()) >> 1; - - in_buf.offset(1) -} - -unsafe fn unpack32_32(mut in_buf: *const u32, mut out: *mut u32) -> *const u32 { - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - in_buf = in_buf.offset(1); - out = out.offset(1); - - *out = in_buf.read_unaligned(); - - in_buf.offset(1) -} diff --git a/parquet/src/util/bit_util.rs b/parquet/src/util/bit_util.rs index c4c1f96f9f4c..68b2f2b2550d 100644 --- a/parquet/src/util/bit_util.rs +++ b/parquet/src/util/bit_util.rs @@ -18,8 +18,8 @@ use std::{cmp, mem::size_of}; use crate::data_type::AsBytes; -use crate::errors::{ParquetError, Result}; -use crate::util::{bit_packing::unpack32, memory::ByteBufferPtr}; +use crate::util::bit_pack::{unpack16, unpack32, unpack64, unpack8}; +use crate::util::memory::ByteBufferPtr; #[inline] pub fn from_ne_slice(bs: &[u8]) -> T { @@ -88,49 +88,17 @@ impl FromBytes for bool { from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64, f32, f64 } -/// Reads `$size` of bytes from `$src`, and reinterprets them as type `$ty`, in -/// little-endian order. `$ty` must implement the `Default` trait. Otherwise this won't -/// compile. +/// Reads `size` of bytes from `src`, and reinterprets them as type `ty`, in +/// little-endian order. /// This is copied and modified from byteorder crate. -macro_rules! read_num_bytes { - ($ty:ty, $size:expr, $src:expr) => {{ - assert!($size <= $src.len()); - let mut buffer = <$ty as $crate::util::bit_util::FromBytes>::Buffer::default(); - buffer.as_mut()[..$size].copy_from_slice(&$src[..$size]); - <$ty>::from_ne_bytes(buffer) - }}; -} - -/// Converts value `val` of type `T` to a byte vector, by reading `num_bytes` from `val`. -/// NOTE: if `val` is less than the size of `T` then it can be truncated. -#[inline] -pub fn convert_to_bytes(val: &T, num_bytes: usize) -> Vec +pub(crate) fn read_num_bytes(size: usize, src: &[u8]) -> T where - T: ?Sized + AsBytes, + T: FromBytes, { - let mut bytes: Vec = vec![0; num_bytes]; - memcpy_value(val.as_bytes(), num_bytes, &mut bytes); - bytes -} - -#[inline] -pub fn memcpy(source: &[u8], target: &mut [u8]) { - assert!(target.len() >= source.len()); - target[..source.len()].copy_from_slice(source) -} - -#[inline] -pub fn memcpy_value(source: &T, num_bytes: usize, target: &mut [u8]) -where - T: ?Sized + AsBytes, -{ - assert!( - target.len() >= num_bytes, - "Not enough space. Only had {} bytes but need to put {} bytes", - target.len(), - num_bytes - ); - memcpy(&source.as_bytes()[..num_bytes], target) + assert!(size <= src.len()); + let mut buffer = ::Buffer::default(); + buffer.as_mut()[..size].copy_from_slice(&src[..size]); + ::from_ne_bytes(buffer) } /// Returns the ceil of value/divisor. @@ -138,7 +106,7 @@ where /// This function should be removed after /// [`int_roundings`](https://github.com/rust-lang/rust/issues/88581) is stable. #[inline] -pub fn ceil(value: i64, divisor: i64) -> i64 { +pub fn ceil(value: T, divisor: T) -> T { num::Integer::div_ceil(&value, &divisor) } @@ -148,20 +116,10 @@ pub fn trailing_bits(v: u64, num_bits: usize) -> u64 { if num_bits >= 64 { v } else { - v & ((1< u8 { @@ -180,59 +138,32 @@ pub fn get_bit(data: &[u8], i: usize) -> bool { /// bit packed or byte aligned fashion. pub struct BitWriter { buffer: Vec, - max_bytes: usize, buffered_values: u64, - byte_offset: usize, - bit_offset: usize, - start: usize, + bit_offset: u8, } impl BitWriter { pub fn new(max_bytes: usize) -> Self { Self { - buffer: vec![0; max_bytes], - max_bytes, + buffer: Vec::with_capacity(max_bytes), buffered_values: 0, - byte_offset: 0, bit_offset: 0, - start: 0, } } - /// Initializes the writer from the existing buffer `buffer` and starting - /// offset `start`. - pub fn new_from_buf(buffer: Vec, start: usize) -> Self { - assert!(start < buffer.len()); - let len = buffer.len(); + /// Initializes the writer appending to the existing buffer `buffer` + pub fn new_from_buf(buffer: Vec) -> Self { Self { buffer, - max_bytes: len, buffered_values: 0, - byte_offset: start, bit_offset: 0, - start, } } - /// Extend buffer size by `increment` bytes - #[inline] - pub fn extend(&mut self, increment: usize) { - self.max_bytes += increment; - let extra = vec![0; increment]; - self.buffer.extend(extra); - } - - /// Report buffer size, in bytes - #[inline] - pub fn capacity(&mut self) -> usize { - self.max_bytes - } - /// Consumes and returns the current buffer. #[inline] pub fn consume(mut self) -> Vec { self.flush(); - self.buffer.truncate(self.byte_offset); self.buffer } @@ -241,53 +172,37 @@ impl BitWriter { #[inline] pub fn flush_buffer(&mut self) -> &[u8] { self.flush(); - &self.buffer()[0..self.byte_offset] + self.buffer() } /// Clears the internal state so the buffer can be reused. #[inline] pub fn clear(&mut self) { + self.buffer.clear(); self.buffered_values = 0; - self.byte_offset = self.start; self.bit_offset = 0; } /// Flushes the internal buffered bits and the align the buffer to the next byte. #[inline] pub fn flush(&mut self) { - let num_bytes = ceil(self.bit_offset as i64, 8) as usize; - assert!(self.byte_offset + num_bytes <= self.max_bytes); - memcpy_value( - &self.buffered_values, - num_bytes, - &mut self.buffer[self.byte_offset..], - ); + let num_bytes = ceil(self.bit_offset, 8); + let slice = &self.buffered_values.to_le_bytes()[..num_bytes as usize]; + self.buffer.extend_from_slice(slice); self.buffered_values = 0; self.bit_offset = 0; - self.byte_offset += num_bytes; } /// Advances the current offset by skipping `num_bytes`, flushing the internal bit /// buffer first. /// This is useful when you want to jump over `num_bytes` bytes and come back later /// to fill these bytes. - /// - /// Returns error if `num_bytes` is beyond the boundary of the internal buffer. - /// Otherwise, returns the old offset. #[inline] - pub fn skip(&mut self, num_bytes: usize) -> Result { + pub fn skip(&mut self, num_bytes: usize) -> usize { self.flush(); - assert!(self.byte_offset <= self.max_bytes); - if self.byte_offset + num_bytes > self.max_bytes { - return Err(general_err!( - "Not enough bytes left in BitWriter. Need {} but only have {}", - self.byte_offset + num_bytes, - self.max_bytes - )); - } - let result = self.byte_offset; - self.byte_offset += num_bytes; - Ok(result) + let result = self.buffer.len(); + self.buffer.extend(std::iter::repeat(0).take(num_bytes)); + result } /// Returns a slice containing the next `num_bytes` bytes starting from the current @@ -295,32 +210,24 @@ impl BitWriter { /// This is useful when you want to jump over `num_bytes` bytes and come back later /// to fill these bytes. #[inline] - pub fn get_next_byte_ptr(&mut self, num_bytes: usize) -> Result<&mut [u8]> { - let offset = self.skip(num_bytes)?; - Ok(&mut self.buffer[offset..offset + num_bytes]) + pub fn get_next_byte_ptr(&mut self, num_bytes: usize) -> &mut [u8] { + let offset = self.skip(num_bytes); + &mut self.buffer[offset..offset + num_bytes] } #[inline] pub fn bytes_written(&self) -> usize { - self.byte_offset - self.start + ceil(self.bit_offset as i64, 8) as usize + self.buffer.len() + ceil(self.bit_offset, 8) as usize } #[inline] pub fn buffer(&self) -> &[u8] { - &self.buffer[self.start..] + &self.buffer } #[inline] pub fn byte_offset(&self) -> usize { - self.byte_offset - } - - /// Returns the internal buffer length. This is the maximum number of bytes that this - /// writer can write. User needs to call `consume` to consume the current buffer - /// before more data can be written. - #[inline] - pub fn buffer_len(&self) -> usize { - self.max_bytes + self.buffer.len() } /// Writes the entire byte `value` at the byte `offset` @@ -330,53 +237,36 @@ impl BitWriter { /// Writes the `num_bits` LSB of value `v` to the internal buffer of this writer. /// The `num_bits` must not be greater than 64. This is bit packed. - /// - /// Returns false if there's not enough room left. True otherwise. #[inline] - pub fn put_value(&mut self, v: u64, num_bits: usize) -> bool { + pub fn put_value(&mut self, v: u64, num_bits: usize) { assert!(num_bits <= 64); + let num_bits = num_bits as u8; assert_eq!(v.checked_shr(num_bits as u32).unwrap_or(0), 0); // covers case v >> 64 - if self.byte_offset * 8 + self.bit_offset + num_bits > self.max_bytes as usize * 8 - { - return false; - } - + // Add value to buffered_values self.buffered_values |= v << self.bit_offset; self.bit_offset += num_bits; - if self.bit_offset >= 64 { - memcpy_value( - &self.buffered_values, - 8, - &mut self.buffer[self.byte_offset..], - ); - self.byte_offset += 8; - self.bit_offset -= 64; - self.buffered_values = 0; + if let Some(remaining) = self.bit_offset.checked_sub(64) { + self.buffer + .extend_from_slice(&self.buffered_values.to_le_bytes()); + self.bit_offset = remaining; + // Perform checked right shift: v >> offset, where offset < 64, otherwise we // shift all bits self.buffered_values = v .checked_shr((num_bits - self.bit_offset) as u32) .unwrap_or(0); } - assert!(self.bit_offset < 64); - true } /// Writes `val` of `num_bytes` bytes to the next aligned byte. If size of `T` is /// larger than `num_bytes`, extra higher ordered bytes will be ignored. - /// - /// Returns false if there's not enough room left. True otherwise. #[inline] - pub fn put_aligned(&mut self, val: T, num_bytes: usize) -> bool { - let result = self.get_next_byte_ptr(num_bytes); - if result.is_err() { - // TODO: should we return `Result` for this func? - return false; - } - let ptr = result.unwrap(); - memcpy_value(&val, num_bytes, ptr); - true + pub fn put_aligned(&mut self, val: T, num_bytes: usize) { + self.flush(); + let slice = val.as_bytes(); + let len = num_bytes.min(slice.len()); + self.buffer.extend_from_slice(&slice[..len]); } /// Writes `val` of `num_bytes` bytes at the designated `offset`. The `offset` is the @@ -384,49 +274,34 @@ impl BitWriter { /// maintains. Note that this will overwrite any existing data between `offset` and /// `offset + num_bytes`. Also that if size of `T` is larger than `num_bytes`, extra /// higher ordered bytes will be ignored. - /// - /// Returns false if there's not enough room left, or the `pos` is not valid. - /// True otherwise. #[inline] pub fn put_aligned_offset( &mut self, val: T, num_bytes: usize, offset: usize, - ) -> bool { - if num_bytes + offset > self.max_bytes { - return false; - } - memcpy_value( - &val, - num_bytes, - &mut self.buffer[offset..offset + num_bytes], - ); - true + ) { + let slice = val.as_bytes(); + let len = num_bytes.min(slice.len()); + self.buffer[offset..offset + len].copy_from_slice(&slice[..len]) } /// Writes a VLQ encoded integer `v` to this buffer. The value is byte aligned. - /// - /// Returns false if there's not enough room left. True otherwise. #[inline] - pub fn put_vlq_int(&mut self, mut v: u64) -> bool { - let mut result = true; + pub fn put_vlq_int(&mut self, mut v: u64) { while v & 0xFFFFFFFFFFFFFF80 != 0 { - result &= self.put_aligned::(((v & 0x7F) | 0x80) as u8, 1); + self.put_aligned::(((v & 0x7F) | 0x80) as u8, 1); v >>= 7; } - result &= self.put_aligned::((v & 0x7F) as u8, 1); - result + self.put_aligned::((v & 0x7F) as u8, 1); } /// Writes a zigzag-VLQ encoded (in little endian order) int `v` to this buffer. /// Zigzag-VLQ is a variant of VLQ encoding where negative and positive /// numbers are encoded in a zigzag fashion. /// See: https://developers.google.com/protocol-buffers/docs/encoding - /// - /// Returns false if there's not enough room left. True otherwise. #[inline] - pub fn put_zigzag_vlq_int(&mut self, v: i64) -> bool { + pub fn put_zigzag_vlq_int(&mut self, v: i64) { let u: u64 = ((v << 1) ^ (v >> 63)) as u64; self.put_vlq_int(u) } @@ -437,50 +312,43 @@ impl BitWriter { pub const MAX_VLQ_BYTE_LEN: usize = 10; pub struct BitReader { - // The byte buffer to read from, passed in by client + /// The byte buffer to read from, passed in by client buffer: ByteBufferPtr, - // Bytes are memcpy'd from `buffer` and values are read from this variable. - // This is faster than reading values byte by byte directly from `buffer` + /// Bytes are memcpy'd from `buffer` and values are read from this variable. + /// This is faster than reading values byte by byte directly from `buffer` + /// + /// This is only populated when `self.bit_offset != 0` buffered_values: u64, - // - // End Start - // |............|B|B|B|B|B|B|B|B|..............| - // ^ ^ - // bit_offset byte_offset - // - // Current byte offset in `buffer` + /// + /// End Start + /// |............|B|B|B|B|B|B|B|B|..............| + /// ^ ^ + /// bit_offset byte_offset + /// + /// Current byte offset in `buffer` byte_offset: usize, - // Current bit offset in `buffered_values` + /// Current bit offset in `buffered_values` bit_offset: usize, - - // Total number of bytes in `buffer` - total_bytes: usize, } /// Utility class to read bit/byte stream. This class can read bits or bytes that are /// either byte aligned or not. impl BitReader { pub fn new(buffer: ByteBufferPtr) -> Self { - let total_bytes = buffer.len(); - let num_bytes = cmp::min(8, total_bytes); - let buffered_values = read_num_bytes!(u64, num_bytes, buffer.as_ref()); BitReader { buffer, - buffered_values, + buffered_values: 0, byte_offset: 0, bit_offset: 0, - total_bytes, } } pub fn reset(&mut self, buffer: ByteBufferPtr) { self.buffer = buffer; - self.total_bytes = self.buffer.len(); - let num_bytes = cmp::min(8, self.total_bytes); - self.buffered_values = read_num_bytes!(u64, num_bytes, self.buffer.as_ref()); + self.buffered_values = 0; self.byte_offset = 0; self.bit_offset = 0; } @@ -488,7 +356,7 @@ impl BitReader { /// Gets the current byte offset #[inline] pub fn get_byte_offset(&self) -> usize { - self.byte_offset + ceil(self.bit_offset as i64, 8) as usize + self.byte_offset + ceil(self.bit_offset, 8) } /// Reads a value of type `T` and of size `num_bits`. @@ -498,10 +366,16 @@ impl BitReader { assert!(num_bits <= 64); assert!(num_bits <= size_of::() * 8); - if self.byte_offset * 8 + self.bit_offset + num_bits > self.total_bytes * 8 { + if self.byte_offset * 8 + self.bit_offset + num_bits > self.buffer.len() * 8 { return None; } + // If buffer is not byte aligned, `self.buffered_values` will + // have already been populated + if self.bit_offset == 0 { + self.load_buffered_values() + } + let mut v = trailing_bits(self.buffered_values, self.bit_offset + num_bits) >> self.bit_offset; self.bit_offset += num_bits; @@ -510,45 +384,40 @@ impl BitReader { self.byte_offset += 8; self.bit_offset -= 64; - self.reload_buffer_values(); - v |= trailing_bits(self.buffered_values, self.bit_offset) - .wrapping_shl((num_bits - self.bit_offset) as u32); + // If the new bit_offset is not 0, we need to read the next 64-bit chunk + // to buffered_values and update `v` + if self.bit_offset != 0 { + self.load_buffered_values(); + + v |= trailing_bits(self.buffered_values, self.bit_offset) + .wrapping_shl((num_bits - self.bit_offset) as u32); + } } // TODO: better to avoid copying here Some(from_ne_slice(v.as_bytes())) } - /// Read multiple values from their packed representation + /// Read multiple values from their packed representation where each element is represented + /// by `num_bits` bits. /// /// # Panics /// /// This function panics if - /// - `bit_width` is larger than the bit-capacity of `T` + /// - `num_bits` is larger than the bit-capacity of `T` /// pub fn get_batch(&mut self, batch: &mut [T], num_bits: usize) -> usize { assert!(num_bits <= size_of::() * 8); let mut values_to_read = batch.len(); let needed_bits = num_bits * values_to_read; - let remaining_bits = (self.total_bytes - self.byte_offset) * 8 - self.bit_offset; + let remaining_bits = (self.buffer.len() - self.byte_offset) * 8 - self.bit_offset; if remaining_bits < needed_bits { values_to_read = remaining_bits / num_bits; } let mut i = 0; - if num_bits > 32 { - // No fast path - read values individually - while i < values_to_read { - batch[i] = self - .get_value(num_bits) - .expect("expected to have more data"); - i += 1; - } - return values_to_read - } - // First align bit offset to byte offset if self.bit_offset != 0 { while i < values_to_read && self.bit_offset != 0 { @@ -559,52 +428,135 @@ impl BitReader { } } - let in_buf = &self.buffer.data()[self.byte_offset..]; - let mut in_ptr = in_buf as *const [u8] as *const u8 as *const u32; - if size_of::() == 4 { - while values_to_read - i >= 32 { - let out_ptr = &mut batch[i..] as *mut [T] as *mut T as *mut u32; - in_ptr = unsafe { unpack32(in_ptr, out_ptr, num_bits) }; - self.byte_offset += 4 * num_bits; - i += 32; + let in_buf = self.buffer.data(); + + // Read directly into output buffer + match size_of::() { + 1 => { + let ptr = batch.as_mut_ptr() as *mut u8; + let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; + while values_to_read - i >= 8 { + let out_slice = (&mut out[i..i + 8]).try_into().unwrap(); + unpack8(&in_buf[self.byte_offset..], out_slice, num_bits); + self.byte_offset += num_bits; + i += 8; + } } - } else { - let mut out_buf = [0u32; 32]; - let out_ptr = &mut out_buf as &mut [u32] as *mut [u32] as *mut u32; - while values_to_read - i >= 32 { - in_ptr = unsafe { unpack32(in_ptr, out_ptr, num_bits) }; - self.byte_offset += 4 * num_bits; - - for out in out_buf { - // Zero-allocate buffer - let mut out_bytes = T::Buffer::default(); - let in_bytes = out.to_le_bytes(); - - { - let out_bytes = out_bytes.as_mut(); - let len = out_bytes.len().min(in_bytes.len()); - (&mut out_bytes[..len]).copy_from_slice(&in_bytes[..len]); - } - - batch[i] = T::from_le_bytes(out_bytes); - i += 1; + 2 => { + let ptr = batch.as_mut_ptr() as *mut u16; + let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; + while values_to_read - i >= 16 { + let out_slice = (&mut out[i..i + 16]).try_into().unwrap(); + unpack16(&in_buf[self.byte_offset..], out_slice, num_bits); + self.byte_offset += 2 * num_bits; + i += 16; } } + 4 => { + let ptr = batch.as_mut_ptr() as *mut u32; + let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; + while values_to_read - i >= 32 { + let out_slice = (&mut out[i..i + 32]).try_into().unwrap(); + unpack32(&in_buf[self.byte_offset..], out_slice, num_bits); + self.byte_offset += 4 * num_bits; + i += 32; + } + } + 8 => { + let ptr = batch.as_mut_ptr() as *mut u64; + let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) }; + while values_to_read - i >= 64 { + let out_slice = (&mut out[i..i + 64]).try_into().unwrap(); + unpack64(&in_buf[self.byte_offset..], out_slice, num_bits); + self.byte_offset += 8 * num_bits; + i += 64; + } + } + _ => unreachable!(), } - assert!(values_to_read - i < 32); + // Try to read smaller batches if possible + if size_of::() > 4 && values_to_read - i >= 32 && num_bits <= 32 { + let mut out_buf = [0_u32; 32]; + unpack32(&in_buf[self.byte_offset..], &mut out_buf, num_bits); + self.byte_offset += 4 * num_bits; + + for out in out_buf { + // Zero-allocate buffer + let mut out_bytes = T::Buffer::default(); + out_bytes.as_mut()[..4].copy_from_slice(&out.to_le_bytes()); + batch[i] = T::from_le_bytes(out_bytes); + i += 1; + } + } - self.reload_buffer_values(); + if size_of::() > 2 && values_to_read - i >= 16 && num_bits <= 16 { + let mut out_buf = [0_u16; 16]; + unpack16(&in_buf[self.byte_offset..], &mut out_buf, num_bits); + self.byte_offset += 2 * num_bits; + + for out in out_buf { + // Zero-allocate buffer + let mut out_bytes = T::Buffer::default(); + out_bytes.as_mut()[..2].copy_from_slice(&out.to_le_bytes()); + batch[i] = T::from_le_bytes(out_bytes); + i += 1; + } + } + + if size_of::() > 1 && values_to_read - i >= 8 && num_bits <= 8 { + let mut out_buf = [0_u8; 8]; + unpack8(&in_buf[self.byte_offset..], &mut out_buf, num_bits); + self.byte_offset += num_bits; + + for out in out_buf { + // Zero-allocate buffer + let mut out_bytes = T::Buffer::default(); + out_bytes.as_mut()[..1].copy_from_slice(&out.to_le_bytes()); + batch[i] = T::from_le_bytes(out_bytes); + i += 1; + } + } + + // Read any trailing values while i < values_to_read { - batch[i] = self + let value = self .get_value(num_bits) .expect("expected to have more data"); + batch[i] = value; i += 1; } values_to_read } + /// Skip num_value values with num_bits bit width + /// + /// Return the number of values skipped (up to num_values) + pub fn skip(&mut self, num_values: usize, num_bits: usize) -> usize { + assert!(num_bits <= 64); + + let needed_bits = num_bits * num_values; + let remaining_bits = (self.buffer.len() - self.byte_offset) * 8 - self.bit_offset; + + let values_to_read = match remaining_bits < needed_bits { + true => remaining_bits / num_bits, + false => num_values, + }; + + let end_bit_offset = + self.byte_offset * 8 + values_to_read * num_bits + self.bit_offset; + + self.byte_offset = end_bit_offset / 8; + self.bit_offset = end_bit_offset % 8; + + if self.bit_offset != 0 { + self.load_buffered_values() + } + + values_to_read + } + /// Reads up to `num_bytes` to `buf` returning the number of bytes read pub(crate) fn get_aligned_bytes( &mut self, @@ -612,7 +564,7 @@ impl BitReader { num_bytes: usize, ) -> usize { // Align to byte offset - self.byte_offset += ceil(self.bit_offset as i64, 8) as usize; + self.byte_offset = self.get_byte_offset(); self.bit_offset = 0; let src = &self.buffer.data()[self.byte_offset..]; @@ -620,7 +572,6 @@ impl BitReader { buf.extend_from_slice(&src[..to_read]); self.byte_offset += to_read; - self.reload_buffer_values(); to_read } @@ -633,19 +584,17 @@ impl BitReader { /// Returns `Some` if there's enough bytes left to form a value of `T`. /// Otherwise `None`. pub fn get_aligned(&mut self, num_bytes: usize) -> Option { - let bytes_read = ceil(self.bit_offset as i64, 8) as usize; - if self.byte_offset + bytes_read + num_bytes > self.total_bytes { + self.byte_offset = self.get_byte_offset(); + self.bit_offset = 0; + + if self.byte_offset + num_bytes > self.buffer.len() { return None; } // Advance byte_offset to next unread byte and read num_bytes - self.byte_offset += bytes_read; - let v = read_num_bytes!(T, num_bytes, self.buffer.data()[self.byte_offset..]); + let v = read_num_bytes::(num_bytes, &self.buffer.data()[self.byte_offset..]); self.byte_offset += num_bytes; - // Reset buffered_values - self.bit_offset = 0; - self.reload_buffer_values(); Some(v) } @@ -688,10 +637,15 @@ impl BitReader { }) } - fn reload_buffer_values(&mut self) { - let bytes_to_read = cmp::min(self.total_bytes - self.byte_offset, 8); + /// Loads up to the the next 8 bytes from `self.buffer` at `self.byte_offset` + /// into `self.buffered_values`. + /// + /// Reads fewer than 8 bytes if there are fewer than 8 bytes left + #[inline] + fn load_buffered_values(&mut self) { + let bytes_to_read = cmp::min(self.buffer.len() - self.byte_offset, 8); self.buffered_values = - read_num_bytes!(u64, bytes_to_read, self.buffer.data()[self.byte_offset..]); + read_num_bytes::(bytes_to_read, &self.buffer.data()[self.byte_offset..]); } } @@ -702,20 +656,11 @@ impl From> for BitReader { } } -/// Returns the nearest multiple of `factor` that is `>=` than `num`. Here `factor` must -/// be a power of 2. -/// -/// Copied from the arrow crate to make arrow optional -pub fn round_upto_power_of_2(num: usize, factor: usize) -> usize { - debug_assert!(factor > 0 && (factor & (factor - 1)) == 0); - (num + (factor - 1)) & !(factor - 1) -} - #[cfg(test)] mod tests { - use super::super::test_common::*; use super::*; + use crate::util::test_common::rand_gen::random_numbers; use rand::distributions::{Distribution, Standard}; use std::fmt::Debug; @@ -729,9 +674,9 @@ mod tests { assert_eq!(ceil(8, 8), 1); assert_eq!(ceil(9, 8), 2); assert_eq!(ceil(9, 9), 1); - assert_eq!(ceil(10000000000, 10), 1000000000); - assert_eq!(ceil(10, 10000000000), 1); - assert_eq!(ceil(10000000000, 1000000000), 10); + assert_eq!(ceil(10000000000_u64, 10), 1000000000); + assert_eq!(ceil(10_u64, 10000000000), 1); + assert_eq!(ceil(10000000000_u64, 1000000000), 10); } #[test] @@ -759,6 +704,23 @@ mod tests { assert_eq!(bit_reader.get_value::(4), Some(3)); } + #[test] + fn test_bit_reader_skip() { + let buffer = vec![255, 0]; + let mut bit_reader = BitReader::from(buffer); + let skipped = bit_reader.skip(1, 1); + assert_eq!(skipped, 1); + assert_eq!(bit_reader.get_value::(1), Some(1)); + let skipped = bit_reader.skip(2, 2); + assert_eq!(skipped, 2); + assert_eq!(bit_reader.get_value::(2), Some(3)); + let skipped = bit_reader.skip(4, 1); + assert_eq!(skipped, 4); + assert_eq!(bit_reader.get_value::(4), Some(0)); + let skipped = bit_reader.skip(1, 1); + assert_eq!(skipped, 0); + } + #[test] fn test_bit_reader_get_value_boundary() { let buffer = vec![10, 0, 0, 0, 20, 0, 30, 0, 0, 0, 40, 0]; @@ -769,6 +731,16 @@ mod tests { assert_eq!(bit_reader.get_value::(16), Some(40)); } + #[test] + fn test_bit_reader_skip_boundary() { + let buffer = vec![10, 0, 0, 0, 20, 0, 30, 0, 0, 0, 40, 0]; + let mut bit_reader = BitReader::from(buffer); + assert_eq!(bit_reader.get_value::(32), Some(10)); + assert_eq!(bit_reader.skip(1, 16), 1); + assert_eq!(bit_reader.get_value::(32), Some(30)); + assert_eq!(bit_reader.get_value::(16), Some(40)); + } + #[test] fn test_bit_reader_get_aligned() { // 01110101 11001011 @@ -800,25 +772,6 @@ mod tests { assert_eq!(bit_reader.get_zigzag_vlq_int(), Some(-2)); } - #[test] - fn test_set_array_bit() { - let mut buffer = vec![0, 0, 0]; - set_array_bit(&mut buffer[..], 1); - assert_eq!(buffer, vec![2, 0, 0]); - set_array_bit(&mut buffer[..], 4); - assert_eq!(buffer, vec![18, 0, 0]); - unset_array_bit(&mut buffer[..], 1); - assert_eq!(buffer, vec![16, 0, 0]); - set_array_bit(&mut buffer[..], 10); - assert_eq!(buffer, vec![16, 4, 0]); - set_array_bit(&mut buffer[..], 10); - assert_eq!(buffer, vec![16, 4, 0]); - set_array_bit(&mut buffer[..], 11); - assert_eq!(buffer, vec![16, 12, 0]); - unset_array_bit(&mut buffer[..], 10); - assert_eq!(buffer, vec![16, 8, 0]); - } - #[test] fn test_num_required_bits() { assert_eq!(num_required_bits(0), 0); @@ -862,7 +815,7 @@ mod tests { #[test] fn test_skip() { let mut writer = BitWriter::new(5); - let old_offset = writer.skip(1).expect("skip() should return OK"); + let old_offset = writer.skip(1); writer.put_aligned(42, 4); writer.put_aligned_offset(0x10, 1, old_offset); let result = writer.consume(); @@ -870,16 +823,15 @@ mod tests { writer = BitWriter::new(4); let result = writer.skip(5); - assert!(result.is_err()); + assert_eq!(result, 0); + assert_eq!(writer.buffer(), &[0; 5]) } #[test] fn test_get_next_byte_ptr() { let mut writer = BitWriter::new(5); { - let first_byte = writer - .get_next_byte_ptr(1) - .expect("get_next_byte_ptr() should return OK"); + let first_byte = writer.get_next_byte_ptr(1); first_byte[0] = 0x10; } writer.put_aligned(42, 4); @@ -906,8 +858,7 @@ mod tests { let mut writer = BitWriter::new(len); for i in 0..8 { - let result = writer.put_value(i % 2, 1); - assert!(result); + writer.put_value(i % 2, 1); } writer.flush(); @@ -918,11 +869,10 @@ mod tests { // Write 00110011 for i in 0..8 { - let result = match i { + match i { 0 | 1 | 4 | 5 => writer.put_value(false as u64, 1), _ => writer.put_value(true as u64, 1), - }; - assert!(result); + } } writer.flush(); { @@ -967,19 +917,13 @@ mod tests { fn test_put_value_rand_numbers(total: usize, num_bits: usize) { assert!(num_bits < 64); - let num_bytes = ceil(num_bits as i64, 8); + let num_bytes = ceil(num_bits, 8); let mut writer = BitWriter::new(num_bytes as usize * total); let values: Vec = random_numbers::(total) .iter() .map(|v| v & ((1 << num_bits) - 1)) .collect(); - (0..total).for_each(|i| { - assert!( - writer.put_value(values[i] as u64, num_bits), - "[{}]: put_value() failed", - i - ); - }); + (0..total).for_each(|i| writer.put_value(values[i] as u64, num_bits)); let mut reader = BitReader::from(writer.consume()); (0..total).for_each(|i| { @@ -998,11 +942,12 @@ mod tests { fn test_get_batch() { const SIZE: &[usize] = &[1, 31, 32, 33, 128, 129]; for s in SIZE { - for i in 0..33 { + for i in 0..=64 { match i { 0..=8 => test_get_batch_helper::(*s, i), 9..=16 => test_get_batch_helper::(*s, i), - _ => test_get_batch_helper::(*s, i), + 17..=32 => test_get_batch_helper::(*s, i), + _ => test_get_batch_helper::(*s, i), } } } @@ -1012,22 +957,25 @@ mod tests { where T: FromBytes + Default + Clone + Debug + Eq, { - assert!(num_bits <= 32); - let num_bytes = ceil(num_bits as i64, 8); + assert!(num_bits <= 64); + let num_bytes = ceil(num_bits, 8); let mut writer = BitWriter::new(num_bytes as usize * total); - let values: Vec = random_numbers::(total) + let mask = match num_bits { + 64 => u64::MAX, + _ => (1 << num_bits) - 1, + }; + + let values: Vec = random_numbers::(total) .iter() - .map(|v| v & ((1u64 << num_bits) - 1) as u32) + .map(|v| v & mask) .collect(); // Generic values used to check against actual values read from `get_batch`. let expected_values: Vec = values.iter().map(|v| from_ne_slice(v.as_bytes())).collect(); - (0..total).for_each(|i| { - assert!(writer.put_value(values[i] as u64, num_bits)); - }); + (0..total).for_each(|i| writer.put_value(values[i] as u64, num_bits)); let buf = writer.consume(); let mut reader = BitReader::from(buf); @@ -1036,9 +984,12 @@ mod tests { assert_eq!(values_read, values.len()); for i in 0..batch.len() { assert_eq!( - batch[i], expected_values[i], - "num_bits = {}, index = {}", - num_bits, i + batch[i], + expected_values[i], + "max_num_bits = {}, num_bits = {}, index = {}", + size_of::() * 8, + num_bits, + i ); } } @@ -1064,7 +1015,7 @@ mod tests { assert!(total % 2 == 0); let aligned_value_byte_width = std::mem::size_of::(); - let value_byte_width = ceil(num_bits as i64, 8) as usize; + let value_byte_width = ceil(num_bits, 8); let mut writer = BitWriter::new((total / 2) * (aligned_value_byte_width + value_byte_width)); let values: Vec = random_numbers::(total / 2) @@ -1076,17 +1027,9 @@ mod tests { for i in 0..total { let j = i / 2; if i % 2 == 0 { - assert!( - writer.put_value(values[j] as u64, num_bits), - "[{}]: put_value() failed", - i - ); + writer.put_value(values[j] as u64, num_bits); } else { - assert!( - writer.put_aligned::(aligned_values[j], aligned_value_byte_width), - "[{}]: put_aligned() failed", - i - ); + writer.put_aligned::(aligned_values[j], aligned_value_byte_width) } } @@ -1120,13 +1063,7 @@ mod tests { let total = 64; let mut writer = BitWriter::new(total * 32); let values = random_numbers::(total); - (0..total).for_each(|i| { - assert!( - writer.put_vlq_int(values[i] as u64), - "[{}]; put_vlq_int() failed", - i - ); - }); + (0..total).for_each(|i| writer.put_vlq_int(values[i] as u64)); let mut reader = BitReader::from(writer.consume()); (0..total).for_each(|i| { @@ -1146,13 +1083,7 @@ mod tests { let total = 64; let mut writer = BitWriter::new(total * 32); let values = random_numbers::(total); - (0..total).for_each(|i| { - assert!( - writer.put_zigzag_vlq_int(values[i] as i64), - "[{}]; put_zigzag_vlq_int() failed", - i - ); - }); + (0..total).for_each(|i| writer.put_zigzag_vlq_int(values[i] as i64)); let mut reader = BitReader::from(writer.consume()); (0..total).for_each(|i| { diff --git a/parquet/src/util/cursor.rs b/parquet/src/util/cursor.rs deleted file mode 100644 index 706724dbf52a..000000000000 --- a/parquet/src/util/cursor.rs +++ /dev/null @@ -1,284 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::util::io::TryClone; -use std::io::{self, Cursor, Error, ErrorKind, Read, Seek, SeekFrom, Write}; -use std::sync::{Arc, Mutex}; -use std::{cmp, fmt}; - -/// This is object to use if your file is already in memory. -/// The sliceable cursor is similar to std::io::Cursor, except that it makes it easy to create "cursor slices". -/// To achieve this, it uses Arc instead of shared references. Indeed reference fields are painful -/// because the lack of Generic Associated Type implies that you would require complex lifetime propagation when -/// returning such a cursor. -#[allow(clippy::rc_buffer)] -#[deprecated = "use bytes::Bytes instead"] -pub struct SliceableCursor { - inner: Arc>, - start: u64, - length: usize, - pos: u64, -} - -#[allow(deprecated)] -impl fmt::Debug for SliceableCursor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SliceableCursor") - .field("start", &self.start) - .field("length", &self.length) - .field("pos", &self.pos) - .field("inner.len", &self.inner.len()) - .finish() - } -} - -#[allow(deprecated)] -impl SliceableCursor { - pub fn new(content: impl Into>>) -> Self { - let inner = content.into(); - let size = inner.len(); - SliceableCursor { - inner, - start: 0, - pos: 0, - length: size, - } - } - - /// Create a slice cursor using the same data as a current one. - pub fn slice(&self, start: u64, length: usize) -> io::Result { - let new_start = self.start + start; - if new_start >= self.inner.len() as u64 - || new_start as usize + length > self.inner.len() - { - return Err(Error::new(ErrorKind::InvalidInput, "out of bound")); - } - Ok(SliceableCursor { - inner: Arc::clone(&self.inner), - start: new_start, - pos: new_start, - length, - }) - } - - fn remaining_slice(&self) -> &[u8] { - let end = self.start as usize + self.length; - let offset = cmp::min(self.pos, end as u64) as usize; - &self.inner[offset..end] - } - - /// Get the length of the current cursor slice - pub fn len(&self) -> u64 { - self.length as u64 - } - - /// return true if the cursor is empty (self.len() == 0) - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -/// Implementation inspired by std::io::Cursor -#[allow(deprecated)] -impl Read for SliceableCursor { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let n = Read::read(&mut self.remaining_slice(), buf)?; - self.pos += n as u64; - Ok(n) - } -} - -#[allow(deprecated)] -impl Seek for SliceableCursor { - fn seek(&mut self, pos: SeekFrom) -> io::Result { - let new_pos = match pos { - SeekFrom::Start(pos) => pos as i64, - SeekFrom::End(pos) => self.inner.len() as i64 + pos as i64, - SeekFrom::Current(pos) => self.pos as i64 + pos as i64, - }; - - if new_pos < 0 { - Err(Error::new( - ErrorKind::InvalidInput, - format!( - "Request out of bounds: cur position {} + seek {:?} < 0: {}", - self.pos, pos, new_pos - ), - )) - } else if new_pos >= self.inner.len() as i64 { - Err(Error::new( - ErrorKind::InvalidInput, - format!( - "Request out of bounds: cur position {} + seek {:?} >= length {}: {}", - self.pos, - pos, - self.inner.len(), - new_pos - ), - )) - } else { - self.pos = new_pos as u64; - Ok(self.start) - } - } -} - -/// Use this type to write Parquet to memory rather than a file. -#[deprecated = "use Vec instead"] -#[derive(Debug, Default, Clone)] -pub struct InMemoryWriteableCursor { - buffer: Arc>>>, -} - -#[allow(deprecated)] -impl InMemoryWriteableCursor { - /// Consume this instance and return the underlying buffer as long as there are no other - /// references to this instance. - pub fn into_inner(self) -> Option> { - Arc::try_unwrap(self.buffer) - .ok() - .and_then(|mutex| mutex.into_inner().ok()) - .map(|cursor| cursor.into_inner()) - } - - /// Returns a clone of the underlying buffer - pub fn data(&self) -> Vec { - let inner = self.buffer.lock().unwrap(); - inner.get_ref().to_vec() - } - - /// Returns a length of the underlying buffer - pub fn len(&self) -> usize { - let inner = self.buffer.lock().unwrap(); - inner.get_ref().len() - } - - /// Returns true if the underlying buffer contains no elements - pub fn is_empty(&self) -> bool { - let inner = self.buffer.lock().unwrap(); - inner.get_ref().is_empty() - } -} - -#[allow(deprecated)] -impl TryClone for InMemoryWriteableCursor { - fn try_clone(&self) -> std::io::Result { - Ok(Self { - buffer: self.buffer.clone(), - }) - } -} - -#[allow(deprecated)] -impl Write for InMemoryWriteableCursor { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut inner = self.buffer.lock().unwrap(); - inner.write(buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - let mut inner = self.buffer.lock().unwrap(); - inner.flush() - } -} - -#[allow(deprecated)] -impl Seek for InMemoryWriteableCursor { - fn seek(&mut self, pos: SeekFrom) -> std::io::Result { - let mut inner = self.buffer.lock().unwrap(); - inner.seek(pos) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - /// Create a SliceableCursor of all u8 values in ascending order - #[allow(deprecated)] - fn get_u8_range() -> SliceableCursor { - let data: Vec = (0u8..=255).collect(); - SliceableCursor::new(data) - } - - /// Reads all the bytes in the slice and checks that it matches the u8 range from start to end_included - #[allow(deprecated)] - fn check_read_all(mut cursor: SliceableCursor, start: u8, end_included: u8) { - let mut target = vec![]; - let cursor_res = cursor.read_to_end(&mut target); - println!("{:?}", cursor_res); - assert!(cursor_res.is_ok(), "reading error"); - assert_eq!((end_included - start) as usize + 1, cursor_res.unwrap()); - assert_eq!((start..=end_included).collect::>(), target); - } - - #[test] - fn read_all_whole() { - let cursor = get_u8_range(); - check_read_all(cursor, 0, 255); - } - - #[test] - fn read_all_slice() { - let cursor = get_u8_range().slice(10, 10).expect("error while slicing"); - check_read_all(cursor, 10, 19); - } - - #[test] - fn seek_cursor_start() { - let mut cursor = get_u8_range(); - - cursor.seek(SeekFrom::Start(5)).unwrap(); - check_read_all(cursor, 5, 255); - } - - #[test] - fn seek_cursor_current() { - let mut cursor = get_u8_range(); - cursor.seek(SeekFrom::Start(10)).unwrap(); - cursor.seek(SeekFrom::Current(10)).unwrap(); - check_read_all(cursor, 20, 255); - } - - #[test] - fn seek_cursor_end() { - let mut cursor = get_u8_range(); - - cursor.seek(SeekFrom::End(-10)).unwrap(); - check_read_all(cursor, 246, 255); - } - - #[test] - fn seek_cursor_error_too_long() { - let mut cursor = get_u8_range(); - let res = cursor.seek(SeekFrom::Start(1000)); - let actual_error = res.expect_err("expected error").to_string(); - let expected_error = - "Request out of bounds: cur position 0 + seek Start(1000) >= length 256: 1000"; - assert_eq!(actual_error, expected_error); - } - - #[test] - fn seek_cursor_error_too_short() { - let mut cursor = get_u8_range(); - let res = cursor.seek(SeekFrom::End(-1000)); - let actual_error = res.expect_err("expected error").to_string(); - let expected_error = - "Request out of bounds: cur position 0 + seek End(-1000) < 0: -744"; - assert_eq!(actual_error, expected_error); - } -} diff --git a/parquet/src/util/hash_util.rs b/parquet/src/util/hash_util.rs deleted file mode 100644 index dd23e7a65f44..000000000000 --- a/parquet/src/util/hash_util.rs +++ /dev/null @@ -1,162 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::data_type::AsBytes; - -/// Computes hash value for `data`, with a seed value `seed`. -/// The data type `T` must implement the `AsBytes` trait. -pub fn hash(data: &T, seed: u32) -> u32 { - hash_(data.as_bytes(), seed) -} - -fn hash_(data: &[u8], seed: u32) -> u32 { - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - { - if is_x86_feature_detected!("sse4.2") { - unsafe { crc32_hash(data, seed) } - } else { - murmur_hash2_64a(data, seed as u64) as u32 - } - } - - #[cfg(any( - target_arch = "aarch64", - target_arch = "arm", - target_arch = "riscv64", - target_arch = "wasm32" - ))] - { - murmur_hash2_64a(data, seed as u64) as u32 - } -} - -const MURMUR_PRIME: u64 = 0xc6a4a7935bd1e995; -const MURMUR_R: i32 = 47; - -/// Rust implementation of MurmurHash2, 64-bit version for 64-bit platforms -fn murmur_hash2_64a(data_bytes: &[u8], seed: u64) -> u64 { - let len = data_bytes.len(); - let len_64 = (len / 8) * 8; - - let mut h = seed ^ (MURMUR_PRIME.wrapping_mul(data_bytes.len() as u64)); - for mut k in data_bytes - .chunks_exact(8) - .map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())) - { - k = k.wrapping_mul(MURMUR_PRIME); - k ^= k >> MURMUR_R; - k = k.wrapping_mul(MURMUR_PRIME); - h ^= k; - h = h.wrapping_mul(MURMUR_PRIME); - } - - let data2 = &data_bytes[len_64..]; - - let v = len & 7; - if v == 7 { - h ^= (data2[6] as u64) << 48; - } - if v >= 6 { - h ^= (data2[5] as u64) << 40; - } - if v >= 5 { - h ^= (data2[4] as u64) << 32; - } - if v >= 4 { - h ^= (data2[3] as u64) << 24; - } - if v >= 3 { - h ^= (data2[2] as u64) << 16; - } - if v >= 2 { - h ^= (data2[1] as u64) << 8; - } - if v >= 1 { - h ^= data2[0] as u64; - } - if v > 0 { - h = h.wrapping_mul(MURMUR_PRIME); - } - - h ^= h >> MURMUR_R; - h = h.wrapping_mul(MURMUR_PRIME); - h ^= h >> MURMUR_R; - h -} - -/// CRC32 hash implementation using SSE4 instructions. Borrowed from Impala. -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -#[target_feature(enable = "sse4.2")] -unsafe fn crc32_hash(bytes: &[u8], seed: u32) -> u32 { - #[cfg(target_arch = "x86")] - use std::arch::x86::*; - #[cfg(target_arch = "x86_64")] - use std::arch::x86_64::*; - - let mut hash = seed; - for chunk in bytes - .chunks_exact(4) - .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) - { - hash = _mm_crc32_u32(hash, chunk); - } - - let remainder = bytes.len() % 4; - - for byte in &bytes[bytes.len() - remainder..] { - hash = _mm_crc32_u8(hash, *byte); - } - - // The lower half of the CRC hash has poor uniformity, so swap the halves - // for anyone who only uses the first several bits of the hash. - hash = (hash << 16) | (hash >> 16); - hash -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_murmur2_64a() { - let result = murmur_hash2_64a(b"hello", 123); - assert_eq!(result, 2597646618390559622); - - let result = murmur_hash2_64a(b"helloworld", 123); - assert_eq!(result, 4934371746140206573); - - let result = murmur_hash2_64a(b"helloworldparquet", 123); - assert_eq!(result, 2392198230801491746); - } - - #[test] - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - fn test_crc32() { - if is_x86_feature_detected!("sse4.2") { - unsafe { - let result = crc32_hash(b"hello", 123); - assert_eq!(result, 3359043980); - - let result = crc32_hash(b"helloworld", 123); - assert_eq!(result, 3971745255); - - let result = crc32_hash(b"helloworldparquet", 123); - assert_eq!(result, 1124504676); - } - } - } -} diff --git a/parquet/src/util/interner.rs b/parquet/src/util/interner.rs new file mode 100644 index 000000000000..e638237e06c5 --- /dev/null +++ b/parquet/src/util/interner.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::data_type::AsBytes; +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; + +const DEFAULT_DEDUP_CAPACITY: usize = 4096; + +/// Storage trait for [`Interner`] +pub trait Storage { + type Key: Copy; + + type Value: AsBytes + PartialEq + ?Sized; + + /// Gets an element by its key + fn get(&self, idx: Self::Key) -> &Self::Value; + + /// Adds a new element, returning the key + fn push(&mut self, value: &Self::Value) -> Self::Key; +} + +/// A generic value interner supporting various different [`Storage`] +#[derive(Debug, Default)] +pub struct Interner { + state: ahash::RandomState, + + /// Used to provide a lookup from value to unique value + /// + /// Note: `S::Key`'s hash implementation is not used, instead the raw entry + /// API is used to store keys w.r.t the hash of the strings themselves + /// + dedup: HashMap, + + storage: S, +} + +impl Interner { + /// Create a new `Interner` with the provided storage + pub fn new(storage: S) -> Self { + Self { + state: Default::default(), + dedup: HashMap::with_capacity_and_hasher(DEFAULT_DEDUP_CAPACITY, ()), + storage, + } + } + + /// Intern the value, returning the interned key, and if this was a new value + pub fn intern(&mut self, value: &S::Value) -> S::Key { + let hash = self.state.hash_one(value.as_bytes()); + + let entry = self + .dedup + .raw_entry_mut() + .from_hash(hash, |index| value == self.storage.get(*index)); + + match entry { + RawEntryMut::Occupied(entry) => *entry.into_key(), + RawEntryMut::Vacant(entry) => { + let key = self.storage.push(value); + + *entry + .insert_with_hasher(hash, key, (), |key| { + self.state.hash_one(self.storage.get(*key).as_bytes()) + }) + .0 + } + } + } + + /// Returns the storage for this interner + pub fn storage(&self) -> &S { + &self.storage + } + + /// Unwraps the inner storage + pub fn into_inner(self) -> S { + self.storage + } +} diff --git a/parquet/src/util/io.rs b/parquet/src/util/io.rs index a7b5e73074c6..43d78866d9ef 100644 --- a/parquet/src/util/io.rs +++ b/parquet/src/util/io.rs @@ -18,8 +18,6 @@ use std::{cell::RefCell, cmp, fmt, io::*}; use crate::file::reader::Length; -#[allow(deprecated)] -use crate::file::writer::ParquetWriter; const DEFAULT_BUF_SIZE: usize = 8 * 1024; @@ -40,14 +38,6 @@ impl ParquetReader for T {} // Read/Write wrappers for `File`. -/// Position trait returns the current position in the stream. -/// Should be viewed as a lighter version of `Seek` that does not allow seek operations, -/// and does not require mutable reference for the current position. -pub trait Position { - /// Returns position in the stream. - fn pos(&self) -> u64; -} - /// Struct that represents a slice of a file data with independent start position and /// length. Internally clones provided file handle, wraps with a custom implementation /// of BufReader that resets position before any read. @@ -144,77 +134,19 @@ impl Read for FileSource { } } -impl Position for FileSource { - fn pos(&self) -> u64 { - self.start - } -} - impl Length for FileSource { fn len(&self) -> u64 { self.end - self.start } } -/// Struct that represents `File` output stream with position tracking. -/// Used as a sink in file writer. -#[deprecated = "use TrackedWrite instead"] -#[allow(deprecated)] -pub struct FileSink { - buf: BufWriter, - // This is not necessarily position in the underlying file, - // but rather current position in the sink. - pos: u64, -} - -#[allow(deprecated)] -impl FileSink { - /// Creates new file sink. - /// Position is set to whatever position file has. - pub fn new(buf: &W) -> Self { - let mut owned_buf = buf.try_clone().unwrap(); - let pos = owned_buf.seek(SeekFrom::Current(0)).unwrap(); - Self { - buf: BufWriter::new(owned_buf), - pos, - } - } -} - -#[allow(deprecated)] -impl Write for FileSink { - fn write(&mut self, buf: &[u8]) -> Result { - let num_bytes = self.buf.write(buf)?; - self.pos += num_bytes as u64; - Ok(num_bytes) - } - - fn flush(&mut self) -> Result<()> { - self.buf.flush() - } -} - -#[allow(deprecated)] -impl Position for FileSink { - fn pos(&self) -> u64 { - self.pos - } -} - -// Position implementation for Cursor to use in various tests. -impl<'a> Position for Cursor<&'a mut Vec> { - fn pos(&self) -> u64 { - self.position() - } -} - #[cfg(test)] mod tests { use super::*; use std::iter; - use crate::util::test_common::get_test_file; + use crate::util::test_common::file_util::get_test_file; #[test] fn test_io_read_fully() { @@ -243,10 +175,10 @@ mod tests { let mut src = FileSource::new(&get_test_file("alltypes_plain.parquet"), 0, 4); let _ = src.read(&mut [0; 1]).unwrap(); - assert_eq!(src.pos(), 1); + assert_eq!(src.start, 1); let _ = src.read(&mut [0; 4]).unwrap(); - assert_eq!(src.pos(), 4); + assert_eq!(src.start, 4); } #[test] @@ -255,12 +187,12 @@ mod tests { // Read all bytes from source let _ = src.read(&mut [0; 128]).unwrap(); - assert_eq!(src.pos(), 4); + assert_eq!(src.start, 4); // Try reading again, should return 0 bytes. let bytes_read = src.read(&mut [0; 128]).unwrap(); assert_eq!(bytes_read, 0); - assert_eq!(src.pos(), 4); + assert_eq!(src.start, 4); } #[test] @@ -277,30 +209,6 @@ mod tests { assert_eq!(buf, vec![b'P', b'A', b'R', b'1']); } - #[test] - #[allow(deprecated)] - fn test_io_write_with_pos() { - let mut file = tempfile::tempfile().unwrap(); - file.write_all(&[b'a', b'b', b'c']).unwrap(); - - // Write into sink - let mut sink = FileSink::new(&file); - assert_eq!(sink.pos(), 3); - - sink.write_all(&[b'd', b'e', b'f', b'g']).unwrap(); - assert_eq!(sink.pos(), 7); - - sink.flush().unwrap(); - assert_eq!(sink.pos(), file.seek(SeekFrom::Current(0)).unwrap()); - - // Read data using file chunk - let mut res = vec![0u8; 7]; - let mut chunk = - FileSource::new(&file, 0, file.metadata().unwrap().len() as usize); - chunk.read_exact(&mut res[..]).unwrap(); - assert_eq!(res, vec![b'a', b'b', b'c', b'd', b'e', b'f', b'g']); - } - #[test] fn test_io_large_read() { // Generate repeated 'abcdef' pattern and write it into a file diff --git a/parquet/src/util/mod.rs b/parquet/src/util/mod.rs index b49e32516921..5f43023941fd 100644 --- a/parquet/src/util/mod.rs +++ b/parquet/src/util/mod.rs @@ -19,12 +19,10 @@ pub mod io; pub mod memory; #[macro_use] pub mod bit_util; -mod bit_packing; -pub mod cursor; -pub mod hash_util; +mod bit_pack; +pub(crate) mod interner; #[cfg(any(test, feature = "test_common"))] pub(crate) mod test_common; -pub(crate)mod page_util; #[cfg(any(test, feature = "test_common"))] pub use self::test_common::page_util::{ diff --git a/parquet/src/util/page_util.rs b/parquet/src/util/page_util.rs deleted file mode 100644 index 5cdcf7535c63..000000000000 --- a/parquet/src/util/page_util.rs +++ /dev/null @@ -1,54 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::VecDeque; -use std::io::Read; -use std::sync::Arc; -use crate::errors::Result; -use parquet_format::PageLocation; -use crate::file::reader::ChunkReader; - -/// Use column chunk's offset index to get the `page_num` page row count. -pub(crate) fn calculate_row_count(indexes: &[PageLocation], page_num: usize, total_row_count: i64) -> Result { - if page_num == indexes.len() - 1 { - Ok((total_row_count - indexes[page_num].first_row_index + 1) as usize) - } else { - Ok((indexes[page_num + 1].first_row_index - indexes[page_num].first_row_index) as usize) - } -} - -/// Use column chunk's offset index to get each page serially readable slice -/// and a flag indicates whether having one dictionary page in this column chunk. -pub(crate) fn get_pages_readable_slices>(col_chunk_offset_index: &[PageLocation], col_start: u64, chunk_reader: Arc) -> Result<(VecDeque, bool)> { - let first_data_page_offset = col_chunk_offset_index[0].offset as u64; - let has_dictionary_page = first_data_page_offset != col_start; - let mut page_readers = VecDeque::with_capacity(col_chunk_offset_index.len() + 1); - - if has_dictionary_page { - let length = (first_data_page_offset - col_start) as usize; - let reader: T = chunk_reader.get_read(col_start, length)?; - page_readers.push_back(reader); - } - - for index in col_chunk_offset_index { - let start = index.offset as u64; - let length = index.compressed_page_size as usize; - let reader: T = chunk_reader.get_read(start, length)?; - page_readers.push_back(reader) - } - Ok((page_readers, has_dictionary_page)) -} diff --git a/parquet/src/util/test_common/mod.rs b/parquet/src/util/test_common/mod.rs index f0beb16ca954..504219ecae19 100644 --- a/parquet/src/util/test_common/mod.rs +++ b/parquet/src/util/test_common/mod.rs @@ -15,17 +15,10 @@ // specific language governing permissions and limitations // under the License. -pub mod file_util; pub mod page_util; -pub mod rand_gen; - -pub use self::rand_gen::random_bools; -pub use self::rand_gen::random_bytes; -pub use self::rand_gen::random_numbers; -pub use self::rand_gen::random_numbers_range; -pub use self::rand_gen::RandGen; -pub use self::file_util::get_test_file; -pub use self::file_util::get_test_path; +#[cfg(test)] +pub mod file_util; -pub use self::page_util::make_pages; +#[cfg(test)] +pub mod rand_gen; \ No newline at end of file diff --git a/parquet/src/util/test_common/page_util.rs b/parquet/src/util/test_common/page_util.rs index f56eaf85e636..243fb6f8b897 100644 --- a/parquet/src/util/test_common/page_util.rs +++ b/parquet/src/util/test_common/page_util.rs @@ -16,18 +16,15 @@ // under the License. use crate::basic::Encoding; -use crate::column::page::{PageMetadata, PageReader}; use crate::column::page::{Page, PageIterator}; +use crate::column::page::{PageMetadata, PageReader}; use crate::data_type::DataType; -use crate::encodings::encoding::{get_encoder, DictEncoder, Encoder}; -use crate::encodings::levels::max_buffer_size; +use crate::encodings::encoding::{get_encoder, Encoder}; use crate::encodings::levels::LevelEncoder; use crate::errors::Result; use crate::schema::types::{ColumnDescPtr, SchemaDescPtr}; use crate::util::memory::ByteBufferPtr; -use crate::util::test_common::random_numbers_range; -use rand::distributions::uniform::SampleUniform; -use std::collections::VecDeque; +use std::iter::Peekable; use std::mem; pub trait DataPageBuilder { @@ -45,7 +42,6 @@ pub trait DataPageBuilder { /// - consume() /// in order to populate and obtain a data page. pub struct DataPageBuilderImpl { - desc: ColumnDescPtr, encoding: Option, num_values: u32, buffer: Vec, @@ -58,9 +54,8 @@ impl DataPageBuilderImpl { // `num_values` is the number of non-null values to put in the data page. // `datapage_v2` flag is used to indicate if the generated data page should use V2 // format or not. - pub fn new(desc: ColumnDescPtr, num_values: u32, datapage_v2: bool) -> Self { + pub fn new(_desc: ColumnDescPtr, num_values: u32, datapage_v2: bool) -> Self { DataPageBuilderImpl { - desc, encoding: None, num_values, buffer: vec![], @@ -75,10 +70,9 @@ impl DataPageBuilderImpl { if max_level <= 0 { return 0; } - let size = max_buffer_size(Encoding::RLE, max_level, levels.len()); - let mut level_encoder = LevelEncoder::v1(Encoding::RLE, max_level, vec![0; size]); - level_encoder.put(levels).expect("put() should be OK"); - let encoded_levels = level_encoder.consume().expect("consume() should be OK"); + let mut level_encoder = LevelEncoder::v1(Encoding::RLE, max_level, levels.len()); + level_encoder.put(levels); + let encoded_levels = level_encoder.consume(); // Actual encoded bytes (without length offset) let encoded_bytes = &encoded_levels[mem::size_of::()..]; if self.datapage_v2 { @@ -113,8 +107,7 @@ impl DataPageBuilder for DataPageBuilderImpl { ); self.encoding = Some(encoding); let mut encoder: Box> = - get_encoder::(self.desc.clone(), encoding) - .expect("get_encoder() should be OK"); + get_encoder::(encoding).expect("get_encoder() should be OK"); encoder.put(values).expect("put() should be OK"); let encoded_values = encoder .flush_buffer() @@ -135,8 +128,8 @@ impl DataPageBuilder for DataPageBuilderImpl { encoding: self.encoding.unwrap(), num_nulls: 0, /* set to dummy value - don't need this when reading * data page */ - num_rows: self.num_values, /* also don't need this when reading - * data page */ + num_rows: self.num_values, /* num_rows only needs in skip_records, now we not support skip REPEATED field, + * so we can assume num_values == num_rows */ def_levels_byte_len: self.def_levels_byte_len, rep_levels_byte_len: self.rep_levels_byte_len, is_compressed: false, @@ -157,13 +150,13 @@ impl DataPageBuilder for DataPageBuilderImpl { /// A utility page reader which stores pages in memory. pub struct InMemoryPageReader> { - page_iter: P, + page_iter: Peekable

, } impl> InMemoryPageReader

{ pub fn new(pages: impl IntoIterator) -> Self { Self { - page_iter: pages.into_iter(), + page_iter: pages.into_iter().peekable(), } } } @@ -174,11 +167,29 @@ impl + Send> PageReader for InMemoryPageReader

{ } fn peek_next_page(&mut self) -> Result> { - unimplemented!() + if let Some(x) = self.page_iter.peek() { + match x { + Page::DataPage { num_values, .. } => Ok(Some(PageMetadata { + num_rows: *num_values as usize, + is_dict: false, + })), + Page::DataPageV2 { num_rows, .. } => Ok(Some(PageMetadata { + num_rows: *num_rows as usize, + is_dict: false, + })), + Page::DictionaryPage { .. } => Ok(Some(PageMetadata { + num_rows: 0, + is_dict: true, + })), + } + } else { + Ok(None) + } } fn skip_next_page(&mut self) -> Result<()> { - unimplemented!() + self.page_iter.next(); + Ok(()) } } @@ -231,88 +242,3 @@ impl> + Send> PageIterator for InMemoryPageIterator Ok(self.column_desc.clone()) } } - -pub fn make_pages( - desc: ColumnDescPtr, - encoding: Encoding, - num_pages: usize, - levels_per_page: usize, - min: T::T, - max: T::T, - def_levels: &mut Vec, - rep_levels: &mut Vec, - values: &mut Vec, - pages: &mut VecDeque, - use_v2: bool, -) where - T::T: PartialOrd + SampleUniform + Copy, -{ - let mut num_values = 0; - let max_def_level = desc.max_def_level(); - let max_rep_level = desc.max_rep_level(); - - let mut dict_encoder = DictEncoder::::new(desc.clone()); - - for i in 0..num_pages { - let mut num_values_cur_page = 0; - let level_range = i * levels_per_page..(i + 1) * levels_per_page; - - if max_def_level > 0 { - random_numbers_range(levels_per_page, 0, max_def_level + 1, def_levels); - for dl in &def_levels[level_range.clone()] { - if *dl == max_def_level { - num_values_cur_page += 1; - } - } - } else { - num_values_cur_page = levels_per_page; - } - if max_rep_level > 0 { - random_numbers_range(levels_per_page, 0, max_rep_level + 1, rep_levels); - } - random_numbers_range(num_values_cur_page, min, max, values); - - // Generate the current page - - let mut pb = - DataPageBuilderImpl::new(desc.clone(), num_values_cur_page as u32, use_v2); - if max_rep_level > 0 { - pb.add_rep_levels(max_rep_level, &rep_levels[level_range.clone()]); - } - if max_def_level > 0 { - pb.add_def_levels(max_def_level, &def_levels[level_range]); - } - - let value_range = num_values..num_values + num_values_cur_page; - match encoding { - Encoding::PLAIN_DICTIONARY | Encoding::RLE_DICTIONARY => { - let _ = dict_encoder.put(&values[value_range.clone()]); - let indices = dict_encoder - .write_indices() - .expect("write_indices() should be OK"); - pb.add_indices(indices); - } - Encoding::PLAIN => { - pb.add_values::(encoding, &values[value_range]); - } - enc => panic!("Unexpected encoding {}", enc), - } - - let data_page = pb.consume(); - pages.push_back(data_page); - num_values += num_values_cur_page; - } - - if encoding == Encoding::PLAIN_DICTIONARY || encoding == Encoding::RLE_DICTIONARY { - let dict = dict_encoder - .write_dict() - .expect("write_dict() should be OK"); - let dict_page = Page::DictionaryPage { - buf: dict, - num_values: dict_encoder.num_entries() as u32, - encoding: Encoding::RLE_DICTIONARY, - is_sorted: false, - }; - pages.push_front(dict_page); - } -} diff --git a/parquet/src/util/test_common/rand_gen.rs b/parquet/src/util/test_common/rand_gen.rs index d9c256577684..4e54aa7999cf 100644 --- a/parquet/src/util/test_common/rand_gen.rs +++ b/parquet/src/util/test_common/rand_gen.rs @@ -15,13 +15,19 @@ // specific language governing permissions and limitations // under the License. +use crate::basic::Encoding; +use crate::column::page::Page; use rand::{ distributions::{uniform::SampleUniform, Distribution, Standard}, thread_rng, Rng, }; +use std::collections::VecDeque; use crate::data_type::*; +use crate::encodings::encoding::{DictEncoder, Encoder}; +use crate::schema::types::ColumnDescPtr; use crate::util::memory::ByteBufferPtr; +use crate::util::{DataPageBuilder, DataPageBuilderImpl}; /// Random generator of data type `T` values and sequences. pub trait RandGen { @@ -106,15 +112,6 @@ pub fn random_bytes(n: usize) -> Vec { result } -pub fn random_bools(n: usize) -> Vec { - let mut result = vec![]; - let mut rng = thread_rng(); - for _ in 0..n { - result.push(rng.gen::()); - } - result -} - pub fn random_numbers(n: usize) -> Vec where Standard: Distribution, @@ -132,3 +129,89 @@ where result.push(rng.gen_range(low..high)); } } + +#[allow(clippy::too_many_arguments)] +pub fn make_pages( + desc: ColumnDescPtr, + encoding: Encoding, + num_pages: usize, + levels_per_page: usize, + min: T::T, + max: T::T, + def_levels: &mut Vec, + rep_levels: &mut Vec, + values: &mut Vec, + pages: &mut VecDeque, + use_v2: bool, +) where + T::T: PartialOrd + SampleUniform + Copy, +{ + let mut num_values = 0; + let max_def_level = desc.max_def_level(); + let max_rep_level = desc.max_rep_level(); + + let mut dict_encoder = DictEncoder::::new(desc.clone()); + + for i in 0..num_pages { + let mut num_values_cur_page = 0; + let level_range = i * levels_per_page..(i + 1) * levels_per_page; + + if max_def_level > 0 { + random_numbers_range(levels_per_page, 0, max_def_level + 1, def_levels); + for dl in &def_levels[level_range.clone()] { + if *dl == max_def_level { + num_values_cur_page += 1; + } + } + } else { + num_values_cur_page = levels_per_page; + } + if max_rep_level > 0 { + random_numbers_range(levels_per_page, 0, max_rep_level + 1, rep_levels); + } + random_numbers_range(num_values_cur_page, min, max, values); + + // Generate the current page + + let mut pb = + DataPageBuilderImpl::new(desc.clone(), num_values_cur_page as u32, use_v2); + if max_rep_level > 0 { + pb.add_rep_levels(max_rep_level, &rep_levels[level_range.clone()]); + } + if max_def_level > 0 { + pb.add_def_levels(max_def_level, &def_levels[level_range]); + } + + let value_range = num_values..num_values + num_values_cur_page; + match encoding { + Encoding::PLAIN_DICTIONARY | Encoding::RLE_DICTIONARY => { + let _ = dict_encoder.put(&values[value_range.clone()]); + let indices = dict_encoder + .write_indices() + .expect("write_indices() should be OK"); + pb.add_indices(indices); + } + Encoding::PLAIN => { + pb.add_values::(encoding, &values[value_range]); + } + enc => panic!("Unexpected encoding {}", enc), + } + + let data_page = pb.consume(); + pages.push_back(data_page); + num_values += num_values_cur_page; + } + + if encoding == Encoding::PLAIN_DICTIONARY || encoding == Encoding::RLE_DICTIONARY { + let dict = dict_encoder + .write_dict() + .expect("write_dict() should be OK"); + let dict_page = Page::DictionaryPage { + buf: dict, + num_values: dict_encoder.num_entries() as u32, + encoding: Encoding::RLE_DICTIONARY, + is_sorted: false, + }; + pages.push_front(dict_page); + } +} diff --git a/parquet_derive/Cargo.toml b/parquet_derive/Cargo.toml index cf0c943cc248..e32ee1ace5b8 100644 --- a/parquet_derive/Cargo.toml +++ b/parquet_derive/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "parquet_derive" -version = "18.0.0" +version = "22.0.0" license = "Apache-2.0" description = "Derive macros for the Rust implementation of Apache Parquet" homepage = "https://github.com/apache/arrow-rs" @@ -26,7 +26,7 @@ authors = ["Apache Arrow "] keywords = [ "parquet" ] readme = "README.md" edition = "2021" -rust-version = "1.57" +rust-version = "1.62" [lib] proc-macro = true @@ -35,4 +35,4 @@ proc-macro = true proc-macro2 = { version = "1.0", default-features = false } quote = { version = "1.0", default-features = false } syn = { version = "1.0", default-features = false } -parquet = { path = "../parquet", version = "18.0.0" } +parquet = { path = "../parquet", version = "22.0.0" } diff --git a/parquet_derive/README.md b/parquet_derive/README.md index 5b74a89524c6..d3d7f56ebf67 100644 --- a/parquet_derive/README.md +++ b/parquet_derive/README.md @@ -32,8 +32,8 @@ Add this to your Cargo.toml: ```toml [dependencies] -parquet = "18.0.0" -parquet_derive = "18.0.0" +parquet = "22.0.0" +parquet_derive = "22.0.0" ``` and this to your crate root: diff --git a/parquet_derive_test/Cargo.toml b/parquet_derive_test/Cargo.toml index 9b8de68cb8a3..4b814c4c088d 100644 --- a/parquet_derive_test/Cargo.toml +++ b/parquet_derive_test/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "parquet_derive_test" -version = "18.0.0" +version = "22.0.0" license = "Apache-2.0" description = "Integration test package for parquet-derive" homepage = "https://github.com/apache/arrow-rs" @@ -26,9 +26,9 @@ authors = ["Apache Arrow "] keywords = [ "parquet" ] edition = "2021" publish = false -rust-version = "1.57" +rust-version = "1.62" [dependencies] -parquet = { path = "../parquet", version = "18.0.0", default-features = false } -parquet_derive = { path = "../parquet_derive", version = "18.0.0", default-features = false } +parquet = { path = "../parquet", version = "22.0.0", default-features = false } +parquet_derive = { path = "../parquet_derive", version = "22.0.0", default-features = false } chrono = { version="0.4.19", default-features = false, features = [ "clock" ] }